Skip to content

Commit

Permalink
clean predictors test
Browse files Browse the repository at this point in the history
  • Loading branch information
teubert committed Jul 15, 2024
1 parent 380efd2 commit 059704a
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions tests/test_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def test_UTP_ThrownObject(self):
samples = MultivariateNormalDist(['x', 'v'], [1.83, 40], [[0.1, 0.01], [0.01, 0.1]])

# No future loading (i.e., no load)
mc_results = pred.predict(samples, dt=0.01, save_freq=1)
self.assertAlmostEqual(mc_results.time_of_event.mean['impact'], 8.21, 0)
self.assertAlmostEqual(mc_results.time_of_event.mean['falling'], 4.15, 0)
results = pred.predict(samples, dt=0.01, save_freq=1)
self.assertAlmostEqual(results.time_of_event.mean['impact'], 8.21, 0)
self.assertAlmostEqual(results.time_of_event.mean['falling'], 4.15, 0)
# self.assertAlmostEqual(mc_results.times[-1], 9, 1) # Saving every second, last time should be around the 1s after impact event (because one of the sigma points fails afterwards)

def test_UTP_ThrownObject_One_Event(self):
Expand All @@ -81,10 +81,10 @@ def test_UTP_ThrownObject_One_Event(self):
def future_loading(t, x={}):
return {}

mc_results = pred.predict(samples, future_loading, dt=0.01, events=['falling'], save_freq=1)
self.assertAlmostEqual(mc_results.time_of_event.mean['falling'], 3.8, 0)
self.assertTrue('impact' not in mc_results.time_of_event.mean)
self.assertAlmostEqual(mc_results.times[-1], 3, 1) # Saving every second, last time should be around the nearest 1s before falling event
results = pred.predict(samples, future_loading, dt=0.01, events=['falling'], save_freq=1)
self.assertAlmostEqual(results.time_of_event.mean['falling'], 3.8, 0)
self.assertTrue('impact' not in results.time_of_event.mean)
self.assertAlmostEqual(results.times[-1], 3, 1) # Saving every second, last time should be around the nearest 1s before falling event

def test_UKP_Battery(self):
def future_loading(t, x=None):
Expand Down Expand Up @@ -115,19 +115,21 @@ def future_loading(t, x=None):
ut = UnscentedTransformPredictor(batt)

# Predict with a step size of 0.1
mc_results = ut.predict(filt.x, future_loading, dt=0.1)
self.assertAlmostEqual(mc_results.time_of_event.mean['EOD'], 3004, -2)
results = ut.predict(filt.x, future_loading, dt=0.1)
self.assertAlmostEqual(results.time_of_event.mean['EOD'], 3004, -2)

# Test Metrics
s = mc_results.time_of_event.sample(100).key('EOD')
s = results.time_of_event.sample(100).key('EOD')
samples.eol_metrics(s) # Kept for backwards compatibility

def test_MC(self):
m = ThrownObject()
mc = MonteCarlo(m)

# Test with empty future loading (i.e., no load)
mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1)
results = mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1)
self.assertAlmostEqual(results.time_of_event.mean['impact'], 8.0, 5)
self.assertAlmostEqual(results.time_of_event.mean['falling'], 3.8, 5)

def test_prediction_mvnormaldist(self):
times = list(range(10))
Expand Down

0 comments on commit 059704a

Please sign in to comment.