Skip to content

Commit

Permalink
Added a metric for computing RMSSE/RMSSE(Predicting the Mean)
Browse files Browse the repository at this point in the history
  • Loading branch information
djpasseyjr committed Nov 8, 2024
1 parent 8c6bbdc commit da7036b
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 2 deletions.
46 changes: 45 additions & 1 deletion interfere/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,53 @@ def __call__(self,

vpt = idxs.min()
return vpt



class RootMeanSquaredScaledError(CounterfactualForecastingMetric):

def __init__(self):
super().__init__("RMSSE")


@copy_doc(CounterfactualForecastingMetric.__call__)
def __call__(self,
X: np.ndarray,
X_do: np.ndarray,
X_do_pred:np.ndarray,
intervention_idxs: Iterable[int],
**kwargs
):
X_resp, X_do_resp, pred_X_do_resp = self.drop_intervention_cols(
intervention_idxs, *[X, X_do, X_do_pred])
return rmsse(
X_do_resp, pred_X_do_resp)


class RootMeanSquaredScaledErrorOverAvgMethod(CounterfactualForecastingMetric):
def __init__(self):
"""Computes RMSSE(actual, predicted) / RMSSE(actual, mean(training))."""
super().__init__("RMSSE/RMSSE(AVG)")

@copy_doc(CounterfactualForecastingMetric.__call__)
def __call__(self,
X: np.ndarray,
X_do: np.ndarray,
X_do_pred:np.ndarray,
intervention_idxs: Iterable[int],
**kwargs
):
rmsse_cntr_metric = RootMeanSquaredScaledError()

err = rmsse_cntr_metric(
X, X_do, X_do_pred, intervention_idxs, **kwargs
)
X_means = np.vstack([np.mean(X, axis=0) for i in range(X_do.shape[0])])
avg_err = rmsse_cntr_metric(
X, X_do, X_means, intervention_idxs, **kwargs
)
return err / avg_err


def _error(actual: np.ndarray, predicted: np.ndarray):
""" Simple error """
return actual - predicted
Expand Down
17 changes: 16 additions & 1 deletion tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
RootMeanStandardizedSquaredError,
TTestDirectionalChangeAccuracy,
ValidPredictionTime,
RootMeanSquaredScaledErrorOverAvgMethod
)
import numpy as np

Expand Down Expand Up @@ -98,4 +99,18 @@ def test_vpt_first_idx():
assert vpt(None, X_true, X_pred, [0]) == 0



def test_rmsse_over_avg():
X_train = np.random.rand(20, 4)
X_true = np.random.rand(10, 4)
X_false = np.zeros((10, 4))
X_pred_good = X_true + 0.01 * np.random.randn(10, 4)
intervention_idxs = np.array([0])

rmsse_over_avg = RootMeanSquaredScaledErrorOverAvgMethod()
x_false_err = rmsse_over_avg(X_train, X_true, X_false, intervention_idxs)
x_true_err = rmsse_over_avg(X_train, X_true, X_pred_good, intervention_idxs)
assert x_false_err > x_true_err

X_mean_pred = np.vstack([
np.mean(X_train, axis=0) for i in range(X_true.shape[0])])
assert 1.0 == rmsse_over_avg(X_train, X_true, X_mean_pred, intervention_idxs)

0 comments on commit da7036b

Please sign in to comment.