Skip to content

Commit

Permalink
Merge pull request #78 from kaseris/metrics
Browse files Browse the repository at this point in the history
Metrics
  • Loading branch information
kaseris committed Feb 1, 2024
2 parents dd74009 + 0bae394 commit 1aa42b9
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/skelcast/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from skelcast.core.registry import Registry

METRICS = Registry()

42 changes: 42 additions & 0 deletions src/skelcast/metrics/mae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch

from typing import Tuple

from skelcast.metrics import METRICS
from skelcast.metrics.metric import Metric


@METRICS.register_module()
class MeanPerJointPositionError(Metric): # Inherits from our abstract Metric class
def __init__(self, keep_time_dim: bool = True):
self.keep_time_dim = keep_time_dim
self.reset() # Initialize/reset the state

def reset(self):
# Reset the state of the metric
self.y = None
self.y_pred = None

def update(self, output: Tuple[torch.Tensor, torch.Tensor]):
y_pred, y = output # Unpack the output tuple, assuming output is already in the desired format

# Initialize or update the stored tensors
if self.y is None:
self.y = y
self.y_pred = y_pred
else:
self.y = torch.cat([self.y, y], dim=0)
self.y_pred = torch.cat([self.y_pred, y_pred], dim=0)

def result(self):
# Compute the Mean Per Joint Position Error
if self.y is None:
raise ValueError('MeanPerJointPositionError must have at least one example before it can be computed.')

error = (self.y - self.y_pred).norm(dim=-1) # Calculate the L2 norm over the last dimension (joints)
mean_error = error.mean(dim=[0, 2]) # Take the mean over the batch and time dimensions

if not self.keep_time_dim:
mean_error = mean_error.mean() # Further reduce mean over all joints if time dimension is not kept

return mean_error
30 changes: 30 additions & 0 deletions src/skelcast/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from abc import ABC, abstractmethod

class Metric(ABC):
@abstractmethod
def update(self, predictions, targets):
"""
Update the metric's state using the predictions and the targets.
Args:
- predictions: The predicted values.
- targets: The ground truth values.
"""
pass

@abstractmethod
def result(self):
"""
Calculates and returns the final metric result based on the state.
Returns:
- The calculated metric.
"""
pass

@abstractmethod
def reset(self):
"""
Resets the metric state.
"""
pass
42 changes: 42 additions & 0 deletions src/skelcast/metrics/mpjpe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch

from typing import Tuple

from skelcast.metrics import METRICS
from skelcast.metrics.metric import Metric


@METRICS.register_module()
class MeanPerJointPositionError(Metric):
def __init__(self, keep_time_dim: bool = True):
self.keep_time_dim = keep_time_dim
self.reset()

def reset(self):
# Reset the state of the metric
self.y = torch.tensor([])
self.y_pred = torch.tensor([])

def update(self, output: Tuple[torch.Tensor, torch.Tensor]):
y_pred, y = output

# Concatenate new predictions and targets to the stored tensors
if self.y.numel() == 0:
self.y = y
self.y_pred = y_pred
else:
self.y = torch.cat([self.y, y], dim=0)
self.y_pred = torch.cat([self.y_pred, y_pred], dim=0)

def result(self):
# Compute the Mean Per Joint Position Error
if self.y.numel() == 0:
raise ValueError('MeanPerJointPositionError must have at least one example before it can be computed.')

error = (self.y - self.y_pred).norm(dim=-1)
mean_error = error.mean(dim=[0, 2])

if not self.keep_time_dim:
mean_error = mean_error.mean()

return mean_error

0 comments on commit 1aa42b9

Please sign in to comment.