add predict
Showing
11 changed files
with
258 additions
and
130 deletions
README.md
0 → 100644
| 1 | ## Useage | ||
| 2 | |||
| 3 | ### 分类 | ||
| 4 | ```python | ||
| 5 | import cv2 | ||
| 6 | from classification import classifier | ||
| 7 | |||
| 8 | img_path = 'xxx' | ||
| 9 | img = cv2.imread(img_path) | ||
| 10 | |||
| 11 | print(classifier.class_name_list) | ||
| 12 | res = classifier.predict(img) | ||
| 13 | print(res) # {'label': '营业执照', 'confidence': 0.988462} | ||
| 14 | ``` | ||
| 15 | |||
| 16 | ### 授权书信息提取 | ||
| 17 | ```python | ||
| 18 | from authorization_from import retriever_individuals, retriever_companies | ||
| 19 | |||
| 20 | # 个人授权书 | ||
| 21 | res = retriever_companies.get_target_fields(go_res, signature_res) | ||
| 22 | print(res) | ||
| 23 | |||
| 24 | # 公司授权书 | ||
| 25 | # res = retriever_individuals.get_target_fields(go_res, signature_res) | ||
| 26 | # print(res) | ||
| 27 | ``` |
authorization_from/README.md
deleted
100644 → 0
| 1 | ## Useage | ||
| 2 | **F3个人授权书和企业授权书的信息提取** | ||
| 3 | |||
| 4 | ```python | ||
| 5 | from retriever import Retriever | ||
| 6 | import const | ||
| 7 | |||
| 8 | # 个人授权书 {'姓名': 'xxx', '个人身份证件号码': 'xxx', '签字': '有'} | ||
| 9 | r = Retriever(const.TARGET_FIELD_INDIVIDUALS) | ||
| 10 | |||
| 11 | # 企业授权书 {'经销商名称': 'xx', '经销商代码-宝马中国': 'xx', '管理人员姓名-总经理': 'xx', '公司公章': '有', '法定代表人签章': '有'} | ||
| 12 | # r = Retriever(const.TARGET_FIELD_COMPANIES) | ||
| 13 | res = r.get_target_fields(go_res, signature_res) | ||
| 14 | ``` | ||
| 15 | |||
| 16 | |||
| 17 | |||
| 18 |
authorization_from/__init__.py
0 → 100644
classification/__init__.py
0 → 100644
| 1 | import os.path | ||
| 2 | |||
| 3 | from .model import F3Classification | ||
| 4 | from .const import CLASS_CN_LIST, CLASS_OTHER_FIRST | ||
| 5 | |||
| 6 | classifier = F3Classification( | ||
| 7 | class_name_list=CLASS_CN_LIST, | ||
| 8 | class_other_first=CLASS_OTHER_FIRST | ||
| 9 | ) | ||
| 10 | |||
| 11 | classifier.load_model(load_weights_path=os.path.join( | ||
| 12 | os.path.dirname(os.path.abspath(__file__)), 'ckpt_prod.h5')) | ||
| 13 |
| ... | @@ -3,14 +3,14 @@ class BaseModel: | ... | @@ -3,14 +3,14 @@ class BaseModel: |
| 3 | All Model classes should extend BaseModel. | 3 | All Model classes should extend BaseModel. |
| 4 | """ | 4 | """ |
| 5 | 5 | ||
| 6 | def load_model(self): | 6 | def load_model(self, for_training=False, load_weights_path=None): |
| 7 | """ | 7 | """ |
| 8 | Defining the network structure and return | 8 | Defining the network structure and return |
| 9 | """ | 9 | """ |
| 10 | raise NotImplementedError(".load() must be overridden.") | 10 | raise NotImplementedError(".load() must be overridden.") |
| 11 | 11 | ||
| 12 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path, | 12 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path, |
| 13 | train_dir_name='train', validate_dir_name='test'): | 13 | train_dir_name='train', validate_dir_name='test', thresholds=0.5, metrics_name='accuracy'): |
| 14 | """ | 14 | """ |
| 15 | Model training process | 15 | Model training process |
| 16 | """ | 16 | """ | ... | ... |
| ... | @@ -2,7 +2,8 @@ CLASS_OTHER_CN = '其他' | ... | @@ -2,7 +2,8 @@ CLASS_OTHER_CN = '其他' |
| 2 | 2 | ||
| 3 | CLASS_OTHER_FIRST = True | 3 | CLASS_OTHER_FIRST = True |
| 4 | 4 | ||
| 5 | CLASS_CN_LIST = [CLASS_OTHER_CN, '身份证', '营业执照', '经销商授权书', '个人授权书'] | 5 | # CLASS_CN_LIST = [CLASS_OTHER_CN, '身份证', '营业执照', '经销商授权书', '个人授权书'] |
| 6 | CLASS_CN_LIST = [CLASS_OTHER_CN, '营业执照', '经销商授权书', '个人授权书'] | ||
| 6 | 7 | ||
| 7 | OTHER_THRESHOLDS = 0.5 | 8 | OTHER_THRESHOLDS = 0.5 |
| 8 | 9 | ... | ... |
classification/main.py
deleted
100644 → 0
| 1 | import os | ||
| 2 | from datetime import datetime | ||
| 3 | from model import F3Classification | ||
| 4 | import const | ||
| 5 | |||
| 6 | |||
| 7 | if __name__ == '__main__': | ||
| 8 | base_dir = os.path.dirname(os.path.abspath(__file__)) | ||
| 9 | |||
| 10 | m = F3Classification( | ||
| 11 | class_name_list=const.CLASS_CN_LIST, | ||
| 12 | class_other_first=const.CLASS_OTHER_FIRST | ||
| 13 | ) | ||
| 14 | |||
| 15 | # m.test() | ||
| 16 | |||
| 17 | dataset_dir = '/home/zwq/data/data_224_f3' | ||
| 18 | ckpt_path = os.path.join(base_dir, 'ckpt_{0}.h5'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))) | ||
| 19 | history_save_path = os.path.join(base_dir, 'history_{0}.jpg'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))) | ||
| 20 | epoch = 100 | ||
| 21 | batch_size = 128 | ||
| 22 | |||
| 23 | m.train(dataset_dir, epoch, batch_size, ckpt_path, history_save_path, | ||
| 24 | train_dir_name='train', validate_dir_name='test', thresholds=const.OTHER_THRESHOLDS) | ||
| 25 |
classification/metrics.py
0 → 100644
| 1 | import tensorflow as tf | ||
| 2 | from keras import metrics | ||
| 3 | |||
| 4 | |||
| 5 | class CustomMetric(metrics.Metric): | ||
| 6 | |||
| 7 | def __init__(self, thresholds=0.5, name="custom_metric", **kwargs): | ||
| 8 | super(CustomMetric, self).__init__(name=name, **kwargs) | ||
| 9 | self.thresholds = thresholds | ||
| 10 | self.true_positives = self.add_weight(name="ctp", initializer="zeros") | ||
| 11 | self.count = self.add_weight(name="count", initializer="zeros", dtype='int32') | ||
| 12 | |||
| 13 | @staticmethod | ||
| 14 | def y_true_with_others(y_true): | ||
| 15 | y_true_idx = tf.argmax(y_true, axis=1) + 1 | ||
| 16 | y_true_is_other = tf.cast(tf.math.reduce_sum(y_true, axis=1), "int64") | ||
| 17 | y_true = tf.math.multiply(y_true_idx, y_true_is_other) | ||
| 18 | return y_true | ||
| 19 | |||
| 20 | def y_pred_with_others(self, y_pred): | ||
| 21 | y_pred_idx = tf.argmax(y_pred, axis=1) + 1 | ||
| 22 | y_pred_is_other = tf.cast(tf.math.greater_equal(tf.math.reduce_max(y_pred, axis=1), self.thresholds), 'int64') | ||
| 23 | y_pred = tf.math.multiply(y_pred_idx, y_pred_is_other) | ||
| 24 | return y_pred | ||
| 25 | |||
| 26 | def update_state(self, y_true, y_pred, sample_weight=None): | ||
| 27 | y_true = self.y_true_with_others(y_true) | ||
| 28 | y_pred = self.y_pred_with_others(y_pred) | ||
| 29 | |||
| 30 | # print(y_true) | ||
| 31 | # print(y_pred) | ||
| 32 | |||
| 33 | values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32") | ||
| 34 | values = tf.cast(values, "float32") | ||
| 35 | if sample_weight is not None: | ||
| 36 | sample_weight = tf.cast(sample_weight, "float32") | ||
| 37 | values = tf.multiply(values, sample_weight) | ||
| 38 | self.true_positives.assign_add(tf.reduce_sum(values)) | ||
| 39 | self.count.assign_add(tf.shape(y_true)[0]) | ||
| 40 | |||
| 41 | def result(self): | ||
| 42 | return self.true_positives / tf.cast(self.count, 'float32') | ||
| 43 | |||
| 44 | def reset_state(self): | ||
| 45 | # The state of the metric will be reset at the start of each epoch. | ||
| 46 | self.true_positives.assign(0.0) | ||
| 47 | self.count.assign(0) |
| ... | @@ -2,57 +2,25 @@ import os | ... | @@ -2,57 +2,25 @@ import os |
| 2 | import random | 2 | import random |
| 3 | import tensorflow as tf | 3 | import tensorflow as tf |
| 4 | import tensorflow_addons as tfa | 4 | import tensorflow_addons as tfa |
| 5 | from keras.applications.mobilenet_v2 import MobileNetV2 | ||
| 6 | from keras import layers, models, optimizers, losses, metrics, callbacks, applications | ||
| 7 | import matplotlib.pyplot as plt | ||
| 8 | |||
| 9 | from base_class import BaseModel | ||
| 10 | |||
| 11 | |||
| 12 | class CustomMetric(metrics.Metric): | ||
| 13 | |||
| 14 | def __init__(self, thresholds=0.5, name="custom_metric", **kwargs): | ||
| 15 | super(CustomMetric, self).__init__(name=name, **kwargs) | ||
| 16 | self.thresholds = thresholds | ||
| 17 | self.true_positives = self.add_weight(name="ctp", initializer="zeros") | ||
| 18 | self.count = self.add_weight(name="count", initializer="zeros", dtype='int32') | ||
| 19 | |||
| 20 | def update_state(self, y_true, y_pred, sample_weight=None): | ||
| 21 | y_true_idx = tf.argmax(y_true, axis=1) + 1 | ||
| 22 | y_true_is_other = tf.cast(tf.math.reduce_sum(y_true, axis=1), "int64") | ||
| 23 | y_true = tf.math.multiply(y_true_idx, y_true_is_other) | ||
| 24 | 5 | ||
| 25 | y_pred_idx = tf.argmax(y_pred, axis=1) + 1 | 6 | from keras.applications.mobilenet_v2 import MobileNetV2 |
| 26 | y_pred_is_other = tf.cast(tf.math.greater_equal(tf.math.reduce_max(y_pred, axis=1), self.thresholds), 'int64') | 7 | from keras import layers, models, optimizers, callbacks, applications |
| 27 | y_pred = tf.math.multiply(y_pred_idx, y_pred_is_other) | 8 | from sklearn.metrics import confusion_matrix, accuracy_score, classification_report |
| 28 | |||
| 29 | print(y_true) | ||
| 30 | print(y_pred) | ||
| 31 | |||
| 32 | values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32") | ||
| 33 | values = tf.cast(values, "float32") | ||
| 34 | if sample_weight is not None: | ||
| 35 | sample_weight = tf.cast(sample_weight, "float32") | ||
| 36 | values = tf.multiply(values, sample_weight) | ||
| 37 | self.true_positives.assign_add(tf.reduce_sum(values)) | ||
| 38 | self.count.assign_add(tf.shape(y_true)[0]) | ||
| 39 | |||
| 40 | def result(self): | ||
| 41 | return self.true_positives / tf.cast(self.count, 'float32') | ||
| 42 | 9 | ||
| 43 | def reset_state(self): | 10 | from .base_class import BaseModel |
| 44 | # The state of the metric will be reset at the start of each epoch. | 11 | from .metrics import CustomMetric |
| 45 | self.true_positives.assign(0.0) | 12 | from .utils import history_save, plot_confusion_matrix |
| 46 | self.count.assign(0) | ||
| 47 | 13 | ||
| 48 | 14 | ||
| 49 | class F3Classification(BaseModel): | 15 | class F3Classification(BaseModel): |
| 50 | 16 | ||
| 51 | def __init__(self, class_name_list, class_other_first, *args, **kwargs): | 17 | def __init__(self, class_name_list, class_other_first, *args, **kwargs): |
| 52 | super().__init__(*args, **kwargs) | 18 | super().__init__(*args, **kwargs) |
| 19 | self.class_name_list = class_name_list | ||
| 53 | self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1 | 20 | self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1 |
| 54 | self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) | 21 | self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) |
| 55 | self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} | 22 | self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} |
| 23 | self.model = None | ||
| 56 | 24 | ||
| 57 | @staticmethod | 25 | @staticmethod |
| 58 | def gpu_config(): | 26 | def gpu_config(): |
| ... | @@ -61,34 +29,6 @@ class F3Classification(BaseModel): | ... | @@ -61,34 +29,6 @@ class F3Classification(BaseModel): |
| 61 | tf.config.set_visible_devices(devices=gpus[1], device_type='GPU') | 29 | tf.config.set_visible_devices(devices=gpus[1], device_type='GPU') |
| 62 | 30 | ||
| 63 | @staticmethod | 31 | @staticmethod |
| 64 | def history_save(history, save_path): | ||
| 65 | acc = history.history['accuracy'] | ||
| 66 | val_acc = history.history['val_accuracy'] | ||
| 67 | |||
| 68 | loss = history.history['loss'] | ||
| 69 | val_loss = history.history['val_loss'] | ||
| 70 | |||
| 71 | plt.figure(figsize=(8, 8)) | ||
| 72 | plt.subplot(2, 1, 1) | ||
| 73 | plt.plot(acc, label='Training Accuracy') | ||
| 74 | plt.plot(val_acc, label='Validation Accuracy') | ||
| 75 | plt.legend(loc='lower right') | ||
| 76 | plt.ylabel('Accuracy') | ||
| 77 | plt.ylim([min(plt.ylim()), 1]) | ||
| 78 | plt.title('Training and Validation Accuracy') | ||
| 79 | |||
| 80 | plt.subplot(2, 1, 2) | ||
| 81 | plt.plot(loss, label='Training Loss') | ||
| 82 | plt.plot(val_loss, label='Validation Loss') | ||
| 83 | plt.legend(loc='upper right') | ||
| 84 | plt.ylabel('Cross Entropy') | ||
| 85 | plt.ylim([0, 1.0]) | ||
| 86 | plt.title('Training and Validation Loss') | ||
| 87 | plt.xlabel('epoch') | ||
| 88 | # plt.show() | ||
| 89 | plt.savefig(save_path) | ||
| 90 | |||
| 91 | @staticmethod | ||
| 92 | def get_class_label_map(class_name_list, class_other_first=False): | 32 | def get_class_label_map(class_name_list, class_other_first=False): |
| 93 | return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)} | 33 | return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)} |
| 94 | 34 | ||
| ... | @@ -103,7 +43,6 @@ class F3Classification(BaseModel): | ... | @@ -103,7 +43,6 @@ class F3Classification(BaseModel): |
| 103 | continue | 43 | continue |
| 104 | label = self.class_label_map[class_name] | 44 | label = self.class_label_map[class_name] |
| 105 | for file_name in os.listdir(class_dir_path): | 45 | for file_name in os.listdir(class_dir_path): |
| 106 | # TODO image check | ||
| 107 | if os.path.splitext(file_name)[1] not in self.image_ext_set: | 46 | if os.path.splitext(file_name)[1] not in self.image_ext_set: |
| 108 | continue | 47 | continue |
| 109 | file_path = os.path.join(class_dir_path, file_name) | 48 | file_path = os.path.join(class_dir_path, file_name) |
| ... | @@ -153,7 +92,7 @@ class F3Classification(BaseModel): | ... | @@ -153,7 +92,7 @@ class F3Classification(BaseModel): |
| 153 | # @tf.function | 92 | # @tf.function |
| 154 | def load_image(image_path, label): | 93 | def load_image(image_path, label): |
| 155 | image = tf.io.read_file(image_path) | 94 | image = tf.io.read_file(image_path) |
| 156 | # image = tf.image.decode_image(image, channels=3) # TODO 为什么不行 | 95 | # image = tf.image.decode_image(image, channels=3) # TODO ? |
| 157 | image = tf.image.decode_png(image, channels=3) | 96 | image = tf.image.decode_png(image, channels=3) |
| 158 | return image, label | 97 | return image, label |
| 159 | 98 | ||
| ... | @@ -186,7 +125,10 @@ class F3Classification(BaseModel): | ... | @@ -186,7 +125,10 @@ class F3Classification(BaseModel): |
| 186 | ).prefetch(tf.data.AUTOTUNE) | 125 | ).prefetch(tf.data.AUTOTUNE) |
| 187 | return parallel_batch_dataset | 126 | return parallel_batch_dataset |
| 188 | 127 | ||
| 189 | def load_model(self): | 128 | def load_model(self, for_training=False, load_weights_path=None): |
| 129 | if self.model is not None: | ||
| 130 | raise Exception('Model is loaded, if you are sure to reload the model, set `self.model = None` first') | ||
| 131 | |||
| 190 | base_model = MobileNetV2( | 132 | base_model = MobileNetV2( |
| 191 | input_shape=(224, 224, 3), | 133 | input_shape=(224, 224, 3), |
| 192 | alpha=0.35, | 134 | alpha=0.35, |
| ... | @@ -199,27 +141,41 @@ class F3Classification(BaseModel): | ... | @@ -199,27 +141,41 @@ class F3Classification(BaseModel): |
| 199 | x = layers.Dense(256, activation='sigmoid', name='dense')(x) | 141 | x = layers.Dense(256, activation='sigmoid', name='dense')(x) |
| 200 | x = layers.Dropout(0.5)(x) | 142 | x = layers.Dropout(0.5)(x) |
| 201 | x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x) | 143 | x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x) |
| 202 | model = models.Model(inputs=base_model.input, outputs=x) | 144 | self.model = models.Model(inputs=base_model.input, outputs=x) |
| 203 | 145 | ||
| 146 | if for_training: | ||
| 204 | freeze = True | 147 | freeze = True |
| 205 | for layer in model.layers: | 148 | for layer in self.model.layers: |
| 206 | layer.trainable = not freeze | 149 | layer.trainable = not freeze |
| 207 | if freeze and layer.name == 'block_16_project_BN': | 150 | if freeze and layer.name == 'block_16_project_BN': |
| 208 | freeze = False | 151 | freeze = False |
| 209 | return model | ||
| 210 | 152 | ||
| 211 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path, | 153 | if isinstance(load_weights_path, str): |
| 212 | train_dir_name='train', validate_dir_name='test', thresholds=0.5): | 154 | if not os.path.isfile(load_weights_path): |
| 155 | raise Exception('load_weights_path can not find') | ||
| 156 | self.model.load_weights(load_weights_path, by_name=True, skip_mismatch=True) | ||
| 157 | |||
| 158 | def train(self, | ||
| 159 | dataset_dir, | ||
| 160 | epoch, | ||
| 161 | batch_size, | ||
| 162 | ckpt_path, | ||
| 163 | history_save_path, | ||
| 164 | load_weights_path=None, | ||
| 165 | train_dir_name='train', | ||
| 166 | validate_dir_name='test', | ||
| 167 | thresholds=0.5, | ||
| 168 | metrics_name='accuracy'): | ||
| 213 | 169 | ||
| 214 | self.gpu_config() | 170 | self.gpu_config() |
| 215 | 171 | ||
| 216 | model = self.load_model() | 172 | self.load_model(for_training=True, load_weights_path=load_weights_path) |
| 217 | model.summary() | 173 | self.model.summary() |
| 218 | 174 | ||
| 219 | model.compile( | 175 | self.model.compile( |
| 220 | optimizer=optimizers.Adam(learning_rate=3e-4), | 176 | optimizer=optimizers.Adam(learning_rate=3e-4), |
| 221 | loss=tfa.losses.SigmoidFocalCrossEntropy(), # TODO >>> | 177 | loss=tfa.losses.SigmoidFocalCrossEntropy(), # TODO ? |
| 222 | metrics=[CustomMetric(thresholds), ], | 178 | metrics=[CustomMetric(thresholds, name=metrics_name), ], |
| 223 | 179 | ||
| 224 | loss_weights=None, | 180 | loss_weights=None, |
| 225 | weighted_metrics=None, | 181 | weighted_metrics=None, |
| ... | @@ -250,14 +206,71 @@ class F3Classification(BaseModel): | ... | @@ -250,14 +206,71 @@ class F3Classification(BaseModel): |
| 250 | 206 | ||
| 251 | ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) | 207 | ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) |
| 252 | 208 | ||
| 253 | history = model.fit( | 209 | history = self.model.fit( |
| 254 | train_dataset, | 210 | train_dataset, |
| 255 | epochs=epoch, | 211 | epochs=epoch, |
| 256 | validation_data=validate_dataset, | 212 | validation_data=validate_dataset, |
| 257 | callbacks=[ckpt_callback, ], | 213 | callbacks=[ckpt_callback, ], |
| 258 | ) | 214 | ) |
| 259 | 215 | ||
| 260 | self.history_save(history, history_save_path) | 216 | history_save(history, history_save_path, metrics_name) |
| 217 | |||
| 218 | def evaluation(self, | ||
| 219 | load_weights_path, | ||
| 220 | confusion_matrix_save_path, | ||
| 221 | dataset_dir, | ||
| 222 | batch_size, | ||
| 223 | validate_dir_name='test', | ||
| 224 | thresholds=0.5): | ||
| 225 | self.gpu_config() | ||
| 226 | |||
| 227 | self.load_model(load_weights_path=load_weights_path) | ||
| 228 | self.model.summary() | ||
| 229 | |||
| 230 | validate_dataset = self.load_dataset( | ||
| 231 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), | ||
| 232 | name=validate_dir_name, | ||
| 233 | batch_size=batch_size, | ||
| 234 | augmentation_methods=[] | ||
| 235 | ) | ||
| 236 | |||
| 237 | label_true_list = [] | ||
| 238 | label_pred_list = [] | ||
| 239 | custom_metric = CustomMetric(thresholds) | ||
| 240 | for image_batch, y_true_batch in validate_dataset: | ||
| 241 | y_pred_batch = self.model.predict(image_batch) | ||
| 242 | label_true_batch_with_others = custom_metric.y_true_with_others(y_true_batch) | ||
| 243 | label_pred_batch_with_others = custom_metric.y_pred_with_others(y_pred_batch) | ||
| 244 | label_true_list.extend(label_true_batch_with_others.numpy()) | ||
| 245 | label_pred_list.extend(label_pred_batch_with_others.numpy()) | ||
| 246 | acc = accuracy_score(label_true_list, label_pred_list) | ||
| 247 | cm = confusion_matrix(label_true_list, label_pred_list) | ||
| 248 | report = classification_report(label_true_list, label_pred_list) | ||
| 249 | print(acc) | ||
| 250 | print(cm) | ||
| 251 | print(report) | ||
| 252 | plot_confusion_matrix(cm, [idx for idx in range(len(self.class_name_list))], confusion_matrix_save_path) | ||
| 253 | |||
| 254 | def predict(self, image, thresholds=0.5): | ||
| 255 | if self.model is None: | ||
| 256 | raise Exception("The model hasn't loaded yet, run `self.load_model()` first") | ||
| 257 | input_image, _ = self.preprocess_input(image, None) | ||
| 258 | input_images = tf.expand_dims(input_image, axis=0) | ||
| 259 | outputs = self.model.predict(input_images) | ||
| 260 | |||
| 261 | for output in outputs: | ||
| 262 | idx = tf.math.argmax(output) | ||
| 263 | confidence = output[idx] | ||
| 264 | if confidence < thresholds: | ||
| 265 | idx = -1 | ||
| 266 | label = self.class_name_list[idx + 1] | ||
| 267 | break | ||
| 268 | |||
| 269 | res = { | ||
| 270 | 'label': label, | ||
| 271 | 'confidence': confidence | ||
| 272 | } | ||
| 273 | return res | ||
| 261 | 274 | ||
| 262 | def test(self): | 275 | def test(self): |
| 263 | y_true = [ | 276 | y_true = [ | ... | ... |
classification/utils.py
0 → 100644
| 1 | import numpy as np | ||
| 2 | import itertools | ||
| 3 | import matplotlib.pyplot as plt | ||
| 4 | |||
| 5 | |||
| 6 | def history_save(history, save_path, metrics_name='accuracy'): | ||
| 7 | acc = history.history[metrics_name] | ||
| 8 | val_acc = history.history['val_{0}'.format(metrics_name)] | ||
| 9 | |||
| 10 | loss = history.history['loss'] | ||
| 11 | val_loss = history.history['val_loss'] | ||
| 12 | |||
| 13 | plt.figure(figsize=(8, 8)) | ||
| 14 | plt.subplot(2, 1, 1) | ||
| 15 | plt.plot(acc, label='Training Accuracy') | ||
| 16 | plt.plot(val_acc, label='Validation Accuracy') | ||
| 17 | plt.legend(loc='lower right') | ||
| 18 | plt.ylabel('Accuracy') | ||
| 19 | plt.ylim([min(plt.ylim()), 1]) | ||
| 20 | plt.title('Training and Validation Accuracy') | ||
| 21 | |||
| 22 | plt.subplot(2, 1, 2) | ||
| 23 | plt.plot(loss, label='Training Loss') | ||
| 24 | plt.plot(val_loss, label='Validation Loss') | ||
| 25 | plt.legend(loc='upper right') | ||
| 26 | plt.ylabel('Cross Entropy') | ||
| 27 | plt.ylim([0, 1.0]) | ||
| 28 | plt.title('Training and Validation Loss') | ||
| 29 | plt.xlabel('epoch') | ||
| 30 | # plt.show() | ||
| 31 | plt.savefig(save_path) | ||
| 32 | |||
| 33 | |||
| 34 | def plot_confusion_matrix(cm, class_names, save_path): | ||
| 35 | """ | ||
| 36 | Returns a matplotlib figure containing the plotted confusion matrix. | ||
| 37 | |||
| 38 | Args: | ||
| 39 | cm (array, shape = [n, n]): a confusion matrix of integer classes | ||
| 40 | class_names (array, shape = [n]): String names of the integer classes | ||
| 41 | save_path (str): figure save path | ||
| 42 | """ | ||
| 43 | figure = plt.figure(figsize=(8, 8)) | ||
| 44 | plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) | ||
| 45 | plt.title("Confusion matrix") | ||
| 46 | plt.colorbar() | ||
| 47 | tick_marks = np.arange(len(class_names)) | ||
| 48 | plt.xticks(tick_marks, class_names, rotation=45) | ||
| 49 | plt.yticks(tick_marks, class_names) | ||
| 50 | |||
| 51 | # Compute the labels from the normalized confusion matrix. | ||
| 52 | labels = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2) | ||
| 53 | # labels = cm.astype('int') | ||
| 54 | |||
| 55 | # Use white text if squares are dark; otherwise black. | ||
| 56 | threshold = cm.max() / 2. | ||
| 57 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): | ||
| 58 | color = "white" if cm[i, j] > threshold else "black" | ||
| 59 | plt.text(j, i, labels[i, j], horizontalalignment="center", color=color) | ||
| 60 | |||
| 61 | plt.tight_layout() | ||
| 62 | plt.ylabel('True label') | ||
| 63 | plt.xlabel('Predicted label') | ||
| 64 | plt.savefig(save_path) |
-
Please register or sign in to post a comment