Skip to content

Commit

Permalink
Add direct access
Browse files Browse the repository at this point in the history
  • Loading branch information
teubert committed Nov 8, 2024
1 parent 8c3d540 commit 1f676d3
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 1 deletion.
6 changes: 6 additions & 0 deletions src/progpy/predictors/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,9 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable, **kwargs)
* time_of_event (UncertainData): Distribution of predicted Time of Event (ToE) for each predicted event, represented by some subclass of UncertaintData (e.g., MultivariateNormalDist)
"""
pass

def __getitem__(self, arg):
return self.parameters[arg]

def __setitem__(self, key, value):
self.parameters[key] = value
6 changes: 6 additions & 0 deletions src/progpy/state_estimators/state_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,9 @@ def x(self) -> UncertainData:
-------
state = filt.x
"""

def __getitem__(self, arg):
return self.parameters[arg]

def __setitem__(self, key, value):
self.parameters[key] = value
42 changes: 41 additions & 1 deletion tests/test_state_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,48 @@ def __test_state_est(self, filt, m):
# should be close to right
self.assertAlmostEqual(x_est[key], x[key], delta=0.4)

def __test_state_est_no_dt(self, filt, m):
x = m.initialize()
filt['dt'] = 0.2

self.assertTrue(all(key in filt.x.mean for key in m.states))

# run for a while
dt = 0.2
u = m.InputContainer({})
last_time = 0
for i in range(500):
# Get simulated output (would be measured in a real application)
x = m.next_state(x, u, dt)
z = m.output(x)

# Estimate New State every few steps
if i % 8 == 0:
# This is to test dt setting at the estimator lvl
# Without dt, this would fail
last_time = (i+1)*dt
filt.estimate((i+1)*dt, u, z)

if last_time != (i+1)*dt:
# Final estimate
filt.estimate((i+1)*dt, u, z)

# Check results - make sure it converged
x_est = filt.x.mean
for key in m.states:
# should be close to right
self.assertAlmostEqual(x_est[key], x[key], delta=0.4)

def test_UKF(self):
m = ThrownObject(process_noise=5e-2, measurement_noise=5e-2)
x_guess = {'x': 1.75, 'v': 35} # Guess of initial state, actual is {'x': 1.83, 'v': 40}

filt = UnscentedKalmanFilter(m, x_guess)
self.__test_state_est(filt, m)

filt = UnscentedKalmanFilter(m, x_guess)
self.__test_state_est_no_dt(filt, m)

m = ThrownObject(process_noise=5e-2, measurement_noise=5e-2)

# Test UnscentedKalmanFilter ScalarData
Expand Down Expand Up @@ -322,6 +357,9 @@ def test_PF(self):
filt = ParticleFilter(m, x_guess, num_particles = 1000, measurement_noise = {'x': 1})
self.__test_state_est(filt, m)

filt = ParticleFilter(m, x_guess, num_particles = 1000, measurement_noise = {'x': 1})
self.__test_state_est_no_dt(filt, m)

# Test ParticleFilter ScalarData
x_scalar = ScalarData({'x': 1.75, 'v': 38.5})
filt_scalar = ParticleFilter(m, x_scalar, num_particles = 20) # Sample count does not affect ScalarData testing
Expand Down Expand Up @@ -438,9 +476,11 @@ def event_state(self, x):
x_guess = {'x': 1.75, 'v': 35} # Guess of initial state, actual is {'x': 1.83, 'v': 40}

filt = KalmanFilter(m, x_guess)

self.__test_state_est(filt, m)

filt = KalmanFilter(m, x_guess)
self.__test_state_est_no_dt(filt, m)

m = ThrownObject(process_noise=5e-2, measurement_noise=5e-2)

# Test KalmanFilter ScalarData
Expand Down

0 comments on commit 1f676d3

Please sign in to comment.