-
I'd like to test out the import jax
import jax.numpy as jnp
from jax.interpreters import batching
jax.config.update("jax_array", False)
jax.config.update("jax_dynamic_shapes", True)
jax.config.update("jax_numpy_rank_promotion", "allow")
xs = jax.vmap(lambda n: jnp.arange(n).sum())(jnp.array([3, 1, 4]))
sizes = jnp.array([3, 1, 4])
p1 = jax.vmap(lambda n: jnp.ones((7, n)), out_axes=batching.pile_axis
)(sizes)
p2 = jax.vmap(lambda n: jnp.ones((n, 7)), out_axes=batching.pile_axis
)(sizes)
y = jax.vmap(jnp.dot, in_axes=batching.pile_axis, out_axes=0,
axis_size=3)(p1, p2)
def pile_map(f):
def mapped(*piles):
return jax.vmap(f, in_axes=batching.pile_axis, out_axes=batching.pile_axis,
axis_size=piles[0].aval.length)(*piles)
return mapped
p = jax.vmap(jnp.arange, out_axes=batching.pile_axis)(jnp.array([3, 1, 4]))
y = pile_map(jnp.dot)(p, p) but I keep getting an error
I'm running this on jax 0.4.6 on a CPU. What's the right way to test this feature out? |
Beta Was this translation helpful? Give feedback.
Answered by
mattjj
Mar 10, 2023
Replies: 1 comment 2 replies
-
Haha, thanks for finding our hidden features! It's really not ready yet; just a concept prototype. That said, if you run your script with |
Beta Was this translation helpful? Give feedback.
2 replies
Answer selected by
KeAWang
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Haha, thanks for finding our hidden features! It's really not ready yet; just a concept prototype.
That said, if you run your script with
JAX_JIT_PJIT_API_MERGE=0 python filename.py
it should work.