add iou
Showing
2 changed files
with
83 additions
and
52 deletions
... | @@ -6,9 +6,49 @@ import uuid | ... | @@ -6,9 +6,49 @@ import uuid |
6 | 6 | ||
7 | import cv2 | 7 | import cv2 |
8 | import pandas as pd | 8 | import pandas as pd |
9 | import numpy as np | ||
10 | from shapely.geometry import Polygon, MultiPoint | ||
9 | from tools import get_file_paths, load_json | 11 | from tools import get_file_paths, load_json |
10 | from word2vec import jwq_word2vec, simple_word2vec | 12 | from word2vec import jwq_word2vec, simple_word2vec |
11 | 13 | ||
14 | def bbox_iou(go_bbox, label_bbox, mode='iou'): | ||
15 | # 所有点的最小凸的表示形式,四边形对象,会自动计算四个点,最后顺序为:左上 左下 右下 右上 左上 | ||
16 | go_poly = Polygon(go_bbox).convex_hull | ||
17 | label_poly = Polygon(label_bbox).convex_hull | ||
18 | if not go_poly.is_valid or not label_poly.is_valid: | ||
19 | print('formatting errors for boxes!!!! ') | ||
20 | return 0 | ||
21 | if go_poly.area == 0 or label_poly.area == 0 : | ||
22 | return 0 | ||
23 | |||
24 | inter = Polygon(go_poly).intersection(Polygon(label_poly)).area | ||
25 | go_area = Polygon(go_poly).area | ||
26 | |||
27 | return inter / go_area | ||
28 | |||
29 | # if mode == 'iou': | ||
30 | # union = go_poly.area + label_poly.area - inter | ||
31 | # elif mode =='tiou': | ||
32 | # union_poly = np.concatenate((go_bbox, label_bbox)) #合并两个box坐标,变为8*2 | ||
33 | # union = MultiPoint(union_poly).convex_hull.area | ||
34 | # # coors = MultiPoint(union_poly).convex_hull.wkt | ||
35 | # elif mode == 'giou': | ||
36 | # union_poly = np.concatenate((go_bbox, label_bbox)) | ||
37 | # union = MultiPoint(union_poly).envelope.area | ||
38 | # # coors = MultiPoint(union_poly).envelope.wkt | ||
39 | # elif mode == 'r_giou': | ||
40 | # union_poly = np.concatenate((go_bbox, label_bbox)) | ||
41 | # union = MultiPoint(union_poly).minimum_rotated_rectangle.area | ||
42 | # # coors = MultiPoint(union_poly).minimum_rotated_rectangle.wkt | ||
43 | # else: | ||
44 | # raise Exception('incorrect mode!') | ||
45 | |||
46 | # if union == 0: | ||
47 | # return 0 | ||
48 | # else: | ||
49 | # return inter / union | ||
50 | |||
51 | |||
12 | 52 | ||
13 | def clean_go_res(go_res_dir): | 53 | def clean_go_res(go_res_dir): |
14 | go_res_json_paths = get_file_paths(go_res_dir, ['.json', ]) | 54 | go_res_json_paths = get_file_paths(go_res_dir, ['.json', ]) |
... | @@ -32,7 +72,6 @@ def clean_go_res(go_res_dir): | ... | @@ -32,7 +72,6 @@ def clean_go_res(go_res_dir): |
32 | json.dump(go_res_list, fp) | 72 | json.dump(go_res_list, fp) |
33 | print('Rerewirte {0}'.format(go_res_json_path)) | 73 | print('Rerewirte {0}'.format(go_res_json_path)) |
34 | 74 | ||
35 | |||
36 | def char_length_statistics(go_res_dir): | 75 | def char_length_statistics(go_res_dir): |
37 | max_char_length = None | 76 | max_char_length = None |
38 | target_file_name = None | 77 | target_file_name = None |
... | @@ -151,40 +190,35 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -151,40 +190,35 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
151 | for group_id in test_group_id: | 190 | for group_id in test_group_id: |
152 | for item in label_res.get("shapes", []): | 191 | for item in label_res.get("shapes", []): |
153 | if item.get("group_id") == group_id: | 192 | if item.get("group_id") == group_id: |
154 | x_list = [] | 193 | label_bbox = list() |
155 | y_list = [] | ||
156 | for point in item['points']: | 194 | for point in item['points']: |
157 | x_list.append(point[0]) | 195 | label_bbox.extend(point) |
158 | y_list.append(point[1]) | 196 | group_list.append(label_bbox) |
159 | group_list.append([min(x_list) + (max(x_list) - min(x_list))/2, min(y_list) + (max(y_list) - min(y_list))/2]) | ||
160 | break | 197 | break |
161 | else: | 198 | else: |
162 | group_list.append(None) | 199 | group_list.append(None) |
163 | 200 | ||
164 | go_center_list = [] | ||
165 | for (x0, y0, x1, y1, x2, y2, x3, y3), _ in go_res_list: | ||
166 | xmin = min(x0, x1, x2, x3) | ||
167 | ymin = min(y0, y1, y2, y3) | ||
168 | xmax = max(x0, x1, x2, x3) | ||
169 | ymax = max(y0, y1, y2, y3) | ||
170 | xcenter = xmin + (xmax - xmin)/2 | ||
171 | ycenter = ymin + (ymax - ymin)/2 | ||
172 | go_center_list.append((xcenter, ycenter)) | ||
173 | |||
174 | label_idx_dict = dict() | 201 | label_idx_dict = dict() |
175 | for label_idx, label_center_list in enumerate(group_list): | 202 | for label_idx, label_bbox in enumerate(group_list): |
176 | if isinstance(label_center_list, list): | 203 | if isinstance(label_bbox, list): |
177 | min_go_key = None | 204 | for go_idx, (go_bbox, _) in enumerate(go_res_list): |
178 | min_length = None | ||
179 | for go_idx, (go_x_center, go_y_center) in enumerate(go_center_list): | ||
180 | if go_idx in top_text_idx_set or go_idx in label_idx_dict: | 205 | if go_idx in top_text_idx_set or go_idx in label_idx_dict: |
181 | continue | 206 | continue |
182 | length = abs(go_x_center-label_center_list[0])+abs(go_y_center-label_center_list[1]) | 207 | go_bbox_rebuild = [ |
183 | if min_go_key is None or length < min_length: | 208 | [go_bbox[0], go_bbox[1]], |
184 | min_go_key = go_idx | 209 | [go_bbox[2], go_bbox[3]], |
185 | min_length = length | 210 | [go_bbox[4], go_bbox[5]], |
186 | if min_go_key is not None: | 211 | [go_bbox[6], go_bbox[7]], |
187 | label_idx_dict[min_go_key] = label_idx | 212 | ] |
213 | label_bbox_rebuild = [ | ||
214 | [label_bbox[0], label_bbox[1]], | ||
215 | [label_bbox[2], label_bbox[1]], | ||
216 | [label_bbox[2], label_bbox[3]], | ||
217 | [label_bbox[0], label_bbox[3]], | ||
218 | ] | ||
219 | iou = bbox_iou(go_bbox_rebuild, label_bbox_rebuild) | ||
220 | if iou >= 0.5: | ||
221 | label_idx_dict[go_idx] = label_idx | ||
188 | 222 | ||
189 | X = list() | 223 | X = list() |
190 | y_true = list() | 224 | y_true = list() |
... | @@ -239,19 +273,16 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -239,19 +273,16 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
239 | create_map[img_name] = { | 273 | create_map[img_name] = { |
240 | 'x_y_valid_lens': save_json_name, | 274 | 'x_y_valid_lens': save_json_name, |
241 | 'find_top_text': [go_res_list[i][-1] for i in top_text_idx_set], | 275 | 'find_top_text': [go_res_list[i][-1] for i in top_text_idx_set], |
242 | 'find_value': {group_cn_list[v]: go_res_list[k][-1] for k, v in label_idx_dict.items()} | 276 | 'find_value': {go_res_list[k][-1]: group_cn_list[v] for k, v in label_idx_dict.items()} |
243 | } | 277 | } |
244 | 278 | ||
245 | |||
246 | # break | ||
247 | |||
248 | # print(create_map) | 279 | # print(create_map) |
249 | # print(is_create_map) | 280 | # print(is_create_map) |
250 | if create_map: | 281 | if create_map: |
282 | # print(create_map) | ||
251 | with open(os.path.join(os.path.dirname(save_dir), 'create_map.json'), 'w') as fp: | 283 | with open(os.path.join(os.path.dirname(save_dir), 'create_map.json'), 'w') as fp: |
252 | json.dump(create_map, fp) | 284 | json.dump(create_map, fp) |
253 | 285 | ||
254 | |||
255 | # print('top text find:') | 286 | # print('top text find:') |
256 | # for i in top_text_idx_set: | 287 | # for i in top_text_idx_set: |
257 | # _, text = go_res_list[i] | 288 | # _, text = go_res_list[i] |
... | @@ -269,7 +300,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save | ... | @@ -269,7 +300,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save |
269 | if __name__ == '__main__': | 300 | if __name__ == '__main__': |
270 | base_dir = '/Users/zhouweiqi/Downloads/gcfp/data' | 301 | base_dir = '/Users/zhouweiqi/Downloads/gcfp/data' |
271 | go_dir = os.path.join(base_dir, 'go_res') | 302 | go_dir = os.path.join(base_dir, 'go_res') |
272 | dataset_save_dir = os.path.join(base_dir, 'dataset160x14') | 303 | dataset_save_dir = os.path.join(base_dir, 'dataset160x14-pro') |
273 | label_dir = os.path.join(base_dir, 'labeled') | 304 | label_dir = os.path.join(base_dir, 'labeled') |
274 | 305 | ||
275 | train_go_path = os.path.join(go_dir, 'train') | 306 | train_go_path = os.path.join(go_dir, 'train') |
... | @@ -329,23 +360,23 @@ if __name__ == '__main__': | ... | @@ -329,23 +360,23 @@ if __name__ == '__main__': |
329 | ] | 360 | ] |
330 | 361 | ||
331 | skip_list_train = [ | 362 | skip_list_train = [ |
332 | 'CH-B101910792-page-12.jpg', | 363 | # 'CH-B101910792-page-12.jpg', |
333 | 'CH-B101655312-page-13.jpg', | 364 | # 'CH-B101655312-page-13.jpg', |
334 | 'CH-B102278656.jpg', | 365 | # 'CH-B102278656.jpg', |
335 | 'CH-B101846620_page_1_img_0.jpg', | 366 | # 'CH-B101846620_page_1_img_0.jpg', |
336 | 'CH-B103062528-0.jpg', | 367 | # 'CH-B103062528-0.jpg', |
337 | 'CH-B102613120-3.jpg', | 368 | # 'CH-B102613120-3.jpg', |
338 | 'CH-B102997980-3.jpg', | 369 | # 'CH-B102997980-3.jpg', |
339 | 'CH-B102680060-3.jpg', | 370 | # 'CH-B102680060-3.jpg', |
340 | # 'CH-B102995500-2.jpg', # 没value | 371 | # # 'CH-B102995500-2.jpg', # 没value |
341 | ] | 372 | ] |
342 | 373 | ||
343 | skip_list_valid = [ | 374 | skip_list_valid = [ |
344 | 'CH-B102897920-2.jpg', | 375 | # 'CH-B102897920-2.jpg', |
345 | 'CH-B102551284-0.jpg', | 376 | # 'CH-B102551284-0.jpg', |
346 | 'CH-B102879376-2.jpg', | 377 | # 'CH-B102879376-2.jpg', |
347 | 'CH-B101509488-page-16.jpg', | 378 | # 'CH-B101509488-page-16.jpg', |
348 | 'CH-B102708352-2.jpg', | 379 | # 'CH-B102708352-2.jpg', |
349 | ] | 380 | ] |
350 | 381 | ||
351 | build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir) | 382 | build_dataset(train_image_path, train_go_path, train_label_path, filter_from_top_text_list, skip_list_train, train_dataset_dir) | ... | ... |
... | @@ -219,11 +219,11 @@ class SLSolver(object): | ... | @@ -219,11 +219,11 @@ class SLSolver(object): |
219 | map_key_value = 'find_value' | 219 | map_key_value = 'find_value' |
220 | group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写'] | 220 | group_cn_list = ['其他', '开票日期', '发票代码', '机打号码', '车辆类型', '电话', '发动机号码', '车架号', '帐号', '开户银行', '小写'] |
221 | skip_list_valid = [ | 221 | skip_list_valid = [ |
222 | 'CH-B102897920-2.jpg', | 222 | # 'CH-B102897920-2.jpg', |
223 | 'CH-B102551284-0.jpg', | 223 | # 'CH-B102551284-0.jpg', |
224 | 'CH-B102879376-2.jpg', | 224 | # 'CH-B102879376-2.jpg', |
225 | 'CH-B101509488-page-16.jpg', | 225 | # 'CH-B101509488-page-16.jpg', |
226 | 'CH-B102708352-2.jpg', | 226 | # 'CH-B102708352-2.jpg', |
227 | ] | 227 | ] |
228 | 228 | ||
229 | dataset_base_dir = os.path.dirname(self.val_map_path) | 229 | dataset_base_dir = os.path.dirname(self.val_map_path) | ... | ... |
-
Please register or sign in to post a comment