bc96c928 by 周伟奇

add predict

1 parent 1ea84670
......@@ -11,7 +11,7 @@
.*
!.gitignore
test.py
test*
*.h5
*.jpg
*.out
\ No newline at end of file
......
## Useage
### 分类
```python
import cv2
from classification import classifier
img_path = 'xxx'
img = cv2.imread(img_path)
print(classifier.class_name_list)
res = classifier.predict(img)
print(res) # {'label': '营业执照', 'confidence': 0.988462}
```
### 授权书信息提取
```python
from authorization_from import retriever_individuals, retriever_companies
# 个人授权书
res = retriever_companies.get_target_fields(go_res, signature_res)
print(res)
# 公司授权书
# res = retriever_individuals.get_target_fields(go_res, signature_res)
# print(res)
```
## Useage
**F3个人授权书和企业授权书的信息提取**
```python
from retriever import Retriever
import const
# 个人授权书 {'姓名': 'xxx', '个人身份证件号码': 'xxx', '签字': '有'}
r = Retriever(const.TARGET_FIELD_INDIVIDUALS)
# 企业授权书 {'经销商名称': 'xx', '经销商代码-宝马中国': 'xx', '管理人员姓名-总经理': 'xx', '公司公章': '有', '法定代表人签章': '有'}
# r = Retriever(const.TARGET_FIELD_COMPANIES)
res = r.get_target_fields(go_res, signature_res)
```
from .retriever import Retriever
from .const import TARGET_FIELD_INDIVIDUALS, TARGET_FIELD_COMPANIES
retriever_individuals = Retriever(const.TARGET_FIELD_INDIVIDUALS)
retriever_companies = Retriever(const.TARGET_FIELD_COMPANIES)
import os.path
from .model import F3Classification
from .const import CLASS_CN_LIST, CLASS_OTHER_FIRST
classifier = F3Classification(
class_name_list=CLASS_CN_LIST,
class_other_first=CLASS_OTHER_FIRST
)
classifier.load_model(load_weights_path=os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'ckpt_prod.h5'))
......@@ -3,14 +3,14 @@ class BaseModel:
All Model classes should extend BaseModel.
"""
def load_model(self):
def load_model(self, for_training=False, load_weights_path=None):
"""
Defining the network structure and return
"""
raise NotImplementedError(".load() must be overridden.")
def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path,
train_dir_name='train', validate_dir_name='test'):
train_dir_name='train', validate_dir_name='test', thresholds=0.5, metrics_name='accuracy'):
"""
Model training process
"""
......
......@@ -2,7 +2,8 @@ CLASS_OTHER_CN = '其他'
CLASS_OTHER_FIRST = True
CLASS_CN_LIST = [CLASS_OTHER_CN, '身份证', '营业执照', '经销商授权书', '个人授权书']
# CLASS_CN_LIST = [CLASS_OTHER_CN, '身份证', '营业执照', '经销商授权书', '个人授权书']
CLASS_CN_LIST = [CLASS_OTHER_CN, '营业执照', '经销商授权书', '个人授权书']
OTHER_THRESHOLDS = 0.5
......
import os
from datetime import datetime
from model import F3Classification
import const
if __name__ == '__main__':
base_dir = os.path.dirname(os.path.abspath(__file__))
m = F3Classification(
class_name_list=const.CLASS_CN_LIST,
class_other_first=const.CLASS_OTHER_FIRST
)
# m.test()
dataset_dir = '/home/zwq/data/data_224_f3'
ckpt_path = os.path.join(base_dir, 'ckpt_{0}.h5'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S')))
history_save_path = os.path.join(base_dir, 'history_{0}.jpg'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S')))
epoch = 100
batch_size = 128
m.train(dataset_dir, epoch, batch_size, ckpt_path, history_save_path,
train_dir_name='train', validate_dir_name='test', thresholds=const.OTHER_THRESHOLDS)
import tensorflow as tf
from keras import metrics
class CustomMetric(metrics.Metric):
def __init__(self, thresholds=0.5, name="custom_metric", **kwargs):
super(CustomMetric, self).__init__(name=name, **kwargs)
self.thresholds = thresholds
self.true_positives = self.add_weight(name="ctp", initializer="zeros")
self.count = self.add_weight(name="count", initializer="zeros", dtype='int32')
@staticmethod
def y_true_with_others(y_true):
y_true_idx = tf.argmax(y_true, axis=1) + 1
y_true_is_other = tf.cast(tf.math.reduce_sum(y_true, axis=1), "int64")
y_true = tf.math.multiply(y_true_idx, y_true_is_other)
return y_true
def y_pred_with_others(self, y_pred):
y_pred_idx = tf.argmax(y_pred, axis=1) + 1
y_pred_is_other = tf.cast(tf.math.greater_equal(tf.math.reduce_max(y_pred, axis=1), self.thresholds), 'int64')
y_pred = tf.math.multiply(y_pred_idx, y_pred_is_other)
return y_pred
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = self.y_true_with_others(y_true)
y_pred = self.y_pred_with_others(y_pred)
# print(y_true)
# print(y_pred)
values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32")
values = tf.cast(values, "float32")
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, "float32")
values = tf.multiply(values, sample_weight)
self.true_positives.assign_add(tf.reduce_sum(values))
self.count.assign_add(tf.shape(y_true)[0])
def result(self):
return self.true_positives / tf.cast(self.count, 'float32')
def reset_state(self):
# The state of the metric will be reset at the start of each epoch.
self.true_positives.assign(0.0)
self.count.assign(0)
......@@ -2,57 +2,25 @@ import os
import random
import tensorflow as tf
import tensorflow_addons as tfa
from keras.applications.mobilenet_v2 import MobileNetV2
from keras import layers, models, optimizers, losses, metrics, callbacks, applications
import matplotlib.pyplot as plt
from base_class import BaseModel
class CustomMetric(metrics.Metric):
def __init__(self, thresholds=0.5, name="custom_metric", **kwargs):
super(CustomMetric, self).__init__(name=name, **kwargs)
self.thresholds = thresholds
self.true_positives = self.add_weight(name="ctp", initializer="zeros")
self.count = self.add_weight(name="count", initializer="zeros", dtype='int32')
def update_state(self, y_true, y_pred, sample_weight=None):
y_true_idx = tf.argmax(y_true, axis=1) + 1
y_true_is_other = tf.cast(tf.math.reduce_sum(y_true, axis=1), "int64")
y_true = tf.math.multiply(y_true_idx, y_true_is_other)
y_pred_idx = tf.argmax(y_pred, axis=1) + 1
y_pred_is_other = tf.cast(tf.math.greater_equal(tf.math.reduce_max(y_pred, axis=1), self.thresholds), 'int64')
y_pred = tf.math.multiply(y_pred_idx, y_pred_is_other)
print(y_true)
print(y_pred)
values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32")
values = tf.cast(values, "float32")
if sample_weight is not None:
sample_weight = tf.cast(sample_weight, "float32")
values = tf.multiply(values, sample_weight)
self.true_positives.assign_add(tf.reduce_sum(values))
self.count.assign_add(tf.shape(y_true)[0])
def result(self):
return self.true_positives / tf.cast(self.count, 'float32')
from keras.applications.mobilenet_v2 import MobileNetV2
from keras import layers, models, optimizers, callbacks, applications
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
def reset_state(self):
# The state of the metric will be reset at the start of each epoch.
self.true_positives.assign(0.0)
self.count.assign(0)
from .base_class import BaseModel
from .metrics import CustomMetric
from .utils import history_save, plot_confusion_matrix
class F3Classification(BaseModel):
def __init__(self, class_name_list, class_other_first, *args, **kwargs):
super().__init__(*args, **kwargs)
self.class_name_list = class_name_list
self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1
self.class_label_map = self.get_class_label_map(class_name_list, class_other_first)
self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
self.model = None
@staticmethod
def gpu_config():
......@@ -61,34 +29,6 @@ class F3Classification(BaseModel):
tf.config.set_visible_devices(devices=gpus[1], device_type='GPU')
@staticmethod
def history_save(history, save_path):
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()), 1])
plt.title('Training and Validation Accuracy')
plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0, 1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
# plt.show()
plt.savefig(save_path)
@staticmethod
def get_class_label_map(class_name_list, class_other_first=False):
return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)}
......@@ -103,7 +43,6 @@ class F3Classification(BaseModel):
continue
label = self.class_label_map[class_name]
for file_name in os.listdir(class_dir_path):
# TODO image check
if os.path.splitext(file_name)[1] not in self.image_ext_set:
continue
file_path = os.path.join(class_dir_path, file_name)
......@@ -153,7 +92,7 @@ class F3Classification(BaseModel):
# @tf.function
def load_image(image_path, label):
image = tf.io.read_file(image_path)
# image = tf.image.decode_image(image, channels=3) # TODO 为什么不行
# image = tf.image.decode_image(image, channels=3) # TODO ?
image = tf.image.decode_png(image, channels=3)
return image, label
......@@ -186,7 +125,10 @@ class F3Classification(BaseModel):
).prefetch(tf.data.AUTOTUNE)
return parallel_batch_dataset
def load_model(self):
def load_model(self, for_training=False, load_weights_path=None):
if self.model is not None:
raise Exception('Model is loaded, if you are sure to reload the model, set `self.model = None` first')
base_model = MobileNetV2(
input_shape=(224, 224, 3),
alpha=0.35,
......@@ -199,27 +141,41 @@ class F3Classification(BaseModel):
x = layers.Dense(256, activation='sigmoid', name='dense')(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x)
model = models.Model(inputs=base_model.input, outputs=x)
freeze = True
for layer in model.layers:
layer.trainable = not freeze
if freeze and layer.name == 'block_16_project_BN':
freeze = False
return model
def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path,
train_dir_name='train', validate_dir_name='test', thresholds=0.5):
self.model = models.Model(inputs=base_model.input, outputs=x)
if for_training:
freeze = True
for layer in self.model.layers:
layer.trainable = not freeze
if freeze and layer.name == 'block_16_project_BN':
freeze = False
if isinstance(load_weights_path, str):
if not os.path.isfile(load_weights_path):
raise Exception('load_weights_path can not find')
self.model.load_weights(load_weights_path, by_name=True, skip_mismatch=True)
def train(self,
dataset_dir,
epoch,
batch_size,
ckpt_path,
history_save_path,
load_weights_path=None,
train_dir_name='train',
validate_dir_name='test',
thresholds=0.5,
metrics_name='accuracy'):
self.gpu_config()
model = self.load_model()
model.summary()
self.load_model(for_training=True, load_weights_path=load_weights_path)
self.model.summary()
model.compile(
self.model.compile(
optimizer=optimizers.Adam(learning_rate=3e-4),
loss=tfa.losses.SigmoidFocalCrossEntropy(), # TODO >>>
metrics=[CustomMetric(thresholds), ],
loss=tfa.losses.SigmoidFocalCrossEntropy(), # TODO ?
metrics=[CustomMetric(thresholds, name=metrics_name), ],
loss_weights=None,
weighted_metrics=None,
......@@ -250,14 +206,71 @@ class F3Classification(BaseModel):
ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True)
history = model.fit(
history = self.model.fit(
train_dataset,
epochs=epoch,
validation_data=validate_dataset,
callbacks=[ckpt_callback, ],
)
self.history_save(history, history_save_path)
history_save(history, history_save_path, metrics_name)
def evaluation(self,
load_weights_path,
confusion_matrix_save_path,
dataset_dir,
batch_size,
validate_dir_name='test',
thresholds=0.5):
self.gpu_config()
self.load_model(load_weights_path=load_weights_path)
self.model.summary()
validate_dataset = self.load_dataset(
dataset_dir=os.path.join(dataset_dir, validate_dir_name),
name=validate_dir_name,
batch_size=batch_size,
augmentation_methods=[]
)
label_true_list = []
label_pred_list = []
custom_metric = CustomMetric(thresholds)
for image_batch, y_true_batch in validate_dataset:
y_pred_batch = self.model.predict(image_batch)
label_true_batch_with_others = custom_metric.y_true_with_others(y_true_batch)
label_pred_batch_with_others = custom_metric.y_pred_with_others(y_pred_batch)
label_true_list.extend(label_true_batch_with_others.numpy())
label_pred_list.extend(label_pred_batch_with_others.numpy())
acc = accuracy_score(label_true_list, label_pred_list)
cm = confusion_matrix(label_true_list, label_pred_list)
report = classification_report(label_true_list, label_pred_list)
print(acc)
print(cm)
print(report)
plot_confusion_matrix(cm, [idx for idx in range(len(self.class_name_list))], confusion_matrix_save_path)
def predict(self, image, thresholds=0.5):
if self.model is None:
raise Exception("The model hasn't loaded yet, run `self.load_model()` first")
input_image, _ = self.preprocess_input(image, None)
input_images = tf.expand_dims(input_image, axis=0)
outputs = self.model.predict(input_images)
for output in outputs:
idx = tf.math.argmax(output)
confidence = output[idx]
if confidence < thresholds:
idx = -1
label = self.class_name_list[idx + 1]
break
res = {
'label': label,
'confidence': confidence
}
return res
def test(self):
y_true = [
......
import numpy as np
import itertools
import matplotlib.pyplot as plt
def history_save(history, save_path, metrics_name='accuracy'):
acc = history.history[metrics_name]
val_acc = history.history['val_{0}'.format(metrics_name)]
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()), 1])
plt.title('Training and Validation Accuracy')
plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0, 1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
# plt.show()
plt.savefig(save_path)
def plot_confusion_matrix(cm, class_names, save_path):
"""
Returns a matplotlib figure containing the plotted confusion matrix.
Args:
cm (array, shape = [n, n]): a confusion matrix of integer classes
class_names (array, shape = [n]): String names of the integer classes
save_path (str): figure save path
"""
figure = plt.figure(figsize=(8, 8))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title("Confusion matrix")
plt.colorbar()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
plt.yticks(tick_marks, class_names)
# Compute the labels from the normalized confusion matrix.
labels = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)
# labels = cm.astype('int')
# Use white text if squares are dark; otherwise black.
threshold = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
color = "white" if cm[i, j] > threshold else "black"
plt.text(j, i, labels[i, j], horizontalalignment="center", color=color)
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.savefig(save_path)
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!