From 62efe1271aab98b3a6f71258054436f3a9d41135 Mon Sep 17 00:00:00 2001 From: Filippo Vicentini Date: Mon, 31 Jul 2023 23:17:06 +0200 Subject: [PATCH] fix --- flax/core/frozen_dict.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/flax/core/frozen_dict.py b/flax/core/frozen_dict.py index 3fcdba6d05..3b34c6d80f 100644 --- a/flax/core/frozen_dict.py +++ b/flax/core/frozen_dict.py @@ -227,7 +227,7 @@ def unfreeze(x: Union[FrozenDict, Dict[str, Any]]) -> Dict[Any, Any]: def copy( x: Union[FrozenDict, Dict[str, Any]], - add_or_replace: Union[FrozenDict, Dict[str, Any]], + add_or_replace: Optional[Union[FrozenDict, Dict[str, Any]]] = None, ) -> Union[FrozenDict, Dict[str, Any]]: """Create a new dict with additional and/or replaced entries. This is a utility function that can act on either a FrozenDict or regular dict and mimics the @@ -248,7 +248,8 @@ def copy( return x.copy(add_or_replace) elif isinstance(x, dict): new_dict = jax.tree_map(lambda x: x, x) # make a deep copy of dict x - new_dict.update(add_or_replace) + if add_or_replace is not None: + new_dict.update(add_or_replace) return new_dict raise TypeError(f'Expected FrozenDict or dict, got {type(x)}')