diff --git a/src/progpy/composite_model.py b/src/progpy/composite_model.py index 1ffe6e4..f8542ef 100644 --- a/src/progpy/composite_model.py +++ b/src/progpy/composite_model.py @@ -206,32 +206,29 @@ def __setstate__(self, params: dict) -> None: # These callbacks will enable setting of parameters in composed models. # E.g., composite.parameters['abc.def'] will set parameter 'def' for composed model 'abc' class PassthroughParams(): - def __init__(self, models, model_name, key): - self.models = models + def __init__(self, model_name, model_index, key): self.model_name = model_name self.key = key - i = 0 - for (name, m) in models: - if name == model_name: - break - i+= 1 - self.model_index = i + self.model_index = model_index self.combined_key = self.model_name + '.' + self.key def __call__(self, params: dict) -> dict: params['models'][self.model_index][1].parameters[self.key] = params[self.combined_key] return {} - for (name, m) in params['models']: - # TODO(CT): TRY JUST SAVING NAME + for (model_index, (name, m)) in enumerate(params['models']): for key in m.parameters.keys(): combined_key = name + '.' + key if combined_key in self.param_callbacks: - self.param_callbacks[combined_key].append(PassthroughParams(params['models'], name, key)) + self.param_callbacks[combined_key].append(PassthroughParams(name, model_index, key)) else: - self.param_callbacks[combined_key] = [PassthroughParams(params['models'], name, key)] + self.param_callbacks[combined_key] = [PassthroughParams(name, model_index, key)] - return super().__setstate__(params) + ms = params.pop('models') # remove from parameters to avoid copying + result = super().__setstate__(params) + params['models'] = ms + self.parameters.__setitem__('models', ms, _copy=False) + return result def initialize(self, u=None, z=None): if u is None: