Skip to content

Commit

Permalink
Merge pull request #25 from banctilrobitaille/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
Benoit Anctil-Robitaille authored Sep 12, 2019
2 parents 1b81c18 + 317dffe commit cfaea2a
Show file tree
Hide file tree
Showing 38 changed files with 1,306 additions and 653 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,5 @@ venv.bak/

idea
.idea
/data/*
/data/*
/tests/functionals/distributed/files/
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,24 @@ if __name__ == "__main__":

model_trainer_config, training_config = YamlConfigurationParser.parse(CONFIG_FILE_PATH)

train_loader = DataLoader(torchvision.datasets.MNIST('/files/', train=True, download=True, transform=Compose(
train_loader = DataLoader(torchvision.datasets.MNIST('./files/', train=True, download=True, transform=Compose(
[ToTensor(), Normalize((0.1307,), (0.3081,))])), batch_size=training_config.batch_size_train, shuffle=True)

test_loader = DataLoader(torchvision.datasets.MNIST('/files/', train=False, download=True, transform=Compose(
test_loader = DataLoader(torchvision.datasets.MNIST('./files/', train=False, download=True, transform=Compose(
[ToTensor(), Normalize((0.1307,), (0.3081,))])), batch_size=training_config.batch_size_valid, shuffle=True)

# Initialize the loggers
visdom_logger = VisdomLogger(VisdomConfiguration.from_yml(CONFIG_FILE_PATH))

# Initialize the model trainers
model_trainer = ModelTrainerFactory(model=SimpleNet()).create(model_trainer_config)
model_trainer = ModelTrainerFactory(model=SimpleNet()).create(model_trainer_config, RunConfiguration(use_amp=False))

# Train with the training strategy
SimpleTrainer("MNIST Trainer", train_loader, test_loader, model_trainer) \
.with_event_handler(ConsoleLogger(), Event.ON_EPOCH_END) \
.with_event_handler(visdom_logger, Event.ON_EPOCH_END, PlotAllModelStateVariables()) \
trainer = SimpleTrainer("MNIST Trainer", train_loader, test_loader, model_trainer) \
.with_event_handler(PrintTrainingStatus(every=100), Event.ON_BATCH_END) \
.with_event_handler(PrintModelTrainersStatus(every=100), Event.ON_BATCH_END) \
.with_event_handler(PlotAllModelStateVariables(visdom_logger), Event.ON_EPOCH_END) \
.with_event_handler(PlotGradientFlow(visdom_logger, every=100), Event.ON_TRAIN_BATCH_END) \
.train(training_config.nb_epochs)
```

Expand Down
2 changes: 1 addition & 1 deletion deploy.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
python3 setup.py sdist
python3 -m twine upload dist/torch-kerosene-0.0.85.tar.gz
python3 -m twine upload dist/torch-kerosene-0.0.86.tar.gz
13 changes: 10 additions & 3 deletions kerosene/config/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@

class RunConfiguration(object):
def __init__(self, use_amp: bool = True, amp_opt_level: str = 'O2', local_rank: int = 0):
use_cuda = torch.cuda.is_available()

self._use_amp = use_amp
self._amp_opt_level = amp_opt_level
self._devices = ([torch.device("cuda:{}".format(device_id)) for device_id in
range(torch.cuda.device_count())]) if use_cuda else [torch.device("cpu")]
range(torch.cuda.device_count())]) if torch.cuda.is_available() else [torch.device("cpu")]
self._local_rank = local_rank
self._device = self._devices[self._local_rank]

@property
def use_amp(self):
Expand All @@ -44,6 +43,10 @@ def devices(self):
def local_rank(self):
return self._local_rank

@property
def device(self):
return self._device

def with_amp_opt_level(self, amp_opt_level: str):
self._amp_opt_level = amp_opt_level
return self
Expand All @@ -60,6 +63,10 @@ def with_local_rank(self, local_rank: int):
self._local_rank = local_rank
return self

def with_device(self, device: torch.device):
self._device = device
return self


class TrainerConfiguration(object):
def __init__(self, config_dict):
Expand Down
File renamed without changes.
40 changes: 40 additions & 0 deletions kerosene/dataloaders/factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import multiprocessing
from typing import Callable

import torch
from torch.utils.data import Dataset

from kerosene.config.trainers import RunConfiguration, TrainerConfiguration
from kerosene.utils.devices import on_single_device


class DataloaderFactory(object):
def __init__(self, train_dataset: Dataset, valid_dataset: Dataset):
self._train_dataset = train_dataset
self._valid_dataset = valid_dataset

def create(self, run_config: RunConfiguration, training_config: TrainerConfiguration, collate_fn: Callable = None):
devices = run_config.devices
if not on_single_device(devices):
torch.distributed.init_process_group(backend='nccl', init_method='env://', rank=run_config.local_rank)
train_sampler = torch.utils.data.distributed.DistributedSampler(self._train_dataset)
valid_sampler = torch.utils.data.distributed.DistributedSampler(self._valid_dataset)

train_loader = torch.utils.data.DataLoader(dataset=self._train_dataset,
batch_size=training_config.batch_size,
shuffle=False if not on_single_device(devices) else True,
num_workers=multiprocessing.cpu_count() // 2 if not on_single_device(
devices) else multiprocessing.cpu_count(),
sampler=train_sampler if not on_single_device(devices) else None,
collate_fn=collate_fn,
pin_memory=torch.cuda.is_available())

valid_loader = torch.utils.data.DataLoader(dataset=self._valid_dataset,
batch_size=training_config.batch_size,
shuffle=False if not on_single_device(devices) else True,
num_workers=multiprocessing.cpu_count() // 2 if not on_single_device(
devices) else multiprocessing.cpu_count(),
sampler=valid_sampler if not on_single_device(devices) else None,
collate_fn=collate_fn,
pin_memory=torch.cuda.is_available())
return train_loader, valid_loader
42 changes: 34 additions & 8 deletions kerosene/events/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
from enum import Enum


class Event(Enum):
class BaseEvent(Enum):
def __str__(self):
return self.value


class BaseVariable(Enum):
def __str__(self):
return self.value

def __eq__(self, other):
if isinstance(other, str):
return self.value == other
elif isinstance(other, BaseVariable):
return self.value == other.value


class Event(BaseEvent):
ON_TRAINING_BEGIN = "training_begin"
ON_TRAINING_END = "training_end"
ON_EPOCH_BEGIN = "epoch_begin"
Expand All @@ -18,13 +34,23 @@ class Event(Enum):
ON_BATCH_END = "batch_end"


class Monitor(Enum):
TRAINING_LOSS = "TrainingLoss"
TRAINING_METRIC = "TrainingMetric"
VALIDATION_LOSS = "ValidationLoss"
VALIDATION_METRIC = "ValidationMetric"
class Monitor(BaseVariable):
TRAINING_LOSS = "train_loss"
TRAINING_METRIC = "train_metric"
VALIDATION_LOSS = "valid_loss"
VALIDATION_METRIC = "valid_metric"

def is_loss(self):
return "loss" in self.value

def is_metric(self):
return "metric" in self.value


class MonitorMode(Enum):
MIN = "min"
MAX = "max"
MIN = -1
MAX = 1
AUTO = "auto"

def __str__(self):
return self.value
6 changes: 3 additions & 3 deletions kerosene/events/generators/base_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from abc import ABC, abstractmethod
from typing import Callable

from kerosene.events import Event
from kerosene.events import BaseEvent


class EventGenerator(ABC):
Expand All @@ -29,10 +29,10 @@ def state(self):
raise NotImplementedError()

@abstractmethod
def with_event_handler(self, handler, event: Event, preprocessor: Callable):
def with_event_handler(self, handler, event: BaseEvent):
raise NotImplementedError()

def fire(self, event: Event):
def fire(self, event: BaseEvent):
if event in self._event_handlers.keys():
state = self.state

Expand Down
20 changes: 20 additions & 0 deletions kerosene/events/handlers/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,29 @@
# ==============================================================================
from abc import ABC, abstractmethod

from kerosene.events import Event


class EventHandler(ABC):

def __init__(self, every=1):
self._every = every

@property
def every(self):
return self._every

def should_handle_epoch_data(self, event, trainer):
return (event in [Event.ON_EPOCH_BEGIN, Event.ON_EPOCH_END, Event.ON_TRAIN_EPOCH_BEGIN,
Event.ON_TRAIN_EPOCH_END, Event.ON_VALID_EPOCH_BEGIN, Event.ON_VALID_EPOCH_END]) and (
trainer.epoch % self._every == 0)

def should_handle_step_data(self, event, trainer):
if event in [Event.ON_TRAIN_BATCH_BEGIN, Event.ON_TRAIN_BATCH_END, Event.ON_BATCH_END]:
return trainer.current_train_step % self._every == 0
elif event in [Event.ON_VALID_BATCH_BEGIN, Event.ON_VALID_BATCH_END, Event.ON_BATCH_END]:
return trainer.current_valid_step % self._every == 0

@abstractmethod
def __call__(self, *inputs):
raise NotImplementedError()
93 changes: 93 additions & 0 deletions kerosene/events/handlers/base_monitor_watcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from abc import ABC
from typing import Dict

from kerosene.events import BaseVariable, MonitorMode, Monitor
from kerosene.events.handlers.base_handler import EventHandler


class MonitorPatienceExceeded(Exception):
pass


class MonitorInspection(object):
def __init__(self, value=0, inspection_num=0):
self._value = value
self._inspection_num = inspection_num

@property
def value(self):
return self._value

@value.setter
def value(self, new_value):
self._value = new_value

@property
def inspection_num(self):
return self._inspection_num

def add_inspection(self):
self._inspection_num = self._inspection_num + 1

def reset_inspection_num(self):
self._inspection_num = 0
return self

def with_value(self, value):
self._value = value
return self


class MonitorWatcher(EventHandler, ABC):
def __init__(self, monitor: BaseVariable, mode: MonitorMode = MonitorMode.AUTO, min_delta=0.01, patience=3):
assert isinstance(monitor,
Monitor) or mode is not MonitorMode.AUTO, "Auto mode is not allowed with custom variables"

self._monitor = monitor
self._mode = mode
self._min_delta = min_delta
self._patience = patience
self._monitor_values: Dict[str, MonitorInspection] = {}

if mode is MonitorMode.AUTO:
self._mode = MonitorWatcher.get_mode_for(monitor)

@property
def monitor(self):
return self._monitor

@property
def mode(self):
return self._mode

@property
def min_delta(self):
return self._min_delta

@property
def patience(self):
return self._patience

@property
def monitor_values(self):
return self._monitor_values

def watch(self, source_name, current_monitor_value):
if source_name not in self._monitor_values.keys():
self._monitor_values[source_name] = MonitorInspection(value=current_monitor_value)
else:
if self._mode is MonitorMode.MIN:
delta = self._monitor_values[source_name].value - current_monitor_value
else:
delta = current_monitor_value - self._monitor_values[source_name].value

if delta >= self._min_delta:
self._monitor_values[source_name].with_value(current_monitor_value).reset_inspection_num()
else:
self._monitor_values[source_name].add_inspection()
if self._monitor_values[source_name].inspection_num >= self._patience:
raise MonitorPatienceExceeded()

@staticmethod
def get_mode_for(monitor: Monitor):
return MonitorMode.MIN if monitor.is_loss() else MonitorMode.MAX
Loading

0 comments on commit cfaea2a

Please sign in to comment.