infer.py 2.16 KB
import os

import cv2

import torch
from PIL import Image
import numpy as np
from torchvision import transforms
from net import *



def video(net,):
    cap=cv2.VideoCapture(0)
    while True:
        _, frame = cap.read()
        image = Image.fromarray(frame)
        w, h = image.size
        temp = max(w, h)
        mask = Image.new('RGB', (temp, temp))
        if w >= h:
            mask.paste(image, (0, (w - h) // 2))
        else:
            mask.paste(image, ((h - w) // 2, 0))
        mask = mask.resize((128, 128))
        mask = np.array(mask)
        mask = cv2.cvtColor(mask, cv2.COLOR_RGB2BGR)
        mask_image = torch.unsqueeze(transform(mask), dim=0)
        out = net(mask_image)
        print(out)
        out=torch.argmax(out,dim=1)
        result = classes_names[int(out.item())]
        cv2.putText(frame, result, (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), thickness=2)
        cv2.imshow('frame', frame)

        if cv2.waitKey(1) & 0XFF == ord('q'):
            break
    cap.release()
    cv2.destroyAllWindows()

def image_cls(net,path):
    frame=cv2.imread(path)
    image = Image.fromarray(frame)
    w, h = image.size
    temp = max(w, h)
    mask = Image.new('RGB', (temp, temp))
    if w >= h:
        mask.paste(image, (0, (w - h) // 2))
    else:
        mask.paste(image, ((h - w) // 2, 0))
    mask = mask.resize((128, 128))
    mask = np.array(mask)
    mask = cv2.cvtColor(mask, cv2.COLOR_RGB2BGR)
    mask_image = torch.unsqueeze(transform(mask), dim=0)
    out = net(mask_image)
    print(out)
    out = torch.argmax(out, dim=1)
    result = classes_names[int(out.item())]
    cv2.putText(frame, result, (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), thickness=2)
    cv2.imshow('frame', frame)
    cv2.waitKey(0)

if __name__ == '__main__':
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    net = FaceMaskNet()
    weights_path = r'params/new_face_mobilenet_v2.pth'
    classes_names = ['normal', 'mask']

    if os.path.exists(weights_path):
        net.load_state_dict(torch.load(weights_path, map_location='cuda:0'))
        print('successfully loading weights!')
    net.eval()

    image_cls(net,'image/img_1.png')