Skip to content
Toggle navigation
Toggle navigation
This project
Loading...
Sign in
周伟奇
/
test_on_pytorch
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Issue Boards
Files
Commits
Network
Compare
Branches
Tags
fb66a889
authored
2022-12-22 17:15:57 +0800
by
周伟奇
Browse Files
Options
Browse Files
Tag
Download
Email Patches
Plain Diff
add iou
1 parent
092baca7
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
83 additions
and
52 deletions
data/create_dataset2.py
solver/sl_solver.py
data/create_dataset2.py
View file @
fb66a88
...
...
@@ -6,9 +6,49 @@ import uuid
import
cv2
import
pandas
as
pd
import
numpy
as
np
from
shapely.geometry
import
Polygon
,
MultiPoint
from
tools
import
get_file_paths
,
load_json
from
word2vec
import
jwq_word2vec
,
simple_word2vec
def
bbox_iou
(
go_bbox
,
label_bbox
,
mode
=
'iou'
):
# 所有点的最小凸的表示形式,四边形对象,会自动计算四个点,最后顺序为:左上 左下 右下 右上 左上
go_poly
=
Polygon
(
go_bbox
)
.
convex_hull
label_poly
=
Polygon
(
label_bbox
)
.
convex_hull
if
not
go_poly
.
is_valid
or
not
label_poly
.
is_valid
:
print
(
'formatting errors for boxes!!!! '
)
return
0
if
go_poly
.
area
==
0
or
label_poly
.
area
==
0
:
return
0
inter
=
Polygon
(
go_poly
)
.
intersection
(
Polygon
(
label_poly
))
.
area
go_area
=
Polygon
(
go_poly
)
.
area
return
inter
/
go_area
# if mode == 'iou':
# union = go_poly.area + label_poly.area - inter
# elif mode =='tiou':
# union_poly = np.concatenate((go_bbox, label_bbox)) #合并两个box坐标,变为8*2
# union = MultiPoint(union_poly).convex_hull.area
# # coors = MultiPoint(union_poly).convex_hull.wkt
# elif mode == 'giou':
# union_poly = np.concatenate((go_bbox, label_bbox))
# union = MultiPoint(union_poly).envelope.area
# # coors = MultiPoint(union_poly).envelope.wkt
# elif mode == 'r_giou':
# union_poly = np.concatenate((go_bbox, label_bbox))
# union = MultiPoint(union_poly).minimum_rotated_rectangle.area
# # coors = MultiPoint(union_poly).minimum_rotated_rectangle.wkt
# else:
# raise Exception('incorrect mode!')
# if union == 0:
# return 0
# else:
# return inter / union
def
clean_go_res
(
go_res_dir
):
go_res_json_paths
=
get_file_paths
(
go_res_dir
,
[
'.json'
,
])
...
...
@@ -32,7 +72,6 @@ def clean_go_res(go_res_dir):
json
.
dump
(
go_res_list
,
fp
)
print
(
'Rerewirte {0}'
.
format
(
go_res_json_path
))
def
char_length_statistics
(
go_res_dir
):
max_char_length
=
None
target_file_name
=
None
...
...
@@ -151,40 +190,35 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
for
group_id
in
test_group_id
:
for
item
in
label_res
.
get
(
"shapes"
,
[]):
if
item
.
get
(
"group_id"
)
==
group_id
:
x_list
=
[]
y_list
=
[]
label_bbox
=
list
()
for
point
in
item
[
'points'
]:
x_list
.
append
(
point
[
0
])
y_list
.
append
(
point
[
1
])
group_list
.
append
([
min
(
x_list
)
+
(
max
(
x_list
)
-
min
(
x_list
))
/
2
,
min
(
y_list
)
+
(
max
(
y_list
)
-
min
(
y_list
))
/
2
])
label_bbox
.
extend
(
point
)
group_list
.
append
(
label_bbox
)
break
else
:
group_list
.
append
(
None
)
go_center_list
=
[]
for
(
x0
,
y0
,
x1
,
y1
,
x2
,
y2
,
x3
,
y3
),
_
in
go_res_list
:
xmin
=
min
(
x0
,
x1
,
x2
,
x3
)
ymin
=
min
(
y0
,
y1
,
y2
,
y3
)
xmax
=
max
(
x0
,
x1
,
x2
,
x3
)
ymax
=
max
(
y0
,
y1
,
y2
,
y3
)
xcenter
=
xmin
+
(
xmax
-
xmin
)
/
2
ycenter
=
ymin
+
(
ymax
-
ymin
)
/
2
go_center_list
.
append
((
xcenter
,
ycenter
))
label_idx_dict
=
dict
()
for
label_idx
,
label_center_list
in
enumerate
(
group_list
):
if
isinstance
(
label_center_list
,
list
):
min_go_key
=
None
min_length
=
None
for
go_idx
,
(
go_x_center
,
go_y_center
)
in
enumerate
(
go_center_list
):
for
label_idx
,
label_bbox
in
enumerate
(
group_list
):
if
isinstance
(
label_bbox
,
list
):
for
go_idx
,
(
go_bbox
,
_
)
in
enumerate
(
go_res_list
):
if
go_idx
in
top_text_idx_set
or
go_idx
in
label_idx_dict
:
continue
length
=
abs
(
go_x_center
-
label_center_list
[
0
])
+
abs
(
go_y_center
-
label_center_list
[
1
])
if
min_go_key
is
None
or
length
<
min_length
:
min_go_key
=
go_idx
min_length
=
length
if
min_go_key
is
not
None
:
label_idx_dict
[
min_go_key
]
=
label_idx
go_bbox_rebuild
=
[
[
go_bbox
[
0
],
go_bbox
[
1
]],
[
go_bbox
[
2
],
go_bbox
[
3
]],
[
go_bbox
[
4
],
go_bbox
[
5
]],
[
go_bbox
[
6
],
go_bbox
[
7
]],
]
label_bbox_rebuild
=
[
[
label_bbox
[
0
],
label_bbox
[
1
]],
[
label_bbox
[
2
],
label_bbox
[
1
]],
[
label_bbox
[
2
],
label_bbox
[
3
]],
[
label_bbox
[
0
],
label_bbox
[
3
]],
]
iou
=
bbox_iou
(
go_bbox_rebuild
,
label_bbox_rebuild
)
if
iou
>=
0.5
:
label_idx_dict
[
go_idx
]
=
label_idx
X
=
list
()
y_true
=
list
()
...
...
@@ -239,19 +273,16 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
create_map
[
img_name
]
=
{
'x_y_valid_lens'
:
save_json_name
,
'find_top_text'
:
[
go_res_list
[
i
][
-
1
]
for
i
in
top_text_idx_set
],
'find_value'
:
{
g
roup_cn_list
[
v
]:
go_res_list
[
k
][
-
1
]
for
k
,
v
in
label_idx_dict
.
items
()}
'find_value'
:
{
g
o_res_list
[
k
][
-
1
]:
group_cn_list
[
v
]
for
k
,
v
in
label_idx_dict
.
items
()}
}
# break
# print(create_map)
# print(is_create_map)
if
create_map
:
# print(create_map)
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
save_dir
),
'create_map.json'
),
'w'
)
as
fp
:
json
.
dump
(
create_map
,
fp
)
# print('top text find:')
# for i in top_text_idx_set:
# _, text = go_res_list[i]
...
...
@@ -269,7 +300,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
if
__name__
==
'__main__'
:
base_dir
=
'/Users/zhouweiqi/Downloads/gcfp/data'
go_dir
=
os
.
path
.
join
(
base_dir
,
'go_res'
)
dataset_save_dir
=
os
.
path
.
join
(
base_dir
,
'dataset160x14'
)
dataset_save_dir
=
os
.
path
.
join
(
base_dir
,
'dataset160x14
-pro
'
)
label_dir
=
os
.
path
.
join
(
base_dir
,
'labeled'
)
train_go_path
=
os
.
path
.
join
(
go_dir
,
'train'
)
...
...
@@ -329,23 +360,23 @@ if __name__ == '__main__':
]
skip_list_train
=
[
'CH-B101910792-page-12.jpg'
,
'CH-B101655312-page-13.jpg'
,
'CH-B102278656.jpg'
,
'CH-B101846620_page_1_img_0.jpg'
,
'CH-B103062528-0.jpg'
,
'CH-B102613120-3.jpg'
,
'CH-B102997980-3.jpg'
,
'CH-B102680060-3.jpg'
,
# 'CH-B102995500-2.jpg', # 没value
#
'CH-B101910792-page-12.jpg',
#
'CH-B101655312-page-13.jpg',
#
'CH-B102278656.jpg',
#
'CH-B101846620_page_1_img_0.jpg',
#
'CH-B103062528-0.jpg',
#
'CH-B102613120-3.jpg',
#
'CH-B102997980-3.jpg',
#
'CH-B102680060-3.jpg',
#
#
'CH-B102995500-2.jpg', # 没value
]
skip_list_valid
=
[
'CH-B102897920-2.jpg'
,
'CH-B102551284-0.jpg'
,
'CH-B102879376-2.jpg'
,
'CH-B101509488-page-16.jpg'
,
'CH-B102708352-2.jpg'
,
#
'CH-B102897920-2.jpg',
#
'CH-B102551284-0.jpg',
#
'CH-B102879376-2.jpg',
#
'CH-B101509488-page-16.jpg',
#
'CH-B102708352-2.jpg',
]
build_dataset
(
train_image_path
,
train_go_path
,
train_label_path
,
filter_from_top_text_list
,
skip_list_train
,
train_dataset_dir
)
...
...
solver/sl_solver.py
View file @
fb66a88
...
...
@@ -219,11 +219,11 @@ class SLSolver(object):
map_key_value
=
'find_value'
group_cn_list
=
[
'其他'
,
'开票日期'
,
'发票代码'
,
'机打号码'
,
'车辆类型'
,
'电话'
,
'发动机号码'
,
'车架号'
,
'帐号'
,
'开户银行'
,
'小写'
]
skip_list_valid
=
[
'CH-B102897920-2.jpg'
,
'CH-B102551284-0.jpg'
,
'CH-B102879376-2.jpg'
,
'CH-B101509488-page-16.jpg'
,
'CH-B102708352-2.jpg'
,
#
'CH-B102897920-2.jpg',
#
'CH-B102551284-0.jpg',
#
'CH-B102879376-2.jpg',
#
'CH-B101509488-page-16.jpg',
#
'CH-B102708352-2.jpg',
]
dataset_base_dir
=
os
.
path
.
dirname
(
self
.
val_map_path
)
...
...
Write
Preview
Styling with
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment