Skip to content
Toggle navigation
Toggle navigation
This project
Loading...
Sign in
周伟奇
/
part_of_F3_OCR
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Issue Boards
Files
Commits
Network
Compare
Branches
Tags
83048d22
authored
2022-06-29 19:04:53 +0800
by
周伟奇
Browse Files
Options
Browse Files
Tag
Download
Email Patches
Plain Diff
add augmentation methods
1 parent
37a9d47e
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
5 deletions
classification/model.py
classification/model.py
View file @
83048d2
...
...
@@ -15,6 +15,7 @@ class F3Classification(BaseModel):
super
()
.
__init__
(
*
args
,
**
kwargs
)
self
.
class_count
=
len
(
class_name_list
)
if
not
class_other_first
else
len
(
class_name_list
)
-
1
self
.
class_label_map
=
self
.
get_class_label_map
(
class_name_list
,
class_other_first
)
self
.
image_ext_set
=
{
".jpg"
,
".jpeg"
,
".png"
,
".bmp"
,
".tif"
,
".tiff"
}
@staticmethod
def
history_save
(
history
,
save_path
):
...
...
@@ -60,6 +61,8 @@ class F3Classification(BaseModel):
label
=
self
.
class_label_map
[
class_name
]
for
file_name
in
os
.
listdir
(
class_dir_path
):
# TODO image check
if
os
.
path
.
splitext
(
file_name
)[
1
]
not
in
self
.
image_ext_set
:
continue
file_path
=
os
.
path
.
join
(
class_dir_path
,
file_name
)
image_path_list
.
append
(
file_path
)
label_list
.
append
(
tf
.
one_hot
(
label
,
depth
=
self
.
class_count
))
...
...
@@ -68,21 +71,42 @@ class F3Classification(BaseModel):
@staticmethod
# @tf.function
def
random_rgb_2_bgr
(
image
,
label
):
if
random
.
random
()
>
0.2
:
return
image
,
label
# 1/5
if
random
.
random
()
<
0.1
:
image
=
image
[:,
:,
::
-
1
]
return
image
,
label
@staticmethod
# @tf.function
def
random_grayscale_expand
(
image
,
label
):
if
random
.
random
()
>
0.1
:
return
image
,
label
# 1/10
if
random
.
random
()
<
0.1
:
image
=
tf
.
image
.
rgb_to_grayscale
(
image
)
image
=
tf
.
image
.
grayscale_to_rgb
(
image
)
return
image
,
label
@staticmethod
def
random_flip_left_right
(
image
,
label
):
# 1/10
if
random
.
random
()
<
0.2
:
image
=
tf
.
image
.
random_flip_left_right
(
image
)
return
image
@staticmethod
def
random_flip_up_down
(
image
,
label
):
# 1/10
if
random
.
random
()
<
0.2
:
image
=
tf
.
image
.
random_flip_up_down
(
image
)
return
image
@staticmethod
def
random_rot90
(
image
,
label
):
# 1/10
if
random
.
random
()
<
0.1
:
image
=
tf
.
image
.
rot90
(
image
,
k
=
random
.
randint
(
1
,
3
))
return
image
@staticmethod
# @tf.function
def
load_image
(
image_path
,
label
):
image
=
tf
.
io
.
read_file
(
image_path
)
...
...
@@ -163,7 +187,13 @@ class F3Classification(BaseModel):
name
=
train_dir_name
,
batch_size
=
batch_size
,
# augmentation_methods=[],
augmentation_methods
=
[
'random_rgb_2_bgr'
,
'random_grayscale_expand'
],
augmentation_methods
=
[
'random_flip_left_right'
,
'random_flip_up_down'
,
'random_rot90'
,
'random_rgb_2_bgr'
,
'random_grayscale_expand'
],
)
validate_dataset
=
self
.
load_dataset
(
dataset_dir
=
os
.
path
.
join
(
dataset_dir
,
validate_dir_name
),
...
...
Write
Preview
Styling with
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment