Creating module attributes with @nn.compact #936
-
Hi all, I have noticed that it is no longer possible to create module attributes in the call method when using the @nn.compact decorator. This used to work in version 0.2.2 but no longer works in version 0.3.0. I am wondering why this is no longer possible. I think there are many cases, where this can be helpful. Thanks! import flax.nn as nn
from jax import random
import jax.numpy as jnp
class Net(nn.Module):
@nn.compact
def __call__(self, x):
self.n = 1
return x
class Net2(nn.Module):
def setup(self):
self.n = 1
def __call__(self, x):
return x
key = random.PRNGKey(0)
x = random.normal(key, shape=(1,1))
# This does not work in version 0.3.0
model = Net()
params = model.init(key, x)
y = model.apply(params, x)
# This still works
model = Net2()
params = model.init(key, x)
y = model.apply(params, x) Error message: Edit: I did not want to use the setup method for this, because I absolutely love using the compact method decorator. It makes the model look so tidy. I also know that I can just declare the dictionary as a normal variable inside the call method and then pass it to the resblock. However, for the larger resnets, when there are many blocks, it just looks cleaner when you don't have to do that. I guess a good compromise is to just use the setup method for initializing the dictionary and still using the compact decorator to define the submodules in the call method (see Net2 for a simplified example). # does not work
class Net(nn.Module):
@nn.compact
def __call__(self, x):
self.param_dict = {} # dict with params
x = self.resblock(x)
return x
def resblock(self, x):
# resblock code here
w = lambda *_ : jnp.array(self.param_dict['dense1']['weight'])
b = lambda *_ : jnp.array(self.param_dict['dense1']['bias'])
x = nn.Conv(features=1, kernel_init=w, bias_init=b, name='dense1')(x)
return x
# works
class Net2(nn.Module):
def setup(self):
self.param_dict = {} # dict with params
@nn.compact
def __call__(self, x):
x = self.resblock(x)
return x
def resblock(self, x):
# resblock code here
w = lambda *_ : jnp.array(self.param_dict['dense1']['weight'])
b = lambda *_ : jnp.array(self.param_dict['dense1']['bias'])
x = nn.Conv(features=1, kernel_init=w, bias_init=b, name='dense1')(x)
return x |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
TL;DR it was never safe to do this but now we actually raise an error. The practical argument is that like all Jax/Flax APIs we want to avoid internal state so we typically freeze things after construction. The theoretical argument is that in the case of Module we actually need the instance to be Frozen after setup otherwise we cannot guarantee correct behaviour in all cases. Basically, we need to clone Module instances when using transformations (vmap, jit) or |
Beta Was this translation helpful? Give feedback.
TL;DR it was never safe to do this but now we actually raise an error.
The practical argument is that like all Jax/Flax APIs we want to avoid internal state so we typically freeze things after construction.
The theoretical argument is that in the case of Module we actually need the instance to be Frozen after setup otherwise we cannot guarantee correct behaviour in all cases. Basically, we need to clone Module instances when using transformations (vmap, jit) or
Module.apply
. Instead of trying to do a magical deep clone we reconstruct the Module from the construction arguments. This is clean and simple as long as there is no internal state like attribute assignment in methods.