Skip to content

Commit

Permalink
Migrating Flax from returning FrozenDicts to returning regular dicts.…
Browse files Browse the repository at this point in the history
… More details can be found in this [announcement](#3191)

PiperOrigin-RevId: 547394685
  • Loading branch information
chiamp authored and Flax Authors committed Jul 19, 2023
1 parent 290e50f commit cb0541c
Show file tree
Hide file tree
Showing 6 changed files with 4 additions and 24 deletions.
7 changes: 4 additions & 3 deletions flax/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,12 @@ def temp_flip_flag(var_name: str, var_value: bool):
default=False,
help=("When adopting outside modules, don't clobber existing names."))

#TODO(marcuschiam): remove this feature flag once regular dict migration is complete
# TODO(marcuschiam): remove this feature flag once regular dict migration is complete
flax_return_frozendict = define_bool_state(
name='return_frozendict',
default=True,
help=('Whether to return FrozenDicts when calling init or apply.'))
default=False,
help='Whether to return FrozenDicts when calling init or apply.',
)

flax_fix_rng = define_bool_state(
name ='fix_rng_separator',
Expand Down
1 change: 0 additions & 1 deletion tests/core/core_lift_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def false_fn(scope, x):
self.assertEqual(vars['state'], {'true_count': 1, 'false_count': 1})
np.testing.assert_allclose(y1, -y2)

@temp_flip_flag('return_frozendict', False)
def test_switch(self):
def f(scope, x, index):
scope.variable('state', 'a_count', lambda: 0)
Expand Down
1 change: 0 additions & 1 deletion tests/linen/linen_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def test_causal_mask_1d(self):
np.testing.assert_allclose(mask_1d, mask_1d_simple,)

@parameterized.parameters([((5,), (1,)), ((6, 5), (2,))])
@temp_flip_flag('return_frozendict', False)
def test_decoding(self, spatial_shape, attn_dims):
bs = 2
num_heads = 3
Expand Down
7 changes: 0 additions & 7 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,6 @@ def test(self):
A().test()
self.assertFalse(setup_called)

@temp_flip_flag('return_frozendict', False)
def test_module_pass_as_attr(self):

class A(nn.Module):
Expand Down Expand Up @@ -1129,7 +1128,6 @@ def __call__(self, x):
}
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))

@temp_flip_flag('return_frozendict', False)
def test_module_pass_in_closure(self):
a = nn.Dense(2)

Expand All @@ -1154,7 +1152,6 @@ def __call__(self, x):
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))
self.assertIsNone(a.name)

@temp_flip_flag('return_frozendict', False)
def test_toplevel_submodule_adoption(self):

class Encoder(nn.Module):
Expand Down Expand Up @@ -1210,7 +1207,6 @@ def __call__(self, x):
}
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))

@temp_flip_flag('return_frozendict', False)
def test_toplevel_submodule_adoption_pytree(self):

class A(nn.Module):
Expand Down Expand Up @@ -1254,7 +1250,6 @@ def __call__(self, c, x):
lambda x, y: np.testing.assert_allclose(x, y, atol=1e-7),
counters, ref_counters)))

@temp_flip_flag('return_frozendict', False)
def test_toplevel_submodule_adoption_sharing(self):
dense = functools.partial(nn.Dense, use_bias=False)

Expand Down Expand Up @@ -1305,7 +1300,6 @@ def __call__(self, x):
}
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))

@temp_flip_flag('return_frozendict', False)
def test_toplevel_named_submodule_adoption(self):
dense = functools.partial(nn.Dense, use_bias=False)

Expand Down Expand Up @@ -1360,7 +1354,6 @@ def __call__(self, x):
}
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))

@temp_flip_flag('return_frozendict', False)
def test_toplevel_submodule_pytree_adoption_sharing(self):

class A(nn.Module):
Expand Down
5 changes: 0 additions & 5 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,6 @@ def __call__(self, x):
y3 = Ctrafo(a2, b).apply(p2, x)
np.testing.assert_allclose(y1, y3, atol=1e-7)

@temp_flip_flag('return_frozendict', False)
def test_toplevel_submodule_adoption_pytree_transform(self):
class A(nn.Module):
@nn.compact
Expand Down Expand Up @@ -852,7 +851,6 @@ def __call__(self, c, x):
cntrs, ref_cntrs)
))

@temp_flip_flag('return_frozendict', False)
def test_partially_applied_module_constructor_transform(self):
k = random.PRNGKey(0)
x = jnp.ones((3,4,4))
Expand All @@ -870,7 +868,6 @@ def test_partially_applied_module_constructor_transform(self):
}
self.assertTrue(tree_equals(init_vars_shapes, ref_var_shapes))

@temp_flip_flag('return_frozendict', False)
def test_partial_module_method(self):
k = random.PRNGKey(0)
x = jnp.ones((3,4,4))
Expand Down Expand Up @@ -1505,7 +1502,6 @@ def false_fn(mdl, x):

return nn.cond(pred, true_fn, false_fn, self, x)

@temp_flip_flag('return_frozendict', False)
def test_switch(self):
class Foo(nn.Module):
@nn.compact
Expand Down Expand Up @@ -1540,7 +1536,6 @@ def c_fn(mdl, x):
self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 1, 'c_count': 1})
np.testing.assert_allclose(y1, y3)

@temp_flip_flag('return_frozendict', False)
def test_switch_multihead(self):
class Foo(nn.Module):
def setup(self) -> None:
Expand Down
7 changes: 0 additions & 7 deletions tests/linen/summary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def test_module_summary(self):
row.counted_variables,
)

@temp_flip_flag('return_frozendict', False)
def test_module_summary_with_depth(self):
"""
This test creates a Table using `module_summary` set the `depth` argument to `1`,
Expand Down Expand Up @@ -240,7 +239,6 @@ def test_module_summary_with_depth(self):
self.assertEqual(table[3].module_variables, table[3].counted_variables)


@temp_flip_flag('return_frozendict', False)
def test_tabulate(self):
"""
This test creates a string representation of a Module using `Module.tabulate`
Expand Down Expand Up @@ -323,7 +321,6 @@ def test_tabulate_with_method(self):
self.assertIn("(block_method)", module_repr)
self.assertIn("(cnn_method)", module_repr)

@temp_flip_flag('return_frozendict', False)
def test_tabulate_function(self):
"""
This test creates a string representation of a Module using `Module.tabulate`
Expand Down Expand Up @@ -370,7 +367,6 @@ def test_tabulate_function(self):
self.assertIn("79.4 KB", lines[-3])


@temp_flip_flag('return_frozendict', False)
def test_lifted_transform(self):
class LSTM(nn.Module):
features: int
Expand Down Expand Up @@ -406,7 +402,6 @@ def __call__(self, x):
self.assertIn("ScanLSTM/ii", lines[13])
self.assertIn("Dense", lines[13])

@temp_flip_flag('return_frozendict', False)
def test_lifted_transform_no_rename(self):
class LSTM(nn.Module):
features: int
Expand Down Expand Up @@ -442,7 +437,6 @@ def __call__(self, x):
self.assertIn("ScanLSTMCell_0/ii", lines[13])
self.assertIn("Dense", lines[13])

@temp_flip_flag('return_frozendict', False)
def test_module_reuse(self):
class ConvBlock(nn.Module):
@nn.compact
Expand Down Expand Up @@ -524,7 +518,6 @@ def __call__(self):
self.assertIn('x: 3.141592', lines[7])
self.assertIn('4.141592', lines[7])

@temp_flip_flag('return_frozendict', False)
def test_partitioned_params(self):

class Classifier(nn.Module):
Expand Down

0 comments on commit cb0541c

Please sign in to comment.