Skip to content

Commit

Permalink
Merge pull request #42 from kaseris/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
kaseris committed Dec 6, 2023
2 parents ad8de28 + 5e45dd2 commit 58209ea
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 0 deletions.
43 changes: 43 additions & 0 deletions configs/lstm_regressor_1024x1024_lr_reduced_.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
dataset:
name: 'NTURGBDDataset'
args:
missing_files_dir: 'data/missing'
label_file: 'data/labels.txt'
max_context_window: 10
max_number_of_bodies: 1
transforms:
name: 'MinMaxScaleTransform'
args:
feature_scale: [-1.0, 1.0]
max_duration: 300
n_joints: 25
cache: /home/kaseris/Documents/mount/dataset_cache.pkl

# Set the train data percentage
train_data_percentage: 0.8

model:
name: 'SimpleLSTMRegressor'
args:
hidden_size: 1024
num_layers: 3
linear_out: 1024
reduction: 'mean'
batch_first: true
n_joints: 25
n_dims: 3

runner:
args:
val_batch_size: 32
train_batch_size: 32
block_size: 8
device: 'cuda'
logger:
name: 'TensorboardLogger'
args:
save_dir: 'runs'
checkpoint_dir: '/home/kaseris/Documents/checkpoints_forecasting'
n_epochs: 30
lr: 0.0001
log_gradient_info: true
5 changes: 5 additions & 0 deletions src/skelcast/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from skelcast.core.registry import Registry

LOSSES = Registry()

from .logloss import LogLoss
54 changes: 54 additions & 0 deletions src/skelcast/losses/logloss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
import torch.nn as nn

from skelcast.losses import LOSSES


@LOSSES.register_module()
class LogLoss(nn.Module):
"""
A custom loss function module in PyTorch that computes a modified logarithmic loss.
Parameters:
- alpha (float): A scaling factor for the difference between predictions and ground truth.
- beta (float): An exponent factor for scaling the difference.
- use_abs (bool): If True, uses the absolute value of the difference. Default is True.
- reduction (str): Specifies the reduction to apply to the output: 'mean', 'sum'. Default is 'mean'.
- batch_first (bool): If True, expects the batch size to be the first dimension of input tensors. Default is True.
The forward method computes the loss given predictions and ground truth values.
Usage:
>>> loss_fn = LogLoss(alpha=1.0, beta=2.0, use_abs=True, reduction='mean', batch_first=True)
>>> y_pred = torch.tensor([[0.2, 0.4], [0.6, 0.8]], dtype=torch.float32)
>>> y_true = torch.tensor([[0.1, 0.3], [0.5, 0.7]], dtype=torch.float32)
>>> loss = loss_fn(y_pred, y_true)
"""
def __init__(self, alpha, beta, use_abs=True, reduction='mean', batch_first=True) -> None:
super().__init__()
self.alpha = alpha
self.beta = beta
self.use_abs = use_abs
if reduction not in ['mean', 'sum']:
self.reduction = 'mean'
else:
self.reduction = reduction
self.batch_first = batch_first

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
assert y_pred.shape[-1] == y_true.shape[-1], f'Predictions and ground truth labels must have the same dimensionality.'
if self.batch_first:
assert y_pred.shape[1] == y_true.shape[1], f'Predictions and ground truth labels must have the same context length.'
else:
assert y_pred.shape[0] == y_true.shape[0], f'Predictions and ground truth labels must have the same context length.'
diff_ = y_pred - y_true
if self.use_abs:
diff_ = torch.abs(diff_)

result = (torch.log(1.0 + self.alpha * diff_ ** self.beta))**2
if self.reduction == 'mean':
result = result.mean()
elif self.reduction == 'sum':
result = result.sum()
return result

73 changes: 73 additions & 0 deletions tools/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import logging
from argparse import ArgumentParser

import matplotlib.pyplot as plt
import torch

from skelcast.core.environment import Environment

args = ArgumentParser()
args.add_argument('--config', type=str, default='configs/lstm_regressor_1024x1024.yaml')
args.add_argument('--data_dir', type=str, default='data')
args.add_argument('--checkpoint_dir', type=str, default='/home/kaseris/Documents/mount/checkpoints_forecasting')

args = args.parse_args()


if __name__ == '__main__':
log_format = '[%(asctime)s] %(levelname)s: %(message)s'
date_format = '%Y-%m-%d %H:%M:%S'
logging.basicConfig(level=logging.INFO, format=log_format, datefmt=date_format)

CONTEXT_SIZE = 8

# Maybe we won't need the Environment interface
env = Environment(data_dir=args.data_dir, checkpoint_dir=args.checkpoint_dir)
env.build_from_file(args.config)
dataset = env.dataset
# hard code the checkpoint path
checkpoint_path = '/home/kaseris/Documents/mount/checkpoints_forecast/acidic-plan/checkpoint_epoch_9_2023-12-01_192123.pt'
checkpoint = torch.load(checkpoint_path)
model_state_dict = checkpoint['model_state_dict']
model = env.model.to('cpu')
model.load_state_dict(model_state_dict)
model.eval()
sample, label = dataset[4]

seq_len, n_bodies, n_joints, n_dims = sample.shape
sample = sample.view(seq_len, n_bodies * n_joints * n_dims)

sample = sample.unsqueeze(0)
context = sample[:, :CONTEXT_SIZE, :]

print(f'context shape: {context.shape}')
# Make a forecast
# The forecast method should be implemented in the model
# For now let's implement it here
# The forecast routine takes a historical record as input and returns a prediction
# The prediction is a tensor of shape (1, CONTEXT_SIZE, n_bodies * n_joints * n_dims)
# The CONTEXT_SIZE-th element of the prediction is the forecast fore the next time step
# Then the prediction is appended to the historical record and the oldest element is removed
# The process is repeated until the desired number of predictions is made
def forecast(model, sample, n_preds):
preds = []
for i in range(n_preds):
pred = model(sample.to(torch.float32))
# print(f'pred shape: {pred.shape}')
preds.append(pred[:, -1, :].detach().unsqueeze(1))
sample = torch.cat([sample[:, 1:, :], pred[:, -1, :].detach().unsqueeze(1)], dim=1)
return torch.cat(preds, dim=1)
preds = forecast(model, context, 8)
print(f'preds shape: {preds.shape}')
preds = preds.view(CONTEXT_SIZE, n_bodies, n_joints, n_dims).detach()
context = context.view(CONTEXT_SIZE, n_bodies, n_joints, n_dims)
sample = sample.view(seq_len, n_bodies, n_joints, n_dims)
print(f'context shape: {context.shape}')
plt.figure(figsize=(12, 9))
plt.plot(preds[:, 0, 0, 0])
plt.plot(sample[CONTEXT_SIZE:CONTEXT_SIZE+8, 0, 0, 0])
print(f'forecasted: {preds[:, 0, 0, 0]}')
print(f'actual: {sample[CONTEXT_SIZE:CONTEXT_SIZE+8, 0, 0, 0]}')
# print(f'abs difference between forecast and actual: {torch.abs(preds[:, 0, 0, 0] - sample[CONTEXT_SIZE:CONTEXT_SIZE+8, 0, 0, 0])}')
# plt.legend(['Actual', 'Forecast'])
plt.show()

0 comments on commit 58209ea

Please sign in to comment.