add predict
Showing
11 changed files
with
258 additions
and
130 deletions
README.md
0 → 100644
1 | ## Useage | ||
2 | |||
3 | ### 分类 | ||
4 | ```python | ||
5 | import cv2 | ||
6 | from classification import classifier | ||
7 | |||
8 | img_path = 'xxx' | ||
9 | img = cv2.imread(img_path) | ||
10 | |||
11 | print(classifier.class_name_list) | ||
12 | res = classifier.predict(img) | ||
13 | print(res) # {'label': '营业执照', 'confidence': 0.988462} | ||
14 | ``` | ||
15 | |||
16 | ### 授权书信息提取 | ||
17 | ```python | ||
18 | from authorization_from import retriever_individuals, retriever_companies | ||
19 | |||
20 | # 个人授权书 | ||
21 | res = retriever_companies.get_target_fields(go_res, signature_res) | ||
22 | print(res) | ||
23 | |||
24 | # 公司授权书 | ||
25 | # res = retriever_individuals.get_target_fields(go_res, signature_res) | ||
26 | # print(res) | ||
27 | ``` |
authorization_from/README.md
deleted
100644 → 0
1 | ## Useage | ||
2 | **F3个人授权书和企业授权书的信息提取** | ||
3 | |||
4 | ```python | ||
5 | from retriever import Retriever | ||
6 | import const | ||
7 | |||
8 | # 个人授权书 {'姓名': 'xxx', '个人身份证件号码': 'xxx', '签字': '有'} | ||
9 | r = Retriever(const.TARGET_FIELD_INDIVIDUALS) | ||
10 | |||
11 | # 企业授权书 {'经销商名称': 'xx', '经销商代码-宝马中国': 'xx', '管理人员姓名-总经理': 'xx', '公司公章': '有', '法定代表人签章': '有'} | ||
12 | # r = Retriever(const.TARGET_FIELD_COMPANIES) | ||
13 | res = r.get_target_fields(go_res, signature_res) | ||
14 | ``` | ||
15 | |||
16 | |||
17 | |||
18 |
authorization_from/__init__.py
0 → 100644
classification/__init__.py
0 → 100644
1 | import os.path | ||
2 | |||
3 | from .model import F3Classification | ||
4 | from .const import CLASS_CN_LIST, CLASS_OTHER_FIRST | ||
5 | |||
6 | classifier = F3Classification( | ||
7 | class_name_list=CLASS_CN_LIST, | ||
8 | class_other_first=CLASS_OTHER_FIRST | ||
9 | ) | ||
10 | |||
11 | classifier.load_model(load_weights_path=os.path.join( | ||
12 | os.path.dirname(os.path.abspath(__file__)), 'ckpt_prod.h5')) | ||
13 |
... | @@ -3,14 +3,14 @@ class BaseModel: | ... | @@ -3,14 +3,14 @@ class BaseModel: |
3 | All Model classes should extend BaseModel. | 3 | All Model classes should extend BaseModel. |
4 | """ | 4 | """ |
5 | 5 | ||
6 | def load_model(self): | 6 | def load_model(self, for_training=False, load_weights_path=None): |
7 | """ | 7 | """ |
8 | Defining the network structure and return | 8 | Defining the network structure and return |
9 | """ | 9 | """ |
10 | raise NotImplementedError(".load() must be overridden.") | 10 | raise NotImplementedError(".load() must be overridden.") |
11 | 11 | ||
12 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path, | 12 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path, |
13 | train_dir_name='train', validate_dir_name='test'): | 13 | train_dir_name='train', validate_dir_name='test', thresholds=0.5, metrics_name='accuracy'): |
14 | """ | 14 | """ |
15 | Model training process | 15 | Model training process |
16 | """ | 16 | """ | ... | ... |
... | @@ -2,7 +2,8 @@ CLASS_OTHER_CN = '其他' | ... | @@ -2,7 +2,8 @@ CLASS_OTHER_CN = '其他' |
2 | 2 | ||
3 | CLASS_OTHER_FIRST = True | 3 | CLASS_OTHER_FIRST = True |
4 | 4 | ||
5 | CLASS_CN_LIST = [CLASS_OTHER_CN, '身份证', '营业执照', '经销商授权书', '个人授权书'] | 5 | # CLASS_CN_LIST = [CLASS_OTHER_CN, '身份证', '营业执照', '经销商授权书', '个人授权书'] |
6 | CLASS_CN_LIST = [CLASS_OTHER_CN, '营业执照', '经销商授权书', '个人授权书'] | ||
6 | 7 | ||
7 | OTHER_THRESHOLDS = 0.5 | 8 | OTHER_THRESHOLDS = 0.5 |
8 | 9 | ... | ... |
classification/main.py
deleted
100644 → 0
1 | import os | ||
2 | from datetime import datetime | ||
3 | from model import F3Classification | ||
4 | import const | ||
5 | |||
6 | |||
7 | if __name__ == '__main__': | ||
8 | base_dir = os.path.dirname(os.path.abspath(__file__)) | ||
9 | |||
10 | m = F3Classification( | ||
11 | class_name_list=const.CLASS_CN_LIST, | ||
12 | class_other_first=const.CLASS_OTHER_FIRST | ||
13 | ) | ||
14 | |||
15 | # m.test() | ||
16 | |||
17 | dataset_dir = '/home/zwq/data/data_224_f3' | ||
18 | ckpt_path = os.path.join(base_dir, 'ckpt_{0}.h5'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))) | ||
19 | history_save_path = os.path.join(base_dir, 'history_{0}.jpg'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))) | ||
20 | epoch = 100 | ||
21 | batch_size = 128 | ||
22 | |||
23 | m.train(dataset_dir, epoch, batch_size, ckpt_path, history_save_path, | ||
24 | train_dir_name='train', validate_dir_name='test', thresholds=const.OTHER_THRESHOLDS) | ||
25 |
classification/metrics.py
0 → 100644
1 | import tensorflow as tf | ||
2 | from keras import metrics | ||
3 | |||
4 | |||
5 | class CustomMetric(metrics.Metric): | ||
6 | |||
7 | def __init__(self, thresholds=0.5, name="custom_metric", **kwargs): | ||
8 | super(CustomMetric, self).__init__(name=name, **kwargs) | ||
9 | self.thresholds = thresholds | ||
10 | self.true_positives = self.add_weight(name="ctp", initializer="zeros") | ||
11 | self.count = self.add_weight(name="count", initializer="zeros", dtype='int32') | ||
12 | |||
13 | @staticmethod | ||
14 | def y_true_with_others(y_true): | ||
15 | y_true_idx = tf.argmax(y_true, axis=1) + 1 | ||
16 | y_true_is_other = tf.cast(tf.math.reduce_sum(y_true, axis=1), "int64") | ||
17 | y_true = tf.math.multiply(y_true_idx, y_true_is_other) | ||
18 | return y_true | ||
19 | |||
20 | def y_pred_with_others(self, y_pred): | ||
21 | y_pred_idx = tf.argmax(y_pred, axis=1) + 1 | ||
22 | y_pred_is_other = tf.cast(tf.math.greater_equal(tf.math.reduce_max(y_pred, axis=1), self.thresholds), 'int64') | ||
23 | y_pred = tf.math.multiply(y_pred_idx, y_pred_is_other) | ||
24 | return y_pred | ||
25 | |||
26 | def update_state(self, y_true, y_pred, sample_weight=None): | ||
27 | y_true = self.y_true_with_others(y_true) | ||
28 | y_pred = self.y_pred_with_others(y_pred) | ||
29 | |||
30 | # print(y_true) | ||
31 | # print(y_pred) | ||
32 | |||
33 | values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32") | ||
34 | values = tf.cast(values, "float32") | ||
35 | if sample_weight is not None: | ||
36 | sample_weight = tf.cast(sample_weight, "float32") | ||
37 | values = tf.multiply(values, sample_weight) | ||
38 | self.true_positives.assign_add(tf.reduce_sum(values)) | ||
39 | self.count.assign_add(tf.shape(y_true)[0]) | ||
40 | |||
41 | def result(self): | ||
42 | return self.true_positives / tf.cast(self.count, 'float32') | ||
43 | |||
44 | def reset_state(self): | ||
45 | # The state of the metric will be reset at the start of each epoch. | ||
46 | self.true_positives.assign(0.0) | ||
47 | self.count.assign(0) |
... | @@ -2,57 +2,25 @@ import os | ... | @@ -2,57 +2,25 @@ import os |
2 | import random | 2 | import random |
3 | import tensorflow as tf | 3 | import tensorflow as tf |
4 | import tensorflow_addons as tfa | 4 | import tensorflow_addons as tfa |
5 | from keras.applications.mobilenet_v2 import MobileNetV2 | ||
6 | from keras import layers, models, optimizers, losses, metrics, callbacks, applications | ||
7 | import matplotlib.pyplot as plt | ||
8 | |||
9 | from base_class import BaseModel | ||
10 | |||
11 | |||
12 | class CustomMetric(metrics.Metric): | ||
13 | |||
14 | def __init__(self, thresholds=0.5, name="custom_metric", **kwargs): | ||
15 | super(CustomMetric, self).__init__(name=name, **kwargs) | ||
16 | self.thresholds = thresholds | ||
17 | self.true_positives = self.add_weight(name="ctp", initializer="zeros") | ||
18 | self.count = self.add_weight(name="count", initializer="zeros", dtype='int32') | ||
19 | |||
20 | def update_state(self, y_true, y_pred, sample_weight=None): | ||
21 | y_true_idx = tf.argmax(y_true, axis=1) + 1 | ||
22 | y_true_is_other = tf.cast(tf.math.reduce_sum(y_true, axis=1), "int64") | ||
23 | y_true = tf.math.multiply(y_true_idx, y_true_is_other) | ||
24 | 5 | ||
25 | y_pred_idx = tf.argmax(y_pred, axis=1) + 1 | 6 | from keras.applications.mobilenet_v2 import MobileNetV2 |
26 | y_pred_is_other = tf.cast(tf.math.greater_equal(tf.math.reduce_max(y_pred, axis=1), self.thresholds), 'int64') | 7 | from keras import layers, models, optimizers, callbacks, applications |
27 | y_pred = tf.math.multiply(y_pred_idx, y_pred_is_other) | 8 | from sklearn.metrics import confusion_matrix, accuracy_score, classification_report |
28 | |||
29 | print(y_true) | ||
30 | print(y_pred) | ||
31 | |||
32 | values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32") | ||
33 | values = tf.cast(values, "float32") | ||
34 | if sample_weight is not None: | ||
35 | sample_weight = tf.cast(sample_weight, "float32") | ||
36 | values = tf.multiply(values, sample_weight) | ||
37 | self.true_positives.assign_add(tf.reduce_sum(values)) | ||
38 | self.count.assign_add(tf.shape(y_true)[0]) | ||
39 | |||
40 | def result(self): | ||
41 | return self.true_positives / tf.cast(self.count, 'float32') | ||
42 | 9 | ||
43 | def reset_state(self): | 10 | from .base_class import BaseModel |
44 | # The state of the metric will be reset at the start of each epoch. | 11 | from .metrics import CustomMetric |
45 | self.true_positives.assign(0.0) | 12 | from .utils import history_save, plot_confusion_matrix |
46 | self.count.assign(0) | ||
47 | 13 | ||
48 | 14 | ||
49 | class F3Classification(BaseModel): | 15 | class F3Classification(BaseModel): |
50 | 16 | ||
51 | def __init__(self, class_name_list, class_other_first, *args, **kwargs): | 17 | def __init__(self, class_name_list, class_other_first, *args, **kwargs): |
52 | super().__init__(*args, **kwargs) | 18 | super().__init__(*args, **kwargs) |
19 | self.class_name_list = class_name_list | ||
53 | self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1 | 20 | self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1 |
54 | self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) | 21 | self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) |
55 | self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} | 22 | self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} |
23 | self.model = None | ||
56 | 24 | ||
57 | @staticmethod | 25 | @staticmethod |
58 | def gpu_config(): | 26 | def gpu_config(): |
... | @@ -61,34 +29,6 @@ class F3Classification(BaseModel): | ... | @@ -61,34 +29,6 @@ class F3Classification(BaseModel): |
61 | tf.config.set_visible_devices(devices=gpus[1], device_type='GPU') | 29 | tf.config.set_visible_devices(devices=gpus[1], device_type='GPU') |
62 | 30 | ||
63 | @staticmethod | 31 | @staticmethod |
64 | def history_save(history, save_path): | ||
65 | acc = history.history['accuracy'] | ||
66 | val_acc = history.history['val_accuracy'] | ||
67 | |||
68 | loss = history.history['loss'] | ||
69 | val_loss = history.history['val_loss'] | ||
70 | |||
71 | plt.figure(figsize=(8, 8)) | ||
72 | plt.subplot(2, 1, 1) | ||
73 | plt.plot(acc, label='Training Accuracy') | ||
74 | plt.plot(val_acc, label='Validation Accuracy') | ||
75 | plt.legend(loc='lower right') | ||
76 | plt.ylabel('Accuracy') | ||
77 | plt.ylim([min(plt.ylim()), 1]) | ||
78 | plt.title('Training and Validation Accuracy') | ||
79 | |||
80 | plt.subplot(2, 1, 2) | ||
81 | plt.plot(loss, label='Training Loss') | ||
82 | plt.plot(val_loss, label='Validation Loss') | ||
83 | plt.legend(loc='upper right') | ||
84 | plt.ylabel('Cross Entropy') | ||
85 | plt.ylim([0, 1.0]) | ||
86 | plt.title('Training and Validation Loss') | ||
87 | plt.xlabel('epoch') | ||
88 | # plt.show() | ||
89 | plt.savefig(save_path) | ||
90 | |||
91 | @staticmethod | ||
92 | def get_class_label_map(class_name_list, class_other_first=False): | 32 | def get_class_label_map(class_name_list, class_other_first=False): |
93 | return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)} | 33 | return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)} |
94 | 34 | ||
... | @@ -103,7 +43,6 @@ class F3Classification(BaseModel): | ... | @@ -103,7 +43,6 @@ class F3Classification(BaseModel): |
103 | continue | 43 | continue |
104 | label = self.class_label_map[class_name] | 44 | label = self.class_label_map[class_name] |
105 | for file_name in os.listdir(class_dir_path): | 45 | for file_name in os.listdir(class_dir_path): |
106 | # TODO image check | ||
107 | if os.path.splitext(file_name)[1] not in self.image_ext_set: | 46 | if os.path.splitext(file_name)[1] not in self.image_ext_set: |
108 | continue | 47 | continue |
109 | file_path = os.path.join(class_dir_path, file_name) | 48 | file_path = os.path.join(class_dir_path, file_name) |
... | @@ -153,7 +92,7 @@ class F3Classification(BaseModel): | ... | @@ -153,7 +92,7 @@ class F3Classification(BaseModel): |
153 | # @tf.function | 92 | # @tf.function |
154 | def load_image(image_path, label): | 93 | def load_image(image_path, label): |
155 | image = tf.io.read_file(image_path) | 94 | image = tf.io.read_file(image_path) |
156 | # image = tf.image.decode_image(image, channels=3) # TODO 为什么不行 | 95 | # image = tf.image.decode_image(image, channels=3) # TODO ? |
157 | image = tf.image.decode_png(image, channels=3) | 96 | image = tf.image.decode_png(image, channels=3) |
158 | return image, label | 97 | return image, label |
159 | 98 | ||
... | @@ -186,7 +125,10 @@ class F3Classification(BaseModel): | ... | @@ -186,7 +125,10 @@ class F3Classification(BaseModel): |
186 | ).prefetch(tf.data.AUTOTUNE) | 125 | ).prefetch(tf.data.AUTOTUNE) |
187 | return parallel_batch_dataset | 126 | return parallel_batch_dataset |
188 | 127 | ||
189 | def load_model(self): | 128 | def load_model(self, for_training=False, load_weights_path=None): |
129 | if self.model is not None: | ||
130 | raise Exception('Model is loaded, if you are sure to reload the model, set `self.model = None` first') | ||
131 | |||
190 | base_model = MobileNetV2( | 132 | base_model = MobileNetV2( |
191 | input_shape=(224, 224, 3), | 133 | input_shape=(224, 224, 3), |
192 | alpha=0.35, | 134 | alpha=0.35, |
... | @@ -199,27 +141,41 @@ class F3Classification(BaseModel): | ... | @@ -199,27 +141,41 @@ class F3Classification(BaseModel): |
199 | x = layers.Dense(256, activation='sigmoid', name='dense')(x) | 141 | x = layers.Dense(256, activation='sigmoid', name='dense')(x) |
200 | x = layers.Dropout(0.5)(x) | 142 | x = layers.Dropout(0.5)(x) |
201 | x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x) | 143 | x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x) |
202 | model = models.Model(inputs=base_model.input, outputs=x) | 144 | self.model = models.Model(inputs=base_model.input, outputs=x) |
203 | 145 | ||
146 | if for_training: | ||
204 | freeze = True | 147 | freeze = True |
205 | for layer in model.layers: | 148 | for layer in self.model.layers: |
206 | layer.trainable = not freeze | 149 | layer.trainable = not freeze |
207 | if freeze and layer.name == 'block_16_project_BN': | 150 | if freeze and layer.name == 'block_16_project_BN': |
208 | freeze = False | 151 | freeze = False |
209 | return model | ||
210 | 152 | ||
211 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path, | 153 | if isinstance(load_weights_path, str): |
212 | train_dir_name='train', validate_dir_name='test', thresholds=0.5): | 154 | if not os.path.isfile(load_weights_path): |
155 | raise Exception('load_weights_path can not find') | ||
156 | self.model.load_weights(load_weights_path, by_name=True, skip_mismatch=True) | ||
157 | |||
158 | def train(self, | ||
159 | dataset_dir, | ||
160 | epoch, | ||
161 | batch_size, | ||
162 | ckpt_path, | ||
163 | history_save_path, | ||
164 | load_weights_path=None, | ||
165 | train_dir_name='train', | ||
166 | validate_dir_name='test', | ||
167 | thresholds=0.5, | ||
168 | metrics_name='accuracy'): | ||
213 | 169 | ||
214 | self.gpu_config() | 170 | self.gpu_config() |
215 | 171 | ||
216 | model = self.load_model() | 172 | self.load_model(for_training=True, load_weights_path=load_weights_path) |
217 | model.summary() | 173 | self.model.summary() |
218 | 174 | ||
219 | model.compile( | 175 | self.model.compile( |
220 | optimizer=optimizers.Adam(learning_rate=3e-4), | 176 | optimizer=optimizers.Adam(learning_rate=3e-4), |
221 | loss=tfa.losses.SigmoidFocalCrossEntropy(), # TODO >>> | 177 | loss=tfa.losses.SigmoidFocalCrossEntropy(), # TODO ? |
222 | metrics=[CustomMetric(thresholds), ], | 178 | metrics=[CustomMetric(thresholds, name=metrics_name), ], |
223 | 179 | ||
224 | loss_weights=None, | 180 | loss_weights=None, |
225 | weighted_metrics=None, | 181 | weighted_metrics=None, |
... | @@ -250,14 +206,71 @@ class F3Classification(BaseModel): | ... | @@ -250,14 +206,71 @@ class F3Classification(BaseModel): |
250 | 206 | ||
251 | ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) | 207 | ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) |
252 | 208 | ||
253 | history = model.fit( | 209 | history = self.model.fit( |
254 | train_dataset, | 210 | train_dataset, |
255 | epochs=epoch, | 211 | epochs=epoch, |
256 | validation_data=validate_dataset, | 212 | validation_data=validate_dataset, |
257 | callbacks=[ckpt_callback, ], | 213 | callbacks=[ckpt_callback, ], |
258 | ) | 214 | ) |
259 | 215 | ||
260 | self.history_save(history, history_save_path) | 216 | history_save(history, history_save_path, metrics_name) |
217 | |||
218 | def evaluation(self, | ||
219 | load_weights_path, | ||
220 | confusion_matrix_save_path, | ||
221 | dataset_dir, | ||
222 | batch_size, | ||
223 | validate_dir_name='test', | ||
224 | thresholds=0.5): | ||
225 | self.gpu_config() | ||
226 | |||
227 | self.load_model(load_weights_path=load_weights_path) | ||
228 | self.model.summary() | ||
229 | |||
230 | validate_dataset = self.load_dataset( | ||
231 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), | ||
232 | name=validate_dir_name, | ||
233 | batch_size=batch_size, | ||
234 | augmentation_methods=[] | ||
235 | ) | ||
236 | |||
237 | label_true_list = [] | ||
238 | label_pred_list = [] | ||
239 | custom_metric = CustomMetric(thresholds) | ||
240 | for image_batch, y_true_batch in validate_dataset: | ||
241 | y_pred_batch = self.model.predict(image_batch) | ||
242 | label_true_batch_with_others = custom_metric.y_true_with_others(y_true_batch) | ||
243 | label_pred_batch_with_others = custom_metric.y_pred_with_others(y_pred_batch) | ||
244 | label_true_list.extend(label_true_batch_with_others.numpy()) | ||
245 | label_pred_list.extend(label_pred_batch_with_others.numpy()) | ||
246 | acc = accuracy_score(label_true_list, label_pred_list) | ||
247 | cm = confusion_matrix(label_true_list, label_pred_list) | ||
248 | report = classification_report(label_true_list, label_pred_list) | ||
249 | print(acc) | ||
250 | print(cm) | ||
251 | print(report) | ||
252 | plot_confusion_matrix(cm, [idx for idx in range(len(self.class_name_list))], confusion_matrix_save_path) | ||
253 | |||
254 | def predict(self, image, thresholds=0.5): | ||
255 | if self.model is None: | ||
256 | raise Exception("The model hasn't loaded yet, run `self.load_model()` first") | ||
257 | input_image, _ = self.preprocess_input(image, None) | ||
258 | input_images = tf.expand_dims(input_image, axis=0) | ||
259 | outputs = self.model.predict(input_images) | ||
260 | |||
261 | for output in outputs: | ||
262 | idx = tf.math.argmax(output) | ||
263 | confidence = output[idx] | ||
264 | if confidence < thresholds: | ||
265 | idx = -1 | ||
266 | label = self.class_name_list[idx + 1] | ||
267 | break | ||
268 | |||
269 | res = { | ||
270 | 'label': label, | ||
271 | 'confidence': confidence | ||
272 | } | ||
273 | return res | ||
261 | 274 | ||
262 | def test(self): | 275 | def test(self): |
263 | y_true = [ | 276 | y_true = [ | ... | ... |
classification/utils.py
0 → 100644
1 | import numpy as np | ||
2 | import itertools | ||
3 | import matplotlib.pyplot as plt | ||
4 | |||
5 | |||
6 | def history_save(history, save_path, metrics_name='accuracy'): | ||
7 | acc = history.history[metrics_name] | ||
8 | val_acc = history.history['val_{0}'.format(metrics_name)] | ||
9 | |||
10 | loss = history.history['loss'] | ||
11 | val_loss = history.history['val_loss'] | ||
12 | |||
13 | plt.figure(figsize=(8, 8)) | ||
14 | plt.subplot(2, 1, 1) | ||
15 | plt.plot(acc, label='Training Accuracy') | ||
16 | plt.plot(val_acc, label='Validation Accuracy') | ||
17 | plt.legend(loc='lower right') | ||
18 | plt.ylabel('Accuracy') | ||
19 | plt.ylim([min(plt.ylim()), 1]) | ||
20 | plt.title('Training and Validation Accuracy') | ||
21 | |||
22 | plt.subplot(2, 1, 2) | ||
23 | plt.plot(loss, label='Training Loss') | ||
24 | plt.plot(val_loss, label='Validation Loss') | ||
25 | plt.legend(loc='upper right') | ||
26 | plt.ylabel('Cross Entropy') | ||
27 | plt.ylim([0, 1.0]) | ||
28 | plt.title('Training and Validation Loss') | ||
29 | plt.xlabel('epoch') | ||
30 | # plt.show() | ||
31 | plt.savefig(save_path) | ||
32 | |||
33 | |||
34 | def plot_confusion_matrix(cm, class_names, save_path): | ||
35 | """ | ||
36 | Returns a matplotlib figure containing the plotted confusion matrix. | ||
37 | |||
38 | Args: | ||
39 | cm (array, shape = [n, n]): a confusion matrix of integer classes | ||
40 | class_names (array, shape = [n]): String names of the integer classes | ||
41 | save_path (str): figure save path | ||
42 | """ | ||
43 | figure = plt.figure(figsize=(8, 8)) | ||
44 | plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) | ||
45 | plt.title("Confusion matrix") | ||
46 | plt.colorbar() | ||
47 | tick_marks = np.arange(len(class_names)) | ||
48 | plt.xticks(tick_marks, class_names, rotation=45) | ||
49 | plt.yticks(tick_marks, class_names) | ||
50 | |||
51 | # Compute the labels from the normalized confusion matrix. | ||
52 | labels = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2) | ||
53 | # labels = cm.astype('int') | ||
54 | |||
55 | # Use white text if squares are dark; otherwise black. | ||
56 | threshold = cm.max() / 2. | ||
57 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): | ||
58 | color = "white" if cm[i, j] > threshold else "black" | ||
59 | plt.text(j, i, labels[i, j], horizontalalignment="center", color=color) | ||
60 | |||
61 | plt.tight_layout() | ||
62 | plt.ylabel('True label') | ||
63 | plt.xlabel('Predicted label') | ||
64 | plt.savefig(save_path) |
-
Please register or sign in to post a comment