diff --git a/interfere/metrics.py b/interfere/metrics.py index e27941d..9cd2516 100644 --- a/interfere/metrics.py +++ b/interfere/metrics.py @@ -246,9 +246,53 @@ def __call__(self, vpt = idxs.min() return vpt - +class RootMeanSquaredScaledError(CounterfactualForecastingMetric): + + def __init__(self): + super().__init__("RMSSE") + + + @copy_doc(CounterfactualForecastingMetric.__call__) + def __call__(self, + X: np.ndarray, + X_do: np.ndarray, + X_do_pred:np.ndarray, + intervention_idxs: Iterable[int], + **kwargs + ): + X_resp, X_do_resp, pred_X_do_resp = self.drop_intervention_cols( + intervention_idxs, *[X, X_do, X_do_pred]) + return rmsse( + X_do_resp, pred_X_do_resp) + + +class RootMeanSquaredScaledErrorOverAvgMethod(CounterfactualForecastingMetric): + def __init__(self): + """Computes RMSSE(actual, predicted) / RMSSE(actual, mean(training)).""" + super().__init__("RMSSE/RMSSE(AVG)") + + @copy_doc(CounterfactualForecastingMetric.__call__) + def __call__(self, + X: np.ndarray, + X_do: np.ndarray, + X_do_pred:np.ndarray, + intervention_idxs: Iterable[int], + **kwargs + ): + rmsse_cntr_metric = RootMeanSquaredScaledError() + + err = rmsse_cntr_metric( + X, X_do, X_do_pred, intervention_idxs, **kwargs + ) + X_means = np.vstack([np.mean(X, axis=0) for i in range(X_do.shape[0])]) + avg_err = rmsse_cntr_metric( + X, X_do, X_means, intervention_idxs, **kwargs + ) + return err / avg_err + + def _error(actual: np.ndarray, predicted: np.ndarray): """ Simple error """ return actual - predicted diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 9a4757b..910ddfe 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -5,6 +5,7 @@ RootMeanStandardizedSquaredError, TTestDirectionalChangeAccuracy, ValidPredictionTime, + RootMeanSquaredScaledErrorOverAvgMethod ) import numpy as np @@ -98,4 +99,18 @@ def test_vpt_first_idx(): assert vpt(None, X_true, X_pred, [0]) == 0 - \ No newline at end of file +def test_rmsse_over_avg(): + X_train = np.random.rand(20, 4) + X_true = np.random.rand(10, 4) + X_false = np.zeros((10, 4)) + X_pred_good = X_true + 0.01 * np.random.randn(10, 4) + intervention_idxs = np.array([0]) + + rmsse_over_avg = RootMeanSquaredScaledErrorOverAvgMethod() + x_false_err = rmsse_over_avg(X_train, X_true, X_false, intervention_idxs) + x_true_err = rmsse_over_avg(X_train, X_true, X_pred_good, intervention_idxs) + assert x_false_err > x_true_err + + X_mean_pred = np.vstack([ + np.mean(X_train, axis=0) for i in range(X_true.shape[0])]) + assert 1.0 == rmsse_over_avg(X_train, X_true, X_mean_pred, intervention_idxs) \ No newline at end of file