From cbeebc6da73215a3014c51b6737c08b3dd7ed018 Mon Sep 17 00:00:00 2001
From: zhouweiqi <zhouweiqi@situdata.com>
Date: Wed, 29 Jun 2022 11:38:28 +0800
Subject: [PATCH] auth from done

---
 .gitignore                      |  13 +++++++++++++
 authorization_from/README.md    |  18 ++++++++++++++++++
 authorization_from/const.py     |  38 ++++++++++++++++++++++++++++++++++++++
 authorization_from/retriever.py | 130 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 classification/base_class.py    |  17 +++++++++++++++++
 classification/const.py         |   6 ++++++
 classification/main.py          |  23 +++++++++++++++++++++++
 classification/model.py         | 188 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 classification/train.py         |   0
 9 files changed, 433 insertions(+), 0 deletions(-)
 create mode 100644 authorization_from/README.md
 create mode 100644 authorization_from/const.py
 create mode 100644 authorization_from/retriever.py
 create mode 100644 classification/base_class.py
 create mode 100644 classification/const.py
 create mode 100644 classification/main.py
 create mode 100644 classification/model.py
 delete mode 100644 classification/train.py

diff --git a/.gitignore b/.gitignore
index 485dee6..af8a18f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,14 @@
 .idea
+
+# Byte-compiled / optimized / DLL files
+*.[oa]
+*~
+*.py[cod]
+*$py.class
+**/*.py[cod]
+
+#Hidden
+.*
+!.gitignore
+
+test.py
\ No newline at end of file
diff --git a/authorization_from/README.md b/authorization_from/README.md
new file mode 100644
index 0000000..a16489d
--- /dev/null
+++ b/authorization_from/README.md
@@ -0,0 +1,18 @@
+## Useage
+**F3个人授权书和企业授权书的信息提取**
+
+```python
+from retriever import Retriever
+import const
+
+# 个人授权书 {'姓名': 'xxx', '个人身份证件号码': 'xxx', '签字': '有'}
+r = Retriever(const.TARGET_FIELD_INDIVIDUALS)  
+
+# 企业授权书 {'经销商名称': 'xx', '经销商代码-宝马中国': 'xx', '管理人员姓名-总经理': 'xx', '公司公章': '有', '法定代表人签章': '有'}
+# r = Retriever(const.TARGET_FIELD_COMPANIES) 
+res = r.get_target_fields(go_res, signature_res)
+```
+
+
+
+
diff --git a/authorization_from/const.py b/authorization_from/const.py
new file mode 100644
index 0000000..3d86498
--- /dev/null
+++ b/authorization_from/const.py
@@ -0,0 +1,38 @@
+TARGET_FIELD_INDIVIDUALS = {
+    'keys': {
+        '姓名': [('姓名', 'top1', {})],
+        '个人身份证件号码': [('个人身份证件号码', 'top1', {})],
+    },
+    'value': {
+        '姓名': ('under', {'left_padding': 1, 'right_padding': 1}, ''),
+        '个人身份证件号码': ('under', {'left_padding': 0.5, 'right_padding': 0.5}, '')
+    },
+    'signature': {
+        '签字': {'signature', }
+    }
+}
+
+TARGET_FIELD_COMPANIES = {
+    'keys': {
+        '经销商名称': [
+            ('经销商名称', 'top1', {})
+        ],
+        '经销商代码-宝马中国': [
+            ('经销商代码', 'top1', {}),
+            ('宝马中国', 'right', {'top_padding': 1.5, 'bottom_padding': 0})
+        ],
+        '管理人员姓名-总经理': [
+            ('管理人员姓名', 'top1', {}),
+            ('总经理', 'right', {'top_padding': 1, 'bottom_padding': 0})
+        ],
+    },
+    'value': {
+        '经销商名称': ('right', {'top_padding': 1, 'bottom_padding': 1}, ''),
+        '经销商代码-宝马中国': ('right', {'top_padding': 0.5, 'bottom_padding': 0.5}, ''),
+        '管理人员姓名-总经理': ('right', {'top_padding': 0.5, 'bottom_padding': 0.5}, '')
+    },
+    'signature': {
+        '公司公章': {'circle', },
+        '法定代表人签章': {'signature', 'rectangle'}
+    }
+}
\ No newline at end of file
diff --git a/authorization_from/retriever.py b/authorization_from/retriever.py
new file mode 100644
index 0000000..43dc74c
--- /dev/null
+++ b/authorization_from/retriever.py
@@ -0,0 +1,130 @@
+class Retriever:
+
+    def __init__(self, target_fields):
+        self.keys_str = 'keys'
+        self.value_str = 'value'
+        self.signature_str = 'signature'
+        self.signature_have_str = '有'
+        self.signature_have_not_str = '无'
+        self.target_fields = target_fields
+        self.key_text_set = self.get_key_text_set(target_fields)
+
+    def get_key_text_set(self, target_fields):
+        key_text_set = set()
+        for key_text_list in target_fields[self.keys_str].values():
+            for key_text, _, _ in key_text_list:
+                key_text_set.add(key_text)
+        return key_text_set
+
+    @staticmethod
+    def key_top1(coordinates_list, key_coordinates):
+        coordinates_list.sort(key=lambda x: x[1])
+        return coordinates_list[0]
+
+    @staticmethod
+    def key_right(coordinates_list, key_coordinates, top_padding, bottom_padding):
+        if len(coordinates_list) == 1:
+            return coordinates_list[0]
+        height = key_coordinates[-1] - key_coordinates[1]
+        y_min = key_coordinates[1] - (top_padding * height)
+        y_max = key_coordinates[-1] + (bottom_padding * height)
+        x = key_coordinates[2]
+
+        x_min = None
+        key_coordinates = None
+        for x0, y0, x1, y1 in coordinates_list:
+            if y0 > y_min and y1 < y_max and x0 > x:
+                if x_min is None or x0 < x_min:
+                    x_min = x0
+                    key_coordinates = (x0, y0, x1, y1)
+        return key_coordinates
+
+    @staticmethod
+    def value_right(go_res, key_coordinates, top_padding, bottom_padding):
+        height = key_coordinates[-1] - key_coordinates[1]
+        y_min = key_coordinates[1] - (top_padding * height)
+        y_max = key_coordinates[-1] + (bottom_padding * height)
+        x = key_coordinates[2]
+
+        x_min = None
+        value = None
+        for (x0, y0, _, _, x1, y1, _, _), text in go_res.values():
+            if y0 > y_min and y1 < y_max and x0 > x:
+                if x_min is None or x0 < x_min:
+                    x_min = x0
+                    value = text
+        return value
+
+    @staticmethod
+    def value_under(go_res, key_coordinates, left_padding, right_padding):
+        width = key_coordinates[2] - key_coordinates[0]
+        x_min = key_coordinates[0] - (width * left_padding)
+        x_max = key_coordinates[2] + (width * right_padding)
+        y = key_coordinates[-1]
+
+        y_min = None
+        value = None
+        for (x0, y0, _, _, x1, y1, _, _), text in go_res.values():
+            if x0 > x_min and x1 < x_max and y0 > y:
+                if y_min is None or y0 < y_min:
+                    y_min = y0
+                    value = text
+        return value
+
+    def get_target_fields(self, go_res, signature_res_list):
+        # 搜索关键词
+        key_text_info = dict()
+        for (x0, y0, _, _, x1, y1, _, _), text in go_res.values():
+            if text in self.key_text_set:
+                key_text_info.setdefault(text, list()).append((x0, y0, x1, y1))
+
+        # 搜索关键词
+        key_coordinates_info = dict()
+        for field, key_text_list in self.target_fields[self.keys_str].items():
+            pre_key_coordinates = None
+            for key_text, direction, kwargs in key_text_list:
+                if key_text not in key_text_info:
+                    break
+                key_coordinates = getattr(self, 'key_{0}'.format(direction))(
+                    key_text_info[key_text],
+                    pre_key_coordinates,
+                    **kwargs)
+                if not isinstance(key_coordinates, tuple):
+                    break
+                pre_key_coordinates = key_coordinates
+            else:
+                key_coordinates_info[field] = pre_key_coordinates
+
+        # 搜索字段值
+        res = dict()
+        for field, (direction, kwargs, default_value) in self.target_fields[self.value_str].items():
+            if not isinstance(key_coordinates_info.get(field), tuple):
+                res[field] = default_value
+                break
+            value = getattr(self, 'value_{0}'.format(direction))(
+                go_res,
+                key_coordinates_info[field],
+                **kwargs
+            )
+            if not isinstance(value, str):
+                res[field] = default_value
+            else:
+                res[field] = value
+
+        # 搜索签章
+        tmp_signature_count = dict()
+        for signature_dict in signature_res_list:
+            if signature_dict['label'] in tmp_signature_count:
+                tmp_signature_count[signature_dict['label']] += 1
+            else:
+                tmp_signature_count[signature_dict['label']] = 1
+        for field, signature_type_set in self.target_fields[self.signature_str].items():
+            for signature_type in signature_type_set:
+                if tmp_signature_count.get(signature_type, 0) > 0:
+                    res[field] = self.signature_have_str
+                    tmp_signature_count[signature_type] -= 1
+                    break
+                else:
+                    res[field] = self.signature_have_not_str
+
+        return res
diff --git a/classification/base_class.py b/classification/base_class.py
new file mode 100644
index 0000000..6eef066
--- /dev/null
+++ b/classification/base_class.py
@@ -0,0 +1,17 @@
+class BaseModel:
+    """
+    All Model classes should extend BaseModel.
+    """
+
+    def load_model(self):
+        """
+        Defining the network structure and return
+        """
+        raise NotImplementedError(".load() must be overridden.")
+
+    def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'):
+        """
+        Model training process
+        """
+        raise NotImplementedError(".train() must be overridden.")
+
diff --git a/classification/const.py b/classification/const.py
new file mode 100644
index 0000000..cbcefe6
--- /dev/null
+++ b/classification/const.py
@@ -0,0 +1,6 @@
+CLASS_OTHER_CN = '其他'
+
+CLASS_OTHER_FIRST = True
+
+CLASS_CN_LIST = [CLASS_OTHER_CN, '身份证', '营业执照', '经销商授权书', '个人授权书']
+
diff --git a/classification/main.py b/classification/main.py
new file mode 100644
index 0000000..b899c34
--- /dev/null
+++ b/classification/main.py
@@ -0,0 +1,23 @@
+import os
+from datetime import datetime
+from model import F3Classification
+import const
+
+
+if __name__ == '__main__':
+    base_dir = os.path.dirname(os.path.abspath(__file__))
+
+    m = F3Classification(
+        class_name_list=const.CLASS_CN_LIST,
+        class_other_first=const.CLASS_OTHER_FIRST
+    )
+
+    # m.test()
+
+    dataset_dir = '/home/zwq/data/data_224'
+    ckpt_path = os.path.join(base_dir, 'ckpt_{0}.h5'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S')))
+    epoch = 100
+    batch_size = 128
+
+    m.train(dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test')
+
diff --git a/classification/model.py b/classification/model.py
new file mode 100644
index 0000000..b0070ff
--- /dev/null
+++ b/classification/model.py
@@ -0,0 +1,188 @@
+import os
+import random
+import tensorflow as tf
+import tensorflow_addons as tfa
+from keras.applications.mobilenet_v2 import MobileNetV2
+from keras import layers, models, optimizers, losses, metrics, callbacks, applications
+import matplotlib.pyplot as plt
+
+from base_class import BaseModel
+
+
+@tf.function
+def random_rgb_2_bgr(image, label):
+    if random.random() > 0.5:
+        return image, label
+    image = image[:, :, ::-1]
+    return image, label
+
+
+@tf.function
+def random_grayscale_expand(image, label):
+    if random.random() > 0.1:
+        return image, label
+    image = tf.image.rgb_to_grayscale(image)
+    image = tf.image.grayscale_to_rgb(image)
+    return image, label
+
+
+@tf.function
+def load_image(image_path, label):
+    image = tf.io.read_file(image_path)
+    image = tf.image.decode_image(image, channels=3)
+    return image, label
+
+
+@tf.function
+def preprocess_input(image, label):
+    image = tf.image.resize(image, [224, 224])
+    image = applications.mobilenet_v2.preprocess_input(image)
+    return image, label
+
+
+class F3Classification(BaseModel):
+
+    def __init__(self, class_name_list, class_other_first, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1
+        self.class_label_map = self.get_class_label_map(class_name_list, class_other_first)
+
+    @staticmethod
+    def get_class_label_map(class_name_list, class_other_first=False):
+        return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)}
+
+    def get_image_label_list(self, dataset_dir):
+        image_path_list = []
+        label_list = []
+        for class_name in os.listdir(dataset_dir):
+            class_dir_path = os.path.join(dataset_dir, class_name)
+            if not os.path.isdir(class_dir_path):
+                continue
+            if class_name not in self.class_label_map:
+                continue
+            label = self.class_label_map[class_name]
+            for file_name in os.listdir(class_dir_path):
+                # TODO image check
+                file_path = os.path.join(class_dir_path, file_name)
+                image_path_list.append(file_path)
+                label_list.append(tf.one_hot(label, depth=self.class_count))
+        return image_path_list, label_list
+
+    def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]):
+        image_and_label_list = self.get_image_label_list(dataset_dir)
+        tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name)
+        tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True)
+        tensor_slice_dataset.map(load_image,
+                                 num_parallel_calls=tf.data.AUTOTUNE,
+                                 deterministic=False)
+        for augmentation_method in augmentation_methods:
+            tensor_slice_dataset.map(getattr(self, augmentation_method),
+                                     num_parallel_calls=tf.data.AUTOTUNE,
+                                     deterministic=False)
+        tensor_slice_dataset.map(preprocess_input,
+                                 num_parallel_calls=tf.data.AUTOTUNE,
+                                 deterministic=False)
+        parallel_batch_dataset = tensor_slice_dataset.batch(
+            batch_size=batch_size,
+            drop_remainder=True,
+            num_parallel_calls=tf.data.AUTOTUNE,
+            deterministic=False,
+            name=name,
+        ).prefetch(tf.data.AUTOTUNE)
+        return parallel_batch_dataset
+
+    def load_model(self):
+        base_model = MobileNetV2(
+            input_shape=(224, 224, 3),
+            alpha=0.35,
+            include_top=False,
+            weights='imagenet',
+            pooling='avg',
+        )
+        x = base_model.output
+        x = layers.Dropout(0.5)(x)
+        x = layers.Dense(256, activation='sigmoid', name='dense')(x)
+        x = layers.Dropout(0.5)(x)
+        x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x)
+        model = models.Model(inputs=base_model.input, outputs=x)
+
+        freeze = True
+        for layer in model.layers:
+            layer.trainable = not freeze
+            if freeze and layer.name == 'block_16_project_BN':
+                freeze = False
+        return model
+
+    def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'):
+        # model = self.load_model()
+        # model.summary()
+        #
+        # model.compile(
+        #     optimizer=optimizers.Adam(learning_rate=3e-4),
+        #     loss=tfa.losses.SigmoidFocalCrossEntropy(),
+        #     metrics=['accuracy', ],
+        #
+        #     loss_weights=None,
+        #     weighted_metrics=None,
+        #     run_eagerly=None,
+        #     steps_per_execution=None,
+        #     jit_compile=None,
+        # )
+
+        train_dataset = self.load_dataset(
+            dataset_dir=os.path.join(dataset_dir, train_dir_name),
+            name=train_dir_name,
+            batch_size=batch_size,
+            augmentation_methods=[],
+            # augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'],
+        )
+        validate_dataset = self.load_dataset(
+            dataset_dir=os.path.join(dataset_dir, validate_dir_name),
+            name=validate_dir_name,
+            batch_size=batch_size,
+            augmentation_methods=[]
+        )
+
+        # ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True)
+        #
+        # history = model.fit(
+        #     train_dataset,
+        #     epochs=epoch,
+        #     validation_data=validate_dataset,
+        #     callbacks=[ckpt_callback, ],
+        # )
+        #
+        # acc = history.history['accuracy']
+        # val_acc = history.history['val_accuracy']
+        #
+        # loss = history.history['loss']
+        # val_loss = history.history['val_loss']
+        #
+        # plt.figure(figsize=(8, 8))
+        # plt.subplot(2, 1, 1)
+        # plt.plot(acc, label='Training Accuracy')
+        # plt.plot(val_acc, label='Validation Accuracy')
+        # plt.legend(loc='lower right')
+        # plt.ylabel('Accuracy')
+        # plt.ylim([min(plt.ylim()), 1])
+        # plt.title('Training and Validation Accuracy')
+        #
+        # plt.subplot(2, 1, 2)
+        # plt.plot(loss, label='Training Loss')
+        # plt.plot(val_loss, label='Validation Loss')
+        # plt.legend(loc='upper right')
+        # plt.ylabel('Cross Entropy')
+        # plt.ylim([0, 1.0])
+        # plt.title('Training and Validation Loss')
+        # plt.xlabel('epoch')
+        # plt.show()
+
+    def test(self):
+        print(self.class_label_map)
+        print(self.class_count)
+        # path = '/home/zwq/data/data_224/train/银行卡/bc_1.jpg'
+        # label = 5
+        # image, label = self.load_image(path, label)
+        # print(image.shape)
+        # image, label = self.preprocess_input(image, label)
+        # print(image.shape)
diff --git a/classification/train.py b/classification/train.py
deleted file mode 100644
index e69de29..0000000
--- a/classification/train.py
+++ /dev/null
--
libgit2 0.24.0