Skip to content

Commit

Permalink
Fixed pickling of the policy subroutine in SBIBase
Browse files Browse the repository at this point in the history
  • Loading branch information
famura committed May 26, 2021
1 parent 1c8f8d9 commit 0283569
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions Pyrado/pyrado/algorithms/meta/sbi_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,18 +753,19 @@ def __getstate__(self):
# subsequent deepcopying
tmp_subrtn_policy = self.__dict__.pop("_subrtn_policy", None)

# Call Algorithm's __getstate__() without the unpickleable sbi-related members
state_dict = super().__getstate__()
# Call Algorithm's __getstate__() without the unpickleable sbi-related members.
# Make a deepcopy of the state dict such that we can return the pickleable version and insert the sbi variables.
state_dict_copy = deepcopy(super().__getstate__())

# Make a deep copy of the state dict such that we can return the pickleable version and insert the sbi variables
state_dict_copy = deepcopy(state_dict)

# Inset them back
# Inset them back to the current state dict
self.__dict__["_sbi_simulator"] = tmp_sbi_simulator
self.__dict__["_subrtn_sbi"]._summary_writer = tmp_subrtn_sbi_summary_writer
self.__dict__["_subrtn_sbi"]._build_neural_net = tmp_subrtn_sbi_build_neural_net
self.__dict__["_subrtn_policy"] = tmp_subrtn_policy

# Inset the subroutine from Pyrado back to the pickled state dict
state_dict_copy["_subrtn_policy"] = tmp_subrtn_policy

return state_dict_copy

def __setstate__(self, state):
Expand Down

0 comments on commit 0283569

Please sign in to comment.