83048d22 by 周伟奇

add augmentation methods

1 parent 37a9d47e
...@@ -15,6 +15,7 @@ class F3Classification(BaseModel): ...@@ -15,6 +15,7 @@ class F3Classification(BaseModel):
15 super().__init__(*args, **kwargs) 15 super().__init__(*args, **kwargs)
16 self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1 16 self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1
17 self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) 17 self.class_label_map = self.get_class_label_map(class_name_list, class_other_first)
18 self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
18 19
19 @staticmethod 20 @staticmethod
20 def history_save(history, save_path): 21 def history_save(history, save_path):
...@@ -60,6 +61,8 @@ class F3Classification(BaseModel): ...@@ -60,6 +61,8 @@ class F3Classification(BaseModel):
60 label = self.class_label_map[class_name] 61 label = self.class_label_map[class_name]
61 for file_name in os.listdir(class_dir_path): 62 for file_name in os.listdir(class_dir_path):
62 # TODO image check 63 # TODO image check
64 if os.path.splitext(file_name)[1] not in self.image_ext_set:
65 continue
63 file_path = os.path.join(class_dir_path, file_name) 66 file_path = os.path.join(class_dir_path, file_name)
64 image_path_list.append(file_path) 67 image_path_list.append(file_path)
65 label_list.append(tf.one_hot(label, depth=self.class_count)) 68 label_list.append(tf.one_hot(label, depth=self.class_count))
...@@ -68,21 +71,42 @@ class F3Classification(BaseModel): ...@@ -68,21 +71,42 @@ class F3Classification(BaseModel):
68 @staticmethod 71 @staticmethod
69 # @tf.function 72 # @tf.function
70 def random_rgb_2_bgr(image, label): 73 def random_rgb_2_bgr(image, label):
71 if random.random() > 0.2: 74 # 1/5
72 return image, label 75 if random.random() < 0.1:
73 image = image[:, :, ::-1] 76 image = image[:, :, ::-1]
74 return image, label 77 return image, label
75 78
76 @staticmethod 79 @staticmethod
77 # @tf.function 80 # @tf.function
78 def random_grayscale_expand(image, label): 81 def random_grayscale_expand(image, label):
79 if random.random() > 0.1: 82 # 1/10
80 return image, label 83 if random.random() < 0.1:
81 image = tf.image.rgb_to_grayscale(image) 84 image = tf.image.rgb_to_grayscale(image)
82 image = tf.image.grayscale_to_rgb(image) 85 image = tf.image.grayscale_to_rgb(image)
83 return image, label 86 return image, label
84 87
85 @staticmethod 88 @staticmethod
89 def random_flip_left_right(image, label):
90 # 1/10
91 if random.random() < 0.2:
92 image = tf.image.random_flip_left_right(image)
93 return image
94
95 @staticmethod
96 def random_flip_up_down(image, label):
97 # 1/10
98 if random.random() < 0.2:
99 image = tf.image.random_flip_up_down(image)
100 return image
101
102 @staticmethod
103 def random_rot90(image, label):
104 # 1/10
105 if random.random() < 0.1:
106 image = tf.image.rot90(image, k=random.randint(1, 3))
107 return image
108
109 @staticmethod
86 # @tf.function 110 # @tf.function
87 def load_image(image_path, label): 111 def load_image(image_path, label):
88 image = tf.io.read_file(image_path) 112 image = tf.io.read_file(image_path)
...@@ -163,7 +187,13 @@ class F3Classification(BaseModel): ...@@ -163,7 +187,13 @@ class F3Classification(BaseModel):
163 name=train_dir_name, 187 name=train_dir_name,
164 batch_size=batch_size, 188 batch_size=batch_size,
165 # augmentation_methods=[], 189 # augmentation_methods=[],
166 augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'], 190 augmentation_methods=[
191 'random_flip_left_right',
192 'random_flip_up_down',
193 'random_rot90',
194 'random_rgb_2_bgr',
195 'random_grayscale_expand'
196 ],
167 ) 197 )
168 validate_dataset = self.load_dataset( 198 validate_dataset = self.load_dataset(
169 dataset_dir=os.path.join(dataset_dir, validate_dir_name), 199 dataset_dir=os.path.join(dataset_dir, validate_dir_name),
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!