1ea84670 by 周伟奇

add CustomMetric

1 parent 83048d22
...@@ -11,4 +11,7 @@ ...@@ -11,4 +11,7 @@
11 .* 11 .*
12 !.gitignore 12 !.gitignore
13 13
14 test.py
...\ No newline at end of file ...\ No newline at end of file
14 test.py
15 *.h5
16 *.jpg
17 *.out
...\ No newline at end of file ...\ No newline at end of file
......
...@@ -4,3 +4,5 @@ CLASS_OTHER_FIRST = True ...@@ -4,3 +4,5 @@ CLASS_OTHER_FIRST = True
4 4
5 CLASS_CN_LIST = [CLASS_OTHER_CN, '身份证', '营业执照', '经销商授权书', '个人授权书'] 5 CLASS_CN_LIST = [CLASS_OTHER_CN, '身份证', '营业执照', '经销商授权书', '个人授权书']
6 6
7 OTHER_THRESHOLDS = 0.5
8
......
...@@ -21,5 +21,5 @@ if __name__ == '__main__': ...@@ -21,5 +21,5 @@ if __name__ == '__main__':
21 batch_size = 128 21 batch_size = 128
22 22
23 m.train(dataset_dir, epoch, batch_size, ckpt_path, history_save_path, 23 m.train(dataset_dir, epoch, batch_size, ckpt_path, history_save_path,
24 train_dir_name='train', validate_dir_name='test') 24 train_dir_name='train', validate_dir_name='test', thresholds=const.OTHER_THRESHOLDS)
25 25
......
...@@ -9,6 +9,43 @@ import matplotlib.pyplot as plt ...@@ -9,6 +9,43 @@ import matplotlib.pyplot as plt
9 from base_class import BaseModel 9 from base_class import BaseModel
10 10
11 11
12 class CustomMetric(metrics.Metric):
13
14 def __init__(self, thresholds=0.5, name="custom_metric", **kwargs):
15 super(CustomMetric, self).__init__(name=name, **kwargs)
16 self.thresholds = thresholds
17 self.true_positives = self.add_weight(name="ctp", initializer="zeros")
18 self.count = self.add_weight(name="count", initializer="zeros", dtype='int32')
19
20 def update_state(self, y_true, y_pred, sample_weight=None):
21 y_true_idx = tf.argmax(y_true, axis=1) + 1
22 y_true_is_other = tf.cast(tf.math.reduce_sum(y_true, axis=1), "int64")
23 y_true = tf.math.multiply(y_true_idx, y_true_is_other)
24
25 y_pred_idx = tf.argmax(y_pred, axis=1) + 1
26 y_pred_is_other = tf.cast(tf.math.greater_equal(tf.math.reduce_max(y_pred, axis=1), self.thresholds), 'int64')
27 y_pred = tf.math.multiply(y_pred_idx, y_pred_is_other)
28
29 print(y_true)
30 print(y_pred)
31
32 values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32")
33 values = tf.cast(values, "float32")
34 if sample_weight is not None:
35 sample_weight = tf.cast(sample_weight, "float32")
36 values = tf.multiply(values, sample_weight)
37 self.true_positives.assign_add(tf.reduce_sum(values))
38 self.count.assign_add(tf.shape(y_true)[0])
39
40 def result(self):
41 return self.true_positives / tf.cast(self.count, 'float32')
42
43 def reset_state(self):
44 # The state of the metric will be reset at the start of each epoch.
45 self.true_positives.assign(0.0)
46 self.count.assign(0)
47
48
12 class F3Classification(BaseModel): 49 class F3Classification(BaseModel):
13 50
14 def __init__(self, class_name_list, class_other_first, *args, **kwargs): 51 def __init__(self, class_name_list, class_other_first, *args, **kwargs):
...@@ -18,6 +55,12 @@ class F3Classification(BaseModel): ...@@ -18,6 +55,12 @@ class F3Classification(BaseModel):
18 self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} 55 self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
19 56
20 @staticmethod 57 @staticmethod
58 def gpu_config():
59 gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
60 # print(gpus)
61 tf.config.set_visible_devices(devices=gpus[1], device_type='GPU')
62
63 @staticmethod
21 def history_save(history, save_path): 64 def history_save(history, save_path):
22 acc = history.history['accuracy'] 65 acc = history.history['accuracy']
23 val_acc = history.history['val_accuracy'] 66 val_acc = history.history['val_accuracy']
...@@ -90,21 +133,21 @@ class F3Classification(BaseModel): ...@@ -90,21 +133,21 @@ class F3Classification(BaseModel):
90 # 1/10 133 # 1/10
91 if random.random() < 0.2: 134 if random.random() < 0.2:
92 image = tf.image.random_flip_left_right(image) 135 image = tf.image.random_flip_left_right(image)
93 return image 136 return image, label
94 137
95 @staticmethod 138 @staticmethod
96 def random_flip_up_down(image, label): 139 def random_flip_up_down(image, label):
97 # 1/10 140 # 1/10
98 if random.random() < 0.2: 141 if random.random() < 0.2:
99 image = tf.image.random_flip_up_down(image) 142 image = tf.image.random_flip_up_down(image)
100 return image 143 return image, label
101 144
102 @staticmethod 145 @staticmethod
103 def random_rot90(image, label): 146 def random_rot90(image, label):
104 # 1/10 147 # 1/10
105 if random.random() < 0.1: 148 if random.random() < 0.1:
106 image = tf.image.rot90(image, k=random.randint(1, 3)) 149 image = tf.image.rot90(image, k=random.randint(1, 3))
107 return image 150 return image, label
108 151
109 @staticmethod 152 @staticmethod
110 # @tf.function 153 # @tf.function
...@@ -166,14 +209,17 @@ class F3Classification(BaseModel): ...@@ -166,14 +209,17 @@ class F3Classification(BaseModel):
166 return model 209 return model
167 210
168 def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path, 211 def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path,
169 train_dir_name='train', validate_dir_name='test'): 212 train_dir_name='train', validate_dir_name='test', thresholds=0.5):
213
214 self.gpu_config()
215
170 model = self.load_model() 216 model = self.load_model()
171 model.summary() 217 model.summary()
172 218
173 model.compile( 219 model.compile(
174 optimizer=optimizers.Adam(learning_rate=3e-4), 220 optimizer=optimizers.Adam(learning_rate=3e-4),
175 loss=tfa.losses.SigmoidFocalCrossEntropy(), 221 loss=tfa.losses.SigmoidFocalCrossEntropy(), # TODO >>>
176 metrics=['accuracy', ], 222 metrics=[CustomMetric(thresholds), ],
177 223
178 loss_weights=None, 224 loss_weights=None,
179 weighted_metrics=None, 225 weighted_metrics=None,
...@@ -214,5 +260,25 @@ class F3Classification(BaseModel): ...@@ -214,5 +260,25 @@ class F3Classification(BaseModel):
214 self.history_save(history, history_save_path) 260 self.history_save(history, history_save_path)
215 261
216 def test(self): 262 def test(self):
217 print(self.class_label_map) 263 y_true = [
218 print(self.class_count) 264 [0, 1, 0],
265 [0, 1, 0],
266 [0, 0, 1],
267 [0, 0, 0],
268 ]
269 y_pre = [
270 [0.1, 0.8, 0.9], # TODO multi_label
271 [0.2, 0.8, 0.1],
272 [0.2, 0.1, 0.85],
273 [0.2, 0.4, 0.1],
274 ]
275
276 # x = tf.argmax(y_pre, axis=1)
277 # y = tf.reduce_sum(y_pre, axis=1)
278 # print(x)
279 # print(y)
280
281 # m = tf.keras.metrics.TopKCategoricalAccuracy(k=1)
282 m = CustomMetric(0.5)
283 m.update_state(y_true, y_pre)
284 print(m.result().numpy())
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!