From 91f5dd38619b8666d24f8535578a3ab0245583e0 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Wed, 31 Jan 2024 14:10:05 +0200 Subject: [PATCH 1/2] Euler Angle Loss --- src/skelcast/losses/__init__.py | 3 ++- src/skelcast/losses/euler_angle_loss.py | 27 +++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 src/skelcast/losses/euler_angle_loss.py diff --git a/src/skelcast/losses/__init__.py b/src/skelcast/losses/__init__.py index 962aeb8..e056190 100644 --- a/src/skelcast/losses/__init__.py +++ b/src/skelcast/losses/__init__.py @@ -2,4 +2,5 @@ LOSSES = Registry() -from .logloss import LogLoss \ No newline at end of file +from .logloss import LogLoss +from .euler_angle_loss import EulerAngleLoss diff --git a/src/skelcast/losses/euler_angle_loss.py b/src/skelcast/losses/euler_angle_loss.py new file mode 100644 index 0000000..104e293 --- /dev/null +++ b/src/skelcast/losses/euler_angle_loss.py @@ -0,0 +1,27 @@ +import numpy as np +import torch +import torch.nn as nn + +from skelcast.data.human36m.quaternion import qeuler + + +class EulerAngleLoss(nn.Module): + def __init__(self, order="xyz", reduction="mean"): + super(EulerAngleLoss, self).__init__() + self._order = order + self._reduction = reduction + + def forward(self, predictions: torch.Tensor, targets: torch.Tensor): + # Check the shape of predictions and targets + assert ( + predictions.shape == targets.shape + ), f"Predictions and targets must have the same shape." + assert ( + predictions.shape[-1] == 3 + ), f"Predictions and targets must have 3 channels in the last dimension." + + predicted_euler = qeuler(predictions, self._order, epsilon=1e-6) + angle_distance = ( + torch.remainder(predicted_euler - targets + np.pi, 2 * np.pi) - np.pi + ) + return torch.mean(torch.abs(angle_distance)) From 2e69aa1cfa743ce3c40dbf876e17c9cdf9f30b1f Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Wed, 31 Jan 2024 14:11:11 +0200 Subject: [PATCH 2/2] Register module for fast instantiation --- src/skelcast/losses/euler_angle_loss.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/skelcast/losses/euler_angle_loss.py b/src/skelcast/losses/euler_angle_loss.py index 104e293..4d78046 100644 --- a/src/skelcast/losses/euler_angle_loss.py +++ b/src/skelcast/losses/euler_angle_loss.py @@ -3,8 +3,10 @@ import torch.nn as nn from skelcast.data.human36m.quaternion import qeuler +from skelcast.losses import LOSSES +@LOSSES.register_module() class EulerAngleLoss(nn.Module): def __init__(self, order="xyz", reduction="mean"): super(EulerAngleLoss, self).__init__()