model.py 9.95 KB
import os
import random
import tensorflow as tf
import tensorflow_addons as tfa
from keras.applications.mobilenet_v2 import MobileNetV2
from keras import layers, models, optimizers, losses, metrics, callbacks, applications
import matplotlib.pyplot as plt

from base_class import BaseModel


class CustomMetric(metrics.Metric):

    def __init__(self, thresholds=0.5, name="custom_metric", **kwargs):
        super(CustomMetric, self).__init__(name=name, **kwargs)
        self.thresholds = thresholds
        self.true_positives = self.add_weight(name="ctp", initializer="zeros")
        self.count = self.add_weight(name="count", initializer="zeros", dtype='int32')

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true_idx = tf.argmax(y_true, axis=1) + 1
        y_true_is_other = tf.cast(tf.math.reduce_sum(y_true, axis=1), "int64")
        y_true = tf.math.multiply(y_true_idx, y_true_is_other)

        y_pred_idx = tf.argmax(y_pred, axis=1) + 1
        y_pred_is_other = tf.cast(tf.math.greater_equal(tf.math.reduce_max(y_pred, axis=1), self.thresholds), 'int64')
        y_pred = tf.math.multiply(y_pred_idx, y_pred_is_other)

        print(y_true)
        print(y_pred)

        values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32")
        values = tf.cast(values, "float32")
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, "float32")
            values = tf.multiply(values, sample_weight)
        self.true_positives.assign_add(tf.reduce_sum(values))
        self.count.assign_add(tf.shape(y_true)[0])

    def result(self):
        return self.true_positives / tf.cast(self.count, 'float32')

    def reset_state(self):
        # The state of the metric will be reset at the start of each epoch.
        self.true_positives.assign(0.0)
        self.count.assign(0)


class F3Classification(BaseModel):

    def __init__(self, class_name_list, class_other_first, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1
        self.class_label_map = self.get_class_label_map(class_name_list, class_other_first)
        self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}

    @staticmethod
    def gpu_config():
        gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
        # print(gpus)
        tf.config.set_visible_devices(devices=gpus[1], device_type='GPU')

    @staticmethod
    def history_save(history, save_path):
        acc = history.history['accuracy']
        val_acc = history.history['val_accuracy']

        loss = history.history['loss']
        val_loss = history.history['val_loss']

        plt.figure(figsize=(8, 8))
        plt.subplot(2, 1, 1)
        plt.plot(acc, label='Training Accuracy')
        plt.plot(val_acc, label='Validation Accuracy')
        plt.legend(loc='lower right')
        plt.ylabel('Accuracy')
        plt.ylim([min(plt.ylim()), 1])
        plt.title('Training and Validation Accuracy')

        plt.subplot(2, 1, 2)
        plt.plot(loss, label='Training Loss')
        plt.plot(val_loss, label='Validation Loss')
        plt.legend(loc='upper right')
        plt.ylabel('Cross Entropy')
        plt.ylim([0, 1.0])
        plt.title('Training and Validation Loss')
        plt.xlabel('epoch')
        # plt.show()
        plt.savefig(save_path)

    @staticmethod
    def get_class_label_map(class_name_list, class_other_first=False):
        return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)}

    def get_image_label_list(self, dataset_dir):
        image_path_list = []
        label_list = []
        for class_name in os.listdir(dataset_dir):
            class_dir_path = os.path.join(dataset_dir, class_name)
            if not os.path.isdir(class_dir_path):
                continue
            if class_name not in self.class_label_map:
                continue
            label = self.class_label_map[class_name]
            for file_name in os.listdir(class_dir_path):
                # TODO image check
                if os.path.splitext(file_name)[1] not in self.image_ext_set:
                    continue
                file_path = os.path.join(class_dir_path, file_name)
                image_path_list.append(file_path)
                label_list.append(tf.one_hot(label, depth=self.class_count))
        return image_path_list, label_list

    @staticmethod
    # @tf.function
    def random_rgb_2_bgr(image, label):
        # 1/5
        if random.random() < 0.1:
            image = image[:, :, ::-1]
        return image, label

    @staticmethod
    # @tf.function
    def random_grayscale_expand(image, label):
        # 1/10
        if random.random() < 0.1:
            image = tf.image.rgb_to_grayscale(image)
            image = tf.image.grayscale_to_rgb(image)
        return image, label

    @staticmethod
    def random_flip_left_right(image, label):
        # 1/10
        if random.random() < 0.2:
            image = tf.image.random_flip_left_right(image)
        return image, label

    @staticmethod
    def random_flip_up_down(image, label):
        # 1/10
        if random.random() < 0.2:
            image = tf.image.random_flip_up_down(image)
        return image, label

    @staticmethod
    def random_rot90(image, label):
        # 1/10
        if random.random() < 0.1:
            image = tf.image.rot90(image, k=random.randint(1, 3))
        return image, label

    @staticmethod
    # @tf.function
    def load_image(image_path, label):
        image = tf.io.read_file(image_path)
        # image = tf.image.decode_image(image, channels=3)    # TODO 为什么不行
        image = tf.image.decode_png(image, channels=3)
        return image, label

    @staticmethod
    # @tf.function
    def preprocess_input(image, label):
        image = tf.image.resize(image, [224, 224])
        image = applications.mobilenet_v2.preprocess_input(image)
        return image, label

    def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]):
        image_and_label_list = self.get_image_label_list(dataset_dir)
        tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name)
        dataset = tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True)
        dataset = dataset.map(
            self.load_image, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
        for augmentation_method in augmentation_methods:
            dataset = dataset.map(
                getattr(self, augmentation_method),
                num_parallel_calls=tf.data.AUTOTUNE,
                deterministic=False)
        dataset = dataset.map(
            self.preprocess_input, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
        parallel_batch_dataset = dataset.batch(
            batch_size=batch_size,
            drop_remainder=True,
            num_parallel_calls=tf.data.AUTOTUNE,
            deterministic=False,
            name=name,
        ).prefetch(tf.data.AUTOTUNE)
        return parallel_batch_dataset

    def load_model(self):
        base_model = MobileNetV2(
            input_shape=(224, 224, 3),
            alpha=0.35,
            include_top=False,
            weights='imagenet',
            pooling='avg',
        )
        x = base_model.output
        x = layers.Dropout(0.5)(x)
        x = layers.Dense(256, activation='sigmoid', name='dense')(x)
        x = layers.Dropout(0.5)(x)
        x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x)
        model = models.Model(inputs=base_model.input, outputs=x)

        freeze = True
        for layer in model.layers:
            layer.trainable = not freeze
            if freeze and layer.name == 'block_16_project_BN':
                freeze = False
        return model

    def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path,
              train_dir_name='train', validate_dir_name='test', thresholds=0.5):

        self.gpu_config()

        model = self.load_model()
        model.summary()

        model.compile(
            optimizer=optimizers.Adam(learning_rate=3e-4),
            loss=tfa.losses.SigmoidFocalCrossEntropy(),  # TODO >>>
            metrics=[CustomMetric(thresholds), ],

            loss_weights=None,
            weighted_metrics=None,
            run_eagerly=None,
            steps_per_execution=None,
            jit_compile=None,
        )

        train_dataset = self.load_dataset(
            dataset_dir=os.path.join(dataset_dir, train_dir_name),
            name=train_dir_name,
            batch_size=batch_size,
            # augmentation_methods=[],
            augmentation_methods=[
                'random_flip_left_right',
                'random_flip_up_down',
                'random_rot90',
                'random_rgb_2_bgr',
                'random_grayscale_expand'
            ],
        )
        validate_dataset = self.load_dataset(
            dataset_dir=os.path.join(dataset_dir, validate_dir_name),
            name=validate_dir_name,
            batch_size=batch_size,
            augmentation_methods=[]
        )

        ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True)

        history = model.fit(
            train_dataset,
            epochs=epoch,
            validation_data=validate_dataset,
            callbacks=[ckpt_callback, ],
        )

        self.history_save(history, history_save_path)

    def test(self):
        y_true = [
            [0, 1, 0],
            [0, 1, 0],
            [0, 0, 1],
            [0, 0, 0],
        ]
        y_pre = [
            [0.1, 0.8, 0.9],  # TODO multi_label
            [0.2, 0.8, 0.1],
            [0.2, 0.1, 0.85],
            [0.2, 0.4, 0.1],
        ]

        # x = tf.argmax(y_pre, axis=1)
        # y = tf.reduce_sum(y_pre, axis=1)
        # print(x)
        # print(y)

        # m = tf.keras.metrics.TopKCategoricalAccuracy(k=1)
        m = CustomMetric(0.5)
        m.update_state(y_true, y_pre)
        print(m.result().numpy())