diff --git a/src/progpy/predictors/monte_carlo.py b/src/progpy/predictors/monte_carlo.py index 96900c4..405304b 100644 --- a/src/progpy/predictors/monte_carlo.py +++ b/src/progpy/predictors/monte_carlo.py @@ -1,5 +1,6 @@ # Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved. +from collections import abc from copy import deepcopy from typing import Callable from progpy.sim_result import SimResult, LazySimResult @@ -99,16 +100,28 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable=None, event params['n_samples'] = len(state) # number of samples is from provided state if events is None: - # Predict to all events - # change to list because of limits of jsonify if 'events' in params and params['events'] is not None: - # Set at a model level - events = list(params['events']) + # Set at a predictor construction + events = params['events'] else: # Otherwise, all events - events = list(self.model.events) + events = self.model.events + + if not isinstance(events, (abc.Iterable)) or isinstance(events, (dict, bytes)): + # must be string or list-like (list, tuple, set) + # using abc.Iterable adds support for custom data structures + # that implement that abstract base class + raise TypeError(f'`events` must be a single event string or list of events. Was unsupported type {type(events)}.') if len(events) == 0 and 'horizon' not in params: raise ValueError("If specifying no event (i.e., simulate to time), must specify horizon") + if isinstance(events, str): + # A single event + events = [events] + if not all([key in self.model.events for key in events]): + raise ValueError("`events` must be event names") + if not isinstance(events, list): + # Change to list because of the limits of jsonify + events = list(events) if 'events' in params: # Params is provided as a argument in construction diff --git a/src/progpy/predictors/unscented_transform.py b/src/progpy/predictors/unscented_transform.py index 312b95b..a16f81a 100644 --- a/src/progpy/predictors/unscented_transform.py +++ b/src/progpy/predictors/unscented_transform.py @@ -1,5 +1,6 @@ # Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved. +from collections import abc from copy import deepcopy from filterpy import kalman from numpy import diag, array, transpose, isnan @@ -180,11 +181,28 @@ def predict(self, state, future_loading_eqn: Callable = None, events=None, **kwa raise ValueError(f"`event_strategy` {params['event_strategy']} not supported. Currently, only 'all' event strategy is supported") if events is None: - # Predict to all events - # change to list because of limits of jsonify - events = list(self.model.events) + if 'events' in params and params['events'] is not None: + # Set at a predictor construction + events = params['events'] + else: + # Otherwise, all events + events = self.model.events + + if not isinstance(events, (abc.Iterable)) or isinstance(events, (dict, bytes)): + # must be string or list-like (list, tuple, set) + # using abc.Iterable adds support for custom data structures + # that implement that abstract base class + raise TypeError(f'`events` must be a single event string or list of events. Was unsupported type {type(events)}.') if len(events) == 0 and 'horizon' not in params: raise ValueError("If specifying no event (i.e., simulate to time), must specify horizon") + if isinstance(events, str): + # A single event + events = [events] + if not all([key in self.model.events for key in events]): + raise ValueError("`events` must be event names") + if not isinstance(events, list): + # Change to list because of the limits of jsonify + events = list(events) # Optimizations dt = params['dt'] diff --git a/src/progpy/prognostics_model.py b/src/progpy/prognostics_model.py index 2b40d61..708ead5 100644 --- a/src/progpy/prognostics_model.py +++ b/src/progpy/prognostics_model.py @@ -871,13 +871,20 @@ def simulate_to_threshold(self, future_loading_eqn: abc.Callable = None, first_o events = kwargs['threshold_keys'] else: warn('Both `events` and `threshold_keys` were set. `events` will be used.') - + + if events is None: + events = self.events.copy() + if not isinstance(events, abc.Iterable): + # must be string or list-like + raise TypeError(f'`events` must be a single event string or list of events. Was unsupported type {type(events)}.') if isinstance(events, str): - # A single threshold key + # A single event events = [events] - if (events is not None) and not all([key in self.events for key in events]): raise ValueError("`events` must be event names") + if not isinstance(events, list): + # Change to list because of the limits of jsonify + events = list(events) # Configure config = { # Defaults diff --git a/tests/test_predictors.py b/tests/test_predictors.py index c6f04c3..233d935 100644 --- a/tests/test_predictors.py +++ b/tests/test_predictors.py @@ -82,6 +82,39 @@ def test_UTP_ThrownObject(self): self.assertAlmostEqual(results.time_of_event.mean['falling'], 4.15, 0) # self.assertAlmostEqual(mc_results.times[-1], 9, 1) # Saving every second, last time should be around the 1s after impact event (because one of the sigma points fails afterwards) + # Setting event manually + results = pred.predict(samples, dt=0.01, save_freq=1, events=['falling']) + self.assertAlmostEqual(results.time_of_event.mean['falling'], 3.8, 5) + self.assertNotIn('impact', results.time_of_event.mean) + + # Setting event in construction + pred = UnscentedTransformPredictor(m, events=['falling']) + results = pred.predict(samples, dt=0.01, save_freq=1) + self.assertAlmostEqual(results.time_of_event.mean['falling'], 3.8, 5) + self.assertNotIn('impact', results.time_of_event.mean) + + # Override event set in construction + results = pred.predict(samples, dt=0.01, save_freq=1, events=['falling', 'impact']) + self.assertAlmostEqual(results.time_of_event.mean['impact'], 8.21, 0) + self.assertAlmostEqual(results.time_of_event.mean['falling'], 4.15, 0) + + # String event + results = pred.predict(samples, dt=0.01, save_freq=1, events='impact') + self.assertAlmostEqual(results.time_of_event.mean['impact'], 7.785, 5) + self.assertNotIn('falling', results.time_of_event.mean) + + # Invalid event + with self.assertRaises(ValueError): + results = pred.predict(samples, dt=0.01, save_freq=1, events='invalid') + with self.assertRaises(ValueError): + # Mix valid, invalid + results = pred.predict(samples, dt=0.01, save_freq=1, events=['falling', 'invalid']) + with self.assertRaises(ValueError): + # Empty + results = pred.predict(samples, dt=0.01, save_freq=1, events=[]) + with self.assertRaises(TypeError): + results = pred.predict(samples, dt=0.01, save_freq=1, events=45) + def test_UTP_ThrownObject_One_Event(self): # Test thrown object, similar to test_UKP_ThrownObject, but with only the 'falling' event m = ThrownObject() @@ -168,6 +201,28 @@ def test_MC_ThrownObject(self): results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, events=['falling', 'impact']) self.assertAlmostEqual(results.time_of_event.mean['falling'], 3.8, 5) self.assertAlmostEqual(results.time_of_event.mean['impact'], 8.0, 5) + + # String event + results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, events='impact') + self.assertAlmostEqual(results.time_of_event.mean['impact'], 8.0, 5) + self.assertNotIn('falling', results.time_of_event.mean) + + # Invalid event + with self.assertRaises(ValueError): + results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, events='invalid') + with self.assertRaises(ValueError): + # Mix valid, invalid + results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, events=['falling', 'invalid']) + with self.assertRaises(ValueError): + # Empty + results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, events=[]) + with self.assertRaises(TypeError): + results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, events=45) + + # Empty with horizon + results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1, horizon=3, events=[]) + + # TODO(CT): Events in other predictor def test_MC_ThrownObject_First(self): # Test thrown object, similar to test_UKP_ThrownObject, but with only the first event (i.e., 'falling')