Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipVinc committed Jul 31, 2023
1 parent 940ff5d commit 62efe12
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions flax/core/frozen_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def unfreeze(x: Union[FrozenDict, Dict[str, Any]]) -> Dict[Any, Any]:

def copy(
x: Union[FrozenDict, Dict[str, Any]],
add_or_replace: Union[FrozenDict, Dict[str, Any]],
add_or_replace: Optional[Union[FrozenDict, Dict[str, Any]]] = None,
) -> Union[FrozenDict, Dict[str, Any]]:
"""Create a new dict with additional and/or replaced entries. This is a utility
function that can act on either a FrozenDict or regular dict and mimics the
Expand All @@ -248,7 +248,8 @@ def copy(
return x.copy(add_or_replace)
elif isinstance(x, dict):
new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x
new_dict.update(add_or_replace)
if add_or_replace is not None:
new_dict.update(add_or_replace)
return new_dict
raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')

Expand Down

0 comments on commit 62efe12

Please sign in to comment.