fix bug
Showing
2 changed files
with
33 additions
and
33 deletions
... | @@ -300,7 +300,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -300,7 +300,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
300 | if __name__ == '__main__': | 300 | if __name__ == '__main__': |
301 | base_dir = '/Users/zhouweiqi/Downloads/gcfp/data' | 301 | base_dir = '/Users/zhouweiqi/Downloads/gcfp/data' |
302 | go_dir = os.path.join(base_dir, 'go_res') | 302 | go_dir = os.path.join(base_dir, 'go_res') |
303 | dataset_save_dir = os.path.join(base_dir, 'dataset160x14-pro') | 303 | dataset_save_dir = os.path.join(base_dir, 'dataset160x14-pro-all-valid') |
304 | label_dir = os.path.join(base_dir, 'labeled') | 304 | label_dir = os.path.join(base_dir, 'labeled') |
305 | 305 | ||
306 | train_go_path = os.path.join(go_dir, 'train') | 306 | train_go_path = os.path.join(go_dir, 'train') |
... | @@ -360,14 +360,14 @@ if __name__ == '__main__': | ... | @@ -360,14 +360,14 @@ if __name__ == '__main__': |
360 | ] | 360 | ] |
361 | 361 | ||
362 | skip_list_train = [ | 362 | skip_list_train = [ |
363 | # 'CH-B101910792-page-12.jpg', | 363 | 'CH-B101910792-page-12.jpg', |
364 | # 'CH-B101655312-page-13.jpg', | 364 | 'CH-B101655312-page-13.jpg', |
365 | # 'CH-B102278656.jpg', | 365 | 'CH-B102278656.jpg', |
366 | # 'CH-B101846620_page_1_img_0.jpg', | 366 | 'CH-B101846620_page_1_img_0.jpg', |
367 | # 'CH-B103062528-0.jpg', | 367 | 'CH-B103062528-0.jpg', |
368 | # 'CH-B102613120-3.jpg', | 368 | 'CH-B102613120-3.jpg', |
369 | # 'CH-B102997980-3.jpg', | 369 | 'CH-B102997980-3.jpg', |
370 | # 'CH-B102680060-3.jpg', | 370 | 'CH-B102680060-3.jpg', |
371 | # # 'CH-B102995500-2.jpg', # 没value | 371 | # # 'CH-B102995500-2.jpg', # 没value |
372 | ] | 372 | ] |
373 | 373 | ... | ... |
... | @@ -292,36 +292,36 @@ class SLSolver(object): | ... | @@ -292,36 +292,36 @@ class SLSolver(object): |
292 | bbox_text_dict = dict() | 292 | bbox_text_dict = dict() |
293 | for i in range(valid_lens_scalar): | 293 | for i in range(valid_lens_scalar): |
294 | if pred[i] != 0: | 294 | if pred[i] != 0: |
295 | bbox_text_dict.setdefault(test_group_id[pred[i]-1], list()).append(i) | 295 | bbox_text_dict.setdefault(pred[i]-1, list()).append(i) |
296 | 296 | ||
297 | if pred[i] == label[i]: | 297 | # if pred[i] == label[i]: |
298 | correct += 1 | 298 | # correct += 1 |
299 | if pred[i] != 0: | 299 | # if pred[i] != 0: |
300 | # 绿色 | 300 | # # 绿色 |
301 | bbox_draw_dict[i] = (group_cn_list[pred[i]], ) | 301 | # bbox_draw_dict[i] = (group_cn_list[pred[i]], ) |
302 | else: | 302 | # else: |
303 | # 红色:左上角label,右上角pred | 303 | # # 红色:左上角label,右上角pred |
304 | bbox_draw_dict[i] = (group_cn_list[label[i]], group_cn_list[pred[i]]) | 304 | # bbox_draw_dict[i] = (group_cn_list[label[i]], group_cn_list[pred[i]]) |
305 | 305 | ||
306 | correct_rate = correct / valid_lens_scalar | 306 | # correct_rate = correct / valid_lens_scalar |
307 | 307 | ||
308 | # 画图 | 308 | # 画图 |
309 | for idx, text_tuple in bbox_draw_dict.items(): | 309 | # for idx, text_tuple in bbox_draw_dict.items(): |
310 | (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[idx] | 310 | # (x0, y0, x1, y1, x2, y2, x3, y3), _ = go_res_list[idx] |
311 | line_color = green_color if len(text_tuple) == 1 else red_color | 311 | # line_color = green_color if len(text_tuple) == 1 else red_color |
312 | draw.polygon([(x0, y0), (x1, y1), (x2, y2), (x3, y3)], outline=line_color) | 312 | # draw.polygon([(x0, y0), (x1, y1), (x2, y2), (x3, y3)], outline=line_color) |
313 | draw.text((int(x0), int(y0)), text_tuple[0], green_color, font=font) | 313 | # draw.text((int(x0), int(y0)), text_tuple[0], green_color, font=font) |
314 | if len(text_tuple) == 2: | 314 | # if len(text_tuple) == 2: |
315 | draw.text((int(x1), int(y1)), text_tuple[1], red_color, font=font) | 315 | # draw.text((int(x1), int(y1)), text_tuple[1], red_color, font=font) |
316 | 316 | ||
317 | draw.text((0, 0), str(correct_rate), blue_color, font=font) | 317 | # draw.text((0, 0), str(correct_rate), blue_color, font=font) |
318 | 318 | ||
319 | last_y = size | 319 | # last_y = size |
320 | for k, v in val_map[img_name][map_key_value].items(): | 320 | # for k, v in val_map[img_name][map_key_value].items(): |
321 | draw.text((0, last_y), '{0}: {1}'.format(k, v), blue_color, font=font) | 321 | # draw.text((0, last_y), '{0}: {1}'.format(k, v), blue_color, font=font) |
322 | last_y += size | 322 | # last_y += size |
323 | 323 | ||
324 | img_pil.save(os.path.join(save_dir, img_name)) | 324 | # img_pil.save(os.path.join(save_dir, img_name)) |
325 | 325 | ||
326 | # 统计准确率 | 326 | # 统计准确率 |
327 | label_json_path = os.path.join(self.val_label_path, '{0}.json'.format(base_image_name)) | 327 | label_json_path = os.path.join(self.val_label_path, '{0}.json'.format(base_image_name)) |
... | @@ -355,7 +355,7 @@ class SLSolver(object): | ... | @@ -355,7 +355,7 @@ class SLSolver(object): |
355 | 355 | ||
356 | # break | 356 | # break |
357 | 357 | ||
358 | for key_cn, (correct_count, all_count) in data_dict.ietms(): | 358 | for key_cn, (correct_count, all_count) in data_dict.items(): |
359 | print('{0}: {1}'.format(key_cn, round(correct_count/all_count, 2))) | 359 | print('{0}: {1}'.format(key_cn, round(correct_count/all_count, 2))) |
360 | 360 | ||
361 | print('===========================') | 361 | print('===========================') | ... | ... |
-
Please register or sign in to post a comment