model.py 11.1 KB
import os
import random
import cv2
import tensorflow as tf
import tensorflow_addons as tfa

from keras.applications.mobilenet_v2 import MobileNetV2
from keras import layers, models, optimizers, callbacks, applications
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report

from .base_class import BaseModel
from .metrics import CustomMetric
from .utils import history_save, plot_confusion_matrix


class F3Classification(BaseModel):

    def __init__(self, class_name_list, class_other_first, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_name_list = class_name_list
        self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1
        self.class_label_map = self.get_class_label_map(class_name_list, class_other_first)
        self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
        self.model = None

    @staticmethod
    def gpu_config(gpu_idx=0):
        gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
        tf.config.set_visible_devices(devices=gpus[gpu_idx], device_type='GPU')

    @staticmethod
    def get_class_label_map(class_name_list, class_other_first=False):
        return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)}

    def get_image_label_list(self, dataset_dir):
        image_path_list = []
        label_list = []
        for class_name in os.listdir(dataset_dir):
            class_dir_path = os.path.join(dataset_dir, class_name)
            if not os.path.isdir(class_dir_path):
                continue
            if class_name not in self.class_label_map:
                continue
            label = self.class_label_map[class_name]
            for file_name in os.listdir(class_dir_path):
                if os.path.splitext(file_name)[1] not in self.image_ext_set:
                    continue
                file_path = os.path.join(class_dir_path, file_name)
                image_path_list.append(file_path)
                label_list.append(tf.one_hot(label, depth=self.class_count))
        return image_path_list, label_list

    @staticmethod
    # @tf.function
    def random_rgb_2_bgr(image, label):
        # 1/2
        if random.random() < 0.5:
            image = image[:, :, ::-1]
        return image, label

    @staticmethod
    # @tf.function
    def rgb_2_bgr(image, label):
        image = image[:, :, ::-1]
        return image, label

    @staticmethod
    # @tf.function
    def random_grayscale_expand(image, label):
        if random.random() < 0.1:
            image = tf.image.rgb_to_grayscale(image)
            image = tf.image.grayscale_to_rgb(image)
        return image, label

    @staticmethod
    def random_flip_left_right(image, label):
        # if random.random() < 0.2:
        image = tf.image.random_flip_left_right(image)
        return image, label

    @staticmethod
    def random_flip_up_down(image, label):
        # if random.random() < 0.2:
        image = tf.image.random_flip_up_down(image)
        return image, label

    @staticmethod
    def random_rot90(image, label):
        if random.random() < 0.3:
            image = tf.image.rot90(image, k=random.randint(1, 3))
        return image, label

    @staticmethod
    # @tf.function
    def load_image(image_path, label):
        image = tf.io.read_file(image_path)
        # image = tf.image.decode_image(image, channels=3)    # TODO ?
        # image = tf.image.decode_png(image, channels=3)
        image = tf.image.decode_jpeg(image, channels=3, dct_method='INTEGER_ACCURATE')
        image = tf.image.resize(image, [224, 224])
        return image, label

    @staticmethod
    # @tf.function
    def preprocess_input(image, label):
        # image = tf.image.resize(image, [224, 224])
        image = applications.mobilenet_v2.preprocess_input(image)
        return image, label

    def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]):
        image_and_label_list = self.get_image_label_list(dataset_dir)
        tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name)
        dataset = tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True)
        dataset = dataset.map(
            self.load_image, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
        for augmentation_method in augmentation_methods:
            dataset = dataset.map(
                getattr(self, augmentation_method),
                num_parallel_calls=tf.data.AUTOTUNE,
                deterministic=False)
        dataset = dataset.map(
            self.preprocess_input, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
        parallel_batch_dataset = dataset.batch(
            batch_size=batch_size,
            drop_remainder=True,
            num_parallel_calls=tf.data.AUTOTUNE,
            deterministic=False,
            name=name,
        ).prefetch(tf.data.AUTOTUNE)
        return parallel_batch_dataset

    def load_model(self, for_training=False, load_weights_path=None):
        if self.model is not None:
            raise Exception('Model is loaded, if you are sure to reload the model, set `self.model = None` first')

        base_model = MobileNetV2(
            input_shape=(224, 224, 3),
            # alpha=0.35,
            alpha=0.5,
            # alpha=1,
            include_top=False,
            weights='imagenet',
            pooling='avg',
        )
        x = base_model.output
        x = layers.Dropout(0.5)(x)
        x = layers.Dense(256, activation='sigmoid', name='dense')(x)
        x = layers.Dropout(0.5)(x)
        x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x)
        self.model = models.Model(inputs=base_model.input, outputs=x)

        if isinstance(load_weights_path, str) and os.path.isfile(load_weights_path):
            self.model.load_weights(load_weights_path, by_name=True, skip_mismatch=True)
        elif for_training:
            freeze = True
            for layer in self.model.layers:
                layer.trainable = not freeze
                if freeze and layer.name == 'block_16_project_BN':
                    freeze = False

    def train(self,
              dataset_dir,
              epoch,
              batch_size,
              ckpt_path,
              history_save_path,
              load_weights_path=None,
              train_dir_name='train',
              validate_dir_name='test',
              thresholds=0.5,
              metrics_name='accuracy'):

        self.gpu_config(1)

        self.load_model(for_training=True, load_weights_path=load_weights_path)
        self.model.summary()

        self.model.compile(
            # optimizer=optimizers.Adam(learning_rate=3e-4),
            optimizer=optimizers.Adam(learning_rate=1e-4),
            loss=tfa.losses.SigmoidFocalCrossEntropy(),  # TODO ?
            metrics=[CustomMetric(thresholds, name=metrics_name), ],

            loss_weights=None,
            weighted_metrics=None,
            run_eagerly=None,
            steps_per_execution=None,
            jit_compile=None,
        )

        train_dataset = self.load_dataset(
            dataset_dir=os.path.join(dataset_dir, train_dir_name),
            name=train_dir_name,
            batch_size=batch_size,
            # augmentation_methods=[],
            augmentation_methods=[
                'random_flip_left_right',
                'random_flip_up_down',
                'random_rot90',
                # 'random_rgb_2_bgr',
                # 'rgb_2_bgr',
                'random_grayscale_expand'
            ],
        )
        validate_dataset = self.load_dataset(
            dataset_dir=os.path.join(dataset_dir, validate_dir_name),
            name=validate_dir_name,
            batch_size=batch_size,
            augmentation_methods=[
                'rgb_2_bgr'
            ]
        )

        ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True)
        es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

        history = self.model.fit(
            train_dataset,
            epochs=epoch,
            validation_data=validate_dataset,
            callbacks=[ckpt_callback,
                       # es_callback
                       ],
        )

        history_save(history, history_save_path, metrics_name)

    def evaluation(self,
                   load_weights_path,
                   confusion_matrix_save_path,
                   dataset_dir,
                   batch_size,
                   validate_dir_name='test',
                   thresholds=0.5):
        self.gpu_config(3)

        self.load_model(load_weights_path=load_weights_path)
        self.model.summary()

        validate_dataset = self.load_dataset(
            dataset_dir=os.path.join(dataset_dir, validate_dir_name),
            name=validate_dir_name,
            batch_size=batch_size,
            augmentation_methods=[
                'rgb_2_bgr'
            ]
        )

        label_true_list = []
        label_pred_list = []
        custom_metric = CustomMetric(thresholds)
        for image_batch, y_true_batch in validate_dataset:
            y_pred_batch = self.model.predict(image_batch)
            label_true_batch_with_others = custom_metric.y_true_with_others(y_true_batch)
            label_pred_batch_with_others = custom_metric.y_pred_with_others(y_pred_batch)
            label_true_list.extend(label_true_batch_with_others.numpy())
            label_pred_list.extend(label_pred_batch_with_others.numpy())
        acc = accuracy_score(label_true_list, label_pred_list)
        cm = confusion_matrix(label_true_list, label_pred_list)
        report = classification_report(label_true_list, label_pred_list)
        print(acc)
        print(cm)
        print(report)
        plot_confusion_matrix(cm, [idx for idx in range(len(self.class_name_list))], confusion_matrix_save_path)

    def predict(self, image, thresholds=0.5):
        if self.model is None:
            raise Exception("The model hasn't loaded yet, run `self.load_model()` first")
        input_image, _ = self.preprocess_input(image, None)
        input_images = tf.expand_dims(input_image, axis=0)
        outputs = self.model.predict(input_images)

        for output in outputs:
            idx = tf.math.argmax(output)
            confidence = output[idx]
            if confidence < thresholds:
                idx = -1
            label = self.class_name_list[idx + 1]
            break

        res = {
            'label': label,
            'confidence': confidence
        }
        return res

    def test(self):
        y_true = [
            [0, 1, 0],
            [0, 1, 0],
            [0, 0, 1],
            [0, 0, 0],
        ]
        y_pre = [
            [0.1, 0.8, 0.9],  # TODO multi_label
            [0.2, 0.8, 0.1],
            [0.2, 0.1, 0.85],
            [0.2, 0.4, 0.1],
        ]

        # x = tf.argmax(y_pre, axis=1)
        # y = tf.reduce_sum(y_pre, axis=1)
        # print(x)
        # print(y)

        # m = tf.keras.metrics.TopKCategoricalAccuracy(k=1)
        m = CustomMetric(0.5)
        m.update_state(y_true, y_pre)
        print(m.result().numpy())