386478a4 by 乔峰昇

mnn推理

1 parent 8a4e9838
1 # Default ignored files
2 /shelf/
3 /workspace.xml
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
2 <module type="PYTHON_MODULE" version="4"> 2 <module type="PYTHON_MODULE" version="4">
3 <component name="NewModuleRootManager"> 3 <component name="NewModuleRootManager">
4 <content url="file://$MODULE_DIR$" /> 4 <content url="file://$MODULE_DIR$" />
5 <orderEntry type="inheritedJdk" /> 5 <orderEntry type="jdk" jdkName="Python 3.8 (gan)" jdkType="Python SDK" />
6 <orderEntry type="sourceFolder" forTests="false" /> 6 <orderEntry type="sourceFolder" forTests="false" />
7 </component> 7 </component>
8 </module> 8 </module>
...\ No newline at end of file ...\ No newline at end of file
......
1 <?xml version="1.0" encoding="UTF-8"?> 1 <?xml version="1.0" encoding="UTF-8"?>
2 <project version="4"> 2 <project version="4">
3 <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8" project-jdk-type="Python SDK" /> 3 <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (gan)" project-jdk-type="Python SDK" />
4 </project> 4 </project>
...\ No newline at end of file ...\ No newline at end of file
......
1 <?xml version="1.0" encoding="UTF-8"?>
2 <project version="4">
3 <component name="ChangeListManager">
4 <list default="true" id="8255f694-c607-4431-a530-39f2e0df4506" name="Default Changelist" comment="">
5 <change beforePath="$PROJECT_DIR$/.idea/.gitignore" beforeDir="false" />
6 <change beforePath="$PROJECT_DIR$/.idea/face_mask_classifier.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/face_mask_classifier.iml" afterDir="false" />
7 <change beforePath="$PROJECT_DIR$/.idea/inspectionProfiles/Project_Default.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/inspectionProfiles/Project_Default.xml" afterDir="false" />
8 <change beforePath="$PROJECT_DIR$/.idea/inspectionProfiles/profiles_settings.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/inspectionProfiles/profiles_settings.xml" afterDir="false" />
9 <change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
10 <change beforePath="$PROJECT_DIR$/dataset.py" beforeDir="false" />
11 <change beforePath="$PROJECT_DIR$/image/img.png" beforeDir="false" />
12 <change beforePath="$PROJECT_DIR$/image/img_1.png" beforeDir="false" />
13 <change beforePath="$PROJECT_DIR$/infer.py" beforeDir="false" afterPath="$PROJECT_DIR$/infer.py" afterDir="false" />
14 <change beforePath="$PROJECT_DIR$/net.py" beforeDir="false" />
15 <change beforePath="$PROJECT_DIR$/params/cls_face_mobilenet_v2.pth" beforeDir="false" />
16 <change beforePath="$PROJECT_DIR$/test.py" beforeDir="false" />
17 <change beforePath="$PROJECT_DIR$/train.py" beforeDir="false" afterPath="$PROJECT_DIR$/train.py" afterDir="false" />
18 <change beforePath="$PROJECT_DIR$/utils.py" beforeDir="false" />
19 </list>
20 <option name="SHOW_DIALOG" value="false" />
21 <option name="HIGHLIGHT_CONFLICTS" value="true" />
22 <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
23 <option name="LAST_RESOLUTION" value="IGNORE" />
24 </component>
25 <component name="Git.Settings">
26 <option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
27 </component>
28 <component name="ProjectId" id="24r53vDpNxHFy2XFaKewyUFOSTL" />
29 <component name="ProjectViewState">
30 <option name="hideEmptyMiddlePackages" value="true" />
31 <option name="showExcludedFiles" value="false" />
32 <option name="showLibraryContents" value="true" />
33 </component>
34 <component name="PropertiesComponent">
35 <property name="RunOnceActivity.OpenProjectViewOnStart" value="true" />
36 <property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
37 <property name="last_opened_file_path" value="$PROJECT_DIR$" />
38 <property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
39 </component>
40 <component name="RunManager">
41 <configuration name="infer_mnn" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
42 <module name="face_mask_classifier" />
43 <option name="INTERPRETER_OPTIONS" value="" />
44 <option name="PARENT_ENVS" value="true" />
45 <envs>
46 <env name="PYTHONUNBUFFERED" value="1" />
47 </envs>
48 <option name="SDK_HOME" value="" />
49 <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
50 <option name="IS_MODULE_SDK" value="true" />
51 <option name="ADD_CONTENT_ROOTS" value="true" />
52 <option name="ADD_SOURCE_ROOTS" value="true" />
53 <option name="SCRIPT_NAME" value="$PROJECT_DIR$/infer_mnn.py" />
54 <option name="PARAMETERS" value="" />
55 <option name="SHOW_COMMAND_LINE" value="false" />
56 <option name="EMULATE_TERMINAL" value="false" />
57 <option name="MODULE_MODE" value="false" />
58 <option name="REDIRECT_INPUT" value="false" />
59 <option name="INPUT_FILE" value="" />
60 <method v="2" />
61 </configuration>
62 <recent_temporary>
63 <list>
64 <item itemvalue="Python.infer_mnn" />
65 </list>
66 </recent_temporary>
67 </component>
68 <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
69 <component name="TaskManager">
70 <task active="true" id="Default" summary="Default task">
71 <changelist id="8255f694-c607-4431-a530-39f2e0df4506" name="Default Changelist" comment="" />
72 <created>1644375669136</created>
73 <option name="number" value="Default" />
74 <option name="presentableId" value="Default" />
75 <updated>1644375669136</updated>
76 </task>
77 <servers />
78 </component>
79 </project>
...\ No newline at end of file ...\ No newline at end of file
No preview for this file type
1 data_dir: "F:/data/face_mask" #数据集存放地址
2 train_rate: 0.8 #数据集划分,训练集比例
3 image_size: 128 #输入网络图像大小
4 net_type: "mobilenet_v2"
5 #支持模型[resnet18,resnet34,resnet50,resnet101,resnet152,resnext101_32x8d,resnext50_32x4d,wide_resnet50_2,wide_resnet101_2,
6 # densenet121,densenet161,densenet169,densenet201,vgg11,vgg13,vgg13_bn,vgg19,vgg19_bn,vgg16,vgg16_bn,inception_v3,
7 # mobilenet_v2,mobilenet_v3_small,mobilenet_v3_large,shufflenet_v2_x0_5,shufflenet_v2_x1_0,shufflenet_v2_x1_5,
8 # shufflenet_v2_x2_0,alexnet,googlenet,mnasnet0_5,mnasnet1_0,mnasnet1_3,mnasnet0_75,squeezenet1_0,squeezenet1_1]
9 # efficientnet-b0 ... efficientnet-b7
10 pretrained: False #是否添加预训练权重
11 batch_size: 50 #批次
12 init_lr: 0.01 #初始学习率
13 optimizer: 'Adam' #优化器 [SGD,ASGD,Adam,AdamW,Adamax,Adagrad,Adadelta,SparseAdam,LBFGS,Rprop,RMSprop]
14 class_names: [ 'mask','no_mask' ] #你的类别名称,必须和data文件夹下的类别文件名一样
15 epochs: 25 #训练总轮次
16 loss_type: "cross_entropy" # mse / l1 / smooth_l1 / cross_entropy #损失函数
17 model_dir: "./mobilenet_v2/weight/" #权重存放地址
18 log_dir: "./mobilenet_v2/logs/" # tensorboard可视化文件存放地址
...\ No newline at end of file ...\ No newline at end of file
1 import os
2 import random
3
4 import cv2
5 import torch
6 from PIL import Image
7 from torch.utils.data import Dataset, DataLoader
8 from utils import *
9 from torchvision import transforms
10
11 classes_names = ['normal', 'mask']
12
13
14 class FaceMaskDataset(Dataset):
15 def __init__(self, root_path):
16 self.transform = transforms.Compose([
17 transforms.ToTensor()
18 ])
19 self.dataset = []
20 class_names = os.listdir(root_path)
21 for cls in class_names:
22 image_names = os.listdir(os.path.join(root_path, cls))
23 for image in image_names:
24 self.dataset.append([os.path.join(root_path, cls, image), classes_names.index(cls)])
25
26 def __len__(self):
27 return len(self.dataset)
28
29 def __getitem__(self, index):
30 lights=[0.6,0.8,1,1.2,1.4,1.6]
31 data = self.dataset[index]
32 image_path = data[0]
33 image_data = keep_resize_image(image_path)
34 image_data=cv2.convertScaleAbs(image_data,alpha=lights[random.randint(0,4)])
35 image_label = data[1]
36 return self.transform(image_data), image_label
37
38
39 if __name__ == '__main__':
40 import tqdm
41 d = FaceMaskDataset('image')
42 for i in d:
43 i
44
1
1 import os 2 import os
2 3
4 from PIL import Image, ImageDraw, ImageFont
3 import cv2 5 import cv2
4
5 import torch 6 import torch
6 from PIL import Image 7 from model.utils import utils
7 import numpy as np
8 from torchvision import transforms 8 from torchvision import transforms
9 from net import * 9 from model.net.net import *
10 import argparse
10 11
12 parse = argparse.ArgumentParser('infer models')
13 parse.add_argument('demo', type=str, help='推理类型支持:image/video/camera')
14 parse.add_argument('--weights_path', type=str, default='', help='模型权重路径')
15 parse.add_argument('--image_path', type=str, default='', help='图片存放路径')
16 parse.add_argument('--video_path', type=str, default='', help='视频路径')
17 parse.add_argument('--camera_id', type=int, default=0, help='摄像头id')
11 18
12 19
13 def video(net,): 20 class ModelInfer:
14 cap=cv2.VideoCapture(0) 21 def __init__(self, config, weights_path):
22 self.config = config
23 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
25 self.transform = transforms.Compose([
26 transforms.ToTensor()
27 ])
28 self.net = ClassifierNet(self.config['net_type'], len(self.config['class_names']),
29 False).to(self.device)
30 if weights_path is not None:
31 if os.path.exists(weights_path):
32 self.net.load_state_dict(torch.load(weights_path))
33 print('successfully loading model weights!')
34 else:
35 print('no loading model weights!')
36 else:
37 print('please input weights_path!')
38 exit(0)
39 self.net.eval()
40
41 def image_infer(self, image_path):
42 image = Image.open(image_path)
43 image_data = utils.keep_shape_resize(image, self.config['image_size'])
44 image_data = self.transform(image_data)
45 image_data = torch.unsqueeze(image_data, dim=0).to(self.device)
46 out = self.net(image_data)
47 out = torch.argmax(out)
48 result = self.config['class_names'][int(out)]
49 draw = ImageDraw.Draw(image)
50 font = ImageFont.truetype(r"C:\Windows\Fonts\BRITANIC.TTF", 35)
51 draw.text((10, 10), result, font=font, fill='red')
52 image.show()
53
54 def video_infer(self, video_path):
55 cap = cv2.VideoCapture(video_path)
15 while True: 56 while True:
16 _, frame = cap.read() 57 _, frame = cap.read()
17 image = Image.fromarray(frame) 58 if _:
18 w, h = image.size 59 image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
19 temp = max(w, h) 60 image_data = Image.fromarray(image_data)
20 mask = Image.new('RGB', (temp, temp)) 61 image_data = utils.keep_shape_resize(image_data, self.config['image_size'])
21 if w >= h: 62 image_data = self.transform(image_data)
22 mask.paste(image, (0, (w - h) // 2)) 63 image_data = torch.unsqueeze(image_data, dim=0).to(self.device)
23 else: 64 out = self.net(image_data)
24 mask.paste(image, ((h - w) // 2, 0)) 65 out = torch.argmax(out)
25 mask = mask.resize((128, 128)) 66 result = self.config['class_names'][int(out)]
26 mask = np.array(mask) 67 cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2)
27 mask = cv2.cvtColor(mask, cv2.COLOR_RGB2BGR)
28 mask_image = torch.unsqueeze(transform(mask), dim=0)
29 out = net(mask_image)
30 print(out)
31 out=torch.argmax(out,dim=1)
32 result = classes_names[int(out.item())]
33 cv2.putText(frame, result, (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), thickness=2)
34 cv2.imshow('frame', frame) 68 cv2.imshow('frame', frame)
35 69 if cv2.waitKey(24) & 0XFF == ord('q'):
36 if cv2.waitKey(1) & 0XFF == ord('q'):
37 break 70 break
38 cap.release()
39 cv2.destroyAllWindows()
40
41 def image_cls(net,path):
42 frame=cv2.imread(path)
43 image = Image.fromarray(frame)
44 w, h = image.size
45 temp = max(w, h)
46 mask = Image.new('RGB', (temp, temp))
47 if w >= h:
48 mask.paste(image, (0, (w - h) // 2))
49 else: 71 else:
50 mask.paste(image, ((h - w) // 2, 0)) 72 break
51 mask = mask.resize((128, 128))
52 mask = np.array(mask)
53 mask = cv2.cvtColor(mask, cv2.COLOR_RGB2BGR)
54 mask_image = torch.unsqueeze(transform(mask), dim=0)
55 out = net(mask_image)
56 print(out)
57 out = torch.argmax(out, dim=1)
58 result = classes_names[int(out.item())]
59 cv2.putText(frame, result, (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), thickness=2)
60 cv2.imshow('frame', frame)
61 cv2.waitKey(0)
62
63 if __name__ == '__main__':
64 transform = transforms.Compose([
65 transforms.ToTensor()
66 ])
67 net = FaceMaskNet()
68 weights_path = r'params/new_face_mobilenet_v2.pth'
69 classes_names = ['normal', 'mask']
70
71 if os.path.exists(weights_path):
72 net.load_state_dict(torch.load(weights_path, map_location='cuda:0'))
73 print('successfully loading weights!')
74 net.eval()
75 73
76 image_cls(net,'image/img_1.png') 74 def camera_infer(self, camera_id):
75 cap = cv2.VideoCapture(camera_id)
76 while True:
77 _, frame = cap.read()
78 h, w, c = frame.shape
79 if _:
80 image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
81 image_data = Image.fromarray(image_data)
82 image_data = utils.keep_shape_resize(image_data, self.config['image_size'])
83 image_data = self.transform(image_data)
84 image_data = torch.unsqueeze(image_data, dim=0).to(self.device)
85 out = self.net(image_data)
86 out = torch.argmax(out)
87 result = self.config['class_names'][int(out)]
88 cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2)
89 cv2.imshow('frame', frame)
90 if cv2.waitKey(24) & 0XFF == ord('q'):
91 break
92 else:
93 break
77 94
78 95
96 if __name__ == '__main__':
97 args = parse.parse_args()
98 config = utils.load_config_util('config/config.yaml')
99 model = ModelInfer(config, args.weights_path)
100 if args.demo == 'image':
101 model.image_infer(args.image_path)
102 elif args.demo == 'video':
103 model.video_infer(args.video_path)
104 elif args.demo == 'camera':
105 model.camera_infer(args.camera_id)
106 else:
107 exit(0)
......
1
2 import os
3
4 import cv2
5 from PIL import Image, ImageFont, ImageDraw
6 import torch
7 from torchvision import transforms
8 import MNN
9
10
11 def keep_shape_resize(frame, size=128):
12 w, h = frame.size
13 temp = max(w, h)
14 mask = Image.new('RGB', (temp, temp), (0, 0, 0))
15 if w >= h:
16 position = (0, (w - h) // 2)
17 else:
18 position = ((h - w) // 2, 0)
19 mask.paste(frame, position)
20 mask = mask.resize((size, size))
21 return mask
22
23
24 def image_infer_mnn(mnn_model_path, image_path, class_list):
25 image = Image.open(image_path)
26 input_image = keep_shape_resize(image)
27 preprocess = transforms.Compose([transforms.ToTensor()])
28 input_data = preprocess(input_image)
29 interpreter = MNN.Interpreter(mnn_model_path)
30 session = interpreter.createSession()
31 input_tensor = interpreter.getSessionInput(session)
32 input_data = input_data.cpu().numpy().squeeze()
33 tmp_input = MNN.Tensor((1, 3, 128, 128), MNN.Halide_Type_Float, input_data, MNN.Tensor_DimensionType_Caffe)
34 input_tensor.copyFrom(tmp_input)
35 interpreter.runSession(session)
36 infer_result = interpreter.getSessionOutput(session)
37 output_data = infer_result.getData()
38 out = output_data.index(max(output_data))
39 draw = ImageDraw.Draw(image)
40 font = ImageFont.truetype(r"C:\Windows\Fonts\BRITANIC.TTF", 35)
41 draw.text((10, 10), class_list[int(out)], font=font, fill='red')
42 return image
43
44
45 def video_infer_mnn(mnn_model_path, video_path):
46 cap = cv2.VideoCapture(video_path)
47 interpreter = MNN.Interpreter(mnn_model_path)
48 session = interpreter.createSession()
49 input_tensor = interpreter.getSessionInput(session)
50 while True:
51 _, frame = cap.read()
52 if _:
53 image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
54 image_data = Image.fromarray(image_data)
55 image_data = keep_shape_resize(image_data, 128)
56 preprocess = transforms.Compose([transforms.ToTensor()])
57 input_data = preprocess(image_data)
58 input_data = input_data.cpu().numpy().squeeze()
59 tmp_input = MNN.Tensor((1, 3, 128, 128), MNN.Halide_Type_Float, input_data, MNN.Tensor_DimensionType_Caffe)
60 input_tensor.copyFrom(tmp_input)
61 interpreter.runSession(session)
62 infer_result = interpreter.getSessionOutput(session)
63 output_data = infer_result.getData()
64 out = output_data.index(max(output_data))
65 cv2.putText(frame, class_list[int(out)], (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2)
66 cv2.imshow('frame', frame)
67 if cv2.waitKey(24) & 0XFF == ord('q'):
68 break
69 else:
70 break
71
72
73 def camera_infer_mnn(mnn_model_path, camera_id):
74 cap = cv2.VideoCapture(camera_id)
75 interpreter = MNN.Interpreter(mnn_model_path)
76 session = interpreter.createSession()
77 input_tensor = interpreter.getSessionInput(session)
78 while True:
79 _, frame = cap.read()
80 if _:
81 image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
82 image_data = Image.fromarray(image_data)
83 image_data = keep_shape_resize(image_data, 128)
84 preprocess = transforms.Compose([transforms.ToTensor()])
85 input_data = preprocess(image_data)
86 input_data = input_data.cpu().numpy().squeeze()
87 tmp_input = MNN.Tensor((1, 3, 128, 128), MNN.Halide_Type_Float, input_data, MNN.Tensor_DimensionType_Caffe)
88 input_tensor.copyFrom(tmp_input)
89 interpreter.runSession(session)
90 infer_result = interpreter.getSessionOutput(session)
91 output_data = infer_result.getData()
92 out = output_data.index(max(output_data))
93 cv2.putText(frame, class_list[int(out)], (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2)
94 cv2.imshow('frame', frame)
95 if cv2.waitKey(24) & 0XFF == ord('q'):
96 break
97 else:
98 break
99
100
101
102 if __name__ == '__main__':
103 class_list = ['mask', 'no_mask']
104 image_path = 'test_image/mask_2997.jpg'
105 mnn_model_path = 'mobilenet_v2.mnn'
106 # image
107 # image=image_infer_mnn(mnn_model_path,image_path,class_list)
108 # image.show()
109
110 # camera
111 camera_infer_mnn(mnn_model_path, 0)
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
1 '''
2 _*_coding:utf-8 _*_
3 @Time :2022/1/28 19:00
4 @Author : qiaofengsheng
5 @File :dataset.py
6 @Software :PyCharm
7 '''
8 import os
9
10 from PIL import Image
11 from torch.utils.data import *
12 from model.utils import utils
13 from torchvision import transforms
14
15
16 class ClassDataset(Dataset):
17 def __init__(self, data_dir, config):
18 self.config = config
19 self.transform = transforms.Compose([
20 transforms.RandomRotation(60),
21 transforms.ToTensor()
22 ])
23 self.dataset = []
24 class_dirs = os.listdir(data_dir)
25 for class_dir in class_dirs:
26 image_names = os.listdir(os.path.join(data_dir, class_dir))
27 for image_name in image_names:
28 self.dataset.append(
29 [os.path.join(data_dir, class_dir, image_name),
30 int(config['class_names'].index(class_dir))])
31
32 def __len__(self):
33 return len(self.dataset)
34
35 def __getitem__(self, index):
36 data = self.dataset[index]
37 image_path, image_label = data
38 image = Image.open(image_path)
39 image = utils.keep_shape_resize(image, self.config['image_size'])
40 return self.transform(image), image_label
No preview for this file type
1 '''
2 _*_coding:utf-8 _*_
3 @Time :2022/1/28 19:05
4 @Author : qiaofengsheng
5 @File :loss_fun.py
6 @Software :PyCharm
7 '''
8
9 from torch import nn
10
11
12 class Loss:
13 def __init__(self, loss_type='mse'):
14 self.loss_fun = nn.MSELoss()
15 if loss_type == 'mse':
16 self.loss_fun = nn.MSELoss()
17 elif loss_type == 'l1':
18 self.loss_fun = nn.L1Loss()
19 elif loss_type == 'smooth_l1':
20 self.loss_fun = nn.SmoothL1Loss()
21 elif loss_type == 'cross_entropy':
22 self.loss_fun = nn.CrossEntropyLoss()
23
24 def get_loss_fun(self):
25 return self.loss_fun
No preview for this file type
1 '''
2 _*_coding:utf-8 _*_
3 @Time :2022/1/28 19:05
4 @Author : qiaofengsheng
5 @File :net.py
6 @Software :PyCharm
7 '''
8
9 import torch
10 from torchvision import models
11 from torch import nn
12 from efficientnet_pytorch import EfficientNet
13
14
15 class ClassifierNet(nn.Module):
16 def __init__(self, net_type='resnet18', num_classes=10, pretrained=False):
17 super(ClassifierNet, self).__init__()
18 self.layer = None
19 if net_type == 'resnet18': self.layer = nn.Sequential(models.resnet18(pretrained=pretrained,num_classes=num_classes), )
20 if net_type == 'resnet34': self.layer = nn.Sequential(models.resnet34(pretrained=pretrained,num_classes=num_classes), )
21 if net_type == 'resnet50': self.layer = nn.Sequential(models.resnet50(pretrained=pretrained,num_classes=num_classes), )
22 if net_type == 'resnet101': self.layer = nn.Sequential(models.resnet101(pretrained=pretrained,num_classes=num_classes), )
23 if net_type == 'resnet152': self.layer = nn.Sequential(models.resnet152(pretrained=pretrained,num_classes=num_classes), )
24 if net_type == 'resnext101_32x8d': self.layer = nn.Sequential(models.resnext101_32x8d(pretrained=pretrained,num_classes=num_classes), )
25 if net_type == 'resnext50_32x4d': self.layer = nn.Sequential(models.resnext50_32x4d(pretrained=pretrained,num_classes=num_classes), )
26 if net_type == 'wide_resnet50_2': self.layer = nn.Sequential(models.wide_resnet50_2(pretrained=pretrained,num_classes=num_classes), )
27 if net_type == 'wide_resnet101_2': self.layer = nn.Sequential(models.wide_resnet101_2(pretrained=pretrained,num_classes=num_classes), )
28 if net_type == 'densenet121': self.layer = nn.Sequential(models.densenet121(pretrained=pretrained,num_classes=num_classes), )
29 if net_type == 'densenet161': self.layer = nn.Sequential(models.densenet161(pretrained=pretrained,num_classes=num_classes), )
30 if net_type == 'densenet169': self.layer = nn.Sequential(models.densenet169(pretrained=pretrained,num_classes=num_classes), )
31 if net_type == 'densenet201': self.layer = nn.Sequential(models.densenet201(pretrained=pretrained,num_classes=num_classes), )
32 if net_type == 'vgg11': self.layer = nn.Sequential(models.vgg11(pretrained=pretrained,num_classes=num_classes), )
33 if net_type == 'vgg13': self.layer = nn.Sequential(models.vgg13(pretrained=pretrained,num_classes=num_classes), )
34 if net_type == 'vgg13_bn': self.layer = nn.Sequential(models.vgg13_bn(pretrained=pretrained,num_classes=num_classes), )
35 if net_type == 'vgg19': self.layer = nn.Sequential(models.vgg19(pretrained=pretrained,num_classes=num_classes), )
36 if net_type == 'vgg19_bn': self.layer = nn.Sequential(models.vgg19_bn(pretrained=pretrained,num_classes=num_classes), )
37 if net_type == 'vgg16': self.layer = nn.Sequential(models.vgg16(pretrained=pretrained,num_classes=num_classes), )
38 if net_type == 'vgg16_bn': self.layer = nn.Sequential(models.vgg16_bn(pretrained=pretrained,num_classes=num_classes), )
39 if net_type == 'inception_v3': self.layer = nn.Sequential(models.inception_v3(pretrained=pretrained,num_classes=num_classes), )
40 if net_type == 'mobilenet_v2': self.layer = nn.Sequential(models.mobilenet_v2(pretrained=pretrained,num_classes=num_classes), )
41 if net_type == 'mobilenet_v3_small': self.layer = nn.Sequential(
42 models.mobilenet_v3_small(pretrained=pretrained,num_classes=num_classes), )
43 if net_type == 'mobilenet_v3_large': self.layer = nn.Sequential(
44 models.mobilenet_v3_large(pretrained=pretrained,num_classes=num_classes), )
45 if net_type == 'shufflenet_v2_x0_5': self.layer = nn.Sequential(
46 models.shufflenet_v2_x0_5(pretrained=pretrained,num_classes=num_classes), )
47 if net_type == 'shufflenet_v2_x1_0': self.layer = nn.Sequential(
48 models.shufflenet_v2_x1_0(pretrained=pretrained,num_classes=num_classes), )
49 if net_type == 'shufflenet_v2_x1_5': self.layer = nn.Sequential(
50 models.shufflenet_v2_x1_5(pretrained=pretrained,num_classes=num_classes), )
51 if net_type == 'shufflenet_v2_x2_0': self.layer = nn.Sequential(
52 models.shufflenet_v2_x2_0(pretrained=pretrained,num_classes=num_classes), )
53 if net_type == 'alexnet':
54 self.layer = nn.Sequential(models.alexnet(pretrained=pretrained,num_classes=num_classes), )
55 if net_type == 'googlenet':
56 self.layer = nn.Sequential(models.googlenet(pretrained=pretrained,num_classes=num_classes), )
57 if net_type == 'mnasnet0_5':
58 self.layer = nn.Sequential(models.mnasnet0_5(pretrained=pretrained,num_classes=num_classes), )
59 if net_type == 'mnasnet1_0':
60 self.layer = nn.Sequential(models.mnasnet1_0(pretrained=pretrained,num_classes=num_classes), )
61 if net_type == 'mnasnet1_3':
62 self.layer = nn.Sequential(models.mnasnet1_3(pretrained=pretrained,num_classes=num_classes), )
63 if net_type == 'mnasnet0_75':
64 self.layer = nn.Sequential(models.mnasnet0_75(pretrained=pretrained,num_classes=num_classes), )
65 if net_type == 'squeezenet1_0':
66 self.layer = nn.Sequential(models.squeezenet1_0(pretrained=pretrained,num_classes=num_classes), )
67 if net_type == 'squeezenet1_1':
68 self.layer = nn.Sequential(models.squeezenet1_1(pretrained=pretrained,num_classes=num_classes), )
69 if net_type in ['efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', 'efficientnet-b4',
70 'efficientnet-b5', 'efficientnet-b6']:
71 if pretrained:
72 self.layer = nn.Sequential(EfficientNet.from_pretrained(net_type,num_classes=num_classes))
73 else:
74 self.layer = nn.Sequential(EfficientNet.from_name(net_type,num_classes=num_classes))
75
76 def forward(self, x):
77 return self.layer(x)
78
79 if __name__ == '__main__':
80 net=ClassifierNet('mnasnet1_0',pretrained=False)
81 x=torch.randn(1,3,125,125)
82 print(net(x).shape)
1 '''
2 _*_coding:utf-8 _*_
3 @Time :2022/1/28 19:06
4 @Author : qiaofengsheng
5 @File :optim.py
6 @Software :PyCharm
7 '''
8
9 from torch import optim
10
11
12 class Optimizer:
13 def __init__(self, net, opt_type='Adam'):
14 super(Optimizer, self).__init__()
15 self.opt = optim.Adam(net.parameters())
16 if opt_type == 'SGD':
17 self.opt = optim.SGD(net.parameters(), lr=0.01)
18 elif opt_type == 'ASGD':
19 self.opt = optim.ASGD(net.parameters())
20 elif opt_type == 'Adam':
21 self.opt = optim.Adam(net.parameters())
22 elif opt_type == 'AdamW':
23 self.opt = optim.AdamW(net.parameters())
24 elif opt_type == 'Adamax':
25 self.opt = optim.Adamax(net.parameters())
26 elif opt_type == 'Adagrad':
27 self.opt = optim.Adagrad(net.parameters())
28 elif opt_type == 'Adadelta':
29 self.opt = optim.Adadelta(net.parameters())
30 elif opt_type == 'SparseAdam':
31 self.opt = optim.SparseAdam(net.parameters())
32 elif opt_type == 'LBFGS':
33 self.opt = optim.LBFGS(net.parameters())
34 elif opt_type == 'Rprop':
35 self.opt = optim.Rprop(net.parameters())
36 elif opt_type == 'RMSprop':
37 self.opt = optim.RMSprop(net.parameters())
38
39 def get_optimizer(self):
40 return self.opt
No preview for this file type
1 '''
2 _*_coding:utf-8 _*_
3 @Time :2022/1/28 19:58
4 @Author : qiaofengsheng
5 @File :utils.py
6 @Software :PyCharm
7 '''
8 import torch
9 import yaml
10 from PIL import Image
11 from torch.nn.functional import one_hot
12
13
14 def load_config_util(config_path):
15 config_file = open(config_path, 'r', encoding='utf-8')
16 config_data = yaml.load(config_file)
17 return config_data
18
19
20 def keep_shape_resize(frame, size=128):
21 w, h = frame.size
22 temp = max(w, h)
23 mask = Image.new('RGB', (temp, temp), (0, 0, 0))
24 if w >= h:
25 position = (0, (w - h) // 2)
26 else:
27 position = ((h - w) // 2, 0)
28 mask.paste(frame, position)
29 mask = mask.resize((size, size))
30 return mask
31
32
33 def label_one_hot(label):
34 return one_hot(torch.tensor(label))
1 import torch
2 from torch import nn
3 from torchvision import models
4
5
6 class FaceMaskNet(nn.Module):
7 def __init__(self):
8 super(FaceMaskNet, self).__init__()
9 self.layer = nn.Sequential(
10 models.mobilenet_v2(pretrained=True)
11 )
12 self.classifier = nn.Sequential(
13 nn.Linear(1000, 2)
14 )
15
16 def forward(self, x):
17 return self.classifier(self.layer(x))
18
19
20
21 if __name__ == '__main__':
22 net = FaceMaskNet()
23 x = torch.randn(5, 3, 128, 128)
24 print(net(x).shape)
This file is too large to display.
1 import os
2
3 import cv2
4 import torch
5
6 from net import *
7 from dataset import *
8
9 transform = transforms.Compose([
10 transforms.ToTensor()
11 ])
12
13 data=FaceMaskDataset('/data2/face_mask')
14 d = DataLoader(data, batch_size=1000, shuffle=True)
15 with torch.no_grad():
16 for i,(image,label) in enumerate(d):
17 net=FaceMaskNet().cuda()
18 net.load_state_dict(torch.load('params/face_mobilenet_v2.pth'))
19 net.eval()
20 out=net(image.cuda())
21 out=torch.argmax(out,dim=1)
22 acc=torch.mean(torch.eq(label.cuda(), out).float()).item()
23 print(acc)
1 import os.path
2 1
3 from torch import nn, optim 2 import os.path
3 import time
4 import torch 4 import torch
5 from dataset import *
6 from torch.utils.data import random_split
7 from net import *
8 import tqdm 5 import tqdm
9 import time 6 from torch.utils.tensorboard import SummaryWriter
7 from model.net.net import ClassifierNet
8 from model.loss.loss_fun import *
9 from model.optimizer.optim import *
10 from model.dataset.dataset import *
11 import argparse
10 12
11 if __name__ == '__main__': 13 parse = argparse.ArgumentParser(description='train_demo of argparse')
12 train_rate=0.8 14 parse.add_argument('--weights_path', default=None)
13 batch_size=50
14 device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15 print(device)
16 epochs=50
17 15
18 datasets = FaceMaskDataset('/data2/new_face_mask')
19 train_datasets, test_datasets = random_split(
20 datasets,
21 [int(len(datasets) * train_rate), len(datasets) - int(len(datasets) * train_rate)],
22 )
23 print(f'train_datasets:{len(train_datasets)} test_datasets:{len(test_datasets)}')
24 train_data_loader = DataLoader(train_datasets, batch_size=batch_size, shuffle=True)
25 test_data_loader = DataLoader(test_datasets, batch_size=batch_size, shuffle=True)
26 loss_fun = nn.CrossEntropyLoss()
27
28 net = FaceMaskNet().to(device)
29 if os.path.exists('params/new_face_mobilenet_v2.pth'):
30 net.load_state_dict(torch.load('params/new_face_mobilenet_v2.pth'))
31 print('successfully loading weights!')
32 opt = optim.Adam(net.parameters())
33 16
17 class Train:
18 def __init__(self, config):
19 self.config = config
20 if not os.path.exists(config['model_dir']):
21 os.makedirs(config['model_dir'])
22 self.summary_writer = SummaryWriter(config['log_dir'])
23 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24 self.net = ClassifierNet(self.config['net_type'], len(self.config['class_names']),
25 self.config['pretrained']).to(self.device)
26 self.loss_fun = Loss(self.config['loss_type']).get_loss_fun()
27 self.optimizer = Optimizer(self.net, self.config['optimizer']).get_optimizer()
28 self.dataset = ClassDataset(self.config['data_dir'], config)
29 self.train_dataset, self.test_dataset = random_split(self.dataset,
30 [int(len(self.dataset) * config['train_rate']),
31 len(self.dataset) - int(
32 len(self.dataset) * config['train_rate'])]
33 )
34 self.train_data_loader = DataLoader(self.train_dataset, batch_size=self.config['batch_size'], shuffle=True)
35 self.test_data_loader = DataLoader(self.test_dataset, batch_size=self.config['batch_size'], shuffle=True)
34 36
35 for epoch in range(1, epochs): 37 def train(self, weights_path):
36 38 print(f'device:{self.device} 训练集:{len(self.train_dataset)} 测试集:{len(self.test_dataset)}')
37 with tqdm.tqdm(train_data_loader) as t1: 39 if weights_path is not None:
38 for i, (image_data, image_label) in enumerate(train_data_loader): 40 if os.path.exists(weights_path):
39 net.train() 41 self.net.load_state_dict(torch.load(weights_path))
40 image_data, image_label = image_data.to(device), image_label.to(device) 42 print('successfully loading model weights!')
41 out = net(image_data) 43 else:
42 train_loss = loss_fun(out, image_label) 44 print('no loading model weights')
43 opt.zero_grad() 45 temp_acc = 0
44 train_loss.backward() 46 for epoch in range(1, self.config['epochs'] + 1):
45 opt.step() 47 self.net.train()
46 t1.set_description(f'Epoch {epoch} train') 48 with tqdm.tqdm(self.train_data_loader) as t1:
47 t1.set_postfix(train_loss=train_loss.item(), 49 for i, (image_data, image_label) in enumerate(self.train_data_loader):
48 train_acc=torch.mean(torch.eq(image_label, torch.argmax(out,dim=1)).float()).item()) 50 image_data, image_label = image_data.to(self.device), image_label.to(self.device)
51 out = self.net(image_data)
52 if self.config['loss_type'] == 'cross_entropy':
53 train_loss = self.loss_fun(out, image_label)
54 else:
55 train_loss = self.loss_fun(out, utils.label_one_hot(image_label).type(torch.FloatTensor).to(
56 self.device))
57 t1.set_description(f'Train-Epoch {epoch} 轮 {i} 批次 : ')
58 t1.set_postfix(train_loss=train_loss.item())
49 time.sleep(0.1) 59 time.sleep(0.1)
50 t1.update(1) 60 t1.update(1)
51 if (i+1) % 10 == 0: 61 self.optimizer.zero_grad()
52 torch.save(net.state_dict(), 'params/new_face_mobilenet_v2.pth') 62 train_loss.backward()
53 print(f'epoch : {epoch} {i} successfully save weights!') 63 self.optimizer.step()
54 64 if i % 10 == 0:
65 torch.save(self.net.state_dict(), os.path.join(self.config['model_dir'], 'last.pth'))
66 self.summary_writer.add_scalar('train_loss', train_loss.item(), epoch)
55 67
68 self.net.eval()
56 acc, temp = 0, 0 69 acc, temp = 0, 0
57 with torch.no_grad(): 70 with torch.no_grad():
58 net.eval() 71 with tqdm.tqdm(self.test_data_loader) as t2:
59 with tqdm.tqdm(test_data_loader) as t2: 72 for j, (image_data, image_label) in enumerate(self.test_data_loader):
60 for j, (image_data, image_label) in enumerate(test_data_loader): 73 image_data, image_label = image_data.to(self.device), image_label.to(self.device)
61 image_data, image_label = image_data.to(device), image_label.to(device) 74 out = self.net(image_data)
62 out = net(image_data) 75 if self.config['loss_type'] == 'cross_entropy':
63 test_loss = loss_fun(out, image_label) 76 test_loss = self.loss_fun(out, image_label)
64 77 else:
65 t2.set_description(f'Epoch {epoch} test') 78 test_loss = self.loss_fun(out, utils.label_one_hot(image_label).type(torch.FloatTensor).to(
79 self.device))
66 out = torch.argmax(out, dim=1) 80 out = torch.argmax(out, dim=1)
67 t2.set_postfix(test_loss=test_loss.item(), 81 test_acc = torch.mean(torch.eq(out, image_label).float()).item()
68 test_acc=torch.mean(torch.eq(image_label, out).float()).item()) 82 acc += test_acc
83 temp += 1
84 t2.set_description(f'Test-Epoch {epoch} 轮 {j} 批次 : ')
85 t2.set_postfix(test_loss=test_loss.item(), test_acc=test_acc)
69 time.sleep(0.1) 86 time.sleep(0.1)
70 t2.update(1) 87 t2.update(1)
71 acc += torch.mean(torch.eq(image_label, out).float()).item() 88 print(f'Test-Epoch {epoch} 轮准确率为 : {acc / temp}')
72 temp += 1 89 if (acc / temp) > temp_acc:
73 print(f'epoch : {epoch} avg acc : ', acc / temp) 90 temp_acc = acc / temp
91 torch.save(self.net.state_dict(), os.path.join(self.config['model_dir'], 'best.pth'))
92 else:
93 torch.save(self.net.state_dict(), os.path.join(self.config['model_dir'], 'last.pth'))
94 self.summary_writer.add_scalar('test_loss', test_loss.item(), epoch)
95 self.summary_writer.add_scalar('test_acc', acc / temp, epoch)
74 96
75 # acc,temp=0,0
76 # with torch.no_grad():
77 # net.eval()
78 # for i, (image_data, image_label) in enumerate(tqdm.tqdm(test_data_loader)):
79 # image_data, image_label = image_data.to(device), image_label.to(device)
80 # out = net(image_data)
81 # test_loss = loss_fun(out, image_label)
82 #
83 # out = torch.argmax(out, dim=1)
84 #
85 # acc += torch.mean(torch.eq(image_label, out).float()).item()
86 # temp+=1
87 # if i % 5 == 0:
88 # print(f'epoch : {epoch} {i} test_loss : ', test_loss.item())
89 # print(f'epoch : {epoch} {i} test acc : ',torch.mean(torch.eq(image_label, out).float()).item())
90 # print(f'epoch : {epoch} avg acc : ',acc/temp)
...\ No newline at end of file ...\ No newline at end of file
97
98 if __name__ == '__main__':
99 args = parse.parse_args()
100 config = utils.load_config_util('config/config.yaml')
101 Train(config).train(args.weights_path)
......
1 import cv2
2 import numpy as np
3 from PIL import Image
4 # 填充黑边,等比缩放
5 def keep_resize_image(image_path,size=(128,128)):
6 image=Image.open(image_path)
7 w,h=image.size
8 temp=max(w,h)
9 mask=Image.new('RGB',(temp,temp))
10 if w>=h:
11 mask.paste(image,(0,(w-h)//2))
12 else:
13 mask.paste(image,((h-w)//2,0))
14 mask=mask.resize(size)
15 mask=np.array(mask)
16 mask=cv2.cvtColor(mask,cv2.COLOR_RGB2BGR)
17 return mask
18 if __name__ == '__main__':
19 keep_resize_image('image/mask/img.png')
...\ No newline at end of file ...\ No newline at end of file
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!