bc96c928 by 周伟奇

add predict

1 parent 1ea84670
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
11 .* 11 .*
12 !.gitignore 12 !.gitignore
13 13
14 test.py 14 test*
15 *.h5 15 *.h5
16 *.jpg 16 *.jpg
17 *.out 17 *.out
...\ No newline at end of file ...\ No newline at end of file
......
1 ## Useage
2
3 ### 分类
4 ```python
5 import cv2
6 from classification import classifier
7
8 img_path = 'xxx'
9 img = cv2.imread(img_path)
10
11 print(classifier.class_name_list)
12 res = classifier.predict(img)
13 print(res) # {'label': '营业执照', 'confidence': 0.988462}
14 ```
15
16 ### 授权书信息提取
17 ```python
18 from authorization_from import retriever_individuals, retriever_companies
19
20 # 个人授权书
21 res = retriever_companies.get_target_fields(go_res, signature_res)
22 print(res)
23
24 # 公司授权书
25 # res = retriever_individuals.get_target_fields(go_res, signature_res)
26 # print(res)
27 ```
1 ## Useage
2 **F3个人授权书和企业授权书的信息提取**
3
4 ```python
5 from retriever import Retriever
6 import const
7
8 # 个人授权书 {'姓名': 'xxx', '个人身份证件号码': 'xxx', '签字': '有'}
9 r = Retriever(const.TARGET_FIELD_INDIVIDUALS)
10
11 # 企业授权书 {'经销商名称': 'xx', '经销商代码-宝马中国': 'xx', '管理人员姓名-总经理': 'xx', '公司公章': '有', '法定代表人签章': '有'}
12 # r = Retriever(const.TARGET_FIELD_COMPANIES)
13 res = r.get_target_fields(go_res, signature_res)
14 ```
15
16
17
18
1 from .retriever import Retriever
2 from .const import TARGET_FIELD_INDIVIDUALS, TARGET_FIELD_COMPANIES
3
4 retriever_individuals = Retriever(const.TARGET_FIELD_INDIVIDUALS)
5 retriever_companies = Retriever(const.TARGET_FIELD_COMPANIES)
6
1 import os.path
2
3 from .model import F3Classification
4 from .const import CLASS_CN_LIST, CLASS_OTHER_FIRST
5
6 classifier = F3Classification(
7 class_name_list=CLASS_CN_LIST,
8 class_other_first=CLASS_OTHER_FIRST
9 )
10
11 classifier.load_model(load_weights_path=os.path.join(
12 os.path.dirname(os.path.abspath(__file__)), 'ckpt_prod.h5'))
13
...@@ -3,14 +3,14 @@ class BaseModel: ...@@ -3,14 +3,14 @@ class BaseModel:
3 All Model classes should extend BaseModel. 3 All Model classes should extend BaseModel.
4 """ 4 """
5 5
6 def load_model(self): 6 def load_model(self, for_training=False, load_weights_path=None):
7 """ 7 """
8 Defining the network structure and return 8 Defining the network structure and return
9 """ 9 """
10 raise NotImplementedError(".load() must be overridden.") 10 raise NotImplementedError(".load() must be overridden.")
11 11
12 def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path, 12 def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path,
13 train_dir_name='train', validate_dir_name='test'): 13 train_dir_name='train', validate_dir_name='test', thresholds=0.5, metrics_name='accuracy'):
14 """ 14 """
15 Model training process 15 Model training process
16 """ 16 """
......
...@@ -2,7 +2,8 @@ CLASS_OTHER_CN = '其他' ...@@ -2,7 +2,8 @@ CLASS_OTHER_CN = '其他'
2 2
3 CLASS_OTHER_FIRST = True 3 CLASS_OTHER_FIRST = True
4 4
5 CLASS_CN_LIST = [CLASS_OTHER_CN, '身份证', '营业执照', '经销商授权书', '个人授权书'] 5 # CLASS_CN_LIST = [CLASS_OTHER_CN, '身份证', '营业执照', '经销商授权书', '个人授权书']
6 CLASS_CN_LIST = [CLASS_OTHER_CN, '营业执照', '经销商授权书', '个人授权书']
6 7
7 OTHER_THRESHOLDS = 0.5 8 OTHER_THRESHOLDS = 0.5
8 9
......
1 import os
2 from datetime import datetime
3 from model import F3Classification
4 import const
5
6
7 if __name__ == '__main__':
8 base_dir = os.path.dirname(os.path.abspath(__file__))
9
10 m = F3Classification(
11 class_name_list=const.CLASS_CN_LIST,
12 class_other_first=const.CLASS_OTHER_FIRST
13 )
14
15 # m.test()
16
17 dataset_dir = '/home/zwq/data/data_224_f3'
18 ckpt_path = os.path.join(base_dir, 'ckpt_{0}.h5'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S')))
19 history_save_path = os.path.join(base_dir, 'history_{0}.jpg'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S')))
20 epoch = 100
21 batch_size = 128
22
23 m.train(dataset_dir, epoch, batch_size, ckpt_path, history_save_path,
24 train_dir_name='train', validate_dir_name='test', thresholds=const.OTHER_THRESHOLDS)
25
1 import tensorflow as tf
2 from keras import metrics
3
4
5 class CustomMetric(metrics.Metric):
6
7 def __init__(self, thresholds=0.5, name="custom_metric", **kwargs):
8 super(CustomMetric, self).__init__(name=name, **kwargs)
9 self.thresholds = thresholds
10 self.true_positives = self.add_weight(name="ctp", initializer="zeros")
11 self.count = self.add_weight(name="count", initializer="zeros", dtype='int32')
12
13 @staticmethod
14 def y_true_with_others(y_true):
15 y_true_idx = tf.argmax(y_true, axis=1) + 1
16 y_true_is_other = tf.cast(tf.math.reduce_sum(y_true, axis=1), "int64")
17 y_true = tf.math.multiply(y_true_idx, y_true_is_other)
18 return y_true
19
20 def y_pred_with_others(self, y_pred):
21 y_pred_idx = tf.argmax(y_pred, axis=1) + 1
22 y_pred_is_other = tf.cast(tf.math.greater_equal(tf.math.reduce_max(y_pred, axis=1), self.thresholds), 'int64')
23 y_pred = tf.math.multiply(y_pred_idx, y_pred_is_other)
24 return y_pred
25
26 def update_state(self, y_true, y_pred, sample_weight=None):
27 y_true = self.y_true_with_others(y_true)
28 y_pred = self.y_pred_with_others(y_pred)
29
30 # print(y_true)
31 # print(y_pred)
32
33 values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32")
34 values = tf.cast(values, "float32")
35 if sample_weight is not None:
36 sample_weight = tf.cast(sample_weight, "float32")
37 values = tf.multiply(values, sample_weight)
38 self.true_positives.assign_add(tf.reduce_sum(values))
39 self.count.assign_add(tf.shape(y_true)[0])
40
41 def result(self):
42 return self.true_positives / tf.cast(self.count, 'float32')
43
44 def reset_state(self):
45 # The state of the metric will be reset at the start of each epoch.
46 self.true_positives.assign(0.0)
47 self.count.assign(0)
...@@ -2,57 +2,25 @@ import os ...@@ -2,57 +2,25 @@ import os
2 import random 2 import random
3 import tensorflow as tf 3 import tensorflow as tf
4 import tensorflow_addons as tfa 4 import tensorflow_addons as tfa
5 from keras.applications.mobilenet_v2 import MobileNetV2
6 from keras import layers, models, optimizers, losses, metrics, callbacks, applications
7 import matplotlib.pyplot as plt
8
9 from base_class import BaseModel
10
11
12 class CustomMetric(metrics.Metric):
13
14 def __init__(self, thresholds=0.5, name="custom_metric", **kwargs):
15 super(CustomMetric, self).__init__(name=name, **kwargs)
16 self.thresholds = thresholds
17 self.true_positives = self.add_weight(name="ctp", initializer="zeros")
18 self.count = self.add_weight(name="count", initializer="zeros", dtype='int32')
19
20 def update_state(self, y_true, y_pred, sample_weight=None):
21 y_true_idx = tf.argmax(y_true, axis=1) + 1
22 y_true_is_other = tf.cast(tf.math.reduce_sum(y_true, axis=1), "int64")
23 y_true = tf.math.multiply(y_true_idx, y_true_is_other)
24 5
25 y_pred_idx = tf.argmax(y_pred, axis=1) + 1 6 from keras.applications.mobilenet_v2 import MobileNetV2
26 y_pred_is_other = tf.cast(tf.math.greater_equal(tf.math.reduce_max(y_pred, axis=1), self.thresholds), 'int64') 7 from keras import layers, models, optimizers, callbacks, applications
27 y_pred = tf.math.multiply(y_pred_idx, y_pred_is_other) 8 from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
28
29 print(y_true)
30 print(y_pred)
31
32 values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32")
33 values = tf.cast(values, "float32")
34 if sample_weight is not None:
35 sample_weight = tf.cast(sample_weight, "float32")
36 values = tf.multiply(values, sample_weight)
37 self.true_positives.assign_add(tf.reduce_sum(values))
38 self.count.assign_add(tf.shape(y_true)[0])
39
40 def result(self):
41 return self.true_positives / tf.cast(self.count, 'float32')
42 9
43 def reset_state(self): 10 from .base_class import BaseModel
44 # The state of the metric will be reset at the start of each epoch. 11 from .metrics import CustomMetric
45 self.true_positives.assign(0.0) 12 from .utils import history_save, plot_confusion_matrix
46 self.count.assign(0)
47 13
48 14
49 class F3Classification(BaseModel): 15 class F3Classification(BaseModel):
50 16
51 def __init__(self, class_name_list, class_other_first, *args, **kwargs): 17 def __init__(self, class_name_list, class_other_first, *args, **kwargs):
52 super().__init__(*args, **kwargs) 18 super().__init__(*args, **kwargs)
19 self.class_name_list = class_name_list
53 self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1 20 self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1
54 self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) 21 self.class_label_map = self.get_class_label_map(class_name_list, class_other_first)
55 self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} 22 self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
23 self.model = None
56 24
57 @staticmethod 25 @staticmethod
58 def gpu_config(): 26 def gpu_config():
...@@ -61,34 +29,6 @@ class F3Classification(BaseModel): ...@@ -61,34 +29,6 @@ class F3Classification(BaseModel):
61 tf.config.set_visible_devices(devices=gpus[1], device_type='GPU') 29 tf.config.set_visible_devices(devices=gpus[1], device_type='GPU')
62 30
63 @staticmethod 31 @staticmethod
64 def history_save(history, save_path):
65 acc = history.history['accuracy']
66 val_acc = history.history['val_accuracy']
67
68 loss = history.history['loss']
69 val_loss = history.history['val_loss']
70
71 plt.figure(figsize=(8, 8))
72 plt.subplot(2, 1, 1)
73 plt.plot(acc, label='Training Accuracy')
74 plt.plot(val_acc, label='Validation Accuracy')
75 plt.legend(loc='lower right')
76 plt.ylabel('Accuracy')
77 plt.ylim([min(plt.ylim()), 1])
78 plt.title('Training and Validation Accuracy')
79
80 plt.subplot(2, 1, 2)
81 plt.plot(loss, label='Training Loss')
82 plt.plot(val_loss, label='Validation Loss')
83 plt.legend(loc='upper right')
84 plt.ylabel('Cross Entropy')
85 plt.ylim([0, 1.0])
86 plt.title('Training and Validation Loss')
87 plt.xlabel('epoch')
88 # plt.show()
89 plt.savefig(save_path)
90
91 @staticmethod
92 def get_class_label_map(class_name_list, class_other_first=False): 32 def get_class_label_map(class_name_list, class_other_first=False):
93 return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)} 33 return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)}
94 34
...@@ -103,7 +43,6 @@ class F3Classification(BaseModel): ...@@ -103,7 +43,6 @@ class F3Classification(BaseModel):
103 continue 43 continue
104 label = self.class_label_map[class_name] 44 label = self.class_label_map[class_name]
105 for file_name in os.listdir(class_dir_path): 45 for file_name in os.listdir(class_dir_path):
106 # TODO image check
107 if os.path.splitext(file_name)[1] not in self.image_ext_set: 46 if os.path.splitext(file_name)[1] not in self.image_ext_set:
108 continue 47 continue
109 file_path = os.path.join(class_dir_path, file_name) 48 file_path = os.path.join(class_dir_path, file_name)
...@@ -153,7 +92,7 @@ class F3Classification(BaseModel): ...@@ -153,7 +92,7 @@ class F3Classification(BaseModel):
153 # @tf.function 92 # @tf.function
154 def load_image(image_path, label): 93 def load_image(image_path, label):
155 image = tf.io.read_file(image_path) 94 image = tf.io.read_file(image_path)
156 # image = tf.image.decode_image(image, channels=3) # TODO 为什么不行 95 # image = tf.image.decode_image(image, channels=3) # TODO ?
157 image = tf.image.decode_png(image, channels=3) 96 image = tf.image.decode_png(image, channels=3)
158 return image, label 97 return image, label
159 98
...@@ -186,7 +125,10 @@ class F3Classification(BaseModel): ...@@ -186,7 +125,10 @@ class F3Classification(BaseModel):
186 ).prefetch(tf.data.AUTOTUNE) 125 ).prefetch(tf.data.AUTOTUNE)
187 return parallel_batch_dataset 126 return parallel_batch_dataset
188 127
189 def load_model(self): 128 def load_model(self, for_training=False, load_weights_path=None):
129 if self.model is not None:
130 raise Exception('Model is loaded, if you are sure to reload the model, set `self.model = None` first')
131
190 base_model = MobileNetV2( 132 base_model = MobileNetV2(
191 input_shape=(224, 224, 3), 133 input_shape=(224, 224, 3),
192 alpha=0.35, 134 alpha=0.35,
...@@ -199,27 +141,41 @@ class F3Classification(BaseModel): ...@@ -199,27 +141,41 @@ class F3Classification(BaseModel):
199 x = layers.Dense(256, activation='sigmoid', name='dense')(x) 141 x = layers.Dense(256, activation='sigmoid', name='dense')(x)
200 x = layers.Dropout(0.5)(x) 142 x = layers.Dropout(0.5)(x)
201 x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x) 143 x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x)
202 model = models.Model(inputs=base_model.input, outputs=x) 144 self.model = models.Model(inputs=base_model.input, outputs=x)
203 145
146 if for_training:
204 freeze = True 147 freeze = True
205 for layer in model.layers: 148 for layer in self.model.layers:
206 layer.trainable = not freeze 149 layer.trainable = not freeze
207 if freeze and layer.name == 'block_16_project_BN': 150 if freeze and layer.name == 'block_16_project_BN':
208 freeze = False 151 freeze = False
209 return model
210 152
211 def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path, 153 if isinstance(load_weights_path, str):
212 train_dir_name='train', validate_dir_name='test', thresholds=0.5): 154 if not os.path.isfile(load_weights_path):
155 raise Exception('load_weights_path can not find')
156 self.model.load_weights(load_weights_path, by_name=True, skip_mismatch=True)
157
158 def train(self,
159 dataset_dir,
160 epoch,
161 batch_size,
162 ckpt_path,
163 history_save_path,
164 load_weights_path=None,
165 train_dir_name='train',
166 validate_dir_name='test',
167 thresholds=0.5,
168 metrics_name='accuracy'):
213 169
214 self.gpu_config() 170 self.gpu_config()
215 171
216 model = self.load_model() 172 self.load_model(for_training=True, load_weights_path=load_weights_path)
217 model.summary() 173 self.model.summary()
218 174
219 model.compile( 175 self.model.compile(
220 optimizer=optimizers.Adam(learning_rate=3e-4), 176 optimizer=optimizers.Adam(learning_rate=3e-4),
221 loss=tfa.losses.SigmoidFocalCrossEntropy(), # TODO >>> 177 loss=tfa.losses.SigmoidFocalCrossEntropy(), # TODO ?
222 metrics=[CustomMetric(thresholds), ], 178 metrics=[CustomMetric(thresholds, name=metrics_name), ],
223 179
224 loss_weights=None, 180 loss_weights=None,
225 weighted_metrics=None, 181 weighted_metrics=None,
...@@ -250,14 +206,71 @@ class F3Classification(BaseModel): ...@@ -250,14 +206,71 @@ class F3Classification(BaseModel):
250 206
251 ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) 207 ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True)
252 208
253 history = model.fit( 209 history = self.model.fit(
254 train_dataset, 210 train_dataset,
255 epochs=epoch, 211 epochs=epoch,
256 validation_data=validate_dataset, 212 validation_data=validate_dataset,
257 callbacks=[ckpt_callback, ], 213 callbacks=[ckpt_callback, ],
258 ) 214 )
259 215
260 self.history_save(history, history_save_path) 216 history_save(history, history_save_path, metrics_name)
217
218 def evaluation(self,
219 load_weights_path,
220 confusion_matrix_save_path,
221 dataset_dir,
222 batch_size,
223 validate_dir_name='test',
224 thresholds=0.5):
225 self.gpu_config()
226
227 self.load_model(load_weights_path=load_weights_path)
228 self.model.summary()
229
230 validate_dataset = self.load_dataset(
231 dataset_dir=os.path.join(dataset_dir, validate_dir_name),
232 name=validate_dir_name,
233 batch_size=batch_size,
234 augmentation_methods=[]
235 )
236
237 label_true_list = []
238 label_pred_list = []
239 custom_metric = CustomMetric(thresholds)
240 for image_batch, y_true_batch in validate_dataset:
241 y_pred_batch = self.model.predict(image_batch)
242 label_true_batch_with_others = custom_metric.y_true_with_others(y_true_batch)
243 label_pred_batch_with_others = custom_metric.y_pred_with_others(y_pred_batch)
244 label_true_list.extend(label_true_batch_with_others.numpy())
245 label_pred_list.extend(label_pred_batch_with_others.numpy())
246 acc = accuracy_score(label_true_list, label_pred_list)
247 cm = confusion_matrix(label_true_list, label_pred_list)
248 report = classification_report(label_true_list, label_pred_list)
249 print(acc)
250 print(cm)
251 print(report)
252 plot_confusion_matrix(cm, [idx for idx in range(len(self.class_name_list))], confusion_matrix_save_path)
253
254 def predict(self, image, thresholds=0.5):
255 if self.model is None:
256 raise Exception("The model hasn't loaded yet, run `self.load_model()` first")
257 input_image, _ = self.preprocess_input(image, None)
258 input_images = tf.expand_dims(input_image, axis=0)
259 outputs = self.model.predict(input_images)
260
261 for output in outputs:
262 idx = tf.math.argmax(output)
263 confidence = output[idx]
264 if confidence < thresholds:
265 idx = -1
266 label = self.class_name_list[idx + 1]
267 break
268
269 res = {
270 'label': label,
271 'confidence': confidence
272 }
273 return res
261 274
262 def test(self): 275 def test(self):
263 y_true = [ 276 y_true = [
......
1 import numpy as np
2 import itertools
3 import matplotlib.pyplot as plt
4
5
6 def history_save(history, save_path, metrics_name='accuracy'):
7 acc = history.history[metrics_name]
8 val_acc = history.history['val_{0}'.format(metrics_name)]
9
10 loss = history.history['loss']
11 val_loss = history.history['val_loss']
12
13 plt.figure(figsize=(8, 8))
14 plt.subplot(2, 1, 1)
15 plt.plot(acc, label='Training Accuracy')
16 plt.plot(val_acc, label='Validation Accuracy')
17 plt.legend(loc='lower right')
18 plt.ylabel('Accuracy')
19 plt.ylim([min(plt.ylim()), 1])
20 plt.title('Training and Validation Accuracy')
21
22 plt.subplot(2, 1, 2)
23 plt.plot(loss, label='Training Loss')
24 plt.plot(val_loss, label='Validation Loss')
25 plt.legend(loc='upper right')
26 plt.ylabel('Cross Entropy')
27 plt.ylim([0, 1.0])
28 plt.title('Training and Validation Loss')
29 plt.xlabel('epoch')
30 # plt.show()
31 plt.savefig(save_path)
32
33
34 def plot_confusion_matrix(cm, class_names, save_path):
35 """
36 Returns a matplotlib figure containing the plotted confusion matrix.
37
38 Args:
39 cm (array, shape = [n, n]): a confusion matrix of integer classes
40 class_names (array, shape = [n]): String names of the integer classes
41 save_path (str): figure save path
42 """
43 figure = plt.figure(figsize=(8, 8))
44 plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
45 plt.title("Confusion matrix")
46 plt.colorbar()
47 tick_marks = np.arange(len(class_names))
48 plt.xticks(tick_marks, class_names, rotation=45)
49 plt.yticks(tick_marks, class_names)
50
51 # Compute the labels from the normalized confusion matrix.
52 labels = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)
53 # labels = cm.astype('int')
54
55 # Use white text if squares are dark; otherwise black.
56 threshold = cm.max() / 2.
57 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
58 color = "white" if cm[i, j] > threshold else "black"
59 plt.text(j, i, labels[i, j], horizontalalignment="center", color=color)
60
61 plt.tight_layout()
62 plt.ylabel('True label')
63 plt.xlabel('Predicted label')
64 plt.savefig(save_path)
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!