Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix noise transfer #137

Merged
merged 2 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/progpy/models/test_models/linear_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,26 @@ class OneInputNoOutputNoEventLM(LinearModel):
}
}

class OneInputTwoStatesNoOutputNoEventLM(LinearModel):
"""
Simple model that increases state by u1 every step.
"""
inputs = ['u1']
states = ['x1', 'x2']

A = np.array([[0, 0], [0, 0]])
B = np.array([[1], [0]])
C = np.empty((0,2))
F = np.empty((0,2))

default_parameters = {
'process_noise': 0,
'x0': {
'x1': 0,
'x2': 0
}
}


class OneInputOneOutputNoEventLM(LinearModel):
"""
Expand Down
25 changes: 19 additions & 6 deletions src/progpy/utils/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from scipy.integrate import OdeSolver
import types

from progpy.utils.containers import DictLikeMatrixWrapper
from progpy.utils.next_state import next_state_functions, SciPyIntegrateNextState
from progpy.utils.noise_functions import measurement_noise_functions, process_noise_functions
from progpy.utils.serialization import CustomEncoder, custom_decoder
Expand Down Expand Up @@ -120,17 +121,23 @@ def __setitem__(self, key: str, value: float, _copy: bool = False) -> None:
if callable(self['process_noise']): # Provided a function
self._m.apply_process_noise = types.MethodType(self['process_noise'], self._m)
else: # Not a function
# Process noise is single number - convert to dict
if key == 'process_noise' and isinstance(self['process_noise'], DictLikeMatrixWrapper):
# If it's already a DictLikeMatrixWrapper- convert to dict,
# so then it will be treated as a dictionary and missing keys will be filled in
# this way we're also sure that the final result is the "right" kind of container
self.data['process_noise'] = dict(self['process_noise'])

if isinstance(self['process_noise'], Number):
self['process_noise'] = self._m.StateContainer({key: self['process_noise'] for key in self._m.states})
# Process noise is single number - convert to dict
self.data['process_noise'] = self._m.StateContainer({key: self['process_noise'] for key in self._m.states})
elif isinstance(self['process_noise'], dict):
noise = self['process_noise']
for key in self._m.states:
# Set any missing keys to 0
if key not in noise.keys():
noise[key] = 0

self['process_noise'] = self._m.StateContainer(noise)
self.data['process_noise'] = self._m.StateContainer(noise)

# Process distribution type
if 'process_noise_dist' in self and self['process_noise_dist'].lower() not in process_noise_functions:
Expand Down Expand Up @@ -158,16 +165,22 @@ def __setitem__(self, key: str, value: float, _copy: bool = False) -> None:
if callable(self['measurement_noise']):
self._m.apply_measurement_noise = types.MethodType(self['measurement_noise'], self._m)
else:
# Process noise is single number - convert to dict
if key == 'measurement_noise' and isinstance(self['measurement_noise'], DictLikeMatrixWrapper):
# If it's already a DictLikeMatrixWrapper- convert to dict,
# so then it will be treated as a dictionary and missing keys will be filled in
# this way we're also sure that the final result is the "right" kind of container
self.data['measurement_noise'] = dict(self['measurement_noise'])

if isinstance(self['measurement_noise'], Number):
self['measurement_noise'] = self._m.OutputContainer({key: self['measurement_noise'] for key in self._m.outputs})
# Process noise is single number - convert to dict
self.data['measurement_noise'] = self._m.OutputContainer({key: self['measurement_noise'] for key in self._m.outputs})
elif isinstance(self['measurement_noise'], dict):
noise = self['measurement_noise']
for key in self._m.outputs:
# Set any missing keys to 0
if key not in noise.keys():
noise[key] = 0
self['measurement_noise'] = self._m.OutputContainer(noise)
self.data['measurement_noise'] = self._m.OutputContainer(noise)

# Process distribution type
if 'measurement_noise_dist' in self and self['measurement_noise_dist'].lower() not in measurement_noise_functions:
Expand Down
10 changes: 9 additions & 1 deletion tests/test_base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from progpy import PrognosticsModel, CompositeModel
from progpy.models import ThrownObject, BatteryElectroChemEOD
from progpy.models.test_models.linear_models import (OneInputNoOutputNoEventLM, OneInputOneOutputNoEventLM, OneInputNoOutputOneEventLM, OneInputOneOutputNoEventLMPM)
from progpy.models.test_models.linear_models import (OneInputNoOutputNoEventLM, OneInputOneOutputNoEventLM, OneInputTwoStatesNoOutputNoEventLM, OneInputNoOutputOneEventLM, OneInputOneOutputNoEventLMPM)
from progpy.models.test_models.linear_thrown_object import (LinearThrownObject, LinearThrownDiffThrowingSpeed, LinearThrownObjectUpdatedInitializedMethod, LinearThrownObjectDiffDefaultParams)


Expand Down Expand Up @@ -160,6 +160,14 @@ def test_integration_type(self):
self.assertEqual(x_default['v'], x_rk4['v'])
self.assertEqual(x_default['x'], x_rk4['x'])

def test_parameters_statelikematrixwrapper(self):
"""
This is testing a very specific case where a state container from one model is used to define the noise from another.
"""
m0 = OneInputNoOutputNoEventLM()
m1 = OneInputTwoStatesNoOutputNoEventLM(process_noise=m0.parameters['process_noise'])
self.assertSetEqual(set(m1.parameters['process_noise'].keys()), set(m1.states))

def test_integration_type_scipy(self):
# SciPy Integrator test.
# Here we will set the integrator to various scipy integration methods and make sure that it works
Expand Down
Loading