-
ContextWhat is the best way to inherit from jax.Tensor ? I want to create a wrapper around ´jax.Tensor` that add on our two methods to it. I mainly want to things:
With Pytorch and numpy I can do that easily by just subclassing the tensor types but with Jax it is not working. My guess is that it is because the real class of a So my question is, what is the best way to achieve these two points in Jax ? |
Beta Was this translation helpful? Give feedback.
Replies: 5 comments 4 replies
-
cc @agaraman0 |
Beta Was this translation helpful? Give feedback.
-
Can you give an example use case? I'm working on a convenience class which is adjacent to this, so it might be that you can get what you want although not in exactly the same way. |
Beta Was this translation helpful? Give feedback.
-
instead of class WrapperClass(jax.Array):
pass but it does not work as expected as cc: @samsja |
Beta Was this translation helpful? Give feedback.
-
No, there is no way to subclass JAX arrays in such a way that they will be compatible with JAX transformations. This may be an XY problem – if you can tell us more about your actual end-goal, we may be able to recommend a good approach that is possible within JAX's API & computational model. |
Beta Was this translation helpful? Give feedback.
-
FYI: we are maintainer of docarray and we are trying to add support for jax, currently we support pytorch, numpy and tf. What we want is to be able to add a couple of function to our tensor class (just some helper function). Thus we would like to be able to extend the Jax.Array object, tho it is understandable why it is not working but we were wondering if they were solutions. Here is an example of how it looks like for pytorch tensor : https://github.com/docarray/docarray/blob/main/docarray/typing/tensor/torch_tensor.py |
Beta Was this translation helpful? Give feedback.
No, there is no way to subclass JAX arrays in such a way that they will be compatible with JAX transformations.
This may be an XY problem – if you can tell us more about your actual end-goal, we may be able to recommend a good approach that is possible within JAX's API & computational model.