cbeebc6d by 周伟奇

auth from done

1 parent 5586ffb0
.idea
# Byte-compiled / optimized / DLL files
*.[oa]
*~
*.py[cod]
*$py.class
**/*.py[cod]
#Hidden
.*
!.gitignore
test.py
\ No newline at end of file
......
## Useage
**F3个人授权书和企业授权书的信息提取**
```python
from retriever import Retriever
import const
# 个人授权书 {'姓名': 'xxx', '个人身份证件号码': 'xxx', '签字': '有'}
r = Retriever(const.TARGET_FIELD_INDIVIDUALS)
# 企业授权书 {'经销商名称': 'xx', '经销商代码-宝马中国': 'xx', '管理人员姓名-总经理': 'xx', '公司公章': '有', '法定代表人签章': '有'}
# r = Retriever(const.TARGET_FIELD_COMPANIES)
res = r.get_target_fields(go_res, signature_res)
```
TARGET_FIELD_INDIVIDUALS = {
'keys': {
'姓名': [('姓名', 'top1', {})],
'个人身份证件号码': [('个人身份证件号码', 'top1', {})],
},
'value': {
'姓名': ('under', {'left_padding': 1, 'right_padding': 1}, ''),
'个人身份证件号码': ('under', {'left_padding': 0.5, 'right_padding': 0.5}, '')
},
'signature': {
'签字': {'signature', }
}
}
TARGET_FIELD_COMPANIES = {
'keys': {
'经销商名称': [
('经销商名称', 'top1', {})
],
'经销商代码-宝马中国': [
('经销商代码', 'top1', {}),
('宝马中国', 'right', {'top_padding': 1.5, 'bottom_padding': 0})
],
'管理人员姓名-总经理': [
('管理人员姓名', 'top1', {}),
('总经理', 'right', {'top_padding': 1, 'bottom_padding': 0})
],
},
'value': {
'经销商名称': ('right', {'top_padding': 1, 'bottom_padding': 1}, ''),
'经销商代码-宝马中国': ('right', {'top_padding': 0.5, 'bottom_padding': 0.5}, ''),
'管理人员姓名-总经理': ('right', {'top_padding': 0.5, 'bottom_padding': 0.5}, '')
},
'signature': {
'公司公章': {'circle', },
'法定代表人签章': {'signature', 'rectangle'}
}
}
\ No newline at end of file
class Retriever:
def __init__(self, target_fields):
self.keys_str = 'keys'
self.value_str = 'value'
self.signature_str = 'signature'
self.signature_have_str = '有'
self.signature_have_not_str = '无'
self.target_fields = target_fields
self.key_text_set = self.get_key_text_set(target_fields)
def get_key_text_set(self, target_fields):
key_text_set = set()
for key_text_list in target_fields[self.keys_str].values():
for key_text, _, _ in key_text_list:
key_text_set.add(key_text)
return key_text_set
@staticmethod
def key_top1(coordinates_list, key_coordinates):
coordinates_list.sort(key=lambda x: x[1])
return coordinates_list[0]
@staticmethod
def key_right(coordinates_list, key_coordinates, top_padding, bottom_padding):
if len(coordinates_list) == 1:
return coordinates_list[0]
height = key_coordinates[-1] - key_coordinates[1]
y_min = key_coordinates[1] - (top_padding * height)
y_max = key_coordinates[-1] + (bottom_padding * height)
x = key_coordinates[2]
x_min = None
key_coordinates = None
for x0, y0, x1, y1 in coordinates_list:
if y0 > y_min and y1 < y_max and x0 > x:
if x_min is None or x0 < x_min:
x_min = x0
key_coordinates = (x0, y0, x1, y1)
return key_coordinates
@staticmethod
def value_right(go_res, key_coordinates, top_padding, bottom_padding):
height = key_coordinates[-1] - key_coordinates[1]
y_min = key_coordinates[1] - (top_padding * height)
y_max = key_coordinates[-1] + (bottom_padding * height)
x = key_coordinates[2]
x_min = None
value = None
for (x0, y0, _, _, x1, y1, _, _), text in go_res.values():
if y0 > y_min and y1 < y_max and x0 > x:
if x_min is None or x0 < x_min:
x_min = x0
value = text
return value
@staticmethod
def value_under(go_res, key_coordinates, left_padding, right_padding):
width = key_coordinates[2] - key_coordinates[0]
x_min = key_coordinates[0] - (width * left_padding)
x_max = key_coordinates[2] + (width * right_padding)
y = key_coordinates[-1]
y_min = None
value = None
for (x0, y0, _, _, x1, y1, _, _), text in go_res.values():
if x0 > x_min and x1 < x_max and y0 > y:
if y_min is None or y0 < y_min:
y_min = y0
value = text
return value
def get_target_fields(self, go_res, signature_res_list):
# 搜索关键词
key_text_info = dict()
for (x0, y0, _, _, x1, y1, _, _), text in go_res.values():
if text in self.key_text_set:
key_text_info.setdefault(text, list()).append((x0, y0, x1, y1))
# 搜索关键词
key_coordinates_info = dict()
for field, key_text_list in self.target_fields[self.keys_str].items():
pre_key_coordinates = None
for key_text, direction, kwargs in key_text_list:
if key_text not in key_text_info:
break
key_coordinates = getattr(self, 'key_{0}'.format(direction))(
key_text_info[key_text],
pre_key_coordinates,
**kwargs)
if not isinstance(key_coordinates, tuple):
break
pre_key_coordinates = key_coordinates
else:
key_coordinates_info[field] = pre_key_coordinates
# 搜索字段值
res = dict()
for field, (direction, kwargs, default_value) in self.target_fields[self.value_str].items():
if not isinstance(key_coordinates_info.get(field), tuple):
res[field] = default_value
break
value = getattr(self, 'value_{0}'.format(direction))(
go_res,
key_coordinates_info[field],
**kwargs
)
if not isinstance(value, str):
res[field] = default_value
else:
res[field] = value
# 搜索签章
tmp_signature_count = dict()
for signature_dict in signature_res_list:
if signature_dict['label'] in tmp_signature_count:
tmp_signature_count[signature_dict['label']] += 1
else:
tmp_signature_count[signature_dict['label']] = 1
for field, signature_type_set in self.target_fields[self.signature_str].items():
for signature_type in signature_type_set:
if tmp_signature_count.get(signature_type, 0) > 0:
res[field] = self.signature_have_str
tmp_signature_count[signature_type] -= 1
break
else:
res[field] = self.signature_have_not_str
return res
class BaseModel:
"""
All Model classes should extend BaseModel.
"""
def load_model(self):
"""
Defining the network structure and return
"""
raise NotImplementedError(".load() must be overridden.")
def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'):
"""
Model training process
"""
raise NotImplementedError(".train() must be overridden.")
CLASS_OTHER_CN = '其他'
CLASS_OTHER_FIRST = True
CLASS_CN_LIST = [CLASS_OTHER_CN, '身份证', '营业执照', '经销商授权书', '个人授权书']
import os
from datetime import datetime
from model import F3Classification
import const
if __name__ == '__main__':
base_dir = os.path.dirname(os.path.abspath(__file__))
m = F3Classification(
class_name_list=const.CLASS_CN_LIST,
class_other_first=const.CLASS_OTHER_FIRST
)
# m.test()
dataset_dir = '/home/zwq/data/data_224'
ckpt_path = os.path.join(base_dir, 'ckpt_{0}.h5'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S')))
epoch = 100
batch_size = 128
m.train(dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test')
import os
import random
import tensorflow as tf
import tensorflow_addons as tfa
from keras.applications.mobilenet_v2 import MobileNetV2
from keras import layers, models, optimizers, losses, metrics, callbacks, applications
import matplotlib.pyplot as plt
from base_class import BaseModel
@tf.function
def random_rgb_2_bgr(image, label):
if random.random() > 0.5:
return image, label
image = image[:, :, ::-1]
return image, label
@tf.function
def random_grayscale_expand(image, label):
if random.random() > 0.1:
return image, label
image = tf.image.rgb_to_grayscale(image)
image = tf.image.grayscale_to_rgb(image)
return image, label
@tf.function
def load_image(image_path, label):
image = tf.io.read_file(image_path)
image = tf.image.decode_image(image, channels=3)
return image, label
@tf.function
def preprocess_input(image, label):
image = tf.image.resize(image, [224, 224])
image = applications.mobilenet_v2.preprocess_input(image)
return image, label
class F3Classification(BaseModel):
def __init__(self, class_name_list, class_other_first, *args, **kwargs):
super().__init__(*args, **kwargs)
self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1
self.class_label_map = self.get_class_label_map(class_name_list, class_other_first)
@staticmethod
def get_class_label_map(class_name_list, class_other_first=False):
return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)}
def get_image_label_list(self, dataset_dir):
image_path_list = []
label_list = []
for class_name in os.listdir(dataset_dir):
class_dir_path = os.path.join(dataset_dir, class_name)
if not os.path.isdir(class_dir_path):
continue
if class_name not in self.class_label_map:
continue
label = self.class_label_map[class_name]
for file_name in os.listdir(class_dir_path):
# TODO image check
file_path = os.path.join(class_dir_path, file_name)
image_path_list.append(file_path)
label_list.append(tf.one_hot(label, depth=self.class_count))
return image_path_list, label_list
def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]):
image_and_label_list = self.get_image_label_list(dataset_dir)
tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name)
tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True)
tensor_slice_dataset.map(load_image,
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False)
for augmentation_method in augmentation_methods:
tensor_slice_dataset.map(getattr(self, augmentation_method),
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False)
tensor_slice_dataset.map(preprocess_input,
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False)
parallel_batch_dataset = tensor_slice_dataset.batch(
batch_size=batch_size,
drop_remainder=True,
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=False,
name=name,
).prefetch(tf.data.AUTOTUNE)
return parallel_batch_dataset
def load_model(self):
base_model = MobileNetV2(
input_shape=(224, 224, 3),
alpha=0.35,
include_top=False,
weights='imagenet',
pooling='avg',
)
x = base_model.output
x = layers.Dropout(0.5)(x)
x = layers.Dense(256, activation='sigmoid', name='dense')(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x)
model = models.Model(inputs=base_model.input, outputs=x)
freeze = True
for layer in model.layers:
layer.trainable = not freeze
if freeze and layer.name == 'block_16_project_BN':
freeze = False
return model
def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'):
# model = self.load_model()
# model.summary()
#
# model.compile(
# optimizer=optimizers.Adam(learning_rate=3e-4),
# loss=tfa.losses.SigmoidFocalCrossEntropy(),
# metrics=['accuracy', ],
#
# loss_weights=None,
# weighted_metrics=None,
# run_eagerly=None,
# steps_per_execution=None,
# jit_compile=None,
# )
train_dataset = self.load_dataset(
dataset_dir=os.path.join(dataset_dir, train_dir_name),
name=train_dir_name,
batch_size=batch_size,
augmentation_methods=[],
# augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'],
)
validate_dataset = self.load_dataset(
dataset_dir=os.path.join(dataset_dir, validate_dir_name),
name=validate_dir_name,
batch_size=batch_size,
augmentation_methods=[]
)
# ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True)
#
# history = model.fit(
# train_dataset,
# epochs=epoch,
# validation_data=validate_dataset,
# callbacks=[ckpt_callback, ],
# )
#
# acc = history.history['accuracy']
# val_acc = history.history['val_accuracy']
#
# loss = history.history['loss']
# val_loss = history.history['val_loss']
#
# plt.figure(figsize=(8, 8))
# plt.subplot(2, 1, 1)
# plt.plot(acc, label='Training Accuracy')
# plt.plot(val_acc, label='Validation Accuracy')
# plt.legend(loc='lower right')
# plt.ylabel('Accuracy')
# plt.ylim([min(plt.ylim()), 1])
# plt.title('Training and Validation Accuracy')
#
# plt.subplot(2, 1, 2)
# plt.plot(loss, label='Training Loss')
# plt.plot(val_loss, label='Validation Loss')
# plt.legend(loc='upper right')
# plt.ylabel('Cross Entropy')
# plt.ylim([0, 1.0])
# plt.title('Training and Validation Loss')
# plt.xlabel('epoch')
# plt.show()
def test(self):
print(self.class_label_map)
print(self.class_count)
# path = '/home/zwq/data/data_224/train/银行卡/bc_1.jpg'
# label = 5
# image, label = self.load_image(path, label)
# print(image.shape)
# image, label = self.preprocess_input(image, label)
# print(image.shape)
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!