Skip to content
Toggle navigation
Toggle navigation
This project
Loading...
Sign in
周伟奇
/
part_of_F3_OCR
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Issue Boards
Files
Commits
Network
Compare
Branches
Tags
37a9d47e
authored
2022-06-29 16:46:28 +0800
by
周伟奇
Browse Files
Options
Browse Files
Tag
Download
Email Patches
Plain Diff
classification train
1 parent
cbeebc6d
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
109 additions
and
101 deletions
authorization_from/retriever.py
classification/base_class.py
classification/main.py
classification/model.py
authorization_from/retriever.py
View file @
37a9d47
...
...
@@ -10,6 +10,7 @@ class Retriever:
self
.
key_text_set
=
self
.
get_key_text_set
(
target_fields
)
def
get_key_text_set
(
self
,
target_fields
):
# 关键词集合
key_text_set
=
set
()
for
key_text_list
in
target_fields
[
self
.
keys_str
]
.
values
():
for
key_text
,
_
,
_
in
key_text_list
:
...
...
@@ -18,11 +19,13 @@ class Retriever:
@staticmethod
def
key_top1
(
coordinates_list
,
key_coordinates
):
# 关键词查找方向:最上面
coordinates_list
.
sort
(
key
=
lambda
x
:
x
[
1
])
return
coordinates_list
[
0
]
@staticmethod
def
key_right
(
coordinates_list
,
key_coordinates
,
top_padding
,
bottom_padding
):
# 关键词查找方向:右侧
if
len
(
coordinates_list
)
==
1
:
return
coordinates_list
[
0
]
height
=
key_coordinates
[
-
1
]
-
key_coordinates
[
1
]
...
...
@@ -41,6 +44,7 @@ class Retriever:
@staticmethod
def
value_right
(
go_res
,
key_coordinates
,
top_padding
,
bottom_padding
):
# 字段值查找方向:右侧
height
=
key_coordinates
[
-
1
]
-
key_coordinates
[
1
]
y_min
=
key_coordinates
[
1
]
-
(
top_padding
*
height
)
y_max
=
key_coordinates
[
-
1
]
+
(
bottom_padding
*
height
)
...
...
@@ -57,6 +61,7 @@ class Retriever:
@staticmethod
def
value_under
(
go_res
,
key_coordinates
,
left_padding
,
right_padding
):
# 字段值查找方向:下方
width
=
key_coordinates
[
2
]
-
key_coordinates
[
0
]
x_min
=
key_coordinates
[
0
]
-
(
width
*
left_padding
)
x_max
=
key_coordinates
[
2
]
+
(
width
*
right_padding
)
...
...
classification/base_class.py
View file @
37a9d47
...
...
@@ -9,7 +9,8 @@ class BaseModel:
"""
raise
NotImplementedError
(
".load() must be overridden."
)
def
train
(
self
,
dataset_dir
,
epoch
,
batch_size
,
ckpt_path
,
train_dir_name
=
'train'
,
validate_dir_name
=
'test'
):
def
train
(
self
,
dataset_dir
,
epoch
,
batch_size
,
ckpt_path
,
history_save_path
,
train_dir_name
=
'train'
,
validate_dir_name
=
'test'
):
"""
Model training process
"""
...
...
classification/main.py
View file @
37a9d47
...
...
@@ -14,10 +14,12 @@ if __name__ == '__main__':
# m.test()
dataset_dir
=
'/home/zwq/data/data_224'
dataset_dir
=
'/home/zwq/data/data_224
_f3
'
ckpt_path
=
os
.
path
.
join
(
base_dir
,
'ckpt_{0}.h5'
.
format
(
datetime
.
now
()
.
strftime
(
'
%
Y-
%
m-
%
d_
%
H:
%
M:
%
S'
)))
history_save_path
=
os
.
path
.
join
(
base_dir
,
'history_{0}.jpg'
.
format
(
datetime
.
now
()
.
strftime
(
'
%
Y-
%
m-
%
d_
%
H:
%
M:
%
S'
)))
epoch
=
100
batch_size
=
128
m
.
train
(
dataset_dir
,
epoch
,
batch_size
,
ckpt_path
,
train_dir_name
=
'train'
,
validate_dir_name
=
'test'
)
m
.
train
(
dataset_dir
,
epoch
,
batch_size
,
ckpt_path
,
history_save_path
,
train_dir_name
=
'train'
,
validate_dir_name
=
'test'
)
...
...
classification/model.py
View file @
37a9d47
...
...
@@ -9,37 +9,6 @@ import matplotlib.pyplot as plt
from
base_class
import
BaseModel
@tf.function
def
random_rgb_2_bgr
(
image
,
label
):
if
random
.
random
()
>
0.5
:
return
image
,
label
image
=
image
[:,
:,
::
-
1
]
return
image
,
label
@tf.function
def
random_grayscale_expand
(
image
,
label
):
if
random
.
random
()
>
0.1
:
return
image
,
label
image
=
tf
.
image
.
rgb_to_grayscale
(
image
)
image
=
tf
.
image
.
grayscale_to_rgb
(
image
)
return
image
,
label
@tf.function
def
load_image
(
image_path
,
label
):
image
=
tf
.
io
.
read_file
(
image_path
)
image
=
tf
.
image
.
decode_image
(
image
,
channels
=
3
)
return
image
,
label
@tf.function
def
preprocess_input
(
image
,
label
):
image
=
tf
.
image
.
resize
(
image
,
[
224
,
224
])
image
=
applications
.
mobilenet_v2
.
preprocess_input
(
image
)
return
image
,
label
class
F3Classification
(
BaseModel
):
def
__init__
(
self
,
class_name_list
,
class_other_first
,
*
args
,
**
kwargs
):
...
...
@@ -48,6 +17,34 @@ class F3Classification(BaseModel):
self
.
class_label_map
=
self
.
get_class_label_map
(
class_name_list
,
class_other_first
)
@staticmethod
def
history_save
(
history
,
save_path
):
acc
=
history
.
history
[
'accuracy'
]
val_acc
=
history
.
history
[
'val_accuracy'
]
loss
=
history
.
history
[
'loss'
]
val_loss
=
history
.
history
[
'val_loss'
]
plt
.
figure
(
figsize
=
(
8
,
8
))
plt
.
subplot
(
2
,
1
,
1
)
plt
.
plot
(
acc
,
label
=
'Training Accuracy'
)
plt
.
plot
(
val_acc
,
label
=
'Validation Accuracy'
)
plt
.
legend
(
loc
=
'lower right'
)
plt
.
ylabel
(
'Accuracy'
)
plt
.
ylim
([
min
(
plt
.
ylim
()),
1
])
plt
.
title
(
'Training and Validation Accuracy'
)
plt
.
subplot
(
2
,
1
,
2
)
plt
.
plot
(
loss
,
label
=
'Training Loss'
)
plt
.
plot
(
val_loss
,
label
=
'Validation Loss'
)
plt
.
legend
(
loc
=
'upper right'
)
plt
.
ylabel
(
'Cross Entropy'
)
plt
.
ylim
([
0
,
1.0
])
plt
.
title
(
'Training and Validation Loss'
)
plt
.
xlabel
(
'epoch'
)
# plt.show()
plt
.
savefig
(
save_path
)
@staticmethod
def
get_class_label_map
(
class_name_list
,
class_other_first
=
False
):
return
{
cn_name
:
idx
-
1
if
class_other_first
else
idx
for
idx
,
cn_name
in
enumerate
(
class_name_list
)}
...
...
@@ -68,21 +65,52 @@ class F3Classification(BaseModel):
label_list
.
append
(
tf
.
one_hot
(
label
,
depth
=
self
.
class_count
))
return
image_path_list
,
label_list
@staticmethod
# @tf.function
def
random_rgb_2_bgr
(
image
,
label
):
if
random
.
random
()
>
0.2
:
return
image
,
label
image
=
image
[:,
:,
::
-
1
]
return
image
,
label
@staticmethod
# @tf.function
def
random_grayscale_expand
(
image
,
label
):
if
random
.
random
()
>
0.1
:
return
image
,
label
image
=
tf
.
image
.
rgb_to_grayscale
(
image
)
image
=
tf
.
image
.
grayscale_to_rgb
(
image
)
return
image
,
label
@staticmethod
# @tf.function
def
load_image
(
image_path
,
label
):
image
=
tf
.
io
.
read_file
(
image_path
)
# image = tf.image.decode_image(image, channels=3) # TODO 为什么不行
image
=
tf
.
image
.
decode_png
(
image
,
channels
=
3
)
return
image
,
label
@staticmethod
# @tf.function
def
preprocess_input
(
image
,
label
):
image
=
tf
.
image
.
resize
(
image
,
[
224
,
224
])
image
=
applications
.
mobilenet_v2
.
preprocess_input
(
image
)
return
image
,
label
def
load_dataset
(
self
,
dataset_dir
,
name
,
batch_size
=
128
,
augmentation_methods
=
[]):
image_and_label_list
=
self
.
get_image_label_list
(
dataset_dir
)
tensor_slice_dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
image_and_label_list
,
name
=
name
)
tensor_slice_dataset
.
shuffle
(
len
(
image_and_label_list
[
0
]),
reshuffle_each_iteration
=
True
)
tensor_slice_dataset
.
map
(
load_image
,
num_parallel_calls
=
tf
.
data
.
AUTOTUNE
,
deterministic
=
False
)
dataset
=
tensor_slice_dataset
.
shuffle
(
len
(
image_and_label_list
[
0
]),
reshuffle_each_iteration
=
True
)
dataset
=
dataset
.
map
(
self
.
load_image
,
num_parallel_calls
=
tf
.
data
.
AUTOTUNE
,
deterministic
=
False
)
for
augmentation_method
in
augmentation_methods
:
tensor_slice_dataset
.
map
(
getattr
(
self
,
augmentation_method
),
num_parallel_calls
=
tf
.
data
.
AUTOTUNE
,
deterministic
=
False
)
tensor_slice_dataset
.
map
(
preprocess_input
,
num_parallel_calls
=
tf
.
data
.
AUTOTUNE
,
deterministic
=
False
)
parallel_batch_dataset
=
tensor_slice_
dataset
.
batch
(
dataset
=
dataset
.
map
(
getattr
(
self
,
augmentation_method
)
,
num_parallel_calls
=
tf
.
data
.
AUTOTUNE
,
deterministic
=
False
)
dataset
=
dataset
.
map
(
self
.
preprocess_input
,
num_parallel_calls
=
tf
.
data
.
AUTOTUNE
,
deterministic
=
False
)
parallel_batch_dataset
=
dataset
.
batch
(
batch_size
=
batch_size
,
drop_remainder
=
True
,
num_parallel_calls
=
tf
.
data
.
AUTOTUNE
,
...
...
@@ -113,28 +141,29 @@ class F3Classification(BaseModel):
freeze
=
False
return
model
def
train
(
self
,
dataset_dir
,
epoch
,
batch_size
,
ckpt_path
,
train_dir_name
=
'train'
,
validate_dir_name
=
'test'
):
# model = self.load_model()
# model.summary()
#
# model.compile(
# optimizer=optimizers.Adam(learning_rate=3e-4),
# loss=tfa.losses.SigmoidFocalCrossEntropy(),
# metrics=['accuracy', ],
#
# loss_weights=None,
# weighted_metrics=None,
# run_eagerly=None,
# steps_per_execution=None,
# jit_compile=None,
# )
def
train
(
self
,
dataset_dir
,
epoch
,
batch_size
,
ckpt_path
,
history_save_path
,
train_dir_name
=
'train'
,
validate_dir_name
=
'test'
):
model
=
self
.
load_model
()
model
.
summary
()
model
.
compile
(
optimizer
=
optimizers
.
Adam
(
learning_rate
=
3e-4
),
loss
=
tfa
.
losses
.
SigmoidFocalCrossEntropy
(),
metrics
=
[
'accuracy'
,
],
loss_weights
=
None
,
weighted_metrics
=
None
,
run_eagerly
=
None
,
steps_per_execution
=
None
,
jit_compile
=
None
,
)
train_dataset
=
self
.
load_dataset
(
dataset_dir
=
os
.
path
.
join
(
dataset_dir
,
train_dir_name
),
name
=
train_dir_name
,
batch_size
=
batch_size
,
augmentation_methods
=
[],
#
augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'],
#
augmentation_methods=[],
augmentation_methods
=
[
'random_rgb_2_bgr'
,
'random_grayscale_expand'
],
)
validate_dataset
=
self
.
load_dataset
(
dataset_dir
=
os
.
path
.
join
(
dataset_dir
,
validate_dir_name
),
...
...
@@ -143,46 +172,17 @@ class F3Classification(BaseModel):
augmentation_methods
=
[]
)
# ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True)
#
# history = model.fit(
# train_dataset,
# epochs=epoch,
# validation_data=validate_dataset,
# callbacks=[ckpt_callback, ],
# )
#
# acc = history.history['accuracy']
# val_acc = history.history['val_accuracy']
#
# loss = history.history['loss']
# val_loss = history.history['val_loss']
#
# plt.figure(figsize=(8, 8))
# plt.subplot(2, 1, 1)
# plt.plot(acc, label='Training Accuracy')
# plt.plot(val_acc, label='Validation Accuracy')
# plt.legend(loc='lower right')
# plt.ylabel('Accuracy')
# plt.ylim([min(plt.ylim()), 1])
# plt.title('Training and Validation Accuracy')
#
# plt.subplot(2, 1, 2)
# plt.plot(loss, label='Training Loss')
# plt.plot(val_loss, label='Validation Loss')
# plt.legend(loc='upper right')
# plt.ylabel('Cross Entropy')
# plt.ylim([0, 1.0])
# plt.title('Training and Validation Loss')
# plt.xlabel('epoch')
# plt.show()
ckpt_callback
=
callbacks
.
ModelCheckpoint
(
ckpt_path
,
save_best_only
=
True
)
history
=
model
.
fit
(
train_dataset
,
epochs
=
epoch
,
validation_data
=
validate_dataset
,
callbacks
=
[
ckpt_callback
,
],
)
self
.
history_save
(
history
,
history_save_path
)
def
test
(
self
):
print
(
self
.
class_label_map
)
print
(
self
.
class_count
)
# path = '/home/zwq/data/data_224/train/银行卡/bc_1.jpg'
# label = 5
# image, label = self.load_image(path, label)
# print(image.shape)
# image, label = self.preprocess_input(image, label)
# print(image.shape)
...
...
Write
Preview
Styling with
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment