import os import random import tensorflow as tf import tensorflow_addons as tfa from keras.applications.mobilenet_v2 import MobileNetV2 from keras import layers, models, optimizers, losses, metrics, callbacks, applications import matplotlib.pyplot as plt from base_class import BaseModel class F3Classification(BaseModel): def __init__(self, class_name_list, class_other_first, *args, **kwargs): super().__init__(*args, **kwargs) self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1 self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) @staticmethod def history_save(history, save_path): acc = history.history['accuracy'] val_acc = history.history['val_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss'] plt.figure(figsize=(8, 8)) plt.subplot(2, 1, 1) plt.plot(acc, label='Training Accuracy') plt.plot(val_acc, label='Validation Accuracy') plt.legend(loc='lower right') plt.ylabel('Accuracy') plt.ylim([min(plt.ylim()), 1]) plt.title('Training and Validation Accuracy') plt.subplot(2, 1, 2) plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.legend(loc='upper right') plt.ylabel('Cross Entropy') plt.ylim([0, 1.0]) plt.title('Training and Validation Loss') plt.xlabel('epoch') # plt.show() plt.savefig(save_path) @staticmethod def get_class_label_map(class_name_list, class_other_first=False): return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)} def get_image_label_list(self, dataset_dir): image_path_list = [] label_list = [] for class_name in os.listdir(dataset_dir): class_dir_path = os.path.join(dataset_dir, class_name) if not os.path.isdir(class_dir_path): continue if class_name not in self.class_label_map: continue label = self.class_label_map[class_name] for file_name in os.listdir(class_dir_path): # TODO image check file_path = os.path.join(class_dir_path, file_name) image_path_list.append(file_path) label_list.append(tf.one_hot(label, depth=self.class_count)) return image_path_list, label_list @staticmethod # @tf.function def random_rgb_2_bgr(image, label): if random.random() > 0.2: return image, label image = image[:, :, ::-1] return image, label @staticmethod # @tf.function def random_grayscale_expand(image, label): if random.random() > 0.1: return image, label image = tf.image.rgb_to_grayscale(image) image = tf.image.grayscale_to_rgb(image) return image, label @staticmethod # @tf.function def load_image(image_path, label): image = tf.io.read_file(image_path) # image = tf.image.decode_image(image, channels=3) # TODO 为什么不行 image = tf.image.decode_png(image, channels=3) return image, label @staticmethod # @tf.function def preprocess_input(image, label): image = tf.image.resize(image, [224, 224]) image = applications.mobilenet_v2.preprocess_input(image) return image, label def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]): image_and_label_list = self.get_image_label_list(dataset_dir) tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name) dataset = tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True) dataset = dataset.map( self.load_image, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False) for augmentation_method in augmentation_methods: dataset = dataset.map( getattr(self, augmentation_method), num_parallel_calls=tf.data.AUTOTUNE, deterministic=False) dataset = dataset.map( self.preprocess_input, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False) parallel_batch_dataset = dataset.batch( batch_size=batch_size, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False, name=name, ).prefetch(tf.data.AUTOTUNE) return parallel_batch_dataset def load_model(self): base_model = MobileNetV2( input_shape=(224, 224, 3), alpha=0.35, include_top=False, weights='imagenet', pooling='avg', ) x = base_model.output x = layers.Dropout(0.5)(x) x = layers.Dense(256, activation='sigmoid', name='dense')(x) x = layers.Dropout(0.5)(x) x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x) model = models.Model(inputs=base_model.input, outputs=x) freeze = True for layer in model.layers: layer.trainable = not freeze if freeze and layer.name == 'block_16_project_BN': freeze = False return model def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path, train_dir_name='train', validate_dir_name='test'): model = self.load_model() model.summary() model.compile( optimizer=optimizers.Adam(learning_rate=3e-4), loss=tfa.losses.SigmoidFocalCrossEntropy(), metrics=['accuracy', ], loss_weights=None, weighted_metrics=None, run_eagerly=None, steps_per_execution=None, jit_compile=None, ) train_dataset = self.load_dataset( dataset_dir=os.path.join(dataset_dir, train_dir_name), name=train_dir_name, batch_size=batch_size, # augmentation_methods=[], augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'], ) validate_dataset = self.load_dataset( dataset_dir=os.path.join(dataset_dir, validate_dir_name), name=validate_dir_name, batch_size=batch_size, augmentation_methods=[] ) ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) history = model.fit( train_dataset, epochs=epoch, validation_data=validate_dataset, callbacks=[ckpt_callback, ], ) self.history_save(history, history_save_path) def test(self): print(self.class_label_map) print(self.class_count)