Skip to content

Commit

Permalink
Merge pull request #48 from kaseris/fix/environment-build
Browse files Browse the repository at this point in the history
Fix/environment build
  • Loading branch information
kaseris authored Dec 12, 2023
2 parents dabc95f + a26df9b commit 699dbd2
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 66 deletions.
65 changes: 65 additions & 0 deletions configs/pvred.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
dataset:
name: NTURGBDDataset
args:
data_directory: /home/kaseris/Documents/mount/data_ntu_rgbd
label_file: /home/kaseris/Documents/dev/skelcast/data/labels.txt
missing_files_dir: /home/kaseris/Documents/dev/skelcast/data/missing
max_context_window: 10
max_number_of_bodies: 1
max_duration: 300
n_joints: 25
cache_file: /home/kaseris/Documents/mount/dataset_cache.pkl

transforms:
- name: MinMaxScaleTransform
args:
feature_scale: [0.0, 1.0]
- name: SomeOtherTransform
args:
some_arg: some_value

loss:
name: MSELoss
args:
reduction: mean

collate_fn:
name: NTURGBDCollateFnWithRandomSampledContextWindow
args:
block_size: 25

logger:
name: TensorboardLogger
args:
save_dir: runs

optimizer:
name: AdamW
args:
lr: 0.0001
weight_decay: 0.0001

model:
name: PositionalVelocityRecurrentEncoderDecoder
args:
input_dim: 75
enc_hidden_dim: 64
dec_hidden_dim: 64
enc_type: lstm
dec_type: lstm
include_velocity: false
pos_enc: add
batch_first: true
std_thresh: 0.0001
use_std_mask: false
use_padded_len_mask: true
observe_until: 20

runner:
name: Runner
args:
train_batch_size: 32
val_batch_size: 32
block_size: 8
log_gradient_info: true
device: cuda
143 changes: 143 additions & 0 deletions src/skelcast/core/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import re
import logging
import yaml

from collections import OrderedDict
from typing import Any, List

from skelcast.core.registry import Registry


class Config:
def __init__(self):
self._config = OrderedDict()
self._config['name'] = None
self._config['args'] = {}

def get(self, key):
return self._config[key]

def set(self, key, value):
if isinstance(value, list):
self._config[key] = []
for v in value:
self._config[key].append(v)
else:
self._config[key] = value

def __str__(self) -> str:
s = self.__class__.__name__ + '(\n'
for key, val in self._config.items():
if isinstance(val, dict):
s += f'\t{key}: \n'
for k, v in val.items():
s += f'\t\t{k}: {v}\n'
s += '\t\n'
else:
s += f'\t{key}: {val}\n'
s += ')'
return s


class ModelConfig(Config):
def __init__(self):
super(ModelConfig, self).__init__()


class DatasetConfig(Config):
def __init__(self):
super(DatasetConfig, self).__init__()


class TransformsConfig(Config):
def __init__(self):
super(TransformsConfig, self).__init__()


class LoggerConfig(Config):
def __init__(self):
super(LoggerConfig, self).__init__()


class OptimizerConfig(Config):
def __init__(self):
super(OptimizerConfig, self).__init__()


class SchedulerConfig(Config):
def __init__(self):
super(SchedulerConfig, self).__init__()


class CriterionConfig(Config):
def __init__(self):
super(CriterionConfig, self).__init__()


class CollateFnConfig(Config):
def __init__(self):
super(CollateFnConfig, self).__init__()


class RunnerConfig(Config):
def __init__(self):
super(RunnerConfig, self).__init__()

class EnvironmentConfig:
def __init__(self, *args) -> None:
for arg in args:
name = arg.__class__.__name__
split_name = re.findall('[A-Z][^A-Z]*', name)
name = '_'.join([s.lower() for s in split_name])
setattr(self, name, arg)

def __str__(self) -> str:
s = self.__class__.__name__ + '(\n'
for key, val in self.__dict__.items():
s += f'\t{key}: {val}\n'
s += ')'
return s


def build_object_from_config(config: Config, registry: Registry, **kwargs):
_name = config.get('name')
_args = config.get('args')
_args.update(kwargs)
return registry.get_module(_name)(**_args)

def summarize_config(configs: List[Config]):
with open(f'/home/kaseris/Documents/mount/config.txt', 'w') as f:
for config in configs:
f.write(str(config))
f.write('\n\n')


CONFIG_MAPPING = {
'model': ModelConfig,
'dataset': DatasetConfig,
'transforms': TransformsConfig,
'logger': LoggerConfig,
'optimizer': OptimizerConfig,
'scheduler': SchedulerConfig,
'loss': CriterionConfig,
'collate_fn': CollateFnConfig,
'runner': RunnerConfig
}

def read_config(config_path: str):
with open(config_path, 'r') as f:
data = yaml.safe_load(f)
cfgs = []
for key in data:
config = CONFIG_MAPPING[key]()
if key == 'transforms':
for element in data[key]:
logging.debug(f'Loading {key} config. Building {config.__class__.__name__} object.')
config.set(element['name'], element['args'])
else:
logging.debug(f'Loading {key} config. Building {config.__class__.__name__} object.')
config.set('name', data[key]['name'])
config.set('args', data[key]['args'])
logging.debug(config)
cfgs.append(config)
return EnvironmentConfig(*cfgs)
72 changes: 8 additions & 64 deletions src/skelcast/core/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from skelcast.logger import LOGGERS

from skelcast.experiments.runner import Runner
from skelcast.core.config import read_config, build_object_from_config

torch.manual_seed(133742069)

Expand Down Expand Up @@ -66,71 +67,14 @@ def experiment_name(self) -> str:
return self._experiment_name

def build_from_file(self, config_path: str) -> None:
config = self._parse_file(config_path)
self.config = config
logging.log(logging.INFO, f'Building environment from {config_path}.')
self._build_dataset()
self._build_model()
self._build_logger()
self._build_runner()

def _build_model(self) -> None:
logging.log(logging.INFO, 'Building model.')
model_config = self.config['model']
_name = model_config.get('name')
_args = model_config.get('args')
self._model = MODELS.get_module(_name)(**_args)
logging.log(logging.INFO, f'Model creation complete.')

def _build_dataset(self) -> None:
logging.log(logging.INFO, 'Building dataset.')
dataset_config = self.config['dataset']
_name = dataset_config.get('name')
_args = dataset_config.get('args')
_transforms_cfg = dataset_config.get('args').get('transforms')
_transforms = TRANSFORMS.get_module(_transforms_cfg.get('name'))(**_transforms_cfg.get('args'))
_args['transforms'] = _transforms
self._dataset = DATASETS.get_module(_name)(self.data_dir, **_args)
# Split the dataset
_train_len = int(self.config['train_data_percentage'] * len(self._dataset))
self._train_dataset, self._val_dataset = random_split(self._dataset, [_train_len, len(self._dataset) - _train_len])
logging.log(logging.INFO, f'Train set size: {len(self._train_dataset)}')

def _build_logger(self) -> None:
logging.log(logging.INFO, 'Building logger.')
logger_config = self.config['runner']['args'].get('logger')
logdir = os.path.join(logger_config['args']['save_dir'], self.experiment_name)
self._logger = LOGGERS.get_module(logger_config['name'])(logdir)
logging.log(logging.INFO, f'Logging to {logdir}.')

def _build_runner(self) -> None:
logging.log(logging.INFO, 'Building runner.')
runner_config = self.config['runner']
_args = runner_config.get('args')
_args['logger'] = self._logger
_args['optimizer'] = optim.AdamW(self._model.parameters(), lr=_args.get('lr'))
_args['train_set'] = self._train_dataset
_args['val_set'] = self._val_dataset
_args['model'] = self._model
_args['train_set'] = self._train_dataset
_args['val_set'] = self._val_dataset
_args['checkpoint_dir'] = os.path.join(self.checkpoint_dir, self._experiment_name)
self._create_checkpoint_dir()
self._runner = Runner(**_args)
self._runner.setup()
logging.log(logging.INFO, 'Runner setup complete.')

def _create_checkpoint_dir(self) -> None:
if os.path.exists(os.path.join(self.checkpoint_dir, self._experiment_name)):
raise ValueError(f'Checkpoint directory {os.path.join(self.checkpoint_dir, self._experiment_name)} already exists.')
else:
logging.log(logging.INFO, f'Creating checkpoint directory: {os.path.join(self.checkpoint_dir, self._experiment_name)}.')
os.mkdir(os.path.join(self.checkpoint_dir, self._experiment_name))

def _parse_file(self, fname: str) -> None:
with open(fname, 'r') as f:
config = yaml.safe_load(f)
return config
cfgs = read_config(config_path=config_path)
# TODO: Handle the case of transforms with a Compose object
self._transforms = build_object_from_config(cfgs.transforms_config, TRANSFORMS)
# TODO: Add support for random splits
self._dataset = build_object_from_config(cfgs.dataset_config, DATASETS)



def run(self) -> None:
# Must check if there is a checkpoint directory
Expand Down
4 changes: 4 additions & 0 deletions src/skelcast/core/optimizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import torch.optim as optim


PYTORCH_OPTIMIZERS = {name: getattr(optim, name) for name in dir(optim) if isinstance(getattr(optim, name), type) and issubclass(getattr(optim, name), optim.Optimizer)}
3 changes: 3 additions & 0 deletions src/skelcast/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,6 @@ def get_module(self, module_name):

def __str__(self):
return str(self._module_dict)

def __contains__(self, module_name):
return module_name in self._module_dict
5 changes: 3 additions & 2 deletions src/skelcast/experiments/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,14 @@ def __init__(self,
checkpoint_dir: str = None,
checkpoint_frequency: int = 1,
logger: BaseLogger = None,
log_gradient_info: bool = False) -> None:
log_gradient_info: bool = False,
collate_fn = None) -> None:
self.train_set = train_set
self.val_set = val_set
self.train_batch_size = train_batch_size
self.val_batch_size = val_batch_size
self.block_size = block_size
self._collate_fn = NTURGBDCollateFn(block_size=self.block_size, is_packed=True)
self._collate_fn = collate_fn if collate_fn is not None else NTURGBDCollateFn(block_size=self.block_size)
self.train_loader = DataLoader(dataset=self.train_set, batch_size=self.train_batch_size, shuffle=True, collate_fn=self._collate_fn)
self.val_loader = DataLoader(dataset=self.val_set, batch_size=self.val_batch_size, shuffle=False, collate_fn=self._collate_fn)
self.model = model
Expand Down
3 changes: 3 additions & 0 deletions src/skelcast/losses/torch_losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import torch.nn as nn

PYTORCH_LOSSES = {name: getattr(nn, name) for name in dir(nn) if isinstance(getattr(nn, name), type) and issubclass(getattr(nn, name), nn.Module) and 'Loss' in name}

0 comments on commit 699dbd2

Please sign in to comment.