How should the Tracers be handled while using pytrees to register a custom python container? #5803
Unanswered
trickarcher
asked this question in
Q&A
Replies: 1 comment 4 replies
-
I'm having trouble understanding from your example what your question is: in your implementation, def func(x):
return x[0] ** 2 + x[1] ** 2
tree = np.array([1., 2., 3., 4.])
basis = np.array([0., 1., 0., 0.])
print(jax.jvp(func, (tree,), (basis,))) and thus it seems to be working as expected. Was that your understanding? |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
We are trying to use
pytrees
, to register Awkward Arrays(which is a tree-like python container, and fits perfectly in the use-case definition of pytrees). to allow JAX to handle element wise ops on it. Just for some context, an Awkward Array can be broken down into a(form, length, children)
where,(form, length)
correspond to theaux_data
of thepytrees
andchildren
are a list of linear buffers. We wrote our own version ofspecial_unflatten
andspecial_flatten
for one type of Awkward Arrays called theak.layout.NumpyArray
which is essentially same as anumpy.ndarray
.While calling the
jvp
function, we recieveJVPTracers
. From my limited knowledge on JAX, it has aprimal
and atangent
. But to form an Awkward Array we need a concrete linear buffer, which theJVPTracer
does provide, but it gives two of them. How should we be handling this tracer? For a simple example, if we catch theJVPTracer
object usingisinstance
and return itas-is
, we do get the correct gradients. However, from the one example on the documentation it seems that these Tracers need to be somehow in the form of the python container. If so, how could we convert something like a Tracer to anumpy-array
like object without losing theprimal
and thetangent
information?EDIT: Here's a basic code example, taking in a wrapper
NumpyArray
class and attempting ajvp
pass on it.Beta Was this translation helpful? Give feedback.
All reactions