Skip to content
Toggle navigation
Toggle navigation
This project
Loading...
Sign in
周伟奇
/
test_on_pytorch
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Issue Boards
Files
Commits
Network
Compare
Branches
Tags
60c39554
authored
2022-12-21 16:46:43 +0800
by
周伟奇
Browse Files
Options
Browse Files
Tag
Download
Email Patches
Plain Diff
add drwa
1 parent
890ea78a
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
169 additions
and
32 deletions
config/sl.yaml
data/create_dataset2.py
draw.sh
eval.sh
main.py
model/seq_labeling.py
solver/sl_solver.py
train.sh
config/sl.yaml
View file @
60c3955
...
...
@@ -3,9 +3,9 @@ seed: 3407
dataset
:
name
:
'
SLData'
args
:
data_root
:
'
/
Users/zhouweiqi/Downloads/gcfp/data/dataset2
'
train_anno_file
:
'
/
Users/zhouweiqi/Downloads/gcfp/data/dataset2
/train.csv'
val_anno_file
:
'
/
Users/zhouweiqi/Downloads/gcfp/data/dataset2
/valid.csv'
data_root
:
'
/
dataset160x14
'
train_anno_file
:
'
/
dataset160x14
/train.csv'
val_anno_file
:
'
/
dataset160x14
/valid.csv'
dataloader
:
batch_size
:
8
...
...
@@ -18,7 +18,7 @@ model:
args
:
seq_lens
:
160
num_classes
:
10
embed_dim
:
9
embed_dim
:
14
depth
:
6
num_heads
:
1
mlp_ratio
:
4.0
...
...
@@ -36,6 +36,11 @@ solver:
epoch
:
100
base_on
:
null
model_path
:
null
val_image_path
:
'
/labeled/valid/image'
val_go_path
:
'
/go_res/valid'
val_map_path
:
'
/dataset160x14/create_map.json'
draw_font_path
:
'
/dataset160x14/STZHONGS.TTF'
thresholds
:
0.5
optimizer
:
name
:
'
Adam'
...
...
@@ -58,5 +63,5 @@ solver:
alpha
:
0.8
logger
:
log_root
:
'
/
Users/zhouweiqi/Downloads/test/
logs'
log_root
:
'
/logs'
suffix
:
'
sl-6-1'
\ No newline at end of file
...
...
data/create_dataset2.py
View file @
60c3955
...
...
@@ -7,7 +7,7 @@ import uuid
import
cv2
import
pandas
as
pd
from
tools
import
get_file_paths
,
load_json
from
word2vec
import
simple_word2vec
,
jwq_word2vec
from
word2vec
import
jwq_word2vec
,
simple_word2vec
def
clean_go_res
(
go_res_dir
):
...
...
@@ -101,7 +101,7 @@ def build_anno_file(dataset_dir, anno_file_path):
df
[
'name'
]
=
img_list
df
.
to_csv
(
anno_file_path
)
def
build_dataset
(
img_dir
,
go_res_dir
,
label_dir
,
top_text_list
,
skip_list
,
save_dir
):
def
build_dataset
(
img_dir
,
go_res_dir
,
label_dir
,
top_text_list
,
skip_list
,
save_dir
,
is_create_map
=
False
):
"""
Args:
img_dir: str 图片目录
...
...
@@ -121,6 +121,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
group_cn_list
=
[
'开票日期'
,
'发票代码'
,
'机打号码'
,
'车辆类型'
,
'电话'
,
'发动机号码'
,
'车架号'
,
'帐号'
,
'开户银行'
,
'小写'
]
test_group_id
=
[
1
,
2
,
5
,
9
,
20
,
15
,
16
,
22
,
24
,
28
]
create_map
=
{}
for
img_name
in
sorted
(
os
.
listdir
(
img_dir
)):
if
img_name
in
skip_list
:
print
(
'Info: skip {0}'
.
format
(
img_name
))
...
...
@@ -188,8 +189,9 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
X
=
list
()
y_true
=
list
()
text_vec_max_lens
=
15
*
50
dim
=
1
+
5
+
8
+
text_vec_max_lens
# text_vec_max_lens = 15 * 50
# dim = 1 + 5 + 8 + text_vec_max_lens
dim
=
1
+
5
+
8
num_classes
=
10
for
i
in
range
(
160
):
if
i
>=
valid_lens
:
...
...
@@ -201,7 +203,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
feature_vec
=
[
1.
]
feature_vec
.
extend
(
simple_word2vec
(
text
))
feature_vec
.
extend
([
x0
/
w
,
y0
/
h
,
x1
/
w
,
y1
/
h
,
x2
/
w
,
y2
/
h
,
x3
/
w
,
y3
/
h
])
feature_vec
.
extend
(
jwq_word2vec
(
text
,
text_vec_max_lens
))
#
feature_vec.extend(jwq_word2vec(text, text_vec_max_lens))
X
.
append
(
feature_vec
)
y_true
.
append
([
0
for
_
in
range
(
num_classes
)])
...
...
@@ -211,7 +213,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
feature_vec
=
[
0.
]
feature_vec
.
extend
(
simple_word2vec
(
text
))
feature_vec
.
extend
([
x0
/
w
,
y0
/
h
,
x1
/
w
,
y1
/
h
,
x2
/
w
,
y2
/
h
,
x3
/
w
,
y3
/
h
])
feature_vec
.
extend
(
jwq_word2vec
(
text
,
text_vec_max_lens
))
#
feature_vec.extend(jwq_word2vec(text, text_vec_max_lens))
X
.
append
(
feature_vec
)
base_label_list
=
[
0
for
_
in
range
(
num_classes
)]
...
...
@@ -222,16 +224,34 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
feature_vec
=
[
0.
]
feature_vec
.
extend
(
simple_word2vec
(
text
))
feature_vec
.
extend
([
x0
/
w
,
y0
/
h
,
x1
/
w
,
y1
/
h
,
x2
/
w
,
y2
/
h
,
x3
/
w
,
y3
/
h
])
feature_vec
.
extend
(
jwq_word2vec
(
text
,
text_vec_max_lens
))
#
feature_vec.extend(jwq_word2vec(text, text_vec_max_lens))
X
.
append
(
feature_vec
)
y_true
.
append
([
0
for
_
in
range
(
num_classes
)])
all_data
=
[
X
,
y_true
,
valid_lens
]
with
open
(
os
.
path
.
join
(
save_dir
,
'{0}.json'
.
format
(
uuid
.
uuid3
(
uuid
.
NAMESPACE_DNS
,
img_name
))),
'w'
)
as
fp
:
save_json_name
=
'{0}.json'
.
format
(
uuid
.
uuid3
(
uuid
.
NAMESPACE_DNS
,
img_name
))
with
open
(
os
.
path
.
join
(
save_dir
,
save_json_name
),
'w'
)
as
fp
:
json
.
dump
(
all_data
,
fp
)
if
is_create_map
:
create_map
[
img_name
]
=
{
'x_y_valid_lens'
:
save_json_name
,
'find_top_text'
:
[
go_res_list
[
i
][
-
1
]
for
i
in
top_text_idx_set
],
'find_value'
:
{
group_cn_list
[
v
]:
go_res_list
[
k
][
-
1
]
for
k
,
v
in
label_idx_dict
.
items
()}
}
# break
# print(create_map)
# print(is_create_map)
if
create_map
:
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
save_dir
),
'create_map.json'
),
'w'
)
as
fp
:
json
.
dump
(
create_map
,
fp
)
# print('top text find:')
# for i in top_text_idx_set:
# _, text = go_res_list[i]
...
...
@@ -249,7 +269,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
if
__name__
==
'__main__'
:
base_dir
=
'/Users/zhouweiqi/Downloads/gcfp/data'
go_dir
=
os
.
path
.
join
(
base_dir
,
'go_res'
)
dataset_save_dir
=
os
.
path
.
join
(
base_dir
,
'dataset
2
'
)
dataset_save_dir
=
os
.
path
.
join
(
base_dir
,
'dataset
160x14
'
)
label_dir
=
os
.
path
.
join
(
base_dir
,
'labeled'
)
train_go_path
=
os
.
path
.
join
(
go_dir
,
'train'
)
...
...
@@ -331,7 +351,7 @@ if __name__ == '__main__':
build_dataset
(
train_image_path
,
train_go_path
,
train_label_path
,
filter_from_top_text_list
,
skip_list_train
,
train_dataset_dir
)
build_anno_file
(
train_dataset_dir
,
train_anno_file_path
)
build_dataset
(
valid_image_path
,
valid_go_path
,
valid_label_path
,
filter_from_top_text_list
,
skip_list_valid
,
valid_dataset_dir
)
build_dataset
(
valid_image_path
,
valid_go_path
,
valid_label_path
,
filter_from_top_text_list
,
skip_list_valid
,
valid_dataset_dir
,
True
)
build_anno_file
(
valid_dataset_dir
,
valid_anno_file_path
)
# print(simple_word2vec(' fd2jk接口 额24;叁‘,。测ADF壹试!¥? '))
...
...
draw.sh
0 → 100755
View file @
60c3955
CUDA_VISIBLE_DEVICES
=
0 nohup python main.py --config
=
config/sl.yaml -d > draw.log 2>&1 &
\ No newline at end of file
eval.sh
0 → 100755
View file @
60c3955
CUDA_VISIBLE_DEVICES
=
0 nohup python main.py --config
=
config/sl.yaml -e > eval.log 2>&1 &
\ No newline at end of file
main.py
View file @
60c3955
...
...
@@ -8,6 +8,7 @@ def main():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--config'
,
default
=
'./config/mlp.yaml'
,
type
=
str
,
help
=
'config file'
)
parser
.
add_argument
(
'-e'
,
'--eval'
,
action
=
"store_true"
)
parser
.
add_argument
(
'-d'
,
'--draw'
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
cfg
=
yaml
.
load
(
open
(
args
.
config
,
'r'
)
.
read
(),
Loader
=
yaml
.
FullLoader
)
...
...
@@ -18,6 +19,8 @@ def main():
if
args
.
eval
:
solver
.
evaluate
()
elif
args
.
draw
:
solver
.
draw_val
()
else
:
solver
.
run
()
...
...
model/seq_labeling.py
View file @
60c3955
...
...
@@ -18,7 +18,7 @@ def masked_softmax(X, valid_lens):
# [batch_size, num_heads, seq_len, seq_len]
shape
=
X
.
shape
if
valid_lens
.
dim
()
==
1
:
valid_lens
=
torch
.
repeat_interleave
(
valid_lens
,
shape
[
1
])
valid_lens
=
torch
.
repeat_interleave
(
valid_lens
,
shape
[
2
])
else
:
valid_lens
=
valid_lens
.
reshape
(
-
1
)
# On the last axis, replace masked elements with a very large negative
...
...
solver/sl_solver.py
View file @
60c3955
import
copy
import
os
import
cv2
import
json
import
torch
from
PIL
import
Image
,
ImageDraw
,
ImageFont
from
data
import
build_dataloader
from
loss
import
build_loss
...
...
@@ -34,6 +37,11 @@ class SLSolver(object):
self
.
hyper_params
=
cfg
[
'solver'
][
'args'
]
self
.
base_on
=
self
.
hyper_params
[
'base_on'
]
self
.
model_path
=
self
.
hyper_params
[
'model_path'
]
self
.
val_image_path
=
self
.
hyper_params
[
'val_image_path'
]
self
.
val_go_path
=
self
.
hyper_params
[
'val_go_path'
]
self
.
val_map_path
=
self
.
hyper_params
[
'val_map_path'
]
self
.
draw_font_path
=
self
.
hyper_params
[
'draw_font_path'
]
self
.
thresholds
=
self
.
hyper_params
[
'thresholds'
]
try
:
self
.
epoch
=
self
.
hyper_params
[
'epoch'
]
except
Exception
:
...
...
@@ -41,19 +49,22 @@ class SLSolver(object):
self
.
logger
,
self
.
log_dir
=
get_logger_and_log_dir
(
**
cfg
[
'solver'
][
'logger'
])
def
accuracy
(
self
,
y_pred
,
y_true
,
valid_lens
,
thresholds
=
0.5
):
def
accuracy
(
self
,
y_pred
,
y_true
,
valid_lens
,
eval
=
False
):
# [batch_size, seq_len, num_classes]
y_pred_sigmoid
=
torch
.
nn
.
Sigmoid
()(
y_pred
)
# [batch_size, seq_len]
y_pred_idx
=
torch
.
argmax
(
y_pred_sigmoid
,
dim
=-
1
)
+
1
# [batch_size, seq_len]
y_pred_is_other
=
(
torch
.
amax
(
y_pred_sigmoid
,
dim
=-
1
)
>
thresholds
)
.
int
()
y_pred_is_other
=
(
torch
.
amax
(
y_pred_sigmoid
,
dim
=-
1
)
>
self
.
thresholds
)
.
int
()
y_pred_rebuild
=
torch
.
multiply
(
y_pred_idx
,
y_pred_is_other
)
y_true_idx
=
torch
.
argmax
(
y_true
,
dim
=-
1
)
+
1
y_true_is_other
=
torch
.
sum
(
y_true
,
dim
=-
1
)
.
int
()
y_true_rebuild
=
torch
.
multiply
(
y_true_idx
,
y_true_is_other
)
if
eval
:
return
y_pred_rebuild
,
y_true_rebuild
masked_y_true_rebuild
=
sequence_mask
(
y_true_rebuild
,
valid_lens
,
value
=-
1
)
return
torch
.
sum
((
y_pred_rebuild
==
masked_y_true_rebuild
)
.
int
())
.
item
()
...
...
@@ -168,19 +179,7 @@ class SLSolver(object):
# pred = torch.nn.Sigmoid()(self.model(X))
y_pred
=
self
.
model
(
X
,
valid_lens
)
# [batch_size, seq_len, num_classes]
y_pred_sigmoid
=
torch
.
nn
.
Sigmoid
()(
y_pred
)
# [batch_size, seq_len]
y_pred_idx
=
torch
.
argmax
(
y_pred_sigmoid
,
dim
=-
1
)
+
1
# [batch_size, seq_len]
y_pred_is_other
=
(
torch
.
amax
(
y_pred_sigmoid
,
dim
=-
1
)
>
0.5
)
.
int
()
y_pred_rebuild
=
torch
.
multiply
(
y_pred_idx
,
y_pred_is_other
)
y_true_idx
=
torch
.
argmax
(
y_true
,
dim
=-
1
)
+
1
y_true_is_other
=
torch
.
sum
(
y_true
,
dim
=-
1
)
.
int
()
y_true_rebuild
=
torch
.
multiply
(
y_true_idx
,
y_true_is_other
)
# masked_y_true_rebuild = sequence_mask(y_true_rebuild, valid_lens, value=-1)
y_pred_rebuild
,
y_true_rebuild
=
self
.
accuracy
(
y_pred
,
y_true
,
valid_lens
,
eval
=
True
)
for
idx
,
seq_result
in
enumerate
(
y_true_rebuild
.
cpu
()
.
numpy
()
.
tolist
()):
label_true_list
.
extend
(
seq_result
[:
valid_lens
.
cpu
()
.
numpy
()[
idx
]])
...
...
@@ -193,3 +192,111 @@ class SLSolver(object):
print
(
acc
)
print
(
cm
)
print
(
report
)
def
draw_val
(
self
):
if
not
os
.
path
.
isdir
(
self
.
val_image_path
):
print
(
'Warn: val_image_path not exists: {0}'
.
format
(
self
.
val_image_path
))
return
if
not
os
.
path
.
isdir
(
self
.
val_go_path
):
print
(
'Warn: val_go_path not exists: {0}'
.
format
(
self
.
val_go_path
))
return
if
not
os
.
path
.
isfile
(
self
.
val_map_path
):
print
(
'Warn: val_map_path not exists: {0}'
.
format
(
self
.
val_map_path
))
return
map_key_input
=
'x_y_valid_lens'
map_key_text
=
'find_top_text'
map_key_value
=
'find_value'
group_cn_list
=
[
'其他'
,
'开票日期'
,
'发票代码'
,
'机打号码'
,
'车辆类型'
,
'电话'
,
'发动机号码'
,
'车架号'
,
'帐号'
,
'开户银行'
,
'小写'
]
dataset_base_dir
=
os
.
path
.
dirname
(
self
.
val_map_path
)
val_dataset_dir
=
os
.
path
.
join
(
dataset_base_dir
,
'valid'
)
save_dir
=
os
.
path
.
join
(
dataset_base_dir
,
'draw_val'
)
if
not
os
.
path
.
isdir
(
save_dir
):
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
self
.
model
.
eval
()
with
open
(
self
.
val_map_path
,
'r'
)
as
fp
:
val_map
=
json
.
load
(
fp
)
for
img_name
in
sorted
(
os
.
listdir
(
self
.
val_image_path
)):
print
(
'Info: start {0}'
.
format
(
img_name
))
image_path
=
os
.
path
.
join
(
self
.
val_image_path
,
img_name
)
img
=
cv2
.
imread
(
image_path
)
im_h
,
im_w
,
_
=
img
.
shape
img_pil
=
Image
.
fromarray
(
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
))
draw
=
ImageDraw
.
Draw
(
img_pil
)
if
im_h
<
im_w
:
size
=
int
(
im_h
*
0.015
)
else
:
size
=
int
(
im_w
*
0.015
)
if
size
<
14
:
size
=
14
font
=
ImageFont
.
truetype
(
self
.
draw_font_path
,
size
,
encoding
=
'utf-8'
)
green_color
=
(
0
,
255
,
0
)
red_color
=
(
255
,
0
,
0
)
blue_color
=
(
0
,
0
,
255
)
base_image_name
,
_
=
os
.
path
.
splitext
(
img_name
)
go_res_json_path
=
os
.
path
.
join
(
self
.
val_go_path
,
'{0}.json'
.
format
(
base_image_name
))
with
open
(
go_res_json_path
,
'r'
)
as
fp
:
go_res_list
=
json
.
load
(
fp
)
with
open
(
os
.
path
.
join
(
val_dataset_dir
,
val_map
[
img_name
][
map_key_input
]),
'r'
)
as
fp
:
input_list
,
label_list
,
valid_lens_scalar
=
json
.
load
(
fp
)
X
=
torch
.
tensor
(
input_list
)
.
unsqueeze
(
0
)
.
to
(
self
.
device
)
y_true
=
torch
.
tensor
(
label_list
)
.
unsqueeze
(
0
)
.
float
()
.
to
(
self
.
device
)
valid_lens
=
torch
.
tenor
([
valid_lens_scalar
,
])
.
to
(
self
.
device
)
del
input_list
del
label_list
y_pred
=
self
.
model
(
X
,
valid_lens
)
y_pred_rebuild
,
y_true_rebuild
=
self
.
accuracy
(
y_pred
,
y_true
,
valid_lens
,
eval
=
True
)
pred
=
y_pred_rebuild
.
cpu
()
.
numpy
()
.
tolist
()[
0
]
label
=
y_true_rebuild
.
cpu
()
.
numpy
()
.
tolist
()[
0
]
correct
=
0
bbox_draw_dict
=
dict
()
for
i
in
range
(
valid_lens_scalar
):
if
pred
[
i
]
==
label
[
i
]:
correct
+=
1
if
pred
[
i
]
!=
0
:
# 绿色
bbox_draw_dict
[
i
]
=
(
group_cn_list
[
pred
[
i
]],
)
else
:
# 红色:左上角label,右上角pred
bbox_draw_dict
[
i
]
=
(
group_cn_list
[
label
[
i
]],
group_cn_list
[
pred
[
i
]])
correct_rate
=
correct
/
valid_lens_scalar
# 画图
for
idx
,
text_tuple
in
bbox_draw_dict
.
items
():
(
x0
,
y0
,
x1
,
y1
,
x2
,
y2
,
x3
,
y3
),
_
=
go_res_list
[
idx
]
line_color
=
green_color
if
len
(
text_tuple
)
==
1
else
red_color
draw
.
polygon
([(
x0
,
y0
),
(
x1
,
y1
),
(
x2
,
y2
),
(
x3
,
y3
)],
outline
=
line_color
)
draw
.
text
((
int
(
x0
),
int
(
y0
)),
text_tuple
[
0
],
green_color
,
font
=
font
)
if
len
(
text_tuple
)
==
2
:
draw
.
text
((
int
(
x1
),
int
(
y1
)),
text_tuple
[
1
],
red_color
,
font
=
font
)
draw
.
text
((
0
,
0
),
str
(
correct_rate
),
blue_color
,
font
=
font
)
last_y
=
size
for
k
,
v
in
val_map
[
img_name
][
map_key_value
]
.
items
():
draw
.
text
((
0
,
last_y
),
'{0}: {1}'
.
format
(
k
,
v
),
blue_color
,
font
=
font
)
last_y
+=
size
img_pil
.
save
(
os
.
path
.
join
(
save_dir
,
img_name
))
# break
...
...
train.sh
100644 → 100755
View file @
60c3955
CUDA_VISIBLE_DEVICES
=
0 nohup python main.py > train.log 2>&1 &
\ No newline at end of file
CUDA_VISIBLE_DEVICES
=
0 nohup python main.py --config
=
config/sl.yaml > train.log 2>&1 &
\ No newline at end of file
...
...
Write
Preview
Styling with
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment