Skip to content

Commit

Permalink
Merge pull request #165 from nasa/feature/calc_error_extension
Browse files Browse the repository at this point in the history
Calc Error Extension
  • Loading branch information
teubert authored Nov 4, 2024
2 parents 9767293 + 04d762b commit a9edf31
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions src/progpy/utils/calc_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,10 @@ def MSE(m, times: List[float], inputs: List[dict], outputs: List[dict], **kwargs
stability_tol represents the fraction of the provided argument `times` that are required to be met in simulation,
before the model goes unstable in order to produce a valid estimate of error.
If the model goes unstable before stability_tol is met, NaN is returned.
If the model goes unstable before stability_tol is met and short_sim_penalty is None, then exception is raised
Else if the model goes unstable before stability_tol is met and short_sim_penalty is not None- the penalty is added to the score
Else, model goes unstable after stability_tol is met, the error calculated from data up to the instability is returned.
short_sim_penalty (float, optional): penalty added for simulation becoming unstable before stability_tol, added for each % below tol. Default is 100
Returns:
float: Total error
Expand Down Expand Up @@ -180,8 +182,15 @@ def MSE(m, times: List[float], inputs: List[dict], outputs: List[dict], **kwargs
# This is true for any window-based model
if any(np.isnan(z_obs.matrix)):
if t <= cutoffThreshold:
raise ValueError(f"Model unstable- NAN reached in simulation (t={t}) before cutoff threshold. "
f"Cutoff threshold is {cutoffThreshold}, or roughly {stability_tol * 100}% of the data")
short_sim_penalty = kwargs.get('short_sim_penalty', 100)
if short_sim_penalty is None:
raise ValueError(f"Model unstable- NAN reached in simulation (t={t}) before cutoff threshold. "
f"Cutoff threshold is {cutoffThreshold}, or roughly {stability_tol * 100}% of the data")

warn(f"Model unstable- NAN reached in simulation (t={t}) before cutoff threshold. "
f"Cutoff threshold is {cutoffThreshold}, or roughly {stability_tol * 100}% of the data. Penalty added to score.")
# Return value with Penalty added
return err_total/counter + (100-(t/cutoffThreshold)*100)*short_sim_penalty
else:
warn("Model unstable- NaN reached in simulation (t={})".format(t))
break
Expand Down

0 comments on commit a9edf31

Please sign in to comment.