diff --git a/flax/core/frozen_dict.py b/flax/core/frozen_dict.py index 3fcdba6d05..3b34c6d80f 100644 --- a/flax/core/frozen_dict.py +++ b/flax/core/frozen_dict.py @@ -227,7 +227,7 @@ def unfreeze(x: Union[FrozenDict, Dict[str, Any]]) -> Dict[Any, Any]: def copy( x: Union[FrozenDict, Dict[str, Any]], - add_or_replace: Union[FrozenDict, Dict[str, Any]], + add_or_replace: Optional[Union[FrozenDict, Dict[str, Any]]] = None, ) -> Union[FrozenDict, Dict[str, Any]]: """Create a new dict with additional and/or replaced entries. This is a utility function that can act on either a FrozenDict or regular dict and mimics the @@ -248,7 +248,8 @@ def copy( return x.copy(add_or_replace) elif isinstance(x, dict): new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x - new_dict.update(add_or_replace) + if add_or_replace is not None: + new_dict.update(add_or_replace) return new_dict raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')