diff --git a/acme/wrappers/multiagent_dict_key_wrapper.py b/acme/wrappers/multiagent_dict_key_wrapper.py index 1c15c12199..8c6d731890 100644 --- a/acme/wrappers/multiagent_dict_key_wrapper.py +++ b/acme/wrappers/multiagent_dict_key_wrapper.py @@ -24,7 +24,8 @@ class MultiagentDictKeyWrapper(base.EnvironmentWrapper): - """Wrapper that converts list-indexed multiagent environments to dict-indexed. + """Wrapper that converts list or tuple indexed multiagent environments + to dict-indexed. Specifically, if the underlying environment observation and actions are: observation = [observation_agent_0, observation_agent_1, ...] @@ -49,8 +50,10 @@ def __init__(self, environment: dm_env.Environment): self._reward_spec = self._list_to_dict(self._environment.reward_spec()) def _list_to_dict(self, data: Union[List[V], V]) -> Union[Dict[str, V], V]: - """Convert list-indexed data to dict-indexed, otherwise passthrough.""" - if isinstance(data, list): + """Convert list or tuple indexed data to dict-indexed, otherwise + passthrough. + """ + if isinstance(data, list) or isinstance(data, tuple): return {str(k): v for k, v in enumerate(data)} return data