331c7717 by 周伟奇

add fix pred text

1 parent e573fab0
...@@ -239,7 +239,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save ...@@ -239,7 +239,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
239 [label_bbox[0], label_bbox[3]], 239 [label_bbox[0], label_bbox[3]],
240 ] 240 ]
241 iou = bbox_iou(go_bbox_rebuild, label_bbox_rebuild) 241 iou = bbox_iou(go_bbox_rebuild, label_bbox_rebuild)
242 if iou >= 0.5: 242 if iou >= 0.4:
243 label_idx_dict[go_idx] = label_idx 243 label_idx_dict[go_idx] = label_idx
244 244
245 X = list() 245 X = list()
...@@ -250,7 +250,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save ...@@ -250,7 +250,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
250 # text_vec_max_lens = 15 * 50 250 # text_vec_max_lens = 15 * 50
251 # dim = 1 + 5 + 8 + text_vec_max_lens 251 # dim = 1 + 5 + 8 + text_vec_max_lens
252 252
253 max_jieba_char = 8 253 max_jieba_char = 4
254 text_vec_max_lens = max_jieba_char * 100 254 text_vec_max_lens = max_jieba_char * 100
255 dim = 1 + 5 + 8 + text_vec_max_lens 255 dim = 1 + 5 + 8 + text_vec_max_lens
256 256
...@@ -333,7 +333,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save ...@@ -333,7 +333,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
333 if __name__ == '__main__': 333 if __name__ == '__main__':
334 base_dir = '/Users/zhouweiqi/Downloads/gcfp/data' 334 base_dir = '/Users/zhouweiqi/Downloads/gcfp/data'
335 go_dir = os.path.join(base_dir, 'go_res') 335 go_dir = os.path.join(base_dir, 'go_res')
336 dataset_save_dir = os.path.join(base_dir, 'dataset160x814') 336 dataset_save_dir = os.path.join(base_dir, 'dataset160x414')
337 label_dir = os.path.join(base_dir, 'labeled') 337 label_dir = os.path.join(base_dir, 'labeled')
338 338
339 train_go_path = os.path.join(go_dir, 'train') 339 train_go_path = os.path.join(go_dir, 'train')
......
...@@ -10,8 +10,7 @@ from data import build_dataloader ...@@ -10,8 +10,7 @@ from data import build_dataloader
10 from loss import build_loss 10 from loss import build_loss
11 from model import build_model 11 from model import build_model
12 from optimizer import build_lr_scheduler, build_optimizer 12 from optimizer import build_lr_scheduler, build_optimizer
13 from utils import SOLVER_REGISTRY, get_logger_and_log_dir 13 from utils import SOLVER_REGISTRY, get_logger_and_log_dir, sequence_mask, fix_text_obj
14 from utils import sequence_mask
15 from sklearn.metrics import confusion_matrix, accuracy_score, classification_report 14 from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
16 15
17 16
...@@ -223,6 +222,18 @@ class SLSolver(object): ...@@ -223,6 +222,18 @@ class SLSolver(object):
223 map_key_text = 'find_top_text' 222 map_key_text = 'find_top_text'
224 map_key_value = 'find_value' 223 map_key_value = 'find_value'
225 test_group_id = [1, 2, 5, 9, 20, 15, 16, 22, 24, 28] 224 test_group_id = [1, 2, 5, 9, 20, 15, 16, 22, 24, 28]
225 fix_pred_methods = [
226 ('only_date', {}),
227 ('only_digit', {}),
228 ('do_nothing', {}),
229 ('do_nothing', {}),
230 ('remove_start', {'start_char': '电话'}),
231 ('only_digit_alpha', {}),
232 ('do_nothing', {}),
233 ('remove_start', {'start_char': '账号'}),
234 ('remove_bank', {}),
235 ('only_amount', {}),
236 ]
226 group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写'] 237 group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写']
227 skip_list_valid = [ 238 skip_list_valid = [
228 # 'CH-B102897920-2.jpg', 239 # 'CH-B102897920-2.jpg',
...@@ -338,12 +349,17 @@ class SLSolver(object): ...@@ -338,12 +349,17 @@ class SLSolver(object):
338 group_text_list.append(None) 349 group_text_list.append(None)
339 350
340 for idx, text in enumerate(group_text_list): 351 for idx, text in enumerate(group_text_list):
352 if '#' in text:
353 continue
354
341 key_cn = group_cn_list[idx+1] 355 key_cn = group_cn_list[idx+1]
342 356
343 pred_idx_list = bbox_text_dict.get(idx) 357 pred_idx_list = bbox_text_dict.get(idx)
344 if isinstance(pred_idx_list, list): 358 if isinstance(pred_idx_list, list):
345 pred_text_list = [go_res_list[idx][-1] for idx in pred_idx_list] 359 pred_text_list = [go_res_list[idx][-1] for idx in pred_idx_list]
346 pred_text = ' '.join(pred_text_list) 360 pred_text_src = ''.join(pred_text_list)
361 # pred_text = pred_text_src
362 pred_text = getattr(fix_text_obj, fix_pred_methods[idx][0])(pred_text_src, **fix_pred_methods[idx][1])
347 else: 363 else:
348 pred_text = None 364 pred_text = None
349 365
...@@ -356,7 +372,7 @@ class SLSolver(object): ...@@ -356,7 +372,7 @@ class SLSolver(object):
356 # break 372 # break
357 373
358 for key_cn, (correct_count, all_count) in data_dict.items(): 374 for key_cn, (correct_count, all_count) in data_dict.items():
359 print('{0}: {1}'.format(key_cn, round(correct_count/all_count, 2))) 375 print('{0}: {1}'.format(key_cn, round(correct_count/all_count, 4)))
360 376
361 print('===========================') 377 print('===========================')
362 378
......
1 import torch 1 import torch
2 from .registery import * 2 from .registery import *
3 from .logger import get_logger_and_log_dir 3 from .logger import get_logger_and_log_dir
4 from .fix_pred import fix_text_obj
4 5
5 __all__ = [ 6 __all__ = [
6 'Registry', 7 'Registry',
......
1 import re
2
3 class FixText:
4
5 @staticmethod
6 def do_nothing(pred_text_src):
7 return pred_text_src
8
9 @staticmethod
10 def only_date(pred_text_src):
11 re_se = re.search(r'20.*', pred_text_src)
12 if re_se:
13 return re_se.group()
14 else:
15 return pred_text_src
16
17 @staticmethod
18 def only_digit(pred_text_src):
19 re_se = re.search(r'\d+', pred_text_src)
20 if re_se:
21 return re_se.group()
22 else:
23 return pred_text_src
24
25 @staticmethod
26 def remove_start(pred_text_src, start_char='电话'):
27 if pred_text_src.startswith(start_char):
28 return pred_text_src.replace(start_char, '')
29 else:
30 return pred_text_src
31
32 @staticmethod
33 def only_digit_alpha(pred_text_src):
34 re_se = re.search(r'\w+', pred_text_src)
35 if re_se:
36 return re_se.group()
37 else:
38 return pred_text_src
39
40 @staticmethod
41 def remove_bank(pred_text_src):
42 re_se = re.search(r'户银行(.*)', pred_text_src)
43 if re_se:
44 return re_se.group(1)
45 else:
46 return pred_text_src
47
48 @staticmethod
49 def only_amount(pred_text_src):
50 re_se = re.search(r'\d+[-,\.]\d+', pred_text_src)
51 if re_se:
52 return re_se.group().replace('-', '.').replace(',', '.')
53 else:
54 return pred_text_src
55
56 fix_text_obj = FixText()
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!