Skip to content

double jax.vmap working as two nested loops, why single jax.vmap also working sometimes. #14057

Answered by soraros
jecampagne asked this question in Q&A
Discussion options

You must be logged in to vote

Ah, I see the problem.

  1. outer-like
# postulate
a.shape == (m,)
b.shape == (n,)

vmap(vmap(f, in_axes=(None, 0)), in_axes=(0, None))(a, b)

which gives

[ f(a0, b0), f(a0, b1), ..., f(a0, bn);
  f(a1, b1), f(a1, b1), ..., f(a1, bn);
  ...
  f(am, bn), f(am, b1), ..., f(am, bn);
]
  1. vectorize-like
# postulate
a.shape == b.shape == (m, n)

vmap(vmap(f, in_axes=(0, 0)), in_axes=(0, 0))(a, b)

which gives

[ f(a00, b00), f(a01, b01), ..., f(a0n, b0n);
  f(a10, b10), f(a11, b11), ..., f(a1n, b1n);
  ...
  f(am0, bm0), f(am1, bm1), ..., f(amn, bmn);
]

Now go back to your code

X, Y = coords[..., 0], coords[..., 1]

# scenario 1, correct
f1 = vmap(vmap(lambda i, j: f(X[i, j], Y[i, j]),
               i…

Replies: 1 comment 6 replies

Comment options

You must be logged in to vote
6 replies
@soraros
Comment options

@jecampagne
Comment options

@soraros
Comment options

@soraros
Comment options

@jecampagne
Comment options

Answer selected by jecampagne
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants