fb66a889 by 周伟奇

add iou

1 parent 092baca7
...@@ -6,9 +6,49 @@ import uuid ...@@ -6,9 +6,49 @@ import uuid
6 6
7 import cv2 7 import cv2
8 import pandas as pd 8 import pandas as pd
9 import numpy as np
10 from shapely.geometry import Polygon, MultiPoint
9 from tools import get_file_paths, load_json 11 from tools import get_file_paths, load_json
10 from word2vec import jwq_word2vec, simple_word2vec 12 from word2vec import jwq_word2vec, simple_word2vec
11 13
14 def bbox_iou(go_bbox, label_bbox, mode='iou'):
15 # 所有点的最小凸的表示形式,四边形对象,会自动计算四个点,最后顺序为:左上 左下 右下 右上 左上
16 go_poly = Polygon(go_bbox).convex_hull
17 label_poly = Polygon(label_bbox).convex_hull
18 if not go_poly.is_valid or not label_poly.is_valid:
19 print('formatting errors for boxes!!!! ')
20 return 0
21 if go_poly.area == 0 or label_poly.area == 0 :
22 return 0
23
24 inter = Polygon(go_poly).intersection(Polygon(label_poly)).area
25 go_area = Polygon(go_poly).area
26
27 return inter / go_area
28
29 # if mode == 'iou':
30 # union = go_poly.area + label_poly.area - inter
31 # elif mode =='tiou':
32 # union_poly = np.concatenate((go_bbox, label_bbox)) #合并两个box坐标,变为8*2
33 # union = MultiPoint(union_poly).convex_hull.area
34 # # coors = MultiPoint(union_poly).convex_hull.wkt
35 # elif mode == 'giou':
36 # union_poly = np.concatenate((go_bbox, label_bbox))
37 # union = MultiPoint(union_poly).envelope.area
38 # # coors = MultiPoint(union_poly).envelope.wkt
39 # elif mode == 'r_giou':
40 # union_poly = np.concatenate((go_bbox, label_bbox))
41 # union = MultiPoint(union_poly).minimum_rotated_rectangle.area
42 # # coors = MultiPoint(union_poly).minimum_rotated_rectangle.wkt
43 # else:
44 # raise Exception('incorrect mode!')
45
46 # if union == 0:
47 # return 0
48 # else:
49 # return inter / union
50
51
12 52
13 def clean_go_res(go_res_dir): 53 def clean_go_res(go_res_dir):
14 go_res_json_paths = get_file_paths(go_res_dir, ['.json', ]) 54 go_res_json_paths = get_file_paths(go_res_dir, ['.json', ])
...@@ -32,7 +72,6 @@ def clean_go_res(go_res_dir): ...@@ -32,7 +72,6 @@ def clean_go_res(go_res_dir):
32 json.dump(go_res_list, fp) 72 json.dump(go_res_list, fp)
33 print('Rerewirte {0}'.format(go_res_json_path)) 73 print('Rerewirte {0}'.format(go_res_json_path))
34 74
35
36 def char_length_statistics(go_res_dir): 75 def char_length_statistics(go_res_dir):
37 max_char_length = None 76 max_char_length = None
38 target_file_name = None 77 target_file_name = None
...@@ -151,40 +190,35 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save ...@@ -151,40 +190,35 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
151 for group_id in test_group_id: 190 for group_id in test_group_id:
152 for item in label_res.get("shapes", []): 191 for item in label_res.get("shapes", []):
153 if item.get("group_id") == group_id: 192 if item.get("group_id") == group_id:
154 x_list = [] 193 label_bbox = list()
155 y_list = []
156 for point in item['points']: 194 for point in item['points']:
157 x_list.append(point[0]) 195 label_bbox.extend(point)
158 y_list.append(point[1]) 196 group_list.append(label_bbox)
159 group_list.append([min(x_list) + (max(x_list) - min(x_list))/2, min(y_list) + (max(y_list) - min(y_list))/2])
160 break 197 break
161 else: 198 else:
162 group_list.append(None) 199 group_list.append(None)
163 200
164 go_center_list = []
165 for (x0, y0, x1, y1, x2, y2, x3, y3), _ in go_res_list:
166 xmin = min(x0, x1, x2, x3)
167 ymin = min(y0, y1, y2, y3)
168 xmax = max(x0, x1, x2, x3)
169 ymax = max(y0, y1, y2, y3)
170 xcenter = xmin + (xmax - xmin)/2
171 ycenter = ymin + (ymax - ymin)/2
172 go_center_list.append((xcenter, ycenter))
173
174 label_idx_dict = dict() 201 label_idx_dict = dict()
175 for label_idx, label_center_list in enumerate(group_list): 202 for label_idx, label_bbox in enumerate(group_list):
176 if isinstance(label_center_list, list): 203 if isinstance(label_bbox, list):
177 min_go_key = None 204 for go_idx, (go_bbox, _) in enumerate(go_res_list):
178 min_length = None
179 for go_idx, (go_x_center, go_y_center) in enumerate(go_center_list):
180 if go_idx in top_text_idx_set or go_idx in label_idx_dict: 205 if go_idx in top_text_idx_set or go_idx in label_idx_dict:
181 continue 206 continue
182 length = abs(go_x_center-label_center_list[0])+abs(go_y_center-label_center_list[1]) 207 go_bbox_rebuild = [
183 if min_go_key is None or length < min_length: 208 [go_bbox[0], go_bbox[1]],
184 min_go_key = go_idx 209 [go_bbox[2], go_bbox[3]],
185 min_length = length 210 [go_bbox[4], go_bbox[5]],
186 if min_go_key is not None: 211 [go_bbox[6], go_bbox[7]],
187 label_idx_dict[min_go_key] = label_idx 212 ]
213 label_bbox_rebuild = [
214 [label_bbox[0], label_bbox[1]],
215 [label_bbox[2], label_bbox[1]],
216 [label_bbox[2], label_bbox[3]],
217 [label_bbox[0], label_bbox[3]],
218 ]
219 iou = bbox_iou(go_bbox_rebuild, label_bbox_rebuild)
220 if iou >= 0.5:
221 label_idx_dict[go_idx] = label_idx
188 222
189 X = list() 223 X = list()
190 y_true = list() 224 y_true = list()
...@@ -239,19 +273,16 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save ...@@ -239,19 +273,16 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
239 create_map[img_name] = { 273 create_map[img_name] = {
240 'x_y_valid_lens': save_json_name, 274 'x_y_valid_lens': save_json_name,
241 'find_top_text': [go_res_list[i][-1] for i in top_text_idx_set], 275 'find_top_text': [go_res_list[i][-1] for i in top_text_idx_set],
242 'find_value': {group_cn_list[v]: go_res_list[k][-1] for k, v in label_idx_dict.items()} 276 'find_value': {go_res_list[k][-1]: group_cn_list[v] for k, v in label_idx_dict.items()}
243 } 277 }
244 278
245
246 # break
247
248 # print(create_map) 279 # print(create_map)
249 # print(is_create_map) 280 # print(is_create_map)
250 if create_map: 281 if create_map:
282 # print(create_map)
251 with open(os.path.join(os.path.dirname(save_dir), 'create_map.json'), 'w') as fp: 283 with open(os.path.join(os.path.dirname(save_dir), 'create_map.json'), 'w') as fp:
252 json.dump(create_map, fp) 284 json.dump(create_map, fp)
253 285
254
255 # print('top text find:') 286 # print('top text find:')
256 # for i in top_text_idx_set: 287 # for i in top_text_idx_set:
257 # _, text = go_res_list[i] 288 # _, text = go_res_list[i]
...@@ -269,7 +300,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save ...@@ -269,7 +300,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
269 if __name__ == '__main__': 300 if __name__ == '__main__':
270 base_dir = '/Users/zhouweiqi/Downloads/gcfp/data' 301 base_dir = '/Users/zhouweiqi/Downloads/gcfp/data'
271 go_dir = os.path.join(base_dir, 'go_res') 302 go_dir = os.path.join(base_dir, 'go_res')
272 dataset_save_dir = os.path.join(base_dir, 'dataset160x14') 303 dataset_save_dir = os.path.join(base_dir, 'dataset160x14-pro')
273 label_dir = os.path.join(base_dir, 'labeled') 304 label_dir = os.path.join(base_dir, 'labeled')
274 305
275 train_go_path = os.path.join(go_dir, 'train') 306 train_go_path = os.path.join(go_dir, 'train')
...@@ -329,23 +360,23 @@ if __name__ == '__main__': ...@@ -329,23 +360,23 @@ if __name__ == '__main__':
329 ] 360 ]
330 361
331 skip_list_train = [ 362 skip_list_train = [
332 'CH-B101910792-page-12.jpg', 363 # 'CH-B101910792-page-12.jpg',
333 'CH-B101655312-page-13.jpg', 364 # 'CH-B101655312-page-13.jpg',
334 'CH-B102278656.jpg', 365 # 'CH-B102278656.jpg',
335 'CH-B101846620_page_1_img_0.jpg', 366 # 'CH-B101846620_page_1_img_0.jpg',
336 'CH-B103062528-0.jpg', 367 # 'CH-B103062528-0.jpg',
337 'CH-B102613120-3.jpg', 368 # 'CH-B102613120-3.jpg',
338 'CH-B102997980-3.jpg', 369 # 'CH-B102997980-3.jpg',
339 'CH-B102680060-3.jpg', 370 # 'CH-B102680060-3.jpg',
340 # 'CH-B102995500-2.jpg', # 没value 371 # # 'CH-B102995500-2.jpg', # 没value
341 ] 372 ]
342 373
343 skip_list_valid = [ 374 skip_list_valid = [
344 'CH-B102897920-2.jpg', 375 # 'CH-B102897920-2.jpg',
345 'CH-B102551284-0.jpg', 376 # 'CH-B102551284-0.jpg',
346 'CH-B102879376-2.jpg', 377 # 'CH-B102879376-2.jpg',
347 'CH-B101509488-page-16.jpg', 378 # 'CH-B101509488-page-16.jpg',
348 'CH-B102708352-2.jpg', 379 # 'CH-B102708352-2.jpg',
349 ] 380 ]
350 381
351 build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir) 382 build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir)
......
...@@ -219,11 +219,11 @@ class SLSolver(object): ...@@ -219,11 +219,11 @@ class SLSolver(object):
219 map_key_value = 'find_value' 219 map_key_value = 'find_value'
220 group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写'] 220 group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写']
221 skip_list_valid = [ 221 skip_list_valid = [
222 'CH-B102897920-2.jpg', 222 # 'CH-B102897920-2.jpg',
223 'CH-B102551284-0.jpg', 223 # 'CH-B102551284-0.jpg',
224 'CH-B102879376-2.jpg', 224 # 'CH-B102879376-2.jpg',
225 'CH-B101509488-page-16.jpg', 225 # 'CH-B101509488-page-16.jpg',
226 'CH-B102708352-2.jpg', 226 # 'CH-B102708352-2.jpg',
227 ] 227 ]
228 228
229 dataset_base_dir = os.path.dirname(self.val_map_path) 229 dataset_base_dir = os.path.dirname(self.val_map_path)
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!