infer.py 4.23 KB

import os

from PIL import Image, ImageDraw, ImageFont
import cv2
import torch
from model.utils import utils
from torchvision import transforms
from model.net.net import *
import argparse

parse = argparse.ArgumentParser('infer models')
parse.add_argument('demo', type=str, help='推理类型支持:image/video/camera')
parse.add_argument('--weights_path', type=str, default='', help='模型权重路径')
parse.add_argument('--image_path', type=str, default='', help='图片存放路径')
parse.add_argument('--video_path', type=str, default='', help='视频路径')
parse.add_argument('--camera_id', type=int, default=0, help='摄像头id')


class ModelInfer:
    def __init__(self, config, weights_path):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])
        self.net = ClassifierNet(self.config['net_type'], len(self.config['class_names']),
                                 False).to(self.device)
        if weights_path is not None:
            if os.path.exists(weights_path):
                self.net.load_state_dict(torch.load(weights_path))
                print('successfully loading model weights!')
            else:
                print('no loading model weights!')
        else:
            print('please input weights_path!')
            exit(0)
        self.net.eval()

    def image_infer(self, image_path):
        image = Image.open(image_path)
        image_data = utils.keep_shape_resize(image, self.config['image_size'])
        image_data = self.transform(image_data)
        image_data = torch.unsqueeze(image_data, dim=0).to(self.device)
        out = self.net(image_data)
        out = torch.argmax(out)
        result = self.config['class_names'][int(out)]
        draw = ImageDraw.Draw(image)
        font = ImageFont.truetype(r"C:\Windows\Fonts\BRITANIC.TTF", 35)
        draw.text((10, 10), result, font=font, fill='red')
        image.show()

    def video_infer(self, video_path):
        cap = cv2.VideoCapture(video_path)
        while True:
            _, frame = cap.read()
            if _:
                image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                image_data = Image.fromarray(image_data)
                image_data = utils.keep_shape_resize(image_data, self.config['image_size'])
                image_data = self.transform(image_data)
                image_data = torch.unsqueeze(image_data, dim=0).to(self.device)
                out = self.net(image_data)
                out = torch.argmax(out)
                result = self.config['class_names'][int(out)]
                cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2)
                cv2.imshow('frame', frame)
                if cv2.waitKey(24) & 0XFF == ord('q'):
                    break
            else:
                break

    def camera_infer(self, camera_id):
        cap = cv2.VideoCapture(camera_id)
        while True:
            _, frame = cap.read()
            h, w, c = frame.shape
            if _:
                image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                image_data = Image.fromarray(image_data)
                image_data = utils.keep_shape_resize(image_data, self.config['image_size'])
                image_data = self.transform(image_data)
                image_data = torch.unsqueeze(image_data, dim=0).to(self.device)
                out = self.net(image_data)
                out = torch.argmax(out)
                result = self.config['class_names'][int(out)]
                cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2)
                cv2.imshow('frame', frame)
                if cv2.waitKey(24) & 0XFF == ord('q'):
                    break
            else:
                break


if __name__ == '__main__':
    args = parse.parse_args()
    config = utils.load_config_util('config/config.yaml')
    model = ModelInfer(config, args.weights_path)
    if args.demo == 'image':
        model.image_infer(args.image_path)
    elif args.demo == 'video':
        model.video_infer(args.video_path)
    elif args.demo == 'camera':
        model.camera_infer(args.camera_id)
    else:
        exit(0)