From 9b69cc19d0d56952f21345b05847c1205ce777dc Mon Sep 17 00:00:00 2001 From: kaseris Date: Sun, 10 Dec 2023 16:19:01 +0200 Subject: [PATCH 01/15] Base config layers --- src/skelcast/core/config.py | 140 ++++++++++++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 src/skelcast/core/config.py diff --git a/src/skelcast/core/config.py b/src/skelcast/core/config.py new file mode 100644 index 0000000..9eb2920 --- /dev/null +++ b/src/skelcast/core/config.py @@ -0,0 +1,140 @@ +import abc + +from skelcast.core.registry import Registry + + +class Config(metaclass=abc.ABCMeta): + def __init__(self): + self._config = {} + + @abc.abstractmethod + def get(self, key): + pass + + @abc.abstractmethod + def set(self, key, value): + pass + + +class ModelConfig(Config): + def __init__(self): + super().__init__() + self._config['name'] = None + self._config['args'] = {} + + def get(self, key): + return self._config[key] + + def set(self, key, value): + self._config[key] = value + + +class DatasetConfig(Config): + def __init__(self): + super().__init__() + self._config['name'] = None + self._config['args'] = {} + + def get(self, key): + return self._config[key] + + def set(self, key, value): + self._config[key] = value + + +class TransformsConfig(Config): + def __init__(self): + super().__init__() + self._config['name'] = None + self._config['args'] = {} + + def get(self, key): + return self._config[key] + + def set(self, key, value): + self._config[key] = value + + +class LoggerConfig(Config): + def __init__(self): + super().__init__() + self._config['name'] = None + self._config['args'] = {} + + def get(self, key): + return self._config[key] + + def set(self, key, value): + self._config[key] = value + + +class OptimizerConfig(Config): + def __init__(self): + super().__init__() + self._config['name'] = None + self._config['args'] = {} + + def get(self, key): + return self._config[key] + + def set(self, key, value): + self._config[key] = value + + +class SchedulerConfig(Config): + def __init__(self): + super().__init__() + self._config['name'] = None + self._config['args'] = {} + + def get(self, key): + return self._config[key] + + def set(self, key, value): + self._config[key] = value + + +class CriterionConfig(Config): + def __init__(self): + super().__init__() + self._config['name'] = None + self._config['args'] = {} + + def get(self, key): + return self._config[key] + + def set(self, key, value): + self._config[key] = value + + +class CollateFnConfig(Config): + def __init__(self): + super().__init__() + self._config['name'] = None + self._config['args'] = {} + + def get(self, key): + return self._config[key] + + def set(self, key, value): + self._config[key] = value + + +class RunnerConfig(Config): + def __init__(self): + super().__init__() + self._config['name'] = None + self._config['args'] = {} + + def get(self, key): + return self._config[key] + + def set(self, key, value): + self._config[key] = value + + +def build_object_from_config(config: Config, registry: Registry, **kwargs): + _name = config.get('name') + _args = config.get('args') + _args.update(kwargs) + return registry[_name](**_args) From 8dc4beed1a81a305124d7929e96c40b7cdae137c Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Sun, 10 Dec 2023 16:23:59 +0200 Subject: [PATCH 02/15] Base config layers --- src/skelcast/core/config.py | 140 ++++++++++++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 src/skelcast/core/config.py diff --git a/src/skelcast/core/config.py b/src/skelcast/core/config.py new file mode 100644 index 0000000..9eb2920 --- /dev/null +++ b/src/skelcast/core/config.py @@ -0,0 +1,140 @@ +import abc + +from skelcast.core.registry import Registry + + +class Config(metaclass=abc.ABCMeta): + def __init__(self): + self._config = {} + + @abc.abstractmethod + def get(self, key): + pass + + @abc.abstractmethod + def set(self, key, value): + pass + + +class ModelConfig(Config): + def __init__(self): + super().__init__() + self._config['name'] = None + self._config['args'] = {} + + def get(self, key): + return self._config[key] + + def set(self, key, value): + self._config[key] = value + + +class DatasetConfig(Config): + def __init__(self): + super().__init__() + self._config['name'] = None + self._config['args'] = {} + + def get(self, key): + return self._config[key] + + def set(self, key, value): + self._config[key] = value + + +class TransformsConfig(Config): + def __init__(self): + super().__init__() + self._config['name'] = None + self._config['args'] = {} + + def get(self, key): + return self._config[key] + + def set(self, key, value): + self._config[key] = value + + +class LoggerConfig(Config): + def __init__(self): + super().__init__() + self._config['name'] = None + self._config['args'] = {} + + def get(self, key): + return self._config[key] + + def set(self, key, value): + self._config[key] = value + + +class OptimizerConfig(Config): + def __init__(self): + super().__init__() + self._config['name'] = None + self._config['args'] = {} + + def get(self, key): + return self._config[key] + + def set(self, key, value): + self._config[key] = value + + +class SchedulerConfig(Config): + def __init__(self): + super().__init__() + self._config['name'] = None + self._config['args'] = {} + + def get(self, key): + return self._config[key] + + def set(self, key, value): + self._config[key] = value + + +class CriterionConfig(Config): + def __init__(self): + super().__init__() + self._config['name'] = None + self._config['args'] = {} + + def get(self, key): + return self._config[key] + + def set(self, key, value): + self._config[key] = value + + +class CollateFnConfig(Config): + def __init__(self): + super().__init__() + self._config['name'] = None + self._config['args'] = {} + + def get(self, key): + return self._config[key] + + def set(self, key, value): + self._config[key] = value + + +class RunnerConfig(Config): + def __init__(self): + super().__init__() + self._config['name'] = None + self._config['args'] = {} + + def get(self, key): + return self._config[key] + + def set(self, key, value): + self._config[key] = value + + +def build_object_from_config(config: Config, registry: Registry, **kwargs): + _name = config.get('name') + _args = config.get('args') + _args.update(kwargs) + return registry[_name](**_args) From 51d3ef3b0f7ff0b44383dc38eb954dd811d8e3c6 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Sun, 10 Dec 2023 16:26:47 +0200 Subject: [PATCH 03/15] Summarize config --- src/skelcast/core/config.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/skelcast/core/config.py b/src/skelcast/core/config.py index 9eb2920..536f493 100644 --- a/src/skelcast/core/config.py +++ b/src/skelcast/core/config.py @@ -138,3 +138,8 @@ def build_object_from_config(config: Config, registry: Registry, **kwargs): _args = config.get('args') _args.update(kwargs) return registry[_name](**_args) + +def summarize_config(config: Config): + _name = config.get('name') + _args = config.get('args') + return f'{_name}({_args})' From 707ae6bd377e20c49c1a6771054842694f166aa7 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Sun, 10 Dec 2023 16:49:30 +0200 Subject: [PATCH 04/15] Extract pytorch losses that are not implemented in skelcast --- src/skelcast/losses/__init__.py | 3 ++- src/skelcast/losses/torch_losses.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 src/skelcast/losses/torch_losses.py diff --git a/src/skelcast/losses/__init__.py b/src/skelcast/losses/__init__.py index 962aeb8..8e4610c 100644 --- a/src/skelcast/losses/__init__.py +++ b/src/skelcast/losses/__init__.py @@ -2,4 +2,5 @@ LOSSES = Registry() -from .logloss import LogLoss \ No newline at end of file +from .logloss import LogLoss +from .torch_losses import PYTORCH_LOSSES \ No newline at end of file diff --git a/src/skelcast/losses/torch_losses.py b/src/skelcast/losses/torch_losses.py new file mode 100644 index 0000000..d899da9 --- /dev/null +++ b/src/skelcast/losses/torch_losses.py @@ -0,0 +1,3 @@ +import torch.nn as nn + +PYTORCH_LOSSES = {name: getattr(nn, name) for name in dir(nn) if isinstance(getattr(nn, name), type) and issubclass(getattr(nn, name), nn.Module) and 'Loss' in name} From d60a01a6274e4b521e1bf8c31609fa4f12ffd1ba Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Sun, 10 Dec 2023 16:54:37 +0200 Subject: [PATCH 05/15] Fix circular imports --- src/skelcast/losses/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/skelcast/losses/__init__.py b/src/skelcast/losses/__init__.py index 8e4610c..962aeb8 100644 --- a/src/skelcast/losses/__init__.py +++ b/src/skelcast/losses/__init__.py @@ -2,5 +2,4 @@ LOSSES = Registry() -from .logloss import LogLoss -from .torch_losses import PYTORCH_LOSSES \ No newline at end of file +from .logloss import LogLoss \ No newline at end of file From a00b1e7879f415091058173b89a231544a41a28e Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Sun, 10 Dec 2023 16:58:00 +0200 Subject: [PATCH 06/15] Support `in` keyword --- src/skelcast/core/registry.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/skelcast/core/registry.py b/src/skelcast/core/registry.py index 5905d02..bd9e017 100644 --- a/src/skelcast/core/registry.py +++ b/src/skelcast/core/registry.py @@ -39,3 +39,6 @@ def get_module(self, module_name): def __str__(self): return str(self._module_dict) + + def __contains__(self, module_name): + return module_name in self._module_dict \ No newline at end of file From bdbcbec9eb1809d724c385b728c21feb16d0217d Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Mon, 11 Dec 2023 13:36:22 +0200 Subject: [PATCH 07/15] Get rid of redundant lines, read configs from yaml keys --- src/skelcast/core/config.py | 155 +++++++++++++++--------------------- 1 file changed, 66 insertions(+), 89 deletions(-) diff --git a/src/skelcast/core/config.py b/src/skelcast/core/config.py index 536f493..ec90a24 100644 --- a/src/skelcast/core/config.py +++ b/src/skelcast/core/config.py @@ -1,24 +1,15 @@ import abc +import logging +import yaml + +from typing import List from skelcast.core.registry import Registry -class Config(metaclass=abc.ABCMeta): +class Config: def __init__(self): self._config = {} - - @abc.abstractmethod - def get(self, key): - pass - - @abc.abstractmethod - def set(self, key, value): - pass - - -class ModelConfig(Config): - def __init__(self): - super().__init__() self._config['name'] = None self._config['args'] = {} @@ -28,109 +19,63 @@ def get(self, key): def set(self, key, value): self._config[key] = value + def __str__(self) -> str: + s = self.__class__.__name__ + '(\n' + for key, val in self._config.items(): + if isinstance(val, dict): + s += f'\t{key}: \n' + for k, v in val.items(): + s += f'\t\t{k}: {v}\n' + s += '\t\n' + else: + s += f'\t{key}: {val}\n' + s += ')' + return s + -class DatasetConfig(Config): +class ModelConfig(Config): def __init__(self): - super().__init__() - self._config['name'] = None - self._config['args'] = {} + super(ModelConfig, self).__init__() - def get(self, key): - return self._config[key] - def set(self, key, value): - self._config[key] = value +class DatasetConfig(Config): + def __init__(self): + super(DatasetConfig, self).__init__() class TransformsConfig(Config): def __init__(self): - super().__init__() - self._config['name'] = None - self._config['args'] = {} - - def get(self, key): - return self._config[key] - - def set(self, key, value): - self._config[key] = value + super(TransformsConfig, self).__init__() class LoggerConfig(Config): def __init__(self): - super().__init__() - self._config['name'] = None - self._config['args'] = {} - - def get(self, key): - return self._config[key] - - def set(self, key, value): - self._config[key] = value + super(LoggerConfig, self).__init__() class OptimizerConfig(Config): def __init__(self): - super().__init__() - self._config['name'] = None - self._config['args'] = {} - - def get(self, key): - return self._config[key] - - def set(self, key, value): - self._config[key] = value + super(OptimizerConfig, self).__init__() class SchedulerConfig(Config): def __init__(self): - super().__init__() - self._config['name'] = None - self._config['args'] = {} - - def get(self, key): - return self._config[key] - - def set(self, key, value): - self._config[key] = value + super(SchedulerConfig, self).__init__() class CriterionConfig(Config): def __init__(self): - super().__init__() - self._config['name'] = None - self._config['args'] = {} - - def get(self, key): - return self._config[key] - - def set(self, key, value): - self._config[key] = value + super(CriterionConfig, self).__init__() class CollateFnConfig(Config): def __init__(self): - super().__init__() - self._config['name'] = None - self._config['args'] = {} - - def get(self, key): - return self._config[key] - - def set(self, key, value): - self._config[key] = value + super(CollateFnConfig, self).__init__() class RunnerConfig(Config): def __init__(self): - super().__init__() - self._config['name'] = None - self._config['args'] = {} - - def get(self, key): - return self._config[key] - - def set(self, key, value): - self._config[key] = value + super(RunnerConfig, self).__init__() def build_object_from_config(config: Config, registry: Registry, **kwargs): @@ -139,7 +84,39 @@ def build_object_from_config(config: Config, registry: Registry, **kwargs): _args.update(kwargs) return registry[_name](**_args) -def summarize_config(config: Config): - _name = config.get('name') - _args = config.get('args') - return f'{_name}({_args})' +def summarize_config(configs: List[Config]): + with open(f'/home/kaseris/Documents/mount/config.txt', 'w') as f: + for config in configs: + f.write(str(config)) + f.write('\n\n') + + +CONFIG_MAPPING = { + 'model': ModelConfig, + 'dataset': DatasetConfig, + 'transforms': TransformsConfig, + 'logger': LoggerConfig, + 'optimizer': OptimizerConfig, + 'scheduler': SchedulerConfig, + 'loss': CriterionConfig, + 'collate_fn': CollateFnConfig, + 'runner': RunnerConfig +} + +def read_config(config_path: str): + with open(config_path, 'r') as f: + data = yaml.safe_load(f) + cfgs = [] + for key in data: + config = CONFIG_MAPPING[key]() + logging.debug(f'Loading {key} config. Building {config.__class__.__name__} object.') + config.set('name', data[key]['name']) + config.set('args', data[key]['args']) + logging.debug(config) + cfgs.append(config) + return cfgs + +if __name__ == '__main__': + log_format = '[%(asctime)s] %(levelname)s %(message)s' + logging.basicConfig(level=logging.DEBUG, format=log_format) + cfgs = read_config('configs/pvred.yaml') From cb01dd21a9a7ac2131a9f02558ae1d737e3925d1 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Mon, 11 Dec 2023 13:37:18 +0200 Subject: [PATCH 08/15] Cleanup --- src/skelcast/core/config.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/skelcast/core/config.py b/src/skelcast/core/config.py index ec90a24..ea128ed 100644 --- a/src/skelcast/core/config.py +++ b/src/skelcast/core/config.py @@ -115,8 +115,3 @@ def read_config(config_path: str): logging.debug(config) cfgs.append(config) return cfgs - -if __name__ == '__main__': - log_format = '[%(asctime)s] %(levelname)s %(message)s' - logging.basicConfig(level=logging.DEBUG, format=log_format) - cfgs = read_config('configs/pvred.yaml') From fc0e48b94d5bedd6d8f8b37bd76f299d611ddd43 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Mon, 11 Dec 2023 14:21:04 +0200 Subject: [PATCH 09/15] Added pytorch optimizers shortcut --- src/skelcast/core/optimizers.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 src/skelcast/core/optimizers.py diff --git a/src/skelcast/core/optimizers.py b/src/skelcast/core/optimizers.py new file mode 100644 index 0000000..9ed4cb1 --- /dev/null +++ b/src/skelcast/core/optimizers.py @@ -0,0 +1,4 @@ +import torch.optim as optim + + +PYTORCH_OPTIMIZERS = {name: getattr(optim, name) for name in dir(optim) if isinstance(getattr(optim, name), type) and issubclass(getattr(optim, name), optim.Optimizer)} From 945068d060b31152cc71c9776e0526999505fb24 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Mon, 11 Dec 2023 14:48:21 +0200 Subject: [PATCH 10/15] Workaround for the transforms config --- src/skelcast/core/config.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/skelcast/core/config.py b/src/skelcast/core/config.py index ea128ed..7f6134c 100644 --- a/src/skelcast/core/config.py +++ b/src/skelcast/core/config.py @@ -2,6 +2,7 @@ import logging import yaml +from collections import OrderedDict from typing import List from skelcast.core.registry import Registry @@ -9,7 +10,7 @@ class Config: def __init__(self): - self._config = {} + self._config = OrderedDict() self._config['name'] = None self._config['args'] = {} @@ -17,7 +18,12 @@ def get(self, key): return self._config[key] def set(self, key, value): - self._config[key] = value + if isinstance(value, list): + self._config[key] = [] + for v in value: + self._config[key].append(v) + else: + self._config[key] = value def __str__(self) -> str: s = self.__class__.__name__ + '(\n' @@ -109,9 +115,14 @@ def read_config(config_path: str): cfgs = [] for key in data: config = CONFIG_MAPPING[key]() - logging.debug(f'Loading {key} config. Building {config.__class__.__name__} object.') - config.set('name', data[key]['name']) - config.set('args', data[key]['args']) + if key == 'transforms': + for element in data[key]: + logging.debug(f'Loading {key} config. Building {config.__class__.__name__} object.') + config.set(element['name'], element['args']) + else: + logging.debug(f'Loading {key} config. Building {config.__class__.__name__} object.') + config.set('name', data[key]['name']) + config.set('args', data[key]['args']) logging.debug(config) cfgs.append(config) return cfgs From f7b2718182dfff1dc6426bd5d8bcf100d8773c1d Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Mon, 11 Dec 2023 15:51:58 +0200 Subject: [PATCH 11/15] Pass collate fn as a parameter --- src/skelcast/experiments/runner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/skelcast/experiments/runner.py b/src/skelcast/experiments/runner.py index 60c43fe..a70c1a9 100644 --- a/src/skelcast/experiments/runner.py +++ b/src/skelcast/experiments/runner.py @@ -73,13 +73,14 @@ def __init__(self, checkpoint_dir: str = None, checkpoint_frequency: int = 1, logger: BaseLogger = None, - log_gradient_info: bool = False) -> None: + log_gradient_info: bool = False, + collate_fn = None) -> None: self.train_set = train_set self.val_set = val_set self.train_batch_size = train_batch_size self.val_batch_size = val_batch_size self.block_size = block_size - self._collate_fn = NTURGBDCollateFn(block_size=self.block_size, is_packed=True) + self._collate_fn = collate_fn if collate_fn is not None else NTURGBDCollateFn(block_size=self.block_size) self.train_loader = DataLoader(dataset=self.train_set, batch_size=self.train_batch_size, shuffle=True, collate_fn=self._collate_fn) self.val_loader = DataLoader(dataset=self.val_set, batch_size=self.val_batch_size, shuffle=False, collate_fn=self._collate_fn) self.model = model From 18bae031fc7d858e6136e06290701188e2e055da Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Mon, 11 Dec 2023 16:14:03 +0200 Subject: [PATCH 12/15] Build from config small fix --- src/skelcast/core/config.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/skelcast/core/config.py b/src/skelcast/core/config.py index 7f6134c..71ae6f3 100644 --- a/src/skelcast/core/config.py +++ b/src/skelcast/core/config.py @@ -1,9 +1,9 @@ -import abc +import re import logging import yaml from collections import OrderedDict -from typing import List +from typing import Any, List from skelcast.core.registry import Registry @@ -83,12 +83,27 @@ class RunnerConfig(Config): def __init__(self): super(RunnerConfig, self).__init__() +class EnvironmentConfig: + def __init__(self, *args) -> None: + for arg in args: + name = arg.__class__.__name__ + split_name = re.findall('[A-Z][^A-Z]*', name) + name = '_'.join([s.lower() for s in split_name]) + setattr(self, name, arg) + + def __str__(self) -> str: + s = self.__class__.__name__ + '(\n' + for key, val in self.__dict__.items(): + s += f'\t{key}: {val}\n' + s += ')' + return s + def build_object_from_config(config: Config, registry: Registry, **kwargs): _name = config.get('name') _args = config.get('args') _args.update(kwargs) - return registry[_name](**_args) + return registry.get_module(_name)(**_args) def summarize_config(configs: List[Config]): with open(f'/home/kaseris/Documents/mount/config.txt', 'w') as f: @@ -125,4 +140,4 @@ def read_config(config_path: str): config.set('args', data[key]['args']) logging.debug(config) cfgs.append(config) - return cfgs + return EnvironmentConfig(*cfgs) From ab977d355a512b956b0404e1133daf8e619f8eed Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Mon, 11 Dec 2023 16:14:27 +0200 Subject: [PATCH 13/15] Building steps for the experiment. --- src/skelcast/core/environment.py | 70 +++----------------------------- 1 file changed, 6 insertions(+), 64 deletions(-) diff --git a/src/skelcast/core/environment.py b/src/skelcast/core/environment.py index c3194d3..c3d74e0 100644 --- a/src/skelcast/core/environment.py +++ b/src/skelcast/core/environment.py @@ -14,6 +14,7 @@ from skelcast.logger import LOGGERS from skelcast.experiments.runner import Runner +from skelcast.core.config import read_config, build_object_from_config torch.manual_seed(133742069) @@ -66,71 +67,12 @@ def experiment_name(self) -> str: return self._experiment_name def build_from_file(self, config_path: str) -> None: - config = self._parse_file(config_path) - self.config = config logging.log(logging.INFO, f'Building environment from {config_path}.') - self._build_dataset() - self._build_model() - self._build_logger() - self._build_runner() - - def _build_model(self) -> None: - logging.log(logging.INFO, 'Building model.') - model_config = self.config['model'] - _name = model_config.get('name') - _args = model_config.get('args') - self._model = MODELS.get_module(_name)(**_args) - logging.log(logging.INFO, f'Model creation complete.') - - def _build_dataset(self) -> None: - logging.log(logging.INFO, 'Building dataset.') - dataset_config = self.config['dataset'] - _name = dataset_config.get('name') - _args = dataset_config.get('args') - _transforms_cfg = dataset_config.get('args').get('transforms') - _transforms = TRANSFORMS.get_module(_transforms_cfg.get('name'))(**_transforms_cfg.get('args')) - _args['transforms'] = _transforms - self._dataset = DATASETS.get_module(_name)(self.data_dir, **_args) - # Split the dataset - _train_len = int(self.config['train_data_percentage'] * len(self._dataset)) - self._train_dataset, self._val_dataset = random_split(self._dataset, [_train_len, len(self._dataset) - _train_len]) - logging.log(logging.INFO, f'Train set size: {len(self._train_dataset)}') - - def _build_logger(self) -> None: - logging.log(logging.INFO, 'Building logger.') - logger_config = self.config['runner']['args'].get('logger') - logdir = os.path.join(logger_config['args']['save_dir'], self.experiment_name) - self._logger = LOGGERS.get_module(logger_config['name'])(logdir) - logging.log(logging.INFO, f'Logging to {logdir}.') - - def _build_runner(self) -> None: - logging.log(logging.INFO, 'Building runner.') - runner_config = self.config['runner'] - _args = runner_config.get('args') - _args['logger'] = self._logger - _args['optimizer'] = optim.AdamW(self._model.parameters(), lr=_args.get('lr')) - _args['train_set'] = self._train_dataset - _args['val_set'] = self._val_dataset - _args['model'] = self._model - _args['train_set'] = self._train_dataset - _args['val_set'] = self._val_dataset - _args['checkpoint_dir'] = os.path.join(self.checkpoint_dir, self._experiment_name) - self._create_checkpoint_dir() - self._runner = Runner(**_args) - self._runner.setup() - logging.log(logging.INFO, 'Runner setup complete.') - - def _create_checkpoint_dir(self) -> None: - if os.path.exists(os.path.join(self.checkpoint_dir, self._experiment_name)): - raise ValueError(f'Checkpoint directory {os.path.join(self.checkpoint_dir, self._experiment_name)} already exists.') - else: - logging.log(logging.INFO, f'Creating checkpoint directory: {os.path.join(self.checkpoint_dir, self._experiment_name)}.') - os.mkdir(os.path.join(self.checkpoint_dir, self._experiment_name)) - - def _parse_file(self, fname: str) -> None: - with open(fname, 'r') as f: - config = yaml.safe_load(f) - return config + cfgs = read_config(config_path=config_path) + self._dataset = build_object_from_config(cfgs.dataset_config, DATASETS) + # TODO: Add support for random splits + # TODO: Handle the case of transforms with a Compose object + self._transforms = build_object_from_config(cfgs.transforms_config, TRANSFORMS) def run(self) -> None: # Must check if there is a checkpoint directory From 2cd1014d747bdc79403102715aaabc25f822c8c9 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Mon, 11 Dec 2023 16:14:51 +0200 Subject: [PATCH 14/15] New schema for experiment creation --- configs/pvred.yaml | 65 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 configs/pvred.yaml diff --git a/configs/pvred.yaml b/configs/pvred.yaml new file mode 100644 index 0000000..cf70d67 --- /dev/null +++ b/configs/pvred.yaml @@ -0,0 +1,65 @@ +dataset: + name: NTURGBDDataset + args: + data_directory: /home/kaseris/Documents/mount/data_ntu_rgbd + label_file: /home/kaseris/Documents/dev/skelcast/data/labels.txt + missing_files_dir: /home/kaseris/Documents/dev/skelcast/data/missing + max_context_window: 10 + max_number_of_bodies: 1 + max_duration: 300 + n_joints: 25 + cache_file: /home/kaseris/Documents/mount/dataset_cache.pkl + +transforms: + - name: MinMaxScaleTransform + args: + feature_scale: [0.0, 1.0] + - name: SomeOtherTransform + args: + some_arg: some_value + +loss: + name: MSELoss + args: + reduction: mean + +collate_fn: + name: NTURGBDCollateFnWithRandomSampledContextWindow + args: + block_size: 25 + +logger: + name: TensorboardLogger + args: + save_dir: runs + +optimizer: + name: AdamW + args: + lr: 0.0001 + weight_decay: 0.0001 + +model: + name: PositionalVelocityRecurrentEncoderDecoder + args: + input_dim: 75 + enc_hidden_dim: 64 + dec_hidden_dim: 64 + enc_type: lstm + dec_type: lstm + include_velocity: false + pos_enc: add + batch_first: true + std_thresh: 0.0001 + use_std_mask: false + use_padded_len_mask: true + observe_until: 20 + +runner: + name: Runner + args: + train_batch_size: 32 + val_batch_size: 32 + block_size: 8 + log_gradient_info: true + device: cuda From a26df9b5d997e20d87251b8cd9f2ff9e0a6d230e Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Tue, 12 Dec 2023 11:27:52 +0200 Subject: [PATCH 15/15] Rearrange operations --- src/skelcast/core/environment.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/skelcast/core/environment.py b/src/skelcast/core/environment.py index c3d74e0..fdc5b7b 100644 --- a/src/skelcast/core/environment.py +++ b/src/skelcast/core/environment.py @@ -69,10 +69,12 @@ def experiment_name(self) -> str: def build_from_file(self, config_path: str) -> None: logging.log(logging.INFO, f'Building environment from {config_path}.') cfgs = read_config(config_path=config_path) - self._dataset = build_object_from_config(cfgs.dataset_config, DATASETS) - # TODO: Add support for random splits # TODO: Handle the case of transforms with a Compose object self._transforms = build_object_from_config(cfgs.transforms_config, TRANSFORMS) + # TODO: Add support for random splits + self._dataset = build_object_from_config(cfgs.dataset_config, DATASETS) + + def run(self) -> None: # Must check if there is a checkpoint directory