41252450 by 周伟奇

add statistics

1 parent fb66a889
...@@ -37,6 +37,7 @@ solver: ...@@ -37,6 +37,7 @@ solver:
37 base_on: null 37 base_on: null
38 model_path: null 38 model_path: null
39 val_image_path: '/labeled/valid/image' 39 val_image_path: '/labeled/valid/image'
40 val_label_path: '/labeled/valid/label'
40 val_go_path: '/go_res/valid' 41 val_go_path: '/go_res/valid'
41 val_map_path: '/dataset160x14/create_map.json' 42 val_map_path: '/dataset160x14/create_map.json'
42 draw_font_path: '/dataset160x14/STZHONGS.TTF' 43 draw_font_path: '/dataset160x14/STZHONGS.TTF'
......
...@@ -38,6 +38,7 @@ class SLSolver(object): ...@@ -38,6 +38,7 @@ class SLSolver(object):
38 self.base_on = self.hyper_params['base_on'] 38 self.base_on = self.hyper_params['base_on']
39 self.model_path = self.hyper_params['model_path'] 39 self.model_path = self.hyper_params['model_path']
40 self.val_image_path = self.hyper_params['val_image_path'] 40 self.val_image_path = self.hyper_params['val_image_path']
41 self.val_label_path = self.hyper_params['val_label_path']
41 self.val_go_path = self.hyper_params['val_go_path'] 42 self.val_go_path = self.hyper_params['val_go_path']
42 self.val_map_path = self.hyper_params['val_map_path'] 43 self.val_map_path = self.hyper_params['val_map_path']
43 self.draw_font_path = self.hyper_params['draw_font_path'] 44 self.draw_font_path = self.hyper_params['draw_font_path']
...@@ -198,6 +199,10 @@ class SLSolver(object): ...@@ -198,6 +199,10 @@ class SLSolver(object):
198 print('Warn: val_image_path not exists: {0}'.format(self.val_image_path)) 199 print('Warn: val_image_path not exists: {0}'.format(self.val_image_path))
199 return 200 return
200 201
202 if not os.path.isdir(self.val_label_path):
203 print('Warn: val_label_path not exists: {0}'.format(self.val_label_path))
204 return
205
201 if not os.path.isdir(self.val_go_path): 206 if not os.path.isdir(self.val_go_path):
202 print('Warn: val_go_path not exists: {0}'.format(self.val_go_path)) 207 print('Warn: val_go_path not exists: {0}'.format(self.val_go_path))
203 return 208 return
...@@ -217,6 +222,7 @@ class SLSolver(object): ...@@ -217,6 +222,7 @@ class SLSolver(object):
217 map_key_input = 'x_y_valid_lens' 222 map_key_input = 'x_y_valid_lens'
218 map_key_text = 'find_top_text' 223 map_key_text = 'find_top_text'
219 map_key_value = 'find_value' 224 map_key_value = 'find_value'
225 test_group_id = [1, 2, 5, 9, 20, 15, 16, 22, 24, 28]
220 group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写'] 226 group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写']
221 skip_list_valid = [ 227 skip_list_valid = [
222 # 'CH-B102897920-2.jpg', 228 # 'CH-B102897920-2.jpg',
...@@ -235,6 +241,8 @@ class SLSolver(object): ...@@ -235,6 +241,8 @@ class SLSolver(object):
235 with open(self.val_map_path, 'r') as fp: 241 with open(self.val_map_path, 'r') as fp:
236 val_map = json.load(fp) 242 val_map = json.load(fp)
237 243
244 data_dict = {key_cn: [0, 0] for key_cn in group_cn_list[1:]}
245 failed_dict = dict()
238 for img_name in sorted(os.listdir(self.val_image_path)): 246 for img_name in sorted(os.listdir(self.val_image_path)):
239 if img_name in skip_list_valid: 247 if img_name in skip_list_valid:
240 continue 248 continue
...@@ -281,7 +289,11 @@ class SLSolver(object): ...@@ -281,7 +289,11 @@ class SLSolver(object):
281 289
282 correct = 0 290 correct = 0
283 bbox_draw_dict = dict() 291 bbox_draw_dict = dict()
292 bbox_text_dict = dict()
284 for i in range(valid_lens_scalar): 293 for i in range(valid_lens_scalar):
294 if pred[i] != 0:
295 bbox_text_dict.setdefault(test_group_id[pred[i]-1], list()).append(i)
296
285 if pred[i] == label[i]: 297 if pred[i] == label[i]:
286 correct += 1 298 correct += 1
287 if pred[i] != 0: 299 if pred[i] != 0:
...@@ -311,8 +323,46 @@ class SLSolver(object): ...@@ -311,8 +323,46 @@ class SLSolver(object):
311 323
312 img_pil.save(os.path.join(save_dir, img_name)) 324 img_pil.save(os.path.join(save_dir, img_name))
313 325
314 # break 326 # 统计准确率
327 label_json_path = os.path.join(self.val_label_path, '{0}.json'.format(base_image_name))
328 with open(label_json_path, 'r') as fp:
329 label_res = json.load(fp)
330
331 group_text_list = []
332 for group_id in test_group_id:
333 for item in label_res.get("shapes", []):
334 if item.get("group_id") == group_id:
335 group_text_list.append(item['label'])
336 break
337 else:
338 group_text_list.append(None)
315 339
316 340 for idx, text in enumerate(group_text_list):
341 key_cn = group_cn_list[idx+1]
317 342
318 343 pred_idx_list = bbox_text_dict.get(idx)
344 if isinstance(pred_idx_list, list):
345 pred_text_list = [go_res_list[idx][-1] for idx in pred_idx_list]
346 pred_text = ' '.join(pred_text_list)
347 else:
348 pred_text = None
349
350 data_dict[key_cn][-1] += 1
351 if pred_text == text:
352 data_dict[key_cn][0] += 1
353 else:
354 failed_dict.setdefault(key_cn, list()).append((text, pred_text))
355
356 # break
357
358 for key_cn, (correct_count, all_count) in data_dict.ietms():
359 print('{0}: {1}'.format(key_cn, round(correct_count/all_count, 2)))
360
361 print('===========================')
362
363 for key_cn, failed_list in failed_dict.items():
364 print(key_cn)
365 for text, pred_text in failed_list:
366 print('label: {0} pred: {1}'.format(text, pred_text))
367 print('----------------------------------')
368
...\ No newline at end of file ...\ No newline at end of file
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!