Compatibility between submodule declaration and @nn.compact #565
-
Wouldn't it be handy to be able to use smaller blocks defined with class ConvBlock(nn.Module):
out_channels: int = 32
kernel_size: Sequence[int] = (3,3)
strides: Optional[Sequence[int]] = None
padding: Union[str, Sequence[Tuple[int, int]]] = 'VALID'
train: bool = False
dtype: Any = jnp.float32
@nn.compact
def __call__(self, x):
x = nn.Conv(self.out_channels, kernel_size=self.kernel_size, strides=self.strides,
padding=self.padding, use_bias=False, name='conv', dtype=self.dtype)(x)
x = nn.BatchNorm(use_running_average=not self.train, name='bn', dtype=self.dtype)(x)
return nn.relu(x)
class Model(nn.Module):
train: bool = False
dtype: Any = jnp.float32
def setup(self):
self.conv_block = functools.partial(ConvBlock, train=self.train, dtype=self.dtype)
def __call__(self, x):
x = self.conv_block(32, kernel_size=(3, 3), strides=(2, 2))(x)
return x Is there a technical reason why this is not possible? |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 4 replies
-
Hi @rolandgvc -- to clarify, are you saying that the code you wrote above doesn't work? (I guess our What actually currently happens when you run this code? |
Beta Was this translation helpful? Give feedback.
-
Can you please file a bug with a repro? |
Beta Was this translation helpful? Give feedback.
-
I think the problem here is not related to I don't think this is a bug, because you aren't allowed to create submodules in this way. However, I agree this error message is confusing. I am adding this example to the error messages in #1072 . |
Beta Was this translation helpful? Give feedback.
I think the problem here is not related to
__setattr__
. The problem is that inModel.__call__()
you callself.conv_block(...)
, which constructs the ModuleConvBlock
(becausepartial
insideModel.setup()
does not create it yet). The__post_init__
call inConvBlock
will then check where in theparent
(which isModel
) the submodule was created, and realize this was outside ofsetup
.I don't think this is a bug, because you aren't allowed to create submodules in this way. However, I agree this error message is confusing. I am adding this example to the error messages in #1072 .