fb66a889 by 周伟奇

add iou

1 parent 092baca7
......@@ -6,9 +6,49 @@ import uuid
import cv2
import pandas as pd
import numpy as np
from shapely.geometry import Polygon, MultiPoint
from tools import get_file_paths, load_json
from word2vec import jwq_word2vec, simple_word2vec
def bbox_iou(go_bbox, label_bbox, mode='iou'):
# 所有点的最小凸的表示形式,四边形对象,会自动计算四个点,最后顺序为:左上 左下 右下 右上 左上
go_poly = Polygon(go_bbox).convex_hull
label_poly = Polygon(label_bbox).convex_hull
if not go_poly.is_valid or not label_poly.is_valid:
print('formatting errors for boxes!!!! ')
return 0
if go_poly.area == 0 or label_poly.area == 0 :
return 0
inter = Polygon(go_poly).intersection(Polygon(label_poly)).area
go_area = Polygon(go_poly).area
return inter / go_area
# if mode == 'iou':
# union = go_poly.area + label_poly.area - inter
# elif mode =='tiou':
# union_poly = np.concatenate((go_bbox, label_bbox)) #合并两个box坐标,变为8*2
# union = MultiPoint(union_poly).convex_hull.area
# # coors = MultiPoint(union_poly).convex_hull.wkt
# elif mode == 'giou':
# union_poly = np.concatenate((go_bbox, label_bbox))
# union = MultiPoint(union_poly).envelope.area
# # coors = MultiPoint(union_poly).envelope.wkt
# elif mode == 'r_giou':
# union_poly = np.concatenate((go_bbox, label_bbox))
# union = MultiPoint(union_poly).minimum_rotated_rectangle.area
# # coors = MultiPoint(union_poly).minimum_rotated_rectangle.wkt
# else:
# raise Exception('incorrect mode!')
# if union == 0:
# return 0
# else:
# return inter / union
def clean_go_res(go_res_dir):
go_res_json_paths = get_file_paths(go_res_dir, ['.json', ])
......@@ -32,7 +72,6 @@ def clean_go_res(go_res_dir):
json.dump(go_res_list, fp)
print('Rerewirte {0}'.format(go_res_json_path))
def char_length_statistics(go_res_dir):
max_char_length = None
target_file_name = None
......@@ -151,40 +190,35 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
for group_id in test_group_id:
for item in label_res.get("shapes", []):
if item.get("group_id") == group_id:
x_list = []
y_list = []
label_bbox = list()
for point in item['points']:
x_list.append(point[0])
y_list.append(point[1])
group_list.append([min(x_list) + (max(x_list) - min(x_list))/2, min(y_list) + (max(y_list) - min(y_list))/2])
label_bbox.extend(point)
group_list.append(label_bbox)
break
else:
group_list.append(None)
go_center_list = []
for (x0, y0, x1, y1, x2, y2, x3, y3), _ in go_res_list:
xmin = min(x0, x1, x2, x3)
ymin = min(y0, y1, y2, y3)
xmax = max(x0, x1, x2, x3)
ymax = max(y0, y1, y2, y3)
xcenter = xmin + (xmax - xmin)/2
ycenter = ymin + (ymax - ymin)/2
go_center_list.append((xcenter, ycenter))
label_idx_dict = dict()
for label_idx, label_center_list in enumerate(group_list):
if isinstance(label_center_list, list):
min_go_key = None
min_length = None
for go_idx, (go_x_center, go_y_center) in enumerate(go_center_list):
for label_idx, label_bbox in enumerate(group_list):
if isinstance(label_bbox, list):
for go_idx, (go_bbox, _) in enumerate(go_res_list):
if go_idx in top_text_idx_set or go_idx in label_idx_dict:
continue
length = abs(go_x_center-label_center_list[0])+abs(go_y_center-label_center_list[1])
if min_go_key is None or length < min_length:
min_go_key = go_idx
min_length = length
if min_go_key is not None:
label_idx_dict[min_go_key] = label_idx
go_bbox_rebuild = [
[go_bbox[0], go_bbox[1]],
[go_bbox[2], go_bbox[3]],
[go_bbox[4], go_bbox[5]],
[go_bbox[6], go_bbox[7]],
]
label_bbox_rebuild = [
[label_bbox[0], label_bbox[1]],
[label_bbox[2], label_bbox[1]],
[label_bbox[2], label_bbox[3]],
[label_bbox[0], label_bbox[3]],
]
iou = bbox_iou(go_bbox_rebuild, label_bbox_rebuild)
if iou >= 0.5:
label_idx_dict[go_idx] = label_idx
X = list()
y_true = list()
......@@ -239,19 +273,16 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
create_map[img_name] = {
'x_y_valid_lens': save_json_name,
'find_top_text': [go_res_list[i][-1] for i in top_text_idx_set],
'find_value': {group_cn_list[v]: go_res_list[k][-1] for k, v in label_idx_dict.items()}
'find_value': {go_res_list[k][-1]: group_cn_list[v] for k, v in label_idx_dict.items()}
}
# break
# print(create_map)
# print(is_create_map)
if create_map:
# print(create_map)
with open(os.path.join(os.path.dirname(save_dir), 'create_map.json'), 'w') as fp:
json.dump(create_map, fp)
# print('top text find:')
# for i in top_text_idx_set:
# _, text = go_res_list[i]
......@@ -269,7 +300,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
if __name__ == '__main__':
base_dir = '/Users/zhouweiqi/Downloads/gcfp/data'
go_dir = os.path.join(base_dir, 'go_res')
dataset_save_dir = os.path.join(base_dir, 'dataset160x14')
dataset_save_dir = os.path.join(base_dir, 'dataset160x14-pro')
label_dir = os.path.join(base_dir, 'labeled')
train_go_path = os.path.join(go_dir, 'train')
......@@ -329,23 +360,23 @@ if __name__ == '__main__':
]
skip_list_train = [
'CH-B101910792-page-12.jpg',
'CH-B101655312-page-13.jpg',
'CH-B102278656.jpg',
'CH-B101846620_page_1_img_0.jpg',
'CH-B103062528-0.jpg',
'CH-B102613120-3.jpg',
'CH-B102997980-3.jpg',
'CH-B102680060-3.jpg',
# 'CH-B102995500-2.jpg', # 没value
# 'CH-B101910792-page-12.jpg',
# 'CH-B101655312-page-13.jpg',
# 'CH-B102278656.jpg',
# 'CH-B101846620_page_1_img_0.jpg',
# 'CH-B103062528-0.jpg',
# 'CH-B102613120-3.jpg',
# 'CH-B102997980-3.jpg',
# 'CH-B102680060-3.jpg',
# # 'CH-B102995500-2.jpg', # 没value
]
skip_list_valid = [
'CH-B102897920-2.jpg',
'CH-B102551284-0.jpg',
'CH-B102879376-2.jpg',
'CH-B101509488-page-16.jpg',
'CH-B102708352-2.jpg',
# 'CH-B102897920-2.jpg',
# 'CH-B102551284-0.jpg',
# 'CH-B102879376-2.jpg',
# 'CH-B101509488-page-16.jpg',
# 'CH-B102708352-2.jpg',
]
build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir)
......
......@@ -219,11 +219,11 @@ class SLSolver(object):
map_key_value = 'find_value'
group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写']
skip_list_valid = [
'CH-B102897920-2.jpg',
'CH-B102551284-0.jpg',
'CH-B102879376-2.jpg',
'CH-B101509488-page-16.jpg',
'CH-B102708352-2.jpg',
# 'CH-B102897920-2.jpg',
# 'CH-B102551284-0.jpg',
# 'CH-B102879376-2.jpg',
# 'CH-B101509488-page-16.jpg',
# 'CH-B102708352-2.jpg',
]
dataset_base_dir = os.path.dirname(self.val_map_path)
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!