Is there a way to initialize the nnx layers dynamically? #4365
Unanswered
yCobanoglu
asked this question in
Q&A
Replies: 1 comment
-
Hey @yCobanoglu, great question! We get this a lot as a downside of having explicit initialization. The nice thing is that you can infer the hard to compute constants by using class CNN(nnx.Module):
"""A simple CNN model."""
def __init__(self, x, *, rngs: nnx.Rngs):
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
# use `eval_shape` to compute the number of flat features without running the model
flat_features = nnx.eval_shape(CNN._get_flat_features, self, x).shape[-1]
self.linear1 = nnx.Linear(flat_features, 256, rngs=rngs)
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
def _get_flat_features(self, x):
x = self.avg_pool(nnx.relu(self.conv1(x)))
x = self.avg_pool(nnx.relu(self.conv2(x)))
x = x.reshape(x.shape[0], -1)
return x
def __call__(self, x):
x = self.avg_pool(nnx.relu(self.conv1(x)))
x = self.avg_pool(nnx.relu(self.conv2(x)))
x = x.reshape(x.shape[0], -1) # flatten
x = nnx.relu(self.linear1(x))
x = self.linear2(x)
return x
sample_x = jnp.ones((1, 64, 64, 1))
model = CNN(sample_x, rngs=nnx.Rngs(0)) Here I'm duplicating some of the forward pass but maybe you could even refactor the model. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/mnist_tutorial.html
This model is from the tutorial and the Linear1 layers input size is fixed which makes it annoying to train this model on a different dataset. Is there a way to lazy init somehow ?
Making self.linear1=None then initializing on the first pass in call with an if-else causes this error:
https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/mnist_tutorial.html
Beta Was this translation helpful? Give feedback.
All reactions