Tricky NonConcreteBooleanIndexError
#22639
-
I'm writing a small RL environment and getting a tricky DECK = list(range(0, 13)) # all possible card values
DTYPE = jnp.int8
...
# Compute mask of valid actions
card_vals = jnp.array(DECK, dtype=DTYPE)
available_mask = jnp.where(agent.hand.held, True, False)
legal_mask = jnp.where(
(card_vals >= state.to_beat) | (card_vals == TWO) | (card_vals == THREE) | (card_vals == TEN),
True,
False
)
valid_mask = available_mask & legal_mask
# Peform action selection
if jnp.any(agent.hand.held):
action, state_update = _select_from_hand(key, agent, state, state_update, valid_mask)
else:
action, state_update = _select_from_undercarriage(key, agent, state, state_update, valid_mask)
I've attempted to move to this, action, state_update = jax.lax.cond(
jnp.any(agent.hand.held),
_select_from_hand,
_select_from_undercarriage,
key, agent, state, state_update, valid_mask
) But keep receiving this error which corresponds to the
I understand that shapes are static and need to be known at compile-time in JAX., but the I've tried a few extra things like this too, to no avail action, state_update = jax.lax.cond(
jnp.where(jnp.sum(agent.hand.held) > 0, True, False),
_select_from_hand,
_select_from_undercarriage,
key, agent, state, state_update, valid_mask
) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
It sounds like From the snippet of the traceback you shared, it may be as simple as changing |
Beta Was this translation helpful? Give feedback.
The problem is in this line:
since
agent.hand.under
is traced, there's no way to know statically how many of its values are larger than zero, and soindices
would have to have a dynamic size, and dynamically-sized arrays are not supported in JAX transforms likejit
(see JAX sharp bits: Dynamic Shapes).You could probably work around this by using
jax.random.randint
rather thanjax.random.choice
, and avoiding constructing the explicitindices
array.