update model
Showing
1 changed file
with
6 additions
and
4 deletions
... | @@ -107,7 +107,7 @@ class F3Classification(BaseModel): | ... | @@ -107,7 +107,7 @@ class F3Classification(BaseModel): |
107 | image = applications.mobilenet_v2.preprocess_input(image) | 107 | image = applications.mobilenet_v2.preprocess_input(image) |
108 | return image, label | 108 | return image, label |
109 | 109 | ||
110 | def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]): | 110 | def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[], drop_remainder=True): |
111 | image_and_label_list = self.get_image_label_list(dataset_dir) | 111 | image_and_label_list = self.get_image_label_list(dataset_dir) |
112 | tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name) | 112 | tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name) |
113 | dataset = tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True) | 113 | dataset = tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True) |
... | @@ -122,7 +122,7 @@ class F3Classification(BaseModel): | ... | @@ -122,7 +122,7 @@ class F3Classification(BaseModel): |
122 | self.preprocess_input, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False) | 122 | self.preprocess_input, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False) |
123 | parallel_batch_dataset = dataset.batch( | 123 | parallel_batch_dataset = dataset.batch( |
124 | batch_size=batch_size, | 124 | batch_size=batch_size, |
125 | drop_remainder=True, | 125 | drop_remainder=drop_remainder, |
126 | num_parallel_calls=tf.data.AUTOTUNE, | 126 | num_parallel_calls=tf.data.AUTOTUNE, |
127 | deterministic=False, | 127 | deterministic=False, |
128 | name=name, | 128 | name=name, |
... | @@ -144,7 +144,8 @@ class F3Classification(BaseModel): | ... | @@ -144,7 +144,8 @@ class F3Classification(BaseModel): |
144 | ) | 144 | ) |
145 | x = base_model.output | 145 | x = base_model.output |
146 | x = layers.Dropout(0.5)(x) | 146 | x = layers.Dropout(0.5)(x) |
147 | x = layers.Dense(256, activation='sigmoid', name='dense')(x) | 147 | # x = layers.Dense(256, activation='sigmoid', name='dense')(x) |
148 | x = layers.Dense(256, activation='relu', name='dense')(x) | ||
148 | x = layers.Dropout(0.5)(x) | 149 | x = layers.Dropout(0.5)(x) |
149 | x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x) | 150 | x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x) |
150 | self.model = models.Model(inputs=base_model.input, outputs=x) | 151 | self.model = models.Model(inputs=base_model.input, outputs=x) |
... | @@ -243,7 +244,8 @@ class F3Classification(BaseModel): | ... | @@ -243,7 +244,8 @@ class F3Classification(BaseModel): |
243 | batch_size=batch_size, | 244 | batch_size=batch_size, |
244 | augmentation_methods=[ | 245 | augmentation_methods=[ |
245 | 'rgb_2_bgr' | 246 | 'rgb_2_bgr' |
246 | ] | 247 | ], |
248 | drop_remainder=False, | ||
247 | ) | 249 | ) |
248 | 250 | ||
249 | label_true_list = [] | 251 | label_true_list = [] | ... | ... |
-
Please register or sign in to post a comment