Skip to content

Commit

Permalink
Acme: update deprecated tree_map API
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 630425642
Change-Id: I205ace49ab167b826267e92445b0077a4e465079
  • Loading branch information
galmacky authored and copybara-github committed May 3, 2024
1 parent 3a1420d commit 8d4cc04
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 46 deletions.
33 changes: 20 additions & 13 deletions acme/datasets/tfds.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def __init__(self,
size = _dataset_size_upperbound(dataset)
data = next(dataset.batch(size).as_numpy_iterator())
self._dataset_size = jax.tree_flatten(
jax.tree_map(lambda x: x.shape[0], data))[0][0]
jax.tree_util.tree_map(lambda x: x.shape[0], data)
)[0][0]
device = jax_utils._pmap_device_order()
if not shard_dataset_across_devices:
device = device[:1]
Expand All @@ -148,21 +149,25 @@ def __init__(self,
# len(device) needs to divide self._dataset_size evenly.
assert self._dataset_size % len(device) == 0
logging.info('Trying to load %s elements to %s', self._dataset_size, device)
logging.info('Dataset %s %s',
('before padding' if should_pmap else ''),
jax.tree_map(lambda x: x.shape, data))
logging.info(
'Dataset %s %s',
('before padding' if should_pmap else ''),
jax.tree_util.tree_map(lambda x: x.shape, data),
)
if should_pmap:
shapes = jax.tree_map(lambda x: x.shape, data)
shapes = jax.tree_util.tree_map(lambda x: x.shape, data)
# Padding to a multiple of 128 is needed to avoid excessive copying on TPU
data = jax.tree_map(_pad, data)
logging.info('Dataset after padding %s',
jax.tree_map(lambda x: x.shape, data))
data = jax.tree_util.tree_map(_pad, data)
logging.info(
'Dataset after padding %s',
jax.tree_util.tree_map(lambda x: x.shape, data),
)
def split_and_put(x: jnp.ndarray) -> jnp.ndarray:
return jax.device_put_sharded(
np.split(x[:self._dataset_size], len(device)), devices=device)
self._jax_dataset = jax.tree_map(split_and_put, data)
self._jax_dataset = jax.tree_util.tree_map(split_and_put, data)
else:
self._jax_dataset = jax.tree_map(jax.device_put, data)
self._jax_dataset = jax.tree_util.tree_map(jax.device_put, data)

self._key = (jnp.stack(jax.random.split(key, len(device)))
if should_pmap else key)
Expand All @@ -174,7 +179,9 @@ def sample_per_shard(data: Any,
key1, (batch_size // len(device),),
minval=0,
maxval=self._dataset_size // len(device))
data_sample = jax.tree_map(lambda d: jnp.take(d, indices, axis=0), data)
data_sample = jax.tree_util.tree_map(
lambda d: jnp.take(d, indices, axis=0), data
)
return data_sample, key2

if should_pmap:
Expand All @@ -184,7 +191,7 @@ def sample(data, key):
# since it avoids Host - Device communications.
data_sample = jax.lax.all_gather(
data_sample, axis_name=_PMAP_AXIS_NAME, axis=0, tiled=True)
data_sample = jax.tree_map(_unpad, data_sample, shapes)
data_sample = jax.tree_util.tree_map(_unpad, data_sample, shapes)
return data_sample, key

pmapped_sample = jax.pmap(sample, axis_name=_PMAP_AXIS_NAME)
Expand All @@ -193,7 +200,7 @@ def sample_and_postprocess(key: jnp.ndarray) -> Tuple[Any, jnp.ndarray]:
data, key = pmapped_sample(self._jax_dataset, key)
# All pmapped devices return the same data, so we just take the one from
# the first device.
return jax.tree_map(lambda x: x[0], data), key
return jax.tree_util.tree_map(lambda x: x[0], data), key
self._sample = sample_and_postprocess
else:
self._sample = jax.jit(
Expand Down
4 changes: 3 additions & 1 deletion acme/jax/inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def __init__(
self._device_params = [None] * len(self._devices)
self._device_params_ids = [None] * len(self._devices)
self._mutex = threading.Lock()
self._handler = jax.tree_map(self._build_handler, handler, is_leaf=callable)
self._handler = jax.tree_util.tree_map(
self._build_handler, handler, is_leaf=callable
)

@property
def handler(self) -> InferenceServerHandler:
Expand Down
17 changes: 11 additions & 6 deletions acme/jax/observation_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _tile_array(array: jnp.ndarray) -> jnp.ndarray:
reps[axis] = num
return jnp.tile(array, reps)

return jax.tree_map(_tile_array, nest)
return jax.tree_util.tree_map(_tile_array, nest)


class ObservationStacker:
Expand Down Expand Up @@ -94,14 +94,18 @@ def __call__(self, inputs: Observation,
inputs)

# Concatenate frames along the final axis (assumed to be for channels).
output = jax.tree_map(lambda *x: jnp.concatenate(x, axis=-1),
state.stack, inputs)
output = jax.tree_util.tree_map(
lambda *x: jnp.concatenate(x, axis=-1), state.stack, inputs
)

# Update the frame stack by adding the input and dropping the first
# observation in the stack. Note that we use the final dimension as each
# leaf in the nested observation may have a different last dim.
new_state = state._replace(
stack=jax.tree_map(lambda x, y: y[..., x.shape[-1]:], inputs, output))
stack=jax.tree_util.tree_map(
lambda x, y: y[..., x.shape[-1] :], inputs, output
)
)

return output, new_state

Expand All @@ -118,8 +122,9 @@ def stack_observation_spec(obs_spec: specs.Array) -> specs.Array:
new_shape = obs_spec.shape[:-1] + (obs_spec.shape[-1] * stack_size,)
return obs_spec.replace(shape=new_shape)

adjusted_observation_spec = jax.tree_map(stack_observation_spec,
environment_spec.observations)
adjusted_observation_spec = jax.tree_util.tree_map(
stack_observation_spec, environment_spec.observations
)

return environment_spec._replace(observations=adjusted_observation_spec)

Expand Down
8 changes: 6 additions & 2 deletions acme/jax/running_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,15 @@ def _is_prefix(a: Path, b: Path) -> bool:


def _zeros_like(nest: types.Nest, dtype=None) -> types.NestedArray:
return jax.tree_map(lambda x: jnp.zeros(x.shape, dtype or x.dtype), nest)
return jax.tree_util.tree_map(
lambda x: jnp.zeros(x.shape, dtype or x.dtype), nest
)


def _ones_like(nest: types.Nest, dtype=None) -> types.NestedArray:
return jax.tree_map(lambda x: jnp.ones(x.shape, dtype or x.dtype), nest)
return jax.tree_util.tree_map(
lambda x: jnp.ones(x.shape, dtype or x.dtype), nest
)


@chex.dataclass(frozen=True)
Expand Down
34 changes: 20 additions & 14 deletions acme/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@


def add_batch_dim(values: types.Nest) -> types.NestedArray:
return jax.tree_map(lambda x: jnp.expand_dims(x, axis=0), values)
return jax.tree_util.tree_map(lambda x: jnp.expand_dims(x, axis=0), values)


def _flatten(x: jnp.ndarray, num_batch_dims: int) -> jnp.ndarray:
Expand Down Expand Up @@ -76,24 +76,28 @@ def batch_concat(


def zeros_like(nest: types.Nest, dtype=None) -> types.NestedArray:
return jax.tree_map(lambda x: jnp.zeros(x.shape, dtype or x.dtype), nest)
return jax.tree_util.tree_map(
lambda x: jnp.zeros(x.shape, dtype or x.dtype), nest
)


def ones_like(nest: types.Nest, dtype=None) -> types.NestedArray:
return jax.tree_map(lambda x: jnp.ones(x.shape, dtype or x.dtype), nest)
return jax.tree_util.tree_map(
lambda x: jnp.ones(x.shape, dtype or x.dtype), nest
)


def squeeze_batch_dim(nest: types.Nest) -> types.NestedArray:
return jax.tree_map(lambda x: jnp.squeeze(x, axis=0), nest)
return jax.tree_util.tree_map(lambda x: jnp.squeeze(x, axis=0), nest)


def to_numpy_squeeze(values: types.Nest) -> types.NestedArray:
"""Converts to numpy and squeezes out dummy batch dimension."""
return jax.tree_map(lambda x: np.asarray(x).squeeze(axis=0), values)
return jax.tree_util.tree_map(lambda x: np.asarray(x).squeeze(axis=0), values)


def to_numpy(values: types.Nest) -> types.NestedArray:
return jax.tree_map(np.asarray, values)
return jax.tree_util.tree_map(np.asarray, values)


def fetch_devicearray(values: types.Nest) -> types.Nest:
Expand All @@ -108,8 +112,9 @@ def _fetch_devicearray(x):


def batch_to_sequence(values: types.Nest) -> types.NestedArray:
return jax.tree_map(
lambda x: jnp.transpose(x, axes=(1, 0, *range(2, len(x.shape)))), values)
return jax.tree_util.tree_map(
lambda x: jnp.transpose(x, axes=(1, 0, *range(2, len(x.shape)))), values
)


def tile_array(array: jnp.ndarray, multiple: int) -> jnp.ndarray:
Expand All @@ -120,7 +125,7 @@ def tile_array(array: jnp.ndarray, multiple: int) -> jnp.ndarray:
def tile_nested(inputs: types.Nest, multiple: int) -> types.Nest:
"""Tiles tensors in a nested structure along a new leading axis."""
tile = functools.partial(tile_array, multiple=multiple)
return jax.tree_map(tile, inputs)
return jax.tree_util.tree_map(tile, inputs)


def maybe_recover_lstm_type(state: types.NestedArray) -> types.NestedArray:
Expand Down Expand Up @@ -379,7 +384,7 @@ def get_from_first_device(nest: N, as_numpy: bool = True) -> N:
the same device as the sharded device array). If `as_numpy=True` then the
array will be copied to the host machine and converted into a `np.ndarray`.
"""
zeroth_nest = jax.tree_map(lambda x: x[0], nest)
zeroth_nest = jax.tree_util.tree_map(lambda x: x[0], nest)
return jax.device_get(zeroth_nest) if as_numpy else zeroth_nest


Expand Down Expand Up @@ -411,7 +416,7 @@ def mapreduce(
vmapped_f = jax.vmap(f, **vmap_kwargs)

def g(*args, **kwargs):
return jax.tree_map(reduce_fn, vmapped_f(*args, **kwargs))
return jax.tree_util.tree_map(reduce_fn, vmapped_f(*args, **kwargs))

return g

Expand Down Expand Up @@ -453,11 +458,12 @@ def _process_one_batch(state, data):
return _process_one_batch

if postprocess_aux is None:
postprocess_aux = lambda x: jax.tree_map(jnp.mean, x)
postprocess_aux = lambda x: jax.tree_util.tree_map(jnp.mean, x)

def _process_multiple_batches(state, data):
data = jax.tree_map(
lambda a: jnp.reshape(a, (num_batches, -1, *a.shape[1:])), data)
data = jax.tree_util.tree_map(
lambda a: jnp.reshape(a, (num_batches, -1, *a.shape[1:])), data
)

state, aux = jax.lax.scan(
process_one_batch, state, data, length=num_batches)
Expand Down
14 changes: 8 additions & 6 deletions acme/jax/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,17 @@ def test_get_from_first_device(self):

# Get zeroth device content as DeviceArray.
device_arrays = utils.get_from_first_device(sharded, as_numpy=False)
jax.tree_map(
lambda x: self.assertIsInstance(x, jax.Array),
device_arrays)
jax.tree_map(np.testing.assert_array_equal, want, device_arrays)
jax.tree_util.tree_map(
lambda x: self.assertIsInstance(x, jax.Array), device_arrays
)
jax.tree_util.tree_map(np.testing.assert_array_equal, want, device_arrays)

# Get the zeroth device content as numpy arrays.
numpy_arrays = utils.get_from_first_device(sharded, as_numpy=True)
jax.tree_map(lambda x: self.assertIsInstance(x, np.ndarray), numpy_arrays)
jax.tree_map(np.testing.assert_array_equal, want, numpy_arrays)
jax.tree_util.tree_map(
lambda x: self.assertIsInstance(x, np.ndarray), numpy_arrays
)
jax.tree_util.tree_map(np.testing.assert_array_equal, want, numpy_arrays)


if __name__ == '__main__':
Expand Down
7 changes: 4 additions & 3 deletions acme/utils/reverb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,12 @@ def roll(observation):
# We remove the last transition as its next_observation field is incorrect.
# It has been obtained by rolling the observation field, such that
# transitions.next_observations[:, -1] is transitions.observations[:, 0]
transitions = jax.tree_map(lambda x: x[:, :-1, ...], transitions)
transitions = jax.tree_util.tree_map(lambda x: x[:, :-1, ...], transitions)
if flatten_batch:
# Merge the 2 leading batch dimensions into 1.
transitions = jax.tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]),
transitions)
transitions = jax.tree_util.tree_map(
lambda x: np.reshape(x, (-1,) + x.shape[2:]), transitions
)
return transitions


Expand Down
2 changes: 1 addition & 1 deletion acme/wrappers/multigrid_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def make_single_agent_spec(spec):
else:
raise ValueError(f'Unexpected spec type {type(spec)}.')

single_agent_spec = jax.tree_map(make_single_agent_spec, spec)
single_agent_spec = jax.tree_util.tree_map(make_single_agent_spec, spec)
return single_agent_spec


Expand Down

0 comments on commit 8d4cc04

Please sign in to comment.