diff --git a/pychastic/vectorized_I_generation.py b/pychastic/vectorized_I_generation.py index 0416484..120d698 100644 --- a/pychastic/vectorized_I_generation.py +++ b/pychastic/vectorized_I_generation.py @@ -140,8 +140,9 @@ def make_D_mat(eta, zeta): def take(tensor, idx, fill=0): # Non jit-friendly implementation # illegal = jnp.logical_or(idx > p,idx < 1) - # return tensor[..., idx-1].at[..., illegal].set(fill) + # return tensor[..., idx-1].at[..., illegal].set(fill) legalized_idx = jnp.clip(idx, a_min=1, a_max=p) + # legalized_idx = jnp.clip(idx, min=1, max=p) # python 3.9+ illegal_mask = jnp.logical_or(idx > p, idx < 1) return ( tensor[..., legalized_idx - 1] * (1 - 1 * illegal_mask)