Updating slices in a pytree with slices from another pytree of the same shape #14686
Unanswered
StoneT2000
asked this question in
Q&A
Replies: 1 comment 1 reply
-
I think you want this: jux_action: JuxAction = jax.tree_map(lambda x, y : x.at[:, 1].set(y[:, 1]), jux_action_0, jux_action_1) or if you want to use jux_action: JuxAction = jax.tree_map(lambda x, y : x.at[:, 1].set(y.at[:, 1].get()), jux_action_0, jux_action_1) |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I currently have a PyTree with the following shape
Notably each leaf has a shared dimensions of (10, 2) = (batch_size, team_id) (this is for a game environment with two players).
Suppose I have two of these,
jux_action_0, jux_action_1
, however the following doesn't workI was hoping this would auto set the 2nd element in the 2nd dimension with what is stored in
jux_action_1
and leave the rest as whatever was injux_action_0
. However I get the following errorThis seems to be because the return value of the lambda passed to tree_map is the output of
.set
. Is there another way to set the value of a slice with another slice and return it in tree_map?Beta Was this translation helpful? Give feedback.
All reactions