41252450 by 周伟奇

add statistics

1 parent fb66a889
......@@ -37,6 +37,7 @@ solver:
base_on: null
model_path: null
val_image_path: '/labeled/valid/image'
val_label_path: '/labeled/valid/label'
val_go_path: '/go_res/valid'
val_map_path: '/dataset160x14/create_map.json'
draw_font_path: '/dataset160x14/STZHONGS.TTF'
......
......@@ -38,6 +38,7 @@ class SLSolver(object):
self.base_on = self.hyper_params['base_on']
self.model_path = self.hyper_params['model_path']
self.val_image_path = self.hyper_params['val_image_path']
self.val_label_path = self.hyper_params['val_label_path']
self.val_go_path = self.hyper_params['val_go_path']
self.val_map_path = self.hyper_params['val_map_path']
self.draw_font_path = self.hyper_params['draw_font_path']
......@@ -198,6 +199,10 @@ class SLSolver(object):
print('Warn: val_image_path not exists: {0}'.format(self.val_image_path))
return
if not os.path.isdir(self.val_label_path):
print('Warn: val_label_path not exists: {0}'.format(self.val_label_path))
return
if not os.path.isdir(self.val_go_path):
print('Warn: val_go_path not exists: {0}'.format(self.val_go_path))
return
......@@ -217,6 +222,7 @@ class SLSolver(object):
map_key_input = 'x_y_valid_lens'
map_key_text = 'find_top_text'
map_key_value = 'find_value'
test_group_id = [1, 2, 5, 9, 20, 15, 16, 22, 24, 28]
group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写']
skip_list_valid = [
# 'CH-B102897920-2.jpg',
......@@ -235,6 +241,8 @@ class SLSolver(object):
with open(self.val_map_path, 'r') as fp:
val_map = json.load(fp)
data_dict = {key_cn: [0, 0] for key_cn in group_cn_list[1:]}
failed_dict = dict()
for img_name in sorted(os.listdir(self.val_image_path)):
if img_name in skip_list_valid:
continue
......@@ -281,7 +289,11 @@ class SLSolver(object):
correct = 0
bbox_draw_dict = dict()
bbox_text_dict = dict()
for i in range(valid_lens_scalar):
if pred[i] != 0:
bbox_text_dict.setdefault(test_group_id[pred[i]-1], list()).append(i)
if pred[i] == label[i]:
correct += 1
if pred[i] != 0:
......@@ -311,8 +323,46 @@ class SLSolver(object):
img_pil.save(os.path.join(save_dir, img_name))
# break
# 统计准确率
label_json_path = os.path.join(self.val_label_path, '{0}.json'.format(base_image_name))
with open(label_json_path, 'r') as fp:
label_res = json.load(fp)
group_text_list = []
for group_id in test_group_id:
for item in label_res.get("shapes", []):
if item.get("group_id") == group_id:
group_text_list.append(item['label'])
break
else:
group_text_list.append(None)
for idx, text in enumerate(group_text_list):
key_cn = group_cn_list[idx+1]
pred_idx_list = bbox_text_dict.get(idx)
if isinstance(pred_idx_list, list):
pred_text_list = [go_res_list[idx][-1] for idx in pred_idx_list]
pred_text = ' '.join(pred_text_list)
else:
pred_text = None
data_dict[key_cn][-1] += 1
if pred_text == text:
data_dict[key_cn][0] += 1
else:
failed_dict.setdefault(key_cn, list()).append((text, pred_text))
# break
for key_cn, (correct_count, all_count) in data_dict.ietms():
print('{0}: {1}'.format(key_cn, round(correct_count/all_count, 2)))
print('===========================')
for key_cn, failed_list in failed_dict.items():
print(key_cn)
for text, pred_text in failed_list:
print('label: {0} pred: {1}'.format(text, pred_text))
print('----------------------------------')
\ No newline at end of file
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!