Skip to content

Commit

Permalink
Cleanup (#11)
Browse files Browse the repository at this point in the history
* cifar wrapper

* move load_state_dict to nbdt.models.utils

* train loss oops

* cleaner dataset kwarg attempt

* aesthetic changes

* decorator for analyzer

* imports and docstring

* moar aesthetics

* more aesthetic
  • Loading branch information
alvinwan authored Sep 10, 2020
1 parent b686d00 commit 5af3c15
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 136 deletions.
163 changes: 34 additions & 129 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,33 @@
"""
Neural-Backed Decision Trees training script on CIFAR10, CIFAR100, TinyImagenet200
The original version of this `main.py` was taken from
https://github.com/kuangliu/pytorch-cifar
and extended in
https://github.com/alvinwan/pytorch-cifar-plus
Neural-Backed Decision Trees training on CIFAR10, CIFAR100, TinyImagenet200
The original version of this `main.py` was taken from kuangliu/pytorch-cifar.
The script has since been heavily modified to support a number of different
configurations and options. See the current repository for a full description
of its bells and whistles.
https://github.com/alvinwan/neural-backed-decision-trees
configurations and options: alvinwan/neural-backed-decision-trees
"""
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch import nn, optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from nbdt import data, analysis, loss, models, metrics

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import numpy as np

from nbdt.utils import (
progress_bar, generate_checkpoint_fname, generate_kwargs, Colors
)
from nbdt import data, analysis, loss, models, metrics
from nbdt.utils import progress_bar, generate_checkpoint_fname, generate_kwargs, Colors
from nbdt.thirdparty.wn import maybe_install_wordnet
from nbdt.models.utils import load_state_dict, make_kwarg_optional

maybe_install_wordnet()

datasets = ('CIFAR10', 'CIFAR100') + data.imagenet.names + data.custom.names


parser = argparse.ArgumentParser(description='PyTorch CIFAR Training')
parser.add_argument('--batch-size', default=512, type=int,
help='Batch size used for training')
parser.add_argument('--epochs', '-e', default=200, type=int,
help='By default, lr schedule is scaled accordingly')
parser.add_argument('--dataset', default='CIFAR10', choices=datasets)
parser.add_argument('--dataset', default='CIFAR10', choices=data.cifar.names + data.imagenet.names + data.custom.names)
parser.add_argument('--arch', default='ResNet18', choices=list(models.get_model_choices()))
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
Expand All @@ -57,22 +40,16 @@
parser.add_argument('--pretrained', action='store_true',
help='Download pretrained model. Not all models support this.')
parser.add_argument('--eval', help='eval only', action='store_true')

# options specific to this project and its dataloaders
parser.add_argument('--loss', choices=loss.names, default='CrossEntropyLoss')
parser.add_argument('--metric', choices=metrics.names, default='top1')
parser.add_argument('--analysis', choices=analysis.names, help='Run analysis after each epoch')
parser.add_argument('--input-size', type=int,
help='Set transform train and val. Samples are resized to '
'input-size + 32.')
parser.add_argument('--lr-decay-every', type=int, default=0)

# other dataset, loss or analysis specific options
data.custom.add_arguments(parser)
loss.add_arguments(parser)
analysis.add_arguments(parser)

args = parser.parse_args()

loss.set_default_values(args)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
Expand All @@ -81,40 +58,9 @@

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

dataset = getattr(data, args.dataset)

if args.dataset in ('TinyImagenet200', 'Imagenet1000'):
default_input_size = 64 if args.dataset == 'TinyImagenet200' else 224
input_size = args.input_size or default_input_size
transform_train = dataset.transform_train(input_size)
transform_test = dataset.transform_val(input_size)
elif args.input_size is not None and args.input_size > 32:
transform_train = transforms.Compose([
transforms.Resize(args.input_size + 32),
transforms.RandomCrop(args.input_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
transforms.Resize(args.input_size + 32),
transforms.CenterCrop(args.input_size),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_train = dataset.transform_train()
transform_test = dataset.transform_val()

dataset_kwargs = generate_kwargs(args, dataset,
name=f'Dataset {args.dataset}',
Expand All @@ -138,16 +84,7 @@

if args.pretrained:
print('==> Loading pretrained model..')
try:
net = model(pretrained=True, dataset=args.dataset, **model_kwargs)
except TypeError as e: # likely because `dataset` not allowed arg
print(e)

try:
net = model(pretrained=True, **model_kwargs)
except Exception as e:
Colors.red(f'Fatal error: {e}')
exit()
net = make_kwarg_optional(model, dataset=args.dataset)(pretrained=True, **model_kwargs)
else:
net = model(**model_kwargs)

Expand All @@ -160,19 +97,6 @@
checkpoint_path = './checkpoint/{}.pth'.format(checkpoint_fname)
print(f'==> Checkpoints will be saved to: {checkpoint_path}')


# TODO(alvin): fix checkpoint structure so that this isn't neededd
def load_state_dict(state_dict):
try:
net.load_state_dict(state_dict)
except RuntimeError as e:
if 'Missing key(s) in state_dict:' in str(e):
net.load_state_dict({
key.replace('module.', '', 1): value
for key, value in state_dict.items()
})


resume_path = args.path_resume or checkpoint_path
if args.resume:
# Load checkpoint.
Expand All @@ -184,13 +108,13 @@ def load_state_dict(state_dict):
checkpoint = torch.load(resume_path, map_location=torch.device(device))

if 'net' in checkpoint:
load_state_dict(checkpoint['net'])
load_state_dict(net, checkpoint['net'])
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']
Colors.cyan(f'==> Checkpoint found for epoch {start_epoch} with accuracy '
f'{best_acc} at {resume_path}')
else:
load_state_dict(checkpoint)
load_state_dict(net, checkpoint)
Colors.cyan(f'==> Checkpoint found at {resume_path}')


Expand All @@ -205,19 +129,25 @@ def load_state_dict(state_dict):
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

def adjust_learning_rate(epoch, lr):
if args.lr_decay_every:
steps = epoch // args.lr_decay_every
return lr / (10 ** steps)
if epoch <= 150 / 350. * args.epochs: # 32k iterations
return lr
elif epoch <= 250 / 350. * args.epochs: # 48k iterations
return lr/10
else:
return lr/100


class_analysis = getattr(analysis, args.analysis or 'Noop')
analyzer_kwargs = generate_kwargs(args, class_analysis,
name=f'Analyzer {args.analysis}',
keys=analysis.keys,
globals=globals())
analyzer = class_analysis(**analyzer_kwargs)


# Training
@analyzer.train_function
def train(epoch, analyzer):
analyzer.start_train(epoch)
lr = adjust_learning_rate(epoch, args.lr)
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)

Expand All @@ -235,19 +165,14 @@ def train(epoch, analyzer):

train_loss += loss.item()
metric.forward(outputs, targets)

stat = analyzer.update_batch(outputs, targets)
extra = f'| {stat}' if stat else ''

progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d) %s'
% (test_loss/(batch_idx+1), 100. * metric.report(),
metric.correct, metric.total, extra))
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d) %s' % (
train_loss / ( batch_idx + 1 ), 100. * metric.report(), metric.correct, metric.total, f'| {stat}' if stat else ''))

analyzer.end_train(epoch)

@analyzer.test_function
def test(epoch, analyzer, checkpoint=True):
analyzer.start_test(epoch)

global best_acc
net.eval()
test_loss = 0
Expand All @@ -260,17 +185,10 @@ def test(epoch, analyzer, checkpoint=True):

test_loss += loss.item()
metric.forward(outputs, targets)

if device == 'cuda':
outputs = outputs.cpu()
targets = targets.cpu()

stat = analyzer.update_batch(outputs, targets)
extra = f'| {stat}' if stat else ''

progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d) %s'
% (test_loss/(batch_idx+1), 100. * metric.report(),
metric.correct, metric.total, extra))
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d) %s' % (
test_loss / ( batch_idx + 1 ), 100. * metric.report(), metric.correct, metric.total, f'| {stat}' if stat else ''))

# Save checkpoint.
acc = 100. * metric.report()
Expand All @@ -288,16 +206,6 @@ def test(epoch, analyzer, checkpoint=True):
torch.save(state, f'./checkpoint/{checkpoint_fname}.pth')
best_acc = acc

analyzer.end_test(epoch)


class_analysis = getattr(analysis, args.analysis or 'Noop')
analyzer_kwargs = generate_kwargs(args, class_analysis,
name=f'Analyzer {args.analysis}',
keys=analysis.keys,
globals=globals())
analyzer = class_analysis(**analyzer_kwargs)


if args.eval:
if not args.resume and not args.pretrained:
Expand All @@ -306,16 +214,13 @@ def test(epoch, analyzer, checkpoint=True):

analyzer.start_epoch(0)
test(0, analyzer, checkpoint=False)
analyzer.end_epoch(0)
exit()


for epoch in range(start_epoch, args.epochs):
analyzer.start_epoch(epoch)
train(epoch, analyzer)
test(epoch, analyzer)
analyzer.end_epoch(epoch)

if args.epochs == 0:
analyzer.start_epoch(0)
test(0, analyzer)
analyzer.end_epoch(0)
print(f'Best accuracy: {best_acc} // Checkpoint name: {checkpoint_fname}')
27 changes: 27 additions & 0 deletions nbdt/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
HardEmbeddedDecisionRules as HardRules
)
from nbdt import metrics
import functools
import numpy as np


Expand All @@ -17,6 +18,20 @@ def add_arguments(parser):
pass


def start_end_decorator(obj, name):
start = getattr(obj, f'start_{name}', None)
end = getattr(obj, f'end_{name}', None)
assert start and end
def decorator(f):
@functools.wraps(f)
def wrapper(epoch, *args, **kwargs):
start(epoch)
f(epoch, *args, **kwargs)
end(epoch)
return wrapper
return decorator


class Noop:

accepts_classes = lambda trainset, **kwargs: trainset.classes
Expand All @@ -28,6 +43,18 @@ def __init__(self, classes=()):
self.num_classes = len(classes)
self.epoch = None

@property
def epoch_function(self):
return start_end_decorator(self, 'epoch')

@property
def train_function(self):
return start_end_decorator(self, 'train')

@property
def test_function(self):
return start_end_decorator(self, 'test')

def start_epoch(self, epoch):
self.epoch = epoch

Expand Down
1 change: 1 addition & 0 deletions nbdt/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .lip import *
from .ade20k import *
from torchvision.datasets import *
from .cifar import *
31 changes: 31 additions & 0 deletions nbdt/data/cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Wrappers around CIFAR datasets"""

from torchvision import datasets, transforms

__all__ = names = ('CIFAR10', 'CIFAR100')


class CIFAR:

@staticmethod
def transform_train():
return transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

@staticmethod
def transform_val():
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


class CIFAR10(datasets.CIFAR10, CIFAR):
pass

class CIFAR100(datasets.CIFAR100, CIFAR):
pass
Loading

0 comments on commit 5af3c15

Please sign in to comment.