main.py 10.9 KB
import os
import cv2
import time
import json
import base64
import numpy as np

import data_util
import load_util

from src import face_detector
from src import face_landmark
from src import direction_classifier
from src import face_id
from src import doc_det
from src import abnormal_face

class InitModel(object):

    def __init__(self, config):
       
        face_det_model_path = config['FACE_DET']['MODEL_PATH']
        face_landmark_model_path = config['FACE_LANDMARK_DET']['MODEL_PATH']
        face_id_model_path = config['FACE_RECOGNIZE']['MODEL_PATH']
        id_rotation_model_path = config['ID_DIRECTION']['MODEL_PATH']
        doc_det_model_path = config['DOC_DET']['MODEL_PATH']        
        abnormal_face_model_path = config['ABNORMAL_FACE']['MODEL_PATH']  
                                              
        self.face_det_thr = config['FACE_DET']['FACE_THR']
        self.is_rot_id = config['SERVICE']['ROT_ID']
        self.is_rot_ids = config['SERVICE']['ROT_IDS']
        self.is_det_id = config['SERVICE']['DET_ID']
        self.is_det_ids = config['SERVICE']['DET_IDS']
        self.return_location = config['SERVICE']['RETURN_LOCATION']
        self.doc_det = config['SERVICE']['DOC_DET']
        self.abnormal_face = config['SERVICE']['ABNORMAL_FACE']

        self.direction_classifier = direction_classifier.Direction_Classifier(id_rotation_model_path)
        self.face_recognizer = face_id.Face_Recognizer(face_id_model_path)
        self.face_detector = face_detector.Face_Detector(face_det_model_path)           
        self.face_landmark_detector = face_landmark.Landmark_Detector(face_landmark_model_path)
        self.doc_detector = doc_det.Doc_Detector(doc_det_model_path)
        self.abnormal_recognizer = abnormal_face.Abnormal_Face(abnormal_face_model_path)

    def error_return(self, result_dict, code):
        result_dict['match'] = False
        result_dict['face_sim'] = 0
        result_dict['more'] = False
        result_dict['code'] = code
        result_dict['sim'] = 0

        return result_dict
        
    def recognition(self, id_rgb_image, life_rgb_image, config, compare_type):
    
        result_dict = {'code':2000}
        face_bboxes = []
        face_scores = []
        faces = []
        
        if compare_type == 0:
            thr = config['FACE_RECOGNIZE']['ID_ID_THR']
        elif compare_type == 1:
            thr = config['FACE_RECOGNIZE']['ID_LIFE_THR']
        elif compare_type == 2:
            thr = config['FACE_RECOGNIZE']['LIFE_LIFE_THR']

        if self.doc_det:
            if compare_type == 0:
                id_doc_boxes, id_doc_classes, id_doc_scores = self.doc_detector.detect(id_rgb_image)
                if len(id_doc_scores) == 0:
                    return self.error_return(result_dict, 4001)
                id_doc_max = max(id_doc_scores)
                id_doc_max_idx = id_doc_scores.tolist().index(id_doc_max)
                id_doc_max_box = id_doc_boxes[id_doc_max_idx]
                id_rgb_image = id_rgb_image[id_doc_max_box[1]:id_doc_max_box[3], id_doc_max_box[0]:id_doc_max_box[2],]

                life_doc_boxes, life_doc_classes, life_doc_scores = self.doc_detector.detect(life_rgb_image)
                if len(life_doc_scores) == 0:
                    return self.error_return(result_dict, 4002)
                life_doc_max = max(life_doc_scores)
                life_doc_max_idx = life_doc_scores.tolist().index(life_doc_max)
                life_doc_max_box = life_doc_boxes[life_doc_max_idx]
                life_rgb_image = life_rgb_image[life_doc_max_box[1]:life_doc_max_box[3], life_doc_max_box[0]:life_doc_max_box[2],]
                
                result_dict['id_doc_bboxes'] = id_doc_boxes.astype(np.int32).tolist()
                result_dict['life_doc_bboxes'] = life_doc_boxes.astype(np.int32).tolist()

            elif compare_type == 1:
                id_doc_boxes, id_doc_classes, id_doc_scores = self.doc_detector.detect(id_rgb_image)
                if len(id_doc_scores) == 0:
                    return self.error_return(result_dict, 4001)
                id_doc_max = max(id_doc_scores)
                id_doc_max_idx = id_doc_scores.tolist().index(id_doc_max)
                id_doc_max_box = id_doc_boxes[id_doc_max_idx]
                id_rgb_image = id_rgb_image[id_doc_max_box[1]:id_doc_max_box[3], id_doc_max_box[0]:id_doc_max_box[2],]
         
                result_dict['id_doc_bboxes'] = id_doc_boxes.astype(np.int32).tolist()
 
        id_h, id_w, id_c = id_rgb_image.shape
        if compare_type == 0 and self.is_rot_id:
            id_rgb_image, id_direction_index = self.direction_classifier.reg(id_rgb_image)
            life_rgb_image, life_direction_index = self.direction_classifier.reg(life_rgb_image)
            result_dict['id_direction'] = int(id_direction_index)
            result_dict['life_direction'] = int(life_direction_index)
        elif compare_type == 1 and self.is_rot_id:
            id_rgb_image, id_direction_index = self.direction_classifier.reg(id_rgb_image)
            result_dict['id_direction'] = int(id_direction_index)
       
        id_h, id_w, id_c = id_rgb_image.shape
        life_h, life_w, life_c = life_rgb_image.shape
        
        id_face_bboxes, id_face_landmarks, id_max_idx = self.face_detector.detect(id_rgb_image, self.face_det_thr)
        life_face_bboxes, life_face_landmarks, life_max_idx = self.face_detector.detect(life_rgb_image, self.face_det_thr)
        
        if compare_type != 0 and self.abnormal_face:
            new_life_face_bboxes = []
            print('1'*100)
            for box_idx, face_bbox in enumerate(life_face_bboxes):            
                abnormal_pre = self.abnormal_recognizer.reg(life_rgb_image, face_bbox)
                if abnormal_pre == 1:
                    new_life_face_bboxes.append(face_bbox)
 
            life_face_bboxes = new_life_face_bboxes      

        if not id_face_bboxes:
            return  self.error_return(result_dict, 4003)
 
        if not life_face_bboxes:
            return  self.error_return(result_dict, 4004)
       
        result_dict['id_face_bboxes'] = id_face_bboxes
        result_dict['id_face_landmarks'] = id_face_landmarks
        result_dict['life_face_bboxes'] = life_face_bboxes
        result_dict['life_face_landmarks'] = life_face_landmarks

        id_face_landmark, id_face = self.face_landmark_detector.detect(id_rgb_image, id_face_bboxes[id_max_idx])

        id_face_landmark = [id_face_landmark[104], id_face_landmark[105], id_face_landmark[46], id_face_landmark[84], id_face_landmark[90]]
        id_norm_image = data_util.get_norm_face(id_face, id_face_landmark)
        
        norm_images = [id_norm_image]

        for f, life_face_bbox in enumerate(life_face_bboxes):
            life_face_landmark, life_face = self.face_landmark_detector.detect(life_rgb_image, life_face_bbox)
            life_face_landmark = [life_face_landmark[104], life_face_landmark[105], life_face_landmark[46], life_face_landmark[84], life_face_landmark[90]]
            life_norm_image = data_util.get_norm_face(life_face, life_face_landmark)
            norm_images.append(life_norm_image)

        embeddings = self.face_recognizer.recognize(norm_images)
        gallery_vector = np.mat(embeddings[0])
        
        res = False
        sim = 0
        for p in range(1, len(embeddings)):
            compare_vector = np.mat(embeddings[p])
            
            dot = np.sum(np.multiply(gallery_vector, compare_vector), axis=1)
            norm = np.linalg.norm(gallery_vector, axis=1) * np.linalg.norm(compare_vector, axis=1)
            dist_1 = dot / norm

            sim = dist_1.tolist()
            sim = sim[0][0]

            if sim > thr: 
                res = True
                result_dict['id_index'] = id_max_idx
                result_dict['life_index'] = p-1
                result_dict['sim'] = sim
            print('sim {} : {}'.format(p, sim))

        result_dict['match'] = res
                
        return result_dict
        
        
def main(id_base64_image, life_base64_image, compare_type):
 
    if format_type == 0:
        life_rgb_image = data_util.base64_to_img(life_base64_image)
        id_rgb_image = data_util.base64_to_img(id_base64_image)
        result_dict = inited_model.recognition(id_rgb_image, life_rgb_image, config, compare_type)  
    elif format_type == 1:
        result_dict = inited_model.recognition(id_base64_image, life_base64_image, config, compare_type)  

    with open("tmp/test.json", "w") as f:
        json.dump(result_dict, f)

    return result_dict
    

if __name__ == '__main__':

    image_dir = r'/data2/face_id/situ_other/pipeline_test/'
    image_list_txt_path = r'/data2/face_id/situ_other/test2.txt'
    config_path = r'config.yaml'

    config = load_util.load_config(config_path)
    
    inited_model = InitModel(config)
    
    compare_type = config['SERVICE']['SCHEMA']  # 0:id-id  1: id-life  2: life-life 
    format_type = config['SERVICE']['FORMAT']  # 0:base64  1: RGB  2: URL

    image_list_txt = open(image_list_txt_path, 'r')
    image_list_txt_lines = image_list_txt.readlines()
    
    hit = 0.01
    hit_pos = 0.01
    pre_pos = 0.01
    pre_all = 0.01
    positive_num = 0.01
    for image_list_txt_line in image_list_txt_lines:
        arr = image_list_txt_line.strip().split(' ')
        label = arr[-1]
    
        pre_all += 1
        if label == '1':
            positive_num += 1
    
        id_image_name_arr = arr[0].split('/')
        id_image_name = id_image_name_arr[-1]
        id_image_name = id_image_name[2:]
        id_image_dir = id_image_name_arr[-2]
        id_image_path = os.path.join(image_dir, id_image_dir, id_image_name)
            
        life_image_name_arr = arr[1].split('/')
        life_image_name = life_image_name_arr[-1]
        life_image_name = life_image_name[2:]
        life_image_dir = life_image_name_arr[-2]
        life_image_path = os.path.join(image_dir, life_image_dir, life_image_name)
    
        id_image = cv2.imread(id_image_path)
        life_image = cv2.imread(life_image_path)
    
        r, id_image_str = cv2.imencode('.jpg', id_image)
        id_base64_image = base64.b64encode(id_image_str)

        r, life_image_str = cv2.imencode('.jpg', life_image)
        life_base64_image = base64.b64encode(life_image_str)
            
        st = time.time()
        result_dict = {}
        if format_type == 0:
            result_dict = main(id_base64_image, life_base64_image, compare_type)
        elif format_type == 1:
            result_dict = main(id_image, life_image, compare_type)
        print(result_dict)
        et = time.time()
        print('total time cost:{}'.format(round((et-st), 2)))
        
        print(label) 
        res = result_dict['match']
        if res:
            if label == '1':
                hit_pos += 1
                hit += 1
            pre_pos += 1
        else:
            if label == '0':
                hit += 1

    print('precision:{}'.format(hit_pos/pre_pos))
    print('recall:{}'.format(hit_pos/positive_num))
    print('accuracy:{}'.format(hit/pre_all))