-
Notifications
You must be signed in to change notification settings - Fork 57
/
Copy pathpytorch_example.py
37 lines (29 loc) · 1.17 KB
/
pytorch_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torchvision
import onnx_tool
import torch
tmpfile = 'tmp.onnx'
def alexnet():
model = torchvision.models.alexnet()
model.eval()
x = torch.rand(1, 3, 224, 224)
with torch.no_grad():
torch_out = torch.onnx.export(model, x, tmpfile, opset_version=12) # opset 12 and opset 7 tested
# do not use dynamic axes will simplify the process
onnx_tool.model_profile(tmpfile, verbose=False)
def convnext_large():
model = torchvision.models.convnext_large()
model.eval()
x = torch.rand(1, 3, 224, 224)
t = torch.jit.trace(model, x)
torch.jit.save(t, 'convx.pth')
with torch.no_grad():
torch_out = torch.onnx.export(model, x, tmpfile, opset_version=12) # opset 12 and opset 7 tested
# do not use dynamic axes will simplify the process
onnx_tool.model_profile(tmpfile, verbose=False)
def ssd300_vgg16():
dummy_input = torch.randn(1, 3, 300, 300)
model = torchvision.models.detection.ssd300_vgg16(weights=torchvision.models.detection.ssd.SSD300_VGG16_Weights.DEFAULT)
model.eval()
torch.onnx.export(model, dummy_input, "ssd300_vgg16.onnx",opset_version=11)
# convnext_large()
ssd300_vgg16()