-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #25 from banctilrobitaille/develop
Develop
- Loading branch information
Showing
38 changed files
with
1,306 additions
and
653 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -105,4 +105,5 @@ venv.bak/ | |
|
||
idea | ||
.idea | ||
/data/* | ||
/data/* | ||
/tests/functionals/distributed/files/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
python3 setup.py sdist | ||
python3 -m twine upload dist/torch-kerosene-0.0.85.tar.gz | ||
python3 -m twine upload dist/torch-kerosene-0.0.86.tar.gz |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import multiprocessing | ||
from typing import Callable | ||
|
||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
from kerosene.config.trainers import RunConfiguration, TrainerConfiguration | ||
from kerosene.utils.devices import on_single_device | ||
|
||
|
||
class DataloaderFactory(object): | ||
def __init__(self, train_dataset: Dataset, valid_dataset: Dataset): | ||
self._train_dataset = train_dataset | ||
self._valid_dataset = valid_dataset | ||
|
||
def create(self, run_config: RunConfiguration, training_config: TrainerConfiguration, collate_fn: Callable = None): | ||
devices = run_config.devices | ||
if not on_single_device(devices): | ||
torch.distributed.init_process_group(backend='nccl', init_method='env://', rank=run_config.local_rank) | ||
train_sampler = torch.utils.data.distributed.DistributedSampler(self._train_dataset) | ||
valid_sampler = torch.utils.data.distributed.DistributedSampler(self._valid_dataset) | ||
|
||
train_loader = torch.utils.data.DataLoader(dataset=self._train_dataset, | ||
batch_size=training_config.batch_size, | ||
shuffle=False if not on_single_device(devices) else True, | ||
num_workers=multiprocessing.cpu_count() // 2 if not on_single_device( | ||
devices) else multiprocessing.cpu_count(), | ||
sampler=train_sampler if not on_single_device(devices) else None, | ||
collate_fn=collate_fn, | ||
pin_memory=torch.cuda.is_available()) | ||
|
||
valid_loader = torch.utils.data.DataLoader(dataset=self._valid_dataset, | ||
batch_size=training_config.batch_size, | ||
shuffle=False if not on_single_device(devices) else True, | ||
num_workers=multiprocessing.cpu_count() // 2 if not on_single_device( | ||
devices) else multiprocessing.cpu_count(), | ||
sampler=valid_sampler if not on_single_device(devices) else None, | ||
collate_fn=collate_fn, | ||
pin_memory=torch.cuda.is_available()) | ||
return train_loader, valid_loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from abc import ABC | ||
from typing import Dict | ||
|
||
from kerosene.events import BaseVariable, MonitorMode, Monitor | ||
from kerosene.events.handlers.base_handler import EventHandler | ||
|
||
|
||
class MonitorPatienceExceeded(Exception): | ||
pass | ||
|
||
|
||
class MonitorInspection(object): | ||
def __init__(self, value=0, inspection_num=0): | ||
self._value = value | ||
self._inspection_num = inspection_num | ||
|
||
@property | ||
def value(self): | ||
return self._value | ||
|
||
@value.setter | ||
def value(self, new_value): | ||
self._value = new_value | ||
|
||
@property | ||
def inspection_num(self): | ||
return self._inspection_num | ||
|
||
def add_inspection(self): | ||
self._inspection_num = self._inspection_num + 1 | ||
|
||
def reset_inspection_num(self): | ||
self._inspection_num = 0 | ||
return self | ||
|
||
def with_value(self, value): | ||
self._value = value | ||
return self | ||
|
||
|
||
class MonitorWatcher(EventHandler, ABC): | ||
def __init__(self, monitor: BaseVariable, mode: MonitorMode = MonitorMode.AUTO, min_delta=0.01, patience=3): | ||
assert isinstance(monitor, | ||
Monitor) or mode is not MonitorMode.AUTO, "Auto mode is not allowed with custom variables" | ||
|
||
self._monitor = monitor | ||
self._mode = mode | ||
self._min_delta = min_delta | ||
self._patience = patience | ||
self._monitor_values: Dict[str, MonitorInspection] = {} | ||
|
||
if mode is MonitorMode.AUTO: | ||
self._mode = MonitorWatcher.get_mode_for(monitor) | ||
|
||
@property | ||
def monitor(self): | ||
return self._monitor | ||
|
||
@property | ||
def mode(self): | ||
return self._mode | ||
|
||
@property | ||
def min_delta(self): | ||
return self._min_delta | ||
|
||
@property | ||
def patience(self): | ||
return self._patience | ||
|
||
@property | ||
def monitor_values(self): | ||
return self._monitor_values | ||
|
||
def watch(self, source_name, current_monitor_value): | ||
if source_name not in self._monitor_values.keys(): | ||
self._monitor_values[source_name] = MonitorInspection(value=current_monitor_value) | ||
else: | ||
if self._mode is MonitorMode.MIN: | ||
delta = self._monitor_values[source_name].value - current_monitor_value | ||
else: | ||
delta = current_monitor_value - self._monitor_values[source_name].value | ||
|
||
if delta >= self._min_delta: | ||
self._monitor_values[source_name].with_value(current_monitor_value).reset_inspection_num() | ||
else: | ||
self._monitor_values[source_name].add_inspection() | ||
if self._monitor_values[source_name].inspection_num >= self._patience: | ||
raise MonitorPatienceExceeded() | ||
|
||
@staticmethod | ||
def get_mode_for(monitor: Monitor): | ||
return MonitorMode.MIN if monitor.is_loss() else MonitorMode.MAX |
Oops, something went wrong.