mnn推理
Showing
37 changed files
with
615 additions
and
259 deletions
.idea/.gitignore
deleted
100644 → 0
... | @@ -2,7 +2,7 @@ | ... | @@ -2,7 +2,7 @@ |
2 | <module type="PYTHON_MODULE" version="4"> | 2 | <module type="PYTHON_MODULE" version="4"> |
3 | <component name="NewModuleRootManager"> | 3 | <component name="NewModuleRootManager"> |
4 | <content url="file://$MODULE_DIR$" /> | 4 | <content url="file://$MODULE_DIR$" /> |
5 | <orderEntry type="inheritedJdk" /> | 5 | <orderEntry type="jdk" jdkName="Python 3.8 (gan)" jdkType="Python SDK" /> |
6 | <orderEntry type="sourceFolder" forTests="false" /> | 6 | <orderEntry type="sourceFolder" forTests="false" /> |
7 | </component> | 7 | </component> |
8 | </module> | 8 | </module> |
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
1 | <?xml version="1.0" encoding="UTF-8"?> | 1 | <?xml version="1.0" encoding="UTF-8"?> |
2 | <project version="4"> | 2 | <project version="4"> |
3 | <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8" project-jdk-type="Python SDK" /> | 3 | <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (gan)" project-jdk-type="Python SDK" /> |
4 | </project> | 4 | </project> |
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
.idea/workspace.xml
0 → 100644
1 | <?xml version="1.0" encoding="UTF-8"?> | ||
2 | <project version="4"> | ||
3 | <component name="ChangeListManager"> | ||
4 | <list default="true" id="8255f694-c607-4431-a530-39f2e0df4506" name="Default Changelist" comment=""> | ||
5 | <change beforePath="$PROJECT_DIR$/.idea/.gitignore" beforeDir="false" /> | ||
6 | <change beforePath="$PROJECT_DIR$/.idea/face_mask_classifier.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/face_mask_classifier.iml" afterDir="false" /> | ||
7 | <change beforePath="$PROJECT_DIR$/.idea/inspectionProfiles/Project_Default.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/inspectionProfiles/Project_Default.xml" afterDir="false" /> | ||
8 | <change beforePath="$PROJECT_DIR$/.idea/inspectionProfiles/profiles_settings.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/inspectionProfiles/profiles_settings.xml" afterDir="false" /> | ||
9 | <change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" /> | ||
10 | <change beforePath="$PROJECT_DIR$/dataset.py" beforeDir="false" /> | ||
11 | <change beforePath="$PROJECT_DIR$/image/img.png" beforeDir="false" /> | ||
12 | <change beforePath="$PROJECT_DIR$/image/img_1.png" beforeDir="false" /> | ||
13 | <change beforePath="$PROJECT_DIR$/infer.py" beforeDir="false" afterPath="$PROJECT_DIR$/infer.py" afterDir="false" /> | ||
14 | <change beforePath="$PROJECT_DIR$/net.py" beforeDir="false" /> | ||
15 | <change beforePath="$PROJECT_DIR$/params/cls_face_mobilenet_v2.pth" beforeDir="false" /> | ||
16 | <change beforePath="$PROJECT_DIR$/test.py" beforeDir="false" /> | ||
17 | <change beforePath="$PROJECT_DIR$/train.py" beforeDir="false" afterPath="$PROJECT_DIR$/train.py" afterDir="false" /> | ||
18 | <change beforePath="$PROJECT_DIR$/utils.py" beforeDir="false" /> | ||
19 | </list> | ||
20 | <option name="SHOW_DIALOG" value="false" /> | ||
21 | <option name="HIGHLIGHT_CONFLICTS" value="true" /> | ||
22 | <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" /> | ||
23 | <option name="LAST_RESOLUTION" value="IGNORE" /> | ||
24 | </component> | ||
25 | <component name="Git.Settings"> | ||
26 | <option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" /> | ||
27 | </component> | ||
28 | <component name="ProjectId" id="24r53vDpNxHFy2XFaKewyUFOSTL" /> | ||
29 | <component name="ProjectViewState"> | ||
30 | <option name="hideEmptyMiddlePackages" value="true" /> | ||
31 | <option name="showExcludedFiles" value="false" /> | ||
32 | <option name="showLibraryContents" value="true" /> | ||
33 | </component> | ||
34 | <component name="PropertiesComponent"> | ||
35 | <property name="RunOnceActivity.OpenProjectViewOnStart" value="true" /> | ||
36 | <property name="RunOnceActivity.ShowReadmeOnStart" value="true" /> | ||
37 | <property name="last_opened_file_path" value="$PROJECT_DIR$" /> | ||
38 | <property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" /> | ||
39 | </component> | ||
40 | <component name="RunManager"> | ||
41 | <configuration name="infer_mnn" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true"> | ||
42 | <module name="face_mask_classifier" /> | ||
43 | <option name="INTERPRETER_OPTIONS" value="" /> | ||
44 | <option name="PARENT_ENVS" value="true" /> | ||
45 | <envs> | ||
46 | <env name="PYTHONUNBUFFERED" value="1" /> | ||
47 | </envs> | ||
48 | <option name="SDK_HOME" value="" /> | ||
49 | <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" /> | ||
50 | <option name="IS_MODULE_SDK" value="true" /> | ||
51 | <option name="ADD_CONTENT_ROOTS" value="true" /> | ||
52 | <option name="ADD_SOURCE_ROOTS" value="true" /> | ||
53 | <option name="SCRIPT_NAME" value="$PROJECT_DIR$/infer_mnn.py" /> | ||
54 | <option name="PARAMETERS" value="" /> | ||
55 | <option name="SHOW_COMMAND_LINE" value="false" /> | ||
56 | <option name="EMULATE_TERMINAL" value="false" /> | ||
57 | <option name="MODULE_MODE" value="false" /> | ||
58 | <option name="REDIRECT_INPUT" value="false" /> | ||
59 | <option name="INPUT_FILE" value="" /> | ||
60 | <method v="2" /> | ||
61 | </configuration> | ||
62 | <recent_temporary> | ||
63 | <list> | ||
64 | <item itemvalue="Python.infer_mnn" /> | ||
65 | </list> | ||
66 | </recent_temporary> | ||
67 | </component> | ||
68 | <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" /> | ||
69 | <component name="TaskManager"> | ||
70 | <task active="true" id="Default" summary="Default task"> | ||
71 | <changelist id="8255f694-c607-4431-a530-39f2e0df4506" name="Default Changelist" comment="" /> | ||
72 | <created>1644375669136</created> | ||
73 | <option name="number" value="Default" /> | ||
74 | <option name="presentableId" value="Default" /> | ||
75 | <updated>1644375669136</updated> | ||
76 | </task> | ||
77 | <servers /> | ||
78 | </component> | ||
79 | </project> | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
__pycache__/net.cpython-38.pyc
deleted
100644 → 0
No preview for this file type
config/config.yaml
0 → 100644
1 | data_dir: "F:/data/face_mask" #数据集存放地址 | ||
2 | train_rate: 0.8 #数据集划分,训练集比例 | ||
3 | image_size: 128 #输入网络图像大小 | ||
4 | net_type: "mobilenet_v2" | ||
5 | #支持模型[resnet18,resnet34,resnet50,resnet101,resnet152,resnext101_32x8d,resnext50_32x4d,wide_resnet50_2,wide_resnet101_2, | ||
6 | # densenet121,densenet161,densenet169,densenet201,vgg11,vgg13,vgg13_bn,vgg19,vgg19_bn,vgg16,vgg16_bn,inception_v3, | ||
7 | # mobilenet_v2,mobilenet_v3_small,mobilenet_v3_large,shufflenet_v2_x0_5,shufflenet_v2_x1_0,shufflenet_v2_x1_5, | ||
8 | # shufflenet_v2_x2_0,alexnet,googlenet,mnasnet0_5,mnasnet1_0,mnasnet1_3,mnasnet0_75,squeezenet1_0,squeezenet1_1] | ||
9 | # efficientnet-b0 ... efficientnet-b7 | ||
10 | pretrained: False #是否添加预训练权重 | ||
11 | batch_size: 50 #批次 | ||
12 | init_lr: 0.01 #初始学习率 | ||
13 | optimizer: 'Adam' #优化器 [SGD,ASGD,Adam,AdamW,Adamax,Adagrad,Adadelta,SparseAdam,LBFGS,Rprop,RMSprop] | ||
14 | class_names: [ 'mask','no_mask' ] #你的类别名称,必须和data文件夹下的类别文件名一样 | ||
15 | epochs: 25 #训练总轮次 | ||
16 | loss_type: "cross_entropy" # mse / l1 / smooth_l1 / cross_entropy #损失函数 | ||
17 | model_dir: "./mobilenet_v2/weight/" #权重存放地址 | ||
18 | log_dir: "./mobilenet_v2/logs/" # tensorboard可视化文件存放地址 | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
dataset.py
deleted
100644 → 0
1 | import os | ||
2 | import random | ||
3 | |||
4 | import cv2 | ||
5 | import torch | ||
6 | from PIL import Image | ||
7 | from torch.utils.data import Dataset, DataLoader | ||
8 | from utils import * | ||
9 | from torchvision import transforms | ||
10 | |||
11 | classes_names = ['normal', 'mask'] | ||
12 | |||
13 | |||
14 | class FaceMaskDataset(Dataset): | ||
15 | def __init__(self, root_path): | ||
16 | self.transform = transforms.Compose([ | ||
17 | transforms.ToTensor() | ||
18 | ]) | ||
19 | self.dataset = [] | ||
20 | class_names = os.listdir(root_path) | ||
21 | for cls in class_names: | ||
22 | image_names = os.listdir(os.path.join(root_path, cls)) | ||
23 | for image in image_names: | ||
24 | self.dataset.append([os.path.join(root_path, cls, image), classes_names.index(cls)]) | ||
25 | |||
26 | def __len__(self): | ||
27 | return len(self.dataset) | ||
28 | |||
29 | def __getitem__(self, index): | ||
30 | lights=[0.6,0.8,1,1.2,1.4,1.6] | ||
31 | data = self.dataset[index] | ||
32 | image_path = data[0] | ||
33 | image_data = keep_resize_image(image_path) | ||
34 | image_data=cv2.convertScaleAbs(image_data,alpha=lights[random.randint(0,4)]) | ||
35 | image_label = data[1] | ||
36 | return self.transform(image_data), image_label | ||
37 | |||
38 | |||
39 | if __name__ == '__main__': | ||
40 | import tqdm | ||
41 | d = FaceMaskDataset('image') | ||
42 | for i in d: | ||
43 | i | ||
44 |
image/img.png
deleted
100644 → 0

55 KB
image/img_1.png
deleted
100644 → 0

90.2 KB
1 | |||
1 | import os | 2 | import os |
2 | 3 | ||
4 | from PIL import Image, ImageDraw, ImageFont | ||
3 | import cv2 | 5 | import cv2 |
4 | |||
5 | import torch | 6 | import torch |
6 | from PIL import Image | 7 | from model.utils import utils |
7 | import numpy as np | ||
8 | from torchvision import transforms | 8 | from torchvision import transforms |
9 | from net import * | 9 | from model.net.net import * |
10 | import argparse | ||
10 | 11 | ||
12 | parse = argparse.ArgumentParser('infer models') | ||
13 | parse.add_argument('demo', type=str, help='推理类型支持:image/video/camera') | ||
14 | parse.add_argument('--weights_path', type=str, default='', help='模型权重路径') | ||
15 | parse.add_argument('--image_path', type=str, default='', help='图片存放路径') | ||
16 | parse.add_argument('--video_path', type=str, default='', help='视频路径') | ||
17 | parse.add_argument('--camera_id', type=int, default=0, help='摄像头id') | ||
11 | 18 | ||
12 | 19 | ||
13 | def video(net,): | 20 | class ModelInfer: |
14 | cap=cv2.VideoCapture(0) | 21 | def __init__(self, config, weights_path): |
15 | while True: | 22 | self.config = config |
16 | _, frame = cap.read() | 23 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
17 | image = Image.fromarray(frame) | ||
18 | w, h = image.size | ||
19 | temp = max(w, h) | ||
20 | mask = Image.new('RGB', (temp, temp)) | ||
21 | if w >= h: | ||
22 | mask.paste(image, (0, (w - h) // 2)) | ||
23 | else: | ||
24 | mask.paste(image, ((h - w) // 2, 0)) | ||
25 | mask = mask.resize((128, 128)) | ||
26 | mask = np.array(mask) | ||
27 | mask = cv2.cvtColor(mask, cv2.COLOR_RGB2BGR) | ||
28 | mask_image = torch.unsqueeze(transform(mask), dim=0) | ||
29 | out = net(mask_image) | ||
30 | print(out) | ||
31 | out=torch.argmax(out,dim=1) | ||
32 | result = classes_names[int(out.item())] | ||
33 | cv2.putText(frame, result, (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), thickness=2) | ||
34 | cv2.imshow('frame', frame) | ||
35 | 24 | ||
36 | if cv2.waitKey(1) & 0XFF == ord('q'): | 25 | self.transform = transforms.Compose([ |
37 | break | 26 | transforms.ToTensor() |
38 | cap.release() | 27 | ]) |
39 | cv2.destroyAllWindows() | 28 | self.net = ClassifierNet(self.config['net_type'], len(self.config['class_names']), |
40 | 29 | False).to(self.device) | |
41 | def image_cls(net,path): | 30 | if weights_path is not None: |
42 | frame=cv2.imread(path) | 31 | if os.path.exists(weights_path): |
43 | image = Image.fromarray(frame) | 32 | self.net.load_state_dict(torch.load(weights_path)) |
44 | w, h = image.size | 33 | print('successfully loading model weights!') |
45 | temp = max(w, h) | 34 | else: |
46 | mask = Image.new('RGB', (temp, temp)) | 35 | print('no loading model weights!') |
47 | if w >= h: | 36 | else: |
48 | mask.paste(image, (0, (w - h) // 2)) | 37 | print('please input weights_path!') |
49 | else: | 38 | exit(0) |
50 | mask.paste(image, ((h - w) // 2, 0)) | 39 | self.net.eval() |
51 | mask = mask.resize((128, 128)) | ||
52 | mask = np.array(mask) | ||
53 | mask = cv2.cvtColor(mask, cv2.COLOR_RGB2BGR) | ||
54 | mask_image = torch.unsqueeze(transform(mask), dim=0) | ||
55 | out = net(mask_image) | ||
56 | print(out) | ||
57 | out = torch.argmax(out, dim=1) | ||
58 | result = classes_names[int(out.item())] | ||
59 | cv2.putText(frame, result, (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), thickness=2) | ||
60 | cv2.imshow('frame', frame) | ||
61 | cv2.waitKey(0) | ||
62 | 40 | ||
63 | if __name__ == '__main__': | 41 | def image_infer(self, image_path): |
64 | transform = transforms.Compose([ | 42 | image = Image.open(image_path) |
65 | transforms.ToTensor() | 43 | image_data = utils.keep_shape_resize(image, self.config['image_size']) |
66 | ]) | 44 | image_data = self.transform(image_data) |
67 | net = FaceMaskNet() | 45 | image_data = torch.unsqueeze(image_data, dim=0).to(self.device) |
68 | weights_path = r'params/new_face_mobilenet_v2.pth' | 46 | out = self.net(image_data) |
69 | classes_names = ['normal', 'mask'] | 47 | out = torch.argmax(out) |
48 | result = self.config['class_names'][int(out)] | ||
49 | draw = ImageDraw.Draw(image) | ||
50 | font = ImageFont.truetype(r"C:\Windows\Fonts\BRITANIC.TTF", 35) | ||
51 | draw.text((10, 10), result, font=font, fill='red') | ||
52 | image.show() | ||
70 | 53 | ||
71 | if os.path.exists(weights_path): | 54 | def video_infer(self, video_path): |
72 | net.load_state_dict(torch.load(weights_path, map_location='cuda:0')) | 55 | cap = cv2.VideoCapture(video_path) |
73 | print('successfully loading weights!') | 56 | while True: |
74 | net.eval() | 57 | _, frame = cap.read() |
58 | if _: | ||
59 | image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | ||
60 | image_data = Image.fromarray(image_data) | ||
61 | image_data = utils.keep_shape_resize(image_data, self.config['image_size']) | ||
62 | image_data = self.transform(image_data) | ||
63 | image_data = torch.unsqueeze(image_data, dim=0).to(self.device) | ||
64 | out = self.net(image_data) | ||
65 | out = torch.argmax(out) | ||
66 | result = self.config['class_names'][int(out)] | ||
67 | cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2) | ||
68 | cv2.imshow('frame', frame) | ||
69 | if cv2.waitKey(24) & 0XFF == ord('q'): | ||
70 | break | ||
71 | else: | ||
72 | break | ||
75 | 73 | ||
76 | image_cls(net,'image/img_1.png') | 74 | def camera_infer(self, camera_id): |
75 | cap = cv2.VideoCapture(camera_id) | ||
76 | while True: | ||
77 | _, frame = cap.read() | ||
78 | h, w, c = frame.shape | ||
79 | if _: | ||
80 | image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | ||
81 | image_data = Image.fromarray(image_data) | ||
82 | image_data = utils.keep_shape_resize(image_data, self.config['image_size']) | ||
83 | image_data = self.transform(image_data) | ||
84 | image_data = torch.unsqueeze(image_data, dim=0).to(self.device) | ||
85 | out = self.net(image_data) | ||
86 | out = torch.argmax(out) | ||
87 | result = self.config['class_names'][int(out)] | ||
88 | cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2) | ||
89 | cv2.imshow('frame', frame) | ||
90 | if cv2.waitKey(24) & 0XFF == ord('q'): | ||
91 | break | ||
92 | else: | ||
93 | break | ||
77 | 94 | ||
78 | 95 | ||
96 | if __name__ == '__main__': | ||
97 | args = parse.parse_args() | ||
98 | config = utils.load_config_util('config/config.yaml') | ||
99 | model = ModelInfer(config, args.weights_path) | ||
100 | if args.demo == 'image': | ||
101 | model.image_infer(args.image_path) | ||
102 | elif args.demo == 'video': | ||
103 | model.video_infer(args.video_path) | ||
104 | elif args.demo == 'camera': | ||
105 | model.camera_infer(args.camera_id) | ||
106 | else: | ||
107 | exit(0) | ... | ... |
infer_mnn.py
0 → 100644
1 | |||
2 | import os | ||
3 | |||
4 | import cv2 | ||
5 | from PIL import Image, ImageFont, ImageDraw | ||
6 | import torch | ||
7 | from torchvision import transforms | ||
8 | import MNN | ||
9 | |||
10 | |||
11 | def keep_shape_resize(frame, size=128): | ||
12 | w, h = frame.size | ||
13 | temp = max(w, h) | ||
14 | mask = Image.new('RGB', (temp, temp), (0, 0, 0)) | ||
15 | if w >= h: | ||
16 | position = (0, (w - h) // 2) | ||
17 | else: | ||
18 | position = ((h - w) // 2, 0) | ||
19 | mask.paste(frame, position) | ||
20 | mask = mask.resize((size, size)) | ||
21 | return mask | ||
22 | |||
23 | |||
24 | def image_infer_mnn(mnn_model_path, image_path, class_list): | ||
25 | image = Image.open(image_path) | ||
26 | input_image = keep_shape_resize(image) | ||
27 | preprocess = transforms.Compose([transforms.ToTensor()]) | ||
28 | input_data = preprocess(input_image) | ||
29 | interpreter = MNN.Interpreter(mnn_model_path) | ||
30 | session = interpreter.createSession() | ||
31 | input_tensor = interpreter.getSessionInput(session) | ||
32 | input_data = input_data.cpu().numpy().squeeze() | ||
33 | tmp_input = MNN.Tensor((1, 3, 128, 128), MNN.Halide_Type_Float, input_data, MNN.Tensor_DimensionType_Caffe) | ||
34 | input_tensor.copyFrom(tmp_input) | ||
35 | interpreter.runSession(session) | ||
36 | infer_result = interpreter.getSessionOutput(session) | ||
37 | output_data = infer_result.getData() | ||
38 | out = output_data.index(max(output_data)) | ||
39 | draw = ImageDraw.Draw(image) | ||
40 | font = ImageFont.truetype(r"C:\Windows\Fonts\BRITANIC.TTF", 35) | ||
41 | draw.text((10, 10), class_list[int(out)], font=font, fill='red') | ||
42 | return image | ||
43 | |||
44 | |||
45 | def video_infer_mnn(mnn_model_path, video_path): | ||
46 | cap = cv2.VideoCapture(video_path) | ||
47 | interpreter = MNN.Interpreter(mnn_model_path) | ||
48 | session = interpreter.createSession() | ||
49 | input_tensor = interpreter.getSessionInput(session) | ||
50 | while True: | ||
51 | _, frame = cap.read() | ||
52 | if _: | ||
53 | image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | ||
54 | image_data = Image.fromarray(image_data) | ||
55 | image_data = keep_shape_resize(image_data, 128) | ||
56 | preprocess = transforms.Compose([transforms.ToTensor()]) | ||
57 | input_data = preprocess(image_data) | ||
58 | input_data = input_data.cpu().numpy().squeeze() | ||
59 | tmp_input = MNN.Tensor((1, 3, 128, 128), MNN.Halide_Type_Float, input_data, MNN.Tensor_DimensionType_Caffe) | ||
60 | input_tensor.copyFrom(tmp_input) | ||
61 | interpreter.runSession(session) | ||
62 | infer_result = interpreter.getSessionOutput(session) | ||
63 | output_data = infer_result.getData() | ||
64 | out = output_data.index(max(output_data)) | ||
65 | cv2.putText(frame, class_list[int(out)], (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2) | ||
66 | cv2.imshow('frame', frame) | ||
67 | if cv2.waitKey(24) & 0XFF == ord('q'): | ||
68 | break | ||
69 | else: | ||
70 | break | ||
71 | |||
72 | |||
73 | def camera_infer_mnn(mnn_model_path, camera_id): | ||
74 | cap = cv2.VideoCapture(camera_id) | ||
75 | interpreter = MNN.Interpreter(mnn_model_path) | ||
76 | session = interpreter.createSession() | ||
77 | input_tensor = interpreter.getSessionInput(session) | ||
78 | while True: | ||
79 | _, frame = cap.read() | ||
80 | if _: | ||
81 | image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | ||
82 | image_data = Image.fromarray(image_data) | ||
83 | image_data = keep_shape_resize(image_data, 128) | ||
84 | preprocess = transforms.Compose([transforms.ToTensor()]) | ||
85 | input_data = preprocess(image_data) | ||
86 | input_data = input_data.cpu().numpy().squeeze() | ||
87 | tmp_input = MNN.Tensor((1, 3, 128, 128), MNN.Halide_Type_Float, input_data, MNN.Tensor_DimensionType_Caffe) | ||
88 | input_tensor.copyFrom(tmp_input) | ||
89 | interpreter.runSession(session) | ||
90 | infer_result = interpreter.getSessionOutput(session) | ||
91 | output_data = infer_result.getData() | ||
92 | out = output_data.index(max(output_data)) | ||
93 | cv2.putText(frame, class_list[int(out)], (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2) | ||
94 | cv2.imshow('frame', frame) | ||
95 | if cv2.waitKey(24) & 0XFF == ord('q'): | ||
96 | break | ||
97 | else: | ||
98 | break | ||
99 | |||
100 | |||
101 | |||
102 | if __name__ == '__main__': | ||
103 | class_list = ['mask', 'no_mask'] | ||
104 | image_path = 'test_image/mask_2997.jpg' | ||
105 | mnn_model_path = 'mobilenet_v2.mnn' | ||
106 | # image | ||
107 | # image=image_infer_mnn(mnn_model_path,image_path,class_list) | ||
108 | # image.show() | ||
109 | |||
110 | # camera | ||
111 | camera_infer_mnn(mnn_model_path, 0) |
mobilenet_v2.mnn
0 → 100644
No preview for this file type
mobilenet_v2.onnx
0 → 100644
No preview for this file type
No preview for this file type
mobilenet_v2/weight/best.pth
0 → 100644
No preview for this file type
mobilenet_v2/weight/last.pth
0 → 100644
No preview for this file type
No preview for this file type
model/dataset/dataset.py
0 → 100644
1 | ''' | ||
2 | _*_coding:utf-8 _*_ | ||
3 | @Time :2022/1/28 19:00 | ||
4 | @Author : qiaofengsheng | ||
5 | @File :dataset.py | ||
6 | @Software :PyCharm | ||
7 | ''' | ||
8 | import os | ||
9 | |||
10 | from PIL import Image | ||
11 | from torch.utils.data import * | ||
12 | from model.utils import utils | ||
13 | from torchvision import transforms | ||
14 | |||
15 | |||
16 | class ClassDataset(Dataset): | ||
17 | def __init__(self, data_dir, config): | ||
18 | self.config = config | ||
19 | self.transform = transforms.Compose([ | ||
20 | transforms.RandomRotation(60), | ||
21 | transforms.ToTensor() | ||
22 | ]) | ||
23 | self.dataset = [] | ||
24 | class_dirs = os.listdir(data_dir) | ||
25 | for class_dir in class_dirs: | ||
26 | image_names = os.listdir(os.path.join(data_dir, class_dir)) | ||
27 | for image_name in image_names: | ||
28 | self.dataset.append( | ||
29 | [os.path.join(data_dir, class_dir, image_name), | ||
30 | int(config['class_names'].index(class_dir))]) | ||
31 | |||
32 | def __len__(self): | ||
33 | return len(self.dataset) | ||
34 | |||
35 | def __getitem__(self, index): | ||
36 | data = self.dataset[index] | ||
37 | image_path, image_label = data | ||
38 | image = Image.open(image_path) | ||
39 | image = utils.keep_shape_resize(image, self.config['image_size']) | ||
40 | return self.transform(image), image_label |
No preview for this file type
model/loss/loss_fun.py
0 → 100644
1 | ''' | ||
2 | _*_coding:utf-8 _*_ | ||
3 | @Time :2022/1/28 19:05 | ||
4 | @Author : qiaofengsheng | ||
5 | @File :loss_fun.py | ||
6 | @Software :PyCharm | ||
7 | ''' | ||
8 | |||
9 | from torch import nn | ||
10 | |||
11 | |||
12 | class Loss: | ||
13 | def __init__(self, loss_type='mse'): | ||
14 | self.loss_fun = nn.MSELoss() | ||
15 | if loss_type == 'mse': | ||
16 | self.loss_fun = nn.MSELoss() | ||
17 | elif loss_type == 'l1': | ||
18 | self.loss_fun = nn.L1Loss() | ||
19 | elif loss_type == 'smooth_l1': | ||
20 | self.loss_fun = nn.SmoothL1Loss() | ||
21 | elif loss_type == 'cross_entropy': | ||
22 | self.loss_fun = nn.CrossEntropyLoss() | ||
23 | |||
24 | def get_loss_fun(self): | ||
25 | return self.loss_fun |
model/net/__pycache__/net.cpython-38.pyc
0 → 100644
No preview for this file type
model/net/net.py
0 → 100644
1 | ''' | ||
2 | _*_coding:utf-8 _*_ | ||
3 | @Time :2022/1/28 19:05 | ||
4 | @Author : qiaofengsheng | ||
5 | @File :net.py | ||
6 | @Software :PyCharm | ||
7 | ''' | ||
8 | |||
9 | import torch | ||
10 | from torchvision import models | ||
11 | from torch import nn | ||
12 | from efficientnet_pytorch import EfficientNet | ||
13 | |||
14 | |||
15 | class ClassifierNet(nn.Module): | ||
16 | def __init__(self, net_type='resnet18', num_classes=10, pretrained=False): | ||
17 | super(ClassifierNet, self).__init__() | ||
18 | self.layer = None | ||
19 | if net_type == 'resnet18': self.layer = nn.Sequential(models.resnet18(pretrained=pretrained,num_classes=num_classes), ) | ||
20 | if net_type == 'resnet34': self.layer = nn.Sequential(models.resnet34(pretrained=pretrained,num_classes=num_classes), ) | ||
21 | if net_type == 'resnet50': self.layer = nn.Sequential(models.resnet50(pretrained=pretrained,num_classes=num_classes), ) | ||
22 | if net_type == 'resnet101': self.layer = nn.Sequential(models.resnet101(pretrained=pretrained,num_classes=num_classes), ) | ||
23 | if net_type == 'resnet152': self.layer = nn.Sequential(models.resnet152(pretrained=pretrained,num_classes=num_classes), ) | ||
24 | if net_type == 'resnext101_32x8d': self.layer = nn.Sequential(models.resnext101_32x8d(pretrained=pretrained,num_classes=num_classes), ) | ||
25 | if net_type == 'resnext50_32x4d': self.layer = nn.Sequential(models.resnext50_32x4d(pretrained=pretrained,num_classes=num_classes), ) | ||
26 | if net_type == 'wide_resnet50_2': self.layer = nn.Sequential(models.wide_resnet50_2(pretrained=pretrained,num_classes=num_classes), ) | ||
27 | if net_type == 'wide_resnet101_2': self.layer = nn.Sequential(models.wide_resnet101_2(pretrained=pretrained,num_classes=num_classes), ) | ||
28 | if net_type == 'densenet121': self.layer = nn.Sequential(models.densenet121(pretrained=pretrained,num_classes=num_classes), ) | ||
29 | if net_type == 'densenet161': self.layer = nn.Sequential(models.densenet161(pretrained=pretrained,num_classes=num_classes), ) | ||
30 | if net_type == 'densenet169': self.layer = nn.Sequential(models.densenet169(pretrained=pretrained,num_classes=num_classes), ) | ||
31 | if net_type == 'densenet201': self.layer = nn.Sequential(models.densenet201(pretrained=pretrained,num_classes=num_classes), ) | ||
32 | if net_type == 'vgg11': self.layer = nn.Sequential(models.vgg11(pretrained=pretrained,num_classes=num_classes), ) | ||
33 | if net_type == 'vgg13': self.layer = nn.Sequential(models.vgg13(pretrained=pretrained,num_classes=num_classes), ) | ||
34 | if net_type == 'vgg13_bn': self.layer = nn.Sequential(models.vgg13_bn(pretrained=pretrained,num_classes=num_classes), ) | ||
35 | if net_type == 'vgg19': self.layer = nn.Sequential(models.vgg19(pretrained=pretrained,num_classes=num_classes), ) | ||
36 | if net_type == 'vgg19_bn': self.layer = nn.Sequential(models.vgg19_bn(pretrained=pretrained,num_classes=num_classes), ) | ||
37 | if net_type == 'vgg16': self.layer = nn.Sequential(models.vgg16(pretrained=pretrained,num_classes=num_classes), ) | ||
38 | if net_type == 'vgg16_bn': self.layer = nn.Sequential(models.vgg16_bn(pretrained=pretrained,num_classes=num_classes), ) | ||
39 | if net_type == 'inception_v3': self.layer = nn.Sequential(models.inception_v3(pretrained=pretrained,num_classes=num_classes), ) | ||
40 | if net_type == 'mobilenet_v2': self.layer = nn.Sequential(models.mobilenet_v2(pretrained=pretrained,num_classes=num_classes), ) | ||
41 | if net_type == 'mobilenet_v3_small': self.layer = nn.Sequential( | ||
42 | models.mobilenet_v3_small(pretrained=pretrained,num_classes=num_classes), ) | ||
43 | if net_type == 'mobilenet_v3_large': self.layer = nn.Sequential( | ||
44 | models.mobilenet_v3_large(pretrained=pretrained,num_classes=num_classes), ) | ||
45 | if net_type == 'shufflenet_v2_x0_5': self.layer = nn.Sequential( | ||
46 | models.shufflenet_v2_x0_5(pretrained=pretrained,num_classes=num_classes), ) | ||
47 | if net_type == 'shufflenet_v2_x1_0': self.layer = nn.Sequential( | ||
48 | models.shufflenet_v2_x1_0(pretrained=pretrained,num_classes=num_classes), ) | ||
49 | if net_type == 'shufflenet_v2_x1_5': self.layer = nn.Sequential( | ||
50 | models.shufflenet_v2_x1_5(pretrained=pretrained,num_classes=num_classes), ) | ||
51 | if net_type == 'shufflenet_v2_x2_0': self.layer = nn.Sequential( | ||
52 | models.shufflenet_v2_x2_0(pretrained=pretrained,num_classes=num_classes), ) | ||
53 | if net_type == 'alexnet': | ||
54 | self.layer = nn.Sequential(models.alexnet(pretrained=pretrained,num_classes=num_classes), ) | ||
55 | if net_type == 'googlenet': | ||
56 | self.layer = nn.Sequential(models.googlenet(pretrained=pretrained,num_classes=num_classes), ) | ||
57 | if net_type == 'mnasnet0_5': | ||
58 | self.layer = nn.Sequential(models.mnasnet0_5(pretrained=pretrained,num_classes=num_classes), ) | ||
59 | if net_type == 'mnasnet1_0': | ||
60 | self.layer = nn.Sequential(models.mnasnet1_0(pretrained=pretrained,num_classes=num_classes), ) | ||
61 | if net_type == 'mnasnet1_3': | ||
62 | self.layer = nn.Sequential(models.mnasnet1_3(pretrained=pretrained,num_classes=num_classes), ) | ||
63 | if net_type == 'mnasnet0_75': | ||
64 | self.layer = nn.Sequential(models.mnasnet0_75(pretrained=pretrained,num_classes=num_classes), ) | ||
65 | if net_type == 'squeezenet1_0': | ||
66 | self.layer = nn.Sequential(models.squeezenet1_0(pretrained=pretrained,num_classes=num_classes), ) | ||
67 | if net_type == 'squeezenet1_1': | ||
68 | self.layer = nn.Sequential(models.squeezenet1_1(pretrained=pretrained,num_classes=num_classes), ) | ||
69 | if net_type in ['efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', 'efficientnet-b4', | ||
70 | 'efficientnet-b5', 'efficientnet-b6']: | ||
71 | if pretrained: | ||
72 | self.layer = nn.Sequential(EfficientNet.from_pretrained(net_type,num_classes=num_classes)) | ||
73 | else: | ||
74 | self.layer = nn.Sequential(EfficientNet.from_name(net_type,num_classes=num_classes)) | ||
75 | |||
76 | def forward(self, x): | ||
77 | return self.layer(x) | ||
78 | |||
79 | if __name__ == '__main__': | ||
80 | net=ClassifierNet('mnasnet1_0',pretrained=False) | ||
81 | x=torch.randn(1,3,125,125) | ||
82 | print(net(x).shape) |
No preview for this file type
model/optimizer/optim.py
0 → 100644
1 | ''' | ||
2 | _*_coding:utf-8 _*_ | ||
3 | @Time :2022/1/28 19:06 | ||
4 | @Author : qiaofengsheng | ||
5 | @File :optim.py | ||
6 | @Software :PyCharm | ||
7 | ''' | ||
8 | |||
9 | from torch import optim | ||
10 | |||
11 | |||
12 | class Optimizer: | ||
13 | def __init__(self, net, opt_type='Adam'): | ||
14 | super(Optimizer, self).__init__() | ||
15 | self.opt = optim.Adam(net.parameters()) | ||
16 | if opt_type == 'SGD': | ||
17 | self.opt = optim.SGD(net.parameters(), lr=0.01) | ||
18 | elif opt_type == 'ASGD': | ||
19 | self.opt = optim.ASGD(net.parameters()) | ||
20 | elif opt_type == 'Adam': | ||
21 | self.opt = optim.Adam(net.parameters()) | ||
22 | elif opt_type == 'AdamW': | ||
23 | self.opt = optim.AdamW(net.parameters()) | ||
24 | elif opt_type == 'Adamax': | ||
25 | self.opt = optim.Adamax(net.parameters()) | ||
26 | elif opt_type == 'Adagrad': | ||
27 | self.opt = optim.Adagrad(net.parameters()) | ||
28 | elif opt_type == 'Adadelta': | ||
29 | self.opt = optim.Adadelta(net.parameters()) | ||
30 | elif opt_type == 'SparseAdam': | ||
31 | self.opt = optim.SparseAdam(net.parameters()) | ||
32 | elif opt_type == 'LBFGS': | ||
33 | self.opt = optim.LBFGS(net.parameters()) | ||
34 | elif opt_type == 'Rprop': | ||
35 | self.opt = optim.Rprop(net.parameters()) | ||
36 | elif opt_type == 'RMSprop': | ||
37 | self.opt = optim.RMSprop(net.parameters()) | ||
38 | |||
39 | def get_optimizer(self): | ||
40 | return self.opt |
model/utils/__pycache__/utils.cpython-38.pyc
0 → 100644
No preview for this file type
model/utils/utils.py
0 → 100644
1 | ''' | ||
2 | _*_coding:utf-8 _*_ | ||
3 | @Time :2022/1/28 19:58 | ||
4 | @Author : qiaofengsheng | ||
5 | @File :utils.py | ||
6 | @Software :PyCharm | ||
7 | ''' | ||
8 | import torch | ||
9 | import yaml | ||
10 | from PIL import Image | ||
11 | from torch.nn.functional import one_hot | ||
12 | |||
13 | |||
14 | def load_config_util(config_path): | ||
15 | config_file = open(config_path, 'r', encoding='utf-8') | ||
16 | config_data = yaml.load(config_file) | ||
17 | return config_data | ||
18 | |||
19 | |||
20 | def keep_shape_resize(frame, size=128): | ||
21 | w, h = frame.size | ||
22 | temp = max(w, h) | ||
23 | mask = Image.new('RGB', (temp, temp), (0, 0, 0)) | ||
24 | if w >= h: | ||
25 | position = (0, (w - h) // 2) | ||
26 | else: | ||
27 | position = ((h - w) // 2, 0) | ||
28 | mask.paste(frame, position) | ||
29 | mask = mask.resize((size, size)) | ||
30 | return mask | ||
31 | |||
32 | |||
33 | def label_one_hot(label): | ||
34 | return one_hot(torch.tensor(label)) |
net.py
deleted
100644 → 0
1 | import torch | ||
2 | from torch import nn | ||
3 | from torchvision import models | ||
4 | |||
5 | |||
6 | class FaceMaskNet(nn.Module): | ||
7 | def __init__(self): | ||
8 | super(FaceMaskNet, self).__init__() | ||
9 | self.layer = nn.Sequential( | ||
10 | models.mobilenet_v2(pretrained=True) | ||
11 | ) | ||
12 | self.classifier = nn.Sequential( | ||
13 | nn.Linear(1000, 2) | ||
14 | ) | ||
15 | |||
16 | def forward(self, x): | ||
17 | return self.classifier(self.layer(x)) | ||
18 | |||
19 | |||
20 | |||
21 | if __name__ == '__main__': | ||
22 | net = FaceMaskNet() | ||
23 | x = torch.randn(5, 3, 128, 128) | ||
24 | print(net(x).shape) |
params/cls_face_mobilenet_v2.pth
deleted
100644 → 0
This file is too large to display.
test.py
deleted
100644 → 0
1 | import os | ||
2 | |||
3 | import cv2 | ||
4 | import torch | ||
5 | |||
6 | from net import * | ||
7 | from dataset import * | ||
8 | |||
9 | transform = transforms.Compose([ | ||
10 | transforms.ToTensor() | ||
11 | ]) | ||
12 | |||
13 | data=FaceMaskDataset('/data2/face_mask') | ||
14 | d = DataLoader(data, batch_size=1000, shuffle=True) | ||
15 | with torch.no_grad(): | ||
16 | for i,(image,label) in enumerate(d): | ||
17 | net=FaceMaskNet().cuda() | ||
18 | net.load_state_dict(torch.load('params/face_mobilenet_v2.pth')) | ||
19 | net.eval() | ||
20 | out=net(image.cuda()) | ||
21 | out=torch.argmax(out,dim=1) | ||
22 | acc=torch.mean(torch.eq(label.cuda(), out).float()).item() | ||
23 | print(acc) |
test_image/mask_2995.jpg
0 → 100644

8.57 KB
test_image/mask_2996.jpg
0 → 100644

10.7 KB
test_image/mask_2997.jpg
0 → 100644

12.5 KB
test_image/normal_33787.jpg
0 → 100644

11.2 KB
test_image/normal_33788.jpg
0 → 100644

54.6 KB
test_image/normal_33790.jpg
0 → 100644

49.2 KB
1 | import os.path | ||
2 | 1 | ||
3 | from torch import nn, optim | 2 | import os.path |
3 | import time | ||
4 | import torch | 4 | import torch |
5 | from dataset import * | ||
6 | from torch.utils.data import random_split | ||
7 | from net import * | ||
8 | import tqdm | 5 | import tqdm |
9 | import time | 6 | from torch.utils.tensorboard import SummaryWriter |
10 | 7 | from model.net.net import ClassifierNet | |
11 | if __name__ == '__main__': | 8 | from model.loss.loss_fun import * |
12 | train_rate=0.8 | 9 | from model.optimizer.optim import * |
13 | batch_size=50 | 10 | from model.dataset.dataset import * |
14 | device=torch.device('cuda' if torch.cuda.is_available() else 'cpu') | 11 | import argparse |
15 | print(device) | ||
16 | epochs=50 | ||
17 | 12 | ||
18 | datasets = FaceMaskDataset('/data2/new_face_mask') | 13 | parse = argparse.ArgumentParser(description='train_demo of argparse') |
19 | train_datasets, test_datasets = random_split( | 14 | parse.add_argument('--weights_path', default=None) |
20 | datasets, | ||
21 | [int(len(datasets) * train_rate), len(datasets) - int(len(datasets) * train_rate)], | ||
22 | ) | ||
23 | print(f'train_datasets:{len(train_datasets)} test_datasets:{len(test_datasets)}') | ||
24 | train_data_loader = DataLoader(train_datasets, batch_size=batch_size, shuffle=True) | ||
25 | test_data_loader = DataLoader(test_datasets, batch_size=batch_size, shuffle=True) | ||
26 | loss_fun = nn.CrossEntropyLoss() | ||
27 | 15 | ||
28 | net = FaceMaskNet().to(device) | ||
29 | if os.path.exists('params/new_face_mobilenet_v2.pth'): | ||
30 | net.load_state_dict(torch.load('params/new_face_mobilenet_v2.pth')) | ||
31 | print('successfully loading weights!') | ||
32 | opt = optim.Adam(net.parameters()) | ||
33 | 16 | ||
17 | class Train: | ||
18 | def __init__(self, config): | ||
19 | self.config = config | ||
20 | if not os.path.exists(config['model_dir']): | ||
21 | os.makedirs(config['model_dir']) | ||
22 | self.summary_writer = SummaryWriter(config['log_dir']) | ||
23 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
24 | self.net = ClassifierNet(self.config['net_type'], len(self.config['class_names']), | ||
25 | self.config['pretrained']).to(self.device) | ||
26 | self.loss_fun = Loss(self.config['loss_type']).get_loss_fun() | ||
27 | self.optimizer = Optimizer(self.net, self.config['optimizer']).get_optimizer() | ||
28 | self.dataset = ClassDataset(self.config['data_dir'], config) | ||
29 | self.train_dataset, self.test_dataset = random_split(self.dataset, | ||
30 | [int(len(self.dataset) * config['train_rate']), | ||
31 | len(self.dataset) - int( | ||
32 | len(self.dataset) * config['train_rate'])] | ||
33 | ) | ||
34 | self.train_data_loader = DataLoader(self.train_dataset, batch_size=self.config['batch_size'], shuffle=True) | ||
35 | self.test_data_loader = DataLoader(self.test_dataset, batch_size=self.config['batch_size'], shuffle=True) | ||
34 | 36 | ||
35 | for epoch in range(1, epochs): | 37 | def train(self, weights_path): |
36 | 38 | print(f'device:{self.device} 训练集:{len(self.train_dataset)} 测试集:{len(self.test_dataset)}') | |
37 | with tqdm.tqdm(train_data_loader) as t1: | 39 | if weights_path is not None: |
38 | for i, (image_data, image_label) in enumerate(train_data_loader): | 40 | if os.path.exists(weights_path): |
39 | net.train() | 41 | self.net.load_state_dict(torch.load(weights_path)) |
40 | image_data, image_label = image_data.to(device), image_label.to(device) | 42 | print('successfully loading model weights!') |
41 | out = net(image_data) | 43 | else: |
42 | train_loss = loss_fun(out, image_label) | 44 | print('no loading model weights') |
43 | opt.zero_grad() | 45 | temp_acc = 0 |
44 | train_loss.backward() | 46 | for epoch in range(1, self.config['epochs'] + 1): |
45 | opt.step() | 47 | self.net.train() |
46 | t1.set_description(f'Epoch {epoch} train') | 48 | with tqdm.tqdm(self.train_data_loader) as t1: |
47 | t1.set_postfix(train_loss=train_loss.item(), | 49 | for i, (image_data, image_label) in enumerate(self.train_data_loader): |
48 | train_acc=torch.mean(torch.eq(image_label, torch.argmax(out,dim=1)).float()).item()) | 50 | image_data, image_label = image_data.to(self.device), image_label.to(self.device) |
49 | time.sleep(0.1) | 51 | out = self.net(image_data) |
50 | t1.update(1) | 52 | if self.config['loss_type'] == 'cross_entropy': |
51 | if (i+1) % 10 == 0: | 53 | train_loss = self.loss_fun(out, image_label) |
52 | torch.save(net.state_dict(), 'params/new_face_mobilenet_v2.pth') | 54 | else: |
53 | print(f'epoch : {epoch} {i} successfully save weights!') | 55 | train_loss = self.loss_fun(out, utils.label_one_hot(image_label).type(torch.FloatTensor).to( |
56 | self.device)) | ||
57 | t1.set_description(f'Train-Epoch {epoch} 轮 {i} 批次 : ') | ||
58 | t1.set_postfix(train_loss=train_loss.item()) | ||
59 | time.sleep(0.1) | ||
60 | t1.update(1) | ||
61 | self.optimizer.zero_grad() | ||
62 | train_loss.backward() | ||
63 | self.optimizer.step() | ||
64 | if i % 10 == 0: | ||
65 | torch.save(self.net.state_dict(), os.path.join(self.config['model_dir'], 'last.pth')) | ||
66 | self.summary_writer.add_scalar('train_loss', train_loss.item(), epoch) | ||
54 | 67 | ||
68 | self.net.eval() | ||
69 | acc, temp = 0, 0 | ||
70 | with torch.no_grad(): | ||
71 | with tqdm.tqdm(self.test_data_loader) as t2: | ||
72 | for j, (image_data, image_label) in enumerate(self.test_data_loader): | ||
73 | image_data, image_label = image_data.to(self.device), image_label.to(self.device) | ||
74 | out = self.net(image_data) | ||
75 | if self.config['loss_type'] == 'cross_entropy': | ||
76 | test_loss = self.loss_fun(out, image_label) | ||
77 | else: | ||
78 | test_loss = self.loss_fun(out, utils.label_one_hot(image_label).type(torch.FloatTensor).to( | ||
79 | self.device)) | ||
80 | out = torch.argmax(out, dim=1) | ||
81 | test_acc = torch.mean(torch.eq(out, image_label).float()).item() | ||
82 | acc += test_acc | ||
83 | temp += 1 | ||
84 | t2.set_description(f'Test-Epoch {epoch} 轮 {j} 批次 : ') | ||
85 | t2.set_postfix(test_loss=test_loss.item(), test_acc=test_acc) | ||
86 | time.sleep(0.1) | ||
87 | t2.update(1) | ||
88 | print(f'Test-Epoch {epoch} 轮准确率为 : {acc / temp}') | ||
89 | if (acc / temp) > temp_acc: | ||
90 | temp_acc = acc / temp | ||
91 | torch.save(self.net.state_dict(), os.path.join(self.config['model_dir'], 'best.pth')) | ||
92 | else: | ||
93 | torch.save(self.net.state_dict(), os.path.join(self.config['model_dir'], 'last.pth')) | ||
94 | self.summary_writer.add_scalar('test_loss', test_loss.item(), epoch) | ||
95 | self.summary_writer.add_scalar('test_acc', acc / temp, epoch) | ||
55 | 96 | ||
56 | acc, temp = 0, 0 | ||
57 | with torch.no_grad(): | ||
58 | net.eval() | ||
59 | with tqdm.tqdm(test_data_loader) as t2: | ||
60 | for j, (image_data, image_label) in enumerate(test_data_loader): | ||
61 | image_data, image_label = image_data.to(device), image_label.to(device) | ||
62 | out = net(image_data) | ||
63 | test_loss = loss_fun(out, image_label) | ||
64 | 97 | ||
65 | t2.set_description(f'Epoch {epoch} test') | ||
66 | out = torch.argmax(out, dim=1) | ||
67 | t2.set_postfix(test_loss=test_loss.item(), | ||
68 | test_acc=torch.mean(torch.eq(image_label, out).float()).item()) | ||
69 | time.sleep(0.1) | ||
70 | t2.update(1) | ||
71 | acc += torch.mean(torch.eq(image_label, out).float()).item() | ||
72 | temp += 1 | ||
73 | print(f'epoch : {epoch} avg acc : ', acc / temp) | ||
74 | |||
75 | # acc,temp=0,0 | ||
76 | # with torch.no_grad(): | ||
77 | # net.eval() | ||
78 | # for i, (image_data, image_label) in enumerate(tqdm.tqdm(test_data_loader)): | ||
79 | # image_data, image_label = image_data.to(device), image_label.to(device) | ||
80 | # out = net(image_data) | ||
81 | # test_loss = loss_fun(out, image_label) | ||
82 | # | ||
83 | # out = torch.argmax(out, dim=1) | ||
84 | # | ||
85 | # acc += torch.mean(torch.eq(image_label, out).float()).item() | ||
86 | # temp+=1 | ||
87 | # if i % 5 == 0: | ||
88 | # print(f'epoch : {epoch} {i} test_loss : ', test_loss.item()) | ||
89 | # print(f'epoch : {epoch} {i} test acc : ',torch.mean(torch.eq(image_label, out).float()).item()) | ||
90 | # print(f'epoch : {epoch} avg acc : ',acc/temp) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
98 | if __name__ == '__main__': | ||
99 | args = parse.parse_args() | ||
100 | config = utils.load_config_util('config/config.yaml') | ||
101 | Train(config).train(args.weights_path) | ... | ... |
utils.py
deleted
100644 → 0
1 | import cv2 | ||
2 | import numpy as np | ||
3 | from PIL import Image | ||
4 | # 填充黑边,等比缩放 | ||
5 | def keep_resize_image(image_path,size=(128,128)): | ||
6 | image=Image.open(image_path) | ||
7 | w,h=image.size | ||
8 | temp=max(w,h) | ||
9 | mask=Image.new('RGB',(temp,temp)) | ||
10 | if w>=h: | ||
11 | mask.paste(image,(0,(w-h)//2)) | ||
12 | else: | ||
13 | mask.paste(image,((h-w)//2,0)) | ||
14 | mask=mask.resize(size) | ||
15 | mask=np.array(mask) | ||
16 | mask=cv2.cvtColor(mask,cv2.COLOR_RGB2BGR) | ||
17 | return mask | ||
18 | if __name__ == '__main__': | ||
19 | keep_resize_image('image/mask/img.png') | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or sign in to post a comment