Skip to content

Commit

Permalink
Remove flax_relaxed_naming feature flag.
Browse files Browse the repository at this point in the history
The relaxed naming convention has been the default for a while, we are
removing the temporary compatibility flag and all code associated with
the old behavior.

PiperOrigin-RevId: 542621262
  • Loading branch information
levskaya authored and Flax Authors committed Jun 22, 2023
1 parent 4f9f64e commit befa4c1
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 118 deletions.
5 changes: 0 additions & 5 deletions flax/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,6 @@ def temp_flip_flag(var_name: str, var_value: bool):
default=True,
help=('Whether to use Orbax to save checkpoints.'))

flax_relaxed_naming = define_bool_state(
name='relaxed_naming',
default=True,
help=('Whether to relax naming constraints.'))

flax_preserve_adopted_names = define_bool_state(
name='preserve_adopted_names',
default=False,
Expand Down
18 changes: 8 additions & 10 deletions flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Flax functional core: Scopes."""

import collections
import contextlib
import dataclasses
import functools
Expand Down Expand Up @@ -408,7 +409,7 @@ class Scope:
<https://github.com/google/flax/tree/main/tests/core/design>`_
for a number of examples using ``Scopes``.
"""
reservations: Dict[str, Optional[str]]
reservations: Dict[str, Set[Optional[str]]]

def __init__(self,
variables: MutableVariableDict,
Expand Down Expand Up @@ -442,7 +443,7 @@ def __init__(self,
self.trace_level = tracers.trace_level(tracers.current_trace())

self.rng_counters = {key: 0 for key in self.rngs}
self.reservations = dict()
self.reservations = collections.defaultdict(set)

self._invalid = False

Expand Down Expand Up @@ -534,14 +535,11 @@ def name_reserved(self, name: str, col: Optional[str] = None) -> bool:
col: if a variable, the collection used.
"""
if name in self.reservations:
# with relaxed naming, allow the same name for two variables in
# allow the same name for two variables in
# different collections, otherwise raise error.
if config.flax_relaxed_naming:
if (self.reservations[name] is None or col is None
or self.reservations[name] == col):
return True
else:
return True
if (None in self.reservations[name] or col is None
or col in self.reservations[name]):
return True
return False

def reserve(self, name: str, col: Optional[str] = None):
Expand All @@ -558,7 +556,7 @@ def reserve(self, name: str, col: Optional[str] = None):
f'it is {type(name)}')
if self.name_reserved(name, col):
raise ValueError(f'Duplicate use of scope name: "{name}"')
self.reservations[name] = col
self.reservations[name].add(col)

def default_name(self, prefix: str) -> str:
"""Generates an unreserved name with the given prefix.
Expand Down
29 changes: 1 addition & 28 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,20 +279,6 @@ def _map_over_modules_in_tree(fn, tree_or_leaf):
tree_or_leaf, traverse_util.unflatten_dict(mapped_flat_dict))


def _all_names_on_object(obj: Any) -> Set[str]:
"""Gets all names of attributes on `obj` and its classes throughout MRO.
Args:
obj: The object to get names for.
Returns:
A set of names of attributes of `obj` and its classes.
"""
nameset = set(obj.__dict__.keys())
for cls in obj.__class__.__mro__:
nameset = nameset.union(set(cls.__dict__.keys()))
return nameset


def _freeze_attr(val: Any) -> Any:
"""Recursively wrap the given attribute `var` in ``FrozenDict``."""
if isinstance(val, (dict, FrozenDict)):
Expand Down Expand Up @@ -1129,22 +1115,9 @@ def _name_taken(self,
reuse_scopes: bool = False,
collection: Optional[str] = None) -> bool:
assert self.scope is not None
# with relaxed naming don't force non-overlap with python attribute names.
if config.flax_relaxed_naming:
if reuse_scopes:
return False
return self.scope.name_reserved(name, collection)
if name in _all_names_on_object(self):
val = getattr(self, name, None)
if module is not None and val is module:
# name is taken by the value itself because
# field assignment happened before naming
return False
return True
# Check for the existence of name in the scope object.
if reuse_scopes:
return False
return name in self.scope.reservations
return self.scope.name_reserved(name, collection)

@property
def _initialization_allowed(self):
Expand Down
93 changes: 25 additions & 68 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,13 +460,8 @@ def __call__(self, x):
x = jnp.array([1.])
scope = Scope({}, {'params': rngkey}, mutable=['params'])

if config.flax_relaxed_naming:
with self.assertRaises(errors.NameInUseError):
unused_y = Dummy(x.shape, parent=scope)(x)
else:
msg = 'Duplicate use of scope name: "bias"'
with self.assertRaisesWithLiteralMatch(ValueError, msg):
unused_y = Dummy(x.shape, parent=scope)(x)
with self.assertRaises(errors.NameInUseError):
unused_y = Dummy(x.shape, parent=scope)(x)

def test_submodule_var_collision_with_submodule(self):
rngkey = jax.random.PRNGKey(0)
Expand Down Expand Up @@ -521,44 +516,6 @@ def __call__(self):

Foo({'a': ()}).apply({})

@absltest.skipIf(config.flax_relaxed_naming, "relaxed naming")
def test_attr_param_name_collision(self):
rngkey = jax.random.PRNGKey(0)

class Dummy(nn.Module):
bias: bool

def setup(self):
self.bias = self.param('bias', initializers.ones, (3, 3))

def __call__(self, x):
return x + self.bias

x = jnp.array([1.])
scope = Scope({}, {'params': rngkey}, mutable=['params'])
msg = 'Could not create param "bias" in Module Dummy: Name in use'
with self.assertRaisesRegex(errors.NameInUseError, msg):
unused_y = Dummy(True, parent=scope)(x)

@absltest.skipIf(config.flax_relaxed_naming, "relaxed naming")
def test_attr_submodule_name_collision(self):
rngkey = jax.random.PRNGKey(0)

class Dummy(nn.Module):
bias: bool

def setup(self):
self.bias = DummyModule(name='bias')

def __call__(self, x):
return self.bias(x)

x = jnp.array([1.])
scope = Scope({}, {'params': rngkey}, mutable=['params'])
msg = 'Could not create submodule "bias" in Module Dummy: Name in use'
with self.assertRaisesRegex(errors.NameInUseError, msg):
unused_y = Dummy(True, parent=scope)(x)

def test_only_one_compact_method(self):
msg = 'Only one method per class can be @compact'
with self.assertRaisesRegex(errors.MultipleMethodsCompactError, msg):
Expand Down Expand Up @@ -2425,18 +2382,10 @@ def __call__(self, x):
p = self.param('dummy', nn.initializers.zeros, x.shape)
return x + p

with set_config('flax_relaxed_naming', True):
foo = Foo(name='foo')
k = random.PRNGKey(0)
x = jnp.zeros((1,))
vs = foo.init(k, x)

with set_config('flax_relaxed_naming', False):
foo = Foo(name='foo')
k = random.PRNGKey(0)
x = jnp.zeros((1,))
with self.assertRaises(errors.NameInUseError):
vs = foo.init(k, x)
foo = Foo(name='foo')
k = random.PRNGKey(0)
x = jnp.zeros((1,))
vs = foo.init(k, x)

def test_relaxed_intercollection_conflict(self):

Expand All @@ -2447,18 +2396,26 @@ def __call__(self, x):
v2 = self.variable('col2', 'v', lambda x: jnp.zeros(x), x.shape)
return x + v1.value + v2.value

with set_config('flax_relaxed_naming', True):
foo = Foo(name='foo')
k = random.PRNGKey(0)
x = jnp.zeros((1,))
vs = foo.init(k, x)
foo = Foo(name='foo')
k = random.PRNGKey(0)
x = jnp.zeros((1,))
vs = foo.init(k, x)

with set_config('flax_relaxed_naming', False):
foo = Foo(name='foo')
k = random.PRNGKey(0)
x = jnp.zeros((1,))
with self.assertRaises(errors.NameInUseError):
vs = foo.init(k, x)
def test_relaxed_intercollection_conflict_set(self):

class Foo(nn.Module):
@nn.compact
def __call__(self, x):
v1 = self.variable('col1', 'v', lambda x: jnp.zeros(x), x.shape)
v2 = self.variable('col2', 'v', lambda x: jnp.zeros(x), x.shape)
v3 = self.variable('col1', 'v', lambda x: jnp.zeros(x), x.shape)
return x + v1.value + v2.value + v3.value

foo = Foo(name='foo')
k = random.PRNGKey(0)
x = jnp.zeros((1,))
with self.assertRaises(errors.NameInUseError):
vs = foo.init(k, x)


class FrozenDictTests(absltest.TestCase):
Expand Down
9 changes: 2 additions & 7 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,13 +1296,8 @@ def __call__(self, x):
k = random.PRNGKey(0)
x = jnp.array([1.])

if config.flax_relaxed_naming:
with self.assertRaises(errors.NameInUseError):
y = Test().init(k, x)
else:
msg = 'Duplicate use of scope name: "sub"'
with self.assertRaisesWithLiteralMatch(ValueError, msg):
y = Test().init(k, x)
with self.assertRaises(errors.NameInUseError):
y = Test().init(k, x)

def test_transform_with_setup_and_methods_on_submodule_pytrees(self):
class Foo(nn.Module):
Expand Down

0 comments on commit befa4c1

Please sign in to comment.