Replies: 2 comments
-
Thanks for the question! The easiest thing to do use use tree utilities, basically mapping operations: # corresponding to #1 of your example, this is the most canonical
from jax.tree_util import tree_map
scaled_gradient = tree_map(lambda x: 0.1 * x, gradient)
# corresponding to #2
from jax.flatten_util import ravel_pytree
vector, unflatten = ravel_pytree(gradient)
scaled_gradient = unflatten(0.1 * vector) There's a way to do the third option too. But the first option seems the most canonical. You might even want to just work with That said, keep in mind that Stax is just an example library for inspiration. It's not actively developed, though people sometimes fork it to write their own version for their own needs. If you want better libraries, take a look at these. |
Beta Was this translation helpful? Give feedback.
-
Great, thanks. How would you recommend computing the norm of the gradient? It seems to me that that problem doesn't quite fit the shape of your example code. Right now I'm using this:
Is that idiomatic? Any tips to improve speed? |
Beta Was this translation helpful? Give feedback.
-
Let's say I make a neural network in Stax:
and then later I'm training this thing, so I have some code like this:
Now I want to get the norm of the gradient. I might try:
Or maybe I want to do arithmetic with this thing, such as multiplying it by a scalar, maybe because I'm implementing my own optimization algorithm:
Of course, neither of these snippets will work, because
gradient
isn't a numpy array or anything similar: it's a Python tuple of Python tuples ofDeviceArray
s. So to do either of these operations, I have to write a nested for loop, which possibly might even break if I change the network architecture.Is there any way of doing arithmetic on the result of a
grad
function? I'm hoping for an answer in one of three forms:scaled_gradient = jnp.multiply(0.1, gradient)
, or something.vector = jax.vectorize(gradient)
,new_gradient = jax.unvectorize_like(vector, gradient)
.grad
returns nice arithmetically compatible objects in the first place.Beta Was this translation helpful? Give feedback.
All reactions