model.py 7.64 KB
import os
import random
import tensorflow as tf
import tensorflow_addons as tfa
from keras.applications.mobilenet_v2 import MobileNetV2
from keras import layers, models, optimizers, losses, metrics, callbacks, applications
import matplotlib.pyplot as plt

from base_class import BaseModel


class F3Classification(BaseModel):

    def __init__(self, class_name_list, class_other_first, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1
        self.class_label_map = self.get_class_label_map(class_name_list, class_other_first)
        self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}

    @staticmethod
    def history_save(history, save_path):
        acc = history.history['accuracy']
        val_acc = history.history['val_accuracy']

        loss = history.history['loss']
        val_loss = history.history['val_loss']

        plt.figure(figsize=(8, 8))
        plt.subplot(2, 1, 1)
        plt.plot(acc, label='Training Accuracy')
        plt.plot(val_acc, label='Validation Accuracy')
        plt.legend(loc='lower right')
        plt.ylabel('Accuracy')
        plt.ylim([min(plt.ylim()), 1])
        plt.title('Training and Validation Accuracy')

        plt.subplot(2, 1, 2)
        plt.plot(loss, label='Training Loss')
        plt.plot(val_loss, label='Validation Loss')
        plt.legend(loc='upper right')
        plt.ylabel('Cross Entropy')
        plt.ylim([0, 1.0])
        plt.title('Training and Validation Loss')
        plt.xlabel('epoch')
        # plt.show()
        plt.savefig(save_path)

    @staticmethod
    def get_class_label_map(class_name_list, class_other_first=False):
        return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)}

    def get_image_label_list(self, dataset_dir):
        image_path_list = []
        label_list = []
        for class_name in os.listdir(dataset_dir):
            class_dir_path = os.path.join(dataset_dir, class_name)
            if not os.path.isdir(class_dir_path):
                continue
            if class_name not in self.class_label_map:
                continue
            label = self.class_label_map[class_name]
            for file_name in os.listdir(class_dir_path):
                # TODO image check
                if os.path.splitext(file_name)[1] not in self.image_ext_set:
                    continue
                file_path = os.path.join(class_dir_path, file_name)
                image_path_list.append(file_path)
                label_list.append(tf.one_hot(label, depth=self.class_count))
        return image_path_list, label_list

    @staticmethod
    # @tf.function
    def random_rgb_2_bgr(image, label):
        # 1/5
        if random.random() < 0.1:
            image = image[:, :, ::-1]
        return image, label

    @staticmethod
    # @tf.function
    def random_grayscale_expand(image, label):
        # 1/10
        if random.random() < 0.1:
            image = tf.image.rgb_to_grayscale(image)
            image = tf.image.grayscale_to_rgb(image)
        return image, label

    @staticmethod
    def random_flip_left_right(image, label):
        # 1/10
        if random.random() < 0.2:
            image = tf.image.random_flip_left_right(image)
        return image

    @staticmethod
    def random_flip_up_down(image, label):
        # 1/10
        if random.random() < 0.2:
            image = tf.image.random_flip_up_down(image)
        return image

    @staticmethod
    def random_rot90(image, label):
        # 1/10
        if random.random() < 0.1:
            image = tf.image.rot90(image, k=random.randint(1, 3))
        return image

    @staticmethod
    # @tf.function
    def load_image(image_path, label):
        image = tf.io.read_file(image_path)
        # image = tf.image.decode_image(image, channels=3)    # TODO 为什么不行
        image = tf.image.decode_png(image, channels=3)
        return image, label

    @staticmethod
    # @tf.function
    def preprocess_input(image, label):
        image = tf.image.resize(image, [224, 224])
        image = applications.mobilenet_v2.preprocess_input(image)
        return image, label

    def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]):
        image_and_label_list = self.get_image_label_list(dataset_dir)
        tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name)
        dataset = tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True)
        dataset = dataset.map(
            self.load_image, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
        for augmentation_method in augmentation_methods:
            dataset = dataset.map(
                getattr(self, augmentation_method),
                num_parallel_calls=tf.data.AUTOTUNE,
                deterministic=False)
        dataset = dataset.map(
            self.preprocess_input, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
        parallel_batch_dataset = dataset.batch(
            batch_size=batch_size,
            drop_remainder=True,
            num_parallel_calls=tf.data.AUTOTUNE,
            deterministic=False,
            name=name,
        ).prefetch(tf.data.AUTOTUNE)
        return parallel_batch_dataset

    def load_model(self):
        base_model = MobileNetV2(
            input_shape=(224, 224, 3),
            alpha=0.35,
            include_top=False,
            weights='imagenet',
            pooling='avg',
        )
        x = base_model.output
        x = layers.Dropout(0.5)(x)
        x = layers.Dense(256, activation='sigmoid', name='dense')(x)
        x = layers.Dropout(0.5)(x)
        x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x)
        model = models.Model(inputs=base_model.input, outputs=x)

        freeze = True
        for layer in model.layers:
            layer.trainable = not freeze
            if freeze and layer.name == 'block_16_project_BN':
                freeze = False
        return model

    def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path,
              train_dir_name='train', validate_dir_name='test'):
        model = self.load_model()
        model.summary()

        model.compile(
            optimizer=optimizers.Adam(learning_rate=3e-4),
            loss=tfa.losses.SigmoidFocalCrossEntropy(),
            metrics=['accuracy', ],

            loss_weights=None,
            weighted_metrics=None,
            run_eagerly=None,
            steps_per_execution=None,
            jit_compile=None,
        )

        train_dataset = self.load_dataset(
            dataset_dir=os.path.join(dataset_dir, train_dir_name),
            name=train_dir_name,
            batch_size=batch_size,
            # augmentation_methods=[],
            augmentation_methods=[
                'random_flip_left_right',
                'random_flip_up_down',
                'random_rot90',
                'random_rgb_2_bgr',
                'random_grayscale_expand'
            ],
        )
        validate_dataset = self.load_dataset(
            dataset_dir=os.path.join(dataset_dir, validate_dir_name),
            name=validate_dir_name,
            batch_size=batch_size,
            augmentation_methods=[]
        )

        ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True)

        history = model.fit(
            train_dataset,
            epochs=epoch,
            validation_data=validate_dataset,
            callbacks=[ckpt_callback, ],
        )

        self.history_save(history, history_save_path)

    def test(self):
        print(self.class_label_map)
        print(self.class_count)