add CustomMetric
Showing
4 changed files
with
80 additions
and
9 deletions
| ... | @@ -21,5 +21,5 @@ if __name__ == '__main__': | ... | @@ -21,5 +21,5 @@ if __name__ == '__main__': |
| 21 | batch_size = 128 | 21 | batch_size = 128 |
| 22 | 22 | ||
| 23 | m.train(dataset_dir, epoch, batch_size, ckpt_path, history_save_path, | 23 | m.train(dataset_dir, epoch, batch_size, ckpt_path, history_save_path, |
| 24 | train_dir_name='train', validate_dir_name='test') | 24 | train_dir_name='train', validate_dir_name='test', thresholds=const.OTHER_THRESHOLDS) |
| 25 | 25 | ... | ... |
| ... | @@ -9,6 +9,43 @@ import matplotlib.pyplot as plt | ... | @@ -9,6 +9,43 @@ import matplotlib.pyplot as plt |
| 9 | from base_class import BaseModel | 9 | from base_class import BaseModel |
| 10 | 10 | ||
| 11 | 11 | ||
| 12 | class CustomMetric(metrics.Metric): | ||
| 13 | |||
| 14 | def __init__(self, thresholds=0.5, name="custom_metric", **kwargs): | ||
| 15 | super(CustomMetric, self).__init__(name=name, **kwargs) | ||
| 16 | self.thresholds = thresholds | ||
| 17 | self.true_positives = self.add_weight(name="ctp", initializer="zeros") | ||
| 18 | self.count = self.add_weight(name="count", initializer="zeros", dtype='int32') | ||
| 19 | |||
| 20 | def update_state(self, y_true, y_pred, sample_weight=None): | ||
| 21 | y_true_idx = tf.argmax(y_true, axis=1) + 1 | ||
| 22 | y_true_is_other = tf.cast(tf.math.reduce_sum(y_true, axis=1), "int64") | ||
| 23 | y_true = tf.math.multiply(y_true_idx, y_true_is_other) | ||
| 24 | |||
| 25 | y_pred_idx = tf.argmax(y_pred, axis=1) + 1 | ||
| 26 | y_pred_is_other = tf.cast(tf.math.greater_equal(tf.math.reduce_max(y_pred, axis=1), self.thresholds), 'int64') | ||
| 27 | y_pred = tf.math.multiply(y_pred_idx, y_pred_is_other) | ||
| 28 | |||
| 29 | print(y_true) | ||
| 30 | print(y_pred) | ||
| 31 | |||
| 32 | values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32") | ||
| 33 | values = tf.cast(values, "float32") | ||
| 34 | if sample_weight is not None: | ||
| 35 | sample_weight = tf.cast(sample_weight, "float32") | ||
| 36 | values = tf.multiply(values, sample_weight) | ||
| 37 | self.true_positives.assign_add(tf.reduce_sum(values)) | ||
| 38 | self.count.assign_add(tf.shape(y_true)[0]) | ||
| 39 | |||
| 40 | def result(self): | ||
| 41 | return self.true_positives / tf.cast(self.count, 'float32') | ||
| 42 | |||
| 43 | def reset_state(self): | ||
| 44 | # The state of the metric will be reset at the start of each epoch. | ||
| 45 | self.true_positives.assign(0.0) | ||
| 46 | self.count.assign(0) | ||
| 47 | |||
| 48 | |||
| 12 | class F3Classification(BaseModel): | 49 | class F3Classification(BaseModel): |
| 13 | 50 | ||
| 14 | def __init__(self, class_name_list, class_other_first, *args, **kwargs): | 51 | def __init__(self, class_name_list, class_other_first, *args, **kwargs): |
| ... | @@ -18,6 +55,12 @@ class F3Classification(BaseModel): | ... | @@ -18,6 +55,12 @@ class F3Classification(BaseModel): |
| 18 | self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} | 55 | self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} |
| 19 | 56 | ||
| 20 | @staticmethod | 57 | @staticmethod |
| 58 | def gpu_config(): | ||
| 59 | gpus = tf.config.experimental.list_physical_devices(device_type='GPU') | ||
| 60 | # print(gpus) | ||
| 61 | tf.config.set_visible_devices(devices=gpus[1], device_type='GPU') | ||
| 62 | |||
| 63 | @staticmethod | ||
| 21 | def history_save(history, save_path): | 64 | def history_save(history, save_path): |
| 22 | acc = history.history['accuracy'] | 65 | acc = history.history['accuracy'] |
| 23 | val_acc = history.history['val_accuracy'] | 66 | val_acc = history.history['val_accuracy'] |
| ... | @@ -90,21 +133,21 @@ class F3Classification(BaseModel): | ... | @@ -90,21 +133,21 @@ class F3Classification(BaseModel): |
| 90 | # 1/10 | 133 | # 1/10 |
| 91 | if random.random() < 0.2: | 134 | if random.random() < 0.2: |
| 92 | image = tf.image.random_flip_left_right(image) | 135 | image = tf.image.random_flip_left_right(image) |
| 93 | return image | 136 | return image, label |
| 94 | 137 | ||
| 95 | @staticmethod | 138 | @staticmethod |
| 96 | def random_flip_up_down(image, label): | 139 | def random_flip_up_down(image, label): |
| 97 | # 1/10 | 140 | # 1/10 |
| 98 | if random.random() < 0.2: | 141 | if random.random() < 0.2: |
| 99 | image = tf.image.random_flip_up_down(image) | 142 | image = tf.image.random_flip_up_down(image) |
| 100 | return image | 143 | return image, label |
| 101 | 144 | ||
| 102 | @staticmethod | 145 | @staticmethod |
| 103 | def random_rot90(image, label): | 146 | def random_rot90(image, label): |
| 104 | # 1/10 | 147 | # 1/10 |
| 105 | if random.random() < 0.1: | 148 | if random.random() < 0.1: |
| 106 | image = tf.image.rot90(image, k=random.randint(1, 3)) | 149 | image = tf.image.rot90(image, k=random.randint(1, 3)) |
| 107 | return image | 150 | return image, label |
| 108 | 151 | ||
| 109 | @staticmethod | 152 | @staticmethod |
| 110 | # @tf.function | 153 | # @tf.function |
| ... | @@ -166,14 +209,17 @@ class F3Classification(BaseModel): | ... | @@ -166,14 +209,17 @@ class F3Classification(BaseModel): |
| 166 | return model | 209 | return model |
| 167 | 210 | ||
| 168 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path, | 211 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path, |
| 169 | train_dir_name='train', validate_dir_name='test'): | 212 | train_dir_name='train', validate_dir_name='test', thresholds=0.5): |
| 213 | |||
| 214 | self.gpu_config() | ||
| 215 | |||
| 170 | model = self.load_model() | 216 | model = self.load_model() |
| 171 | model.summary() | 217 | model.summary() |
| 172 | 218 | ||
| 173 | model.compile( | 219 | model.compile( |
| 174 | optimizer=optimizers.Adam(learning_rate=3e-4), | 220 | optimizer=optimizers.Adam(learning_rate=3e-4), |
| 175 | loss=tfa.losses.SigmoidFocalCrossEntropy(), | 221 | loss=tfa.losses.SigmoidFocalCrossEntropy(), # TODO >>> |
| 176 | metrics=['accuracy', ], | 222 | metrics=[CustomMetric(thresholds), ], |
| 177 | 223 | ||
| 178 | loss_weights=None, | 224 | loss_weights=None, |
| 179 | weighted_metrics=None, | 225 | weighted_metrics=None, |
| ... | @@ -214,5 +260,25 @@ class F3Classification(BaseModel): | ... | @@ -214,5 +260,25 @@ class F3Classification(BaseModel): |
| 214 | self.history_save(history, history_save_path) | 260 | self.history_save(history, history_save_path) |
| 215 | 261 | ||
| 216 | def test(self): | 262 | def test(self): |
| 217 | print(self.class_label_map) | 263 | y_true = [ |
| 218 | print(self.class_count) | 264 | [0, 1, 0], |
| 265 | [0, 1, 0], | ||
| 266 | [0, 0, 1], | ||
| 267 | [0, 0, 0], | ||
| 268 | ] | ||
| 269 | y_pre = [ | ||
| 270 | [0.1, 0.8, 0.9], # TODO multi_label | ||
| 271 | [0.2, 0.8, 0.1], | ||
| 272 | [0.2, 0.1, 0.85], | ||
| 273 | [0.2, 0.4, 0.1], | ||
| 274 | ] | ||
| 275 | |||
| 276 | # x = tf.argmax(y_pre, axis=1) | ||
| 277 | # y = tf.reduce_sum(y_pre, axis=1) | ||
| 278 | # print(x) | ||
| 279 | # print(y) | ||
| 280 | |||
| 281 | # m = tf.keras.metrics.TopKCategoricalAccuracy(k=1) | ||
| 282 | m = CustomMetric(0.5) | ||
| 283 | m.update_state(y_true, y_pre) | ||
| 284 | print(m.result().numpy()) | ... | ... |
-
Please register or sign in to post a comment