From 41252450bb79b46d9f6621531519db0e8965f420 Mon Sep 17 00:00:00 2001 From: zhouweiqi <zhouweiqi@situdata.com> Date: Thu, 22 Dec 2022 18:29:08 +0800 Subject: [PATCH] add statistics --- config/sl.yaml | 1 + solver/sl_solver.py | 56 +++++++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/config/sl.yaml b/config/sl.yaml index 39b74fc..5ae1745 100644 --- a/config/sl.yaml +++ b/config/sl.yaml @@ -37,6 +37,7 @@ solver: base_on: null model_path: null val_image_path: '/labeled/valid/image' + val_label_path: '/labeled/valid/label' val_go_path: '/go_res/valid' val_map_path: '/dataset160x14/create_map.json' draw_font_path: '/dataset160x14/STZHONGS.TTF' diff --git a/solver/sl_solver.py b/solver/sl_solver.py index 708ae46..6764fbc 100644 --- a/solver/sl_solver.py +++ b/solver/sl_solver.py @@ -38,6 +38,7 @@ class SLSolver(object): self.base_on = self.hyper_params['base_on'] self.model_path = self.hyper_params['model_path'] self.val_image_path = self.hyper_params['val_image_path'] + self.val_label_path = self.hyper_params['val_label_path'] self.val_go_path = self.hyper_params['val_go_path'] self.val_map_path = self.hyper_params['val_map_path'] self.draw_font_path = self.hyper_params['draw_font_path'] @@ -198,6 +199,10 @@ class SLSolver(object): print('Warn: val_image_path not exists: {0}'.format(self.val_image_path)) return + if not os.path.isdir(self.val_label_path): + print('Warn: val_label_path not exists: {0}'.format(self.val_label_path)) + return + if not os.path.isdir(self.val_go_path): print('Warn: val_go_path not exists: {0}'.format(self.val_go_path)) return @@ -217,6 +222,7 @@ class SLSolver(object): map_key_input = 'x_y_valid_lens' map_key_text = 'find_top_text' map_key_value = 'find_value' + test_group_id = [1, 2, 5, 9, 20, 15, 16, 22, 24, 28] group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写'] skip_list_valid = [ # 'CH-B102897920-2.jpg', @@ -235,6 +241,8 @@ class SLSolver(object): with open(self.val_map_path, 'r') as fp: val_map = json.load(fp) + data_dict = {key_cn: [0, 0] for key_cn in group_cn_list[1:]} + failed_dict = dict() for img_name in sorted(os.listdir(self.val_image_path)): if img_name in skip_list_valid: continue @@ -281,7 +289,11 @@ class SLSolver(object): correct = 0 bbox_draw_dict = dict() + bbox_text_dict = dict() for i in range(valid_lens_scalar): + if pred[i] != 0: + bbox_text_dict.setdefault(test_group_id[pred[i]-1], list()).append(i) + if pred[i] == label[i]: correct += 1 if pred[i] != 0: @@ -311,8 +323,46 @@ class SLSolver(object): img_pil.save(os.path.join(save_dir, img_name)) - # break + # 统计准确率 + label_json_path = os.path.join(self.val_label_path, '{0}.json'.format(base_image_name)) + with open(label_json_path, 'r') as fp: + label_res = json.load(fp) + + group_text_list = [] + for group_id in test_group_id: + for item in label_res.get("shapes", []): + if item.get("group_id") == group_id: + group_text_list.append(item['label']) + break + else: + group_text_list.append(None) - + for idx, text in enumerate(group_text_list): + key_cn = group_cn_list[idx+1] - + pred_idx_list = bbox_text_dict.get(idx) + if isinstance(pred_idx_list, list): + pred_text_list = [go_res_list[idx][-1] for idx in pred_idx_list] + pred_text = ' '.join(pred_text_list) + else: + pred_text = None + + data_dict[key_cn][-1] += 1 + if pred_text == text: + data_dict[key_cn][0] += 1 + else: + failed_dict.setdefault(key_cn, list()).append((text, pred_text)) + + # break + + for key_cn, (correct_count, all_count) in data_dict.ietms(): + print('{0}: {1}'.format(key_cn, round(correct_count/all_count, 2))) + + print('===========================') + + for key_cn, failed_list in failed_dict.items(): + print(key_cn) + for text, pred_text in failed_list: + print('label: {0} pred: {1}'.format(text, pred_text)) + print('----------------------------------') + \ No newline at end of file -- libgit2 0.24.0