pytorch_to_onnx.py
2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
'''
_*_coding:utf-8 _*_
@Time :2022/1/29 19:00
@Author : qiaofengsheng
@File :pytorch_to_onnx.py
@Software :PyCharm
'''
import os
import sys
sys.path.append(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
import numpy as np
import torch.onnx
import torch.cuda
import onnx, onnxruntime
from model.net.net import *
from model.utils import utils
import argparse
parse = argparse.ArgumentParser(description='pack onnx model')
parse.add_argument('--config_path', type=str, default='', help='配置文件存放地址')
parse.add_argument('--weights_path', type=str, default='', help='模型权重文件地址')
def pack_onnx(model_path, config):
model = ClassifierNet(config['net_type'], len(config['class_names']),
False)
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
map_location = None
model.load_state_dict(torch.load(model_path, map_location=map_location))
model.eval()
batch_size = 1
input = torch.randn(batch_size, 3, 128, 128, requires_grad=True)
output = model(input)
torch.onnx.export(model,
input,
config['net_type'] + '.onnx',
verbose=True,
# export_params=True,
# opset_version=11,
# do_constant_folding=True,
input_names=['input'],
output_names=['output'],
# dynamic_axes={
# 'input': {0: 'batch_size'},
# 'output': {0: 'batch_size'}
# }
)
print('onnx打包成功!')
output = model(input)
onnx_model = onnx.load(config['net_type'] + '.onnx')
onnx.checker.check_model(onnx_model)
ort_session = onnxruntime.InferenceSession(config['net_type'] + '.onnx')
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy
ort_input = {ort_session.get_inputs()[0].name: to_numpy(input)}
ort_output = ort_session.run(None, ort_input)
np.testing.assert_allclose(to_numpy(output), ort_output[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")
if __name__ == '__main__':
args = parse.parse_args()
config = utils.load_config_util(args.config_path)
pack_onnx(args.weights_path, config)