yolov5_infer.py 4.7 KB
import cv2
import numpy as np
import tritonclient.grpc as grpcclient


def keep_resize_padding(image):
    '''
    注意由于输入需要固定640*640的大小,而官方的推理为了加速采用了最小缩放比的方式进行
    导致输入的尺寸不固定,重写resize方法,添加padding到640*640
    '''
    h, w, c = image.shape
    if h >= w:
        pad1 = (h - w) // 2
        pad2 = h - w - pad1
        p1 = np.ones((h, pad1, 3)) * 114.0
        p2 = np.ones((h, pad2, 3)) * 114.0
        p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8)
        new_image = np.hstack((p1, image, p2))
        padding_info = [pad1, pad2, 0]
    else:
        pad1 = (w - h) // 2
        pad2 = w - h - pad1
        p1 = np.ones((pad1, w, 3)) * 114.0
        p2 = np.ones((pad2, w, 3)) * 114.0
        p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8)
        new_image = np.vstack((p1, image, p2))
        padding_info = [pad1, pad2, 1]
    new_image = cv2.resize(new_image, (640, 640))
    return new_image, padding_info


# remove padding
def extract_authentic_bboxes(image, padding_info, bboxes):
    '''
    反算坐标信息
    '''
    pad1, pad2, pad_type = padding_info
    h, w, c = image.shape
    bboxes = np.array(bboxes)
    max_slide = max(h, w)
    scale = max_slide / 640
    bboxes[:, :4] = bboxes[:, :4] * scale
    if pad_type == 0:
        bboxes[:, 0] = bboxes[:, 0] - pad1
    else:
        bboxes[:, 1] = bboxes[:, 1] - pad1
    return bboxes.tolist()


# NMS
def py_nms_cpu(
        prediction,
        conf_thres=0.25,
        iou_thres=0.45,
):
    """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections

    Returns:
         list of detections, on (n,6) tensor per image [xyxy, conf, cls]
    """
    xc = prediction[..., 4] > conf_thres  # candidates
    prediction = prediction[xc]

    # MNS
    x1 = prediction[..., 0] - prediction[..., 2] / 2
    y1 = prediction[..., 1] - prediction[..., 3] / 2
    x2 = prediction[..., 0] + prediction[..., 2] / 2
    y2 = prediction[..., 1] + prediction[..., 3] / 2

    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    score = prediction[..., 5]
    order = np.argsort(score)
    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)

        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        ww, hh = np.maximum(0, xx2 - xx1 + 1), np.maximum(0, yy2 - yy1 + 1)
        inter = ww * hh

        over = inter / (areas[i] + areas[order[1:]] - inter)

        idx = np.where(over < iou_thres)[0]
        order = order[idx + 1]

    return prediction[keep]


def client_init(url='localhost:8001',
                ssl=False,
                private_key=None,
                root_certificates=None,
                certificate_chain=None,
                verbose=False):
    triton_client = grpcclient.InferenceServerClient(
        url=url,
        verbose=verbose,  # 详细输出 默认是False
        ssl=ssl,
        root_certificates=root_certificates,
        private_key=private_key,
        certificate_chain=certificate_chain,
    )
    return triton_client


triton_client = client_init('localhost:8001')
compression_algorithm = None
input_name = 'images'
output_name = 'output0'
model_name = 'yolov5'


def grpc_detect(img):
    image, padding_info = keep_resize_padding(img)
    image = image.transpose((2, 0, 1))[::-1]
    image = image.astype(np.float32)
    image = image / 255.0
    if len(image.shape) == 3:
        image = image[None]

    outputs, inputs = [], []

    # 动态输入
    input_shape = image.shape
    inputs.append(grpcclient.InferInput(input_name, input_shape, 'FP32'))
    outputs.append(grpcclient.InferRequestedOutput(output_name))

    inputs[0].set_data_from_numpy(image.astype(np.float32))

    pred = triton_client.infer(
        model_name=model_name,
        inputs=inputs, outputs=outputs,
        compression_algorithm=compression_algorithm
    )
    pred = pred.as_numpy(output_name).copy()
    result_bboxes = py_nms_cpu(pred)
    result_bboxes = extract_authentic_bboxes(img, padding_info, result_bboxes)
    return result_bboxes


def plot_label(img, result_bboxes):
    print(result_bboxes)
    for bbox in result_bboxes:
        x, y, w, h, conf, cls = bbox
        cv2.rectangle(img, (int(x - w // 2), int(y - h // 2)), (int(x + w // 2), int(y + h // 2)), (0, 0, 255), 2)
    cv2.imshow('im', img)
    cv2.waitKey(0)


if __name__ == '__main__':
    img = cv2.imread(
        '/data/situ_invoice_bill_data/qfs_train_val_data/train_data/authentic/gongshang/images/val/_1594890232.0110397page_11_img_0_name_au_gongshang.jpg')

    result_bboxes = grpc_detect(img)
    plot_label(result_bboxes)