model.py
1.07 KB
import tensorflow as tf
class F3Classification:
def __init__(self, class_name_list, class_other_first, *args, **kwargs):
self.class_name_list = class_name_list
self.class_count = len(class_name_list) if not class_other_first else len(class_name_list) - 1
self.model_name = 'classification_model'
self.signature_name = 'serving_default'
self.server_name = 'server_1'
@staticmethod
def preprocess_input(image):
image = tf.image.resize(image, [224, 224])
image = tf.keras.applications.mobilenet_v2.preprocess_input(image)
input_images = tf.expand_dims(image, axis=0)
return input_images
def reprocess_output(self, outputs, thresholds=0.5):
for output in outputs:
idx = tf.math.argmax(output)
confidence = output[idx]
if confidence < thresholds:
idx = -1
label = self.class_name_list[idx + 1]
break
res = {
'label': label,
'confidence': float(confidence)
}
return res