doc_det.py 8.86 KB
import cv2
import MNN
import numpy as np
from scipy.special import softmax

class Doc_Detector(object):
    def __init__(self, model_path):
        
        self.strides = [8, 16, 32]
        self.input_shape = [320, 320]
        self.reg_max = 7
        self.prob_threshold = 0.4
        self.iou_threshold = 0.3
        self.num_candidate = 1000
        self.top_k = -1
        self.image_mean = [103.53, 116.28, 123.675]
        self.image_std = [57.375, 57.12, 58.395]
        self.input_size = (self.input_shape[1], self.input_shape[0])
        # self.class_names = ["Head", "Hand"]
        self.class_names = [
            "ID_front",
            "ID_back",
            "zhiye_front",
            "zhiye_back",
            "doc",
            "phone"
        ]

        self.interpreter = MNN.Interpreter(model_path)
        self.session = self.interpreter.createSession()
        self.input_tensor = self.interpreter.getSessionInput(self.session)

    def get_resize_matrix(self, raw_shape, dst_shape, keep_ratio):
        r_w, r_h = raw_shape
        d_w, d_h = dst_shape
        Rs = np.eye(3)
        if keep_ratio:
            C = np.eye(3)
            C[0, 2] = -r_w / 2
            C[1, 2] = -r_h / 2

            if r_w / r_h < d_w / d_h:
                ratio = d_h / r_h
            else:
                ratio = d_w / r_w
            Rs[0, 0] *= ratio
            Rs[1, 1] *= ratio

            T = np.eye(3)
            T[0, 2] = 0.5 * d_w
            T[1, 2] = 0.5 * d_h
            return T @ Rs @ C
        else:
            Rs[0, 0] *= d_w / r_w
            Rs[1, 1] *= d_h / r_h
            return Rs
    
    def preprocess(self, image):
        # resize image
        resize_m = self.get_resize_matrix((image.shape[1], image.shape[0]), self.input_size, True)
        image_resize = cv2.warpPerspective(image, resize_m, dsize=self.input_size)
        # normalize image
        image_input = image_resize.astype(np.float32) / 255
        image_mean = np.array(self.image_mean, dtype=np.float32).reshape(1, 1, 3) / 255
        image_std = np.array(self.image_std, dtype=np.float32).reshape(1, 1, 3) / 255
        image_input = (image_input - image_mean) / image_std
        # expand dims
        image_input = np.transpose(image_input, [2, 0, 1])
        image_input = np.expand_dims(image_input, axis=0)
        return image_input, resize_m

    def postprocess(self, scores, raw_boxes, resize_m, raw_shape):
        # generate centers
        decode_boxes = []
        select_scores = []
        for stride, box_distribute, score in zip(self.strides, raw_boxes, scores):
            # centers
            fm_h = self.input_shape[0] / stride
            fm_w = self.input_shape[1] / stride
            h_range = np.arange(fm_h)
            w_range = np.arange(fm_w)
            ww, hh = np.meshgrid(w_range, h_range)
            ct_row = (hh.flatten() + 0.5) * stride
            ct_col = (ww.flatten() + 0.5) * stride
            center = np.stack((ct_col, ct_row, ct_col, ct_row), axis=1)

            # box distribution to distance
            reg_range = np.arange(self.reg_max + 1)
            box_distance = box_distribute.reshape((-1, self.reg_max + 1))
            box_distance = softmax(box_distance, axis=1)
            box_distance = box_distance * np.expand_dims(reg_range, axis=0)
            box_distance = np.sum(box_distance, axis=1).reshape((-1, 4))
            box_distance = box_distance * stride

            # top K candidate
            topk_idx = np.argsort(score.max(axis=1))[::-1]
            topk_idx = topk_idx[: self.num_candidate]
            center = center[topk_idx]
            score = score[topk_idx]
            box_distance = box_distance[topk_idx]

            # decode box
            decode_box = center + [-1, -1, 1, 1] * box_distance

            select_scores.append(score)
            decode_boxes.append(decode_box)

        # nms
        bboxes = np.concatenate(decode_boxes, axis=0)
        confidences = np.concatenate(select_scores, axis=0)
        picked_box_probs = []
        picked_labels = []
        for class_index in range(0, confidences.shape[1]):
            probs = confidences[:, class_index]
            mask = probs > self.prob_threshold
            probs = probs[mask]
            if probs.shape[0] == 0:
                continue
            subset_boxes = bboxes[mask, :]
            box_probs = np.concatenate([subset_boxes, probs.reshape(-1, 1)], axis=1)
            box_probs = self.hard_nms(
                box_probs,
                iou_threshold=self.iou_threshold,
                top_k=self.top_k,
            )
            picked_box_probs.append(box_probs)
            picked_labels.extend([class_index] * box_probs.shape[0])
        if not picked_box_probs:
            return np.array([]), np.array([]), np.array([])
        picked_box_probs = np.concatenate(picked_box_probs)

        picked_box_probs[:, :4] = self.warp_boxes(
            picked_box_probs[:, :4], np.linalg.inv(resize_m), raw_shape[1], raw_shape[0]
        )
        return (picked_box_probs[:, :4].astype(np.int32), np.array(picked_labels), picked_box_probs[:, 4],)
   
    def warp_boxes(self, boxes, M, width, height):
        n = len(boxes)
        if n:
            # warp points
            xy = np.ones((n * 4, 3))
            xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
                n * 4, 2
            )  # x1y1, x2y2, x1y2, x2y1
            xy = xy @ M.T  # transform
            xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8)  # rescale
            # create new boxes
            x = xy[:, [0, 2, 4, 6]]
            y = xy[:, [1, 3, 5, 7]]
            xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
            # clip boxes
            xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
            xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
            return xy.astype(np.float32)
        else:
            return boxes
 
    def hard_nms(self, box_scores, iou_threshold, top_k=-1, candidate_size=200):
        scores = box_scores[:, -1]
        boxes = box_scores[:, :-1]
        picked = []
        indexes = np.argsort(scores)
        indexes = indexes[-candidate_size:]
        while len(indexes) > 0:
            current = indexes[-1]
            picked.append(current)
            if 0 < top_k == len(picked) or len(indexes) == 1:
                break
            current_box = boxes[current, :]
            indexes = indexes[:-1]
            rest_boxes = boxes[indexes, :]
            iou = self.iou_of(
                rest_boxes,
                np.expand_dims(current_box, axis=0),
            )
            indexes = indexes[iou <= iou_threshold]

        return box_scores[picked, :]
    
    def iou_of(self, boxes0, boxes1, eps=1e-5):
        overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
        overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])

        overlap_area = self.area_of(overlap_left_top, overlap_right_bottom)
        area0 = self.area_of(boxes0[..., :2], boxes0[..., 2:])
        area1 = self.area_of(boxes1[..., :2], boxes1[..., 2:])
        return overlap_area / (area0 + area1 - overlap_area + eps)

    def area_of(self, left_top, right_bottom):
        hw = np.clip(right_bottom - left_top, 0.0, None)
        return hw[..., 0] * hw[..., 1]

    def detect(self, image):
        raw_shape = image.shape
        image_input, resize_m = self.preprocess(image)
        scores, raw_boxes = self.infer_image(image_input)
        if scores[0].ndim == 1:  # handling num_classes=1 case
            scores = [x[:, None] for x in scores]
        bbox, label, score = self.postprocess(scores, raw_boxes, resize_m, raw_shape)
        return bbox, label, score

    def infer_image(self, image):

        tmp_input = MNN.Tensor((1, 3, self.input_size[1], self.input_size[0]), MNN.Halide_Type_Float, image, MNN.Tensor_DimensionType_Caffe)
        self.input_tensor.copyFrom(tmp_input)
        self.interpreter.runSession(self.session)
        score_out_name = [
            "cls_pred_stride_8",
            "cls_pred_stride_16",
            "cls_pred_stride_32",
        ]
        scores = [
            self.interpreter.getSessionOutput(self.session, x).getData()
            for x in score_out_name
        ]
        
        scores = [np.reshape(x, (-1, 6)) for x in scores]
        
        boxes_out_name = ["dis_pred_stride_8", "dis_pred_stride_16", "dis_pred_stride_32"]

        raw_boxes = [self.interpreter.getSessionOutput(self.session, x).getData() for x in boxes_out_name]
        raw_boxes = [np.reshape(x, (-1, 32)) for x in raw_boxes]
        
        return scores, raw_boxes


if __name__ == "__main__":
    model_path = r'models/det_doc_mnn_1.0.0_v0.3.0.mnn'
    detector = Doc_Detector(model_path)
    image_path = r'/data2/face_id/situ_other/pipeline_test/59297ec0094211ecaf3d00163e514671/310faceImageContent163029410817774.jpg'
    image = cv2.imread(image_path)
    out_boxes, out_classes, out_scores = detector.detect(image)
    print(out_boxes, out_classes, out_scores)