add CustomMetric
Showing
4 changed files
with
81 additions
and
10 deletions
... | @@ -21,5 +21,5 @@ if __name__ == '__main__': | ... | @@ -21,5 +21,5 @@ if __name__ == '__main__': |
21 | batch_size = 128 | 21 | batch_size = 128 |
22 | 22 | ||
23 | m.train(dataset_dir, epoch, batch_size, ckpt_path, history_save_path, | 23 | m.train(dataset_dir, epoch, batch_size, ckpt_path, history_save_path, |
24 | train_dir_name='train', validate_dir_name='test') | 24 | train_dir_name='train', validate_dir_name='test', thresholds=const.OTHER_THRESHOLDS) |
25 | 25 | ... | ... |
... | @@ -9,6 +9,43 @@ import matplotlib.pyplot as plt | ... | @@ -9,6 +9,43 @@ import matplotlib.pyplot as plt |
9 | from base_class import BaseModel | 9 | from base_class import BaseModel |
10 | 10 | ||
11 | 11 | ||
12 | class CustomMetric(metrics.Metric): | ||
13 | |||
14 | def __init__(self, thresholds=0.5, name="custom_metric", **kwargs): | ||
15 | super(CustomMetric, self).__init__(name=name, **kwargs) | ||
16 | self.thresholds = thresholds | ||
17 | self.true_positives = self.add_weight(name="ctp", initializer="zeros") | ||
18 | self.count = self.add_weight(name="count", initializer="zeros", dtype='int32') | ||
19 | |||
20 | def update_state(self, y_true, y_pred, sample_weight=None): | ||
21 | y_true_idx = tf.argmax(y_true, axis=1) + 1 | ||
22 | y_true_is_other = tf.cast(tf.math.reduce_sum(y_true, axis=1), "int64") | ||
23 | y_true = tf.math.multiply(y_true_idx, y_true_is_other) | ||
24 | |||
25 | y_pred_idx = tf.argmax(y_pred, axis=1) + 1 | ||
26 | y_pred_is_other = tf.cast(tf.math.greater_equal(tf.math.reduce_max(y_pred, axis=1), self.thresholds), 'int64') | ||
27 | y_pred = tf.math.multiply(y_pred_idx, y_pred_is_other) | ||
28 | |||
29 | print(y_true) | ||
30 | print(y_pred) | ||
31 | |||
32 | values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32") | ||
33 | values = tf.cast(values, "float32") | ||
34 | if sample_weight is not None: | ||
35 | sample_weight = tf.cast(sample_weight, "float32") | ||
36 | values = tf.multiply(values, sample_weight) | ||
37 | self.true_positives.assign_add(tf.reduce_sum(values)) | ||
38 | self.count.assign_add(tf.shape(y_true)[0]) | ||
39 | |||
40 | def result(self): | ||
41 | return self.true_positives / tf.cast(self.count, 'float32') | ||
42 | |||
43 | def reset_state(self): | ||
44 | # The state of the metric will be reset at the start of each epoch. | ||
45 | self.true_positives.assign(0.0) | ||
46 | self.count.assign(0) | ||
47 | |||
48 | |||
12 | class F3Classification(BaseModel): | 49 | class F3Classification(BaseModel): |
13 | 50 | ||
14 | def __init__(self, class_name_list, class_other_first, *args, **kwargs): | 51 | def __init__(self, class_name_list, class_other_first, *args, **kwargs): |
... | @@ -18,6 +55,12 @@ class F3Classification(BaseModel): | ... | @@ -18,6 +55,12 @@ class F3Classification(BaseModel): |
18 | self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} | 55 | self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} |
19 | 56 | ||
20 | @staticmethod | 57 | @staticmethod |
58 | def gpu_config(): | ||
59 | gpus = tf.config.experimental.list_physical_devices(device_type='GPU') | ||
60 | # print(gpus) | ||
61 | tf.config.set_visible_devices(devices=gpus[1], device_type='GPU') | ||
62 | |||
63 | @staticmethod | ||
21 | def history_save(history, save_path): | 64 | def history_save(history, save_path): |
22 | acc = history.history['accuracy'] | 65 | acc = history.history['accuracy'] |
23 | val_acc = history.history['val_accuracy'] | 66 | val_acc = history.history['val_accuracy'] |
... | @@ -90,21 +133,21 @@ class F3Classification(BaseModel): | ... | @@ -90,21 +133,21 @@ class F3Classification(BaseModel): |
90 | # 1/10 | 133 | # 1/10 |
91 | if random.random() < 0.2: | 134 | if random.random() < 0.2: |
92 | image = tf.image.random_flip_left_right(image) | 135 | image = tf.image.random_flip_left_right(image) |
93 | return image | 136 | return image, label |
94 | 137 | ||
95 | @staticmethod | 138 | @staticmethod |
96 | def random_flip_up_down(image, label): | 139 | def random_flip_up_down(image, label): |
97 | # 1/10 | 140 | # 1/10 |
98 | if random.random() < 0.2: | 141 | if random.random() < 0.2: |
99 | image = tf.image.random_flip_up_down(image) | 142 | image = tf.image.random_flip_up_down(image) |
100 | return image | 143 | return image, label |
101 | 144 | ||
102 | @staticmethod | 145 | @staticmethod |
103 | def random_rot90(image, label): | 146 | def random_rot90(image, label): |
104 | # 1/10 | 147 | # 1/10 |
105 | if random.random() < 0.1: | 148 | if random.random() < 0.1: |
106 | image = tf.image.rot90(image, k=random.randint(1, 3)) | 149 | image = tf.image.rot90(image, k=random.randint(1, 3)) |
107 | return image | 150 | return image, label |
108 | 151 | ||
109 | @staticmethod | 152 | @staticmethod |
110 | # @tf.function | 153 | # @tf.function |
... | @@ -166,14 +209,17 @@ class F3Classification(BaseModel): | ... | @@ -166,14 +209,17 @@ class F3Classification(BaseModel): |
166 | return model | 209 | return model |
167 | 210 | ||
168 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path, | 211 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path, |
169 | train_dir_name='train', validate_dir_name='test'): | 212 | train_dir_name='train', validate_dir_name='test', thresholds=0.5): |
213 | |||
214 | self.gpu_config() | ||
215 | |||
170 | model = self.load_model() | 216 | model = self.load_model() |
171 | model.summary() | 217 | model.summary() |
172 | 218 | ||
173 | model.compile( | 219 | model.compile( |
174 | optimizer=optimizers.Adam(learning_rate=3e-4), | 220 | optimizer=optimizers.Adam(learning_rate=3e-4), |
175 | loss=tfa.losses.SigmoidFocalCrossEntropy(), | 221 | loss=tfa.losses.SigmoidFocalCrossEntropy(), # TODO >>> |
176 | metrics=['accuracy', ], | 222 | metrics=[CustomMetric(thresholds), ], |
177 | 223 | ||
178 | loss_weights=None, | 224 | loss_weights=None, |
179 | weighted_metrics=None, | 225 | weighted_metrics=None, |
... | @@ -214,5 +260,25 @@ class F3Classification(BaseModel): | ... | @@ -214,5 +260,25 @@ class F3Classification(BaseModel): |
214 | self.history_save(history, history_save_path) | 260 | self.history_save(history, history_save_path) |
215 | 261 | ||
216 | def test(self): | 262 | def test(self): |
217 | print(self.class_label_map) | 263 | y_true = [ |
218 | print(self.class_count) | 264 | [0, 1, 0], |
265 | [0, 1, 0], | ||
266 | [0, 0, 1], | ||
267 | [0, 0, 0], | ||
268 | ] | ||
269 | y_pre = [ | ||
270 | [0.1, 0.8, 0.9], # TODO multi_label | ||
271 | [0.2, 0.8, 0.1], | ||
272 | [0.2, 0.1, 0.85], | ||
273 | [0.2, 0.4, 0.1], | ||
274 | ] | ||
275 | |||
276 | # x = tf.argmax(y_pre, axis=1) | ||
277 | # y = tf.reduce_sum(y_pre, axis=1) | ||
278 | # print(x) | ||
279 | # print(y) | ||
280 | |||
281 | # m = tf.keras.metrics.TopKCategoricalAccuracy(k=1) | ||
282 | m = CustomMetric(0.5) | ||
283 | m.update_state(y_true, y_pre) | ||
284 | print(m.result().numpy()) | ... | ... |
-
Please register or sign in to post a comment