Skip to content

Commit

Permalink
Merge pull request #130 from nasa/bugs/129-cannot-deep-copy-composite…
Browse files Browse the repository at this point in the history
…-model

[Closes #129] Fixes issue with deepcopy in composite model
  • Loading branch information
teubert committed Dec 6, 2023
2 parents 28dd429 + 51ce763 commit 0980f8c
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 27 deletions.
66 changes: 39 additions & 27 deletions src/progpy/composite_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,11 @@ def __init__(self, models: list, connections: list = [], **kwargs):
raise ValueError('The connections argument must be a list')

# Initialize
self.inputs = set()
self.states = set()
self.outputs = set()
self.events = set()
self.performance_metric_keys = set()
self.model_names = set()
duplicate_names = {}
kwargs['model_names'] = set()
kwargs['models'] = []
kwargs['functions'] = []
kwargs['connections'] = connections
duplicate_names = {}

# Handle models
for m in models:
Expand All @@ -78,10 +74,10 @@ def __init__(self, models: list, connections: list = [], **kwargs):
raise ValueError(f'Each model must be a PrognosticsModel or tuple (name: str, PrognosticsModel), was {type(m)}')

# Check for duplicate names
if m[0] in self.model_names:
if m[0] in kwargs['model_names']:
duplicate_names[m[0]] = duplicate_names.get(m[0], 1) + 1
m = (m[0] + '_' + str(duplicate_names[m[0]]), m[1])
self.model_names.add(m[0])
kwargs['model_names'].add(m[0])

# Handle model/function
if isinstance(m[1], PrognosticsModel):
Expand All @@ -93,40 +89,57 @@ def __init__(self, models: list, connections: list = [], **kwargs):
'The second element of each model tuple must be a'
' PrognosticsModel')

self.__setstate__(kwargs)

# Finish initialization
super().__init__(**kwargs)

def __setstate__(self, params: dict) -> None:
"""
Setup inputs, outputs, connections from models, functions. Needed to fix copying/pickling
Args:
params (dict): kwargs (either parameters or kwargs into constructor)
"""
self.inputs = set()
self.states = set()
self.outputs = set()
self.events = set()
self.performance_metric_keys = set()

# update inputs, states, outputs, etc.
for (name, m) in kwargs['models']:
for (name, m) in params['models']:
self.inputs |= set([name + DIVIDER + u for u in m.inputs])
self.states |= set([name + DIVIDER + x for x in m.states])
self.outputs |= set([name + DIVIDER + z for z in m.outputs])
self.events |= set([name + DIVIDER + e for e in m.events])
self.performance_metric_keys |= set([name + DIVIDER + p for p in m.performance_metric_keys])

for (name, fcn) in kwargs['functions']:
for (name, fcn) in params['functions']:
self.inputs |= set([name + DIVIDER + u for u in signature(fcn).parameters.keys()])
self.states.add(name + DIVIDER + 'return')

# Handle outputs
if 'outputs' in kwargs:
if isinstance(kwargs['outputs'], str):
kwargs['outputs'] = [kwargs['outputs']]
if not isinstance(kwargs['outputs'], Iterable):
if 'outputs' in params:
if isinstance(params['outputs'], str):
params['outputs'] = [params['outputs']]
if not isinstance(params['outputs'], Iterable):
raise ValueError('The outputs argument must be a list[str]')
if not set(kwargs['outputs']).issubset(self.outputs):
if not set(params['outputs']).issubset(self.outputs):
raise ValueError(
'The outputs of the composite model must be a '
'subset of the outputs of the models')
self.outputs = kwargs['outputs']
self.outputs = params['outputs']

# Handle Connections
kwargs['connections'] = []
self.__to_input_connections = {
m_name: [] for m_name in self.model_names}
m_name: [] for m_name in params['model_names']}
self.__to_state_connections = {
m_name: [] for m_name in self.model_names}
m_name: [] for m_name in params['model_names']}
self.__to_state_from_pm_connections = {
m_name: [] for m_name in self.model_names}
m_name: [] for m_name in params['model_names']}

for connection in connections:
for connection in params['connections']:
# Input validation
if not isinstance(connection, Iterable) or len(connection) != 2:
raise ValueError(
Expand Down Expand Up @@ -158,11 +171,11 @@ def __init__(self, models: list, connections: list = [], **kwargs):
if in_model == out_model:
raise ValueError(
'The input and output models must be different')
if in_model not in self.model_names:
if in_model not in params['model_names']:
raise ValueError(
'The input model must be one of the models'
' in the composite model')
if out_model not in self.model_names:
if out_model not in params['model_names']:
raise ValueError(
'The output model must be one of the models'
' in the composite model')
Expand All @@ -187,9 +200,8 @@ def __init__(self, models: list, connections: list = [], **kwargs):
else:
raise ValueError(
f'The input key {in_key} must be an output or state')

# Finish initialization
super().__init__(**kwargs)

return super().__setstate__(params)

def initialize(self, u=None, z=None):
if u is None:
Expand Down
40 changes: 40 additions & 0 deletions tests/test_composite.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright © 2021 United States Government as represented by the Administrator of the National Aeronautics and Space Administration. All Rights Reserved.

from copy import deepcopy
import unittest

from progpy import CompositeModel
Expand Down Expand Up @@ -409,6 +410,45 @@ def test_composite_pm(self):
self.assertAlmostEqual(x['OneInputOneOutputOneEventLM.x1'], 3) # extra 1 from pm
self.assertAlmostEqual(x['OneInputOneOutputOneEventLM_2.x1'], 2)

def test_composite_copy(self):
m = OneInputOneOutputOneEventLM()
m_composite = CompositeModel([m, m], connections=[('OneInputOneOutputOneEventLM_2.pm1', 'OneInputOneOutputOneEventLM.u1')])
m_composite_copy = deepcopy(m_composite)
self.assertSetEqual(m_composite.states, m_composite_copy.states)
self.assertSetEqual(m_composite.inputs, m_composite_copy.inputs)
self.assertSetEqual(m_composite.outputs, m_composite_copy.outputs)
self.assertSetEqual(m_composite.events, m_composite_copy.events)
self.assertSetEqual(m_composite.performance_metric_keys, m_composite_copy.performance_metric_keys)

# Initial State
x0 = m_composite.initialize()
x0_copy = m_composite_copy.initialize()
self.assertSetEqual(set(x0.keys()), set(x0_copy.keys()))
for key in x0.keys():
self.assertEqual(x0[key], x0_copy[key])

# State transition
u = m_composite.InputContainer({'OneInputOneOutputOneEventLM_2.u1': 1})
x = m_composite.next_state(x0, u, 1)
x_copy = m_composite_copy.next_state(x0_copy, u, 1)
self.assertSetEqual(set(x.keys()), set(x_copy.keys()))
for key in x.keys():
self.assertEqual(x[key], x_copy[key])

# Outputs
z = m_composite.output(x)
z_copy = m_composite_copy.output(x_copy)
self.assertSetEqual(set(z.keys()), set(z_copy.keys()))
for key in z.keys():
self.assertEqual(z[key], z_copy[key])

# Event states
es = m_composite.event_state(x)
es_copy = m_composite_copy.event_state(x_copy)
self.assertSetEqual(set(es.keys()), set(es_copy.keys()))
for key in es.keys():
self.assertEqual(es[key], es_copy[key])

def main():
load_test = unittest.TestLoader()
runner = unittest.TextTestRunner()
Expand Down

0 comments on commit 0980f8c

Please sign in to comment.