You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hey, I just thinking about how to tackle a problem and came up with this simple toy example involving pmap:
importjaximportjax.numpyasjnp@jax.jitdeff(w, x):
@jax.pmapdefg(x):
returnw*xreturng(x)
x=jnp.ones([8, 3])
w=jnp.array([2])
print(f(w, x)) # array with 2'sw=jnp.array([4])
print(f(w, x)) # array with 4's
My initial intuition was that this wouldn't work since w is captured by g it would be treated as a constant in the eyes of pmap, however, the jit around f somehow resolves this issue. This patterns "works" but seems a bit wrong somehow and I have a lot of doubts about how w is actually being treated:
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hey, I just thinking about how to tackle a problem and came up with this simple toy example involving
pmap
:My initial intuition was that this wouldn't work since
w
is captured byg
it would be treated as a constant in the eyes ofpmap
, however, thejit
aroundf
somehow resolves this issue. This patterns "works" but seems a bit wrong somehow and I have a lot of doubts about howw
is actually being treated:w
replicated across devices?pmap
+jit
?pmap
?Beta Was this translation helpful? Give feedback.
All reactions