Skip to content
Toggle navigation
Toggle navigation
This project
Loading...
Sign in
周伟奇
/
test_on_pytorch
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Issue Boards
Files
Commits
Network
Compare
Branches
Tags
092baca7
authored
2022-12-21 18:08:20 +0800
by
周伟奇
Browse Files
Options
Browse Files
Tag
Download
Email Patches
Plain Diff
fix bug
1 parent
60c39554
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
7 deletions
solver/sl_solver.py
solver/sl_solver.py
View file @
092baca
...
...
@@ -206,10 +206,25 @@ class SLSolver(object):
print
(
'Warn: val_map_path not exists: {0}'
.
format
(
self
.
val_map_path
))
return
if
isinstance
(
self
.
model_path
,
str
)
and
os
.
path
.
exists
(
self
.
model_path
):
self
.
model
.
load_state_dict
(
torch
.
load
(
self
.
model_path
))
self
.
logger
.
info
(
f
'==> Load Model from {self.model_path}'
)
else
:
return
self
.
model
.
eval
()
map_key_input
=
'x_y_valid_lens'
map_key_text
=
'find_top_text'
map_key_value
=
'find_value'
group_cn_list
=
[
'其他'
,
'开票日期'
,
'发票代码'
,
'机打号码'
,
'车辆类型'
,
'电话'
,
'发动机号码'
,
'车架号'
,
'帐号'
,
'开户银行'
,
'小写'
]
skip_list_valid
=
[
'CH-B102897920-2.jpg'
,
'CH-B102551284-0.jpg'
,
'CH-B102879376-2.jpg'
,
'CH-B101509488-page-16.jpg'
,
'CH-B102708352-2.jpg'
,
]
dataset_base_dir
=
os
.
path
.
dirname
(
self
.
val_map_path
)
val_dataset_dir
=
os
.
path
.
join
(
dataset_base_dir
,
'valid'
)
...
...
@@ -217,12 +232,13 @@ class SLSolver(object):
if
not
os
.
path
.
isdir
(
save_dir
):
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
self
.
model
.
eval
()
with
open
(
self
.
val_map_path
,
'r'
)
as
fp
:
val_map
=
json
.
load
(
fp
)
for
img_name
in
sorted
(
os
.
listdir
(
self
.
val_image_path
)):
if
img_name
in
skip_list_valid
:
continue
print
(
'Info: start {0}'
.
format
(
img_name
))
image_path
=
os
.
path
.
join
(
self
.
val_image_path
,
img_name
)
...
...
@@ -232,11 +248,11 @@ class SLSolver(object):
draw
=
ImageDraw
.
Draw
(
img_pil
)
if
im_h
<
im_w
:
size
=
int
(
im_h
*
0.01
5
)
size
=
int
(
im_h
*
0.01
0
)
else
:
size
=
int
(
im_w
*
0.01
5
)
if
size
<
1
4
:
size
=
1
4
size
=
int
(
im_w
*
0.01
0
)
if
size
<
1
0
:
size
=
1
0
font
=
ImageFont
.
truetype
(
self
.
draw_font_path
,
size
,
encoding
=
'utf-8'
)
green_color
=
(
0
,
255
,
0
)
...
...
@@ -253,7 +269,7 @@ class SLSolver(object):
X
=
torch
.
tensor
(
input_list
)
.
unsqueeze
(
0
)
.
to
(
self
.
device
)
y_true
=
torch
.
tensor
(
label_list
)
.
unsqueeze
(
0
)
.
float
()
.
to
(
self
.
device
)
valid_lens
=
torch
.
tenor
([
valid_lens_scalar
,
])
.
to
(
self
.
device
)
valid_lens
=
torch
.
ten
s
or
([
valid_lens_scalar
,
])
.
to
(
self
.
device
)
del
input_list
del
label_list
...
...
Write
Preview
Styling with
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment