Skip to content

Commit

Permalink
Predictor Event Strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
teubert committed Jul 15, 2024
1 parent 5b87c32 commit 7530cda
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
16 changes: 14 additions & 2 deletions src/progpy/predictors/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class MonteCarlo(Predictor):
__DEFAULT_N_SAMPLES = 100 # Default number of samples to use, if none specified and not UncertainData

default_parameters = {
'n_samples': None
'n_samples': None,
'event_strategy': 'all'
}

def predict(self, state: UncertainData, future_loading_eqn: Callable=None, events=None, **kwargs) -> PredictionResults:
Expand All @@ -50,6 +51,10 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable=None, event
Simulation step size (s), e.g., 0.1
events : list[str], optional
Events to predict (subset of model.events) e.g., ['event1', 'event2']
event_strategy: str, optional
Strategy for stopping evaluation. Default is 'all'. One of:\n
'first': Will stop when first event in `events` list is reached.
'all': Will stop when all events in `events` list have been reached
horizon : float, optional
Prediction horizon (s)
n_samples : int, optional
Expand Down Expand Up @@ -84,6 +89,8 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable=None, event
params.update(kwargs) # update for specific run
params['print'] = False
params['progress'] = False
# Remove event_strategy from params to not confuse simulate_to method call
event_strategy = params.pop('event_strategy')

if not isinstance(state, UnweightedSamples) and params['n_samples'] is None:
# if not unweighted samples, some sample number is required, so set to default.
Expand Down Expand Up @@ -184,7 +191,12 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable=None, event

# An event has occured
time_of_event[event] = times[-1]
events_remaining.remove(event) # No longer an event to predict to
if event_strategy == 'all':
events_remaining.remove(event) # No longer an event to predict to
elif event_strategy in ('first', 'any'):
events_remaining = []
else:
raise ValueError(f"Invalid value for `event_strategy`: {event_strategy}. Should be either 'all' or 'first'")

# Remove last state (event)
params['t0'] = times.pop()
Expand Down
6 changes: 5 additions & 1 deletion src/progpy/predictors/unscented_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class UnscentedTransformPredictor(Predictor):
'kappa': -1,
't0': 0,
'dt': 0.5,
'event_strategy': 'all',
'horizon': 1e99,
'save_pts': [],
'save_freq': 1e99
Expand Down Expand Up @@ -175,6 +176,9 @@ def predict(self, state, future_loading_eqn: Callable = None, events=None, **kwa
params = deepcopy(self.parameters) # copy parameters
params.update(kwargs) # update for specific run

if params['event_strategy'] != 'all':
raise ValueError(f"`event_strategy` {params['event_strategy']} not supported. Currently, only 'all' event strategy is supported")

if events is None:
# Predict to all events
# change to list because of limits of jsonify
Expand Down Expand Up @@ -252,7 +256,7 @@ def update_all():
all_failed = False # This event for this sigma point hasn't been met yet
if all_failed:
# If all events have been reched for every sigma point
break
break

# Prepare Results
pts = array([[e for e in ToE[key]] for key in ToE.keys()])
Expand Down
35 changes: 34 additions & 1 deletion tests/test_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def test_pred_template(self):
m = MockProgModel()
pred = TemplatePredictor(m)

def test_UTP_Broken(self):
m = ThrownObject()
pred = UnscentedTransformPredictor(m)
samples = MultivariateNormalDist(['x', 'v'], [1.83, 40], [[0.1, 0.01], [0.01, 0.1]])

with self.assertRaises(ValueError):
# Invalid event strategy - first (not supported)
pred.predict(samples, dt=0.2, num_samples=3, save_freq=1, event_strategy='first')

def test_UTP_ThrownObject(self):
m = ThrownObject()
pred = UnscentedTransformPredictor(m)
Expand Down Expand Up @@ -122,7 +131,15 @@ def future_loading(t, x=None):
s = results.time_of_event.sample(100).key('EOD')
samples.eol_metrics(s) # Kept for backwards compatibility

def test_MC(self):
def test_MC_Broken(self):
m = ThrownObject()
mc = MonteCarlo(m)

with self.assertRaises(ValueError):
# Invalid event strategy
mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, event_strategy='fdksl')

def test_MC_ThrownObject(self):
m = ThrownObject()
mc = MonteCarlo(m)

Expand All @@ -131,6 +148,22 @@ def test_MC(self):
self.assertAlmostEqual(results.time_of_event.mean['impact'], 8.0, 5)
self.assertAlmostEqual(results.time_of_event.mean['falling'], 3.8, 5)

# event_strategy='all' should act the same
results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, event_strategy='all')
self.assertAlmostEqual(results.time_of_event.mean['impact'], 8.0, 5)
self.assertAlmostEqual(results.time_of_event.mean['falling'], 3.8, 5)

def test_MC_ThrownObject_First(self):
# Test thrown object, similar to test_UKP_ThrownObject, but with only the first event (i.e., 'falling')

m = ThrownObject()
mc = MonteCarlo(m)
mc_results = mc.predict(m.initialize(), dt=0.2, event_strategy='first', num_samples=3, save_freq=1)

self.assertAlmostEqual(mc_results.time_of_event.mean['falling'], 3.8, 10)
self.assertTrue('impact' not in mc_results.time_of_event.mean)
self.assertAlmostEqual(mc_results.times[-1], 3, 1) # Saving every second, last time should be around the nearest 1s before falling event

def test_prediction_mvnormaldist(self):
times = list(range(10))
covar = [[0.1, 0.01], [0.01, 0.1]]
Expand Down

0 comments on commit 7530cda

Please sign in to comment.