From 8d4cc040db23f14da88124a34b35158084ab752a Mon Sep 17 00:00:00 2001 From: Changwan Ryu Date: Fri, 3 May 2024 10:17:39 -0700 Subject: [PATCH] Acme: update deprecated tree_map API PiperOrigin-RevId: 630425642 Change-Id: I205ace49ab167b826267e92445b0077a4e465079 --- acme/datasets/tfds.py | 33 +++++++++++++++++------------ acme/jax/inference_server.py | 4 +++- acme/jax/observation_stacking.py | 17 +++++++++------ acme/jax/running_statistics.py | 8 +++++-- acme/jax/utils.py | 34 ++++++++++++++++++------------ acme/jax/utils_test.py | 14 ++++++------ acme/utils/reverb_utils.py | 7 +++--- acme/wrappers/multigrid_wrapper.py | 2 +- 8 files changed, 73 insertions(+), 46 deletions(-) diff --git a/acme/datasets/tfds.py b/acme/datasets/tfds.py index 80c79e2497..5c0a236901 100644 --- a/acme/datasets/tfds.py +++ b/acme/datasets/tfds.py @@ -138,7 +138,8 @@ def __init__(self, size = _dataset_size_upperbound(dataset) data = next(dataset.batch(size).as_numpy_iterator()) self._dataset_size = jax.tree_flatten( - jax.tree_map(lambda x: x.shape[0], data))[0][0] + jax.tree_util.tree_map(lambda x: x.shape[0], data) + )[0][0] device = jax_utils._pmap_device_order() if not shard_dataset_across_devices: device = device[:1] @@ -148,21 +149,25 @@ def __init__(self, # len(device) needs to divide self._dataset_size evenly. assert self._dataset_size % len(device) == 0 logging.info('Trying to load %s elements to %s', self._dataset_size, device) - logging.info('Dataset %s %s', - ('before padding' if should_pmap else ''), - jax.tree_map(lambda x: x.shape, data)) + logging.info( + 'Dataset %s %s', + ('before padding' if should_pmap else ''), + jax.tree_util.tree_map(lambda x: x.shape, data), + ) if should_pmap: - shapes = jax.tree_map(lambda x: x.shape, data) + shapes = jax.tree_util.tree_map(lambda x: x.shape, data) # Padding to a multiple of 128 is needed to avoid excessive copying on TPU - data = jax.tree_map(_pad, data) - logging.info('Dataset after padding %s', - jax.tree_map(lambda x: x.shape, data)) + data = jax.tree_util.tree_map(_pad, data) + logging.info( + 'Dataset after padding %s', + jax.tree_util.tree_map(lambda x: x.shape, data), + ) def split_and_put(x: jnp.ndarray) -> jnp.ndarray: return jax.device_put_sharded( np.split(x[:self._dataset_size], len(device)), devices=device) - self._jax_dataset = jax.tree_map(split_and_put, data) + self._jax_dataset = jax.tree_util.tree_map(split_and_put, data) else: - self._jax_dataset = jax.tree_map(jax.device_put, data) + self._jax_dataset = jax.tree_util.tree_map(jax.device_put, data) self._key = (jnp.stack(jax.random.split(key, len(device))) if should_pmap else key) @@ -174,7 +179,9 @@ def sample_per_shard(data: Any, key1, (batch_size // len(device),), minval=0, maxval=self._dataset_size // len(device)) - data_sample = jax.tree_map(lambda d: jnp.take(d, indices, axis=0), data) + data_sample = jax.tree_util.tree_map( + lambda d: jnp.take(d, indices, axis=0), data + ) return data_sample, key2 if should_pmap: @@ -184,7 +191,7 @@ def sample(data, key): # since it avoids Host - Device communications. data_sample = jax.lax.all_gather( data_sample, axis_name=_PMAP_AXIS_NAME, axis=0, tiled=True) - data_sample = jax.tree_map(_unpad, data_sample, shapes) + data_sample = jax.tree_util.tree_map(_unpad, data_sample, shapes) return data_sample, key pmapped_sample = jax.pmap(sample, axis_name=_PMAP_AXIS_NAME) @@ -193,7 +200,7 @@ def sample_and_postprocess(key: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: data, key = pmapped_sample(self._jax_dataset, key) # All pmapped devices return the same data, so we just take the one from # the first device. - return jax.tree_map(lambda x: x[0], data), key + return jax.tree_util.tree_map(lambda x: x[0], data), key self._sample = sample_and_postprocess else: self._sample = jax.jit( diff --git a/acme/jax/inference_server.py b/acme/jax/inference_server.py index a7d331aa9b..599ca7fa2c 100644 --- a/acme/jax/inference_server.py +++ b/acme/jax/inference_server.py @@ -74,7 +74,9 @@ def __init__( self._device_params = [None] * len(self._devices) self._device_params_ids = [None] * len(self._devices) self._mutex = threading.Lock() - self._handler = jax.tree_map(self._build_handler, handler, is_leaf=callable) + self._handler = jax.tree_util.tree_map( + self._build_handler, handler, is_leaf=callable + ) @property def handler(self) -> InferenceServerHandler: diff --git a/acme/jax/observation_stacking.py b/acme/jax/observation_stacking.py index d495715ede..9fb92fbc83 100644 --- a/acme/jax/observation_stacking.py +++ b/acme/jax/observation_stacking.py @@ -53,7 +53,7 @@ def _tile_array(array: jnp.ndarray) -> jnp.ndarray: reps[axis] = num return jnp.tile(array, reps) - return jax.tree_map(_tile_array, nest) + return jax.tree_util.tree_map(_tile_array, nest) class ObservationStacker: @@ -94,14 +94,18 @@ def __call__(self, inputs: Observation, inputs) # Concatenate frames along the final axis (assumed to be for channels). - output = jax.tree_map(lambda *x: jnp.concatenate(x, axis=-1), - state.stack, inputs) + output = jax.tree_util.tree_map( + lambda *x: jnp.concatenate(x, axis=-1), state.stack, inputs + ) # Update the frame stack by adding the input and dropping the first # observation in the stack. Note that we use the final dimension as each # leaf in the nested observation may have a different last dim. new_state = state._replace( - stack=jax.tree_map(lambda x, y: y[..., x.shape[-1]:], inputs, output)) + stack=jax.tree_util.tree_map( + lambda x, y: y[..., x.shape[-1] :], inputs, output + ) + ) return output, new_state @@ -118,8 +122,9 @@ def stack_observation_spec(obs_spec: specs.Array) -> specs.Array: new_shape = obs_spec.shape[:-1] + (obs_spec.shape[-1] * stack_size,) return obs_spec.replace(shape=new_shape) - adjusted_observation_spec = jax.tree_map(stack_observation_spec, - environment_spec.observations) + adjusted_observation_spec = jax.tree_util.tree_map( + stack_observation_spec, environment_spec.observations + ) return environment_spec._replace(observations=adjusted_observation_spec) diff --git a/acme/jax/running_statistics.py b/acme/jax/running_statistics.py index 6f6688e41d..9445460c98 100644 --- a/acme/jax/running_statistics.py +++ b/acme/jax/running_statistics.py @@ -43,11 +43,15 @@ def _is_prefix(a: Path, b: Path) -> bool: def _zeros_like(nest: types.Nest, dtype=None) -> types.NestedArray: - return jax.tree_map(lambda x: jnp.zeros(x.shape, dtype or x.dtype), nest) + return jax.tree_util.tree_map( + lambda x: jnp.zeros(x.shape, dtype or x.dtype), nest + ) def _ones_like(nest: types.Nest, dtype=None) -> types.NestedArray: - return jax.tree_map(lambda x: jnp.ones(x.shape, dtype or x.dtype), nest) + return jax.tree_util.tree_map( + lambda x: jnp.ones(x.shape, dtype or x.dtype), nest + ) @chex.dataclass(frozen=True) diff --git a/acme/jax/utils.py b/acme/jax/utils.py index 60ed8aaf39..8550b604e2 100644 --- a/acme/jax/utils.py +++ b/acme/jax/utils.py @@ -41,7 +41,7 @@ def add_batch_dim(values: types.Nest) -> types.NestedArray: - return jax.tree_map(lambda x: jnp.expand_dims(x, axis=0), values) + return jax.tree_util.tree_map(lambda x: jnp.expand_dims(x, axis=0), values) def _flatten(x: jnp.ndarray, num_batch_dims: int) -> jnp.ndarray: @@ -76,24 +76,28 @@ def batch_concat( def zeros_like(nest: types.Nest, dtype=None) -> types.NestedArray: - return jax.tree_map(lambda x: jnp.zeros(x.shape, dtype or x.dtype), nest) + return jax.tree_util.tree_map( + lambda x: jnp.zeros(x.shape, dtype or x.dtype), nest + ) def ones_like(nest: types.Nest, dtype=None) -> types.NestedArray: - return jax.tree_map(lambda x: jnp.ones(x.shape, dtype or x.dtype), nest) + return jax.tree_util.tree_map( + lambda x: jnp.ones(x.shape, dtype or x.dtype), nest + ) def squeeze_batch_dim(nest: types.Nest) -> types.NestedArray: - return jax.tree_map(lambda x: jnp.squeeze(x, axis=0), nest) + return jax.tree_util.tree_map(lambda x: jnp.squeeze(x, axis=0), nest) def to_numpy_squeeze(values: types.Nest) -> types.NestedArray: """Converts to numpy and squeezes out dummy batch dimension.""" - return jax.tree_map(lambda x: np.asarray(x).squeeze(axis=0), values) + return jax.tree_util.tree_map(lambda x: np.asarray(x).squeeze(axis=0), values) def to_numpy(values: types.Nest) -> types.NestedArray: - return jax.tree_map(np.asarray, values) + return jax.tree_util.tree_map(np.asarray, values) def fetch_devicearray(values: types.Nest) -> types.Nest: @@ -108,8 +112,9 @@ def _fetch_devicearray(x): def batch_to_sequence(values: types.Nest) -> types.NestedArray: - return jax.tree_map( - lambda x: jnp.transpose(x, axes=(1, 0, *range(2, len(x.shape)))), values) + return jax.tree_util.tree_map( + lambda x: jnp.transpose(x, axes=(1, 0, *range(2, len(x.shape)))), values + ) def tile_array(array: jnp.ndarray, multiple: int) -> jnp.ndarray: @@ -120,7 +125,7 @@ def tile_array(array: jnp.ndarray, multiple: int) -> jnp.ndarray: def tile_nested(inputs: types.Nest, multiple: int) -> types.Nest: """Tiles tensors in a nested structure along a new leading axis.""" tile = functools.partial(tile_array, multiple=multiple) - return jax.tree_map(tile, inputs) + return jax.tree_util.tree_map(tile, inputs) def maybe_recover_lstm_type(state: types.NestedArray) -> types.NestedArray: @@ -379,7 +384,7 @@ def get_from_first_device(nest: N, as_numpy: bool = True) -> N: the same device as the sharded device array). If `as_numpy=True` then the array will be copied to the host machine and converted into a `np.ndarray`. """ - zeroth_nest = jax.tree_map(lambda x: x[0], nest) + zeroth_nest = jax.tree_util.tree_map(lambda x: x[0], nest) return jax.device_get(zeroth_nest) if as_numpy else zeroth_nest @@ -411,7 +416,7 @@ def mapreduce( vmapped_f = jax.vmap(f, **vmap_kwargs) def g(*args, **kwargs): - return jax.tree_map(reduce_fn, vmapped_f(*args, **kwargs)) + return jax.tree_util.tree_map(reduce_fn, vmapped_f(*args, **kwargs)) return g @@ -453,11 +458,12 @@ def _process_one_batch(state, data): return _process_one_batch if postprocess_aux is None: - postprocess_aux = lambda x: jax.tree_map(jnp.mean, x) + postprocess_aux = lambda x: jax.tree_util.tree_map(jnp.mean, x) def _process_multiple_batches(state, data): - data = jax.tree_map( - lambda a: jnp.reshape(a, (num_batches, -1, *a.shape[1:])), data) + data = jax.tree_util.tree_map( + lambda a: jnp.reshape(a, (num_batches, -1, *a.shape[1:])), data + ) state, aux = jax.lax.scan( process_one_batch, state, data, length=num_batches) diff --git a/acme/jax/utils_test.py b/acme/jax/utils_test.py index 04786d7319..7e8bdf36e4 100644 --- a/acme/jax/utils_test.py +++ b/acme/jax/utils_test.py @@ -72,15 +72,17 @@ def test_get_from_first_device(self): # Get zeroth device content as DeviceArray. device_arrays = utils.get_from_first_device(sharded, as_numpy=False) - jax.tree_map( - lambda x: self.assertIsInstance(x, jax.Array), - device_arrays) - jax.tree_map(np.testing.assert_array_equal, want, device_arrays) + jax.tree_util.tree_map( + lambda x: self.assertIsInstance(x, jax.Array), device_arrays + ) + jax.tree_util.tree_map(np.testing.assert_array_equal, want, device_arrays) # Get the zeroth device content as numpy arrays. numpy_arrays = utils.get_from_first_device(sharded, as_numpy=True) - jax.tree_map(lambda x: self.assertIsInstance(x, np.ndarray), numpy_arrays) - jax.tree_map(np.testing.assert_array_equal, want, numpy_arrays) + jax.tree_util.tree_map( + lambda x: self.assertIsInstance(x, np.ndarray), numpy_arrays + ) + jax.tree_util.tree_map(np.testing.assert_array_equal, want, numpy_arrays) if __name__ == '__main__': diff --git a/acme/utils/reverb_utils.py b/acme/utils/reverb_utils.py index 5df39153a7..adc85004d2 100644 --- a/acme/utils/reverb_utils.py +++ b/acme/utils/reverb_utils.py @@ -124,11 +124,12 @@ def roll(observation): # We remove the last transition as its next_observation field is incorrect. # It has been obtained by rolling the observation field, such that # transitions.next_observations[:, -1] is transitions.observations[:, 0] - transitions = jax.tree_map(lambda x: x[:, :-1, ...], transitions) + transitions = jax.tree_util.tree_map(lambda x: x[:, :-1, ...], transitions) if flatten_batch: # Merge the 2 leading batch dimensions into 1. - transitions = jax.tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), - transitions) + transitions = jax.tree_util.tree_map( + lambda x: np.reshape(x, (-1,) + x.shape[2:]), transitions + ) return transitions diff --git a/acme/wrappers/multigrid_wrapper.py b/acme/wrappers/multigrid_wrapper.py index 9c4b2287bb..b75e17b402 100644 --- a/acme/wrappers/multigrid_wrapper.py +++ b/acme/wrappers/multigrid_wrapper.py @@ -195,7 +195,7 @@ def make_single_agent_spec(spec): else: raise ValueError(f'Unexpected spec type {type(spec)}.') - single_agent_spec = jax.tree_map(make_single_agent_spec, spec) + single_agent_spec = jax.tree_util.tree_map(make_single_agent_spec, spec) return single_agent_spec