Skip to content

Commit

Permalink
Update metric
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Oct 29, 2024
1 parent f58f443 commit 780cd3a
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 0 deletions.
78 changes: 78 additions & 0 deletions sdmetrics/timeseries/inter_row.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""InterRowMSAS module."""

import numpy as np
import pandas as pd

from sdmetrics.goal import Goal
from sdmetrics.single_column.statistical.kscomplement import KSComplement


class InterRowMSAS:
"""Inter-Row Multi-Sequence Aggregate Similarity (MSAS) metric.
Attributes:
name (str):
Name to use when reports about this metric are printed.
goal (sdmetrics.goal.Goal):
The goal of this metric.
min_value (Union[float, tuple[float]]):
Minimum value or values that this metric can take.
max_value (Union[float, tuple[float]]):
Maximum value or values that this metric can take.
"""

name = 'Inter-Row Multi-Sequence Aggregate Similarity'
goal = Goal.MAXIMIZE
min_value = 0.0
max_value = 1.0

@staticmethod
def compute(real_data, synthetic_data, n_rows_diff=1, apply_log=False):
"""Compute this metric.
This metric compares the inter-row differences of sequences in the real data
vs. the synthetic data.
It works as follows:
- Calculate the difference between row r and row r+x for each row in the real data
- Take the average over each sequence to form a distribution D_r
- Do the same for the synthetic data to form a new distribution D_s
- Apply the KSComplement metric to compare the similarities of (D_r, D_s)
- Return this score
Args:
real_data (tuple[pd.Series, pd.Series]):
A tuple of 2 pandas.Series objects. The first represents the sequence key
of the real data and the second represents a continuous column of data.
synthetic_data (tuple[pd.Series, pd.Series]):
A tuple of 2 pandas.Series objects. The first represents the sequence key
of the synthetic data and the second represents a continuous column of data.
n_rows_diff (int):
An integer representing the number of rows to consider when taking the difference.
apply_log (bool):
Whether to apply a natural log before taking the difference.
Returns:
float:
The similarity score between the real and synthetic data distributions.
"""
real_keys, real_values = real_data
synthetic_keys, synthetic_values = synthetic_data

if apply_log:
real_values = np.log(real_values)
synthetic_values = np.log(synthetic_values)

def calculate_differences(keys, values):
differences = []
for key in keys.unique():
group_values = values[keys == key].to_numpy()
if len(group_values) > n_rows_diff:
diff = group_values[n_rows_diff:] - group_values[:-n_rows_diff]
differences.append(np.mean(diff))
return pd.Series(differences)

real_diff = calculate_differences(real_keys, real_values)
synthetic_diff = calculate_differences(synthetic_keys, synthetic_values)

return KSComplement.compute(real_diff, synthetic_diff)
74 changes: 74 additions & 0 deletions tests/unit/timeseries/test_inter_row.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import numpy as np
import pandas as pd

from sdmetrics.timeseries.inter_row import InterRowMSAS


class TestInterRowMSAS:
def test_compute_identical_sequences(self):
"""Test it returns 1 when real and synthetic data are identical."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
real_values = pd.Series([1, 2, 3, 4, 5, 6])
synthetic_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
synthetic_values = pd.Series([1, 2, 3, 4, 5, 6])

# Run
score = InterRowMSAS.compute(
real_data=(real_keys, real_values), synthetic_data=(synthetic_keys, synthetic_values)
)

# Assert
assert score == 1

def test_compute_different_sequences(self):
"""Test it for distinct distributions."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
real_values = pd.Series([1, 2, 3, 4, 5, 6])
synthetic_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
synthetic_values = pd.Series([1, 3, 5, 2, 4, 6])

# Run
score = InterRowMSAS.compute(
real_data=(real_keys, real_values), synthetic_data=(synthetic_keys, synthetic_values)
)

# Assert
assert score == 0

def test_compute_with_log(self):
"""Test it with logarithmic transformation."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
real_values = pd.Series([1, 2, 4, 8, 16, 32])
synthetic_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
synthetic_values = pd.Series([1, 2, 4, 8, 16, 32])

# Run
score = InterRowMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
apply_log=True,
)

# Assert
assert score == 1

def test_compute_different_n_rows_diff(self):
"""Test it with different n_rows_diff."""
# Setup
real_keys = pd.Series(['id1'] * 10 + ['id2'] * 10)
real_values = pd.Series(list(range(10)) + list(range(10)))
synthetic_keys = pd.Series(['id1'] * 10 + ['id2'] * 10)
synthetic_values = pd.Series(list(range(10)) + list(range(10)))

# Run
score = InterRowMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
n_rows_diff=3,
)

# Assert
assert score == 1

0 comments on commit 780cd3a

Please sign in to comment.