From cbeebc6da73215a3014c51b6737c08b3dd7ed018 Mon Sep 17 00:00:00 2001 From: zhouweiqi <zhouweiqi@situdata.com> Date: Wed, 29 Jun 2022 11:38:28 +0800 Subject: [PATCH] auth from done --- .gitignore | 13 +++++++++++++ authorization_from/README.md | 18 ++++++++++++++++++ authorization_from/const.py | 38 ++++++++++++++++++++++++++++++++++++++ authorization_from/retriever.py | 130 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ classification/base_class.py | 17 +++++++++++++++++ classification/const.py | 6 ++++++ classification/main.py | 23 +++++++++++++++++++++++ classification/model.py | 188 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ classification/train.py | 0 9 files changed, 433 insertions(+), 0 deletions(-) create mode 100644 authorization_from/README.md create mode 100644 authorization_from/const.py create mode 100644 authorization_from/retriever.py create mode 100644 classification/base_class.py create mode 100644 classification/const.py create mode 100644 classification/main.py create mode 100644 classification/model.py delete mode 100644 classification/train.py diff --git a/.gitignore b/.gitignore index 485dee6..af8a18f 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,14 @@ .idea + +# Byte-compiled / optimized / DLL files +*.[oa] +*~ +*.py[cod] +*$py.class +**/*.py[cod] + +#Hidden +.* +!.gitignore + +test.py \ No newline at end of file diff --git a/authorization_from/README.md b/authorization_from/README.md new file mode 100644 index 0000000..a16489d --- /dev/null +++ b/authorization_from/README.md @@ -0,0 +1,18 @@ +## Useage +**F3个人授权书和企业授权书的信息提取** + +```python +from retriever import Retriever +import const + +# 个人授权书 {'姓名': 'xxx', '个人身份证件号码': 'xxx', '签字': '有'} +r = Retriever(const.TARGET_FIELD_INDIVIDUALS) + +# 企业授权书 {'经销商名称': 'xx', '经销商代码-宝马中国': 'xx', '管理人员姓名-总经理': 'xx', '公司公章': '有', '法定代表人签章': '有'} +# r = Retriever(const.TARGET_FIELD_COMPANIES) +res = r.get_target_fields(go_res, signature_res) +``` + + + + diff --git a/authorization_from/const.py b/authorization_from/const.py new file mode 100644 index 0000000..3d86498 --- /dev/null +++ b/authorization_from/const.py @@ -0,0 +1,38 @@ +TARGET_FIELD_INDIVIDUALS = { + 'keys': { + '姓名': [('姓名', 'top1', {})], + '个人身份证件号码': [('个人身份证件号码', 'top1', {})], + }, + 'value': { + '姓名': ('under', {'left_padding': 1, 'right_padding': 1}, ''), + '个人身份证件号码': ('under', {'left_padding': 0.5, 'right_padding': 0.5}, '') + }, + 'signature': { + '签字': {'signature', } + } +} + +TARGET_FIELD_COMPANIES = { + 'keys': { + '经销商名称': [ + ('经销商名称', 'top1', {}) + ], + '经销商代码-宝马中国': [ + ('经销商代码', 'top1', {}), + ('宝马中国', 'right', {'top_padding': 1.5, 'bottom_padding': 0}) + ], + '管理人员姓名-总经理': [ + ('管理人员姓名', 'top1', {}), + ('总经理', 'right', {'top_padding': 1, 'bottom_padding': 0}) + ], + }, + 'value': { + '经销商名称': ('right', {'top_padding': 1, 'bottom_padding': 1}, ''), + '经销商代码-宝马中国': ('right', {'top_padding': 0.5, 'bottom_padding': 0.5}, ''), + '管理人员姓名-总经理': ('right', {'top_padding': 0.5, 'bottom_padding': 0.5}, '') + }, + 'signature': { + '公司公章': {'circle', }, + '法定代表人签章': {'signature', 'rectangle'} + } +} \ No newline at end of file diff --git a/authorization_from/retriever.py b/authorization_from/retriever.py new file mode 100644 index 0000000..43dc74c --- /dev/null +++ b/authorization_from/retriever.py @@ -0,0 +1,130 @@ +class Retriever: + + def __init__(self, target_fields): + self.keys_str = 'keys' + self.value_str = 'value' + self.signature_str = 'signature' + self.signature_have_str = '有' + self.signature_have_not_str = '无' + self.target_fields = target_fields + self.key_text_set = self.get_key_text_set(target_fields) + + def get_key_text_set(self, target_fields): + key_text_set = set() + for key_text_list in target_fields[self.keys_str].values(): + for key_text, _, _ in key_text_list: + key_text_set.add(key_text) + return key_text_set + + @staticmethod + def key_top1(coordinates_list, key_coordinates): + coordinates_list.sort(key=lambda x: x[1]) + return coordinates_list[0] + + @staticmethod + def key_right(coordinates_list, key_coordinates, top_padding, bottom_padding): + if len(coordinates_list) == 1: + return coordinates_list[0] + height = key_coordinates[-1] - key_coordinates[1] + y_min = key_coordinates[1] - (top_padding * height) + y_max = key_coordinates[-1] + (bottom_padding * height) + x = key_coordinates[2] + + x_min = None + key_coordinates = None + for x0, y0, x1, y1 in coordinates_list: + if y0 > y_min and y1 < y_max and x0 > x: + if x_min is None or x0 < x_min: + x_min = x0 + key_coordinates = (x0, y0, x1, y1) + return key_coordinates + + @staticmethod + def value_right(go_res, key_coordinates, top_padding, bottom_padding): + height = key_coordinates[-1] - key_coordinates[1] + y_min = key_coordinates[1] - (top_padding * height) + y_max = key_coordinates[-1] + (bottom_padding * height) + x = key_coordinates[2] + + x_min = None + value = None + for (x0, y0, _, _, x1, y1, _, _), text in go_res.values(): + if y0 > y_min and y1 < y_max and x0 > x: + if x_min is None or x0 < x_min: + x_min = x0 + value = text + return value + + @staticmethod + def value_under(go_res, key_coordinates, left_padding, right_padding): + width = key_coordinates[2] - key_coordinates[0] + x_min = key_coordinates[0] - (width * left_padding) + x_max = key_coordinates[2] + (width * right_padding) + y = key_coordinates[-1] + + y_min = None + value = None + for (x0, y0, _, _, x1, y1, _, _), text in go_res.values(): + if x0 > x_min and x1 < x_max and y0 > y: + if y_min is None or y0 < y_min: + y_min = y0 + value = text + return value + + def get_target_fields(self, go_res, signature_res_list): + # 搜索关键词 + key_text_info = dict() + for (x0, y0, _, _, x1, y1, _, _), text in go_res.values(): + if text in self.key_text_set: + key_text_info.setdefault(text, list()).append((x0, y0, x1, y1)) + + # 搜索关键词 + key_coordinates_info = dict() + for field, key_text_list in self.target_fields[self.keys_str].items(): + pre_key_coordinates = None + for key_text, direction, kwargs in key_text_list: + if key_text not in key_text_info: + break + key_coordinates = getattr(self, 'key_{0}'.format(direction))( + key_text_info[key_text], + pre_key_coordinates, + **kwargs) + if not isinstance(key_coordinates, tuple): + break + pre_key_coordinates = key_coordinates + else: + key_coordinates_info[field] = pre_key_coordinates + + # 搜索字段值 + res = dict() + for field, (direction, kwargs, default_value) in self.target_fields[self.value_str].items(): + if not isinstance(key_coordinates_info.get(field), tuple): + res[field] = default_value + break + value = getattr(self, 'value_{0}'.format(direction))( + go_res, + key_coordinates_info[field], + **kwargs + ) + if not isinstance(value, str): + res[field] = default_value + else: + res[field] = value + + # 搜索签章 + tmp_signature_count = dict() + for signature_dict in signature_res_list: + if signature_dict['label'] in tmp_signature_count: + tmp_signature_count[signature_dict['label']] += 1 + else: + tmp_signature_count[signature_dict['label']] = 1 + for field, signature_type_set in self.target_fields[self.signature_str].items(): + for signature_type in signature_type_set: + if tmp_signature_count.get(signature_type, 0) > 0: + res[field] = self.signature_have_str + tmp_signature_count[signature_type] -= 1 + break + else: + res[field] = self.signature_have_not_str + + return res diff --git a/classification/base_class.py b/classification/base_class.py new file mode 100644 index 0000000..6eef066 --- /dev/null +++ b/classification/base_class.py @@ -0,0 +1,17 @@ +class BaseModel: + """ + All Model classes should extend BaseModel. + """ + + def load_model(self): + """ + Defining the network structure and return + """ + raise NotImplementedError(".load() must be overridden.") + + def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'): + """ + Model training process + """ + raise NotImplementedError(".train() must be overridden.") + diff --git a/classification/const.py b/classification/const.py new file mode 100644 index 0000000..cbcefe6 --- /dev/null +++ b/classification/const.py @@ -0,0 +1,6 @@ +CLASS_OTHER_CN = '其他' + +CLASS_OTHER_FIRST = True + +CLASS_CN_LIST = [CLASS_OTHER_CN, '身份证', '营业执照', '经销商授权书', '个人授权书'] + diff --git a/classification/main.py b/classification/main.py new file mode 100644 index 0000000..b899c34 --- /dev/null +++ b/classification/main.py @@ -0,0 +1,23 @@ +import os +from datetime import datetime +from model import F3Classification +import const + + +if __name__ == '__main__': + base_dir = os.path.dirname(os.path.abspath(__file__)) + + m = F3Classification( + class_name_list=const.CLASS_CN_LIST, + class_other_first=const.CLASS_OTHER_FIRST + ) + + # m.test() + + dataset_dir = '/home/zwq/data/data_224' + ckpt_path = os.path.join(base_dir, 'ckpt_{0}.h5'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))) + epoch = 100 + batch_size = 128 + + m.train(dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test') + diff --git a/classification/model.py b/classification/model.py new file mode 100644 index 0000000..b0070ff --- /dev/null +++ b/classification/model.py @@ -0,0 +1,188 @@ +import os +import random +import tensorflow as tf +import tensorflow_addons as tfa +from keras.applications.mobilenet_v2 import MobileNetV2 +from keras import layers, models, optimizers, losses, metrics, callbacks, applications +import matplotlib.pyplot as plt + +from base_class import BaseModel + + +@tf.function +def random_rgb_2_bgr(image, label): + if random.random() > 0.5: + return image, label + image = image[:, :, ::-1] + return image, label + + +@tf.function +def random_grayscale_expand(image, label): + if random.random() > 0.1: + return image, label + image = tf.image.rgb_to_grayscale(image) + image = tf.image.grayscale_to_rgb(image) + return image, label + + +@tf.function +def load_image(image_path, label): + image = tf.io.read_file(image_path) + image = tf.image.decode_image(image, channels=3) + return image, label + + +@tf.function +def preprocess_input(image, label): + image = tf.image.resize(image, [224, 224]) + image = applications.mobilenet_v2.preprocess_input(image) + return image, label + + +class F3Classification(BaseModel): + + def __init__(self, class_name_list, class_other_first, *args, **kwargs): + super().__init__(*args, **kwargs) + self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1 + self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) + + @staticmethod + def get_class_label_map(class_name_list, class_other_first=False): + return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)} + + def get_image_label_list(self, dataset_dir): + image_path_list = [] + label_list = [] + for class_name in os.listdir(dataset_dir): + class_dir_path = os.path.join(dataset_dir, class_name) + if not os.path.isdir(class_dir_path): + continue + if class_name not in self.class_label_map: + continue + label = self.class_label_map[class_name] + for file_name in os.listdir(class_dir_path): + # TODO image check + file_path = os.path.join(class_dir_path, file_name) + image_path_list.append(file_path) + label_list.append(tf.one_hot(label, depth=self.class_count)) + return image_path_list, label_list + + def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]): + image_and_label_list = self.get_image_label_list(dataset_dir) + tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name) + tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True) + tensor_slice_dataset.map(load_image, + num_parallel_calls=tf.data.AUTOTUNE, + deterministic=False) + for augmentation_method in augmentation_methods: + tensor_slice_dataset.map(getattr(self, augmentation_method), + num_parallel_calls=tf.data.AUTOTUNE, + deterministic=False) + tensor_slice_dataset.map(preprocess_input, + num_parallel_calls=tf.data.AUTOTUNE, + deterministic=False) + parallel_batch_dataset = tensor_slice_dataset.batch( + batch_size=batch_size, + drop_remainder=True, + num_parallel_calls=tf.data.AUTOTUNE, + deterministic=False, + name=name, + ).prefetch(tf.data.AUTOTUNE) + return parallel_batch_dataset + + def load_model(self): + base_model = MobileNetV2( + input_shape=(224, 224, 3), + alpha=0.35, + include_top=False, + weights='imagenet', + pooling='avg', + ) + x = base_model.output + x = layers.Dropout(0.5)(x) + x = layers.Dense(256, activation='sigmoid', name='dense')(x) + x = layers.Dropout(0.5)(x) + x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x) + model = models.Model(inputs=base_model.input, outputs=x) + + freeze = True + for layer in model.layers: + layer.trainable = not freeze + if freeze and layer.name == 'block_16_project_BN': + freeze = False + return model + + def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'): + # model = self.load_model() + # model.summary() + # + # model.compile( + # optimizer=optimizers.Adam(learning_rate=3e-4), + # loss=tfa.losses.SigmoidFocalCrossEntropy(), + # metrics=['accuracy', ], + # + # loss_weights=None, + # weighted_metrics=None, + # run_eagerly=None, + # steps_per_execution=None, + # jit_compile=None, + # ) + + train_dataset = self.load_dataset( + dataset_dir=os.path.join(dataset_dir, train_dir_name), + name=train_dir_name, + batch_size=batch_size, + augmentation_methods=[], + # augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'], + ) + validate_dataset = self.load_dataset( + dataset_dir=os.path.join(dataset_dir, validate_dir_name), + name=validate_dir_name, + batch_size=batch_size, + augmentation_methods=[] + ) + + # ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) + # + # history = model.fit( + # train_dataset, + # epochs=epoch, + # validation_data=validate_dataset, + # callbacks=[ckpt_callback, ], + # ) + # + # acc = history.history['accuracy'] + # val_acc = history.history['val_accuracy'] + # + # loss = history.history['loss'] + # val_loss = history.history['val_loss'] + # + # plt.figure(figsize=(8, 8)) + # plt.subplot(2, 1, 1) + # plt.plot(acc, label='Training Accuracy') + # plt.plot(val_acc, label='Validation Accuracy') + # plt.legend(loc='lower right') + # plt.ylabel('Accuracy') + # plt.ylim([min(plt.ylim()), 1]) + # plt.title('Training and Validation Accuracy') + # + # plt.subplot(2, 1, 2) + # plt.plot(loss, label='Training Loss') + # plt.plot(val_loss, label='Validation Loss') + # plt.legend(loc='upper right') + # plt.ylabel('Cross Entropy') + # plt.ylim([0, 1.0]) + # plt.title('Training and Validation Loss') + # plt.xlabel('epoch') + # plt.show() + + def test(self): + print(self.class_label_map) + print(self.class_count) + # path = '/home/zwq/data/data_224/train/银行卡/bc_1.jpg' + # label = 5 + # image, label = self.load_image(path, label) + # print(image.shape) + # image, label = self.preprocess_input(image, label) + # print(image.shape) diff --git a/classification/train.py b/classification/train.py deleted file mode 100644 index e69de29..0000000 --- a/classification/train.py +++ /dev/null -- libgit2 0.24.0