How to perform a strided dynamic update slice? #22246
-
To deal with sliding window functions, I've created some code that batches the sliding window into big chunks. So a (3,3) kernel would operate over 3*3=9 total views of the original matrix. This way I don't create a huge temporary array of all sliding windows at once (causing OOM due to intermediate representation size in JAX), and I don't get overhead/blowup of doing a The only part that is missing is combining the stack of C_out =2
H, W = 11, 11
Q = jnp.zeros((C_out, H*3,W*3))
S = jnp.arange((9* C_out*H*W)).reshape(9, C_out, H, W)
for z in range(9):
i, j = jnp.mod(z,kh), jnp.mod(z//kw,kh)
Q = Q.at[:, i::3,j::3].add(S[z,:,:,:]) However, this isn't going to run nicely without a jit. The problem is hitting the def bod(ij, val):
i, j = jnp.mod(z,kh), jnp.mod(z//kw,kh) #[0, kw*kh] -> tuples of module arithmatic (0,0), (1, 0), ...
val = val.at[:, i::3,j::3].add(S[ij,:,:,:])
return val
jax.lax.fori_loop(0, kw*kh, bod, Q) But I get an error about static static start/stop/step. The Any alternative approaches to the interlevade reconstruction of the Array([[0, 1, 2, 0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5, 3, 4, 5],
[6, 7, 8, 6, 7, 8, 6, 7, 8],
[0, 1, 2, 0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5, 3, 4, 5],
[6, 7, 8, 6, 7, 8, 6, 7, 8],
[0, 1, 2, 0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5, 3, 4, 5],
[6, 7, 8, 6, 7, 8, 6, 7, 8]], dtype=int32) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 6 replies
-
I'm not sure I totally understand the question or issues that you're encountering, but something like your first example works just fine under @jax.jit
def fun(S, Q):
for z in range(9):
i, j = np.mod(z,3), np.mod(z//3,3) # <- Using 'np' not 'jnp' here
Q = Q.at[:, i::3,j::3].add(S[z,:,:,:])
return Q
C_out = 2
H, W = 11, 11
Q = jnp.zeros((C_out, H*3,W*3))
S = jnp.arange((9* C_out*H*W)).reshape(9, C_out, H, W)
Q = fun(S, Q) Depending on your actual use case, you might also be able to use |
Beta Was this translation helpful? Give feedback.
I'm not sure I totally understand the question or issues that you're encountering, but something like your first example works just fine under
jit
:Depending on your actual use case, you might also be able to use
scan
with theunroll
parameter set>1
, or (even better!) perhaps you can rewrite your problem as a convolution and use JAX's convolution implementation!