Replies: 1 comment
-
Hey @jlperla, the easiest way is to use params = nnx.state(model, nnx.Param)
total_params = sum(np.prod(x.shape) for x in jax.tree.leaves(params), 0) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I want to get a count of all of the trainable parameters for a module. In pytorch I can use
model.parameters
and then recursively check if a gradient is required, etc.How would I do this with a recursive function and filters in nnx?
Beta Was this translation helpful? Give feedback.
All reactions