update model
Showing
1 changed file
with
6 additions
and
4 deletions
| ... | @@ -107,7 +107,7 @@ class F3Classification(BaseModel): | ... | @@ -107,7 +107,7 @@ class F3Classification(BaseModel): |
| 107 | image = applications.mobilenet_v2.preprocess_input(image) | 107 | image = applications.mobilenet_v2.preprocess_input(image) |
| 108 | return image, label | 108 | return image, label |
| 109 | 109 | ||
| 110 | def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]): | 110 | def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[], drop_remainder=True): |
| 111 | image_and_label_list = self.get_image_label_list(dataset_dir) | 111 | image_and_label_list = self.get_image_label_list(dataset_dir) |
| 112 | tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name) | 112 | tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name) |
| 113 | dataset = tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True) | 113 | dataset = tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True) |
| ... | @@ -122,7 +122,7 @@ class F3Classification(BaseModel): | ... | @@ -122,7 +122,7 @@ class F3Classification(BaseModel): |
| 122 | self.preprocess_input, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False) | 122 | self.preprocess_input, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False) |
| 123 | parallel_batch_dataset = dataset.batch( | 123 | parallel_batch_dataset = dataset.batch( |
| 124 | batch_size=batch_size, | 124 | batch_size=batch_size, |
| 125 | drop_remainder=True, | 125 | drop_remainder=drop_remainder, |
| 126 | num_parallel_calls=tf.data.AUTOTUNE, | 126 | num_parallel_calls=tf.data.AUTOTUNE, |
| 127 | deterministic=False, | 127 | deterministic=False, |
| 128 | name=name, | 128 | name=name, |
| ... | @@ -144,7 +144,8 @@ class F3Classification(BaseModel): | ... | @@ -144,7 +144,8 @@ class F3Classification(BaseModel): |
| 144 | ) | 144 | ) |
| 145 | x = base_model.output | 145 | x = base_model.output |
| 146 | x = layers.Dropout(0.5)(x) | 146 | x = layers.Dropout(0.5)(x) |
| 147 | x = layers.Dense(256, activation='sigmoid', name='dense')(x) | 147 | # x = layers.Dense(256, activation='sigmoid', name='dense')(x) |
| 148 | x = layers.Dense(256, activation='relu', name='dense')(x) | ||
| 148 | x = layers.Dropout(0.5)(x) | 149 | x = layers.Dropout(0.5)(x) |
| 149 | x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x) | 150 | x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x) |
| 150 | self.model = models.Model(inputs=base_model.input, outputs=x) | 151 | self.model = models.Model(inputs=base_model.input, outputs=x) |
| ... | @@ -243,7 +244,8 @@ class F3Classification(BaseModel): | ... | @@ -243,7 +244,8 @@ class F3Classification(BaseModel): |
| 243 | batch_size=batch_size, | 244 | batch_size=batch_size, |
| 244 | augmentation_methods=[ | 245 | augmentation_methods=[ |
| 245 | 'rgb_2_bgr' | 246 | 'rgb_2_bgr' |
| 246 | ] | 247 | ], |
| 248 | drop_remainder=False, | ||
| 247 | ) | 249 | ) |
| 248 | 250 | ||
| 249 | label_true_list = [] | 251 | label_true_list = [] | ... | ... |
-
Please register or sign in to post a comment