Skip to content

Commit

Permalink
Clean handling of events
Browse files Browse the repository at this point in the history
  • Loading branch information
teubert committed Jul 25, 2024
1 parent 0c655e7 commit 697f8ce
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 11 deletions.
23 changes: 18 additions & 5 deletions src/progpy/predictors/monte_carlo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved.

from collections import abc
from copy import deepcopy
from typing import Callable
from progpy.sim_result import SimResult, LazySimResult
Expand Down Expand Up @@ -99,16 +100,28 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable=None, event
params['n_samples'] = len(state) # number of samples is from provided state

if events is None:
# Predict to all events
# change to list because of limits of jsonify
if 'events' in params and params['events'] is not None:
# Set at a model level
events = list(params['events'])
# Set at a predictor construction
events = params['events']
else:
# Otherwise, all events
events = list(self.model.events)
events = self.model.events

if not isinstance(events, (abc.Iterable)) or isinstance(events, (dict, bytes)):
# must be string or list-like (list, tuple, set)
# using abc.Iterable adds support for custom data structures
# that implement that abstract base class
raise TypeError(f'`events` must be a single event string or list of events. Was unsupported type {type(events)}.')
if len(events) == 0 and 'horizon' not in params:
raise ValueError("If specifying no event (i.e., simulate to time), must specify horizon")
if isinstance(events, str):
# A single event
events = [events]
if not all([key in self.model.events for key in events]):
raise ValueError("`events` must be event names")
if not isinstance(events, list):
# Change to list because of the limits of jsonify
events = list(events)

if 'events' in params:
# Params is provided as a argument in construction
Expand Down
24 changes: 21 additions & 3 deletions src/progpy/predictors/unscented_transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved.

from collections import abc
from copy import deepcopy
from filterpy import kalman
from numpy import diag, array, transpose, isnan
Expand Down Expand Up @@ -180,11 +181,28 @@ def predict(self, state, future_loading_eqn: Callable = None, events=None, **kwa
raise ValueError(f"`event_strategy` {params['event_strategy']} not supported. Currently, only 'all' event strategy is supported")

if events is None:
# Predict to all events
# change to list because of limits of jsonify
events = list(self.model.events)
if 'events' in params and params['events'] is not None:
# Set at a predictor construction
events = params['events']
else:
# Otherwise, all events
events = self.model.events

if not isinstance(events, (abc.Iterable)) or isinstance(events, (dict, bytes)):
# must be string or list-like (list, tuple, set)
# using abc.Iterable adds support for custom data structures
# that implement that abstract base class
raise TypeError(f'`events` must be a single event string or list of events. Was unsupported type {type(events)}.')
if len(events) == 0 and 'horizon' not in params:
raise ValueError("If specifying no event (i.e., simulate to time), must specify horizon")
if isinstance(events, str):
# A single event
events = [events]
if not all([key in self.model.events for key in events]):
raise ValueError("`events` must be event names")
if not isinstance(events, list):
# Change to list because of the limits of jsonify
events = list(events)

# Optimizations
dt = params['dt']
Expand Down
13 changes: 10 additions & 3 deletions src/progpy/prognostics_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,13 +871,20 @@ def simulate_to_threshold(self, future_loading_eqn: abc.Callable = None, first_o
events = kwargs['threshold_keys']
else:
warn('Both `events` and `threshold_keys` were set. `events` will be used.')


if events is None:
events = self.events.copy()
if not isinstance(events, abc.Iterable):
# must be string or list-like
raise TypeError(f'`events` must be a single event string or list of events. Was unsupported type {type(events)}.')
if isinstance(events, str):
# A single threshold key
# A single event
events = [events]

if (events is not None) and not all([key in self.events for key in events]):
raise ValueError("`events` must be event names")
if not isinstance(events, list):
# Change to list because of the limits of jsonify
events = list(events)

# Configure
config = { # Defaults
Expand Down
55 changes: 55 additions & 0 deletions tests/test_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,39 @@ def test_UTP_ThrownObject(self):
self.assertAlmostEqual(results.time_of_event.mean['falling'], 4.15, 0)
# self.assertAlmostEqual(mc_results.times[-1], 9, 1) # Saving every second, last time should be around the 1s after impact event (because one of the sigma points fails afterwards)

# Setting event manually
results = pred.predict(samples, dt=0.01, save_freq=1, events=['falling'])
self.assertAlmostEqual(results.time_of_event.mean['falling'], 3.8, 5)
self.assertNotIn('impact', results.time_of_event.mean)

# Setting event in construction
pred = UnscentedTransformPredictor(m, events=['falling'])
results = pred.predict(samples, dt=0.01, save_freq=1)
self.assertAlmostEqual(results.time_of_event.mean['falling'], 3.8, 5)
self.assertNotIn('impact', results.time_of_event.mean)

# Override event set in construction
results = pred.predict(samples, dt=0.01, save_freq=1, events=['falling', 'impact'])
self.assertAlmostEqual(results.time_of_event.mean['impact'], 8.21, 0)
self.assertAlmostEqual(results.time_of_event.mean['falling'], 4.15, 0)

# String event
results = pred.predict(samples, dt=0.01, save_freq=1, events='impact')
self.assertAlmostEqual(results.time_of_event.mean['impact'], 7.785, 5)
self.assertNotIn('falling', results.time_of_event.mean)

# Invalid event
with self.assertRaises(ValueError):
results = pred.predict(samples, dt=0.01, save_freq=1, events='invalid')
with self.assertRaises(ValueError):
# Mix valid, invalid
results = pred.predict(samples, dt=0.01, save_freq=1, events=['falling', 'invalid'])
with self.assertRaises(ValueError):
# Empty
results = pred.predict(samples, dt=0.01, save_freq=1, events=[])
with self.assertRaises(TypeError):
results = pred.predict(samples, dt=0.01, save_freq=1, events=45)

def test_UTP_ThrownObject_One_Event(self):
# Test thrown object, similar to test_UKP_ThrownObject, but with only the 'falling' event
m = ThrownObject()
Expand Down Expand Up @@ -168,6 +201,28 @@ def test_MC_ThrownObject(self):
results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, events=['falling', 'impact'])
self.assertAlmostEqual(results.time_of_event.mean['falling'], 3.8, 5)
self.assertAlmostEqual(results.time_of_event.mean['impact'], 8.0, 5)

# String event
results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, events='impact')
self.assertAlmostEqual(results.time_of_event.mean['impact'], 8.0, 5)
self.assertNotIn('falling', results.time_of_event.mean)

# Invalid event
with self.assertRaises(ValueError):
results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, events='invalid')
with self.assertRaises(ValueError):
# Mix valid, invalid
results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, events=['falling', 'invalid'])
with self.assertRaises(ValueError):
# Empty
results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, events=[])
with self.assertRaises(TypeError):
results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, events=45)

# Empty with horizon
results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, horizon=3, events=[])

# TODO(CT): Events in other predictor

def test_MC_ThrownObject_First(self):
# Test thrown object, similar to test_UKP_ThrownObject, but with only the first event (i.e., 'falling')
Expand Down

0 comments on commit 697f8ce

Please sign in to comment.