inference.py 4.63 KB
import copy
import os
import sys
from pathlib import Path
import numpy as np
import torch

from utils.augmentations import letterbox

FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]  # YOLOv5 root directory
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative
from models.common import DetectMultiBackend
from utils.general import (check_img_size, cv2, non_max_suppression, scale_boxes)
from utils.torch_utils import select_device, smart_inference_mode
from models.yolov5_config import config

classes = ['tampered']


def gen_result_dict(boxes, label_list=[], std=False):
    result = {
        "error_code": 1,
        "result": []
    }
    rs_box = {
        "class": '',
        "score": 0,
        "left": 0,
        "top": 0,
        "width": 0,
        "height": 0
    }

    if not label_list:
        label_list = classes

    for box in boxes:
        result['error_code'] = 0
        box_dict = copy.deepcopy(rs_box)
        if std:
            box_dict['class'] = str(int(box[-1]))
        else:
            box_dict['class'] = label_list[int(box[-1])]

        box_dict['left'] = int(round(box[0], 0))
        box_dict['top'] = int(round(box[1], 0))
        box_dict['width'] = int(round(box[2], 0) - round(box[0], 0))
        box_dict['height'] = int(round(box[3], 0) - (round(box[1], 0)))
        box_dict['score'] = box[-2]
        result['result'].append(box_dict)
    return result


def keep_resize_padding(image):
    h, w, c = image.shape
    if h >= w:
        pad1 = (h - w) // 2
        pad2 = h - w - pad1
        p1 = np.ones((h, pad1, 3)) * 114.0
        p2 = np.ones((h, pad2, 3)) * 114.0
        p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8)
        new_image = np.hstack((p1, image, p2))
    else:
        pad1 = (w - h) // 2
        pad2 = w - h - pad1
        p1 = np.ones((pad1, w, 3)) * 114.0
        p2 = np.ones((pad2, w, 3)) * 114.0
        p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8)
        new_image = np.vstack((p1, image, p2))
    new_image = cv2.resize(new_image, (640, 640))
    return new_image


class Yolov5:
    def __init__(self, cfg=None):
        self.cfg = cfg
        self.device = select_device(self.cfg.device)
        self.model = DetectMultiBackend(self.cfg.weights, device=self.device, dnn=False, data=self.cfg.data, fp16=False)

    def detect(self, image):
        image0 = image.copy()
        stride, names, pt = self.model.stride, self.model.names, self.model.pt
        imgsz = check_img_size(self.cfg.imgsz, s=stride)  # check image size
        # Dataloader
        bs = 1  # batch_size
        # im = letterbox(image, imgsz, stride=stride, auto=True)[0]  # padded resize
        # hh, ww, cc = im.shape
        # tlen1 = (640 - hh) // 2
        # tlen2 = 640 - hh - tlen1
        # t1 = np.zeros((tlen1, ww, cc))
        # t2 = np.zeros((tlen2, ww, cc))
        # im = np.vstack((t1, im, t2))
        im = keep_resize_padding(image)

        # print(im.shape)
        im = im.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
        im = np.ascontiguousarray(im)  # contiguous
        # Run inference
        self.model.warmup(imgsz=(1 if pt or self.model.triton else bs, 3, *imgsz))  # warmup
        im = torch.from_numpy(im).to(self.model.device)
        im = im.half() if self.model.fp16 else im.float()  # uint8 to fp16/32
        im /= 255  # 0 - 255 to 0.0 - 1.0

        if len(im.shape) == 3:
            im = im[None]  # expand for batch dim
        # Inference
        pred = self.model(im, augment=False, visualize=False)
        # print(pred[0].shape)
        # exit(0)
        # NMS
        pred = non_max_suppression(pred, self.cfg.conf_thres, self.cfg.iou_thres, None, False, max_det=self.cfg.max_det)
        det = pred[0]
        # if len(det):
        det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], image0.shape).round()
        result = gen_result_dict(det.cpu().numpy().tolist())
        return result

    def plot(self, image, boxes):
        for box in boxes:
            cv2.rectangle(image, (box[0], box[1], box[2], box[3]), (0, 0, 255), 2)
        return image


if __name__ == "__main__":
    img = cv2.imread(
        '/data/situ_invoice_bill_data/qfs_train_val_data/train_data/authentic/gongshang/images/val/_1594890232.0110397page_11_img_0_name_au_gongshang.jpg')
    detector = Yolov5(config)
    result = detector.detect(img)
    for i in result['result']:
        position = list(i.values())[2:]
        print(position)
        cv2.rectangle(img, (position[0], position[1]), (position[0] + position[2], position[1] + position[3]),
                      (0, 0, 255))
    cv2.imshow('w', img)
    cv2.waitKey(0)
    print(result)