model.py 6.87 KB
import os
import random
import tensorflow as tf
import tensorflow_addons as tfa
from keras.applications.mobilenet_v2 import MobileNetV2
from keras import layers, models, optimizers, losses, metrics, callbacks, applications
import matplotlib.pyplot as plt

from base_class import BaseModel


@tf.function
def random_rgb_2_bgr(image, label):
    if random.random() > 0.5:
        return image, label
    image = image[:, :, ::-1]
    return image, label


@tf.function
def random_grayscale_expand(image, label):
    if random.random() > 0.1:
        return image, label
    image = tf.image.rgb_to_grayscale(image)
    image = tf.image.grayscale_to_rgb(image)
    return image, label


@tf.function
def load_image(image_path, label):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_image(image, channels=3)
    return image, label


@tf.function
def preprocess_input(image, label):
    image = tf.image.resize(image, [224, 224])
    image = applications.mobilenet_v2.preprocess_input(image)
    return image, label


class F3Classification(BaseModel):

    def __init__(self, class_name_list, class_other_first, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1
        self.class_label_map = self.get_class_label_map(class_name_list, class_other_first)

    @staticmethod
    def get_class_label_map(class_name_list, class_other_first=False):
        return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)}

    def get_image_label_list(self, dataset_dir):
        image_path_list = []
        label_list = []
        for class_name in os.listdir(dataset_dir):
            class_dir_path = os.path.join(dataset_dir, class_name)
            if not os.path.isdir(class_dir_path):
                continue
            if class_name not in self.class_label_map:
                continue
            label = self.class_label_map[class_name]
            for file_name in os.listdir(class_dir_path):
                # TODO image check
                file_path = os.path.join(class_dir_path, file_name)
                image_path_list.append(file_path)
                label_list.append(tf.one_hot(label, depth=self.class_count))
        return image_path_list, label_list

    def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]):
        image_and_label_list = self.get_image_label_list(dataset_dir)
        tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name)
        tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True)
        tensor_slice_dataset.map(load_image,
                                 num_parallel_calls=tf.data.AUTOTUNE,
                                 deterministic=False)
        for augmentation_method in augmentation_methods:
            tensor_slice_dataset.map(getattr(self, augmentation_method),
                                     num_parallel_calls=tf.data.AUTOTUNE,
                                     deterministic=False)
        tensor_slice_dataset.map(preprocess_input,
                                 num_parallel_calls=tf.data.AUTOTUNE,
                                 deterministic=False)
        parallel_batch_dataset = tensor_slice_dataset.batch(
            batch_size=batch_size,
            drop_remainder=True,
            num_parallel_calls=tf.data.AUTOTUNE,
            deterministic=False,
            name=name,
        ).prefetch(tf.data.AUTOTUNE)
        return parallel_batch_dataset

    def load_model(self):
        base_model = MobileNetV2(
            input_shape=(224, 224, 3),
            alpha=0.35,
            include_top=False,
            weights='imagenet',
            pooling='avg',
        )
        x = base_model.output
        x = layers.Dropout(0.5)(x)
        x = layers.Dense(256, activation='sigmoid', name='dense')(x)
        x = layers.Dropout(0.5)(x)
        x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x)
        model = models.Model(inputs=base_model.input, outputs=x)

        freeze = True
        for layer in model.layers:
            layer.trainable = not freeze
            if freeze and layer.name == 'block_16_project_BN':
                freeze = False
        return model

    def train(self, dataset_dir, epoch, batch_size, ckpt_path, train_dir_name='train', validate_dir_name='test'):
        # model = self.load_model()
        # model.summary()
        #
        # model.compile(
        #     optimizer=optimizers.Adam(learning_rate=3e-4),
        #     loss=tfa.losses.SigmoidFocalCrossEntropy(),
        #     metrics=['accuracy', ],
        #
        #     loss_weights=None,
        #     weighted_metrics=None,
        #     run_eagerly=None,
        #     steps_per_execution=None,
        #     jit_compile=None,
        # )

        train_dataset = self.load_dataset(
            dataset_dir=os.path.join(dataset_dir, train_dir_name),
            name=train_dir_name,
            batch_size=batch_size,
            augmentation_methods=[],
            # augmentation_methods=['random_rgb_2_bgr', 'random_grayscale_expand'],
        )
        validate_dataset = self.load_dataset(
            dataset_dir=os.path.join(dataset_dir, validate_dir_name),
            name=validate_dir_name,
            batch_size=batch_size,
            augmentation_methods=[]
        )

        # ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True)
        #
        # history = model.fit(
        #     train_dataset,
        #     epochs=epoch,
        #     validation_data=validate_dataset,
        #     callbacks=[ckpt_callback, ],
        # )
        #
        # acc = history.history['accuracy']
        # val_acc = history.history['val_accuracy']
        #
        # loss = history.history['loss']
        # val_loss = history.history['val_loss']
        #
        # plt.figure(figsize=(8, 8))
        # plt.subplot(2, 1, 1)
        # plt.plot(acc, label='Training Accuracy')
        # plt.plot(val_acc, label='Validation Accuracy')
        # plt.legend(loc='lower right')
        # plt.ylabel('Accuracy')
        # plt.ylim([min(plt.ylim()), 1])
        # plt.title('Training and Validation Accuracy')
        #
        # plt.subplot(2, 1, 2)
        # plt.plot(loss, label='Training Loss')
        # plt.plot(val_loss, label='Validation Loss')
        # plt.legend(loc='upper right')
        # plt.ylabel('Cross Entropy')
        # plt.ylim([0, 1.0])
        # plt.title('Training and Validation Loss')
        # plt.xlabel('epoch')
        # plt.show()

    def test(self):
        print(self.class_label_map)
        print(self.class_count)
        # path = '/home/zwq/data/data_224/train/银行卡/bc_1.jpg'
        # label = 5
        # image, label = self.load_image(path, label)
        # print(image.shape)
        # image, label = self.preprocess_input(image, label)
        # print(image.shape)