4d8ada78 by 乔峰昇

Modify the zoom

1 parent b8d9812e
1 <?xml version="1.0" encoding="UTF-8"?>
2 <module type="PYTHON_MODULE" version="4">
3 <component name="NewModuleRootManager">
4 <content url="file://$MODULE_DIR$" />
5 <orderEntry type="jdk" jdkName="Python 3.8 (gan)" jdkType="Python SDK" />
6 <orderEntry type="sourceFolder" forTests="false" />
7 </component>
8 </module>
...\ No newline at end of file ...\ No newline at end of file
1 <component name="InspectionProjectProfileManager">
2 <profile version="1.0">
3 <option name="myName" value="Project Default" />
4 <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5 <option name="ignoredPackages">
6 <value>
7 <list size="25">
8 <item index="0" class="java.lang.String" itemvalue="tqdm" />
9 <item index="1" class="java.lang.String" itemvalue="easydict" />
10 <item index="2" class="java.lang.String" itemvalue="scikit_image" />
11 <item index="3" class="java.lang.String" itemvalue="matplotlib" />
12 <item index="4" class="java.lang.String" itemvalue="tensorboardX" />
13 <item index="5" class="java.lang.String" itemvalue="torch" />
14 <item index="6" class="java.lang.String" itemvalue="numpy" />
15 <item index="7" class="java.lang.String" itemvalue="pycocotools" />
16 <item index="8" class="java.lang.String" itemvalue="skimage" />
17 <item index="9" class="java.lang.String" itemvalue="Pillow" />
18 <item index="10" class="java.lang.String" itemvalue="scipy" />
19 <item index="11" class="java.lang.String" itemvalue="torchvision" />
20 <item index="12" class="java.lang.String" itemvalue="opencv_python" />
21 <item index="13" class="java.lang.String" itemvalue="onnxruntime" />
22 <item index="14" class="java.lang.String" itemvalue="onnx-simplifier" />
23 <item index="15" class="java.lang.String" itemvalue="onnx" />
24 <item index="16" class="java.lang.String" itemvalue="opencv-contrib-python" />
25 <item index="17" class="java.lang.String" itemvalue="numba" />
26 <item index="18" class="java.lang.String" itemvalue="opencv-python" />
27 <item index="19" class="java.lang.String" itemvalue="librosa" />
28 <item index="20" class="java.lang.String" itemvalue="tensorboard" />
29 <item index="21" class="java.lang.String" itemvalue="dill" />
30 <item index="22" class="java.lang.String" itemvalue="pandas" />
31 <item index="23" class="java.lang.String" itemvalue="scikit_learn" />
32 <item index="24" class="java.lang.String" itemvalue="pytorch-gradual-warmup-lr" />
33 </list>
34 </value>
35 </option>
36 </inspection_tool>
37 </profile>
38 </component>
...\ No newline at end of file ...\ No newline at end of file
1 <component name="InspectionProjectProfileManager">
2 <settings>
3 <option name="USE_PROJECT_PROFILE" value="false" />
4 <version value="1.0" />
5 </settings>
6 </component>
...\ No newline at end of file ...\ No newline at end of file
1 <?xml version="1.0" encoding="UTF-8"?>
2 <project version="4">
3 <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (gan)" project-jdk-type="Python SDK" />
4 </project>
...\ No newline at end of file ...\ No newline at end of file
1 <?xml version="1.0" encoding="UTF-8"?>
2 <project version="4">
3 <component name="ProjectModuleManager">
4 <modules>
5 <module fileurl="file://$PROJECT_DIR$/.idea/face_mask_classifier.iml" filepath="$PROJECT_DIR$/.idea/face_mask_classifier.iml" />
6 </modules>
7 </component>
8 </project>
...\ No newline at end of file ...\ No newline at end of file
1 <?xml version="1.0" encoding="UTF-8"?>
2 <project version="4">
3 <component name="VcsDirectoryMappings">
4 <mapping directory="$PROJECT_DIR$" vcs="Git" />
5 </component>
6 </project>
...\ No newline at end of file ...\ No newline at end of file
1 <?xml version="1.0" encoding="UTF-8"?> 1 <?xml version="1.0" encoding="UTF-8"?>
2 <project version="4"> 2 <project version="4">
3 <component name="ChangeListManager"> 3 <component name="ChangeListManager">
4 <list default="true" id="8255f694-c607-4431-a530-39f2e0df4506" name="Default Changelist" comment=""> 4 <list default="true" id="a64a78af-ad8c-4647-b359-632e9aeec7f0" name="Default Changelist" comment="">
5 <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" /> 5 <change beforePath="$PROJECT_DIR$/.idea/face_mask_classifier.iml" beforeDir="false" />
6 <change beforePath="$PROJECT_DIR$/.idea/inspectionProfiles/Project_Default.xml" beforeDir="false" />
7 <change beforePath="$PROJECT_DIR$/.idea/inspectionProfiles/profiles_settings.xml" beforeDir="false" />
8 <change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" />
9 <change beforePath="$PROJECT_DIR$/.idea/modules.xml" beforeDir="false" />
10 <change beforePath="$PROJECT_DIR$/.idea/vcs.xml" beforeDir="false" />
11 <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" />
12 <change beforePath="$PROJECT_DIR$/cls_abnormal_face_mnn_1.0.0_v0.0.1.mnn" beforeDir="false" afterPath="$PROJECT_DIR$/cls_abnormal_face_mnn_1.0.0_v0.0.1.mnn" afterDir="false" />
13 <change beforePath="$PROJECT_DIR$/cls_abnormal_face_onnx_1.0.0_v0.0.1.onnx" beforeDir="false" />
6 <change beforePath="$PROJECT_DIR$/infer_mnn.py" beforeDir="false" afterPath="$PROJECT_DIR$/infer_mnn.py" afterDir="false" /> 14 <change beforePath="$PROJECT_DIR$/infer_mnn.py" beforeDir="false" afterPath="$PROJECT_DIR$/infer_mnn.py" afterDir="false" />
7 <change beforePath="$PROJECT_DIR$/mobilenet_v2.mnn" beforeDir="false" afterPath="$PROJECT_DIR$/cls_abnormal_face_mnn_1.0.0_v0.0.1.mnn" afterDir="false" /> 15 <change beforePath="$PROJECT_DIR$/mobilenet_v2/logs/events.out.tfevents.1644369407.USER-20210707NI.5140.0" beforeDir="false" />
8 <change beforePath="$PROJECT_DIR$/mobilenet_v2.onnx" beforeDir="false" afterPath="$PROJECT_DIR$/cls_abnormal_face_onnx_1.0.0_v0.0.1.onnx" afterDir="false" /> 16 <change beforePath="$PROJECT_DIR$/mobilenet_v2/weight/best.pth" beforeDir="false" afterPath="$PROJECT_DIR$/mobilenet_v2/weight/best.pth" afterDir="false" />
17 <change beforePath="$PROJECT_DIR$/mobilenet_v2/weight/last.pth" beforeDir="false" afterPath="$PROJECT_DIR$/mobilenet_v2/weight/last.pth" afterDir="false" />
18 <change beforePath="$PROJECT_DIR$/model/dataset/dataset.py" beforeDir="false" afterPath="$PROJECT_DIR$/model/dataset/dataset.py" afterDir="false" />
19 <change beforePath="$PROJECT_DIR$/model/utils/utils.py" beforeDir="false" afterPath="$PROJECT_DIR$/model/utils/utils.py" afterDir="false" />
9 </list> 20 </list>
10 <option name="SHOW_DIALOG" value="false" /> 21 <option name="SHOW_DIALOG" value="false" />
11 <option name="HIGHLIGHT_CONFLICTS" value="true" /> 22 <option name="HIGHLIGHT_CONFLICTS" value="true" />
...@@ -22,15 +33,10 @@ ...@@ -22,15 +33,10 @@
22 <component name="Git.Settings"> 33 <component name="Git.Settings">
23 <option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" /> 34 <option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
24 </component> 35 </component>
25 <component name="GitSEFilterConfiguration"> 36 <component name="ProjectId" id="24qs6nJrg6hqGGCf0I6cjksimJm" />
26 <file-type-list> 37 <component name="ProjectLevelVcsManager" settingsEditedManually="true">
27 <filtered-out-file-type name="LOCAL_BRANCH" /> 38 <ConfirmationsSetting value="1" id="Add" />
28 <filtered-out-file-type name="REMOTE_BRANCH" />
29 <filtered-out-file-type name="TAG" />
30 <filtered-out-file-type name="COMMIT_BY_MESSAGE" />
31 </file-type-list>
32 </component> 39 </component>
33 <component name="ProjectId" id="24r53vDpNxHFy2XFaKewyUFOSTL" />
34 <component name="ProjectViewState"> 40 <component name="ProjectViewState">
35 <option name="hideEmptyMiddlePackages" value="true" /> 41 <option name="hideEmptyMiddlePackages" value="true" />
36 <option name="showExcludedFiles" value="false" /> 42 <option name="showExcludedFiles" value="false" />
...@@ -39,12 +45,42 @@ ...@@ -39,12 +45,42 @@
39 <component name="PropertiesComponent"> 45 <component name="PropertiesComponent">
40 <property name="RunOnceActivity.OpenProjectViewOnStart" value="true" /> 46 <property name="RunOnceActivity.OpenProjectViewOnStart" value="true" />
41 <property name="RunOnceActivity.ShowReadmeOnStart" value="true" /> 47 <property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
42 <property name="last_opened_file_path" value="$PROJECT_DIR$/../Pytorch-Image-Classifier-Collection" /> 48 <property name="last_opened_file_path" value="$PROJECT_DIR$" />
43 <property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" /> 49 <property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
44 </component> 50 </component>
51 <component name="RecentsManager">
52 <key name="CopyFile.RECENT_KEYS">
53 <recent name="F:\WorkSpace\Pytorch-Image-Classifier-Collection" />
54 </key>
55 <key name="MoveFile.RECENT_KEYS">
56 <recent name="F:\WorkSpace\Pytorch-Image-Classifier-Collection\test_image" />
57 <recent name="F:\WorkSpace\Pytorch-Image-Classifier-Collection" />
58 </key>
59 </component>
45 <component name="RunManager" selected="Python.infer_mnn"> 60 <component name="RunManager" selected="Python.infer_mnn">
46 <configuration name="cc" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true"> 61 <configuration name="dataset" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
47 <module name="face_mask_classifier" /> 62 <module name="pytorch-image-classifier-collection" />
63 <option name="INTERPRETER_OPTIONS" value="" />
64 <option name="PARENT_ENVS" value="true" />
65 <envs>
66 <env name="PYTHONUNBUFFERED" value="1" />
67 </envs>
68 <option name="SDK_HOME" value="" />
69 <option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/model/dataset" />
70 <option name="IS_MODULE_SDK" value="true" />
71 <option name="ADD_CONTENT_ROOTS" value="true" />
72 <option name="ADD_SOURCE_ROOTS" value="true" />
73 <option name="SCRIPT_NAME" value="$PROJECT_DIR$/model/dataset/dataset.py" />
74 <option name="PARAMETERS" value="" />
75 <option name="SHOW_COMMAND_LINE" value="false" />
76 <option name="EMULATE_TERMINAL" value="false" />
77 <option name="MODULE_MODE" value="false" />
78 <option name="REDIRECT_INPUT" value="false" />
79 <option name="INPUT_FILE" value="" />
80 <method v="2" />
81 </configuration>
82 <configuration name="infer_mnn" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
83 <module name="pytorch-image-classifier-collection" />
48 <option name="INTERPRETER_OPTIONS" value="" /> 84 <option name="INTERPRETER_OPTIONS" value="" />
49 <option name="PARENT_ENVS" value="true" /> 85 <option name="PARENT_ENVS" value="true" />
50 <envs> 86 <envs>
...@@ -55,7 +91,7 @@ ...@@ -55,7 +91,7 @@
55 <option name="IS_MODULE_SDK" value="true" /> 91 <option name="IS_MODULE_SDK" value="true" />
56 <option name="ADD_CONTENT_ROOTS" value="true" /> 92 <option name="ADD_CONTENT_ROOTS" value="true" />
57 <option name="ADD_SOURCE_ROOTS" value="true" /> 93 <option name="ADD_SOURCE_ROOTS" value="true" />
58 <option name="SCRIPT_NAME" value="$PROJECT_DIR$/cc.py" /> 94 <option name="SCRIPT_NAME" value="$PROJECT_DIR$/infer_mnn.py" />
59 <option name="PARAMETERS" value="" /> 95 <option name="PARAMETERS" value="" />
60 <option name="SHOW_COMMAND_LINE" value="false" /> 96 <option name="SHOW_COMMAND_LINE" value="false" />
61 <option name="EMULATE_TERMINAL" value="false" /> 97 <option name="EMULATE_TERMINAL" value="false" />
...@@ -64,8 +100,8 @@ ...@@ -64,8 +100,8 @@
64 <option name="INPUT_FILE" value="" /> 100 <option name="INPUT_FILE" value="" />
65 <method v="2" /> 101 <method v="2" />
66 </configuration> 102 </configuration>
67 <configuration name="ddd" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true"> 103 <configuration name="mnn_infer" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
68 <module name="face_mask_classifier" /> 104 <module name="pytorch-image-classifier-collection" />
69 <option name="INTERPRETER_OPTIONS" value="" /> 105 <option name="INTERPRETER_OPTIONS" value="" />
70 <option name="PARENT_ENVS" value="true" /> 106 <option name="PARENT_ENVS" value="true" />
71 <envs> 107 <envs>
...@@ -76,7 +112,7 @@ ...@@ -76,7 +112,7 @@
76 <option name="IS_MODULE_SDK" value="true" /> 112 <option name="IS_MODULE_SDK" value="true" />
77 <option name="ADD_CONTENT_ROOTS" value="true" /> 113 <option name="ADD_CONTENT_ROOTS" value="true" />
78 <option name="ADD_SOURCE_ROOTS" value="true" /> 114 <option name="ADD_SOURCE_ROOTS" value="true" />
79 <option name="SCRIPT_NAME" value="$PROJECT_DIR$/ddd.py" /> 115 <option name="SCRIPT_NAME" value="$PROJECT_DIR$/mnn_infer.py" />
80 <option name="PARAMETERS" value="" /> 116 <option name="PARAMETERS" value="" />
81 <option name="SHOW_COMMAND_LINE" value="false" /> 117 <option name="SHOW_COMMAND_LINE" value="false" />
82 <option name="EMULATE_TERMINAL" value="false" /> 118 <option name="EMULATE_TERMINAL" value="false" />
...@@ -85,8 +121,8 @@ ...@@ -85,8 +121,8 @@
85 <option name="INPUT_FILE" value="" /> 121 <option name="INPUT_FILE" value="" />
86 <method v="2" /> 122 <method v="2" />
87 </configuration> 123 </configuration>
88 <configuration name="infer_mnn" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true"> 124 <configuration name="train" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
89 <module name="face_mask_classifier" /> 125 <module name="pytorch-image-classifier-collection" />
90 <option name="INTERPRETER_OPTIONS" value="" /> 126 <option name="INTERPRETER_OPTIONS" value="" />
91 <option name="PARENT_ENVS" value="true" /> 127 <option name="PARENT_ENVS" value="true" />
92 <envs> 128 <envs>
...@@ -97,36 +133,32 @@ ...@@ -97,36 +133,32 @@
97 <option name="IS_MODULE_SDK" value="true" /> 133 <option name="IS_MODULE_SDK" value="true" />
98 <option name="ADD_CONTENT_ROOTS" value="true" /> 134 <option name="ADD_CONTENT_ROOTS" value="true" />
99 <option name="ADD_SOURCE_ROOTS" value="true" /> 135 <option name="ADD_SOURCE_ROOTS" value="true" />
100 <option name="SCRIPT_NAME" value="$PROJECT_DIR$/infer_mnn.py" /> 136 <option name="SCRIPT_NAME" value="$PROJECT_DIR$/train.py" />
101 <option name="PARAMETERS" value="" /> 137 <option name="PARAMETERS" value="" />
102 <option name="SHOW_COMMAND_LINE" value="false" /> 138 <option name="SHOW_COMMAND_LINE" value="false" />
103 <option name="EMULATE_TERMINAL" value="true" /> 139 <option name="EMULATE_TERMINAL" value="false" />
104 <option name="MODULE_MODE" value="false" /> 140 <option name="MODULE_MODE" value="false" />
105 <option name="REDIRECT_INPUT" value="false" /> 141 <option name="REDIRECT_INPUT" value="false" />
106 <option name="INPUT_FILE" value="" /> 142 <option name="INPUT_FILE" value="" />
107 <method v="2" /> 143 <method v="2" />
108 </configuration> 144 </configuration>
109 <list>
110 <item itemvalue="Python.cc" />
111 <item itemvalue="Python.ddd" />
112 <item itemvalue="Python.infer_mnn" />
113 </list>
114 <recent_temporary> 145 <recent_temporary>
115 <list> 146 <list>
116 <item itemvalue="Python.infer_mnn" /> 147 <item itemvalue="Python.infer_mnn" />
117 <item itemvalue="Python.ddd" /> 148 <item itemvalue="Python.train" />
118 <item itemvalue="Python.cc" /> 149 <item itemvalue="Python.dataset" />
150 <item itemvalue="Python.mnn_infer" />
119 </list> 151 </list>
120 </recent_temporary> 152 </recent_temporary>
121 </component> 153 </component>
122 <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" /> 154 <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
123 <component name="TaskManager"> 155 <component name="TaskManager">
124 <task active="true" id="Default" summary="Default task"> 156 <task active="true" id="Default" summary="Default task">
125 <changelist id="8255f694-c607-4431-a530-39f2e0df4506" name="Default Changelist" comment="" /> 157 <changelist id="a64a78af-ad8c-4647-b359-632e9aeec7f0" name="Default Changelist" comment="" />
126 <created>1644375669136</created> 158 <created>1644369279294</created>
127 <option name="number" value="Default" /> 159 <option name="number" value="Default" />
128 <option name="presentableId" value="Default" /> 160 <option name="presentableId" value="Default" />
129 <updated>1644375669136</updated> 161 <updated>1644369279294</updated>
130 </task> 162 </task>
131 <servers /> 163 <servers />
132 </component> 164 </component>
......
1 import os
2
1 import cv2 3 import cv2
2 from PIL import Image, ImageFont, ImageDraw 4 from PIL import Image, ImageFont, ImageDraw
3 import MNN
4 import numpy as np 5 import numpy as np
5 6 from torchvision import transforms
6 7 import MNN
7 def keep_shape_resize(frame, size=128):
8 w, h = frame.size
9 temp = max(w, h)
10 mask = Image.new('RGB', (temp, temp), (0, 0, 0))
11 if w >= h:
12 position = (0, (w - h) // 2)
13 else:
14 position = ((h - w) // 2, 0)
15 mask.paste(frame, position)
16 mask = mask.resize((size, size))
17 return mask
18 8
19 9
20 def image_infer_mnn(mnn_model_path, image_path, class_list): 10 def image_infer_mnn(mnn_model_path, image_path, class_list):
21 image = Image.open(image_path) 11 image = cv2.imread(image_path)
22 input_image = keep_shape_resize(image) 12 input_image = cv2.resize(image,(128,128))
23 input_data = np.array(input_image).astype(np.float32).transpose((2, 0, 1)) / 255 13 input_data = input_image.astype(np.float32).transpose((2, 0, 1)) / 255
24 interpreter = MNN.Interpreter(mnn_model_path) 14 interpreter = MNN.Interpreter(mnn_model_path)
25 session = interpreter.createSession() 15 session = interpreter.createSession()
26 input_tensor = interpreter.getSessionInput(session) 16 input_tensor = interpreter.getSessionInput(session)
...@@ -30,9 +20,8 @@ def image_infer_mnn(mnn_model_path, image_path, class_list): ...@@ -30,9 +20,8 @@ def image_infer_mnn(mnn_model_path, image_path, class_list):
30 infer_result = interpreter.getSessionOutput(session) 20 infer_result = interpreter.getSessionOutput(session)
31 output_data = infer_result.getData() 21 output_data = infer_result.getData()
32 out = output_data.index(max(output_data)) 22 out = output_data.index(max(output_data))
33 draw = ImageDraw.Draw(image) 23
34 font = ImageFont.truetype(r"C:\Windows\Fonts\BRITANIC.TTF", 35) 24 cv2.putText(image,class_list[int(out)],(50, 50),cv2.FONT_HERSHEY_SIMPLEX,2,(0,0,255))
35 draw.text((10, 10), class_list[int(out)], font=font, fill='red')
36 return image 25 return image
37 26
38 27
...@@ -44,10 +33,8 @@ def video_infer_mnn(mnn_model_path, video_path): ...@@ -44,10 +33,8 @@ def video_infer_mnn(mnn_model_path, video_path):
44 while True: 33 while True:
45 _, frame = cap.read() 34 _, frame = cap.read()
46 if _: 35 if _:
47 image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 36 input_image = cv2.resize(frame, (128, 128))
48 image_data = Image.fromarray(image_data) 37 input_data = input_image.astype(np.float32).transpose((2, 0, 1)) / 255
49 image_data = keep_shape_resize(image_data, 128)
50 input_data = np.array(image_data).astype(np.float32).transpose((2, 0, 1)) / 255
51 tmp_input = MNN.Tensor((1, 3, 128, 128), MNN.Halide_Type_Float, input_data, MNN.Tensor_DimensionType_Caffe) 38 tmp_input = MNN.Tensor((1, 3, 128, 128), MNN.Halide_Type_Float, input_data, MNN.Tensor_DimensionType_Caffe)
52 input_tensor.copyFrom(tmp_input) 39 input_tensor.copyFrom(tmp_input)
53 interpreter.runSession(session) 40 interpreter.runSession(session)
...@@ -70,10 +57,8 @@ def camera_infer_mnn(mnn_model_path, camera_id): ...@@ -70,10 +57,8 @@ def camera_infer_mnn(mnn_model_path, camera_id):
70 while True: 57 while True:
71 _, frame = cap.read() 58 _, frame = cap.read()
72 if _: 59 if _:
73 image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 60 input_image = cv2.resize(frame, (128, 128))
74 image_data = Image.fromarray(image_data) 61 input_data = input_image.astype(np.float32).transpose((2, 0, 1)) / 255
75 image_data = keep_shape_resize(image_data, 128)
76 input_data = np.array(image_data).astype(np.float32).transpose((2, 0, 1)) / 255
77 tmp_input = MNN.Tensor((1, 3, 128, 128), MNN.Halide_Type_Float, input_data, MNN.Tensor_DimensionType_Caffe) 62 tmp_input = MNN.Tensor((1, 3, 128, 128), MNN.Halide_Type_Float, input_data, MNN.Tensor_DimensionType_Caffe)
78 input_tensor.copyFrom(tmp_input) 63 input_tensor.copyFrom(tmp_input)
79 interpreter.runSession(session) 64 interpreter.runSession(session)
...@@ -90,10 +75,13 @@ def camera_infer_mnn(mnn_model_path, camera_id): ...@@ -90,10 +75,13 @@ def camera_infer_mnn(mnn_model_path, camera_id):
90 75
91 if __name__ == '__main__': 76 if __name__ == '__main__':
92 class_list = ['mask', 'no_mask'] 77 class_list = ['mask', 'no_mask']
93 image_path = 'test_image/mask_2995.jpg' 78 image_path = 'test_image/mask_2997.jpg'
94 mnn_model_path = 'cls_abnormal_face_mnn_1.0.0_v0.0.1.mnn' 79 mnn_model_path = 'cls_abnormal_face_mnn_1.0.0_v0.0.1.mnn'
95 # image 80 # image
96 image = image_infer_mnn(mnn_model_path, image_path, class_list) 81 # for i in os.listdir('test_image'):
97 image.show() 82 # image=image_infer_mnn(mnn_model_path,os.path.join('test_image',i),class_list)
83 # cv2.imshow('image',image)
84 # cv2.waitKey(0)
85
98 # camera 86 # camera
99 # camera_infer_mnn(mnn_model_path,0) 87 camera_infer_mnn(mnn_model_path, 0)
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
7 ''' 7 '''
8 import os 8 import os
9 9
10 import cv2
10 from PIL import Image 11 from PIL import Image
11 from torch.utils.data import * 12 from torch.utils.data import *
12 from model.utils import utils 13 from model.utils import utils
...@@ -17,7 +18,6 @@ class ClassDataset(Dataset): ...@@ -17,7 +18,6 @@ class ClassDataset(Dataset):
17 def __init__(self, data_dir, config): 18 def __init__(self, data_dir, config):
18 self.config = config 19 self.config = config
19 self.transform = transforms.Compose([ 20 self.transform = transforms.Compose([
20 transforms.RandomRotation(60),
21 transforms.ToTensor() 21 transforms.ToTensor()
22 ]) 22 ])
23 self.dataset = [] 23 self.dataset = []
...@@ -35,6 +35,7 @@ class ClassDataset(Dataset): ...@@ -35,6 +35,7 @@ class ClassDataset(Dataset):
35 def __getitem__(self, index): 35 def __getitem__(self, index):
36 data = self.dataset[index] 36 data = self.dataset[index]
37 image_path, image_label = data 37 image_path, image_label = data
38 image = Image.open(image_path) 38 image = cv2.imread(image_path)
39 image = utils.keep_shape_resize(image, self.config['image_size']) 39 image = cv2.resize(image, (self.config['image_size'],self.config['image_size']))
40 return self.transform(image), image_label 40 return self.transform(image), image_label
41
......
...@@ -17,18 +17,5 @@ def load_config_util(config_path): ...@@ -17,18 +17,5 @@ def load_config_util(config_path):
17 return config_data 17 return config_data
18 18
19 19
20 def keep_shape_resize(frame, size=128):
21 w, h = frame.size
22 temp = max(w, h)
23 mask = Image.new('RGB', (temp, temp), (0, 0, 0))
24 if w >= h:
25 position = (0, (w - h) // 2)
26 else:
27 position = ((h - w) // 2, 0)
28 mask.paste(frame, position)
29 mask = mask.resize((size, size))
30 return mask
31
32
33 def label_one_hot(label): 20 def label_one_hot(label):
34 return one_hot(torch.tensor(label)) 21 return one_hot(torch.tensor(label))
......
1 '''
2 _*_coding:utf-8 _*_
3 @Time :2022/1/30 10:28
4 @Author : qiaofengsheng
5 @File :pytorch_onnx_infer.py
6 @Software :PyCharm
7 '''
8 import os
9 import sys
10
11 import numpy as np
12
13 sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
14 import cv2
15 import onnxruntime
16 import argparse
17 from PIL import Image, ImageDraw, ImageFont
18 from torchvision import transforms
19 import torch
20 from model.utils import utils
21
22 parse = argparse.ArgumentParser(description='onnx model infer!')
23 parse.add_argument('demo', type=str, help='推理类型支持:image/video/camera')
24 parse.add_argument('--config_path', type=str, help='配置文件存放地址')
25 parse.add_argument('--onnx_path', type=str, default=None, help='onnx包存放路径')
26 parse.add_argument('--image_path', type=str, default='', help='图片存放路径')
27 parse.add_argument('--video_path', type=str, default='', help='视频路径')
28 parse.add_argument('--camera_id', type=int, default=0, help='摄像头id')
29 parse.add_argument('--device', type=str, default='cpu', help='默认设备cpu (暂未完善GPU代码)')
30
31
32 def to_numpy(tensor):
33 return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
34
35
36 def onnx_infer_image(args, config):
37 ort_session = onnxruntime.InferenceSession(args.onnx_path)
38 transform = transforms.Compose([transforms.ToTensor()])
39 image = Image.open(args.image_path)
40 image_data = utils.keep_shape_resize(image, config['image_size'])
41 image_data = transform(image_data)
42 image_data = torch.unsqueeze(image_data, dim=0)
43 if args.device == 'cpu':
44 ort_input = {ort_session.get_inputs()[0].name: to_numpy(image_data)}
45 ort_out = ort_session.run(None, ort_input)
46 out = np.argmax(ort_out[0], axis=1)
47 result = config['class_names'][int(out)]
48 draw = ImageDraw.Draw(image)
49 font = ImageFont.truetype(r"C:\Windows\Fonts\BRITANIC.TTF", 35)
50 draw.text((10, 10), result, font=font, fill='red')
51 image.show()
52 elif args.device == 'cuda':
53 pass
54 else:
55 exit(0)
56
57
58 def onnx_infer_video(args, config):
59 ort_session = onnxruntime.InferenceSession(args.onnx_path)
60 transform = transforms.Compose([transforms.ToTensor()])
61 cap = cv2.VideoCapture(args.video_path)
62 while True:
63 _, frame = cap.read()
64 if _:
65 image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
66 image_data = Image.fromarray(image_data)
67 image_data = utils.keep_shape_resize(image_data, config['image_size'])
68 image_data = transform(image_data)
69 image_data = torch.unsqueeze(image_data, dim=0)
70 if args.device == 'cpu':
71 ort_input = {ort_session.get_inputs()[0].name: to_numpy(image_data)}
72 ort_out = ort_session.run(None, ort_input)
73 out = np.argmax(ort_out[0], axis=1)
74 result = config['class_names'][int(out)]
75 cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2)
76 cv2.imshow('frame', frame)
77 if cv2.waitKey(24) & 0XFF == ord('q'):
78 break
79 elif args.device == 'cuda':
80 pass
81 else:
82 exit(0)
83 else:
84 exit(0)
85
86
87 def onnx_infer_camera(args, config):
88 ort_session = onnxruntime.InferenceSession(args.onnx_path)
89 transform = transforms.Compose([transforms.ToTensor()])
90 cap = cv2.VideoCapture(args.camera_id)
91 while True:
92 _, frame = cap.read()
93 if _:
94 image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
95 image_data = Image.fromarray(image_data)
96 image_data = utils.keep_shape_resize(image_data, config['image_size'])
97 image_data = transform(image_data)
98 image_data = torch.unsqueeze(image_data, dim=0)
99 if args.device == 'cpu':
100 ort_input = {ort_session.get_inputs()[0].name: to_numpy(image_data)}
101 ort_out = ort_session.run(None, ort_input)
102 out = np.argmax(ort_out[0], axis=1)
103 result = config['class_names'][int(out)]
104 cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2)
105 cv2.imshow('frame', frame)
106 if cv2.waitKey(24) & 0XFF == ord('q'):
107 break
108 elif args.device == 'cuda':
109 pass
110 else:
111 exit(0)
112 else:
113 exit(0)
114
115
116 if __name__ == '__main__':
117 args = parse.parse_args()
118 config = utils.load_config_util(args.config_path)
119 if args.demo == 'image':
120 onnx_infer_image(args, config)
121 elif args.demo == 'video':
122 onnx_infer_video(args, config)
123 elif args.demo == 'camera':
124 onnx_infer_camera(args, config)
125 else:
126 exit(0)
1 '''
2 _*_coding:utf-8 _*_
3 @Time :2022/1/29 19:00
4 @Author : qiaofengsheng
5 @File :pytorch_to_onnx.py
6 @Software :PyCharm
7 '''
8 import os
9 import sys
10
11 sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
12 import numpy as np
13 import torch.onnx
14 import torch.cuda
15 import onnx, onnxruntime
16 from model.net.net import *
17 from model.utils import utils
18 import argparse
19
20 parse = argparse.ArgumentParser(description='pack onnx model')
21 parse.add_argument('--config_path', type=str, default='', help='配置文件存放地址')
22 parse.add_argument('--weights_path', type=str, default='', help='模型权重文件地址')
23
24
25 def pack_onnx(model_path, config):
26 model = ClassifierNet(config['net_type'], len(config['class_names']),
27 False)
28 map_location = lambda storage, loc: storage
29 if torch.cuda.is_available():
30 map_location = None
31 model.load_state_dict(torch.load(model_path, map_location=map_location))
32 model.eval()
33 batch_size = 1
34 input = torch.randn(batch_size, 3, 128, 128, requires_grad=True)
35 output = model(input)
36 torch.onnx.export(model,
37 input,
38 config['net_type'] + '.onnx',
39 verbose=True,
40 # export_params=True,
41 # opset_version=11,
42 # do_constant_folding=True,
43 input_names=['input'],
44 output_names=['output'],
45 # dynamic_axes={
46 # 'input': {0: 'batch_size'},
47 # 'output': {0: 'batch_size'}
48 # }
49 )
50 print('onnx打包成功!')
51 output = model(input)
52 onnx_model = onnx.load(config['net_type'] + '.onnx')
53 onnx.checker.check_model(onnx_model)
54 ort_session = onnxruntime.InferenceSession(config['net_type'] + '.onnx')
55
56 def to_numpy(tensor):
57 return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy
58
59 ort_input = {ort_session.get_inputs()[0].name: to_numpy(input)}
60 ort_output = ort_session.run(None, ort_input)
61
62 np.testing.assert_allclose(to_numpy(output), ort_output[0], rtol=1e-03, atol=1e-05)
63 print("Exported model has been tested with ONNXRuntime, and the result looks good!")
64
65
66 if __name__ == '__main__':
67 args = parse.parse_args()
68 config = utils.load_config_util(args.config_path)
69 pack_onnx(args.weights_path, config)
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!