83048d22 by 周伟奇

add augmentation methods

1 parent 37a9d47e
......@@ -15,6 +15,7 @@ class F3Classification(BaseModel):
super().__init__(*args, **kwargs)
self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1
self.class_label_map = self.get_class_label_map(class_name_list, class_other_first)
self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
@staticmethod
def history_save(history, save_path):
......@@ -60,6 +61,8 @@ class F3Classification(BaseModel):
label = self.class_label_map[class_name]
for file_name in os.listdir(class_dir_path):
# TODO image check
if os.path.splitext(file_name)[1] not in self.image_ext_set:
continue
file_path = os.path.join(class_dir_path, file_name)
image_path_list.append(file_path)
label_list.append(tf.one_hot(label, depth=self.class_count))
......@@ -68,21 +71,42 @@ class F3Classification(BaseModel):
@staticmethod
# @tf.function
def random_rgb_2_bgr(image, label):
if random.random() > 0.2:
return image, label
image = image[:, :, ::-1]
# 1/5
if random.random() < 0.1:
image = image[:, :, ::-1]
return image, label
@staticmethod
# @tf.function
def random_grayscale_expand(image, label):
if random.random() > 0.1:
return image, label
image = tf.image.rgb_to_grayscale(image)
image = tf.image.grayscale_to_rgb(image)
# 1/10
if random.random() < 0.1:
image = tf.image.rgb_to_grayscale(image)
image = tf.image.grayscale_to_rgb(image)
return image, label
@staticmethod
def random_flip_left_right(image, label):
# 1/10
if random.random() < 0.2:
image = tf.image.random_flip_left_right(image)
return image
@staticmethod
def random_flip_up_down(image, label):
# 1/10
if random.random() < 0.2:
image = tf.image.random_flip_up_down(image)
return image
@staticmethod
def random_rot90(image, label):
# 1/10
if random.random() < 0.1:
image = tf.image.rot90(image, k=random.randint(1, 3))
return image
@staticmethod
# @tf.function
def load_image(image_path, label):
image = tf.io.read_file(image_path)
......@@ -163,7 +187,13 @@ class F3Classification(BaseModel):
name=train_dir_name,
batch_size=batch_size,
# augmentation_methods=[],
augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'],
augmentation_methods=[
'random_flip_left_right',
'random_flip_up_down',
'random_rot90',
'random_rgb_2_bgr',
'random_grayscale_expand'
],
)
validate_dataset = self.load_dataset(
dataset_dir=os.path.join(dataset_dir, validate_dir_name),
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!