import os import random import cv2 import tensorflow as tf import tensorflow_addons as tfa from keras.applications.mobilenet_v2 import MobileNetV2 from keras import layers, models, optimizers, callbacks, applications from sklearn.metrics import confusion_matrix, accuracy_score, classification_report from .base_class import BaseModel from .metrics import CustomMetric from .utils import history_save, plot_confusion_matrix class F3Classification(BaseModel): def __init__(self, class_name_list, class_other_first, *args, **kwargs): super().__init__(*args, **kwargs) self.class_name_list = class_name_list self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1 self.class_label_map = self.get_class_label_map(class_name_list, class_other_first) self.image_ext_set = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} self.model = None @staticmethod def gpu_config(gpu_idx=0): gpus = tf.config.experimental.list_physical_devices(device_type='GPU') tf.config.set_visible_devices(devices=gpus[gpu_idx], device_type='GPU') @staticmethod def get_class_label_map(class_name_list, class_other_first=False): return {cn_name: idx - 1 if class_other_first else idx for idx, cn_name in enumerate(class_name_list)} def get_image_label_list(self, dataset_dir): image_path_list = [] label_list = [] for class_name in os.listdir(dataset_dir): class_dir_path = os.path.join(dataset_dir, class_name) if not os.path.isdir(class_dir_path): continue if class_name not in self.class_label_map: continue label = self.class_label_map[class_name] for file_name in os.listdir(class_dir_path): if os.path.splitext(file_name)[1] not in self.image_ext_set: continue file_path = os.path.join(class_dir_path, file_name) image_path_list.append(file_path) label_list.append(tf.one_hot(label, depth=self.class_count)) return image_path_list, label_list @staticmethod # @tf.function def random_rgb_2_bgr(image, label): # 1/2 if random.random() < 0.5: image = image[:, :, ::-1] return image, label @staticmethod # @tf.function def rgb_2_bgr(image, label): image = image[:, :, ::-1] return image, label @staticmethod # @tf.function def random_grayscale_expand(image, label): if random.random() < 0.1: image = tf.image.rgb_to_grayscale(image) image = tf.image.grayscale_to_rgb(image) return image, label @staticmethod def random_flip_left_right(image, label): # if random.random() < 0.2: image = tf.image.random_flip_left_right(image) return image, label @staticmethod def random_flip_up_down(image, label): # if random.random() < 0.2: image = tf.image.random_flip_up_down(image) return image, label @staticmethod def random_rot90(image, label): if random.random() < 0.3: image = tf.image.rot90(image, k=random.randint(1, 3)) return image, label @staticmethod # @tf.function def load_image(image_path, label): image = tf.io.read_file(image_path) # image = tf.image.decode_image(image, channels=3) # TODO ? # image = tf.image.decode_png(image, channels=3) image = tf.image.decode_jpeg(image, channels=3, dct_method='INTEGER_ACCURATE') image = tf.image.resize(image, [224, 224]) return image, label @staticmethod # @tf.function def preprocess_input(image, label): # image = tf.image.resize(image, [224, 224]) image = applications.mobilenet_v2.preprocess_input(image) return image, label def load_dataset(self, dataset_dir, name, batch_size=128, augmentation_methods=[]): image_and_label_list = self.get_image_label_list(dataset_dir) tensor_slice_dataset = tf.data.Dataset.from_tensor_slices(image_and_label_list, name=name) dataset = tensor_slice_dataset.shuffle(len(image_and_label_list[0]), reshuffle_each_iteration=True) dataset = dataset.map( self.load_image, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False) for augmentation_method in augmentation_methods: dataset = dataset.map( getattr(self, augmentation_method), num_parallel_calls=tf.data.AUTOTUNE, deterministic=False) dataset = dataset.map( self.preprocess_input, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False) parallel_batch_dataset = dataset.batch( batch_size=batch_size, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False, name=name, ).prefetch(tf.data.AUTOTUNE) return parallel_batch_dataset def load_model(self, for_training=False, load_weights_path=None): if self.model is not None: raise Exception('Model is loaded, if you are sure to reload the model, set `self.model = None` first') base_model = MobileNetV2( input_shape=(224, 224, 3), # alpha=0.35, alpha=0.5, # alpha=1, include_top=False, weights='imagenet', pooling='avg', ) x = base_model.output x = layers.Dropout(0.5)(x) x = layers.Dense(256, activation='sigmoid', name='dense')(x) x = layers.Dropout(0.5)(x) x = layers.Dense(self.class_count, activation='sigmoid', name='output')(x) self.model = models.Model(inputs=base_model.input, outputs=x) if isinstance(load_weights_path, str) and os.path.isfile(load_weights_path): self.model.load_weights(load_weights_path, by_name=True, skip_mismatch=True) elif for_training: freeze = True for layer in self.model.layers: layer.trainable = not freeze if freeze and layer.name == 'block_16_project_BN': freeze = False def train(self, dataset_dir, epoch, batch_size, ckpt_path, history_save_path, load_weights_path=None, train_dir_name='train', validate_dir_name='test', thresholds=0.5, metrics_name='accuracy'): self.gpu_config(1) self.load_model(for_training=True, load_weights_path=load_weights_path) self.model.summary() self.model.compile( # optimizer=optimizers.Adam(learning_rate=3e-4), optimizer=optimizers.Adam(learning_rate=1e-4), loss=tfa.losses.SigmoidFocalCrossEntropy(), # TODO ? metrics=[CustomMetric(thresholds, name=metrics_name), ], loss_weights=None, weighted_metrics=None, run_eagerly=None, steps_per_execution=None, jit_compile=None, ) train_dataset = self.load_dataset( dataset_dir=os.path.join(dataset_dir, train_dir_name), name=train_dir_name, batch_size=batch_size, # augmentation_methods=[], augmentation_methods=[ 'random_flip_left_right', 'random_flip_up_down', 'random_rot90', # 'random_rgb_2_bgr', # 'rgb_2_bgr', 'random_grayscale_expand' ], ) validate_dataset = self.load_dataset( dataset_dir=os.path.join(dataset_dir, validate_dir_name), name=validate_dir_name, batch_size=batch_size, augmentation_methods=[ 'rgb_2_bgr' ] ) ckpt_callback = callbacks.ModelCheckpoint(ckpt_path, save_best_only=True) es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True) history = self.model.fit( train_dataset, epochs=epoch, validation_data=validate_dataset, callbacks=[ckpt_callback, # es_callback ], ) history_save(history, history_save_path, metrics_name) def evaluation(self, load_weights_path, confusion_matrix_save_path, dataset_dir, batch_size, validate_dir_name='test', thresholds=0.5): self.gpu_config(3) self.load_model(load_weights_path=load_weights_path) self.model.summary() validate_dataset = self.load_dataset( dataset_dir=os.path.join(dataset_dir, validate_dir_name), name=validate_dir_name, batch_size=batch_size, augmentation_methods=[ 'rgb_2_bgr' ] ) label_true_list = [] label_pred_list = [] custom_metric = CustomMetric(thresholds) for image_batch, y_true_batch in validate_dataset: y_pred_batch = self.model.predict(image_batch) label_true_batch_with_others = custom_metric.y_true_with_others(y_true_batch) label_pred_batch_with_others = custom_metric.y_pred_with_others(y_pred_batch) label_true_list.extend(label_true_batch_with_others.numpy()) label_pred_list.extend(label_pred_batch_with_others.numpy()) acc = accuracy_score(label_true_list, label_pred_list) cm = confusion_matrix(label_true_list, label_pred_list) report = classification_report(label_true_list, label_pred_list) print(acc) print(cm) print(report) plot_confusion_matrix(cm, [idx for idx in range(len(self.class_name_list))], confusion_matrix_save_path) def predict(self, image, thresholds=0.5): if self.model is None: raise Exception("The model hasn't loaded yet, run `self.load_model()` first") input_image, _ = self.preprocess_input(image, None) input_images = tf.expand_dims(input_image, axis=0) outputs = self.model.predict(input_images) for output in outputs: idx = tf.math.argmax(output) confidence = output[idx] if confidence < thresholds: idx = -1 label = self.class_name_list[idx + 1] break res = { 'label': label, 'confidence': confidence } return res def test(self): y_true = [ [0, 1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0], ] y_pre = [ [0.1, 0.8, 0.9], # TODO multi_label [0.2, 0.8, 0.1], [0.2, 0.1, 0.85], [0.2, 0.4, 0.1], ] # x = tf.argmax(y_pre, axis=1) # y = tf.reduce_sum(y_pre, axis=1) # print(x) # print(y) # m = tf.keras.metrics.TopKCategoricalAccuracy(k=1) m = CustomMetric(0.5) m.update_state(y_true, y_pre) print(m.result().numpy())