This repository has been archived by the owner on Aug 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
86 lines (76 loc) · 3.39 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import torch
import torchvision
import torchvision.transforms as transforms
from trainer import CapsNetTrainer
import yaml, argparse
from utils.util import ensure_dir
from logger.logger import Logger #
def main(args):
conf = yaml.load(open(args.config))
conf.update(conf[conf['model']])
if args.multi_gpu:
conf['batch_size'] *= torch.cuda.device_count()
datasets = {
'MNIST': torchvision.datasets.MNIST,
'CIFAR': torchvision.datasets.CIFAR10
}
if conf['dataset'].upper() == 'MNIST':
conf['data_path'] = os.path.join(conf['data_path'], 'MNIST')
size = 28
classes = list(range(10))
mean, std = ((0.1307,), (0.3081,))
elif conf['dataset'].upper() == 'CIFAR':
conf['data_path'] = os.path.join(conf['data_path'], 'CIFAR')
size = 32
classes = ['plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
mean, std = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
else:
raise ValueError('Dataset must be either MNIST or CIFAR!')
transform = transforms.Compose([
transforms.RandomCrop(size, padding=2),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
loaders = {}
trainset = datasets[conf['dataset'].upper()](root=conf['data_path'],
train=True, download=True, transform=transform)
testset = datasets[conf['dataset'].upper()](root=conf['data_path'],
train=False, download=True, transform=transform)
loaders['train'] = torch.utils.data.DataLoader(trainset,
batch_size=conf['batch_size'], shuffle=True, num_workers=4)
loaders['test'] = torch.utils.data.DataLoader(testset,
batch_size=conf['batch_size'], shuffle=False, num_workers=4)
print(9*'#', 'Using {} dataset'.format(conf['dataset']), 9*'#')
# Training
use_gpu = not args.disable_gpu and torch.cuda.is_available()
caps_net = CapsNetTrainer(loaders,
conf['model'],
conf['lr'],
conf['lr_decay'],
conf['num_classes'],
conf['num_routing'],
conf['loss'],
use_gpu=use_gpu,
multi_gpu=args.multi_gpu)
ensure_dir('logs') #
logger = {}
logger['train'] = Logger('logs/{}-train'.format(conf['dataset']))
logger['test'] = Logger('logs/{}-test'.format(conf['dataset']))
ensure_dir(conf['save_dir']) #
caps_net.train(conf['epochs'], classes, conf['save_dir'], logger)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch Capsules Networks')
parser.add_argument('-c', '--config', default='config.yaml', type=str,
help='config file path (default: config.yaml)')
parser.add_argument('-r', '--resume', default=None, type=str,
help='path to latest checkpoint (default: None)')
# Use multiple GPUs? '--multi_gpu' will store multi_gpu as True
parser.add_argument('--multi_gpu', action='store_true',
help='Flag whether to use multiple GPUs.')
# Select GPU device
parser.add_argument('--disable_gpu', action='store_true',
help='Flag whether to use disable GPU')
args = parser.parse_args()
main(args)