-
Notifications
You must be signed in to change notification settings - Fork 0
/
torch2onnx.py
80 lines (65 loc) · 2.06 KB
/
torch2onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import argparse
import torch
import torch.onnx
from basicsr.archs.rrdbnet_arch import RRDBNet
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--input',
type=str,
required=True,
help='input pytorch model path'
)
parser.add_argument('--output',
type=str,
required=True,
help='output onnx model path'
)
parser.add_argument('--params',
action='store_false',
help='use params instead of params_ema'
)
parser.add_argument('--fp16',
action='store_true',
help='use float16 precision'
)
parser.add_argument('--opset',
type=int,
default=17,
help='onnx opset version'
)
args = parser.parse_args()
return args
def main(args):
model = RRDBNet(num_in_ch=3,
num_out_ch=3,
num_feats=64,
num_block=23,
num_grow_ch=32,
scale=4
)
if args.params:
keyname = 'params'
else:
keyname = 'params_ema'
model.load_state_dict(torch.load(args.input)[keyname])
model.train(False)
if args.fp16:
model.half()
model.cuda().eval()
x = torch.rand(1, 3, 256, 256)
if args.fp16:
x = x.half().cuda()
else:
x = x.cuda()
torch.onnx.export(model,
x,
args.output,
verbose=True,
input_names=['input'],
output_names=['output'],
opset_version=args.opset,
export_params=True
)
if __name__ == '__main__':
args = parse_args()
main()