diff --git a/src/progpy/composite_model.py b/src/progpy/composite_model.py index 1909f3f..c74655d 100644 --- a/src/progpy/composite_model.py +++ b/src/progpy/composite_model.py @@ -118,6 +118,7 @@ def __setstate__(self, params: dict) -> None: for (name, fcn) in params['functions']: self.inputs |= set([name + DIVIDER + u for u in signature(fcn).parameters.keys()]) self.states.add(name + DIVIDER + 'return') + self.outputs.add(name + DIVIDER + 'return') # Handle outputs if 'outputs' in params: @@ -307,6 +308,9 @@ def output(self, x): # Save to super outputs for key, value in z_i.items(): z[name + '.' + key] = value + + for (name, _) in self.parameters['functions']: + z[name + DIVIDER + 'return'] = x[name + DIVIDER + 'return'] return self.OutputContainer(z) def performance_metrics(self, x) -> dict: diff --git a/tests/test_composite.py b/tests/test_composite.py index 8608735..ce17c0f 100644 --- a/tests/test_composite.py +++ b/tests/test_composite.py @@ -84,7 +84,7 @@ def fcn(u0, u1): m_composite = CompositeModel([m1, m1, fcn]) self.assertSetEqual(m_composite.states, {'OneInputOneOutputNoEventLM_2.x1', 'OneInputOneOutputNoEventLM.x1', 'function.return'}) self.assertSetEqual(m_composite.inputs, {'OneInputOneOutputNoEventLM.u1', 'OneInputOneOutputNoEventLM_2.u1', 'function.u0', 'function.u1'}) - self.assertSetEqual(m_composite.outputs, {'OneInputOneOutputNoEventLM.z1', 'OneInputOneOutputNoEventLM_2.z1'}) + self.assertSetEqual(m_composite.outputs, {'OneInputOneOutputNoEventLM.z1', 'OneInputOneOutputNoEventLM_2.z1', 'function.return'}) self.assertSetEqual(m_composite.events, set()) self.assertSetEqual(m_composite.performance_metric_keys, set(), "Shouldn't have any performance metrics") @@ -120,7 +120,7 @@ def fcn(u0, u1): self.assertSetEqual(m_composite.states, {'OneInputOneOutputNoEventLM_2.x1', 'OneInputOneOutputNoEventLM.x1', 'OneInputOneOutputNoEventLM.z1', 'function.return'}) # One less input - since it's internally connected self.assertSetEqual(m_composite.inputs, {'OneInputOneOutputNoEventLM.u1', 'function.u1'}) - self.assertSetEqual(m_composite.outputs, {'OneInputOneOutputNoEventLM.z1', 'OneInputOneOutputNoEventLM_2.z1'}) + self.assertSetEqual(m_composite.outputs, {'OneInputOneOutputNoEventLM.z1', 'OneInputOneOutputNoEventLM_2.z1', 'function.return'}) self.assertSetEqual(m_composite.events, set()) with self.assertRaises(TypeError): @@ -167,7 +167,7 @@ def fcn(u0, u1): self.assertSetEqual(m_composite.states, {'OneInputOneOutputNoEventLM_2.x1', 'OneInputOneOutputNoEventLM.x1', 'OneInputOneOutputNoEventLM.z1', 'function.return'}) # One less input - since it's internally connected self.assertSetEqual(m_composite.inputs, {'OneInputOneOutputNoEventLM.u1'}) - self.assertSetEqual(m_composite.outputs, {'OneInputOneOutputNoEventLM.z1', 'OneInputOneOutputNoEventLM_2.z1'}) + self.assertSetEqual(m_composite.outputs, {'OneInputOneOutputNoEventLM.z1', 'OneInputOneOutputNoEventLM_2.z1', 'function.return'}) self.assertSetEqual(m_composite.events, set()) # Empty initialization should work now @@ -216,7 +216,7 @@ def fcn(u0, u1) -> float: self.assertSetEqual(m_composite.states, {'OneInputOneOutputNoEventLM_2.x1', 'OneInputOneOutputNoEventLM.x1', 'OneInputOneOutputNoEventLM.z1', 'function.return'}) # Two less input - since it's fully internally connected self.assertSetEqual(m_composite.inputs, set()) - self.assertSetEqual(m_composite.outputs, {'OneInputOneOutputNoEventLM.z1', 'OneInputOneOutputNoEventLM_2.z1'}) + self.assertSetEqual(m_composite.outputs, {'OneInputOneOutputNoEventLM.z1', 'OneInputOneOutputNoEventLM_2.z1', 'function.return'}) self.assertSetEqual(m_composite.events, set()) # Empty initialization should work @@ -250,6 +250,10 @@ def fcn(u0, u1) -> float: self.assertEqual(x['OneInputOneOutputNoEventLM.x1'], 4) self.assertEqual(x['function.return'], 9) # 4 + 4 + 1 + # Function return in outputs + z = m_composite.output(x) + self.assertEqual(x['function.return'], z['function.return']) + def test_composite(self): m1 = OneInputOneOutputNoEventLM() m2 = OneInputNoOutputOneEventLM()