diff --git a/flax/configurations.py b/flax/configurations.py index 5dd3b0f82c..d56e6a8022 100644 --- a/flax/configurations.py +++ b/flax/configurations.py @@ -104,11 +104,6 @@ def temp_flip_flag(var_name: str, var_value: bool): default=True, help=('Whether to use Orbax to save checkpoints.')) -flax_relaxed_naming = define_bool_state( - name='relaxed_naming', - default=True, - help=('Whether to relax naming constraints.')) - flax_preserve_adopted_names = define_bool_state( name='preserve_adopted_names', default=False, diff --git a/flax/core/scope.py b/flax/core/scope.py index aca24d525d..42bfb5a873 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -14,6 +14,7 @@ """Flax functional core: Scopes.""" +import collections import contextlib import dataclasses import functools @@ -408,7 +409,7 @@ class Scope: `_ for a number of examples using ``Scopes``. """ - reservations: Dict[str, Optional[str]] + reservations: Dict[str, Set[Optional[str]]] def __init__(self, variables: MutableVariableDict, @@ -442,7 +443,7 @@ def __init__(self, self.trace_level = tracers.trace_level(tracers.current_trace()) self.rng_counters = {key: 0 for key in self.rngs} - self.reservations = dict() + self.reservations = collections.defaultdict(set) self._invalid = False @@ -534,14 +535,11 @@ def name_reserved(self, name: str, col: Optional[str] = None) -> bool: col: if a variable, the collection used. """ if name in self.reservations: - # with relaxed naming, allow the same name for two variables in + # allow the same name for two variables in # different collections, otherwise raise error. - if config.flax_relaxed_naming: - if (self.reservations[name] is None or col is None - or self.reservations[name] == col): - return True - else: - return True + if (None in self.reservations[name] or col is None + or col in self.reservations[name]): + return True return False def reserve(self, name: str, col: Optional[str] = None): @@ -558,7 +556,7 @@ def reserve(self, name: str, col: Optional[str] = None): f'it is {type(name)}') if self.name_reserved(name, col): raise ValueError(f'Duplicate use of scope name: "{name}"') - self.reservations[name] = col + self.reservations[name].add(col) def default_name(self, prefix: str) -> str: """Generates an unreserved name with the given prefix. diff --git a/flax/linen/module.py b/flax/linen/module.py index 87b4b8abc2..a4f2b114ab 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -279,20 +279,6 @@ def _map_over_modules_in_tree(fn, tree_or_leaf): tree_or_leaf, traverse_util.unflatten_dict(mapped_flat_dict)) -def _all_names_on_object(obj: Any) -> Set[str]: - """Gets all names of attributes on `obj` and its classes throughout MRO. - - Args: - obj: The object to get names for. - Returns: - A set of names of attributes of `obj` and its classes. - """ - nameset = set(obj.__dict__.keys()) - for cls in obj.__class__.__mro__: - nameset = nameset.union(set(cls.__dict__.keys())) - return nameset - - def _freeze_attr(val: Any) -> Any: """Recursively wrap the given attribute `var` in ``FrozenDict``.""" if isinstance(val, (dict, FrozenDict)): @@ -1129,22 +1115,9 @@ def _name_taken(self, reuse_scopes: bool = False, collection: Optional[str] = None) -> bool: assert self.scope is not None - # with relaxed naming don't force non-overlap with python attribute names. - if config.flax_relaxed_naming: - if reuse_scopes: - return False - return self.scope.name_reserved(name, collection) - if name in _all_names_on_object(self): - val = getattr(self, name, None) - if module is not None and val is module: - # name is taken by the value itself because - # field assignment happened before naming - return False - return True - # Check for the existence of name in the scope object. if reuse_scopes: return False - return name in self.scope.reservations + return self.scope.name_reserved(name, collection) @property def _initialization_allowed(self): diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index a2372309f4..e48122c558 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -460,13 +460,8 @@ def __call__(self, x): x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}, mutable=['params']) - if config.flax_relaxed_naming: - with self.assertRaises(errors.NameInUseError): - unused_y = Dummy(x.shape, parent=scope)(x) - else: - msg = 'Duplicate use of scope name: "bias"' - with self.assertRaisesWithLiteralMatch(ValueError, msg): - unused_y = Dummy(x.shape, parent=scope)(x) + with self.assertRaises(errors.NameInUseError): + unused_y = Dummy(x.shape, parent=scope)(x) def test_submodule_var_collision_with_submodule(self): rngkey = jax.random.PRNGKey(0) @@ -521,44 +516,6 @@ def __call__(self): Foo({'a': ()}).apply({}) - @absltest.skipIf(config.flax_relaxed_naming, "relaxed naming") - def test_attr_param_name_collision(self): - rngkey = jax.random.PRNGKey(0) - - class Dummy(nn.Module): - bias: bool - - def setup(self): - self.bias = self.param('bias', initializers.ones, (3, 3)) - - def __call__(self, x): - return x + self.bias - - x = jnp.array([1.]) - scope = Scope({}, {'params': rngkey}, mutable=['params']) - msg = 'Could not create param "bias" in Module Dummy: Name in use' - with self.assertRaisesRegex(errors.NameInUseError, msg): - unused_y = Dummy(True, parent=scope)(x) - - @absltest.skipIf(config.flax_relaxed_naming, "relaxed naming") - def test_attr_submodule_name_collision(self): - rngkey = jax.random.PRNGKey(0) - - class Dummy(nn.Module): - bias: bool - - def setup(self): - self.bias = DummyModule(name='bias') - - def __call__(self, x): - return self.bias(x) - - x = jnp.array([1.]) - scope = Scope({}, {'params': rngkey}, mutable=['params']) - msg = 'Could not create submodule "bias" in Module Dummy: Name in use' - with self.assertRaisesRegex(errors.NameInUseError, msg): - unused_y = Dummy(True, parent=scope)(x) - def test_only_one_compact_method(self): msg = 'Only one method per class can be @compact' with self.assertRaisesRegex(errors.MultipleMethodsCompactError, msg): @@ -2425,18 +2382,10 @@ def __call__(self, x): p = self.param('dummy', nn.initializers.zeros, x.shape) return x + p - with set_config('flax_relaxed_naming', True): - foo = Foo(name='foo') - k = random.PRNGKey(0) - x = jnp.zeros((1,)) - vs = foo.init(k, x) - - with set_config('flax_relaxed_naming', False): - foo = Foo(name='foo') - k = random.PRNGKey(0) - x = jnp.zeros((1,)) - with self.assertRaises(errors.NameInUseError): - vs = foo.init(k, x) + foo = Foo(name='foo') + k = random.PRNGKey(0) + x = jnp.zeros((1,)) + vs = foo.init(k, x) def test_relaxed_intercollection_conflict(self): @@ -2447,18 +2396,26 @@ def __call__(self, x): v2 = self.variable('col2', 'v', lambda x: jnp.zeros(x), x.shape) return x + v1.value + v2.value - with set_config('flax_relaxed_naming', True): - foo = Foo(name='foo') - k = random.PRNGKey(0) - x = jnp.zeros((1,)) - vs = foo.init(k, x) + foo = Foo(name='foo') + k = random.PRNGKey(0) + x = jnp.zeros((1,)) + vs = foo.init(k, x) - with set_config('flax_relaxed_naming', False): - foo = Foo(name='foo') - k = random.PRNGKey(0) - x = jnp.zeros((1,)) - with self.assertRaises(errors.NameInUseError): - vs = foo.init(k, x) + def test_relaxed_intercollection_conflict_set(self): + + class Foo(nn.Module): + @nn.compact + def __call__(self, x): + v1 = self.variable('col1', 'v', lambda x: jnp.zeros(x), x.shape) + v2 = self.variable('col2', 'v', lambda x: jnp.zeros(x), x.shape) + v3 = self.variable('col1', 'v', lambda x: jnp.zeros(x), x.shape) + return x + v1.value + v2.value + v3.value + + foo = Foo(name='foo') + k = random.PRNGKey(0) + x = jnp.zeros((1,)) + with self.assertRaises(errors.NameInUseError): + vs = foo.init(k, x) class FrozenDictTests(absltest.TestCase): diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index e8e856fb8d..2cc69daa4c 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -1296,13 +1296,8 @@ def __call__(self, x): k = random.PRNGKey(0) x = jnp.array([1.]) - if config.flax_relaxed_naming: - with self.assertRaises(errors.NameInUseError): - y = Test().init(k, x) - else: - msg = 'Duplicate use of scope name: "sub"' - with self.assertRaisesWithLiteralMatch(ValueError, msg): - y = Test().init(k, x) + with self.assertRaises(errors.NameInUseError): + y = Test().init(k, x) def test_transform_with_setup_and_methods_on_submodule_pytrees(self): class Foo(nn.Module):