add augmentation methods
Showing
1 changed file
with
38 additions
and
8 deletions
| ... | @@ -15,6 +15,7 @@ class F3Classification(BaseModel): | ... | @@ -15,6 +15,7 @@ class F3Classification(BaseModel): |
| 15 | super().__init__(*args, **kwargs) | 15 | super().__init__(*args, **kwargs) |
| 16 | self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1 | 16 | self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1 |
| 17 | self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) | 17 | self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) |
| 18 | self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} | ||
| 18 | 19 | ||
| 19 | @staticmethod | 20 | @staticmethod |
| 20 | def history_save(history, save_path): | 21 | def history_save(history, save_path): |
| ... | @@ -60,6 +61,8 @@ class F3Classification(BaseModel): | ... | @@ -60,6 +61,8 @@ class F3Classification(BaseModel): |
| 60 | label = self.class_label_map[class_name] | 61 | label = self.class_label_map[class_name] |
| 61 | for file_name in os.listdir(class_dir_path): | 62 | for file_name in os.listdir(class_dir_path): |
| 62 | # TODO image check | 63 | # TODO image check |
| 64 | if os.path.splitext(file_name)[1] not in self.image_ext_set: | ||
| 65 | continue | ||
| 63 | file_path = os.path.join(class_dir_path, file_name) | 66 | file_path = os.path.join(class_dir_path, file_name) |
| 64 | image_path_list.append(file_path) | 67 | image_path_list.append(file_path) |
| 65 | label_list.append(tf.one_hot(label, depth=self.class_count)) | 68 | label_list.append(tf.one_hot(label, depth=self.class_count)) |
| ... | @@ -68,21 +71,42 @@ class F3Classification(BaseModel): | ... | @@ -68,21 +71,42 @@ class F3Classification(BaseModel): |
| 68 | @staticmethod | 71 | @staticmethod |
| 69 | # @tf.function | 72 | # @tf.function |
| 70 | def random_rgb_2_bgr(image, label): | 73 | def random_rgb_2_bgr(image, label): |
| 71 | if random.random() > 0.2: | 74 | # 1/5 |
| 72 | return image, label | 75 | if random.random() < 0.1: |
| 73 | image = image[:, :, ::-1] | 76 | image = image[:, :, ::-1] |
| 74 | return image, label | 77 | return image, label |
| 75 | 78 | ||
| 76 | @staticmethod | 79 | @staticmethod |
| 77 | # @tf.function | 80 | # @tf.function |
| 78 | def random_grayscale_expand(image, label): | 81 | def random_grayscale_expand(image, label): |
| 79 | if random.random() > 0.1: | 82 | # 1/10 |
| 80 | return image, label | 83 | if random.random() < 0.1: |
| 81 | image = tf.image.rgb_to_grayscale(image) | 84 | image = tf.image.rgb_to_grayscale(image) |
| 82 | image = tf.image.grayscale_to_rgb(image) | 85 | image = tf.image.grayscale_to_rgb(image) |
| 83 | return image, label | 86 | return image, label |
| 84 | 87 | ||
| 85 | @staticmethod | 88 | @staticmethod |
| 89 | def random_flip_left_right(image, label): | ||
| 90 | # 1/10 | ||
| 91 | if random.random() < 0.2: | ||
| 92 | image = tf.image.random_flip_left_right(image) | ||
| 93 | return image | ||
| 94 | |||
| 95 | @staticmethod | ||
| 96 | def random_flip_up_down(image, label): | ||
| 97 | # 1/10 | ||
| 98 | if random.random() < 0.2: | ||
| 99 | image = tf.image.random_flip_up_down(image) | ||
| 100 | return image | ||
| 101 | |||
| 102 | @staticmethod | ||
| 103 | def random_rot90(image, label): | ||
| 104 | # 1/10 | ||
| 105 | if random.random() < 0.1: | ||
| 106 | image = tf.image.rot90(image, k=random.randint(1, 3)) | ||
| 107 | return image | ||
| 108 | |||
| 109 | @staticmethod | ||
| 86 | # @tf.function | 110 | # @tf.function |
| 87 | def load_image(image_path, label): | 111 | def load_image(image_path, label): |
| 88 | image = tf.io.read_file(image_path) | 112 | image = tf.io.read_file(image_path) |
| ... | @@ -163,7 +187,13 @@ class F3Classification(BaseModel): | ... | @@ -163,7 +187,13 @@ class F3Classification(BaseModel): |
| 163 | name=train_dir_name, | 187 | name=train_dir_name, |
| 164 | batch_size=batch_size, | 188 | batch_size=batch_size, |
| 165 | # augmentation_methods=[], | 189 | # augmentation_methods=[], |
| 166 | augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'], | 190 | augmentation_methods=[ |
| 191 | 'random_flip_left_right', | ||
| 192 | 'random_flip_up_down', | ||
| 193 | 'random_rot90', | ||
| 194 | 'random_rgb_2_bgr', | ||
| 195 | 'random_grayscale_expand' | ||
| 196 | ], | ||
| 167 | ) | 197 | ) |
| 168 | validate_dataset = self.load_dataset( | 198 | validate_dataset = self.load_dataset( |
| 169 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), | 199 | dataset_dir=os.path.join(dataset_dir, validate_dir_name), | ... | ... |
-
Please register or sign in to post a comment