Skip to content

How to export the nnx Module to SavedModel? #4341

Answered by cgarciae
5c4lar asked this question in Q&A
Discussion options

You must be logged in to vote

Here an example using nnx.split and nnx.merge to use jax2tf to export an Flax NNX Module as a SavedModel:

from flax import nnx
import jax
import jax.numpy as jnp
from jax.experimental import jax2tf
import tensorflow as tf

model = nnx.Linear(3, 4, rngs=nnx.Rngs(0))

graphdef, statedef = nnx.split(model)
pure_state = statedef.to_pure_dict()

tf_state = tf.nest.map_structure(tf.Variable, pure_state)

def forward_jax(pure_state, x):
  state = jax.tree.map(lambda x: x, statedef)  # copy
  state.replace_by_pure_dict(pure_state)
  model = nnx.merge(graphdef, state)
  return model(x)

# test forward
y = forward_jax(pure_state, jnp.ones((3,)))

def predict_tf(x):
  return jax2tf.convert(forward_jax

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@5c4lar
Comment options

Answer selected by 5c4lar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants