Skip to content

Commit

Permalink
allow method argument to accept submodules
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Sep 11, 2023
1 parent 426c0b4 commit 37cbb7d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
6 changes: 6 additions & 0 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1910,6 +1910,12 @@ def other_fn(instance, ...):
f"'{class_name}.{attribute_name}' must be a callable, got"
f' {type(method)}.'
)
# if the `method` string is a submodule, we create a lambda function
# that calls the submodule, forwarding all arguments.
if isinstance(method, Module):
method = lambda self, *args, **kwargs: getattr(self, attribute_name)(
*args, **kwargs
)
elif method is None:
method = self.__call__
method = _get_unbound_fn(method)
Expand Down
13 changes: 13 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,19 @@ def test(self):
# test same for init.
Foo().init({}, method='not_callable')

def test_module_apply_method_submodule(self):
class Foo(nn.Module):
bar: nn.Module

@nn.compact
def __call__(self, x):
return self.bar(x)

foo = Foo(nn.Dense(3))
variables = foo.init(jax.random.PRNGKey(0), jnp.zeros(3))

foo.apply(variables, jnp.zeros(3), method='bar')

def test_call_unbound_compact_module_methods(self):
dense = Dense(3)
msg = r'Can\'t call compact methods on unbound modules'
Expand Down

0 comments on commit 37cbb7d

Please sign in to comment.