model.py 1.07 KB
import tensorflow as tf


class F3Classification:

    def __init__(self, class_name_list, class_other_first, *args, **kwargs):
        self.class_name_list = class_name_list
        self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1
        self.model_name = 'classification_model'
        self.signature_name = 'serving_default'
        self.server_name = 'server_1'

    @staticmethod
    def preprocess_input(image):
        image = tf.image.resize(image, [224, 224])
        image = tf.keras.applications.mobilenet_v2.preprocess_input(image)
        input_images = tf.expand_dims(image, axis=0)
        return input_images

    def reprocess_output(self, outputs, thresholds=0.5):
        for output in outputs:
            idx = tf.math.argmax(output)
            confidence = output[idx]
            if confidence < thresholds:
                idx = -1
            label = self.class_name_list[idx + 1]
            break

        res = {
            'label': label,
            'confidence': float(confidence)
        }
        return res