Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/environment build #48

Merged
merged 16 commits into from
Dec 12, 2023
65 changes: 65 additions & 0 deletions configs/pvred.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
dataset:
name: NTURGBDDataset
args:
data_directory: /home/kaseris/Documents/mount/data_ntu_rgbd
label_file: /home/kaseris/Documents/dev/skelcast/data/labels.txt
missing_files_dir: /home/kaseris/Documents/dev/skelcast/data/missing
max_context_window: 10
max_number_of_bodies: 1
max_duration: 300
n_joints: 25
cache_file: /home/kaseris/Documents/mount/dataset_cache.pkl

transforms:
- name: MinMaxScaleTransform
args:
feature_scale: [0.0, 1.0]
- name: SomeOtherTransform
args:
some_arg: some_value

loss:
name: MSELoss
args:
reduction: mean

collate_fn:
name: NTURGBDCollateFnWithRandomSampledContextWindow
args:
block_size: 25

logger:
name: TensorboardLogger
args:
save_dir: runs

optimizer:
name: AdamW
args:
lr: 0.0001
weight_decay: 0.0001

model:
name: PositionalVelocityRecurrentEncoderDecoder
args:
input_dim: 75
enc_hidden_dim: 64
dec_hidden_dim: 64
enc_type: lstm
dec_type: lstm
include_velocity: false
pos_enc: add
batch_first: true
std_thresh: 0.0001
use_std_mask: false
use_padded_len_mask: true
observe_until: 20

runner:
name: Runner
args:
train_batch_size: 32
val_batch_size: 32
block_size: 8
log_gradient_info: true
device: cuda
143 changes: 143 additions & 0 deletions src/skelcast/core/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import re
import logging
import yaml

from collections import OrderedDict
from typing import Any, List

from skelcast.core.registry import Registry


class Config:
def __init__(self):
self._config = OrderedDict()
self._config['name'] = None
self._config['args'] = {}

def get(self, key):
return self._config[key]

def set(self, key, value):
if isinstance(value, list):
self._config[key] = []
for v in value:
self._config[key].append(v)
else:
self._config[key] = value

def __str__(self) -> str:
s = self.__class__.__name__ + '(\n'
for key, val in self._config.items():
if isinstance(val, dict):
s += f'\t{key}: \n'
for k, v in val.items():
s += f'\t\t{k}: {v}\n'
s += '\t\n'
else:
s += f'\t{key}: {val}\n'
s += ')'
return s


class ModelConfig(Config):
def __init__(self):
super(ModelConfig, self).__init__()


class DatasetConfig(Config):
def __init__(self):
super(DatasetConfig, self).__init__()


class TransformsConfig(Config):
def __init__(self):
super(TransformsConfig, self).__init__()


class LoggerConfig(Config):
def __init__(self):
super(LoggerConfig, self).__init__()


class OptimizerConfig(Config):
def __init__(self):
super(OptimizerConfig, self).__init__()


class SchedulerConfig(Config):
def __init__(self):
super(SchedulerConfig, self).__init__()


class CriterionConfig(Config):
def __init__(self):
super(CriterionConfig, self).__init__()


class CollateFnConfig(Config):
def __init__(self):
super(CollateFnConfig, self).__init__()


class RunnerConfig(Config):
def __init__(self):
super(RunnerConfig, self).__init__()

class EnvironmentConfig:
def __init__(self, *args) -> None:
for arg in args:
name = arg.__class__.__name__
split_name = re.findall('[A-Z][^A-Z]*', name)
name = '_'.join([s.lower() for s in split_name])
setattr(self, name, arg)

def __str__(self) -> str:
s = self.__class__.__name__ + '(\n'
for key, val in self.__dict__.items():
s += f'\t{key}: {val}\n'
s += ')'
return s


def build_object_from_config(config: Config, registry: Registry, **kwargs):
_name = config.get('name')
_args = config.get('args')
_args.update(kwargs)
return registry.get_module(_name)(**_args)

def summarize_config(configs: List[Config]):
with open(f'/home/kaseris/Documents/mount/config.txt', 'w') as f:
for config in configs:
f.write(str(config))
f.write('\n\n')


CONFIG_MAPPING = {
'model': ModelConfig,
'dataset': DatasetConfig,
'transforms': TransformsConfig,
'logger': LoggerConfig,
'optimizer': OptimizerConfig,
'scheduler': SchedulerConfig,
'loss': CriterionConfig,
'collate_fn': CollateFnConfig,
'runner': RunnerConfig
}

def read_config(config_path: str):
with open(config_path, 'r') as f:
data = yaml.safe_load(f)
cfgs = []
for key in data:
config = CONFIG_MAPPING[key]()
if key == 'transforms':
for element in data[key]:
logging.debug(f'Loading {key} config. Building {config.__class__.__name__} object.')
config.set(element['name'], element['args'])
else:
logging.debug(f'Loading {key} config. Building {config.__class__.__name__} object.')
config.set('name', data[key]['name'])
config.set('args', data[key]['args'])
logging.debug(config)
cfgs.append(config)
return EnvironmentConfig(*cfgs)
72 changes: 8 additions & 64 deletions src/skelcast/core/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from skelcast.logger import LOGGERS

from skelcast.experiments.runner import Runner
from skelcast.core.config import read_config, build_object_from_config

torch.manual_seed(133742069)

Expand Down Expand Up @@ -66,71 +67,14 @@ def experiment_name(self) -> str:
return self._experiment_name

def build_from_file(self, config_path: str) -> None:
config = self._parse_file(config_path)
self.config = config
logging.log(logging.INFO, f'Building environment from {config_path}.')
self._build_dataset()
self._build_model()
self._build_logger()
self._build_runner()

def _build_model(self) -> None:
logging.log(logging.INFO, 'Building model.')
model_config = self.config['model']
_name = model_config.get('name')
_args = model_config.get('args')
self._model = MODELS.get_module(_name)(**_args)
logging.log(logging.INFO, f'Model creation complete.')

def _build_dataset(self) -> None:
logging.log(logging.INFO, 'Building dataset.')
dataset_config = self.config['dataset']
_name = dataset_config.get('name')
_args = dataset_config.get('args')
_transforms_cfg = dataset_config.get('args').get('transforms')
_transforms = TRANSFORMS.get_module(_transforms_cfg.get('name'))(**_transforms_cfg.get('args'))
_args['transforms'] = _transforms
self._dataset = DATASETS.get_module(_name)(self.data_dir, **_args)
# Split the dataset
_train_len = int(self.config['train_data_percentage'] * len(self._dataset))
self._train_dataset, self._val_dataset = random_split(self._dataset, [_train_len, len(self._dataset) - _train_len])
logging.log(logging.INFO, f'Train set size: {len(self._train_dataset)}')

def _build_logger(self) -> None:
logging.log(logging.INFO, 'Building logger.')
logger_config = self.config['runner']['args'].get('logger')
logdir = os.path.join(logger_config['args']['save_dir'], self.experiment_name)
self._logger = LOGGERS.get_module(logger_config['name'])(logdir)
logging.log(logging.INFO, f'Logging to {logdir}.')

def _build_runner(self) -> None:
logging.log(logging.INFO, 'Building runner.')
runner_config = self.config['runner']
_args = runner_config.get('args')
_args['logger'] = self._logger
_args['optimizer'] = optim.AdamW(self._model.parameters(), lr=_args.get('lr'))
_args['train_set'] = self._train_dataset
_args['val_set'] = self._val_dataset
_args['model'] = self._model
_args['train_set'] = self._train_dataset
_args['val_set'] = self._val_dataset
_args['checkpoint_dir'] = os.path.join(self.checkpoint_dir, self._experiment_name)
self._create_checkpoint_dir()
self._runner = Runner(**_args)
self._runner.setup()
logging.log(logging.INFO, 'Runner setup complete.')

def _create_checkpoint_dir(self) -> None:
if os.path.exists(os.path.join(self.checkpoint_dir, self._experiment_name)):
raise ValueError(f'Checkpoint directory {os.path.join(self.checkpoint_dir, self._experiment_name)} already exists.')
else:
logging.log(logging.INFO, f'Creating checkpoint directory: {os.path.join(self.checkpoint_dir, self._experiment_name)}.')
os.mkdir(os.path.join(self.checkpoint_dir, self._experiment_name))

def _parse_file(self, fname: str) -> None:
with open(fname, 'r') as f:
config = yaml.safe_load(f)
return config
cfgs = read_config(config_path=config_path)
# TODO: Handle the case of transforms with a Compose object
self._transforms = build_object_from_config(cfgs.transforms_config, TRANSFORMS)
# TODO: Add support for random splits
self._dataset = build_object_from_config(cfgs.dataset_config, DATASETS)



def run(self) -> None:
# Must check if there is a checkpoint directory
Expand Down
4 changes: 4 additions & 0 deletions src/skelcast/core/optimizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import torch.optim as optim


PYTORCH_OPTIMIZERS = {name: getattr(optim, name) for name in dir(optim) if isinstance(getattr(optim, name), type) and issubclass(getattr(optim, name), optim.Optimizer)}
3 changes: 3 additions & 0 deletions src/skelcast/core/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,6 @@ def get_module(self, module_name):

def __str__(self):
return str(self._module_dict)

def __contains__(self, module_name):
return module_name in self._module_dict
5 changes: 3 additions & 2 deletions src/skelcast/experiments/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,14 @@ def __init__(self,
checkpoint_dir: str = None,
checkpoint_frequency: int = 1,
logger: BaseLogger = None,
log_gradient_info: bool = False) -> None:
log_gradient_info: bool = False,
collate_fn = None) -> None:
self.train_set = train_set
self.val_set = val_set
self.train_batch_size = train_batch_size
self.val_batch_size = val_batch_size
self.block_size = block_size
self._collate_fn = NTURGBDCollateFn(block_size=self.block_size, is_packed=True)
self._collate_fn = collate_fn if collate_fn is not None else NTURGBDCollateFn(block_size=self.block_size)
self.train_loader = DataLoader(dataset=self.train_set, batch_size=self.train_batch_size, shuffle=True, collate_fn=self._collate_fn)
self.val_loader = DataLoader(dataset=self.val_set, batch_size=self.val_batch_size, shuffle=False, collate_fn=self._collate_fn)
self.model = model
Expand Down
3 changes: 3 additions & 0 deletions src/skelcast/losses/torch_losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import torch.nn as nn

PYTORCH_LOSSES = {name: getattr(nn, name) for name in dir(nn) if isinstance(getattr(nn, name), type) and issubclass(getattr(nn, name), nn.Module) and 'Loss' in name}
Loading