From 41252450bb79b46d9f6621531519db0e8965f420 Mon Sep 17 00:00:00 2001
From: zhouweiqi <zhouweiqi@situdata.com>
Date: Thu, 22 Dec 2022 18:29:08 +0800
Subject: [PATCH] add statistics

---
 config/sl.yaml      |  1 +
 solver/sl_solver.py | 56 +++++++++++++++++++++++++++++++++++++++++++++++++++++---
 2 files changed, 54 insertions(+), 3 deletions(-)

diff --git a/config/sl.yaml b/config/sl.yaml
index 39b74fc..5ae1745 100644
--- a/config/sl.yaml
+++ b/config/sl.yaml
@@ -37,6 +37,7 @@ solver:
     base_on: null
     model_path: null
     val_image_path: '/labeled/valid/image'
+    val_label_path: '/labeled/valid/label'
     val_go_path: '/go_res/valid'
     val_map_path: '/dataset160x14/create_map.json'
     draw_font_path: '/dataset160x14/STZHONGS.TTF'
diff --git a/solver/sl_solver.py b/solver/sl_solver.py
index 708ae46..6764fbc 100644
--- a/solver/sl_solver.py
+++ b/solver/sl_solver.py
@@ -38,6 +38,7 @@ class SLSolver(object):
         self.base_on = self.hyper_params['base_on']
         self.model_path = self.hyper_params['model_path']
         self.val_image_path = self.hyper_params['val_image_path']
+        self.val_label_path = self.hyper_params['val_label_path']
         self.val_go_path = self.hyper_params['val_go_path']
         self.val_map_path = self.hyper_params['val_map_path']
         self.draw_font_path = self.hyper_params['draw_font_path']
@@ -198,6 +199,10 @@ class SLSolver(object):
             print('Warn: val_image_path not exists: {0}'.format(self.val_image_path))    
             return
 
+        if not os.path.isdir(self.val_label_path):
+            print('Warn: val_label_path not exists: {0}'.format(self.val_label_path))    
+            return
+
         if not os.path.isdir(self.val_go_path):
             print('Warn: val_go_path not exists: {0}'.format(self.val_go_path))    
             return
@@ -217,6 +222,7 @@ class SLSolver(object):
         map_key_input = 'x_y_valid_lens'
         map_key_text = 'find_top_text'
         map_key_value = 'find_value'
+        test_group_id = [1, 2, 5, 9, 20, 15, 16, 22, 24, 28]
         group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写']
         skip_list_valid = [
             # 'CH-B102897920-2.jpg',
@@ -235,6 +241,8 @@ class SLSolver(object):
         with open(self.val_map_path, 'r') as fp:
             val_map = json.load(fp) 
         
+        data_dict = {key_cn: [0, 0] for key_cn in group_cn_list[1:]}
+        failed_dict = dict()
         for img_name in sorted(os.listdir(self.val_image_path)):
             if img_name in skip_list_valid:
                 continue
@@ -281,7 +289,11 @@ class SLSolver(object):
 
             correct = 0
             bbox_draw_dict = dict()
+            bbox_text_dict = dict()
             for i in range(valid_lens_scalar):
+                if pred[i] != 0:
+                    bbox_text_dict.setdefault(test_group_id[pred[i]-1], list()).append(i) 
+
                 if pred[i] == label[i]:
                     correct += 1
                     if pred[i] != 0:
@@ -311,8 +323,46 @@ class SLSolver(object):
 
             img_pil.save(os.path.join(save_dir, img_name))
 
-            # break
+            # 统计准确率
+            label_json_path = os.path.join(self.val_label_path, '{0}.json'.format(base_image_name))
+            with open(label_json_path, 'r') as fp:
+                label_res = json.load(fp)
+
+            group_text_list = []
+            for group_id in test_group_id:
+                for item in label_res.get("shapes", []):
+                    if item.get("group_id") == group_id:
+                        group_text_list.append(item['label'])
+                        break
+                else:
+                    group_text_list.append(None)
 
-            
+            for idx, text in enumerate(group_text_list):
+                key_cn = group_cn_list[idx+1]
 
-            
+                pred_idx_list = bbox_text_dict.get(idx)
+                if isinstance(pred_idx_list, list):
+                    pred_text_list = [go_res_list[idx][-1] for idx in pred_idx_list]
+                    pred_text = ' '.join(pred_text_list)
+                else:
+                    pred_text = None
+
+                data_dict[key_cn][-1] += 1
+                if pred_text == text:
+                    data_dict[key_cn][0] += 1 
+                else:
+                    failed_dict.setdefault(key_cn, list()).append((text, pred_text))
+
+            # break
+        
+        for key_cn, (correct_count, all_count) in data_dict.ietms():
+            print('{0}: {1}'.format(key_cn, round(correct_count/all_count, 2)))
+    
+        print('===========================')
+
+        for key_cn, failed_list in failed_dict.items():
+            print(key_cn)
+            for text, pred_text in failed_list:
+                print('label: {0} pred: {1}'.format(text, pred_text))
+            print('----------------------------------') 
+            
\ No newline at end of file
--
libgit2 0.24.0