Skip to content

Commit

Permalink
fix events conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
teubert committed Jul 12, 2024
1 parent 8c60e9a commit 1ace59a
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 14 deletions.
12 changes: 8 additions & 4 deletions src/progpy/predictors/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class MonteCarlo(Predictor):
'n_samples': None
}

def predict(self, state: UncertainData, future_loading_eqn: Callable=None, **kwargs) -> PredictionResults:
def predict(self, state: UncertainData, future_loading_eqn: Callable=None, events=None, **kwargs) -> PredictionResults:
"""
Perform a single prediction
Expand Down Expand Up @@ -91,7 +91,11 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable=None, **kwa
elif isinstance(state, UnweightedSamples) and params['n_samples'] is None:
params['n_samples'] = len(state) # number of samples is from provided state

if len(params['events']) == 0 and 'horizon' not in params:
if events is None:
# Predict to all events
# change to list because of limits of jsonify
events = list(self.model.events)
if len(events) == 0 and 'horizon' not in params:
raise ValueError("If specifying no event (i.e., simulate to time), must specify horizon")

# Sample from state if n_samples specified or state is not UnweightedSamples (Case 2)
Expand Down Expand Up @@ -129,15 +133,15 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable=None, **kwa
if 'save_freq' in params and not isinstance(params['save_freq'], tuple):
params['save_freq'] = (params['t0'], params['save_freq'])

if len(params['events']) == 0: # Predict to time
if len(events) == 0: # Predict to time
(times, inputs, states, outputs, event_states) = simulate_to_threshold(
future_loading_eqn,
first_output,
events=[],
**params
)
else:
events_remaining = params['events'].copy()
events_remaining = events.copy()

times = []
inputs = SimResult(_copy=False)
Expand Down
3 changes: 0 additions & 3 deletions src/progpy/predictors/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ def __init__(self, model, **kwargs):
self.model = model

self.parameters = deepcopy(self.default_parameters)
# Events to predict to - must be a list
# This is because of limitations with jsonify for sets
self.parameters['events'] = list(self.model.events.copy())
self.parameters.update(kwargs)

@abstractmethod
Expand Down
15 changes: 9 additions & 6 deletions src/progpy/predictors/unscented_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def state_transition(x, dt):
self.filter = kalman.UnscentedKalmanFilter(num_states, num_measurements, self.parameters['dt'], measure, state_transition, self.sigma_points)
self.filter.Q = self.parameters['Q']

def predict(self, state, future_loading_eqn: Callable = None, **kwargs) -> PredictionResults:
def predict(self, state, future_loading_eqn: Callable = None, events=None, **kwargs) -> PredictionResults:
"""
Perform a single prediction
Expand Down Expand Up @@ -175,11 +175,14 @@ def predict(self, state, future_loading_eqn: Callable = None, **kwargs) -> Predi
params = deepcopy(self.parameters) # copy parameters
params.update(kwargs) # update for specific run

if len(params['events']) == 0 and 'horizon' not in params:
if events is None:
# Predict to all events
# change to list because of limits of jsonify
events = list(self.model.events)
if len(events) == 0 and 'horizon' not in params:
raise ValueError("If specifying no event (i.e., simulate to time), must specify horizon")

# Optimizations
events_to_predict = params['events']
dt = params['dt']
model = self.model
filt = self.filter
Expand All @@ -196,8 +199,8 @@ def predict(self, state, future_loading_eqn: Callable = None, **kwargs) -> Predi
# Setup first states
t = params['t0']
save_pt_index = 0
ToE = {key: [float('nan') for i in range(n_points)] for key in events_to_predict} # Keep track of final ToE values
last_state = {key: [None for i in range(n_points)] for key in events_to_predict} # Keep track of final state values
ToE = {key: [float('nan') for i in range(n_points)] for key in events} # Keep track of final ToE values
last_state = {key: [None for i in range(n_points)] for key in events} # Keep track of final state values

times = []
inputs = []
Expand Down Expand Up @@ -239,7 +242,7 @@ def update_all():
t_met = threshold_met(x)

# Check Thresholds
for key in events_to_predict:
for key in events:
if t_met[key]:
if isnan(ToE[key][i]):
# First time event has been reached
Expand Down
2 changes: 1 addition & 1 deletion tests/test_horizon.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def future_loading(t, x=None):
# With this horizon, all samples will reach 'falling', but only some will reach 'impact'
PREDICTION_HORIZON = 2.127
STEP_SIZE = 0.001
mc_results = mc.predict(initial_state, future_loading, dt=STEP_SIZE, horizon = PREDICTION_HORIZON)
mc_results = mc.predict(initial_state, future_loading, dt=STEP_SIZE, horizon=PREDICTION_HORIZON)

# 'falling' happens before the horizon is met
falling_res = [mc_results.time_of_event[iter]['falling'] for iter in range(mc.parameters['n_samples']) if mc_results.time_of_event[iter]['falling'] is not None]
Expand Down

0 comments on commit 1ace59a

Please sign in to comment.