cbeebc6d by 周伟奇

auth from done

1 parent 5586ffb0
1 .idea 1 .idea
2
3 # Byte-compiled / optimized / DLL files
4 *.[oa]
5 *~
6 *.py[cod]
7 *$py.class
8 **/*.py[cod]
9
10 #Hidden
11 .*
12 !.gitignore
13
14 test.py
...\ No newline at end of file ...\ No newline at end of file
......
1 ## Useage
2 **F3个人授权书和企业授权书的信息提取**
3
4 ```python
5 from retriever import Retriever
6 import const
7
8 # 个人授权书 {'姓名': 'xxx', '个人身份证件号码': 'xxx', '签字': '有'}
9 r = Retriever(const.TARGET_FIELD_INDIVIDUALS)
10
11 # 企业授权书 {'经销商名称': 'xx', '经销商代码-宝马中国': 'xx', '管理人员姓名-总经理': 'xx', '公司公章': '有', '法定代表人签章': '有'}
12 # r = Retriever(const.TARGET_FIELD_COMPANIES)
13 res = r.get_target_fields(go_res, signature_res)
14 ```
15
16
17
18
1 TARGET_FIELD_INDIVIDUALS = {
2 'keys': {
3 '姓名': [('姓名', 'top1', {})],
4 '个人身份证件号码': [('个人身份证件号码', 'top1', {})],
5 },
6 'value': {
7 '姓名': ('under', {'left_padding': 1, 'right_padding': 1}, ''),
8 '个人身份证件号码': ('under', {'left_padding': 0.5, 'right_padding': 0.5}, '')
9 },
10 'signature': {
11 '签字': {'signature', }
12 }
13 }
14
15 TARGET_FIELD_COMPANIES = {
16 'keys': {
17 '经销商名称': [
18 ('经销商名称', 'top1', {})
19 ],
20 '经销商代码-宝马中国': [
21 ('经销商代码', 'top1', {}),
22 ('宝马中国', 'right', {'top_padding': 1.5, 'bottom_padding': 0})
23 ],
24 '管理人员姓名-总经理': [
25 ('管理人员姓名', 'top1', {}),
26 ('总经理', 'right', {'top_padding': 1, 'bottom_padding': 0})
27 ],
28 },
29 'value': {
30 '经销商名称': ('right', {'top_padding': 1, 'bottom_padding': 1}, ''),
31 '经销商代码-宝马中国': ('right', {'top_padding': 0.5, 'bottom_padding': 0.5}, ''),
32 '管理人员姓名-总经理': ('right', {'top_padding': 0.5, 'bottom_padding': 0.5}, '')
33 },
34 'signature': {
35 '公司公章': {'circle', },
36 '法定代表人签章': {'signature', 'rectangle'}
37 }
38 }
...\ No newline at end of file ...\ No newline at end of file
1 class Retriever:
2
3 def __init__(self, target_fields):
4 self.keys_str = 'keys'
5 self.value_str = 'value'
6 self.signature_str = 'signature'
7 self.signature_have_str = '有'
8 self.signature_have_not_str = '无'
9 self.target_fields = target_fields
10 self.key_text_set = self.get_key_text_set(target_fields)
11
12 def get_key_text_set(self, target_fields):
13 key_text_set = set()
14 for key_text_list in target_fields[self.keys_str].values():
15 for key_text, _, _ in key_text_list:
16 key_text_set.add(key_text)
17 return key_text_set
18
19 @staticmethod
20 def key_top1(coordinates_list, key_coordinates):
21 coordinates_list.sort(key=lambda x: x[1])
22 return coordinates_list[0]
23
24 @staticmethod
25 def key_right(coordinates_list, key_coordinates, top_padding, bottom_padding):
26 if len(coordinates_list) == 1:
27 return coordinates_list[0]
28 height = key_coordinates[-1] - key_coordinates[1]
29 y_min = key_coordinates[1] - (top_padding * height)
30 y_max = key_coordinates[-1] + (bottom_padding * height)
31 x = key_coordinates[2]
32
33 x_min = None
34 key_coordinates = None
35 for x0, y0, x1, y1 in coordinates_list:
36 if y0 > y_min and y1 < y_max and x0 > x:
37 if x_min is None or x0 < x_min:
38 x_min = x0
39 key_coordinates = (x0, y0, x1, y1)
40 return key_coordinates
41
42 @staticmethod
43 def value_right(go_res, key_coordinates, top_padding, bottom_padding):
44 height = key_coordinates[-1] - key_coordinates[1]
45 y_min = key_coordinates[1] - (top_padding * height)
46 y_max = key_coordinates[-1] + (bottom_padding * height)
47 x = key_coordinates[2]
48
49 x_min = None
50 value = None
51 for (x0, y0, _, _, x1, y1, _, _), text in go_res.values():
52 if y0 > y_min and y1 < y_max and x0 > x:
53 if x_min is None or x0 < x_min:
54 x_min = x0
55 value = text
56 return value
57
58 @staticmethod
59 def value_under(go_res, key_coordinates, left_padding, right_padding):
60 width = key_coordinates[2] - key_coordinates[0]
61 x_min = key_coordinates[0] - (width * left_padding)
62 x_max = key_coordinates[2] + (width * right_padding)
63 y = key_coordinates[-1]
64
65 y_min = None
66 value = None
67 for (x0, y0, _, _, x1, y1, _, _), text in go_res.values():
68 if x0 > x_min and x1 < x_max and y0 > y:
69 if y_min is None or y0 < y_min:
70 y_min = y0
71 value = text
72 return value
73
74 def get_target_fields(self, go_res, signature_res_list):
75 # 搜索关键词
76 key_text_info = dict()
77 for (x0, y0, _, _, x1, y1, _, _), text in go_res.values():
78 if text in self.key_text_set:
79 key_text_info.setdefault(text, list()).append((x0, y0, x1, y1))
80
81 # 搜索关键词
82 key_coordinates_info = dict()
83 for field, key_text_list in self.target_fields[self.keys_str].items():
84 pre_key_coordinates = None
85 for key_text, direction, kwargs in key_text_list:
86 if key_text not in key_text_info:
87 break
88 key_coordinates = getattr(self, 'key_{0}'.format(direction))(
89 key_text_info[key_text],
90 pre_key_coordinates,
91 **kwargs)
92 if not isinstance(key_coordinates, tuple):
93 break
94 pre_key_coordinates = key_coordinates
95 else:
96 key_coordinates_info[field] = pre_key_coordinates
97
98 # 搜索字段值
99 res = dict()
100 for field, (direction, kwargs, default_value) in self.target_fields[self.value_str].items():
101 if not isinstance(key_coordinates_info.get(field), tuple):
102 res[field] = default_value
103 break
104 value = getattr(self, 'value_{0}'.format(direction))(
105 go_res,
106 key_coordinates_info[field],
107 **kwargs
108 )
109 if not isinstance(value, str):
110 res[field] = default_value
111 else:
112 res[field] = value
113
114 # 搜索签章
115 tmp_signature_count = dict()
116 for signature_dict in signature_res_list:
117 if signature_dict['label'] in tmp_signature_count:
118 tmp_signature_count[signature_dict['label']] += 1
119 else:
120 tmp_signature_count[signature_dict['label']] = 1
121 for field, signature_type_set in self.target_fields[self.signature_str].items():
122 for signature_type in signature_type_set:
123 if tmp_signature_count.get(signature_type, 0) > 0:
124 res[field] = self.signature_have_str
125 tmp_signature_count[signature_type] -= 1
126 break
127 else:
128 res[field] = self.signature_have_not_str
129
130 return res
1 class BaseModel:
2 """
3 All Model classes should extend BaseModel.
4 """
5
6 def load_model(self):
7 """
8 Defining the network structure and return
9 """
10 raise NotImplementedError(".load() must be overridden.")
11
12 def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'):
13 """
14 Model training process
15 """
16 raise NotImplementedError(".train() must be overridden.")
17
1 CLASS_OTHER_CN = '其他'
2
3 CLASS_OTHER_FIRST = True
4
5 CLASS_CN_LIST = [CLASS_OTHER_CN, '身份证', '营业执照', '经销商授权书', '个人授权书']
6
1 import os
2 from datetime import datetime
3 from model import F3Classification
4 import const
5
6
7 if __name__ == '__main__':
8 base_dir = os.path.dirname(os.path.abspath(__file__))
9
10 m = F3Classification(
11 class_name_list=const.CLASS_CN_LIST,
12 class_other_first=const.CLASS_OTHER_FIRST
13 )
14
15 # m.test()
16
17 dataset_dir = '/home/zwq/data/data_224'
18 ckpt_path = os.path.join(base_dir, 'ckpt_{0}.h5'.format(datetime.now().strftime('%Y-%m-%d_%H:%M:%S')))
19 epoch = 100
20 batch_size = 128
21
22 m.train(dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test')
23
1 import os
2 import random
3 import tensorflow as tf
4 import tensorflow_addons as tfa
5 from keras.applications.mobilenet_v2 import MobileNetV2
6 from keras import layers, models, optimizers, losses, metrics, callbacks, applications
7 import matplotlib.pyplot as plt
8
9 from base_class import BaseModel
10
11
12 @tf.function
13 def random_rgb_2_bgr(image, label):
14 if random.random() > 0.5:
15 return image, label
16 image = image[:, :, ::-1]
17 return image, label
18
19
20 @tf.function
21 def random_grayscale_expand(image, label):
22 if random.random() > 0.1:
23 return image, label
24 image = tf.image.rgb_to_grayscale(image)
25 image = tf.image.grayscale_to_rgb(image)
26 return image, label
27
28
29 @tf.function
30 def load_image(image_path, label):
31 image = tf.io.read_file(image_path)
32 image = tf.image.decode_image(image, channels=3)
33 return image, label
34
35
36 @tf.function
37 def preprocess_input(image, label):
38 image = tf.image.resize(image, [224, 224])
39 image = applications.mobilenet_v2.preprocess_input(image)
40 return image, label
41
42
43 class F3Classification(BaseModel):
44
45 def __init__(self, class_name_list, class_other_first, *args, **kwargs):
46 super().__init__(*args, **kwargs)
47 self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1
48 self.class_label_map = self.get_class_label_map(class_name_list, class_other_first)
49
50 @staticmethod
51 def get_class_label_map(class_name_list, class_other_first=False):
52 return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)}
53
54 def get_image_label_list(self, dataset_dir):
55 image_path_list = []
56 label_list = []
57 for class_name in os.listdir(dataset_dir):
58 class_dir_path = os.path.join(dataset_dir, class_name)
59 if not os.path.isdir(class_dir_path):
60 continue
61 if class_name not in self.class_label_map:
62 continue
63 label = self.class_label_map[class_name]
64 for file_name in os.listdir(class_dir_path):
65 # TODO image check
66 file_path = os.path.join(class_dir_path, file_name)
67 image_path_list.append(file_path)
68 label_list.append(tf.one_hot(label, depth=self.class_count))
69 return image_path_list, label_list
70
71 def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]):
72 image_and_label_list = self.get_image_label_list(dataset_dir)
73 tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name)
74 tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True)
75 tensor_slice_dataset.map(load_image,
76 num_parallel_calls=tf.data.AUTOTUNE,
77 deterministic=False)
78 for augmentation_method in augmentation_methods:
79 tensor_slice_dataset.map(getattr(self, augmentation_method),
80 num_parallel_calls=tf.data.AUTOTUNE,
81 deterministic=False)
82 tensor_slice_dataset.map(preprocess_input,
83 num_parallel_calls=tf.data.AUTOTUNE,
84 deterministic=False)
85 parallel_batch_dataset = tensor_slice_dataset.batch(
86 batch_size=batch_size,
87 drop_remainder=True,
88 num_parallel_calls=tf.data.AUTOTUNE,
89 deterministic=False,
90 name=name,
91 ).prefetch(tf.data.AUTOTUNE)
92 return parallel_batch_dataset
93
94 def load_model(self):
95 base_model = MobileNetV2(
96 input_shape=(224, 224, 3),
97 alpha=0.35,
98 include_top=False,
99 weights='imagenet',
100 pooling='avg',
101 )
102 x = base_model.output
103 x = layers.Dropout(0.5)(x)
104 x = layers.Dense(256, activation='sigmoid', name='dense')(x)
105 x = layers.Dropout(0.5)(x)
106 x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x)
107 model = models.Model(inputs=base_model.input, outputs=x)
108
109 freeze = True
110 for layer in model.layers:
111 layer.trainable = not freeze
112 if freeze and layer.name == 'block_16_project_BN':
113 freeze = False
114 return model
115
116 def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'):
117 # model = self.load_model()
118 # model.summary()
119 #
120 # model.compile(
121 # optimizer=optimizers.Adam(learning_rate=3e-4),
122 # loss=tfa.losses.SigmoidFocalCrossEntropy(),
123 # metrics=['accuracy', ],
124 #
125 # loss_weights=None,
126 # weighted_metrics=None,
127 # run_eagerly=None,
128 # steps_per_execution=None,
129 # jit_compile=None,
130 # )
131
132 train_dataset = self.load_dataset(
133 dataset_dir=os.path.join(dataset_dir, train_dir_name),
134 name=train_dir_name,
135 batch_size=batch_size,
136 augmentation_methods=[],
137 # augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'],
138 )
139 validate_dataset = self.load_dataset(
140 dataset_dir=os.path.join(dataset_dir, validate_dir_name),
141 name=validate_dir_name,
142 batch_size=batch_size,
143 augmentation_methods=[]
144 )
145
146 # ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True)
147 #
148 # history = model.fit(
149 # train_dataset,
150 # epochs=epoch,
151 # validation_data=validate_dataset,
152 # callbacks=[ckpt_callback, ],
153 # )
154 #
155 # acc = history.history['accuracy']
156 # val_acc = history.history['val_accuracy']
157 #
158 # loss = history.history['loss']
159 # val_loss = history.history['val_loss']
160 #
161 # plt.figure(figsize=(8, 8))
162 # plt.subplot(2, 1, 1)
163 # plt.plot(acc, label='Training Accuracy')
164 # plt.plot(val_acc, label='Validation Accuracy')
165 # plt.legend(loc='lower right')
166 # plt.ylabel('Accuracy')
167 # plt.ylim([min(plt.ylim()), 1])
168 # plt.title('Training and Validation Accuracy')
169 #
170 # plt.subplot(2, 1, 2)
171 # plt.plot(loss, label='Training Loss')
172 # plt.plot(val_loss, label='Validation Loss')
173 # plt.legend(loc='upper right')
174 # plt.ylabel('Cross Entropy')
175 # plt.ylim([0, 1.0])
176 # plt.title('Training and Validation Loss')
177 # plt.xlabel('epoch')
178 # plt.show()
179
180 def test(self):
181 print(self.class_label_map)
182 print(self.class_count)
183 # path = '/home/zwq/data/data_224/train/银行卡/bc_1.jpg'
184 # label = 5
185 # image, label = self.load_image(path, label)
186 # print(image.shape)
187 # image, label = self.preprocess_input(image, label)
188 # print(image.shape)
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!