NNX : Create a Custom Primitive layer that works with the jax.grad or flax.nnx.grad #4434
-
Hi, As a PhD student, I've been investigating recently what I can do with flax.nnx modules: import jax.numpy as jnp
from flax import nnx
import jax
from jax import lax
from functools import partial
from typing import Any, Callable, Sequence, Union
def Custuniform(**args):
# print('args : ', args)
if args['dtype'] == (jnp.complex64):
args['dtype'] = jnp.float32 ; #print('rewriting args... ')
return jax.random.uniform( **args) + 1j * jax.random.uniform( **args)
# if dtype = "quaternion"
return jax.random.uniform( **args)
class ConvND(nnx.Module): # (kernel_size, features, c_in )
def __init__(self,
#
kernel_size: tuple,
features: int = 1,
c_in : int = 1 , # channels in
stride : Tuple[int]|int = (1,),
# dilation options
input_dilation : Union[None, int, Sequence[int]] = None ,
kernel_dilation : Union[None, int, Sequence[int]] = None ,
#
precision : flax.typing.PrecisionLike = jax.lax.Precision('high'),
use_bias : bool = False,
rngs = nnx.Rngs(seed),
dtype = jnp.float32 # for weigths initialization type of data (not vector valued)
):
self.kern = nnx.Param(Custuniform(key = rngs.params(), shape = tuple(kernel_size)+(c_in,features), dtype = dtype) ) ;
# kernel = nnx.initializers.lecun_normal()(rngs.params(), (2, 3))
self.use_bias = use_bias
if use_bias: self.bias = nnx.Param(Custuniform(key = rngs.params(), shape = (features,), dtype = dtype)) ;
# stride
if not hasattr(stride,"__iter__"):
stride = (stride,)*len(kernel_size)
elif len(stride)-len(kernel_size):
stride = (stride[0],)*len(kernel_size)
# enforcing the (Batch_size, *spatial_dim, channel_size) I/O convention
incf= (0, len(kernel_size)+1)+tuple(range(1, len(kernel_size)+1));
kercf = (len(kernel_size)+1, len(kernel_size)) + tuple(range(0, len(kernel_size)));
dimnum = jax.lax.ConvDimensionNumbers( incf, kercf, incf)
# compiling the convolution with the given parameters
self.conv = jax.jit( partial(jax.lax.conv_general_dilated, window_strides = stride, dimension_numbers = dimnum, padding = "VALID", lhs_dilation = input_dilation , rhs_dilation = kernel_dilation , precision = precision ))
# the function will be compiled per instance of the class.. > not perfect (depending on jax/flax.nnx jit caching system)
def __call__(self, x: jax.Array) -> jax.Array:
out = self.conv(x,self.kern)
if self.use_bias:
out += self.bias
return out
c_in =5; c_out = 2
input = Custuniform( key = nnx.Rngs(seed).params(), shape = (60, 20, 20, 20, c_in), dtype = jnp.float32)
Kerns = (3,3,3)
Layer = (
# nnx.Conv(in_features=c_in, out_features=c_out, kernel_size=Kerns, padding='VALID', rngs=rngs)
ConvND( Kerns, c_out, c_in = c_in)
)
def loss( graphdef, state):
model = nnx.merge(graphdef, state)
return ((model(input))**2).mean()
grads = jax.grad(loss,1)(*nnx.split(Layer) ) I used flax nnx documentation (and forums), notably : https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html NB : I have a similar implementation in JAX, that works for autodifferentiation, but I can't use the model construction of nnx (creating a model in cascade by assigning sub layer/ blocks in a dict): import jax
import jax.numpy as jnp
class Rng():
def __init__(self, seed = seed):
self.key = jax.random.key(seed)
# print(self.key)
def __call__(self):
self.key = jax.random.split(self.key)[1]
return self.key
class ConvND():
def __init__(self,
kernel_size : Sequence[int] = (3,),
outC : int = 2,
inC : int = 1,
strides : Sequence[int] = (1,) ,
bias : bool = False,
vecvalD : int = 5,
rng = Rng(),
precision = jax.lax.Precision('high')
):
""" shape convention: {input: (B_dim, input_channels, vec_valued_dim, *spatial_dim) vector valued compatible with kernel
or (B_dim, input_channels, *spatial_dim)
, kernel : (output_channels, input_channels, *spatial_dim) """
self.LG = dict( kernel = jax.random.uniform(key = rng(), shape = ( outC, inC ) + kernel_size + ( (vecvalD,) if vecvalD >1 else () ) ) )
if bias:
self.LG.update({'bias': jax.random.uniform(key = rng(), shape = (outC,) + ( (vecvalD,) if vecvalD >1 else () ) ) })
padding = [(0,0)] * len(kernel_size) # equivalent of VALID padding
if len(strides)-len(kernel_size):
strides = strides[0]*len(kernel_size)
elif not hasattr(strides,"__iter__"):
strides = (strides,)*len(kernel_size)
if vecvalD>1:
print(" vector valued parameters ")
self.conv = jax.jit(jax.vmap(partial( lax.conv_general_dilated, window_strides = strides, padding = padding , precision = precision), in_axes = (-1,-1), out_axes = -1) )
else:
self.conv = jax.jit( partial(lax.conv_general_dilated , window_strides = strides, padding = padding , precision = precision) )
def __call__( self, x : jax.Array , params : Any = None ):
if params is None:
params = self.LG
out = self.conv( x, params['kernel'])
if 'bias' in params:
out += params['bias']
return out
def Gjacob(f, *x):
""" General jacobian wrapper :
returns the "non directionnal" jacobian of respective parameters """
y, vjp_fn = jax.vjp(f, *x)
return vjp_fn(jnp.ones_like(y) ) # tangent defined as 1
Layer = ConvND()
LGflat , treedef = jax.tree_util.tree_flatten(Layer.LG )
vecvalD = 5
Lil = jax.random.uniform( key = rng(), shape = (8,1,3) + ( (vecvalD,) if vecvalD else () ))
dW = Gjacob( partial(Layer , Lil) , Layer.LG)
print( " update dW: " , dW[0]['kernel'].shape , ' vs Kernel origin shape : ', Layer.LG['kernel'].shape) Apologies for the long question! Thanks in advance |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 5 replies
-
Hey @DiagRisker , for NNX Modules you have to use |
Beta Was this translation helpful? Give feedback.
-
Hey @DiagRisker, sorry for the slow reply. I looked into this and the issue is a bug in JAX where it doesn't call def __call__(self, x: jax.Array) -> jax.Array:
out = self.conv(x, self.kern.value)
if self.use_bias:
out += self.bias.value
return out Here's a the full example using import jax.numpy as jnp
from flax import nnx
import jax
from jax import lax
from functools import partial
from typing import Any, Callable, Sequence, Union
def Custuniform(**args):
# print('args : ', args)
if args['dtype'] == (jnp.complex64):
args['dtype'] = jnp.float32 # print('rewriting args... ')
return jax.random.uniform(**args) + 1j * jax.random.uniform(**args)
# if dtype = "quaternion"
return jax.random.uniform(**args)
class ConvND(nnx.Module): # (kernel_size, features, c_in )
def __init__(
self,
#
kernel_size: tuple,
features: int = 1,
c_in: int = 1, # channels in
*,
stride: tuple[int] | int = (1,),
# dilation options
input_dilation: Union[None, int, Sequence[int]] = None,
kernel_dilation: Union[None, int, Sequence[int]] = None,
#
precision=jax.lax.Precision('high'),
use_bias: bool = False,
rngs: nnx.Rngs,
dtype=jnp.float32, # for weigths initialization type of data (not vector valued)
):
self.kern = nnx.Param(
Custuniform(
key=rngs.params(),
shape=tuple(kernel_size) + (c_in, features),
dtype=dtype,
)
)
# kernel = nnx.initializers.lecun_normal()(rngs.params(), (2, 3))
self.use_bias = use_bias
if use_bias:
self.bias = nnx.Param(
Custuniform(key=rngs.params(), shape=(features,), dtype=dtype)
)
# stride
if not hasattr(stride, '__iter__'):
stride = (stride,) * len(kernel_size)
elif len(stride) - len(kernel_size):
stride = (stride[0],) * len(kernel_size)
# enforcing the (Batch_size, *spatial_dim, channel_size) I/O convention
incf = (0, len(kernel_size) + 1) + tuple(range(1, len(kernel_size) + 1))
kercf = (len(kernel_size) + 1, len(kernel_size)) + tuple(
range(0, len(kernel_size))
)
dimnum = jax.lax.ConvDimensionNumbers(incf, kercf, incf)
# compiling the convolution with the given parameters
self.conv = jax.jit(
partial(
jax.lax.conv_general_dilated,
window_strides=stride,
dimension_numbers=dimnum,
padding='VALID',
lhs_dilation=input_dilation,
rhs_dilation=kernel_dilation,
precision=precision,
)
)
# the function will be compiled per instance of the class.. > not perfect (depending on jax/flax.nnx jit caching system)
def __call__(self, x: jax.Array) -> jax.Array:
out = self.conv(x, self.kern.value)
if self.use_bias:
out += self.bias.value
return out
c_in = 5
c_out = 2
rngs = nnx.Rngs(0)
input = Custuniform(key=rngs(), shape=(60, 20, 20, 20, c_in), dtype=jnp.float32)
Kerns = (3, 3, 3)
# nnx.Conv(in_features=c_in, out_features=c_out, kernel_size=Kerns, padding='VALID', rngs=rngs)
model = ConvND(Kerns, c_out, c_in=c_in, rngs=nnx.Rngs(0))
y = model(input)
def loss(model):
return ((model(input)) ** 2).mean()
grads = nnx.grad(loss)(model)
print(jax.tree.map(jnp.shape, grads)) |
Beta Was this translation helpful? Give feedback.
Hey @DiagRisker, sorry for the slow reply. I looked into this and the issue is a bug in JAX where it doesn't call
Variable.__jax_array__
and errors instead. This__jax_array__
protocol is a bit experimental but I've been chatting with some of the team about this since NNX is using it very broadly. To fix it simply call.value
on your params when passing it to JAX:Here's a the full example using
nnx.grad
: