-
Notifications
You must be signed in to change notification settings - Fork 332
partialconv
sngyo edited this page Nov 4, 2019
·
2 revisions
Export script for partialconv
import argparse
import os
import torch
import torchvision.models as models_baseline # networks with zero padding
import models as models_partial # partial conv based padding
model_baseline_names = sorted(name for name in models_baseline.__dict__
if name.islower() and not name.startswith("__")
and callable(models_baseline.__dict__[name]))
model_partial_names = sorted(name for name in models_partial.__dict__
if name.islower() and not name.startswith("__")
and callable(models_partial.__dict__[name]))
model_names = model_baseline_names + model_partial_names
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet50)')
def main():
global args
args = parser.parse_args()
print("=> using pre-trained model '{}'".format(args.arch))
if args.arch in models_baseline.__dict__:
model = models_baseline.__dict__[args.arch](pretrained=True)
else:
model = models_partial.__dict__[args.arch](pretrained=True)
print(model)
model.eval()
dummy = torch.autograd.Variable(torch.randn(1, 3, 224, 224))
out = model(dummy)
torch.onnx.export(model, dummy, args.arch + '.onnx', verbose=True, opset_version=10)
print('Export is done')
if __name__ == '__main__':
main()
(c) 2019 ax Inc. & AXELL CORPORATION