-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #32 from kaseris/dev
Dev
- Loading branch information
Showing
13 changed files
with
305 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
dataset: | ||
name: 'NTURGBDDataset' | ||
args: | ||
missing_files_dir: 'data/missing' | ||
label_file: 'data/labels.txt' | ||
max_context_window: 10 | ||
max_number_of_bodies: 1 | ||
transforms: | ||
name: 'MinMaxScaleTransform' | ||
args: | ||
feature_scale: [0.0, 1.0] | ||
max_duration: 300 | ||
n_joints: 25 | ||
|
||
# Set the train data percentage | ||
train_data_percentage: 0.8 | ||
|
||
model: | ||
name: 'SimpleLSTMRegressor' | ||
args: | ||
hidden_size: 1024 | ||
num_layers: 2 | ||
linear_out: 1024 | ||
reduction: 'mean' | ||
batch_first: true | ||
n_joints: 25 | ||
n_dims: 3 | ||
|
||
runner: | ||
args: | ||
val_batch_size: 32 | ||
train_batch_size: 32 | ||
block_size: 8 | ||
device: 'cuda' | ||
logger: | ||
name: 'TensorboardLogger' | ||
args: | ||
save_dir: 'runs' | ||
checkpoint_dir: '/home/kaseris/Documents/checkpoints_forecasting' | ||
n_epochs: 10 | ||
lr: 0.00001 | ||
log_gradient_info: true |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
import os | ||
import logging | ||
|
||
import randomname | ||
import yaml | ||
|
||
import torch | ||
import torch.optim as optim | ||
from torch.utils.data import random_split | ||
|
||
from skelcast.models import MODELS | ||
from skelcast.data import DATASETS | ||
from skelcast.data import TRANSFORMS | ||
from skelcast.logger import LOGGERS | ||
|
||
from skelcast.experiments.runner import Runner | ||
|
||
torch.manual_seed(133742069) | ||
|
||
class Environment: | ||
""" | ||
The Environment class is designed to set up and manage the environment for training machine learning models. | ||
It includes methods for building models, datasets, loggers, and runners based on specified configurations. | ||
Attributes: | ||
_experiment_name (str): A randomly generated name for the experiment. | ||
checkpoint_dir (str): Directory path for storing model checkpoints. | ||
data_dir (str): Directory path where the dataset is located. | ||
config (dict, optional): Configuration settings for the model, dataset, logger, and runner. | ||
_model (object, optional): The instantiated machine learning model. | ||
_dataset (object, optional): The complete dataset. | ||
_train_dataset (object, optional): The training subset of the dataset. | ||
_val_dataset (object, optional): The validation subset of the dataset. | ||
_runner (object, optional): The training runner. | ||
_logger (object, optional): The logger for recording experiment results. | ||
Methods: | ||
experiment_name: Property that returns the experiment name. | ||
build_from_file(config_path): Parses the configuration file and builds the dataset, model, logger, and runner. | ||
run(): Starts the training process, either from scratch or by resuming from the latest checkpoint. | ||
Usage: | ||
1. Initialize the Environment with data and checkpoint directories. | ||
2. Call `build_from_file` with the path to a configuration file. | ||
3. Use `run` to start the training process. | ||
Note: | ||
This class is highly dependent on external modules and configurations. Ensure that all required modules | ||
and configurations are properly set up before using this class. | ||
""" | ||
def __init__(self, data_dir: str = '/home/kaseris/Documents/data_ntu_rbgd', | ||
checkpoint_dir = '/home/kaseris/Documents/checkpoints_forecasting') -> None: | ||
self._experiment_name = randomname.get_name() | ||
self.checkpoint_dir = checkpoint_dir | ||
self.data_dir = data_dir | ||
self.config = None | ||
self._model = None | ||
self._dataset = None | ||
self._train_dataset = None | ||
self._val_dataset = None | ||
self._runner = None | ||
self._logger = None | ||
|
||
@property | ||
def experiment_name(self) -> str: | ||
return self._experiment_name | ||
|
||
def build_from_file(self, config_path: str) -> None: | ||
config = self._parse_file(config_path) | ||
self.config = config | ||
logging.log(logging.INFO, f'Building environment from {config_path}.') | ||
self._build_dataset() | ||
self._build_model() | ||
self._build_logger() | ||
self._build_runner() | ||
|
||
def _build_model(self) -> None: | ||
logging.log(logging.INFO, 'Building model.') | ||
model_config = self.config['model'] | ||
_name = model_config.get('name') | ||
_args = model_config.get('args') | ||
self._model = MODELS.get_module(_name)(**_args) | ||
logging.log(logging.INFO, f'Model creation complete.') | ||
|
||
def _build_dataset(self) -> None: | ||
logging.log(logging.INFO, 'Building dataset.') | ||
dataset_config = self.config['dataset'] | ||
_name = dataset_config.get('name') | ||
_args = dataset_config.get('args') | ||
_transforms_cfg = dataset_config.get('args').get('transforms') | ||
_transforms = TRANSFORMS.get_module(_transforms_cfg.get('name'))(**_transforms_cfg.get('args')) | ||
_args['transforms'] = _transforms | ||
self._dataset = DATASETS.get_module(_name)(self.data_dir, **_args) | ||
# Split the dataset | ||
_train_len = int(self.config['train_data_percentage'] * len(self._dataset)) | ||
self._train_dataset, self._val_dataset = random_split(self._dataset, [_train_len, len(self._dataset) - _train_len]) | ||
logging.log(logging.INFO, f'Train set size: {len(self._train_dataset)}') | ||
|
||
def _build_logger(self) -> None: | ||
logging.log(logging.INFO, 'Building logger.') | ||
logger_config = self.config['runner']['args'].get('logger') | ||
logdir = os.path.join(logger_config['args']['save_dir'], self.experiment_name) | ||
self._logger = LOGGERS.get_module(logger_config['name'])(logdir) | ||
logging.log(logging.INFO, f'Logging to {logdir}.') | ||
|
||
def _build_runner(self) -> None: | ||
logging.log(logging.INFO, 'Building runner.') | ||
runner_config = self.config['runner'] | ||
_args = runner_config.get('args') | ||
_args['logger'] = self._logger | ||
_args['optimizer'] = optim.AdamW(self._model.parameters(), lr=_args.get('lr')) | ||
_args['train_set'] = self._train_dataset | ||
_args['val_set'] = self._val_dataset | ||
_args['model'] = self._model | ||
_args['train_set'] = self._train_dataset | ||
_args['val_set'] = self._val_dataset | ||
_args['checkpoint_dir'] = os.path.join(self.checkpoint_dir, self._experiment_name) | ||
self._runner = Runner(**_args) | ||
self._runner.setup() | ||
logging.log(logging.INFO, 'Runner setup complete.') | ||
|
||
def _create_checkpoint_dir(self) -> None: | ||
if os.path.exists(os.path.join(self.checkpoint_dir, self._experiment_name)): | ||
raise ValueError(f'Checkpoint directory {self.checkpoint_dir} already exists.') | ||
else: | ||
logging.log(logging.INFO, f'Creating checkpoint directory: {self.checkpoint_dir}.') | ||
os.mkdir(os.path.join(self.checkpoint_dir, self._experiment_name)) | ||
|
||
def _parse_file(self, fname: str) -> None: | ||
with open(fname, 'r') as f: | ||
config = yaml.safe_load(f) | ||
return config | ||
|
||
def run(self) -> None: | ||
# Must check if there is a checkpoint directory | ||
# If there is, load the latest checkpoint and continue training | ||
# Else, create a new checkpoint directory and start training | ||
# If there's not a checkpoint directory, use the self._runner.fit() method | ||
# Otherwise, use the self._runner.resume(path_to_checkpoint) method | ||
if not os.path.exists(os.path.join(self.checkpoint_dir, self._experiment_name)): | ||
self._create_checkpoint_dir() | ||
return self._runner.fit() | ||
else: | ||
return self._runner.resume(os.path.join(self.checkpoint_dir, self._experiment_name)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
class Registry: | ||
def __init__(self): | ||
self._module_dict = dict() | ||
|
||
def register_module(self, cls=None, module_name=None): | ||
""" | ||
A decorator to register a module. | ||
Args: | ||
- cls (class, optional): The class to be registered. | ||
- module_name (str, optional): The name under which the class will be registered. | ||
Defaults to the class name if not provided. | ||
""" | ||
|
||
def _register(cls): | ||
nonlocal module_name | ||
if module_name is None: | ||
module_name = cls.__name__ | ||
if module_name in self._module_dict: | ||
raise KeyError(f"{module_name} is already registered in {self.__class__.__name__}") | ||
self._module_dict[module_name] = cls | ||
return cls | ||
|
||
if cls is not None: | ||
return _register(cls) | ||
else: | ||
return _register | ||
|
||
def get_module(self, module_name): | ||
""" | ||
Retrieves a class by its registered name. | ||
Args: | ||
- module_name (str): The name of the module to retrieve. | ||
""" | ||
if module_name not in self._module_dict: | ||
raise KeyError(f"{module_name} is not registered in {self.__class__.__name__}") | ||
return self._module_dict[module_name] | ||
|
||
def __str__(self): | ||
return str(self._module_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from skelcast.core.registry import Registry | ||
|
||
DATASETS = Registry() | ||
COLLATE_FUNCS = Registry() | ||
TRANSFORMS = Registry() | ||
|
||
from .dataset import NTURGBDCollateFn, NTURGBDDataset | ||
from .transforms import MinMaxScaleTransform |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from skelcast.core.registry import Registry | ||
|
||
LOGGERS = Registry() | ||
|
||
from .tensorboard_logger import TensorboardLogger |
Oops, something went wrong.