auth from done
Showing
9 changed files
with
433 additions
and
0 deletions
authorization_from/README.md
0 → 100644
| 1 | ## Useage | ||
| 2 | **F3个人授权书和企业授权书的信息提取** | ||
| 3 | |||
| 4 | ```python | ||
| 5 | from retriever import Retriever | ||
| 6 | import const | ||
| 7 | |||
| 8 | # 个人授权书 {'姓名': 'xxx', '个人身份证件号码': 'xxx', '签字': '有'} | ||
| 9 | r = Retriever(const.TARGET_FIELD_INDIVIDUALS) | ||
| 10 | |||
| 11 | # 企业授权书 {'经销商名称': 'xx', '经销商代码-宝马中国': 'xx', '管理人员姓名-总经理': 'xx', '公司公章': '有', '法定代表人签章': '有'} | ||
| 12 | # r = Retriever(const.TARGET_FIELD_COMPANIES) | ||
| 13 | res = r.get_target_fields(go_res, signature_res) | ||
| 14 | ``` | ||
| 15 | |||
| 16 | |||
| 17 | |||
| 18 |
authorization_from/const.py
0 → 100644
| 1 | TARGET_FIELD_INDIVIDUALS = { | ||
| 2 | 'keys': { | ||
| 3 | '姓名': [('姓名', 'top1', {})], | ||
| 4 | '个人身份证件号码': [('个人身份证件号码', 'top1', {})], | ||
| 5 | }, | ||
| 6 | 'value': { | ||
| 7 | '姓名': ('under', {'left_padding': 1, 'right_padding': 1}, ''), | ||
| 8 | '个人身份证件号码': ('under', {'left_padding': 0.5, 'right_padding': 0.5}, '') | ||
| 9 | }, | ||
| 10 | 'signature': { | ||
| 11 | '签字': {'signature', } | ||
| 12 | } | ||
| 13 | } | ||
| 14 | |||
| 15 | TARGET_FIELD_COMPANIES = { | ||
| 16 | 'keys': { | ||
| 17 | '经销商名称': [ | ||
| 18 | ('经销商名称', 'top1', {}) | ||
| 19 | ], | ||
| 20 | '经销商代码-宝马中国': [ | ||
| 21 | ('经销商代码', 'top1', {}), | ||
| 22 | ('宝马中国', 'right', {'top_padding': 1.5, 'bottom_padding': 0}) | ||
| 23 | ], | ||
| 24 | '管理人员姓名-总经理': [ | ||
| 25 | ('管理人员姓名', 'top1', {}), | ||
| 26 | ('总经理', 'right', {'top_padding': 1, 'bottom_padding': 0}) | ||
| 27 | ], | ||
| 28 | }, | ||
| 29 | 'value': { | ||
| 30 | '经销商名称': ('right', {'top_padding': 1, 'bottom_padding': 1}, ''), | ||
| 31 | '经销商代码-宝马中国': ('right', {'top_padding': 0.5, 'bottom_padding': 0.5}, ''), | ||
| 32 | '管理人员姓名-总经理': ('right', {'top_padding': 0.5, 'bottom_padding': 0.5}, '') | ||
| 33 | }, | ||
| 34 | 'signature': { | ||
| 35 | '公司公章': {'circle', }, | ||
| 36 | '法定代表人签章': {'signature', 'rectangle'} | ||
| 37 | } | ||
| 38 | } | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
authorization_from/retriever.py
0 → 100644
| 1 | class Retriever: | ||
| 2 | |||
| 3 | def __init__(self, target_fields): | ||
| 4 | self.keys_str = 'keys' | ||
| 5 | self.value_str = 'value' | ||
| 6 | self.signature_str = 'signature' | ||
| 7 | self.signature_have_str = '有' | ||
| 8 | self.signature_have_not_str = '无' | ||
| 9 | self.target_fields = target_fields | ||
| 10 | self.key_text_set = self.get_key_text_set(target_fields) | ||
| 11 | |||
| 12 | def get_key_text_set(self, target_fields): | ||
| 13 | key_text_set = set() | ||
| 14 | for key_text_list in target_fields[self.keys_str].values(): | ||
| 15 | for key_text, _, _ in key_text_list: | ||
| 16 | key_text_set.add(key_text) | ||
| 17 | return key_text_set | ||
| 18 | |||
| 19 | @staticmethod | ||
| 20 | def key_top1(coordinates_list, key_coordinates): | ||
| 21 | coordinates_list.sort(key=lambda x: x[1]) | ||
| 22 | return coordinates_list[0] | ||
| 23 | |||
| 24 | @staticmethod | ||
| 25 | def key_right(coordinates_list, key_coordinates, top_padding, bottom_padding): | ||
| 26 | if len(coordinates_list) == 1: | ||
| 27 | return coordinates_list[0] | ||
| 28 | height = key_coordinates[-1] - key_coordinates[1] | ||
| 29 | y_min = key_coordinates[1] - (top_padding * height) | ||
| 30 | y_max = key_coordinates[-1] + (bottom_padding * height) | ||
| 31 | x = key_coordinates[2] | ||
| 32 | |||
| 33 | x_min = None | ||
| 34 | key_coordinates = None | ||
| 35 | for x0, y0, x1, y1 in coordinates_list: | ||
| 36 | if y0 > y_min and y1 < y_max and x0 > x: | ||
| 37 | if x_min is None or x0 < x_min: | ||
| 38 | x_min = x0 | ||
| 39 | key_coordinates = (x0, y0, x1, y1) | ||
| 40 | return key_coordinates | ||
| 41 | |||
| 42 | @staticmethod | ||
| 43 | def value_right(go_res, key_coordinates, top_padding, bottom_padding): | ||
| 44 | height = key_coordinates[-1] - key_coordinates[1] | ||
| 45 | y_min = key_coordinates[1] - (top_padding * height) | ||
| 46 | y_max = key_coordinates[-1] + (bottom_padding * height) | ||
| 47 | x = key_coordinates[2] | ||
| 48 | |||
| 49 | x_min = None | ||
| 50 | value = None | ||
| 51 | for (x0, y0, _, _, x1, y1, _, _), text in go_res.values(): | ||
| 52 | if y0 > y_min and y1 < y_max and x0 > x: | ||
| 53 | if x_min is None or x0 < x_min: | ||
| 54 | x_min = x0 | ||
| 55 | value = text | ||
| 56 | return value | ||
| 57 | |||
| 58 | @staticmethod | ||
| 59 | def value_under(go_res, key_coordinates, left_padding, right_padding): | ||
| 60 | width = key_coordinates[2] - key_coordinates[0] | ||
| 61 | x_min = key_coordinates[0] - (width * left_padding) | ||
| 62 | x_max = key_coordinates[2] + (width * right_padding) | ||
| 63 | y = key_coordinates[-1] | ||
| 64 | |||
| 65 | y_min = None | ||
| 66 | value = None | ||
| 67 | for (x0, y0, _, _, x1, y1, _, _), text in go_res.values(): | ||
| 68 | if x0 > x_min and x1 < x_max and y0 > y: | ||
| 69 | if y_min is None or y0 < y_min: | ||
| 70 | y_min = y0 | ||
| 71 | value = text | ||
| 72 | return value | ||
| 73 | |||
| 74 | def get_target_fields(self, go_res, signature_res_list): | ||
| 75 | # 搜索关键词 | ||
| 76 | key_text_info = dict() | ||
| 77 | for (x0, y0, _, _, x1, y1, _, _), text in go_res.values(): | ||
| 78 | if text in self.key_text_set: | ||
| 79 | key_text_info.setdefault(text, list()).append((x0, y0, x1, y1)) | ||
| 80 | |||
| 81 | # 搜索关键词 | ||
| 82 | key_coordinates_info = dict() | ||
| 83 | for field, key_text_list in self.target_fields[self.keys_str].items(): | ||
| 84 | pre_key_coordinates = None | ||
| 85 | for key_text, direction, kwargs in key_text_list: | ||
| 86 | if key_text not in key_text_info: | ||
| 87 | break | ||
| 88 | key_coordinates = getattr(self, 'key_{0}'.format(direction))( | ||
| 89 | key_text_info[key_text], | ||
| 90 | pre_key_coordinates, | ||
| 91 | **kwargs) | ||
| 92 | if not isinstance(key_coordinates, tuple): | ||
| 93 | break | ||
| 94 | pre_key_coordinates = key_coordinates | ||
| 95 | else: | ||
| 96 | key_coordinates_info[field] = pre_key_coordinates | ||
| 97 | |||
| 98 | # 搜索字段值 | ||
| 99 | res = dict() | ||
| 100 | for field, (direction, kwargs, default_value) in self.target_fields[self.value_str].items(): | ||
| 101 | if not isinstance(key_coordinates_info.get(field), tuple): | ||
| 102 | res[field] = default_value | ||
| 103 | break | ||
| 104 | value = getattr(self, 'value_{0}'.format(direction))( | ||
| 105 | go_res, | ||
| 106 | key_coordinates_info[field], | ||
| 107 | **kwargs | ||
| 108 | ) | ||
| 109 | if not isinstance(value, str): | ||
| 110 | res[field] = default_value | ||
| 111 | else: | ||
| 112 | res[field] = value | ||
| 113 | |||
| 114 | # 搜索签章 | ||
| 115 | tmp_signature_count = dict() | ||
| 116 | for signature_dict in signature_res_list: | ||
| 117 | if signature_dict['label'] in tmp_signature_count: | ||
| 118 | tmp_signature_count[signature_dict['label']] += 1 | ||
| 119 | else: | ||
| 120 | tmp_signature_count[signature_dict['label']] = 1 | ||
| 121 | for field, signature_type_set in self.target_fields[self.signature_str].items(): | ||
| 122 | for signature_type in signature_type_set: | ||
| 123 | if tmp_signature_count.get(signature_type, 0) > 0: | ||
| 124 | res[field] = self.signature_have_str | ||
| 125 | tmp_signature_count[signature_type] -= 1 | ||
| 126 | break | ||
| 127 | else: | ||
| 128 | res[field] = self.signature_have_not_str | ||
| 129 | |||
| 130 | return res |
classification/base_class.py
0 → 100644
| 1 | class BaseModel: | ||
| 2 | """ | ||
| 3 | All Model classes should extend BaseModel. | ||
| 4 | """ | ||
| 5 | |||
| 6 | def load_model(self): | ||
| 7 | """ | ||
| 8 | Defining the network structure and return | ||
| 9 | """ | ||
| 10 | raise NotImplementedError(".load() must be overridden.") | ||
| 11 | |||
| 12 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'): | ||
| 13 | """ | ||
| 14 | Model training process | ||
| 15 | """ | ||
| 16 | raise NotImplementedError(".train() must be overridden.") | ||
| 17 |
classification/const.py
0 → 100644
classification/main.py
0 → 100644
| 1 | import os | ||
| 2 | from datetime import datetime | ||
| 3 | from model import F3Classification | ||
| 4 | import const | ||
| 5 | |||
| 6 | |||
| 7 | if __name__ == '__main__': | ||
| 8 | base_dir = os.path.dirname(os.path.abspath(__file__)) | ||
| 9 | |||
| 10 | m = F3Classification( | ||
| 11 | class_name_list=const.CLASS_CN_LIST, | ||
| 12 | class_other_first=const.CLASS_OTHER_FIRST | ||
| 13 | ) | ||
| 14 | |||
| 15 | # m.test() | ||
| 16 | |||
| 17 | dataset_dir = '/home/zwq/data/data_224' | ||
| 18 | ckpt_path = os.path.join(base_dir, 'ckpt_{0}.h5'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))) | ||
| 19 | epoch = 100 | ||
| 20 | batch_size = 128 | ||
| 21 | |||
| 22 | m.train(dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test') | ||
| 23 |
classification/model.py
0 → 100644
| 1 | import os | ||
| 2 | import random | ||
| 3 | import tensorflow as tf | ||
| 4 | import tensorflow_addons as tfa | ||
| 5 | from keras.applications.mobilenet_v2 import MobileNetV2 | ||
| 6 | from keras import layers, models, optimizers, losses, metrics, callbacks, applications | ||
| 7 | import matplotlib.pyplot as plt | ||
| 8 | |||
| 9 | from base_class import BaseModel | ||
| 10 | |||
| 11 | |||
| 12 | @tf.function | ||
| 13 | def random_rgb_2_bgr(image, label): | ||
| 14 | if random.random() > 0.5: | ||
| 15 | return image, label | ||
| 16 | image = image[:, :, ::-1] | ||
| 17 | return image, label | ||
| 18 | |||
| 19 | |||
| 20 | @tf.function | ||
| 21 | def random_grayscale_expand(image, label): | ||
| 22 | if random.random() > 0.1: | ||
| 23 | return image, label | ||
| 24 | image = tf.image.rgb_to_grayscale(image) | ||
| 25 | image = tf.image.grayscale_to_rgb(image) | ||
| 26 | return image, label | ||
| 27 | |||
| 28 | |||
| 29 | @tf.function | ||
| 30 | def load_image(image_path, label): | ||
| 31 | image = tf.io.read_file(image_path) | ||
| 32 | image = tf.image.decode_image(image, channels=3) | ||
| 33 | return image, label | ||
| 34 | |||
| 35 | |||
| 36 | @tf.function | ||
| 37 | def preprocess_input(image, label): | ||
| 38 | image = tf.image.resize(image, [224, 224]) | ||
| 39 | image = applications.mobilenet_v2.preprocess_input(image) | ||
| 40 | return image, label | ||
| 41 | |||
| 42 | |||
| 43 | class F3Classification(BaseModel): | ||
| 44 | |||
| 45 | def __init__(self, class_name_list, class_other_first, *args, **kwargs): | ||
| 46 | super().__init__(*args, **kwargs) | ||
| 47 | self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1 | ||
| 48 | self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) | ||
| 49 | |||
| 50 | @staticmethod | ||
| 51 | def get_class_label_map(class_name_list, class_other_first=False): | ||
| 52 | return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)} | ||
| 53 | |||
| 54 | def get_image_label_list(self, dataset_dir): | ||
| 55 | image_path_list = [] | ||
| 56 | label_list = [] | ||
| 57 | for class_name in os.listdir(dataset_dir): | ||
| 58 | class_dir_path = os.path.join(dataset_dir, class_name) | ||
| 59 | if not os.path.isdir(class_dir_path): | ||
| 60 | continue | ||
| 61 | if class_name not in self.class_label_map: | ||
| 62 | continue | ||
| 63 | label = self.class_label_map[class_name] | ||
| 64 | for file_name in os.listdir(class_dir_path): | ||
| 65 | # TODO image check | ||
| 66 | file_path = os.path.join(class_dir_path, file_name) | ||
| 67 | image_path_list.append(file_path) | ||
| 68 | label_list.append(tf.one_hot(label, depth=self.class_count)) | ||
| 69 | return image_path_list, label_list | ||
| 70 | |||
| 71 | def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]): | ||
| 72 | image_and_label_list = self.get_image_label_list(dataset_dir) | ||
| 73 | tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name) | ||
| 74 | tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True) | ||
| 75 | tensor_slice_dataset.map(load_image, | ||
| 76 | num_parallel_calls=tf.data.AUTOTUNE, | ||
| 77 | deterministic=False) | ||
| 78 | for augmentation_method in augmentation_methods: | ||
| 79 | tensor_slice_dataset.map(getattr(self, augmentation_method), | ||
| 80 | num_parallel_calls=tf.data.AUTOTUNE, | ||
| 81 | deterministic=False) | ||
| 82 | tensor_slice_dataset.map(preprocess_input, | ||
| 83 | num_parallel_calls=tf.data.AUTOTUNE, | ||
| 84 | deterministic=False) | ||
| 85 | parallel_batch_dataset = tensor_slice_dataset.batch( | ||
| 86 | batch_size=batch_size, | ||
| 87 | drop_remainder=True, | ||
| 88 | num_parallel_calls=tf.data.AUTOTUNE, | ||
| 89 | deterministic=False, | ||
| 90 | name=name, | ||
| 91 | ).prefetch(tf.data.AUTOTUNE) | ||
| 92 | return parallel_batch_dataset | ||
| 93 | |||
| 94 | def load_model(self): | ||
| 95 | base_model = MobileNetV2( | ||
| 96 | input_shape=(224, 224, 3), | ||
| 97 | alpha=0.35, | ||
| 98 | include_top=False, | ||
| 99 | weights='imagenet', | ||
| 100 | pooling='avg', | ||
| 101 | ) | ||
| 102 | x = base_model.output | ||
| 103 | x = layers.Dropout(0.5)(x) | ||
| 104 | x = layers.Dense(256, activation='sigmoid', name='dense')(x) | ||
| 105 | x = layers.Dropout(0.5)(x) | ||
| 106 | x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x) | ||
| 107 | model = models.Model(inputs=base_model.input, outputs=x) | ||
| 108 | |||
| 109 | freeze = True | ||
| 110 | for layer in model.layers: | ||
| 111 | layer.trainable = not freeze | ||
| 112 | if freeze and layer.name == 'block_16_project_BN': | ||
| 113 | freeze = False | ||
| 114 | return model | ||
| 115 | |||
| 116 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'): | ||
| 117 | # model = self.load_model() | ||
| 118 | # model.summary() | ||
| 119 | # | ||
| 120 | # model.compile( | ||
| 121 | # optimizer=optimizers.Adam(learning_rate=3e-4), | ||
| 122 | # loss=tfa.losses.SigmoidFocalCrossEntropy(), | ||
| 123 | # metrics=['accuracy', ], | ||
| 124 | # | ||
| 125 | # loss_weights=None, | ||
| 126 | # weighted_metrics=None, | ||
| 127 | # run_eagerly=None, | ||
| 128 | # steps_per_execution=None, | ||
| 129 | # jit_compile=None, | ||
| 130 | # ) | ||
| 131 | |||
| 132 | train_dataset = self.load_dataset( | ||
| 133 | dataset_dir=os.path.join(dataset_dir, train_dir_name), | ||
| 134 | name=train_dir_name, | ||
| 135 | batch_size=batch_size, | ||
| 136 | augmentation_methods=[], | ||
| 137 | # augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'], | ||
| 138 | ) | ||
| 139 | validate_dataset = self.load_dataset( | ||
| 140 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), | ||
| 141 | name=validate_dir_name, | ||
| 142 | batch_size=batch_size, | ||
| 143 | augmentation_methods=[] | ||
| 144 | ) | ||
| 145 | |||
| 146 | # ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) | ||
| 147 | # | ||
| 148 | # history = model.fit( | ||
| 149 | # train_dataset, | ||
| 150 | # epochs=epoch, | ||
| 151 | # validation_data=validate_dataset, | ||
| 152 | # callbacks=[ckpt_callback, ], | ||
| 153 | # ) | ||
| 154 | # | ||
| 155 | # acc = history.history['accuracy'] | ||
| 156 | # val_acc = history.history['val_accuracy'] | ||
| 157 | # | ||
| 158 | # loss = history.history['loss'] | ||
| 159 | # val_loss = history.history['val_loss'] | ||
| 160 | # | ||
| 161 | # plt.figure(figsize=(8, 8)) | ||
| 162 | # plt.subplot(2, 1, 1) | ||
| 163 | # plt.plot(acc, label='Training Accuracy') | ||
| 164 | # plt.plot(val_acc, label='Validation Accuracy') | ||
| 165 | # plt.legend(loc='lower right') | ||
| 166 | # plt.ylabel('Accuracy') | ||
| 167 | # plt.ylim([min(plt.ylim()), 1]) | ||
| 168 | # plt.title('Training and Validation Accuracy') | ||
| 169 | # | ||
| 170 | # plt.subplot(2, 1, 2) | ||
| 171 | # plt.plot(loss, label='Training Loss') | ||
| 172 | # plt.plot(val_loss, label='Validation Loss') | ||
| 173 | # plt.legend(loc='upper right') | ||
| 174 | # plt.ylabel('Cross Entropy') | ||
| 175 | # plt.ylim([0, 1.0]) | ||
| 176 | # plt.title('Training and Validation Loss') | ||
| 177 | # plt.xlabel('epoch') | ||
| 178 | # plt.show() | ||
| 179 | |||
| 180 | def test(self): | ||
| 181 | print(self.class_label_map) | ||
| 182 | print(self.class_count) | ||
| 183 | # path = '/home/zwq/data/data_224/train/银行卡/bc_1.jpg' | ||
| 184 | # label = 5 | ||
| 185 | # image, label = self.load_image(path, label) | ||
| 186 | # print(image.shape) | ||
| 187 | # image, label = self.preprocess_input(image, label) | ||
| 188 | # print(image.shape) |
classification/train.py
deleted
100644 → 0
File mode changed
-
Please register or sign in to post a comment