Skip to content
Toggle navigation
Toggle navigation
This project
Loading...
Sign in
周伟奇
/
part_of_F3_OCR
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Issue Boards
Files
Commits
Network
Compare
Branches
Tags
1ea84670
authored
2022-06-30 12:59:19 +0800
by
周伟奇
Browse Files
Options
Browse Files
Tag
Download
Email Patches
Plain Diff
add CustomMetric
1 parent
83048d22
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
81 additions
and
10 deletions
.gitignore
classification/const.py
classification/main.py
classification/model.py
.gitignore
View file @
1ea8467
...
...
@@ -11,4 +11,7 @@
.*
!.gitignore
test.py
\ No newline at end of file
test.py
*.h5
*.jpg
*.out
\ No newline at end of file
...
...
classification/const.py
View file @
1ea8467
...
...
@@ -4,3 +4,5 @@ CLASS_OTHER_FIRST = True
CLASS_CN_LIST
=
[
CLASS_OTHER_CN
,
'身份证'
,
'营业执照'
,
'经销商授权书'
,
'个人授权书'
]
OTHER_THRESHOLDS
=
0.5
...
...
classification/main.py
View file @
1ea8467
...
...
@@ -21,5 +21,5 @@ if __name__ == '__main__':
batch_size
=
128
m
.
train
(
dataset_dir
,
epoch
,
batch_size
,
ckpt_path
,
history_save_path
,
train_dir_name
=
'train'
,
validate_dir_name
=
'test'
)
train_dir_name
=
'train'
,
validate_dir_name
=
'test'
,
thresholds
=
const
.
OTHER_THRESHOLDS
)
...
...
classification/model.py
View file @
1ea8467
...
...
@@ -9,6 +9,43 @@ import matplotlib.pyplot as plt
from
base_class
import
BaseModel
class
CustomMetric
(
metrics
.
Metric
):
def
__init__
(
self
,
thresholds
=
0.5
,
name
=
"custom_metric"
,
**
kwargs
):
super
(
CustomMetric
,
self
)
.
__init__
(
name
=
name
,
**
kwargs
)
self
.
thresholds
=
thresholds
self
.
true_positives
=
self
.
add_weight
(
name
=
"ctp"
,
initializer
=
"zeros"
)
self
.
count
=
self
.
add_weight
(
name
=
"count"
,
initializer
=
"zeros"
,
dtype
=
'int32'
)
def
update_state
(
self
,
y_true
,
y_pred
,
sample_weight
=
None
):
y_true_idx
=
tf
.
argmax
(
y_true
,
axis
=
1
)
+
1
y_true_is_other
=
tf
.
cast
(
tf
.
math
.
reduce_sum
(
y_true
,
axis
=
1
),
"int64"
)
y_true
=
tf
.
math
.
multiply
(
y_true_idx
,
y_true_is_other
)
y_pred_idx
=
tf
.
argmax
(
y_pred
,
axis
=
1
)
+
1
y_pred_is_other
=
tf
.
cast
(
tf
.
math
.
greater_equal
(
tf
.
math
.
reduce_max
(
y_pred
,
axis
=
1
),
self
.
thresholds
),
'int64'
)
y_pred
=
tf
.
math
.
multiply
(
y_pred_idx
,
y_pred_is_other
)
print
(
y_true
)
print
(
y_pred
)
values
=
tf
.
cast
(
y_true
,
"int32"
)
==
tf
.
cast
(
y_pred
,
"int32"
)
values
=
tf
.
cast
(
values
,
"float32"
)
if
sample_weight
is
not
None
:
sample_weight
=
tf
.
cast
(
sample_weight
,
"float32"
)
values
=
tf
.
multiply
(
values
,
sample_weight
)
self
.
true_positives
.
assign_add
(
tf
.
reduce_sum
(
values
))
self
.
count
.
assign_add
(
tf
.
shape
(
y_true
)[
0
])
def
result
(
self
):
return
self
.
true_positives
/
tf
.
cast
(
self
.
count
,
'float32'
)
def
reset_state
(
self
):
# The state of the metric will be reset at the start of each epoch.
self
.
true_positives
.
assign
(
0.0
)
self
.
count
.
assign
(
0
)
class
F3Classification
(
BaseModel
):
def
__init__
(
self
,
class_name_list
,
class_other_first
,
*
args
,
**
kwargs
):
...
...
@@ -18,6 +55,12 @@ class F3Classification(BaseModel):
self
.
image_ext_set
=
{
".jpg"
,
".jpeg"
,
".png"
,
".bmp"
,
".tif"
,
".tiff"
}
@staticmethod
def
gpu_config
():
gpus
=
tf
.
config
.
experimental
.
list_physical_devices
(
device_type
=
'GPU'
)
# print(gpus)
tf
.
config
.
set_visible_devices
(
devices
=
gpus
[
1
],
device_type
=
'GPU'
)
@staticmethod
def
history_save
(
history
,
save_path
):
acc
=
history
.
history
[
'accuracy'
]
val_acc
=
history
.
history
[
'val_accuracy'
]
...
...
@@ -90,21 +133,21 @@ class F3Classification(BaseModel):
# 1/10
if
random
.
random
()
<
0.2
:
image
=
tf
.
image
.
random_flip_left_right
(
image
)
return
image
return
image
,
label
@staticmethod
def
random_flip_up_down
(
image
,
label
):
# 1/10
if
random
.
random
()
<
0.2
:
image
=
tf
.
image
.
random_flip_up_down
(
image
)
return
image
return
image
,
label
@staticmethod
def
random_rot90
(
image
,
label
):
# 1/10
if
random
.
random
()
<
0.1
:
image
=
tf
.
image
.
rot90
(
image
,
k
=
random
.
randint
(
1
,
3
))
return
image
return
image
,
label
@staticmethod
# @tf.function
...
...
@@ -166,14 +209,17 @@ class F3Classification(BaseModel):
return
model
def
train
(
self
,
dataset_dir
,
epoch
,
batch_size
,
ckpt_path
,
history_save_path
,
train_dir_name
=
'train'
,
validate_dir_name
=
'test'
):
train_dir_name
=
'train'
,
validate_dir_name
=
'test'
,
thresholds
=
0.5
):
self
.
gpu_config
()
model
=
self
.
load_model
()
model
.
summary
()
model
.
compile
(
optimizer
=
optimizers
.
Adam
(
learning_rate
=
3e-4
),
loss
=
tfa
.
losses
.
SigmoidFocalCrossEntropy
(),
metrics
=
[
'accuracy'
,
],
loss
=
tfa
.
losses
.
SigmoidFocalCrossEntropy
(),
# TODO >>>
metrics
=
[
CustomMetric
(
thresholds
)
,
],
loss_weights
=
None
,
weighted_metrics
=
None
,
...
...
@@ -214,5 +260,25 @@ class F3Classification(BaseModel):
self
.
history_save
(
history
,
history_save_path
)
def
test
(
self
):
print
(
self
.
class_label_map
)
print
(
self
.
class_count
)
y_true
=
[
[
0
,
1
,
0
],
[
0
,
1
,
0
],
[
0
,
0
,
1
],
[
0
,
0
,
0
],
]
y_pre
=
[
[
0.1
,
0.8
,
0.9
],
# TODO multi_label
[
0.2
,
0.8
,
0.1
],
[
0.2
,
0.1
,
0.85
],
[
0.2
,
0.4
,
0.1
],
]
# x = tf.argmax(y_pre, axis=1)
# y = tf.reduce_sum(y_pre, axis=1)
# print(x)
# print(y)
# m = tf.keras.metrics.TopKCategoricalAccuracy(k=1)
m
=
CustomMetric
(
0.5
)
m
.
update_state
(
y_true
,
y_pre
)
print
(
m
.
result
()
.
numpy
())
...
...
Write
Preview
Styling with
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment