Skip to content

Commit

Permalink
Merge pull request #139 from nasa/feature/138-composite-model-functio…
Browse files Browse the repository at this point in the history
…n-output

Composite Model - Add function return in output
  • Loading branch information
teubert committed Dec 21, 2023
2 parents ec48258 + a367143 commit 31af475
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
4 changes: 4 additions & 0 deletions src/progpy/composite_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __setstate__(self, params: dict) -> None:
for (name, fcn) in params['functions']:
self.inputs |= set([name + DIVIDER + u for u in signature(fcn).parameters.keys()])
self.states.add(name + DIVIDER + 'return')
self.outputs.add(name + DIVIDER + 'return')

# Handle outputs
if 'outputs' in params:
Expand Down Expand Up @@ -307,6 +308,9 @@ def output(self, x):
# Save to super outputs
for key, value in z_i.items():
z[name + '.' + key] = value

for (name, _) in self.parameters['functions']:
z[name + DIVIDER + 'return'] = x[name + DIVIDER + 'return']
return self.OutputContainer(z)

def performance_metrics(self, x) -> dict:
Expand Down
12 changes: 8 additions & 4 deletions tests/test_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def fcn(u0, u1):
m_composite = CompositeModel([m1, m1, fcn])
self.assertSetEqual(m_composite.states, {'OneInputOneOutputNoEventLM_2.x1', 'OneInputOneOutputNoEventLM.x1', 'function.return'})
self.assertSetEqual(m_composite.inputs, {'OneInputOneOutputNoEventLM.u1', 'OneInputOneOutputNoEventLM_2.u1', 'function.u0', 'function.u1'})
self.assertSetEqual(m_composite.outputs, {'OneInputOneOutputNoEventLM.z1', 'OneInputOneOutputNoEventLM_2.z1'})
self.assertSetEqual(m_composite.outputs, {'OneInputOneOutputNoEventLM.z1', 'OneInputOneOutputNoEventLM_2.z1', 'function.return'})
self.assertSetEqual(m_composite.events, set())
self.assertSetEqual(m_composite.performance_metric_keys, set(), "Shouldn't have any performance metrics")

Expand Down Expand Up @@ -120,7 +120,7 @@ def fcn(u0, u1):
self.assertSetEqual(m_composite.states, {'OneInputOneOutputNoEventLM_2.x1', 'OneInputOneOutputNoEventLM.x1', 'OneInputOneOutputNoEventLM.z1', 'function.return'})
# One less input - since it's internally connected
self.assertSetEqual(m_composite.inputs, {'OneInputOneOutputNoEventLM.u1', 'function.u1'})
self.assertSetEqual(m_composite.outputs, {'OneInputOneOutputNoEventLM.z1', 'OneInputOneOutputNoEventLM_2.z1'})
self.assertSetEqual(m_composite.outputs, {'OneInputOneOutputNoEventLM.z1', 'OneInputOneOutputNoEventLM_2.z1', 'function.return'})
self.assertSetEqual(m_composite.events, set())

with self.assertRaises(TypeError):
Expand Down Expand Up @@ -167,7 +167,7 @@ def fcn(u0, u1):
self.assertSetEqual(m_composite.states, {'OneInputOneOutputNoEventLM_2.x1', 'OneInputOneOutputNoEventLM.x1', 'OneInputOneOutputNoEventLM.z1', 'function.return'})
# One less input - since it's internally connected
self.assertSetEqual(m_composite.inputs, {'OneInputOneOutputNoEventLM.u1'})
self.assertSetEqual(m_composite.outputs, {'OneInputOneOutputNoEventLM.z1', 'OneInputOneOutputNoEventLM_2.z1'})
self.assertSetEqual(m_composite.outputs, {'OneInputOneOutputNoEventLM.z1', 'OneInputOneOutputNoEventLM_2.z1', 'function.return'})
self.assertSetEqual(m_composite.events, set())

# Empty initialization should work now
Expand Down Expand Up @@ -216,7 +216,7 @@ def fcn(u0, u1) -> float:
self.assertSetEqual(m_composite.states, {'OneInputOneOutputNoEventLM_2.x1', 'OneInputOneOutputNoEventLM.x1', 'OneInputOneOutputNoEventLM.z1', 'function.return'})
# Two less input - since it's fully internally connected
self.assertSetEqual(m_composite.inputs, set())
self.assertSetEqual(m_composite.outputs, {'OneInputOneOutputNoEventLM.z1', 'OneInputOneOutputNoEventLM_2.z1'})
self.assertSetEqual(m_composite.outputs, {'OneInputOneOutputNoEventLM.z1', 'OneInputOneOutputNoEventLM_2.z1', 'function.return'})
self.assertSetEqual(m_composite.events, set())

# Empty initialization should work
Expand Down Expand Up @@ -250,6 +250,10 @@ def fcn(u0, u1) -> float:
self.assertEqual(x['OneInputOneOutputNoEventLM.x1'], 4)
self.assertEqual(x['function.return'], 9) # 4 + 4 + 1

# Function return in outputs
z = m_composite.output(x)
self.assertEqual(x['function.return'], z['function.return'])

def test_composite(self):
m1 = OneInputOneOutputNoEventLM()
m2 = OneInputNoOutputOneEventLM()
Expand Down

0 comments on commit 31af475

Please sign in to comment.