Skip to content

nnx.Dropout in combination with jax.lax.scan returns an error #4152

Answered by cgarciae
maxencefaldor asked this question in Q&A
Discussion options

You must be logged in to vote

This makes sense, nnx.Dropout updates is RNG state but you cannot update of NNX objects passed as a capture to JAX transforms. To fix this use nnx.scan and pass model as a carry:

class MLP(nnx.Module):
  def __init__(self, *, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(10, 256, rngs=rngs)
    self.linear2 = nnx.Linear(256, 10, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.5, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.dropout(self.linear1(x)))
    x = self.linear2(x)
    return x

model = MLP(rngs=nnx.Rngs(0))

@nnx.scan  # nnx.scan has similar api to vmap
def scan_fn(model, x):
  y = model(x)
  return model, y

model, y = scan_fn(model, jnp.ones((10, 10)))

You will g…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by maxencefaldor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants
Converted from issue

This discussion was converted from issue #4151 on August 28, 2024 14:04.