Skip to content
Toggle navigation
Toggle navigation
This project
Loading...
Sign in
周伟奇
/
test_on_pytorch
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Issue Boards
Files
Commits
Network
Compare
Branches
Tags
41252450
authored
2022-12-22 18:29:08 +0800
by
周伟奇
Browse Files
Options
Browse Files
Tag
Download
Email Patches
Plain Diff
add statistics
1 parent
fb66a889
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
3 deletions
config/sl.yaml
solver/sl_solver.py
config/sl.yaml
View file @
4125245
...
...
@@ -37,6 +37,7 @@ solver:
base_on
:
null
model_path
:
null
val_image_path
:
'
/labeled/valid/image'
val_label_path
:
'
/labeled/valid/label'
val_go_path
:
'
/go_res/valid'
val_map_path
:
'
/dataset160x14/create_map.json'
draw_font_path
:
'
/dataset160x14/STZHONGS.TTF'
...
...
solver/sl_solver.py
View file @
4125245
...
...
@@ -38,6 +38,7 @@ class SLSolver(object):
self
.
base_on
=
self
.
hyper_params
[
'base_on'
]
self
.
model_path
=
self
.
hyper_params
[
'model_path'
]
self
.
val_image_path
=
self
.
hyper_params
[
'val_image_path'
]
self
.
val_label_path
=
self
.
hyper_params
[
'val_label_path'
]
self
.
val_go_path
=
self
.
hyper_params
[
'val_go_path'
]
self
.
val_map_path
=
self
.
hyper_params
[
'val_map_path'
]
self
.
draw_font_path
=
self
.
hyper_params
[
'draw_font_path'
]
...
...
@@ -198,6 +199,10 @@ class SLSolver(object):
print
(
'Warn: val_image_path not exists: {0}'
.
format
(
self
.
val_image_path
))
return
if
not
os
.
path
.
isdir
(
self
.
val_label_path
):
print
(
'Warn: val_label_path not exists: {0}'
.
format
(
self
.
val_label_path
))
return
if
not
os
.
path
.
isdir
(
self
.
val_go_path
):
print
(
'Warn: val_go_path not exists: {0}'
.
format
(
self
.
val_go_path
))
return
...
...
@@ -217,6 +222,7 @@ class SLSolver(object):
map_key_input
=
'x_y_valid_lens'
map_key_text
=
'find_top_text'
map_key_value
=
'find_value'
test_group_id
=
[
1
,
2
,
5
,
9
,
20
,
15
,
16
,
22
,
24
,
28
]
group_cn_list
=
[
'其他'
,
'开票日期'
,
'发票代码'
,
'机打号码'
,
'车辆类型'
,
'电话'
,
'发动机号码'
,
'车架号'
,
'帐号'
,
'开户银行'
,
'小写'
]
skip_list_valid
=
[
# 'CH-B102897920-2.jpg',
...
...
@@ -235,6 +241,8 @@ class SLSolver(object):
with
open
(
self
.
val_map_path
,
'r'
)
as
fp
:
val_map
=
json
.
load
(
fp
)
data_dict
=
{
key_cn
:
[
0
,
0
]
for
key_cn
in
group_cn_list
[
1
:]}
failed_dict
=
dict
()
for
img_name
in
sorted
(
os
.
listdir
(
self
.
val_image_path
)):
if
img_name
in
skip_list_valid
:
continue
...
...
@@ -281,7 +289,11 @@ class SLSolver(object):
correct
=
0
bbox_draw_dict
=
dict
()
bbox_text_dict
=
dict
()
for
i
in
range
(
valid_lens_scalar
):
if
pred
[
i
]
!=
0
:
bbox_text_dict
.
setdefault
(
test_group_id
[
pred
[
i
]
-
1
],
list
())
.
append
(
i
)
if
pred
[
i
]
==
label
[
i
]:
correct
+=
1
if
pred
[
i
]
!=
0
:
...
...
@@ -311,8 +323,46 @@ class SLSolver(object):
img_pil
.
save
(
os
.
path
.
join
(
save_dir
,
img_name
))
# break
# 统计准确率
label_json_path
=
os
.
path
.
join
(
self
.
val_label_path
,
'{0}.json'
.
format
(
base_image_name
))
with
open
(
label_json_path
,
'r'
)
as
fp
:
label_res
=
json
.
load
(
fp
)
group_text_list
=
[]
for
group_id
in
test_group_id
:
for
item
in
label_res
.
get
(
"shapes"
,
[]):
if
item
.
get
(
"group_id"
)
==
group_id
:
group_text_list
.
append
(
item
[
'label'
])
break
else
:
group_text_list
.
append
(
None
)
for
idx
,
text
in
enumerate
(
group_text_list
):
key_cn
=
group_cn_list
[
idx
+
1
]
pred_idx_list
=
bbox_text_dict
.
get
(
idx
)
if
isinstance
(
pred_idx_list
,
list
):
pred_text_list
=
[
go_res_list
[
idx
][
-
1
]
for
idx
in
pred_idx_list
]
pred_text
=
' '
.
join
(
pred_text_list
)
else
:
pred_text
=
None
data_dict
[
key_cn
][
-
1
]
+=
1
if
pred_text
==
text
:
data_dict
[
key_cn
][
0
]
+=
1
else
:
failed_dict
.
setdefault
(
key_cn
,
list
())
.
append
((
text
,
pred_text
))
# break
for
key_cn
,
(
correct_count
,
all_count
)
in
data_dict
.
ietms
():
print
(
'{0}: {1}'
.
format
(
key_cn
,
round
(
correct_count
/
all_count
,
2
)))
print
(
'==========================='
)
for
key_cn
,
failed_list
in
failed_dict
.
items
():
print
(
key_cn
)
for
text
,
pred_text
in
failed_list
:
print
(
'label: {0} pred: {1}'
.
format
(
text
,
pred_text
))
print
(
'----------------------------------'
)
\ No newline at end of file
...
...
Write
Preview
Styling with
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment