Skip to content

Commit

Permalink
Merge pull request #3241 from PhilipVinc:patch-3
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 553890291
  • Loading branch information
Flax Authors committed Aug 4, 2023
2 parents 8e51b71 + 81c274a commit d883d34
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,7 @@ build/
.pytype
.vscode/*
/.devcontainer
docs/**/tmp
docs/**/tmp

# used by direnv
.envrc
11 changes: 8 additions & 3 deletions flax/core/frozen_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"""Frozen Dictionary."""

import collections
from typing import Any, TypeVar, Mapping, Dict, Tuple, Union, Hashable
from typing import Any, Dict, Hashable, Optional, Mapping, Tuple, TypeVar, Union
from types import MappingProxyType

from flax import serialization
import jax
Expand Down Expand Up @@ -111,7 +112,9 @@ def __hash__(self):
self._hash = h
return self._hash

def copy(self, add_or_replace: Mapping[K, V]) -> 'FrozenDict[K, V]':
def copy(
self, add_or_replace: Mapping[K, V] = MappingProxyType({})
) -> 'FrozenDict[K, V]':
"""Create a new FrozenDict with additional or replaced entries."""
return type(self)({**self, **unfreeze(add_or_replace)}) # type: ignore[arg-type]

Expand Down Expand Up @@ -223,7 +226,9 @@ def unfreeze(x: Union[FrozenDict, Dict[str, Any]]) -> Dict[Any, Any]:

def copy(
x: Union[FrozenDict, Dict[str, Any]],
add_or_replace: Union[FrozenDict, Dict[str, Any]],
add_or_replace: Union[FrozenDict[str, Any], Dict[str, Any]] = FrozenDict(
{}
),
) -> Union[FrozenDict, Dict[str, Any]]:
"""Create a new dict with additional and/or replaced entries. This is a utility
function that can act on either a FrozenDict or regular dict and mimics the
Expand Down
12 changes: 12 additions & 0 deletions tests/core/core_frozen_dict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ def test_utility_copy(self, x, add_or_replace, actual_new_x):
new_x == actual_new_x and isinstance(new_x, type(actual_new_x))
)

@parameterized.parameters(
{
'x': {'a': 1, 'b': {'c': 2}},
},
{
'x': FrozenDict({'a': 1, 'b': {'c': 2}}),
},
)
def test_utility_copy_singlearg(self, x):
new_x = copy(x)
self.assertTrue(new_x == x and isinstance(new_x, type(x)))

@parameterized.parameters(
{
'x': {'a': 1, 'b': {'c': 2}},
Expand Down

0 comments on commit d883d34

Please sign in to comment.