4d8ada78 by 乔峰昇

Modify the zoom

1 parent b8d9812e
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.8 (gan)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
\ No newline at end of file
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="25">
<item index="0" class="java.lang.String" itemvalue="tqdm" />
<item index="1" class="java.lang.String" itemvalue="easydict" />
<item index="2" class="java.lang.String" itemvalue="scikit_image" />
<item index="3" class="java.lang.String" itemvalue="matplotlib" />
<item index="4" class="java.lang.String" itemvalue="tensorboardX" />
<item index="5" class="java.lang.String" itemvalue="torch" />
<item index="6" class="java.lang.String" itemvalue="numpy" />
<item index="7" class="java.lang.String" itemvalue="pycocotools" />
<item index="8" class="java.lang.String" itemvalue="skimage" />
<item index="9" class="java.lang.String" itemvalue="Pillow" />
<item index="10" class="java.lang.String" itemvalue="scipy" />
<item index="11" class="java.lang.String" itemvalue="torchvision" />
<item index="12" class="java.lang.String" itemvalue="opencv_python" />
<item index="13" class="java.lang.String" itemvalue="onnxruntime" />
<item index="14" class="java.lang.String" itemvalue="onnx-simplifier" />
<item index="15" class="java.lang.String" itemvalue="onnx" />
<item index="16" class="java.lang.String" itemvalue="opencv-contrib-python" />
<item index="17" class="java.lang.String" itemvalue="numba" />
<item index="18" class="java.lang.String" itemvalue="opencv-python" />
<item index="19" class="java.lang.String" itemvalue="librosa" />
<item index="20" class="java.lang.String" itemvalue="tensorboard" />
<item index="21" class="java.lang.String" itemvalue="dill" />
<item index="22" class="java.lang.String" itemvalue="pandas" />
<item index="23" class="java.lang.String" itemvalue="scikit_learn" />
<item index="24" class="java.lang.String" itemvalue="pytorch-gradual-warmup-lr" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>
\ No newline at end of file
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (gan)" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/face_mask_classifier.iml" filepath="$PROJECT_DIR$/.idea/face_mask_classifier.iml" />
</modules>
</component>
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ChangeListManager">
<list default="true" id="8255f694-c607-4431-a530-39f2e0df4506" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<list default="true" id="a64a78af-ad8c-4647-b359-632e9aeec7f0" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/face_mask_classifier.iml" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/inspectionProfiles/Project_Default.xml" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/inspectionProfiles/profiles_settings.xml" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/modules.xml" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/vcs.xml" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/cls_abnormal_face_mnn_1.0.0_v0.0.1.mnn" beforeDir="false" afterPath="$PROJECT_DIR$/cls_abnormal_face_mnn_1.0.0_v0.0.1.mnn" afterDir="false" />
<change beforePath="$PROJECT_DIR$/cls_abnormal_face_onnx_1.0.0_v0.0.1.onnx" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/infer_mnn.py" beforeDir="false" afterPath="$PROJECT_DIR$/infer_mnn.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/mobilenet_v2.mnn" beforeDir="false" afterPath="$PROJECT_DIR$/cls_abnormal_face_mnn_1.0.0_v0.0.1.mnn" afterDir="false" />
<change beforePath="$PROJECT_DIR$/mobilenet_v2.onnx" beforeDir="false" afterPath="$PROJECT_DIR$/cls_abnormal_face_onnx_1.0.0_v0.0.1.onnx" afterDir="false" />
<change beforePath="$PROJECT_DIR$/mobilenet_v2/logs/events.out.tfevents.1644369407.USER-20210707NI.5140.0" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/mobilenet_v2/weight/best.pth" beforeDir="false" afterPath="$PROJECT_DIR$/mobilenet_v2/weight/best.pth" afterDir="false" />
<change beforePath="$PROJECT_DIR$/mobilenet_v2/weight/last.pth" beforeDir="false" afterPath="$PROJECT_DIR$/mobilenet_v2/weight/last.pth" afterDir="false" />
<change beforePath="$PROJECT_DIR$/model/dataset/dataset.py" beforeDir="false" afterPath="$PROJECT_DIR$/model/dataset/dataset.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/model/utils/utils.py" beforeDir="false" afterPath="$PROJECT_DIR$/model/utils/utils.py" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
......@@ -22,15 +33,10 @@
<component name="Git.Settings">
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
</component>
<component name="GitSEFilterConfiguration">
<file-type-list>
<filtered-out-file-type name="LOCAL_BRANCH" />
<filtered-out-file-type name="REMOTE_BRANCH" />
<filtered-out-file-type name="TAG" />
<filtered-out-file-type name="COMMIT_BY_MESSAGE" />
</file-type-list>
<component name="ProjectId" id="24qs6nJrg6hqGGCf0I6cjksimJm" />
<component name="ProjectLevelVcsManager" settingsEditedManually="true">
<ConfirmationsSetting value="1" id="Add" />
</component>
<component name="ProjectId" id="24r53vDpNxHFy2XFaKewyUFOSTL" />
<component name="ProjectViewState">
<option name="hideEmptyMiddlePackages" value="true" />
<option name="showExcludedFiles" value="false" />
......@@ -39,12 +45,42 @@
<component name="PropertiesComponent">
<property name="RunOnceActivity.OpenProjectViewOnStart" value="true" />
<property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
<property name="last_opened_file_path" value="$PROJECT_DIR$/../Pytorch-Image-Classifier-Collection" />
<property name="last_opened_file_path" value="$PROJECT_DIR$" />
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
</component>
<component name="RecentsManager">
<key name="CopyFile.RECENT_KEYS">
<recent name="F:\WorkSpace\Pytorch-Image-Classifier-Collection" />
</key>
<key name="MoveFile.RECENT_KEYS">
<recent name="F:\WorkSpace\Pytorch-Image-Classifier-Collection\test_image" />
<recent name="F:\WorkSpace\Pytorch-Image-Classifier-Collection" />
</key>
</component>
<component name="RunManager" selected="Python.infer_mnn">
<configuration name="cc" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="face_mask_classifier" />
<configuration name="dataset" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="pytorch-image-classifier-collection" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/model/dataset" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/model/dataset/dataset.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="infer_mnn" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="pytorch-image-classifier-collection" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
......@@ -55,7 +91,7 @@
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/cc.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/infer_mnn.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
......@@ -64,8 +100,8 @@
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="ddd" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="face_mask_classifier" />
<configuration name="mnn_infer" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="pytorch-image-classifier-collection" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
......@@ -76,7 +112,7 @@
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/ddd.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/mnn_infer.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
......@@ -85,8 +121,8 @@
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="infer_mnn" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="face_mask_classifier" />
<configuration name="train" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="pytorch-image-classifier-collection" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
......@@ -97,36 +133,32 @@
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/infer_mnn.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/train.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="true" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<list>
<item itemvalue="Python.cc" />
<item itemvalue="Python.ddd" />
<item itemvalue="Python.infer_mnn" />
</list>
<recent_temporary>
<list>
<item itemvalue="Python.infer_mnn" />
<item itemvalue="Python.ddd" />
<item itemvalue="Python.cc" />
<item itemvalue="Python.train" />
<item itemvalue="Python.dataset" />
<item itemvalue="Python.mnn_infer" />
</list>
</recent_temporary>
</component>
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
<component name="TaskManager">
<task active="true" id="Default" summary="Default task">
<changelist id="8255f694-c607-4431-a530-39f2e0df4506" name="Default Changelist" comment="" />
<created>1644375669136</created>
<changelist id="a64a78af-ad8c-4647-b359-632e9aeec7f0" name="Default Changelist" comment="" />
<created>1644369279294</created>
<option name="number" value="Default" />
<option name="presentableId" value="Default" />
<updated>1644375669136</updated>
<updated>1644369279294</updated>
</task>
<servers />
</component>
......
import os
import cv2
from PIL import Image, ImageFont, ImageDraw
import MNN
import numpy as np
def keep_shape_resize(frame, size=128):
w, h = frame.size
temp = max(w, h)
mask = Image.new('RGB', (temp, temp), (0, 0, 0))
if w >= h:
position = (0, (w - h) // 2)
else:
position = ((h - w) // 2, 0)
mask.paste(frame, position)
mask = mask.resize((size, size))
return mask
from torchvision import transforms
import MNN
def image_infer_mnn(mnn_model_path, image_path, class_list):
image = Image.open(image_path)
input_image = keep_shape_resize(image)
input_data = np.array(input_image).astype(np.float32).transpose((2, 0, 1)) / 255
image = cv2.imread(image_path)
input_image = cv2.resize(image,(128,128))
input_data = input_image.astype(np.float32).transpose((2, 0, 1)) / 255
interpreter = MNN.Interpreter(mnn_model_path)
session = interpreter.createSession()
input_tensor = interpreter.getSessionInput(session)
......@@ -30,9 +20,8 @@ def image_infer_mnn(mnn_model_path, image_path, class_list):
infer_result = interpreter.getSessionOutput(session)
output_data = infer_result.getData()
out = output_data.index(max(output_data))
draw = ImageDraw.Draw(image)
font = ImageFont.truetype(r"C:\Windows\Fonts\BRITANIC.TTF", 35)
draw.text((10, 10), class_list[int(out)], font=font, fill='red')
cv2.putText(image,class_list[int(out)],(50, 50),cv2.FONT_HERSHEY_SIMPLEX,2,(0,0,255))
return image
......@@ -44,10 +33,8 @@ def video_infer_mnn(mnn_model_path, video_path):
while True:
_, frame = cap.read()
if _:
image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
image_data = Image.fromarray(image_data)
image_data = keep_shape_resize(image_data, 128)
input_data = np.array(image_data).astype(np.float32).transpose((2, 0, 1)) / 255
input_image = cv2.resize(frame, (128, 128))
input_data = input_image.astype(np.float32).transpose((2, 0, 1)) / 255
tmp_input = MNN.Tensor((1, 3, 128, 128), MNN.Halide_Type_Float, input_data, MNN.Tensor_DimensionType_Caffe)
input_tensor.copyFrom(tmp_input)
interpreter.runSession(session)
......@@ -70,10 +57,8 @@ def camera_infer_mnn(mnn_model_path, camera_id):
while True:
_, frame = cap.read()
if _:
image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
image_data = Image.fromarray(image_data)
image_data = keep_shape_resize(image_data, 128)
input_data = np.array(image_data).astype(np.float32).transpose((2, 0, 1)) / 255
input_image = cv2.resize(frame, (128, 128))
input_data = input_image.astype(np.float32).transpose((2, 0, 1)) / 255
tmp_input = MNN.Tensor((1, 3, 128, 128), MNN.Halide_Type_Float, input_data, MNN.Tensor_DimensionType_Caffe)
input_tensor.copyFrom(tmp_input)
interpreter.runSession(session)
......@@ -90,10 +75,13 @@ def camera_infer_mnn(mnn_model_path, camera_id):
if __name__ == '__main__':
class_list = ['mask', 'no_mask']
image_path = 'test_image/mask_2995.jpg'
image_path = 'test_image/mask_2997.jpg'
mnn_model_path = 'cls_abnormal_face_mnn_1.0.0_v0.0.1.mnn'
# image
image = image_infer_mnn(mnn_model_path, image_path, class_list)
image.show()
# for i in os.listdir('test_image'):
# image=image_infer_mnn(mnn_model_path,os.path.join('test_image',i),class_list)
# cv2.imshow('image',image)
# cv2.waitKey(0)
# camera
# camera_infer_mnn(mnn_model_path,0)
camera_infer_mnn(mnn_model_path, 0)
......
......@@ -7,6 +7,7 @@
'''
import os
import cv2
from PIL import Image
from torch.utils.data import *
from model.utils import utils
......@@ -17,7 +18,6 @@ class ClassDataset(Dataset):
def __init__(self, data_dir, config):
self.config = config
self.transform = transforms.Compose([
transforms.RandomRotation(60),
transforms.ToTensor()
])
self.dataset = []
......@@ -35,6 +35,7 @@ class ClassDataset(Dataset):
def __getitem__(self, index):
data = self.dataset[index]
image_path, image_label = data
image = Image.open(image_path)
image = utils.keep_shape_resize(image, self.config['image_size'])
image = cv2.imread(image_path)
image = cv2.resize(image, (self.config['image_size'],self.config['image_size']))
return self.transform(image), image_label
......
......@@ -17,18 +17,5 @@ def load_config_util(config_path):
return config_data
def keep_shape_resize(frame, size=128):
w, h = frame.size
temp = max(w, h)
mask = Image.new('RGB', (temp, temp), (0, 0, 0))
if w >= h:
position = (0, (w - h) // 2)
else:
position = ((h - w) // 2, 0)
mask.paste(frame, position)
mask = mask.resize((size, size))
return mask
def label_one_hot(label):
return one_hot(torch.tensor(label))
......
'''
_*_coding:utf-8 _*_
@Time :2022/1/30 10:28
@Author : qiaofengsheng
@File :pytorch_onnx_infer.py
@Software :PyCharm
'''
import os
import sys
import numpy as np
sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
import cv2
import onnxruntime
import argparse
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms
import torch
from model.utils import utils
parse = argparse.ArgumentParser(description='onnx model infer!')
parse.add_argument('demo', type=str, help='推理类型支持:image/video/camera')
parse.add_argument('--config_path', type=str, help='配置文件存放地址')
parse.add_argument('--onnx_path', type=str, default=None, help='onnx包存放路径')
parse.add_argument('--image_path', type=str, default='', help='图片存放路径')
parse.add_argument('--video_path', type=str, default='', help='视频路径')
parse.add_argument('--camera_id', type=int, default=0, help='摄像头id')
parse.add_argument('--device', type=str, default='cpu', help='默认设备cpu (暂未完善GPU代码)')
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
def onnx_infer_image(args, config):
ort_session = onnxruntime.InferenceSession(args.onnx_path)
transform = transforms.Compose([transforms.ToTensor()])
image = Image.open(args.image_path)
image_data = utils.keep_shape_resize(image, config['image_size'])
image_data = transform(image_data)
image_data = torch.unsqueeze(image_data, dim=0)
if args.device == 'cpu':
ort_input = {ort_session.get_inputs()[0].name: to_numpy(image_data)}
ort_out = ort_session.run(None, ort_input)
out = np.argmax(ort_out[0], axis=1)
result = config['class_names'][int(out)]
draw = ImageDraw.Draw(image)
font = ImageFont.truetype(r"C:\Windows\Fonts\BRITANIC.TTF", 35)
draw.text((10, 10), result, font=font, fill='red')
image.show()
elif args.device == 'cuda':
pass
else:
exit(0)
def onnx_infer_video(args, config):
ort_session = onnxruntime.InferenceSession(args.onnx_path)
transform = transforms.Compose([transforms.ToTensor()])
cap = cv2.VideoCapture(args.video_path)
while True:
_, frame = cap.read()
if _:
image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
image_data = Image.fromarray(image_data)
image_data = utils.keep_shape_resize(image_data, config['image_size'])
image_data = transform(image_data)
image_data = torch.unsqueeze(image_data, dim=0)
if args.device == 'cpu':
ort_input = {ort_session.get_inputs()[0].name: to_numpy(image_data)}
ort_out = ort_session.run(None, ort_input)
out = np.argmax(ort_out[0], axis=1)
result = config['class_names'][int(out)]
cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2)
cv2.imshow('frame', frame)
if cv2.waitKey(24) & 0XFF == ord('q'):
break
elif args.device == 'cuda':
pass
else:
exit(0)
else:
exit(0)
def onnx_infer_camera(args, config):
ort_session = onnxruntime.InferenceSession(args.onnx_path)
transform = transforms.Compose([transforms.ToTensor()])
cap = cv2.VideoCapture(args.camera_id)
while True:
_, frame = cap.read()
if _:
image_data = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
image_data = Image.fromarray(image_data)
image_data = utils.keep_shape_resize(image_data, config['image_size'])
image_data = transform(image_data)
image_data = torch.unsqueeze(image_data, dim=0)
if args.device == 'cpu':
ort_input = {ort_session.get_inputs()[0].name: to_numpy(image_data)}
ort_out = ort_session.run(None, ort_input)
out = np.argmax(ort_out[0], axis=1)
result = config['class_names'][int(out)]
cv2.putText(frame, result, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), thickness=2)
cv2.imshow('frame', frame)
if cv2.waitKey(24) & 0XFF == ord('q'):
break
elif args.device == 'cuda':
pass
else:
exit(0)
else:
exit(0)
if __name__ == '__main__':
args = parse.parse_args()
config = utils.load_config_util(args.config_path)
if args.demo == 'image':
onnx_infer_image(args, config)
elif args.demo == 'video':
onnx_infer_video(args, config)
elif args.demo == 'camera':
onnx_infer_camera(args, config)
else:
exit(0)
'''
_*_coding:utf-8 _*_
@Time :2022/1/29 19:00
@Author : qiaofengsheng
@File :pytorch_to_onnx.py
@Software :PyCharm
'''
import os
import sys
sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
import numpy as np
import torch.onnx
import torch.cuda
import onnx, onnxruntime
from model.net.net import *
from model.utils import utils
import argparse
parse = argparse.ArgumentParser(description='pack onnx model')
parse.add_argument('--config_path', type=str, default='', help='配置文件存放地址')
parse.add_argument('--weights_path', type=str, default='', help='模型权重文件地址')
def pack_onnx(model_path, config):
model = ClassifierNet(config['net_type'], len(config['class_names']),
False)
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
map_location = None
model.load_state_dict(torch.load(model_path, map_location=map_location))
model.eval()
batch_size = 1
input = torch.randn(batch_size, 3, 128, 128, requires_grad=True)
output = model(input)
torch.onnx.export(model,
input,
config['net_type'] + '.onnx',
verbose=True,
# export_params=True,
# opset_version=11,
# do_constant_folding=True,
input_names=['input'],
output_names=['output'],
# dynamic_axes={
# 'input': {0: 'batch_size'},
# 'output': {0: 'batch_size'}
# }
)
print('onnx打包成功!')
output = model(input)
onnx_model = onnx.load(config['net_type'] + '.onnx')
onnx.checker.check_model(onnx_model)
ort_session = onnxruntime.InferenceSession(config['net_type'] + '.onnx')
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy
ort_input = {ort_session.get_inputs()[0].name: to_numpy(input)}
ort_output = ort_session.run(None, ort_input)
np.testing.assert_allclose(to_numpy(output), ort_output[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")
if __name__ == '__main__':
args = parse.parse_args()
config = utils.load_config_util(args.config_path)
pack_onnx(args.weights_path, config)
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!