add fix pred text
Showing
4 changed files
with
80 additions
and
7 deletions
... | @@ -239,7 +239,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -239,7 +239,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
239 | [label_bbox[0], label_bbox[3]], | 239 | [label_bbox[0], label_bbox[3]], |
240 | ] | 240 | ] |
241 | iou = bbox_iou(go_bbox_rebuild, label_bbox_rebuild) | 241 | iou = bbox_iou(go_bbox_rebuild, label_bbox_rebuild) |
242 | if iou >= 0.5: | 242 | if iou >= 0.4: |
243 | label_idx_dict[go_idx] = label_idx | 243 | label_idx_dict[go_idx] = label_idx |
244 | 244 | ||
245 | X = list() | 245 | X = list() |
... | @@ -250,7 +250,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -250,7 +250,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
250 | # text_vec_max_lens = 15 * 50 | 250 | # text_vec_max_lens = 15 * 50 |
251 | # dim = 1 + 5 + 8 + text_vec_max_lens | 251 | # dim = 1 + 5 + 8 + text_vec_max_lens |
252 | 252 | ||
253 | max_jieba_char = 8 | 253 | max_jieba_char = 4 |
254 | text_vec_max_lens = max_jieba_char * 100 | 254 | text_vec_max_lens = max_jieba_char * 100 |
255 | dim = 1 + 5 + 8 + text_vec_max_lens | 255 | dim = 1 + 5 + 8 + text_vec_max_lens |
256 | 256 | ||
... | @@ -333,7 +333,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -333,7 +333,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
333 | if __name__ == '__main__': | 333 | if __name__ == '__main__': |
334 | base_dir = '/Users/zhouweiqi/Downloads/gcfp/data' | 334 | base_dir = '/Users/zhouweiqi/Downloads/gcfp/data' |
335 | go_dir = os.path.join(base_dir, 'go_res') | 335 | go_dir = os.path.join(base_dir, 'go_res') |
336 | dataset_save_dir = os.path.join(base_dir, 'dataset160x814') | 336 | dataset_save_dir = os.path.join(base_dir, 'dataset160x414') |
337 | label_dir = os.path.join(base_dir, 'labeled') | 337 | label_dir = os.path.join(base_dir, 'labeled') |
338 | 338 | ||
339 | train_go_path = os.path.join(go_dir, 'train') | 339 | train_go_path = os.path.join(go_dir, 'train') | ... | ... |
... | @@ -10,8 +10,7 @@ from data import build_dataloader | ... | @@ -10,8 +10,7 @@ from data import build_dataloader |
10 | from loss import build_loss | 10 | from loss import build_loss |
11 | from model import build_model | 11 | from model import build_model |
12 | from optimizer import build_lr_scheduler, build_optimizer | 12 | from optimizer import build_lr_scheduler, build_optimizer |
13 | from utils import SOLVER_REGISTRY, get_logger_and_log_dir | 13 | from utils import SOLVER_REGISTRY, get_logger_and_log_dir, sequence_mask, fix_text_obj |
14 | from utils import sequence_mask | ||
15 | from sklearn.metrics import confusion_matrix, accuracy_score, classification_report | 14 | from sklearn.metrics import confusion_matrix, accuracy_score, classification_report |
16 | 15 | ||
17 | 16 | ||
... | @@ -223,6 +222,18 @@ class SLSolver(object): | ... | @@ -223,6 +222,18 @@ class SLSolver(object): |
223 | map_key_text = 'find_top_text' | 222 | map_key_text = 'find_top_text' |
224 | map_key_value = 'find_value' | 223 | map_key_value = 'find_value' |
225 | test_group_id = [1, 2, 5, 9, 20, 15, 16, 22, 24, 28] | 224 | test_group_id = [1, 2, 5, 9, 20, 15, 16, 22, 24, 28] |
225 | fix_pred_methods = [ | ||
226 | ('only_date', {}), | ||
227 | ('only_digit', {}), | ||
228 | ('do_nothing', {}), | ||
229 | ('do_nothing', {}), | ||
230 | ('remove_start', {'start_char': '电话'}), | ||
231 | ('only_digit_alpha', {}), | ||
232 | ('do_nothing', {}), | ||
233 | ('remove_start', {'start_char': '账号'}), | ||
234 | ('remove_bank', {}), | ||
235 | ('only_amount', {}), | ||
236 | ] | ||
226 | group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写'] | 237 | group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写'] |
227 | skip_list_valid = [ | 238 | skip_list_valid = [ |
228 | # 'CH-B102897920-2.jpg', | 239 | # 'CH-B102897920-2.jpg', |
... | @@ -338,12 +349,17 @@ class SLSolver(object): | ... | @@ -338,12 +349,17 @@ class SLSolver(object): |
338 | group_text_list.append(None) | 349 | group_text_list.append(None) |
339 | 350 | ||
340 | for idx, text in enumerate(group_text_list): | 351 | for idx, text in enumerate(group_text_list): |
352 | if '#' in text: | ||
353 | continue | ||
354 | |||
341 | key_cn = group_cn_list[idx+1] | 355 | key_cn = group_cn_list[idx+1] |
342 | 356 | ||
343 | pred_idx_list = bbox_text_dict.get(idx) | 357 | pred_idx_list = bbox_text_dict.get(idx) |
344 | if isinstance(pred_idx_list, list): | 358 | if isinstance(pred_idx_list, list): |
345 | pred_text_list = [go_res_list[idx][-1] for idx in pred_idx_list] | 359 | pred_text_list = [go_res_list[idx][-1] for idx in pred_idx_list] |
346 | pred_text = ' '.join(pred_text_list) | 360 | pred_text_src = ''.join(pred_text_list) |
361 | # pred_text = pred_text_src | ||
362 | pred_text = getattr(fix_text_obj, fix_pred_methods[idx][0])(pred_text_src, **fix_pred_methods[idx][1]) | ||
347 | else: | 363 | else: |
348 | pred_text = None | 364 | pred_text = None |
349 | 365 | ||
... | @@ -356,7 +372,7 @@ class SLSolver(object): | ... | @@ -356,7 +372,7 @@ class SLSolver(object): |
356 | # break | 372 | # break |
357 | 373 | ||
358 | for key_cn, (correct_count, all_count) in data_dict.items(): | 374 | for key_cn, (correct_count, all_count) in data_dict.items(): |
359 | print('{0}: {1}'.format(key_cn, round(correct_count/all_count, 2))) | 375 | print('{0}: {1}'.format(key_cn, round(correct_count/all_count, 4))) |
360 | 376 | ||
361 | print('===========================') | 377 | print('===========================') |
362 | 378 | ... | ... |
utils/fix_pred.py
0 → 100644
1 | import re | ||
2 | |||
3 | class FixText: | ||
4 | |||
5 | @staticmethod | ||
6 | def do_nothing(pred_text_src): | ||
7 | return pred_text_src | ||
8 | |||
9 | @staticmethod | ||
10 | def only_date(pred_text_src): | ||
11 | re_se = re.search(r'20.*', pred_text_src) | ||
12 | if re_se: | ||
13 | return re_se.group() | ||
14 | else: | ||
15 | return pred_text_src | ||
16 | |||
17 | @staticmethod | ||
18 | def only_digit(pred_text_src): | ||
19 | re_se = re.search(r'\d+', pred_text_src) | ||
20 | if re_se: | ||
21 | return re_se.group() | ||
22 | else: | ||
23 | return pred_text_src | ||
24 | |||
25 | @staticmethod | ||
26 | def remove_start(pred_text_src, start_char='电话'): | ||
27 | if pred_text_src.startswith(start_char): | ||
28 | return pred_text_src.replace(start_char, '') | ||
29 | else: | ||
30 | return pred_text_src | ||
31 | |||
32 | @staticmethod | ||
33 | def only_digit_alpha(pred_text_src): | ||
34 | re_se = re.search(r'\w+', pred_text_src) | ||
35 | if re_se: | ||
36 | return re_se.group() | ||
37 | else: | ||
38 | return pred_text_src | ||
39 | |||
40 | @staticmethod | ||
41 | def remove_bank(pred_text_src): | ||
42 | re_se = re.search(r'户银行(.*)', pred_text_src) | ||
43 | if re_se: | ||
44 | return re_se.group(1) | ||
45 | else: | ||
46 | return pred_text_src | ||
47 | |||
48 | @staticmethod | ||
49 | def only_amount(pred_text_src): | ||
50 | re_se = re.search(r'\d+[-,\.]\d+', pred_text_src) | ||
51 | if re_se: | ||
52 | return re_se.group().replace('-', '.').replace(',', '.') | ||
53 | else: | ||
54 | return pred_text_src | ||
55 | |||
56 | fix_text_obj = FixText() |
-
Please register or sign in to post a comment