331c7717 by 周伟奇

add fix pred text

1 parent e573fab0
......@@ -239,7 +239,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
[label_bbox[0], label_bbox[3]],
]
iou = bbox_iou(go_bbox_rebuild, label_bbox_rebuild)
if iou >= 0.5:
if iou >= 0.4:
label_idx_dict[go_idx] = label_idx
X = list()
......@@ -250,7 +250,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
# text_vec_max_lens = 15 * 50
# dim = 1 + 5 + 8 + text_vec_max_lens
max_jieba_char = 8
max_jieba_char = 4
text_vec_max_lens = max_jieba_char * 100
dim = 1 + 5 + 8 + text_vec_max_lens
......@@ -333,7 +333,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
if __name__ == '__main__':
base_dir = '/Users/zhouweiqi/Downloads/gcfp/data'
go_dir = os.path.join(base_dir, 'go_res')
dataset_save_dir = os.path.join(base_dir, 'dataset160x814')
dataset_save_dir = os.path.join(base_dir, 'dataset160x414')
label_dir = os.path.join(base_dir, 'labeled')
train_go_path = os.path.join(go_dir, 'train')
......
......@@ -10,8 +10,7 @@ from data import build_dataloader
from loss import build_loss
from model import build_model
from optimizer import build_lr_scheduler, build_optimizer
from utils import SOLVER_REGISTRY, get_logger_and_log_dir
from utils import sequence_mask
from utils import SOLVER_REGISTRY, get_logger_and_log_dir, sequence_mask, fix_text_obj
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
......@@ -223,6 +222,18 @@ class SLSolver(object):
map_key_text = 'find_top_text'
map_key_value = 'find_value'
test_group_id = [1, 2, 5, 9, 20, 15, 16, 22, 24, 28]
fix_pred_methods = [
('only_date', {}),
('only_digit', {}),
('do_nothing', {}),
('do_nothing', {}),
('remove_start', {'start_char': '电话'}),
('only_digit_alpha', {}),
('do_nothing', {}),
('remove_start', {'start_char': '账号'}),
('remove_bank', {}),
('only_amount', {}),
]
group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写']
skip_list_valid = [
# 'CH-B102897920-2.jpg',
......@@ -338,12 +349,17 @@ class SLSolver(object):
group_text_list.append(None)
for idx, text in enumerate(group_text_list):
if '#' in text:
continue
key_cn = group_cn_list[idx+1]
pred_idx_list = bbox_text_dict.get(idx)
if isinstance(pred_idx_list, list):
pred_text_list = [go_res_list[idx][-1] for idx in pred_idx_list]
pred_text = ' '.join(pred_text_list)
pred_text_src = ''.join(pred_text_list)
# pred_text = pred_text_src
pred_text = getattr(fix_text_obj, fix_pred_methods[idx][0])(pred_text_src, **fix_pred_methods[idx][1])
else:
pred_text = None
......@@ -356,7 +372,7 @@ class SLSolver(object):
# break
for key_cn, (correct_count, all_count) in data_dict.items():
print('{0}: {1}'.format(key_cn, round(correct_count/all_count, 2)))
print('{0}: {1}'.format(key_cn, round(correct_count/all_count, 4)))
print('===========================')
......
import torch
from .registery import *
from .logger import get_logger_and_log_dir
from .fix_pred import fix_text_obj
__all__ = [
'Registry',
......
import re
class FixText:
@staticmethod
def do_nothing(pred_text_src):
return pred_text_src
@staticmethod
def only_date(pred_text_src):
re_se = re.search(r'20.*', pred_text_src)
if re_se:
return re_se.group()
else:
return pred_text_src
@staticmethod
def only_digit(pred_text_src):
re_se = re.search(r'\d+', pred_text_src)
if re_se:
return re_se.group()
else:
return pred_text_src
@staticmethod
def remove_start(pred_text_src, start_char='电话'):
if pred_text_src.startswith(start_char):
return pred_text_src.replace(start_char, '')
else:
return pred_text_src
@staticmethod
def only_digit_alpha(pred_text_src):
re_se = re.search(r'\w+', pred_text_src)
if re_se:
return re_se.group()
else:
return pred_text_src
@staticmethod
def remove_bank(pred_text_src):
re_se = re.search(r'户银行(.*)', pred_text_src)
if re_se:
return re_se.group(1)
else:
return pred_text_src
@staticmethod
def only_amount(pred_text_src):
re_se = re.search(r'\d+[-,\.]\d+', pred_text_src)
if re_se:
return re_se.group().replace('-', '.').replace(',', '.')
else:
return pred_text_src
fix_text_obj = FixText()
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!