Non-hashable static arguments not allowed for pmap but allowed for vmap? #14159
Unanswered
somearthling
asked this question in
Q&A
Replies: 1 comment 2 replies
-
This sounds like a bug: a traced value should never be valid as a static argument, so you should get an error regardless of whether you use import os
os.environ['XLA_FLAGS'] = " --xla_force_host_platform_device_count=8"
import jax
import jax.numpy as jnp
def f(x, y):
return x + y
@jax.jit
def g_vmap(x, y):
jax.vmap(jax.jit(f, static_argnums=1))(x, y)
@jax.jit
def g_pmap(x, y):
jax.pmap(jax.jit(f, static_argnums=1))(x, y)
x = jnp.ones((8, 2))
y = jnp.ones((8, 2))
g_vmap(x, y)
# ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses.
g_pmap(x, y)
# ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I have this piece of code
qnode = jit(qml.QNode(circuit, dev, interface='jax'), static_argnums=2)
probs = pmap(qnode, in_axes=(0, None, None))(x, p, filt)
inside a function that is already jitted, and so x and p are jax tracers at this point. This runs completely fine, but when I replace the vmap with a pmap, I get the following error:
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> for function circuit is non-hashable.
Is this intended, or is there something I'm doing wrong that I can fix?
Beta Was this translation helpful? Give feedback.
All reactions