From 8f1c570ad913d7319c70a4597e97837f47edf7e3 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Wed, 31 Jan 2024 16:12:58 +0200 Subject: [PATCH 1/4] Metrics. --- src/skelcast/metrics/__init__.py | 4 ++++ src/skelcast/metrics/metric.py | 12 ++++++++++++ src/skelcast/metrics/mjmpe.py | 14 ++++++++++++++ 3 files changed, 30 insertions(+) create mode 100644 src/skelcast/metrics/__init__.py create mode 100644 src/skelcast/metrics/metric.py create mode 100644 src/skelcast/metrics/mjmpe.py diff --git a/src/skelcast/metrics/__init__.py b/src/skelcast/metrics/__init__.py new file mode 100644 index 0000000..8b572f4 --- /dev/null +++ b/src/skelcast/metrics/__init__.py @@ -0,0 +1,4 @@ +from skelcast.core.registry import Registry + +METRICS = Registry() + diff --git a/src/skelcast/metrics/metric.py b/src/skelcast/metrics/metric.py new file mode 100644 index 0000000..3f89eb7 --- /dev/null +++ b/src/skelcast/metrics/metric.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod + +# Create abstract class Metric +class Metric(ABC): + + @abstractmethod + def update(self): + pass + + @abstractmethod + def compute(self): + pass \ No newline at end of file diff --git a/src/skelcast/metrics/mjmpe.py b/src/skelcast/metrics/mjmpe.py new file mode 100644 index 0000000..93cefb0 --- /dev/null +++ b/src/skelcast/metrics/mjmpe.py @@ -0,0 +1,14 @@ +from skelcast.metrics import METRICS +from skelcast.metrics.metric import Metric + +class MeanPerJointPositionError(Metric): + """Mean Per Joint Position Error (MPJPE) metric. + """ + def __init__(self, name='MPJPE', **kwargs): + super().__init__(name=name, **kwargs) + + def update(self): + pass + + def compute(self): + pass \ No newline at end of file From 9c63a238edad9fc74c4eb82069e9316bd0143f59 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Thu, 1 Feb 2024 11:19:59 +0200 Subject: [PATCH 2/4] Abstract metric class --- src/skelcast/metrics/metric.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/src/skelcast/metrics/metric.py b/src/skelcast/metrics/metric.py index 3f89eb7..2fee27f 100644 --- a/src/skelcast/metrics/metric.py +++ b/src/skelcast/metrics/metric.py @@ -1,12 +1,30 @@ from abc import ABC, abstractmethod -# Create abstract class Metric class Metric(ABC): + @abstractmethod + def update(self, predictions, targets): + """ + Update the metric's state using the predictions and the targets. + + Args: + - predictions: The predicted values. + - targets: The ground truth values. + """ + pass @abstractmethod - def update(self): + def result(self): + """ + Calculates and returns the final metric result based on the state. + + Returns: + - The calculated metric. + """ pass @abstractmethod - def compute(self): - pass \ No newline at end of file + def reset(self): + """ + Resets the metric state. + """ + pass From b18d87a2ac2eb4b9a83e4955400b5b00d62efb6f Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Thu, 1 Feb 2024 11:21:13 +0200 Subject: [PATCH 3/4] Mean per joint positional error implementation --- src/skelcast/metrics/mjmpe.py | 14 ------------ src/skelcast/metrics/mpjpe.py | 42 +++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 14 deletions(-) delete mode 100644 src/skelcast/metrics/mjmpe.py create mode 100644 src/skelcast/metrics/mpjpe.py diff --git a/src/skelcast/metrics/mjmpe.py b/src/skelcast/metrics/mjmpe.py deleted file mode 100644 index 93cefb0..0000000 --- a/src/skelcast/metrics/mjmpe.py +++ /dev/null @@ -1,14 +0,0 @@ -from skelcast.metrics import METRICS -from skelcast.metrics.metric import Metric - -class MeanPerJointPositionError(Metric): - """Mean Per Joint Position Error (MPJPE) metric. - """ - def __init__(self, name='MPJPE', **kwargs): - super().__init__(name=name, **kwargs) - - def update(self): - pass - - def compute(self): - pass \ No newline at end of file diff --git a/src/skelcast/metrics/mpjpe.py b/src/skelcast/metrics/mpjpe.py new file mode 100644 index 0000000..085383c --- /dev/null +++ b/src/skelcast/metrics/mpjpe.py @@ -0,0 +1,42 @@ +import torch + +from typing import Tuple + +from skelcast.metrics import METRICS +from skelcast.metrics.metric import Metric + + +@METRICS.register_module() +class MeanPerJointPositionError(Metric): + def __init__(self, keep_time_dim: bool = True): + self.keep_time_dim = keep_time_dim + self.reset() + + def reset(self): + # Reset the state of the metric + self.y = torch.tensor([]) + self.y_pred = torch.tensor([]) + + def update(self, output: Tuple[torch.Tensor, torch.Tensor]): + y_pred, y = output + + # Concatenate new predictions and targets to the stored tensors + if self.y.numel() == 0: + self.y = y + self.y_pred = y_pred + else: + self.y = torch.cat([self.y, y], dim=0) + self.y_pred = torch.cat([self.y_pred, y_pred], dim=0) + + def result(self): + # Compute the Mean Per Joint Position Error + if self.y.numel() == 0: + raise ValueError('MeanPerJointPositionError must have at least one example before it can be computed.') + + error = (self.y - self.y_pred).norm(dim=-1) + mean_error = error.mean(dim=[0, 2]) + + if not self.keep_time_dim: + mean_error = mean_error.mean() + + return mean_error \ No newline at end of file From 0bae39425efd50ff6a86a2ca7d086b5393fd1a20 Mon Sep 17 00:00:00 2001 From: Michail Kaseris Date: Thu, 1 Feb 2024 11:59:23 +0200 Subject: [PATCH 4/4] Added mean angle error metric. --- src/skelcast/metrics/mae.py | 42 +++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 src/skelcast/metrics/mae.py diff --git a/src/skelcast/metrics/mae.py b/src/skelcast/metrics/mae.py new file mode 100644 index 0000000..0e3101c --- /dev/null +++ b/src/skelcast/metrics/mae.py @@ -0,0 +1,42 @@ +import torch + +from typing import Tuple + +from skelcast.metrics import METRICS +from skelcast.metrics.metric import Metric + + +@METRICS.register_module() +class MeanPerJointPositionError(Metric): # Inherits from our abstract Metric class + def __init__(self, keep_time_dim: bool = True): + self.keep_time_dim = keep_time_dim + self.reset() # Initialize/reset the state + + def reset(self): + # Reset the state of the metric + self.y = None + self.y_pred = None + + def update(self, output: Tuple[torch.Tensor, torch.Tensor]): + y_pred, y = output # Unpack the output tuple, assuming output is already in the desired format + + # Initialize or update the stored tensors + if self.y is None: + self.y = y + self.y_pred = y_pred + else: + self.y = torch.cat([self.y, y], dim=0) + self.y_pred = torch.cat([self.y_pred, y_pred], dim=0) + + def result(self): + # Compute the Mean Per Joint Position Error + if self.y is None: + raise ValueError('MeanPerJointPositionError must have at least one example before it can be computed.') + + error = (self.y - self.y_pred).norm(dim=-1) # Calculate the L2 norm over the last dimension (joints) + mean_error = error.mean(dim=[0, 2]) # Take the mean over the batch and time dimensions + + if not self.keep_time_dim: + mean_error = mean_error.mean() # Further reduce mean over all joints if time dimension is not kept + + return mean_error \ No newline at end of file