Skip to content

Commit

Permalink
no ignore counters
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiges committed Oct 2, 2024
1 parent b3eb488 commit 6d7069d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions jaxmarl/environments/overcooked_v2/overcooked.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ def get_obs_featurized(self, state: State) -> chex.Array:

# TODO: maybe pass as argument
num_pots = 2
reproduce_overcooked_ai = True
ignore_counters = False

onion = DynamicObject.ingredient(0)
recipe = 3 * onion
Expand Down Expand Up @@ -841,7 +841,7 @@ def _closest_features(
dynamic_mask = state.grid[:, :, 1] == dynamic_locator
if not_in_pot:
dynamic_mask &= state.grid[:, :, 0] != StaticObject.POT
if reproduce_overcooked_ai:
if ignore_counters:
dynamic_mask &= state.grid[:, :, 0] != StaticObject.WALL
mask |= dynamic_mask
mask = mask.at[pos.y, pos.x].set(inv == dynamic_locator)
Expand Down Expand Up @@ -878,13 +878,13 @@ def _closest_features(
empty_counter_features = _closest_features(
static_locator=StaticObject.WALL, no_ingredients=True
)
if reproduce_overcooked_ai:
if ignore_counters:
empty_counter_features = jnp.array([0, 0])

# pi_closest_soup_n_{onions|tomatoes}
# we assume that recipe is always 3 onions
soup_on_grid_mask = state.grid[:, :, 1] == soup
if reproduce_overcooked_ai:
if ignore_counters:
soup_on_grid_mask &= state.grid[:, :, 0] != StaticObject.WALL
soup_onions = jax.lax.select(
jnp.any(soup_on_grid_mask) | (inv == soup), 3, 0
Expand Down

0 comments on commit 6d7069d

Please sign in to comment.