nnx.Dropout
in combination with jax.lax.scan
returns an error
#4152
Answered
by
cgarciae
maxencefaldor
asked this question in
Q&A
-
System information
Problem you have encountered:
Steps to reproduce:
However,
returns an error:
|
Beta Was this translation helpful? Give feedback.
Answered by
cgarciae
Aug 28, 2024
Replies: 1 comment
-
This makes sense, class MLP(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(10, 256, rngs=rngs)
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.dropout(self.linear1(x)))
x = self.linear2(x)
return x
model = MLP(rngs=nnx.Rngs(0))
@nnx.scan # nnx.scan has similar api to vmap
def scan_fn(model, x):
y = model(x)
return model, y
model, y = scan_fn(model, jnp.ones((10, 10))) You will get new RNG state each step. |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
maxencefaldor
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This makes sense,
nnx.Dropout
updates is RNG state but you cannot update of NNX objects passed as a capture to JAX transforms. To fix this usennx.scan
and passmodel
as a carry:You will g…