From 1f676d32d4641ce171012681eef195948cc04577 Mon Sep 17 00:00:00 2001 From: Christopher Teubert Date: Fri, 8 Nov 2024 12:15:41 -0800 Subject: [PATCH] Add direct access --- src/progpy/predictors/predictor.py | 6 +++ .../state_estimators/state_estimator.py | 6 +++ tests/test_state_estimators.py | 42 ++++++++++++++++++- 3 files changed, 53 insertions(+), 1 deletion(-) diff --git a/src/progpy/predictors/predictor.py b/src/progpy/predictors/predictor.py index 669cdb3..79a535c 100644 --- a/src/progpy/predictors/predictor.py +++ b/src/progpy/predictors/predictor.py @@ -74,3 +74,9 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable, **kwargs) * time_of_event (UncertainData): Distribution of predicted Time of Event (ToE) for each predicted event, represented by some subclass of UncertaintData (e.g., MultivariateNormalDist) """ pass + + def __getitem__(self, arg): + return self.parameters[arg] + + def __setitem__(self, key, value): + self.parameters[key] = value diff --git a/src/progpy/state_estimators/state_estimator.py b/src/progpy/state_estimators/state_estimator.py index 75cac51..806fa69 100644 --- a/src/progpy/state_estimators/state_estimator.py +++ b/src/progpy/state_estimators/state_estimator.py @@ -111,3 +111,9 @@ def x(self) -> UncertainData: ------- state = filt.x """ + + def __getitem__(self, arg): + return self.parameters[arg] + + def __setitem__(self, key, value): + self.parameters[key] = value diff --git a/tests/test_state_estimators.py b/tests/test_state_estimators.py index 99bbf6e..1232027 100644 --- a/tests/test_state_estimators.py +++ b/tests/test_state_estimators.py @@ -115,6 +115,38 @@ def __test_state_est(self, filt, m): # should be close to right self.assertAlmostEqual(x_est[key], x[key], delta=0.4) + def __test_state_est_no_dt(self, filt, m): + x = m.initialize() + filt['dt'] = 0.2 + + self.assertTrue(all(key in filt.x.mean for key in m.states)) + + # run for a while + dt = 0.2 + u = m.InputContainer({}) + last_time = 0 + for i in range(500): + # Get simulated output (would be measured in a real application) + x = m.next_state(x, u, dt) + z = m.output(x) + + # Estimate New State every few steps + if i % 8 == 0: + # This is to test dt setting at the estimator lvl + # Without dt, this would fail + last_time = (i+1)*dt + filt.estimate((i+1)*dt, u, z) + + if last_time != (i+1)*dt: + # Final estimate + filt.estimate((i+1)*dt, u, z) + + # Check results - make sure it converged + x_est = filt.x.mean + for key in m.states: + # should be close to right + self.assertAlmostEqual(x_est[key], x[key], delta=0.4) + def test_UKF(self): m = ThrownObject(process_noise=5e-2, measurement_noise=5e-2) x_guess = {'x': 1.75, 'v': 35} # Guess of initial state, actual is {'x': 1.83, 'v': 40} @@ -122,6 +154,9 @@ def test_UKF(self): filt = UnscentedKalmanFilter(m, x_guess) self.__test_state_est(filt, m) + filt = UnscentedKalmanFilter(m, x_guess) + self.__test_state_est_no_dt(filt, m) + m = ThrownObject(process_noise=5e-2, measurement_noise=5e-2) # Test UnscentedKalmanFilter ScalarData @@ -322,6 +357,9 @@ def test_PF(self): filt = ParticleFilter(m, x_guess, num_particles = 1000, measurement_noise = {'x': 1}) self.__test_state_est(filt, m) + filt = ParticleFilter(m, x_guess, num_particles = 1000, measurement_noise = {'x': 1}) + self.__test_state_est_no_dt(filt, m) + # Test ParticleFilter ScalarData x_scalar = ScalarData({'x': 1.75, 'v': 38.5}) filt_scalar = ParticleFilter(m, x_scalar, num_particles = 20) # Sample count does not affect ScalarData testing @@ -438,9 +476,11 @@ def event_state(self, x): x_guess = {'x': 1.75, 'v': 35} # Guess of initial state, actual is {'x': 1.83, 'v': 40} filt = KalmanFilter(m, x_guess) - self.__test_state_est(filt, m) + filt = KalmanFilter(m, x_guess) + self.__test_state_est_no_dt(filt, m) + m = ThrownObject(process_noise=5e-2, measurement_noise=5e-2) # Test KalmanFilter ScalarData