Skip to content

Numpy loops to Jax: conversion help #14229

Answered by jakevdp
alonfnt asked this question in Q&A
Discussion options

You must be logged in to vote

This is not an easy function to express in terms of JAX transformations or higher-order primitives. You can't use vmap, because your function is not a purely batch-wise function (the result at z[t1] depends on all x[t2] where t2 > t1)

Further, you won't be able to directly express this within lax.fori_loop or similar because each iteration constructs dynamically-shaped intermediate arrays (if you're iterating over t1 and t2 using fori_loop, the size of z[t1:t2] will be dynamic).

That said, you could use this within JAX nearly as-written if you modify it to use JAX's functional array updates:

import jax

@jax.jit
def f(x, r, z):
  t, n = x.shape
  assert r.shape == (t, t)
  assert z.shape ==

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
1 reply
@alonfnt
Comment options

Answer selected by alonfnt
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants