-
I'm trying to load and use the model trained with flax nnx with C++. jax2tf seems to provide a way to achieve this, but the examples are about linen. How can we convert the flax nnx module to tf? |
Beta Was this translation helpful? Give feedback.
Answered by
cgarciae
Nov 1, 2024
Replies: 1 comment 1 reply
-
Here an example using from flax import nnx
import jax
import jax.numpy as jnp
from jax.experimental import jax2tf
import tensorflow as tf
model = nnx.Linear(3, 4, rngs=nnx.Rngs(0))
graphdef, statedef = nnx.split(model)
pure_state = statedef.to_pure_dict()
tf_state = tf.nest.map_structure(tf.Variable, pure_state)
def forward_jax(pure_state, x):
state = jax.tree.map(lambda x: x, statedef) # copy
state.replace_by_pure_dict(pure_state)
model = nnx.merge(graphdef, state)
return model(x)
# test forward
y = forward_jax(pure_state, jnp.ones((3,)))
def predict_tf(x):
return jax2tf.convert(forward_jax)(tf_state, x)
tf_model = tf.Module()
# Tell the model saver what are the variables.
tf_model._variables = tf.nest.flatten(tf_state)
tf_model.f = tf.function(predict_tf, jit_compile=True, autograph=False)
tf.saved_model.save(tf_model, 'some/path') After #4352 is merged this can be slightly simplified to: model = nnx.Linear(3, 4, rngs=nnx.Rngs(0))
graphdef, state = nnx.split(model)
state = state.to_pure_dict() # pure state
tf_state = tf.nest.map_structure(tf.Variable, state)
def forward_jax(state, x):
model = nnx.merge(graphdef, state)
return model(x)
# test forward
y = forward_jax(pure_state, jnp.ones((3,)))
def predict_tf(x):
return jax2tf.convert(forward_jax)(tf_state, x)
tf_model = tf.Module()
# Tell the model saver what are the variables.
tf_model._variables = tf.nest.flatten(tf_state)
tf_model.f = tf.function(predict_tf, jit_compile=True, autograph=False)
tf.saved_model.save(tf_model, 'some/path') |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
5c4lar
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Here an example using
nnx.split
andnnx.merge
to usejax2tf
to export an Flax NNX Module as a SavedModel: