Skip to content

partialconv

sngyo edited this page Nov 4, 2019 · 2 revisions

Export script for partialconv

import argparse
import os   
import torch
import torchvision.models as models_baseline # networks with zero padding
import models as models_partial # partial conv based padding 


model_baseline_names = sorted(name for name in models_baseline.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models_baseline.__dict__[name]))

model_partial_names = sorted(name for name in models_partial.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models_partial.__dict__[name]))

model_names = model_baseline_names + model_partial_names

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50',
                    choices=model_names,
                    help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet50)')


def main():
    global args
    args = parser.parse_args()
    
    print("=> using pre-trained model '{}'".format(args.arch))
    if args.arch in models_baseline.__dict__:
        model = models_baseline.__dict__[args.arch](pretrained=True)
    else:
        model = models_partial.__dict__[args.arch](pretrained=True)

    print(model)
    model.eval()
    dummy = torch.autograd.Variable(torch.randn(1, 3, 224, 224))
    out = model(dummy)
    torch.onnx.export(model, dummy, args.arch + '.onnx', verbose=True, opset_version=10)
    print('Export is done')

if __name__ == '__main__':
    main()
Clone this wiki locally