-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #42 from kaseris/dev
Dev
- Loading branch information
Showing
4 changed files
with
175 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
dataset: | ||
name: 'NTURGBDDataset' | ||
args: | ||
missing_files_dir: 'data/missing' | ||
label_file: 'data/labels.txt' | ||
max_context_window: 10 | ||
max_number_of_bodies: 1 | ||
transforms: | ||
name: 'MinMaxScaleTransform' | ||
args: | ||
feature_scale: [-1.0, 1.0] | ||
max_duration: 300 | ||
n_joints: 25 | ||
cache: /home/kaseris/Documents/mount/dataset_cache.pkl | ||
|
||
# Set the train data percentage | ||
train_data_percentage: 0.8 | ||
|
||
model: | ||
name: 'SimpleLSTMRegressor' | ||
args: | ||
hidden_size: 1024 | ||
num_layers: 3 | ||
linear_out: 1024 | ||
reduction: 'mean' | ||
batch_first: true | ||
n_joints: 25 | ||
n_dims: 3 | ||
|
||
runner: | ||
args: | ||
val_batch_size: 32 | ||
train_batch_size: 32 | ||
block_size: 8 | ||
device: 'cuda' | ||
logger: | ||
name: 'TensorboardLogger' | ||
args: | ||
save_dir: 'runs' | ||
checkpoint_dir: '/home/kaseris/Documents/checkpoints_forecasting' | ||
n_epochs: 30 | ||
lr: 0.0001 | ||
log_gradient_info: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from skelcast.core.registry import Registry | ||
|
||
LOSSES = Registry() | ||
|
||
from .logloss import LogLoss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
from skelcast.losses import LOSSES | ||
|
||
|
||
@LOSSES.register_module() | ||
class LogLoss(nn.Module): | ||
""" | ||
A custom loss function module in PyTorch that computes a modified logarithmic loss. | ||
Parameters: | ||
- alpha (float): A scaling factor for the difference between predictions and ground truth. | ||
- beta (float): An exponent factor for scaling the difference. | ||
- use_abs (bool): If True, uses the absolute value of the difference. Default is True. | ||
- reduction (str): Specifies the reduction to apply to the output: 'mean', 'sum'. Default is 'mean'. | ||
- batch_first (bool): If True, expects the batch size to be the first dimension of input tensors. Default is True. | ||
The forward method computes the loss given predictions and ground truth values. | ||
Usage: | ||
>>> loss_fn = LogLoss(alpha=1.0, beta=2.0, use_abs=True, reduction='mean', batch_first=True) | ||
>>> y_pred = torch.tensor([[0.2, 0.4], [0.6, 0.8]], dtype=torch.float32) | ||
>>> y_true = torch.tensor([[0.1, 0.3], [0.5, 0.7]], dtype=torch.float32) | ||
>>> loss = loss_fn(y_pred, y_true) | ||
""" | ||
def __init__(self, alpha, beta, use_abs=True, reduction='mean', batch_first=True) -> None: | ||
super().__init__() | ||
self.alpha = alpha | ||
self.beta = beta | ||
self.use_abs = use_abs | ||
if reduction not in ['mean', 'sum']: | ||
self.reduction = 'mean' | ||
else: | ||
self.reduction = reduction | ||
self.batch_first = batch_first | ||
|
||
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | ||
assert y_pred.shape[-1] == y_true.shape[-1], f'Predictions and ground truth labels must have the same dimensionality.' | ||
if self.batch_first: | ||
assert y_pred.shape[1] == y_true.shape[1], f'Predictions and ground truth labels must have the same context length.' | ||
else: | ||
assert y_pred.shape[0] == y_true.shape[0], f'Predictions and ground truth labels must have the same context length.' | ||
diff_ = y_pred - y_true | ||
if self.use_abs: | ||
diff_ = torch.abs(diff_) | ||
|
||
result = (torch.log(1.0 + self.alpha * diff_ ** self.beta))**2 | ||
if self.reduction == 'mean': | ||
result = result.mean() | ||
elif self.reduction == 'sum': | ||
result = result.sum() | ||
return result | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import logging | ||
from argparse import ArgumentParser | ||
|
||
import matplotlib.pyplot as plt | ||
import torch | ||
|
||
from skelcast.core.environment import Environment | ||
|
||
args = ArgumentParser() | ||
args.add_argument('--config', type=str, default='configs/lstm_regressor_1024x1024.yaml') | ||
args.add_argument('--data_dir', type=str, default='data') | ||
args.add_argument('--checkpoint_dir', type=str, default='/home/kaseris/Documents/mount/checkpoints_forecasting') | ||
|
||
args = args.parse_args() | ||
|
||
|
||
if __name__ == '__main__': | ||
log_format = '[%(asctime)s] %(levelname)s: %(message)s' | ||
date_format = '%Y-%m-%d %H:%M:%S' | ||
logging.basicConfig(level=logging.INFO, format=log_format, datefmt=date_format) | ||
|
||
CONTEXT_SIZE = 8 | ||
|
||
# Maybe we won't need the Environment interface | ||
env = Environment(data_dir=args.data_dir, checkpoint_dir=args.checkpoint_dir) | ||
env.build_from_file(args.config) | ||
dataset = env.dataset | ||
# hard code the checkpoint path | ||
checkpoint_path = '/home/kaseris/Documents/mount/checkpoints_forecast/acidic-plan/checkpoint_epoch_9_2023-12-01_192123.pt' | ||
checkpoint = torch.load(checkpoint_path) | ||
model_state_dict = checkpoint['model_state_dict'] | ||
model = env.model.to('cpu') | ||
model.load_state_dict(model_state_dict) | ||
model.eval() | ||
sample, label = dataset[4] | ||
|
||
seq_len, n_bodies, n_joints, n_dims = sample.shape | ||
sample = sample.view(seq_len, n_bodies * n_joints * n_dims) | ||
|
||
sample = sample.unsqueeze(0) | ||
context = sample[:, :CONTEXT_SIZE, :] | ||
|
||
print(f'context shape: {context.shape}') | ||
# Make a forecast | ||
# The forecast method should be implemented in the model | ||
# For now let's implement it here | ||
# The forecast routine takes a historical record as input and returns a prediction | ||
# The prediction is a tensor of shape (1, CONTEXT_SIZE, n_bodies * n_joints * n_dims) | ||
# The CONTEXT_SIZE-th element of the prediction is the forecast fore the next time step | ||
# Then the prediction is appended to the historical record and the oldest element is removed | ||
# The process is repeated until the desired number of predictions is made | ||
def forecast(model, sample, n_preds): | ||
preds = [] | ||
for i in range(n_preds): | ||
pred = model(sample.to(torch.float32)) | ||
# print(f'pred shape: {pred.shape}') | ||
preds.append(pred[:, -1, :].detach().unsqueeze(1)) | ||
sample = torch.cat([sample[:, 1:, :], pred[:, -1, :].detach().unsqueeze(1)], dim=1) | ||
return torch.cat(preds, dim=1) | ||
preds = forecast(model, context, 8) | ||
print(f'preds shape: {preds.shape}') | ||
preds = preds.view(CONTEXT_SIZE, n_bodies, n_joints, n_dims).detach() | ||
context = context.view(CONTEXT_SIZE, n_bodies, n_joints, n_dims) | ||
sample = sample.view(seq_len, n_bodies, n_joints, n_dims) | ||
print(f'context shape: {context.shape}') | ||
plt.figure(figsize=(12, 9)) | ||
plt.plot(preds[:, 0, 0, 0]) | ||
plt.plot(sample[CONTEXT_SIZE:CONTEXT_SIZE+8, 0, 0, 0]) | ||
print(f'forecasted: {preds[:, 0, 0, 0]}') | ||
print(f'actual: {sample[CONTEXT_SIZE:CONTEXT_SIZE+8, 0, 0, 0]}') | ||
# print(f'abs difference between forecast and actual: {torch.abs(preds[:, 0, 0, 0] - sample[CONTEXT_SIZE:CONTEXT_SIZE+8, 0, 0, 0])}') | ||
# plt.legend(['Actual', 'Forecast']) | ||
plt.show() |