1ea84670 by 周伟奇

add CustomMetric

1 parent 83048d22
......@@ -12,3 +12,6 @@
!.gitignore
test.py
*.h5
*.jpg
*.out
\ No newline at end of file
......
......@@ -4,3 +4,5 @@ CLASS_OTHER_FIRST = True
CLASS_CN_LIST = [CLASS_OTHER_CN, '身份证', '营业执照', '经销商授权书', '个人授权书']
OTHER_THRESHOLDS = 0.5
......
......@@ -21,5 +21,5 @@ if __name__ == '__main__':
batch_size = 128
m.train(dataset_dir, epoch, batch_size, ckpt_path, history_save_path,
train_dir_name='train', validate_dir_name='test')
train_dir_name='train', validate_dir_name='test', thresholds=const.OTHER_THRESHOLDS)
......
......@@ -9,6 +9,43 @@ import matplotlib.pyplot as plt
from base_class import BaseModel
class CustomMetric(metrics.Metric):
def __init__(self, thresholds=0.5, name="custom_metric", **kwargs):
super(CustomMetric, self).__init__(name=name, **kwargs)
self.thresholds = thresholds
self.true_positives = self.add_weight(name="ctp", initializer="zeros")
self.count = self.add_weight(name="count", initializer="zeros", dtype='int32')
def update_state(self, y_true, y_pred, sample_weight=None):
y_true_idx = tf.argmax(y_true, axis=1) + 1
y_true_is_other = tf.cast(tf.math.reduce_sum(y_true, axis=1), "int64")
y_true = tf.math.multiply(y_true_idx, y_true_is_other)
y_pred_idx = tf.argmax(y_pred, axis=1) + 1
y_pred_is_other = tf.cast(tf.math.greater_equal(tf.math.reduce_max(y_pred, axis=1), self.thresholds), 'int64')
y_pred = tf.math.multiply(y_pred_idx, y_pred_is_other)
print(y_true)
print(y_pred)
values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32")
values = tf.cast(values, "float32")
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, "float32")
values = tf.multiply(values, sample_weight)
self.true_positives.assign_add(tf.reduce_sum(values))
self.count.assign_add(tf.shape(y_true)[0])
def result(self):
return self.true_positives / tf.cast(self.count, 'float32')
def reset_state(self):
# The state of the metric will be reset at the start of each epoch.
self.true_positives.assign(0.0)
self.count.assign(0)
class F3Classification(BaseModel):
def __init__(self, class_name_list, class_other_first, *args, **kwargs):
......@@ -18,6 +55,12 @@ class F3Classification(BaseModel):
self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
@staticmethod
def gpu_config():
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
# print(gpus)
tf.config.set_visible_devices(devices=gpus[1], device_type='GPU')
@staticmethod
def history_save(history, save_path):
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
......@@ -90,21 +133,21 @@ class F3Classification(BaseModel):
# 1/10
if random.random() < 0.2:
image = tf.image.random_flip_left_right(image)
return image
return image, label
@staticmethod
def random_flip_up_down(image, label):
# 1/10
if random.random() < 0.2:
image = tf.image.random_flip_up_down(image)
return image
return image, label
@staticmethod
def random_rot90(image, label):
# 1/10
if random.random() < 0.1:
image = tf.image.rot90(image, k=random.randint(1, 3))
return image
return image, label
@staticmethod
# @tf.function
......@@ -166,14 +209,17 @@ class F3Classification(BaseModel):
return model
def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path,
train_dir_name='train', validate_dir_name='test'):
train_dir_name='train', validate_dir_name='test', thresholds=0.5):
self.gpu_config()
model = self.load_model()
model.summary()
model.compile(
optimizer=optimizers.Adam(learning_rate=3e-4),
loss=tfa.losses.SigmoidFocalCrossEntropy(),
metrics=['accuracy', ],
loss=tfa.losses.SigmoidFocalCrossEntropy(), # TODO >>>
metrics=[CustomMetric(thresholds), ],
loss_weights=None,
weighted_metrics=None,
......@@ -214,5 +260,25 @@ class F3Classification(BaseModel):
self.history_save(history, history_save_path)
def test(self):
print(self.class_label_map)
print(self.class_count)
y_true = [
[0, 1, 0],
[0, 1, 0],
[0, 0, 1],
[0, 0, 0],
]
y_pre = [
[0.1, 0.8, 0.9], # TODO multi_label
[0.2, 0.8, 0.1],
[0.2, 0.1, 0.85],
[0.2, 0.4, 0.1],
]
# x = tf.argmax(y_pre, axis=1)
# y = tf.reduce_sum(y_pre, axis=1)
# print(x)
# print(y)
# m = tf.keras.metrics.TopKCategoricalAccuracy(k=1)
m = CustomMetric(0.5)
m.update_state(y_true, y_pre)
print(m.result().numpy())
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!