add augmentation methods
Showing
1 changed file
with
35 additions
and
5 deletions
... | @@ -15,6 +15,7 @@ class F3Classification(BaseModel): | ... | @@ -15,6 +15,7 @@ class F3Classification(BaseModel): |
15 | super().__init__(*args, **kwargs) | 15 | super().__init__(*args, **kwargs) |
16 | self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1 | 16 | self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1 |
17 | self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) | 17 | self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) |
18 | self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} | ||
18 | 19 | ||
19 | @staticmethod | 20 | @staticmethod |
20 | def history_save(history, save_path): | 21 | def history_save(history, save_path): |
... | @@ -60,6 +61,8 @@ class F3Classification(BaseModel): | ... | @@ -60,6 +61,8 @@ class F3Classification(BaseModel): |
60 | label = self.class_label_map[class_name] | 61 | label = self.class_label_map[class_name] |
61 | for file_name in os.listdir(class_dir_path): | 62 | for file_name in os.listdir(class_dir_path): |
62 | # TODO image check | 63 | # TODO image check |
64 | if os.path.splitext(file_name)[1] not in self.image_ext_set: | ||
65 | continue | ||
63 | file_path = os.path.join(class_dir_path, file_name) | 66 | file_path = os.path.join(class_dir_path, file_name) |
64 | image_path_list.append(file_path) | 67 | image_path_list.append(file_path) |
65 | label_list.append(tf.one_hot(label, depth=self.class_count)) | 68 | label_list.append(tf.one_hot(label, depth=self.class_count)) |
... | @@ -68,21 +71,42 @@ class F3Classification(BaseModel): | ... | @@ -68,21 +71,42 @@ class F3Classification(BaseModel): |
68 | @staticmethod | 71 | @staticmethod |
69 | # @tf.function | 72 | # @tf.function |
70 | def random_rgb_2_bgr(image, label): | 73 | def random_rgb_2_bgr(image, label): |
71 | if random.random() > 0.2: | 74 | # 1/5 |
72 | return image, label | 75 | if random.random() < 0.1: |
73 | image = image[:, :, ::-1] | 76 | image = image[:, :, ::-1] |
74 | return image, label | 77 | return image, label |
75 | 78 | ||
76 | @staticmethod | 79 | @staticmethod |
77 | # @tf.function | 80 | # @tf.function |
78 | def random_grayscale_expand(image, label): | 81 | def random_grayscale_expand(image, label): |
79 | if random.random() > 0.1: | 82 | # 1/10 |
80 | return image, label | 83 | if random.random() < 0.1: |
81 | image = tf.image.rgb_to_grayscale(image) | 84 | image = tf.image.rgb_to_grayscale(image) |
82 | image = tf.image.grayscale_to_rgb(image) | 85 | image = tf.image.grayscale_to_rgb(image) |
83 | return image, label | 86 | return image, label |
84 | 87 | ||
85 | @staticmethod | 88 | @staticmethod |
89 | def random_flip_left_right(image, label): | ||
90 | # 1/10 | ||
91 | if random.random() < 0.2: | ||
92 | image = tf.image.random_flip_left_right(image) | ||
93 | return image | ||
94 | |||
95 | @staticmethod | ||
96 | def random_flip_up_down(image, label): | ||
97 | # 1/10 | ||
98 | if random.random() < 0.2: | ||
99 | image = tf.image.random_flip_up_down(image) | ||
100 | return image | ||
101 | |||
102 | @staticmethod | ||
103 | def random_rot90(image, label): | ||
104 | # 1/10 | ||
105 | if random.random() < 0.1: | ||
106 | image = tf.image.rot90(image, k=random.randint(1, 3)) | ||
107 | return image | ||
108 | |||
109 | @staticmethod | ||
86 | # @tf.function | 110 | # @tf.function |
87 | def load_image(image_path, label): | 111 | def load_image(image_path, label): |
88 | image = tf.io.read_file(image_path) | 112 | image = tf.io.read_file(image_path) |
... | @@ -163,7 +187,13 @@ class F3Classification(BaseModel): | ... | @@ -163,7 +187,13 @@ class F3Classification(BaseModel): |
163 | name=train_dir_name, | 187 | name=train_dir_name, |
164 | batch_size=batch_size, | 188 | batch_size=batch_size, |
165 | # augmentation_methods=[], | 189 | # augmentation_methods=[], |
166 | augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'], | 190 | augmentation_methods=[ |
191 | 'random_flip_left_right', | ||
192 | 'random_flip_up_down', | ||
193 | 'random_rot90', | ||
194 | 'random_rgb_2_bgr', | ||
195 | 'random_grayscale_expand' | ||
196 | ], | ||
167 | ) | 197 | ) |
168 | validate_dataset = self.load_dataset( | 198 | validate_dataset = self.load_dataset( |
169 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), | 199 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), | ... | ... |
-
Please register or sign in to post a comment