From 7be36c4c8988cf8ac1623a39c18f5a953a1ed114 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Fri, 9 Feb 2024 15:32:26 +0200 Subject: [PATCH] Add the sample class and a custom collate fn for the Human3.6m --- src/skelcast/core/environment.py | 88 ++++++++++++++++++-------- src/skelcast/data/human36m/human36m.py | 63 ++++++++++++++++-- 2 files changed, 118 insertions(+), 33 deletions(-) diff --git a/src/skelcast/core/environment.py b/src/skelcast/core/environment.py index 1f08fe6..bfadc33 100644 --- a/src/skelcast/core/environment.py +++ b/src/skelcast/core/environment.py @@ -21,11 +21,12 @@ torch.manual_seed(133742069) + class Environment: """ The Environment class is designed to set up and manage the environment for training machine learning models. It includes methods for building models, datasets, loggers, and runners based on specified configurations. - + Attributes: _experiment_name (str): A randomly generated name for the experiment. checkpoint_dir (str): Directory path for storing model checkpoints. @@ -52,13 +53,17 @@ class Environment: This class is highly dependent on external modules and configurations. Ensure that all required modules and configurations are properly set up before using this class. """ - def __init__(self, data_dir: str = '/home/kaseris/Documents/data_ntu_rbgd', - checkpoint_dir = '/home/kaseris/Documents/mount/checkpoints_forecasting', - train_set_size = 0.8) -> None: + + def __init__( + self, + data_dir: str = "/home/kaseris/Documents/data_ntu_rbgd", + checkpoint_dir="/home/kaseris/Documents/mount/checkpoints_forecasting", + train_set_size=0.8, + ) -> None: self._experiment_name = randomname.get_name() self.checkpoint_dir = os.path.join(checkpoint_dir, self._experiment_name) os.mkdir(self.checkpoint_dir) - logging.info(f'Created checkpoint directory at {self.checkpoint_dir}') + logging.info(f"Created checkpoint directory at {self.checkpoint_dir}") self.data_dir = data_dir self.train_set_size = train_set_size self.config = None @@ -72,53 +77,80 @@ def __init__(self, data_dir: str = '/home/kaseris/Documents/data_ntu_rbgd', self._optimizer = None self._collate_fn = None - @property def experiment_name(self) -> str: return self._experiment_name def build_from_file(self, config_path: str) -> None: - logging.log(logging.INFO, f'Building environment from {config_path}.') + logging.log(logging.INFO, f"Building environment from {config_path}.") cfgs = read_config(config_path=config_path) # Build tranforms first, because they are used in the dataset - self._transforms = build_object_from_config(cfgs.transforms_config, TRANSFORMS) + if hasattr(cfgs, "transforms_config"): + self._transforms = build_object_from_config( + cfgs.transforms_config, TRANSFORMS + ) + else: + self._transforms = None # TODO: Add support for random splits. Maybe as external parameter? - self._dataset = build_object_from_config(cfgs.dataset_config, DATASETS, transforms=self._transforms) - logging.info(f'Loaded dataset from {self.data_dir}.') + self._dataset = build_object_from_config( + cfgs.dataset_config, DATASETS, transforms=self._transforms + ) + logging.info(f"Loaded dataset from {self.data_dir}.") # Build the loss first, because it is used in the model - loss_registry = LOSSES if cfgs.criterion_config.get('name') not in PYTORCH_LOSSES else PYTORCH_LOSSES + loss_registry = ( + LOSSES + if cfgs.criterion_config.get("name") not in PYTORCH_LOSSES + else PYTORCH_LOSSES + ) self._loss = build_object_from_config(cfgs.criterion_config, loss_registry) logging.info(f'Loaded loss function {cfgs.criterion_config.get("name")}.') - self._model = build_object_from_config(cfgs.model_config, MODELS, loss_fn=self._loss) + self._model = build_object_from_config( + cfgs.model_config, MODELS, loss_fn=self._loss + ) logging.info(f'Loaded model {cfgs.model_config.get("name")}.') # Build the optimizer - self._optimizer = build_object_from_config(cfgs.optimizer_config, PYTORCH_OPTIMIZERS, params=self._model.parameters()) + self._optimizer = build_object_from_config( + cfgs.optimizer_config, PYTORCH_OPTIMIZERS, params=self._model.parameters() + ) logging.info(f'Loaded optimizer {cfgs.optimizer_config.get("name")}.') # Build the logger - cfgs.logger_config.get('args').update({'log_dir': os.path.join(cfgs.logger_config.get('args').get('log_dir'), self._experiment_name)}) + cfgs.logger_config.get("args").update( + { + "log_dir": os.path.join( + cfgs.logger_config.get("args").get("log_dir"), self._experiment_name + ) + } + ) self._logger = build_object_from_config(cfgs.logger_config, LOGGERS) - logging.info(f'Created runs directory at {cfgs.logger_config.get("args").get("log_dir")}') + logging.info( + f'Created runs directory at {cfgs.logger_config.get("args").get("log_dir")}' + ) # Build the collate_fn - self._collate_fn = build_object_from_config(cfgs.collate_fn_config, COLLATE_FUNCS) + self._collate_fn = build_object_from_config( + cfgs.collate_fn_config, COLLATE_FUNCS + ) logging.info(f'Loaded collate function {cfgs.collate_fn_config.get("name")}.') # Split the dataset into training and validation sets train_size = int(self.train_set_size * len(self._dataset)) val_size = len(self._dataset) - train_size - self._train_dataset, self._val_dataset = random_split(self._dataset, [train_size, val_size]) + self._train_dataset, self._val_dataset = random_split( + self._dataset, [train_size, val_size] + ) # Build the runner - self._runner = Runner(model=self._model, - optimizer=self._optimizer, - logger=self._logger, - collate_fn=self._collate_fn, - train_set=self._train_dataset, - val_set=self._val_dataset, - checkpoint_dir=self.checkpoint_dir, - **cfgs.runner_config.get('args')) - logging.info(f'Finished building environment from {config_path}.') + self._runner = Runner( + model=self._model, + optimizer=self._optimizer, + logger=self._logger, + collate_fn=self._collate_fn, + train_set=self._train_dataset, + val_set=self._val_dataset, + checkpoint_dir=self.checkpoint_dir, + **cfgs.runner_config.get("args"), + ) + logging.info(f"Finished building environment from {config_path}.") self._runner.setup() - logging.info(f'Set up runner.') + logging.info(f"Set up runner.") - def run(self) -> None: # Must check if there is a checkpoint directory # If there is, load the latest checkpoint and continue training diff --git a/src/skelcast/data/human36m/human36m.py b/src/skelcast/data/human36m/human36m.py index 1613d2b..b17b6f2 100644 --- a/src/skelcast/data/human36m/human36m.py +++ b/src/skelcast/data/human36m/human36m.py @@ -1,9 +1,11 @@ import copy +from dataclasses import dataclass + import numpy as np import torch -from skelcast.data import DATASETS +from skelcast.data import DATASETS, COLLATE_FUNCS from skelcast.data.human36m.camera import normalize_screen_coordinates from skelcast.data.human36m.skeleton import Skeleton @@ -153,13 +155,64 @@ def skeleton(self): return self._skeleton +@dataclass +class Human36MSample: + x: torch.tensor + y: torch.tensor + mask: torch.tensor = None + + +@COLLATE_FUNCS.register_module() +class Human36MCollateFnWithRandomSampledContextWindow: + """ + Custom collate function for batched variable-length sequences. + During the __call__ function, we creata `block_size`-long context windows, for each sequence in the batch. + If is_packed is True, we pack the padded sequences, otherwise we return the padded sequences as is. + + Args: + - block_size (int): Sequence's context length. + - is_packed (bool): Whether to pack the padded sequence or not. + + Returns: + + The batched padded sequences ready to be fed to a transformer or an lstm model. + """ + + def __init__(self, block_size: int) -> None: + self.block_size = block_size + + def __call__(self, batch) -> Human36MSample: + # Pick a random index for each element of the batch and create a context window of size `block_size` + # around that index + # If the batch element's sequence length is less than `block_size`, then we sample the entire sequence + # Pick the random index using pytorch + seq_lens = [sample.shape[0] for sample in batch] + pre_batch = [] + pre_mask = [] + for sample in batch: + if sample.shape[0] <= self.block_size: + # Sample the entire sequence + pre_batch.append(sample) + pre_mask.append(torch.ones_like(sample)) + else: + # Sample a random index + idx = torch.randint( + low=0, high=sample.shape[0] - self.block_size, size=(1,) + ).item() + pre_batch.append(torch.from_numpy(sample[idx : idx + self.block_size, ...])) + # Pad the sequences to the maximum sequence length in the batch + batch_x = torch.nn.utils.rnn.pad_sequence(pre_batch, batch_first=True) + # Generate masks + return Human36MSample(x=batch_x, y=batch_x, mask=None) + + @DATASETS.register_module() class Human36MDataset(MocapDataset): """ TODO: Possibly add a flatten_data method for easy accessing with a single index. """ - def __init__(self, path, seq_len=27): + def __init__(self, path, seq_len=27, **kwargs): skeleton = Skeleton( offsets=[ [0.0, 0.0, 0.0], @@ -235,12 +288,12 @@ def __init__(self, path, seq_len=27): super().__init__(path, skeleton, fps=50) self.compute_positions() self._dataset_flat = [] - for subject in ['S1', 'S5', 'S6', 'S7', 'S8', 'S9', 'S11']: + for subject in ["S1", "S5", "S6", "S7", "S8", "S9", "S11"]: for action in list(self._data[subject].keys()): - self._dataset_flat.append(self._data[subject][action]['rotations']) + self._dataset_flat.append(self._data[subject][action]["rotations"]) def __getitem__(self, index): return self._dataset_flat[index] def __len__(self): - return len(self._dataset_flat) \ No newline at end of file + return len(self._dataset_flat)