infer.py
4.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
from PIL import Image, ImageDraw, ImageFont
import cv2
import torch
from model.utils import utils
from torchvision import transforms
from model.net.net import *
import argparse
parse = argparse.ArgumentParser('infer models')
parse.add_argument('demo', type=str, help='推理类型支持:image/video/camera')
parse.add_argument('--weights_path', type=str, default='', help='模型权重路径')
parse.add_argument('--image_path', type=str, default='', help='图片存放路径')
parse.add_argument('--video_path', type=str, default='', help='视频路径')
parse.add_argument('--camera_id', type=int, default=0, help='摄像头id')
class ModelInfer:
def __init__(self, config, weights_path):
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.transform = transforms.Compose([
transforms.ToTensor()
])
self.net = ClassifierNet(self.config['net_type'], len(self.config['class_names']),
False).to(self.device)
if weights_path is not None:
if os.path.exists(weights_path):
self.net.load_state_dict(torch.load(weights_path))
print('successfully loading model weights!')
else:
print('no loading model weights!')
else:
print('please input weights_path!')
exit(0)
self.net.eval()
def image_infer(self, image_path):
image = Image.open(image_path)
image_data = utils.keep_shape_resize(image, self.config['image_size'])
image_data = self.transform(image_data)
image_data = torch.unsqueeze(image_data, dim=0).to(self.device)
out = self.net(image_data)
out = torch.argmax(out)
result = self.config['class_names'][int(out)]
draw = ImageDraw.Draw(image)
font = ImageFont.truetype(r"C:\Windows\Fonts\BRITANIC.TTF", 35)
draw.text((10, 10), result, font=font, fill='red')
image.show()
def video_infer(self, video_path):
cap = cv2.VideoCapture(video_path)
while True:
_, frame = cap.read()
if _:
image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
image_data = Image.fromarray(image_data)
image_data = utils.keep_shape_resize(image_data, self.config['image_size'])
image_data = self.transform(image_data)
image_data = torch.unsqueeze(image_data, dim=0).to(self.device)
out = self.net(image_data)
out = torch.argmax(out)
result = self.config['class_names'][int(out)]
cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2)
cv2.imshow('frame', frame)
if cv2.waitKey(24) & 0XFF == ord('q'):
break
else:
break
def camera_infer(self, camera_id):
cap = cv2.VideoCapture(camera_id)
while True:
_, frame = cap.read()
h, w, c = frame.shape
if _:
image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
image_data = Image.fromarray(image_data)
image_data = utils.keep_shape_resize(image_data, self.config['image_size'])
image_data = self.transform(image_data)
image_data = torch.unsqueeze(image_data, dim=0).to(self.device)
out = self.net(image_data)
out = torch.argmax(out)
result = self.config['class_names'][int(out)]
cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2)
cv2.imshow('frame', frame)
if cv2.waitKey(24) & 0XFF == ord('q'):
break
else:
break
if __name__ == '__main__':
args = parse.parse_args()
config = utils.load_config_util('config/config.yaml')
model = ModelInfer(config, args.weights_path)
if args.demo == 'image':
model.image_infer(args.image_path)
elif args.demo == 'video':
model.video_infer(args.video_path)
elif args.demo == 'camera':
model.camera_infer(args.camera_id)
else:
exit(0)