Skip to content

NNX : Create a Custom Primitive layer that works with the jax.grad or flax.nnx.grad #4434

Answered by cgarciae
DiagRisker asked this question in Q&A
Discussion options

You must be logged in to vote

Hey @DiagRisker, sorry for the slow reply. I looked into this and the issue is a bug in JAX where it doesn't call Variable.__jax_array__ and errors instead. This __jax_array__ protocol is a bit experimental but I've been chatting with some of the team about this since NNX is using it very broadly. To fix it simply call .value on your params when passing it to JAX:

  def __call__(self, x: jax.Array) -> jax.Array:
    out = self.conv(x, self.kern.value)
    if self.use_bias:
      out += self.bias.value
    return out

Here's a the full example using nnx.grad:

import jax.numpy as jnp
from flax import nnx
import jax
from jax import lax

from functools import partial
from typing import Any, Ca…

Replies: 2 comments 5 replies

Comment options

You must be logged in to vote
4 replies
@DiagRisker
Comment options

@cgarciae
Comment options

@DiagRisker
Comment options

@DiagRisker
Comment options

Comment options

You must be logged in to vote
1 reply
@DiagRisker
Comment options

Answer selected by DiagRisker
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants