Skip to content

Commit

Permalink
Add the sample class and a custom collate fn for the Human3.6m
Browse files Browse the repository at this point in the history
  • Loading branch information
kaseris committed Feb 9, 2024
1 parent 85c6e88 commit 7be36c4
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 33 deletions.
88 changes: 60 additions & 28 deletions src/skelcast/core/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@

torch.manual_seed(133742069)


class Environment:
"""
The Environment class is designed to set up and manage the environment for training machine learning models.
It includes methods for building models, datasets, loggers, and runners based on specified configurations.
Attributes:
_experiment_name (str): A randomly generated name for the experiment.
checkpoint_dir (str): Directory path for storing model checkpoints.
Expand All @@ -52,13 +53,17 @@ class Environment:
This class is highly dependent on external modules and configurations. Ensure that all required modules
and configurations are properly set up before using this class.
"""
def __init__(self, data_dir: str = '/home/kaseris/Documents/data_ntu_rbgd',
checkpoint_dir = '/home/kaseris/Documents/mount/checkpoints_forecasting',
train_set_size = 0.8) -> None:

def __init__(
self,
data_dir: str = "/home/kaseris/Documents/data_ntu_rbgd",
checkpoint_dir="/home/kaseris/Documents/mount/checkpoints_forecasting",
train_set_size=0.8,
) -> None:
self._experiment_name = randomname.get_name()
self.checkpoint_dir = os.path.join(checkpoint_dir, self._experiment_name)
os.mkdir(self.checkpoint_dir)
logging.info(f'Created checkpoint directory at {self.checkpoint_dir}')
logging.info(f"Created checkpoint directory at {self.checkpoint_dir}")
self.data_dir = data_dir
self.train_set_size = train_set_size
self.config = None
Expand All @@ -72,53 +77,80 @@ def __init__(self, data_dir: str = '/home/kaseris/Documents/data_ntu_rbgd',
self._optimizer = None
self._collate_fn = None


@property
def experiment_name(self) -> str:
return self._experiment_name

def build_from_file(self, config_path: str) -> None:
logging.log(logging.INFO, f'Building environment from {config_path}.')
logging.log(logging.INFO, f"Building environment from {config_path}.")
cfgs = read_config(config_path=config_path)
# Build tranforms first, because they are used in the dataset
self._transforms = build_object_from_config(cfgs.transforms_config, TRANSFORMS)
if hasattr(cfgs, "transforms_config"):
self._transforms = build_object_from_config(
cfgs.transforms_config, TRANSFORMS
)
else:
self._transforms = None
# TODO: Add support for random splits. Maybe as external parameter?
self._dataset = build_object_from_config(cfgs.dataset_config, DATASETS, transforms=self._transforms)
logging.info(f'Loaded dataset from {self.data_dir}.')
self._dataset = build_object_from_config(
cfgs.dataset_config, DATASETS, transforms=self._transforms
)
logging.info(f"Loaded dataset from {self.data_dir}.")
# Build the loss first, because it is used in the model
loss_registry = LOSSES if cfgs.criterion_config.get('name') not in PYTORCH_LOSSES else PYTORCH_LOSSES
loss_registry = (
LOSSES
if cfgs.criterion_config.get("name") not in PYTORCH_LOSSES
else PYTORCH_LOSSES
)
self._loss = build_object_from_config(cfgs.criterion_config, loss_registry)
logging.info(f'Loaded loss function {cfgs.criterion_config.get("name")}.')
self._model = build_object_from_config(cfgs.model_config, MODELS, loss_fn=self._loss)
self._model = build_object_from_config(
cfgs.model_config, MODELS, loss_fn=self._loss
)
logging.info(f'Loaded model {cfgs.model_config.get("name")}.')
# Build the optimizer
self._optimizer = build_object_from_config(cfgs.optimizer_config, PYTORCH_OPTIMIZERS, params=self._model.parameters())
self._optimizer = build_object_from_config(
cfgs.optimizer_config, PYTORCH_OPTIMIZERS, params=self._model.parameters()
)
logging.info(f'Loaded optimizer {cfgs.optimizer_config.get("name")}.')
# Build the logger
cfgs.logger_config.get('args').update({'log_dir': os.path.join(cfgs.logger_config.get('args').get('log_dir'), self._experiment_name)})
cfgs.logger_config.get("args").update(
{
"log_dir": os.path.join(
cfgs.logger_config.get("args").get("log_dir"), self._experiment_name
)
}
)
self._logger = build_object_from_config(cfgs.logger_config, LOGGERS)
logging.info(f'Created runs directory at {cfgs.logger_config.get("args").get("log_dir")}')
logging.info(
f'Created runs directory at {cfgs.logger_config.get("args").get("log_dir")}'
)
# Build the collate_fn
self._collate_fn = build_object_from_config(cfgs.collate_fn_config, COLLATE_FUNCS)
self._collate_fn = build_object_from_config(
cfgs.collate_fn_config, COLLATE_FUNCS
)
logging.info(f'Loaded collate function {cfgs.collate_fn_config.get("name")}.')
# Split the dataset into training and validation sets
train_size = int(self.train_set_size * len(self._dataset))
val_size = len(self._dataset) - train_size
self._train_dataset, self._val_dataset = random_split(self._dataset, [train_size, val_size])
self._train_dataset, self._val_dataset = random_split(
self._dataset, [train_size, val_size]
)
# Build the runner
self._runner = Runner(model=self._model,
optimizer=self._optimizer,
logger=self._logger,
collate_fn=self._collate_fn,
train_set=self._train_dataset,
val_set=self._val_dataset,
checkpoint_dir=self.checkpoint_dir,
**cfgs.runner_config.get('args'))
logging.info(f'Finished building environment from {config_path}.')
self._runner = Runner(
model=self._model,
optimizer=self._optimizer,
logger=self._logger,
collate_fn=self._collate_fn,
train_set=self._train_dataset,
val_set=self._val_dataset,
checkpoint_dir=self.checkpoint_dir,
**cfgs.runner_config.get("args"),
)
logging.info(f"Finished building environment from {config_path}.")
self._runner.setup()
logging.info(f'Set up runner.')
logging.info(f"Set up runner.")


def run(self) -> None:
# Must check if there is a checkpoint directory
# If there is, load the latest checkpoint and continue training
Expand Down
63 changes: 58 additions & 5 deletions src/skelcast/data/human36m/human36m.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import copy

from dataclasses import dataclass

import numpy as np
import torch

from skelcast.data import DATASETS
from skelcast.data import DATASETS, COLLATE_FUNCS
from skelcast.data.human36m.camera import normalize_screen_coordinates
from skelcast.data.human36m.skeleton import Skeleton

Expand Down Expand Up @@ -153,13 +155,64 @@ def skeleton(self):
return self._skeleton


@dataclass
class Human36MSample:
x: torch.tensor
y: torch.tensor
mask: torch.tensor = None


@COLLATE_FUNCS.register_module()
class Human36MCollateFnWithRandomSampledContextWindow:
"""
Custom collate function for batched variable-length sequences.
During the __call__ function, we creata `block_size`-long context windows, for each sequence in the batch.
If is_packed is True, we pack the padded sequences, otherwise we return the padded sequences as is.
Args:
- block_size (int): Sequence's context length.
- is_packed (bool): Whether to pack the padded sequence or not.
Returns:
The batched padded sequences ready to be fed to a transformer or an lstm model.
"""

def __init__(self, block_size: int) -> None:
self.block_size = block_size

def __call__(self, batch) -> Human36MSample:
# Pick a random index for each element of the batch and create a context window of size `block_size`
# around that index
# If the batch element's sequence length is less than `block_size`, then we sample the entire sequence
# Pick the random index using pytorch
seq_lens = [sample.shape[0] for sample in batch]
pre_batch = []
pre_mask = []
for sample in batch:
if sample.shape[0] <= self.block_size:
# Sample the entire sequence
pre_batch.append(sample)
pre_mask.append(torch.ones_like(sample))
else:
# Sample a random index
idx = torch.randint(
low=0, high=sample.shape[0] - self.block_size, size=(1,)
).item()
pre_batch.append(torch.from_numpy(sample[idx : idx + self.block_size, ...]))
# Pad the sequences to the maximum sequence length in the batch
batch_x = torch.nn.utils.rnn.pad_sequence(pre_batch, batch_first=True)
# Generate masks
return Human36MSample(x=batch_x, y=batch_x, mask=None)


@DATASETS.register_module()
class Human36MDataset(MocapDataset):
"""
TODO: Possibly add a flatten_data method for easy accessing with a single index.
"""

def __init__(self, path, seq_len=27):
def __init__(self, path, seq_len=27, **kwargs):
skeleton = Skeleton(
offsets=[
[0.0, 0.0, 0.0],
Expand Down Expand Up @@ -235,12 +288,12 @@ def __init__(self, path, seq_len=27):
super().__init__(path, skeleton, fps=50)
self.compute_positions()
self._dataset_flat = []
for subject in ['S1', 'S5', 'S6', 'S7', 'S8', 'S9', 'S11']:
for subject in ["S1", "S5", "S6", "S7", "S8", "S9", "S11"]:
for action in list(self._data[subject].keys()):
self._dataset_flat.append(self._data[subject][action]['rotations'])
self._dataset_flat.append(self._data[subject][action]["rotations"])

def __getitem__(self, index):
return self._dataset_flat[index]

def __len__(self):
return len(self._dataset_flat)
return len(self._dataset_flat)

0 comments on commit 7be36c4

Please sign in to comment.