diff --git a/examples/06_Combining_Models.ipynb b/examples/06_Combining_Models.ipynb index 7839753..e3305d8 100644 --- a/examples/06_Combining_Models.ipynb +++ b/examples/06_Combining_Models.ipynb @@ -244,6 +244,40 @@ "fig = simulated_results.states.plot(keys=['DCMotor.i_b', 'DCMotor.i_c', 'DCMotor.i_a'], ylabel='ESC Currents')" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Parameters in composed models can be updated directly using the model_name.parameter name parameter of the composite model. Like so:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "m_powertrain.parameters['PropellerLoad.D'] = 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we updated the propeller diameter to 1, greatly increasing the load on the motor. You can see this in the updated simulation outputs (below). When compared to the original results above you will find that the maximum velocity is lower. This is expected given the larger propeller load." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "simulated_results = m_powertrain.simulate_to(1, future_loading, dt=2.5e-5, save_freq=2e-2)\n", + "fig = simulated_results.outputs.plot(compact=False, keys=['DCMotor.v_rot'], ylabel='Velocity')\n", + "fig = simulated_results.states.plot(keys=['DCMotor.i_b', 'DCMotor.i_c', 'DCMotor.i_a'], ylabel='ESC Currents')" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/src/progpy/composite_model.py b/src/progpy/composite_model.py index c74655d..1ffe6e4 100644 --- a/src/progpy/composite_model.py +++ b/src/progpy/composite_model.py @@ -201,7 +201,36 @@ def __setstate__(self, params: dict) -> None: else: raise ValueError( f'The input key {in_key} must be an output or state') + + # Setup callbacks + # These callbacks will enable setting of parameters in composed models. + # E.g., composite.parameters['abc.def'] will set parameter 'def' for composed model 'abc' + class PassthroughParams(): + def __init__(self, models, model_name, key): + self.models = models + self.model_name = model_name + self.key = key + i = 0 + for (name, m) in models: + if name == model_name: + break + i+= 1 + self.model_index = i + self.combined_key = self.model_name + '.' + self.key + + def __call__(self, params: dict) -> dict: + params['models'][self.model_index][1].parameters[self.key] = params[self.combined_key] + return {} + for (name, m) in params['models']: + # TODO(CT): TRY JUST SAVING NAME + for key in m.parameters.keys(): + combined_key = name + '.' + key + if combined_key in self.param_callbacks: + self.param_callbacks[combined_key].append(PassthroughParams(params['models'], name, key)) + else: + self.param_callbacks[combined_key] = [PassthroughParams(params['models'], name, key)] + return super().__setstate__(params) def initialize(self, u=None, z=None): diff --git a/tests/test_base_models.py b/tests/test_base_models.py index 86b63a6..2ef0bd2 100644 --- a/tests/test_base_models.py +++ b/tests/test_base_models.py @@ -261,7 +261,7 @@ def test_integration_type_error(self): def test_size(self): m = MockProgModel() size = sys.getsizeof(m) - self.assertLess(size, 7500) + self.assertLess(size, 20000) # Adding a parameter m.parameters['test'] = 8675309 diff --git a/tests/test_composite.py b/tests/test_composite.py index ce17c0f..b4ace27 100644 --- a/tests/test_composite.py +++ b/tests/test_composite.py @@ -254,6 +254,17 @@ def fcn(u0, u1) -> float: z = m_composite.output(x) self.assertEqual(x['function.return'], z['function.return']) + def test_parameter_passthrough(self): + # This tests a feature where parameters of the composed models are settable in the composite model. + m1 = OneInputOneOutputNoEventLM() + m2 = OneInputNoOutputOneEventLM() + m_composite = CompositeModel([m1, m1], connections=[('OneInputOneOutputNoEventLM.z1', 'OneInputOneOutputNoEventLM_2.u1')]) + + # At the beginning process noise is 0, lets set it to something else. + model_name = m_composite.parameters['models'][0][0] + m_composite.parameters[model_name + "." + "process_noise"] = 2.5 + self.assertEqual(m_composite.parameters['models'][0][1].parameters['process_noise']['x1'], 2.5) + def test_composite(self): m1 = OneInputOneOutputNoEventLM() m2 = OneInputNoOutputOneEventLM()