auth from done
Showing
9 changed files
with
433 additions
and
0 deletions
authorization_from/README.md
0 → 100644
1 | ## Useage | ||
2 | **F3个人授权书和企业授权书的信息提取** | ||
3 | |||
4 | ```python | ||
5 | from retriever import Retriever | ||
6 | import const | ||
7 | |||
8 | # 个人授权书 {'姓名': 'xxx', '个人身份证件号码': 'xxx', '签字': '有'} | ||
9 | r = Retriever(const.TARGET_FIELD_INDIVIDUALS) | ||
10 | |||
11 | # 企业授权书 {'经销商名称': 'xx', '经销商代码-宝马中国': 'xx', '管理人员姓名-总经理': 'xx', '公司公章': '有', '法定代表人签章': '有'} | ||
12 | # r = Retriever(const.TARGET_FIELD_COMPANIES) | ||
13 | res = r.get_target_fields(go_res, signature_res) | ||
14 | ``` | ||
15 | |||
16 | |||
17 | |||
18 |
authorization_from/const.py
0 → 100644
1 | TARGET_FIELD_INDIVIDUALS = { | ||
2 | 'keys': { | ||
3 | '姓名': [('姓名', 'top1', {})], | ||
4 | '个人身份证件号码': [('个人身份证件号码', 'top1', {})], | ||
5 | }, | ||
6 | 'value': { | ||
7 | '姓名': ('under', {'left_padding': 1, 'right_padding': 1}, ''), | ||
8 | '个人身份证件号码': ('under', {'left_padding': 0.5, 'right_padding': 0.5}, '') | ||
9 | }, | ||
10 | 'signature': { | ||
11 | '签字': {'signature', } | ||
12 | } | ||
13 | } | ||
14 | |||
15 | TARGET_FIELD_COMPANIES = { | ||
16 | 'keys': { | ||
17 | '经销商名称': [ | ||
18 | ('经销商名称', 'top1', {}) | ||
19 | ], | ||
20 | '经销商代码-宝马中国': [ | ||
21 | ('经销商代码', 'top1', {}), | ||
22 | ('宝马中国', 'right', {'top_padding': 1.5, 'bottom_padding': 0}) | ||
23 | ], | ||
24 | '管理人员姓名-总经理': [ | ||
25 | ('管理人员姓名', 'top1', {}), | ||
26 | ('总经理', 'right', {'top_padding': 1, 'bottom_padding': 0}) | ||
27 | ], | ||
28 | }, | ||
29 | 'value': { | ||
30 | '经销商名称': ('right', {'top_padding': 1, 'bottom_padding': 1}, ''), | ||
31 | '经销商代码-宝马中国': ('right', {'top_padding': 0.5, 'bottom_padding': 0.5}, ''), | ||
32 | '管理人员姓名-总经理': ('right', {'top_padding': 0.5, 'bottom_padding': 0.5}, '') | ||
33 | }, | ||
34 | 'signature': { | ||
35 | '公司公章': {'circle', }, | ||
36 | '法定代表人签章': {'signature', 'rectangle'} | ||
37 | } | ||
38 | } | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
authorization_from/retriever.py
0 → 100644
1 | class Retriever: | ||
2 | |||
3 | def __init__(self, target_fields): | ||
4 | self.keys_str = 'keys' | ||
5 | self.value_str = 'value' | ||
6 | self.signature_str = 'signature' | ||
7 | self.signature_have_str = '有' | ||
8 | self.signature_have_not_str = '无' | ||
9 | self.target_fields = target_fields | ||
10 | self.key_text_set = self.get_key_text_set(target_fields) | ||
11 | |||
12 | def get_key_text_set(self, target_fields): | ||
13 | key_text_set = set() | ||
14 | for key_text_list in target_fields[self.keys_str].values(): | ||
15 | for key_text, _, _ in key_text_list: | ||
16 | key_text_set.add(key_text) | ||
17 | return key_text_set | ||
18 | |||
19 | @staticmethod | ||
20 | def key_top1(coordinates_list, key_coordinates): | ||
21 | coordinates_list.sort(key=lambda x: x[1]) | ||
22 | return coordinates_list[0] | ||
23 | |||
24 | @staticmethod | ||
25 | def key_right(coordinates_list, key_coordinates, top_padding, bottom_padding): | ||
26 | if len(coordinates_list) == 1: | ||
27 | return coordinates_list[0] | ||
28 | height = key_coordinates[-1] - key_coordinates[1] | ||
29 | y_min = key_coordinates[1] - (top_padding * height) | ||
30 | y_max = key_coordinates[-1] + (bottom_padding * height) | ||
31 | x = key_coordinates[2] | ||
32 | |||
33 | x_min = None | ||
34 | key_coordinates = None | ||
35 | for x0, y0, x1, y1 in coordinates_list: | ||
36 | if y0 > y_min and y1 < y_max and x0 > x: | ||
37 | if x_min is None or x0 < x_min: | ||
38 | x_min = x0 | ||
39 | key_coordinates = (x0, y0, x1, y1) | ||
40 | return key_coordinates | ||
41 | |||
42 | @staticmethod | ||
43 | def value_right(go_res, key_coordinates, top_padding, bottom_padding): | ||
44 | height = key_coordinates[-1] - key_coordinates[1] | ||
45 | y_min = key_coordinates[1] - (top_padding * height) | ||
46 | y_max = key_coordinates[-1] + (bottom_padding * height) | ||
47 | x = key_coordinates[2] | ||
48 | |||
49 | x_min = None | ||
50 | value = None | ||
51 | for (x0, y0, _, _, x1, y1, _, _), text in go_res.values(): | ||
52 | if y0 > y_min and y1 < y_max and x0 > x: | ||
53 | if x_min is None or x0 < x_min: | ||
54 | x_min = x0 | ||
55 | value = text | ||
56 | return value | ||
57 | |||
58 | @staticmethod | ||
59 | def value_under(go_res, key_coordinates, left_padding, right_padding): | ||
60 | width = key_coordinates[2] - key_coordinates[0] | ||
61 | x_min = key_coordinates[0] - (width * left_padding) | ||
62 | x_max = key_coordinates[2] + (width * right_padding) | ||
63 | y = key_coordinates[-1] | ||
64 | |||
65 | y_min = None | ||
66 | value = None | ||
67 | for (x0, y0, _, _, x1, y1, _, _), text in go_res.values(): | ||
68 | if x0 > x_min and x1 < x_max and y0 > y: | ||
69 | if y_min is None or y0 < y_min: | ||
70 | y_min = y0 | ||
71 | value = text | ||
72 | return value | ||
73 | |||
74 | def get_target_fields(self, go_res, signature_res_list): | ||
75 | # 搜索关键词 | ||
76 | key_text_info = dict() | ||
77 | for (x0, y0, _, _, x1, y1, _, _), text in go_res.values(): | ||
78 | if text in self.key_text_set: | ||
79 | key_text_info.setdefault(text, list()).append((x0, y0, x1, y1)) | ||
80 | |||
81 | # 搜索关键词 | ||
82 | key_coordinates_info = dict() | ||
83 | for field, key_text_list in self.target_fields[self.keys_str].items(): | ||
84 | pre_key_coordinates = None | ||
85 | for key_text, direction, kwargs in key_text_list: | ||
86 | if key_text not in key_text_info: | ||
87 | break | ||
88 | key_coordinates = getattr(self, 'key_{0}'.format(direction))( | ||
89 | key_text_info[key_text], | ||
90 | pre_key_coordinates, | ||
91 | **kwargs) | ||
92 | if not isinstance(key_coordinates, tuple): | ||
93 | break | ||
94 | pre_key_coordinates = key_coordinates | ||
95 | else: | ||
96 | key_coordinates_info[field] = pre_key_coordinates | ||
97 | |||
98 | # 搜索字段值 | ||
99 | res = dict() | ||
100 | for field, (direction, kwargs, default_value) in self.target_fields[self.value_str].items(): | ||
101 | if not isinstance(key_coordinates_info.get(field), tuple): | ||
102 | res[field] = default_value | ||
103 | break | ||
104 | value = getattr(self, 'value_{0}'.format(direction))( | ||
105 | go_res, | ||
106 | key_coordinates_info[field], | ||
107 | **kwargs | ||
108 | ) | ||
109 | if not isinstance(value, str): | ||
110 | res[field] = default_value | ||
111 | else: | ||
112 | res[field] = value | ||
113 | |||
114 | # 搜索签章 | ||
115 | tmp_signature_count = dict() | ||
116 | for signature_dict in signature_res_list: | ||
117 | if signature_dict['label'] in tmp_signature_count: | ||
118 | tmp_signature_count[signature_dict['label']] += 1 | ||
119 | else: | ||
120 | tmp_signature_count[signature_dict['label']] = 1 | ||
121 | for field, signature_type_set in self.target_fields[self.signature_str].items(): | ||
122 | for signature_type in signature_type_set: | ||
123 | if tmp_signature_count.get(signature_type, 0) > 0: | ||
124 | res[field] = self.signature_have_str | ||
125 | tmp_signature_count[signature_type] -= 1 | ||
126 | break | ||
127 | else: | ||
128 | res[field] = self.signature_have_not_str | ||
129 | |||
130 | return res |
classification/base_class.py
0 → 100644
1 | class BaseModel: | ||
2 | """ | ||
3 | All Model classes should extend BaseModel. | ||
4 | """ | ||
5 | |||
6 | def load_model(self): | ||
7 | """ | ||
8 | Defining the network structure and return | ||
9 | """ | ||
10 | raise NotImplementedError(".load() must be overridden.") | ||
11 | |||
12 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'): | ||
13 | """ | ||
14 | Model training process | ||
15 | """ | ||
16 | raise NotImplementedError(".train() must be overridden.") | ||
17 |
classification/const.py
0 → 100644
classification/main.py
0 → 100644
1 | import os | ||
2 | from datetime import datetime | ||
3 | from model import F3Classification | ||
4 | import const | ||
5 | |||
6 | |||
7 | if __name__ == '__main__': | ||
8 | base_dir = os.path.dirname(os.path.abspath(__file__)) | ||
9 | |||
10 | m = F3Classification( | ||
11 | class_name_list=const.CLASS_CN_LIST, | ||
12 | class_other_first=const.CLASS_OTHER_FIRST | ||
13 | ) | ||
14 | |||
15 | # m.test() | ||
16 | |||
17 | dataset_dir = '/home/zwq/data/data_224' | ||
18 | ckpt_path = os.path.join(base_dir, 'ckpt_{0}.h5'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))) | ||
19 | epoch = 100 | ||
20 | batch_size = 128 | ||
21 | |||
22 | m.train(dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test') | ||
23 |
classification/model.py
0 → 100644
1 | import os | ||
2 | import random | ||
3 | import tensorflow as tf | ||
4 | import tensorflow_addons as tfa | ||
5 | from keras.applications.mobilenet_v2 import MobileNetV2 | ||
6 | from keras import layers, models, optimizers, losses, metrics, callbacks, applications | ||
7 | import matplotlib.pyplot as plt | ||
8 | |||
9 | from base_class import BaseModel | ||
10 | |||
11 | |||
12 | @tf.function | ||
13 | def random_rgb_2_bgr(image, label): | ||
14 | if random.random() > 0.5: | ||
15 | return image, label | ||
16 | image = image[:, :, ::-1] | ||
17 | return image, label | ||
18 | |||
19 | |||
20 | @tf.function | ||
21 | def random_grayscale_expand(image, label): | ||
22 | if random.random() > 0.1: | ||
23 | return image, label | ||
24 | image = tf.image.rgb_to_grayscale(image) | ||
25 | image = tf.image.grayscale_to_rgb(image) | ||
26 | return image, label | ||
27 | |||
28 | |||
29 | @tf.function | ||
30 | def load_image(image_path, label): | ||
31 | image = tf.io.read_file(image_path) | ||
32 | image = tf.image.decode_image(image, channels=3) | ||
33 | return image, label | ||
34 | |||
35 | |||
36 | @tf.function | ||
37 | def preprocess_input(image, label): | ||
38 | image = tf.image.resize(image, [224, 224]) | ||
39 | image = applications.mobilenet_v2.preprocess_input(image) | ||
40 | return image, label | ||
41 | |||
42 | |||
43 | class F3Classification(BaseModel): | ||
44 | |||
45 | def __init__(self, class_name_list, class_other_first, *args, **kwargs): | ||
46 | super().__init__(*args, **kwargs) | ||
47 | self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1 | ||
48 | self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) | ||
49 | |||
50 | @staticmethod | ||
51 | def get_class_label_map(class_name_list, class_other_first=False): | ||
52 | return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)} | ||
53 | |||
54 | def get_image_label_list(self, dataset_dir): | ||
55 | image_path_list = [] | ||
56 | label_list = [] | ||
57 | for class_name in os.listdir(dataset_dir): | ||
58 | class_dir_path = os.path.join(dataset_dir, class_name) | ||
59 | if not os.path.isdir(class_dir_path): | ||
60 | continue | ||
61 | if class_name not in self.class_label_map: | ||
62 | continue | ||
63 | label = self.class_label_map[class_name] | ||
64 | for file_name in os.listdir(class_dir_path): | ||
65 | # TODO image check | ||
66 | file_path = os.path.join(class_dir_path, file_name) | ||
67 | image_path_list.append(file_path) | ||
68 | label_list.append(tf.one_hot(label, depth=self.class_count)) | ||
69 | return image_path_list, label_list | ||
70 | |||
71 | def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]): | ||
72 | image_and_label_list = self.get_image_label_list(dataset_dir) | ||
73 | tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name) | ||
74 | tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True) | ||
75 | tensor_slice_dataset.map(load_image, | ||
76 | num_parallel_calls=tf.data.AUTOTUNE, | ||
77 | deterministic=False) | ||
78 | for augmentation_method in augmentation_methods: | ||
79 | tensor_slice_dataset.map(getattr(self, augmentation_method), | ||
80 | num_parallel_calls=tf.data.AUTOTUNE, | ||
81 | deterministic=False) | ||
82 | tensor_slice_dataset.map(preprocess_input, | ||
83 | num_parallel_calls=tf.data.AUTOTUNE, | ||
84 | deterministic=False) | ||
85 | parallel_batch_dataset = tensor_slice_dataset.batch( | ||
86 | batch_size=batch_size, | ||
87 | drop_remainder=True, | ||
88 | num_parallel_calls=tf.data.AUTOTUNE, | ||
89 | deterministic=False, | ||
90 | name=name, | ||
91 | ).prefetch(tf.data.AUTOTUNE) | ||
92 | return parallel_batch_dataset | ||
93 | |||
94 | def load_model(self): | ||
95 | base_model = MobileNetV2( | ||
96 | input_shape=(224, 224, 3), | ||
97 | alpha=0.35, | ||
98 | include_top=False, | ||
99 | weights='imagenet', | ||
100 | pooling='avg', | ||
101 | ) | ||
102 | x = base_model.output | ||
103 | x = layers.Dropout(0.5)(x) | ||
104 | x = layers.Dense(256, activation='sigmoid', name='dense')(x) | ||
105 | x = layers.Dropout(0.5)(x) | ||
106 | x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x) | ||
107 | model = models.Model(inputs=base_model.input, outputs=x) | ||
108 | |||
109 | freeze = True | ||
110 | for layer in model.layers: | ||
111 | layer.trainable = not freeze | ||
112 | if freeze and layer.name == 'block_16_project_BN': | ||
113 | freeze = False | ||
114 | return model | ||
115 | |||
116 | def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'): | ||
117 | # model = self.load_model() | ||
118 | # model.summary() | ||
119 | # | ||
120 | # model.compile( | ||
121 | # optimizer=optimizers.Adam(learning_rate=3e-4), | ||
122 | # loss=tfa.losses.SigmoidFocalCrossEntropy(), | ||
123 | # metrics=['accuracy', ], | ||
124 | # | ||
125 | # loss_weights=None, | ||
126 | # weighted_metrics=None, | ||
127 | # run_eagerly=None, | ||
128 | # steps_per_execution=None, | ||
129 | # jit_compile=None, | ||
130 | # ) | ||
131 | |||
132 | train_dataset = self.load_dataset( | ||
133 | dataset_dir=os.path.join(dataset_dir, train_dir_name), | ||
134 | name=train_dir_name, | ||
135 | batch_size=batch_size, | ||
136 | augmentation_methods=[], | ||
137 | # augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'], | ||
138 | ) | ||
139 | validate_dataset = self.load_dataset( | ||
140 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), | ||
141 | name=validate_dir_name, | ||
142 | batch_size=batch_size, | ||
143 | augmentation_methods=[] | ||
144 | ) | ||
145 | |||
146 | # ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) | ||
147 | # | ||
148 | # history = model.fit( | ||
149 | # train_dataset, | ||
150 | # epochs=epoch, | ||
151 | # validation_data=validate_dataset, | ||
152 | # callbacks=[ckpt_callback, ], | ||
153 | # ) | ||
154 | # | ||
155 | # acc = history.history['accuracy'] | ||
156 | # val_acc = history.history['val_accuracy'] | ||
157 | # | ||
158 | # loss = history.history['loss'] | ||
159 | # val_loss = history.history['val_loss'] | ||
160 | # | ||
161 | # plt.figure(figsize=(8, 8)) | ||
162 | # plt.subplot(2, 1, 1) | ||
163 | # plt.plot(acc, label='Training Accuracy') | ||
164 | # plt.plot(val_acc, label='Validation Accuracy') | ||
165 | # plt.legend(loc='lower right') | ||
166 | # plt.ylabel('Accuracy') | ||
167 | # plt.ylim([min(plt.ylim()), 1]) | ||
168 | # plt.title('Training and Validation Accuracy') | ||
169 | # | ||
170 | # plt.subplot(2, 1, 2) | ||
171 | # plt.plot(loss, label='Training Loss') | ||
172 | # plt.plot(val_loss, label='Validation Loss') | ||
173 | # plt.legend(loc='upper right') | ||
174 | # plt.ylabel('Cross Entropy') | ||
175 | # plt.ylim([0, 1.0]) | ||
176 | # plt.title('Training and Validation Loss') | ||
177 | # plt.xlabel('epoch') | ||
178 | # plt.show() | ||
179 | |||
180 | def test(self): | ||
181 | print(self.class_label_map) | ||
182 | print(self.class_count) | ||
183 | # path = '/home/zwq/data/data_224/train/银行卡/bc_1.jpg' | ||
184 | # label = 5 | ||
185 | # image, label = self.load_image(path, label) | ||
186 | # print(image.shape) | ||
187 | # image, label = self.preprocess_input(image, label) | ||
188 | # print(image.shape) |
classification/train.py
deleted
100644 → 0
File mode changed
-
Please register or sign in to post a comment