Skip to content

Commit

Permalink
Merge pull request #144 from nasa/feature/143-improved-composite-para…
Browse files Browse the repository at this point in the history
…meters

Composite Model Parameter
  • Loading branch information
teubert committed Jul 5, 2024
2 parents 31af475 + 8894cce commit 8ee90ba
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 1 deletion.
34 changes: 34 additions & 0 deletions examples/06_Combining_Models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,40 @@
"fig = simulated_results.states.plot(keys=['DCMotor.i_b', 'DCMotor.i_c', 'DCMotor.i_a'], ylabel='ESC Currents')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Parameters in composed models can be updated directly using the model_name.parameter name parameter of the composite model. Like so:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"m_powertrain.parameters['PropellerLoad.D'] = 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we updated the propeller diameter to 1, greatly increasing the load on the motor. You can see this in the updated simulation outputs (below). When compared to the original results above you will find that the maximum velocity is lower. This is expected given the larger propeller load."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"simulated_results = m_powertrain.simulate_to(1, future_loading, dt=2.5e-5, save_freq=2e-2)\n",
"fig = simulated_results.outputs.plot(compact=False, keys=['DCMotor.v_rot'], ylabel='Velocity')\n",
"fig = simulated_results.states.plot(keys=['DCMotor.i_b', 'DCMotor.i_c', 'DCMotor.i_a'], ylabel='ESC Currents')"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
29 changes: 29 additions & 0 deletions src/progpy/composite_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,36 @@ def __setstate__(self, params: dict) -> None:
else:
raise ValueError(
f'The input key {in_key} must be an output or state')

# Setup callbacks
# These callbacks will enable setting of parameters in composed models.
# E.g., composite.parameters['abc.def'] will set parameter 'def' for composed model 'abc'
class PassthroughParams():
def __init__(self, models, model_name, key):
self.models = models
self.model_name = model_name
self.key = key
i = 0
for (name, m) in models:
if name == model_name:
break
i+= 1
self.model_index = i
self.combined_key = self.model_name + '.' + self.key

def __call__(self, params: dict) -> dict:
params['models'][self.model_index][1].parameters[self.key] = params[self.combined_key]
return {}

for (name, m) in params['models']:
# TODO(CT): TRY JUST SAVING NAME
for key in m.parameters.keys():
combined_key = name + '.' + key
if combined_key in self.param_callbacks:
self.param_callbacks[combined_key].append(PassthroughParams(params['models'], name, key))
else:
self.param_callbacks[combined_key] = [PassthroughParams(params['models'], name, key)]

return super().__setstate__(params)

def initialize(self, u=None, z=None):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def test_integration_type_error(self):
def test_size(self):
m = MockProgModel()
size = sys.getsizeof(m)
self.assertLess(size, 7500)
self.assertLess(size, 20000)

# Adding a parameter
m.parameters['test'] = 8675309
Expand Down
11 changes: 11 additions & 0 deletions tests/test_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,17 @@ def fcn(u0, u1) -> float:
z = m_composite.output(x)
self.assertEqual(x['function.return'], z['function.return'])

def test_parameter_passthrough(self):
# This tests a feature where parameters of the composed models are settable in the composite model.
m1 = OneInputOneOutputNoEventLM()
m2 = OneInputNoOutputOneEventLM()
m_composite = CompositeModel([m1, m1], connections=[('OneInputOneOutputNoEventLM.z1', 'OneInputOneOutputNoEventLM_2.u1')])

# At the beginning process noise is 0, lets set it to something else.
model_name = m_composite.parameters['models'][0][0]
m_composite.parameters[model_name + "." + "process_noise"] = 2.5
self.assertEqual(m_composite.parameters['models'][0][1].parameters['process_noise']['x1'], 2.5)

def test_composite(self):
m1 = OneInputOneOutputNoEventLM()
m2 = OneInputNoOutputOneEventLM()
Expand Down

0 comments on commit 8ee90ba

Please sign in to comment.