Skip to content

Commit

Permalink
Merge pull request #49 from kaseris/compose-transforms
Browse files Browse the repository at this point in the history
Compose transforms
  • Loading branch information
kaseris committed Dec 12, 2023
2 parents 699dbd2 + 0457152 commit 4a48735
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 25 deletions.
7 changes: 4 additions & 3 deletions configs/pvred.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ transforms:
- name: MinMaxScaleTransform
args:
feature_scale: [0.0, 1.0]
- name: SomeOtherTransform
- name: CartToExpMapsTransform
args:
some_arg: some_value
parents: null

loss:
name: MSELoss
Expand All @@ -31,7 +31,7 @@ collate_fn:
logger:
name: TensorboardLogger
args:
save_dir: runs
log_dir: runs

optimizer:
name: AdamW
Expand Down Expand Up @@ -63,3 +63,4 @@ runner:
block_size: 8
log_gradient_info: true
device: cuda
n_epochs: 100
52 changes: 40 additions & 12 deletions src/skelcast/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import yaml

from collections import OrderedDict
from typing import Any, List
from typing import Any, List, Union

from skelcast.core.registry import Registry
from skelcast.data.transforms import Compose


class Config:
Expand Down Expand Up @@ -50,8 +51,23 @@ def __init__(self):


class TransformsConfig(Config):
def __init__(self):
super(TransformsConfig, self).__init__()
def __init__(self, transforms):
super().__init__()
self.set('args', self.parse_transforms(transforms))

def parse_transforms(self, transforms):
parsed_transforms = []
for transform in transforms:
# Assuming each transform in the list is a dictionary
transform_dict = {'name': transform.get('name'), 'args': transform.get('args', {})}
parsed_transforms.append(transform_dict)
return parsed_transforms

def get(self, key):
if key == 'args':
return self._config['args']
else:
return super().get(key)


class LoggerConfig(Config):
Expand Down Expand Up @@ -99,11 +115,21 @@ def __str__(self) -> str:
return s


def build_object_from_config(config: Config, registry: Registry, **kwargs):
_name = config.get('name')
_args = config.get('args')
_args.update(kwargs)
return registry.get_module(_name)(**_args)
def build_object_from_config(config: Config, registry: Union[Registry, dict], **kwargs):
if isinstance(config, TransformsConfig):
list_of_transforms = []
for transform in config.get('args'):
logging.debug(transform)
tf = registry.get_module(transform.get('name'))(**transform.get('args'))
list_of_transforms.append(tf)
return Compose(list_of_transforms)
else:
_name = config.get('name')
_args = config.get('args')
_args.update(kwargs)
if isinstance(registry, dict):
return registry[_name](**_args)
return registry.get_module(_name)(**_args)

def summarize_config(configs: List[Config]):
with open(f'/home/kaseris/Documents/mount/config.txt', 'w') as f:
Expand All @@ -129,15 +155,17 @@ def read_config(config_path: str):
data = yaml.safe_load(f)
cfgs = []
for key in data:
config = CONFIG_MAPPING[key]()
if key == 'transforms':
for element in data[key]:
logging.debug(f'Loading {key} config. Building {config.__class__.__name__} object.')
config.set(element['name'], element['args'])
# Initialize TransformsConfig with the transforms data
config = CONFIG_MAPPING[key](data[key])
logging.debug(f'Loading {key} config. Building {config.__class__.__name__} object.')
else:
# Initialize other configurations normally
config = CONFIG_MAPPING[key]()
logging.debug(f'Loading {key} config. Building {config.__class__.__name__} object.')
config.set('name', data[key]['name'])
config.set('args', data[key]['args'])

logging.debug(config)
cfgs.append(config)
return EnvironmentConfig(*cfgs)
57 changes: 49 additions & 8 deletions src/skelcast/core/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@

from skelcast.models import MODELS
from skelcast.data import DATASETS
from skelcast.data import TRANSFORMS
from skelcast.data import TRANSFORMS, COLLATE_FUNCS
from skelcast.logger import LOGGERS
from skelcast.losses import LOSSES
from skelcast.losses.torch_losses import PYTORCH_LOSSES
from skelcast.core.optimizers import PYTORCH_OPTIMIZERS

from skelcast.experiments.runner import Runner
from skelcast.core.config import read_config, build_object_from_config
Expand Down Expand Up @@ -50,17 +53,25 @@ class Environment:
and configurations are properly set up before using this class.
"""
def __init__(self, data_dir: str = '/home/kaseris/Documents/data_ntu_rbgd',
checkpoint_dir = '/home/kaseris/Documents/checkpoints_forecasting') -> None:
checkpoint_dir = '/home/kaseris/Documents/mount/checkpoints_forecasting',
train_set_size = 0.8) -> None:
self._experiment_name = randomname.get_name()
self.checkpoint_dir = checkpoint_dir
self.checkpoint_dir = os.path.join(checkpoint_dir, self._experiment_name)
os.mkdir(self.checkpoint_dir)
logging.info(f'Created checkpoint directory at {self.checkpoint_dir}')
self.data_dir = data_dir
self.train_set_size = train_set_size
self.config = None
self._model = None
self._dataset = None
self._train_dataset = None
self._val_dataset = None
self._runner = None
self._logger = None
self._loss = None
self._optimizer = None
self._collate_fn = None


@property
def experiment_name(self) -> str:
Expand All @@ -69,12 +80,42 @@ def experiment_name(self) -> str:
def build_from_file(self, config_path: str) -> None:
logging.log(logging.INFO, f'Building environment from {config_path}.')
cfgs = read_config(config_path=config_path)
# TODO: Handle the case of transforms with a Compose object
# Build tranforms first, because they are used in the dataset
self._transforms = build_object_from_config(cfgs.transforms_config, TRANSFORMS)
# TODO: Add support for random splits
self._dataset = build_object_from_config(cfgs.dataset_config, DATASETS)


# TODO: Add support for random splits. Maybe as external parameter?
self._dataset = build_object_from_config(cfgs.dataset_config, DATASETS, transforms=self._transforms)
logging.info(f'Loaded dataset from {self.data_dir}.')
# Build the loss first, because it is used in the model
loss_registry = LOSSES if cfgs.criterion_config.get('name') not in PYTORCH_LOSSES else PYTORCH_LOSSES
self._loss = build_object_from_config(cfgs.criterion_config, loss_registry)
logging.info(f'Loaded loss function {cfgs.criterion_config.get("name")}.')
self._model = build_object_from_config(cfgs.model_config, MODELS, loss_fn=self._loss)
logging.info(f'Loaded model {cfgs.model_config.get("name")}.')
# Build the optimizer
self._optimizer = build_object_from_config(cfgs.optimizer_config, PYTORCH_OPTIMIZERS, params=self._model.parameters())
logging.info(f'Loaded optimizer {cfgs.optimizer_config.get("name")}.')
# Build the logger
cfgs.logger_config.get('args').update({'log_dir': os.path.join(cfgs.logger_config.get('args').get('log_dir'), self._experiment_name)})
self._logger = build_object_from_config(cfgs.logger_config, LOGGERS)
logging.info(f'Created runs directory at {cfgs.logger_config.get("args").get("log_dir")}')
# Build the collate_fn
self._collate_fn = build_object_from_config(cfgs.collate_fn_config, COLLATE_FUNCS)
logging.info(f'Loaded collate function {cfgs.collate_fn_config.get("name")}.')
# Split the dataset into training and validation sets
train_size = int(self.train_set_size * len(self._dataset))
val_size = len(self._dataset) - train_size
self._train_dataset, self._val_dataset = random_split(self._dataset, [train_size, val_size])
# Build the runner
self._runner = Runner(model=self._model,
optimizer=self._optimizer,
logger=self._logger,
collate_fn=self._collate_fn,
train_set=self._train_dataset,
val_set=self._val_dataset,
checkpoint_dir=self.checkpoint_dir,
**cfgs.runner_config.get('args'))
logging.info(f'Finished building environment from {config_path}.')


def run(self) -> None:
# Must check if there is a checkpoint directory
Expand Down
14 changes: 12 additions & 2 deletions src/skelcast/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,20 @@ class CartToQuaternionTransform:

def __init__(self, parents: list = None) -> None:
if parents is None:
self.pareents = KinectSkeleton.parent_scheme()
self.parents = KinectSkeleton.parent_scheme()
else:
self.parents = parents

def __call__(self, x) -> Any:
_exps = xyz_to_expmap(x, self.pareents)
_exps = xyz_to_expmap(x, self.parents)
return exps_to_quats(_exps)

class Compose:
def __init__(self, transforms: list) -> None:
self.transforms = transforms

def __call__(self, x: torch.Tensor) -> torch.Tensor:
for t in self.transforms:
x = t(x)
return x

1 change: 1 addition & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
args.add_argument('--config', type=str, default='../configs/lstm_regressor_1024x1024.yaml')
args.add_argument('--data_dir', type=str, default='data')
args.add_argument('--checkpoint_dir', type=str, default='checkpoints')
args.add_argument('--train_set_size', type=float, default=0.8, required=False)

args = args.parse_args()

Expand Down

0 comments on commit 4a48735

Please sign in to comment.