Skip to content
Toggle navigation
Toggle navigation
This project
Loading...
Sign in
周伟奇
/
test_on_pytorch
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Issue Boards
Files
Commits
Network
Compare
Branches
Tags
331c7717
authored
2022-12-23 16:51:58 +0800
by
周伟奇
Browse Files
Options
Browse Files
Tag
Download
Email Patches
Plain Diff
add fix pred text
1 parent
e573fab0
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
80 additions
and
7 deletions
data/create_dataset2.py
solver/sl_solver.py
utils/__init__.py
utils/fix_pred.py
data/create_dataset2.py
View file @
331c771
...
...
@@ -239,7 +239,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
[
label_bbox
[
0
],
label_bbox
[
3
]],
]
iou
=
bbox_iou
(
go_bbox_rebuild
,
label_bbox_rebuild
)
if
iou
>=
0.
5
:
if
iou
>=
0.
4
:
label_idx_dict
[
go_idx
]
=
label_idx
X
=
list
()
...
...
@@ -250,7 +250,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
# text_vec_max_lens = 15 * 50
# dim = 1 + 5 + 8 + text_vec_max_lens
max_jieba_char
=
8
max_jieba_char
=
4
text_vec_max_lens
=
max_jieba_char
*
100
dim
=
1
+
5
+
8
+
text_vec_max_lens
...
...
@@ -333,7 +333,7 @@ def build_dataset(img_dir, go_res_dir, label_dir, top_text_list, skip_list, save
if
__name__
==
'__main__'
:
base_dir
=
'/Users/zhouweiqi/Downloads/gcfp/data'
go_dir
=
os
.
path
.
join
(
base_dir
,
'go_res'
)
dataset_save_dir
=
os
.
path
.
join
(
base_dir
,
'dataset160x
8
14'
)
dataset_save_dir
=
os
.
path
.
join
(
base_dir
,
'dataset160x
4
14'
)
label_dir
=
os
.
path
.
join
(
base_dir
,
'labeled'
)
train_go_path
=
os
.
path
.
join
(
go_dir
,
'train'
)
...
...
solver/sl_solver.py
View file @
331c771
...
...
@@ -10,8 +10,7 @@ from data import build_dataloader
from
loss
import
build_loss
from
model
import
build_model
from
optimizer
import
build_lr_scheduler
,
build_optimizer
from
utils
import
SOLVER_REGISTRY
,
get_logger_and_log_dir
from
utils
import
sequence_mask
from
utils
import
SOLVER_REGISTRY
,
get_logger_and_log_dir
,
sequence_mask
,
fix_text_obj
from
sklearn.metrics
import
confusion_matrix
,
accuracy_score
,
classification_report
...
...
@@ -223,6 +222,18 @@ class SLSolver(object):
map_key_text
=
'find_top_text'
map_key_value
=
'find_value'
test_group_id
=
[
1
,
2
,
5
,
9
,
20
,
15
,
16
,
22
,
24
,
28
]
fix_pred_methods
=
[
(
'only_date'
,
{}),
(
'only_digit'
,
{}),
(
'do_nothing'
,
{}),
(
'do_nothing'
,
{}),
(
'remove_start'
,
{
'start_char'
:
'电话'
}),
(
'only_digit_alpha'
,
{}),
(
'do_nothing'
,
{}),
(
'remove_start'
,
{
'start_char'
:
'账号'
}),
(
'remove_bank'
,
{}),
(
'only_amount'
,
{}),
]
group_cn_list
=
[
'其他'
,
'开票日期'
,
'发票代码'
,
'机打号码'
,
'车辆类型'
,
'电话'
,
'发动机号码'
,
'车架号'
,
'帐号'
,
'开户银行'
,
'小写'
]
skip_list_valid
=
[
# 'CH-B102897920-2.jpg',
...
...
@@ -338,12 +349,17 @@ class SLSolver(object):
group_text_list
.
append
(
None
)
for
idx
,
text
in
enumerate
(
group_text_list
):
if
'#'
in
text
:
continue
key_cn
=
group_cn_list
[
idx
+
1
]
pred_idx_list
=
bbox_text_dict
.
get
(
idx
)
if
isinstance
(
pred_idx_list
,
list
):
pred_text_list
=
[
go_res_list
[
idx
][
-
1
]
for
idx
in
pred_idx_list
]
pred_text
=
' '
.
join
(
pred_text_list
)
pred_text_src
=
''
.
join
(
pred_text_list
)
# pred_text = pred_text_src
pred_text
=
getattr
(
fix_text_obj
,
fix_pred_methods
[
idx
][
0
])(
pred_text_src
,
**
fix_pred_methods
[
idx
][
1
])
else
:
pred_text
=
None
...
...
@@ -356,7 +372,7 @@ class SLSolver(object):
# break
for
key_cn
,
(
correct_count
,
all_count
)
in
data_dict
.
items
():
print
(
'{0}: {1}'
.
format
(
key_cn
,
round
(
correct_count
/
all_count
,
2
)))
print
(
'{0}: {1}'
.
format
(
key_cn
,
round
(
correct_count
/
all_count
,
4
)))
print
(
'==========================='
)
...
...
utils/__init__.py
View file @
331c771
import
torch
from
.registery
import
*
from
.logger
import
get_logger_and_log_dir
from
.fix_pred
import
fix_text_obj
__all__
=
[
'Registry'
,
...
...
utils/fix_pred.py
0 → 100644
View file @
331c771
import
re
class
FixText
:
@staticmethod
def
do_nothing
(
pred_text_src
):
return
pred_text_src
@staticmethod
def
only_date
(
pred_text_src
):
re_se
=
re
.
search
(
r'20.*'
,
pred_text_src
)
if
re_se
:
return
re_se
.
group
()
else
:
return
pred_text_src
@staticmethod
def
only_digit
(
pred_text_src
):
re_se
=
re
.
search
(
r'\d+'
,
pred_text_src
)
if
re_se
:
return
re_se
.
group
()
else
:
return
pred_text_src
@staticmethod
def
remove_start
(
pred_text_src
,
start_char
=
'电话'
):
if
pred_text_src
.
startswith
(
start_char
):
return
pred_text_src
.
replace
(
start_char
,
''
)
else
:
return
pred_text_src
@staticmethod
def
only_digit_alpha
(
pred_text_src
):
re_se
=
re
.
search
(
r'\w+'
,
pred_text_src
)
if
re_se
:
return
re_se
.
group
()
else
:
return
pred_text_src
@staticmethod
def
remove_bank
(
pred_text_src
):
re_se
=
re
.
search
(
r'户银行(.*)'
,
pred_text_src
)
if
re_se
:
return
re_se
.
group
(
1
)
else
:
return
pred_text_src
@staticmethod
def
only_amount
(
pred_text_src
):
re_se
=
re
.
search
(
r'\d+[-,\.]\d+'
,
pred_text_src
)
if
re_se
:
return
re_se
.
group
()
.
replace
(
'-'
,
'.'
)
.
replace
(
','
,
'.'
)
else
:
return
pred_text_src
fix_text_obj
=
FixText
()
Write
Preview
Styling with
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment