From d92e23b9a1eb29ea5b799c448eae92ab53ec8338 Mon Sep 17 00:00:00 2001 From: Bobak Shahriari Date: Mon, 9 Oct 2023 05:06:41 -0700 Subject: [PATCH] Acme: Make D4PG use the tested n-step transition adder. Fixes Issue 292. PiperOrigin-RevId: 571906732 Change-Id: I7ee0c4952fab2f3eec353e787caeffab799617a9 --- acme/adders/reverb/__init__.py | 1 + acme/adders/reverb/structured.py | 183 +++++++++++++++++++++----- acme/adders/reverb/structured_test.py | 85 +++--------- acme/agents/jax/d4pg/builder.py | 83 ++---------- 4 files changed, 172 insertions(+), 180 deletions(-) diff --git a/acme/adders/reverb/__init__.py b/acme/adders/reverb/__init__.py index 189f8ce780..9fc3e80267 100644 --- a/acme/adders/reverb/__init__.py +++ b/acme/adders/reverb/__init__.py @@ -28,5 +28,6 @@ from acme.adders.reverb.sequence import SequenceAdder from acme.adders.reverb.structured import create_n_step_transition_config from acme.adders.reverb.structured import create_step_spec +from acme.adders.reverb.structured import n_step_from_trajectory from acme.adders.reverb.structured import StructuredAdder from acme.adders.reverb.transition import NStepTransitionAdder diff --git a/acme/adders/reverb/structured.py b/acme/adders/reverb/structured.py index de2b712057..9e899f17f9 100644 --- a/acme/adders/reverb/structured.py +++ b/acme/adders/reverb/structured.py @@ -16,7 +16,6 @@ import itertools import time - from typing import Callable, List, Optional, Sequence, Sized from absl import logging @@ -25,6 +24,7 @@ from acme.adders import base as adders_base from acme.adders.reverb import base as reverb_base from acme.adders.reverb import sequence as sequence_adder +from acme.utils import tree_utils import dm_env import numpy as np import reverb @@ -63,8 +63,13 @@ class StructuredAdder(adders_base.Adder): expected to perform preprocessing in the dataset pipeline on the learner. """ - def __init__(self, client: reverb.Client, max_in_flight_items: int, - configs: Sequence[sw.Config], step_spec: Step): + def __init__( + self, + client: reverb.Client, + max_in_flight_items: int, + configs: Sequence[sw.Config], + step_spec: Step, + ): """Initialize a StructuredAdder instance. Args: @@ -86,7 +91,8 @@ def __init__(self, client: reverb.Client, max_in_flight_items: int, sw.infer_signature(list(table_configs), step_spec) except ValueError as e: raise ValueError( - f'Received invalid configs for table {table}: {str(e)}') from e + f'Received invalid configs for table {table}: {str(e)}' + ) from e self._client = client self._configs = tuple(configs) @@ -106,7 +112,9 @@ def __del__(self): except reverb.DeadlineExceededError as e: logging.error( 'Timeout (10 s) exceeded when flushing the writer before ' - 'deleting it. Caught Reverb exception: %s', str(e)) + 'deleting it. Caught Reverb exception: %s', + str(e), + ) def _make_step(self, **kwargs) -> Step: """Complete the step with None in the missing positions.""" @@ -132,7 +140,8 @@ def add_first(self, timestep: dm_env.TimeStep): if not timestep.first(): raise ValueError( 'adder.add_first called with a timestep that was not the first of its' - 'episode (i.e. one for which timestep.first() is not True)') + 'episode (i.e. one for which timestep.first() is not True)' + ) if self._writer is None: self._writer = self._client.structured_writer(self._configs) @@ -142,15 +151,18 @@ def add_first(self, timestep: dm_env.TimeStep): # passing `partial_step=True`. self._writer.append( data=self._make_step( - observation=timestep.observation, - start_of_episode=timestep.first()), - partial_step=True) + observation=timestep.observation, start_of_episode=timestep.first() + ), + partial_step=True, + ) self._writer.flush(self._max_in_flight_items) - def add(self, - action: types.NestedArray, - next_timestep: dm_env.TimeStep, - extras: types.NestedArray = ()): + def add( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = (), + ): """Record an action and the following timestep.""" if self._writer is None or not self._writer.step_is_open: @@ -158,22 +170,27 @@ def add(self, # Add the timestep to the buffer. has_extras = ( - len(extras) > 0 if isinstance(extras, Sized) # pylint: disable=g-explicit-length-test - else extras is not None) + len(extras) > 0 # pylint: disable=g-explicit-length-test + if isinstance(extras, Sized) + else extras is not None + ) current_step = self._make_step( action=action, reward=next_timestep.reward, discount=next_timestep.discount, - extras=extras if has_extras else self._none_step.extras) + extras=extras if has_extras else self._none_step.extras, + ) self._writer.append(current_step) # Record the next observation and write. self._writer.append( data=self._make_step( observation=next_timestep.observation, - start_of_episode=next_timestep.first()), - partial_step=True) + start_of_episode=next_timestep.first(), + ), + partial_step=True, + ) self._writer.flush(self._max_in_flight_items) if next_timestep.last(): @@ -181,7 +198,8 @@ def add(self, # TODO(b/183945808): remove this when fields are no longer expected to be # of equal length on the learner side. dummy_step = tree.map_structure( - lambda x: None if x is None else np.zeros_like(x), current_step) + lambda x: None if x is None else np.zeros_like(x), current_step + ) self._writer.append(dummy_step) self.reset() @@ -192,7 +210,8 @@ def create_step_spec( return Step( *environment_spec, start_of_episode=tf.TensorSpec([], tf.bool, 'start_of_episode'), - extras=extras_spec) + extras=extras_spec, + ) def _last_n(n: int, step_spec: Step) -> Trajectory: @@ -227,8 +246,8 @@ def create_sequence_config( end_of_episode_behavior: Determines how sequences at the end of the episode are handled (default `EndOfEpisodeBehavior.TRUNCATE`). See the docstring of `EndOfEpisodeBehavior` for more information. - sequence_pattern: Transformation to obtain a sequence given the length - and the shape of the step. + sequence_pattern: Transformation to obtain a sequence given the length and + the shape of the step. Returns: A list of configs for `StructuredAdder` to produce the described behaviour. @@ -242,14 +261,16 @@ def create_sequence_config( if end_of_episode_behavior == EndBehavior.ZERO_PAD: raise NotImplementedError( - 'Zero-padding is not supported. Please use TRUNCATE instead.') + 'Zero-padding is not supported. Please use TRUNCATE instead.' + ) if end_of_episode_behavior == EndBehavior.CONTINUE: raise NotImplementedError('Merging episodes is not supported.') def _sequence_pattern(n: int) -> sw.Pattern: - return sw.pattern_from_transform(step_spec, - lambda step: sequence_pattern(n, step)) + return sw.pattern_from_transform( + step_spec, lambda step: sequence_pattern(n, step) + ) # The base config is considered for all but the last step in the episode. No # trajectories are created for the first `sequence_step-1` steps and then a @@ -260,7 +281,8 @@ def _sequence_pattern(n: int) -> sw.Pattern: conditions=[ sw.Condition.step_index() >= sequence_length - 1, sw.Condition.step_index() % period == (sequence_length - 1) % period, - ]) + ], + ) end_of_episode_configs = [] if end_of_episode_behavior == EndBehavior.WRITE: @@ -275,7 +297,8 @@ def _sequence_pattern(n: int) -> sw.Pattern: conditions=[ sw.Condition.is_end_episode(), sw.Condition.step_index() >= sequence_length - 1, - ]) + ], + ) end_of_episode_configs.append(config) elif end_of_episode_behavior == EndBehavior.TRUNCATE: # The first trajectory is written at step index `sequence_length - 1` and @@ -315,7 +338,8 @@ def _sequence_pattern(n: int) -> sw.Pattern: sw.Condition.is_end_episode(), sw.Condition.step_index() % period == x, sw.Condition.step_index() >= sequence_length, - ]) + ], + ) end_of_episode_configs.append(config) # The above configs will capture the "remainder" of any episode that is at @@ -330,11 +354,13 @@ def _sequence_pattern(n: int) -> sw.Pattern: conditions=[ sw.Condition.is_end_episode(), sw.Condition.step_index() == x - 1, - ]) + ], + ) end_of_episode_configs.append(config) else: raise ValueError( - f'Unexpected `end_of_episod_behavior`: {end_of_episode_behavior}') + f'Unexpected `end_of_episod_behavior`: {end_of_episode_behavior}' + ) return [base_config] + end_of_episode_configs @@ -342,7 +368,8 @@ def _sequence_pattern(n: int) -> sw.Pattern: def create_n_step_transition_config( step_spec: Step, n_step: int, - table: str = reverb_base.DEFAULT_PRIORITY_TABLE) -> List[sw.Config]: + table: str = reverb_base.DEFAULT_PRIORITY_TABLE, +) -> List[sw.Config]: """Generates configs that replicates the behaviour of NStepTransitionAdder. Please see the docstring of NStepTransitionAdder for more details. @@ -370,9 +397,9 @@ def create_n_step_transition_config( def _make_pattern(n: int): ref_step = sw.create_reference_step(step_spec) - get_first = lambda x: x[-(n + 1):-n] - get_all = lambda x: x[-(n + 1):-1] - get_first_and_last = lambda x: x[-(n + 1)::n] + get_first = lambda x: x[-(n + 1) : -n] + get_all = lambda x: x[-(n + 1) : -1] + get_first_and_last = lambda x: x[-(n + 1) :: n] tmap = tree.map_structure @@ -388,7 +415,8 @@ def _make_pattern(n: int): reward=tmap(get_all, ref_step.reward), discount=tmap(get_all, ref_step.discount), start_of_episode=tmap(get_first, ref_step.start_of_episode), - extras=tmap(get_first, ref_step.extras)) + extras=tmap(get_first, ref_step.extras), + ) # At the start of the episodes we'll add shorter transitions. start_of_episode_configs = [] @@ -422,3 +450,88 @@ def _make_pattern(n: int): end_of_episode_configs.append(config) return start_of_episode_configs + [base_config] + end_of_episode_configs + + +def n_step_from_trajectory( + trajectory: reverb_base.Trajectory, + agent_discount: float, +) -> types.Transition: + """Converts an (n+1)-step trajectory into an n-step transition.""" + + rewards, discount = _compute_cumulative_quantities( + rewards=trajectory.reward, + discounts=trajectory.discount, + additional_discount=agent_discount, + ) + + tmap = tree.map_structure + return types.Transition( + observation=tmap(lambda x: x[0], trajectory.observation), + action=tmap(lambda x: x[0], trajectory.action), + reward=rewards, + discount=discount, + next_observation=tmap(lambda x: x[-1], trajectory.observation), + extras=tmap(lambda x: x[0], trajectory.extras), + ) + + +def _compute_cumulative_quantities( + rewards: types.NestedArray, + discounts: types.NestedArray, + additional_discount: float, +): + """Stolen from TransitionAdder.""" + + # Give the same tree structure to the n-step return accumulator, + # n-step discount accumulator, and self.discount, so that they can be + # iterated in parallel using tree.map_structure. + rewards, discounts = tree_utils.broadcast_structures(rewards, discounts) + flat_rewards = tree.flatten(rewards) + flat_discounts = tree.flatten(discounts) + n_step = tf.shape(flat_rewards[0])[0] + # Initialize flat output containers. + flat_total_discounts = [] + flat_n_step_returns = [] + + def scan_body( + state: types.NestedTensor, + discount_and_reward: types.NestedTensor, + ) -> types.NestedTensor: + compound_discount, discounted_return = state + discount, reward = discount_and_reward + return ( + additional_discount * discount * compound_discount, + discounted_return + additional_discount * compound_discount * reward, + ) + + for reward, discount in zip(flat_rewards, flat_discounts): + shape = tf.broadcast_static_shape( + tf.TensorShape(reward[0].shape), + tf.TensorShape(discount[0].shape), + ) + total_discount = discount[0] + n_step_return = tf.broadcast_to(reward[0], shape) + + if n_step > 1: + # NOTE: total_discount will have one less additional_discount applied to + # it (compared to flat_discount). This is so that when the learner/update + # uses an additional discount we don't apply it twice. Inside the + # following loop we will apply this right before summing up the + # n_step_return. + total_discount, n_step_return = tf.scan( + scan_body, + (discount[1:], reward[1:]), + (total_discount, n_step_return), + ) + + # Add the last return and discount of the scan, which correspond to the + # n-step return and environment discount. + n_step_return = n_step_return[-1] + total_discount = total_discount[-1] + + flat_n_step_returns.append(n_step_return) + flat_total_discounts.append(total_discount) + + n_step_return = tree.unflatten_as(rewards, flat_n_step_returns) + total_discount = tree.unflatten_as(rewards, flat_total_discounts) + return n_step_return, total_discount diff --git a/acme/adders/reverb/structured_test.py b/acme/adders/reverb/structured_test.py index 761536e138..779c819428 100644 --- a/acme/adders/reverb/structured_test.py +++ b/acme/adders/reverb/structured_test.py @@ -21,7 +21,6 @@ from acme.adders.reverb import structured from acme.adders.reverb import test_cases from acme.adders.reverb import test_utils -from acme.utils import tree_utils import dm_env import numpy as np from reverb import structured_writer as sw @@ -93,32 +92,17 @@ def _maybe_zero_pad(flat_trajectory): signature=sw.infer_signature(configs, step_spec)) @parameterized.named_parameters(*test_cases.TEST_CASES_FOR_TRANSITION_ADDER) - def test_transition_adder(self, n_step: int, additional_discount: float, - first: dm_env.TimeStep, - steps: Sequence[dm_env.TimeStep], - expected_transitions: Sequence[types.Transition]): - + def test_transition_adder( + self, + n_step: int, + additional_discount: float, + first: dm_env.TimeStep, + steps: Sequence[dm_env.TimeStep], + expected_transitions: Sequence[types.Transition], + ): env_spec, extras_spec = test_utils.get_specs(steps[0]) step_spec = structured.create_step_spec(env_spec, extras_spec) - def _as_n_step_transition(flat_trajectory): - trajectory = tree.unflatten_as(step_spec, flat_trajectory) - - rewards, discount = _compute_cumulative_quantities( - rewards=trajectory.reward, - discounts=trajectory.discount, - additional_discount=additional_discount, - n_step=tree.flatten(trajectory.reward)[0].shape[0]) - - tmap = tree.map_structure - return types.Transition( - observation=tmap(lambda x: x[0], trajectory.observation), - action=tmap(lambda x: x[0], trajectory.action), - reward=rewards, - discount=discount, - next_observation=tmap(lambda x: x[-1], trajectory.observation), - extras=tmap(lambda x: x[0], trajectory.extras)) - configs = structured.create_n_step_transition_config( step_spec=step_spec, n_step=n_step) @@ -128,58 +112,19 @@ def _as_n_step_transition(flat_trajectory): configs=configs, step_spec=step_spec) + def n_step_from_trajectory(trajectory: Sequence[types.Transition]): + trajectory = tree.unflatten_as(step_spec, trajectory) + return structured.n_step_from_trajectory(trajectory, additional_discount) + super().run_test_adder( adder=adder, first=first, steps=steps, expected_items=expected_transitions, stack_sequence_fields=False, - item_transform=_as_n_step_transition, - signature=sw.infer_signature(configs, step_spec)) - - -def _compute_cumulative_quantities(rewards: types.NestedArray, - discounts: types.NestedArray, - additional_discount: float, n_step: int): - """Stolen from TransitionAdder.""" - - # Give the same tree structure to the n-step return accumulator, - # n-step discount accumulator, and self.discount, so that they can be - # iterated in parallel using tree.map_structure. - rewards, discounts, self_discount = tree_utils.broadcast_structures( - rewards, discounts, additional_discount) - flat_rewards = tree.flatten(rewards) - flat_discounts = tree.flatten(discounts) - flat_self_discount = tree.flatten(self_discount) - - # Copy total_discount as it is otherwise read-only. - total_discount = [np.copy(a[0]) for a in flat_discounts] - - # Broadcast n_step_return to have the broadcasted shape of - # reward * discount. - n_step_return = [ - np.copy(np.broadcast_to(r[0], - np.broadcast(r[0], d).shape)) - for r, d in zip(flat_rewards, total_discount) - ] - - # NOTE: total_discount will have one less self_discount applied to it than - # the value of self._n_step. This is so that when the learner/update uses - # an additional discount we don't apply it twice. Inside the following loop - # we will apply this right before summing up the n_step_return. - for i in range(1, n_step): - for nsr, td, r, d, sd in zip(n_step_return, total_discount, flat_rewards, - flat_discounts, flat_self_discount): - # Equivalent to: `total_discount *= self._discount`. - td *= sd - # Equivalent to: `n_step_return += reward[i] * total_discount`. - nsr += r[i] * td - # Equivalent to: `total_discount *= discount[i]`. - td *= d[i] - - n_step_return = tree.unflatten_as(rewards, n_step_return) - total_discount = tree.unflatten_as(rewards, total_discount) - return n_step_return, total_discount + item_transform=n_step_from_trajectory, + signature=sw.infer_signature(configs, step_spec), + ) if __name__ == '__main__': diff --git a/acme/agents/jax/d4pg/builder.py b/acme/agents/jax/d4pg/builder.py index feed96283d..8fe07d19b7 100644 --- a/acme/agents/jax/d4pg/builder.py +++ b/acme/agents/jax/d4pg/builder.py @@ -19,7 +19,6 @@ from acme import adders from acme import core from acme import specs -from acme import types from acme.adders import reverb as adders_reverb from acme.adders.reverb import base as reverb_base from acme.agents.jax import actor_core as actor_core_lib @@ -39,8 +38,6 @@ import reverb from reverb import rate_limiters from reverb import structured_writer as sw -import tensorflow as tf -import tree def _make_adder_config(step_spec: reverb_base.Step, n_step: int, @@ -49,76 +46,6 @@ def _make_adder_config(step_spec: reverb_base.Step, n_step: int, step_spec=step_spec, n_step=n_step, table=table) -def _as_n_step_transition(flat_trajectory: reverb.ReplaySample, - agent_discount: float) -> reverb.ReplaySample: - """Compute discounted return and total discount for N-step transitions. - - For N greater than 1, transitions are of the form: - - (s_t, a_t, r_{t:t+n}, r_{t:t+n}, s_{t+N}, e_t), - - where: - - s_t = State (observation) at time t. - a_t = Action taken from state s_t. - g = the additional discount, used by the agent to discount future returns. - r_{t:t+n} = A vector of N-step rewards: [r_t r_{t+1} ... r_{t+n}] - d_{t:t+n} = A vector of N-step environment: [d_t d_{t+1} ... d_{t+n}] - For most environments d_i is 1 for all steps except the last, - i.e. it is the episode termination signal. - s_{t+n}: The "arrival" state, i.e. the state at time t+n. - e_t [Optional]: A nested structure of any 'extras' the user wishes to add. - - As such postprocessing is necessary to calculate the N-Step discounted return - and the total discount as follows: - - (s_t, a_t, R_{t:t+n}, D_{t:t+n}, s_{t+N}, e_t), - - where: - - R_{t:t+n} = N-step discounted return, i.e. accumulated over N rewards: - R_{t:t+n} := r_t + g * d_t * r_{t+1} + ... - + g^{n-1} * d_t * ... * d_{t+n-2} * r_{t+n-1}. - D_{t:t+n}: N-step product of agent discounts g_i and environment - "discounts" d_i. - D_{t:t+n} := g^{n-1} * d_{t} * ... * d_{t+n-1}, - - Args: - flat_trajectory: An trajectory with n-step rewards and discounts to be - process. - agent_discount: An additional discount factor used by the agent to discount - futrue returns. - - Returns: - A reverb.ReplaySample with computed discounted return and total discount. - """ - trajectory = flat_trajectory.data - - def compute_discount_and_reward( - state: types.NestedTensor, - discount_and_reward: types.NestedTensor) -> types.NestedTensor: - compounded_discount, discounted_reward = state - return (agent_discount * discount_and_reward[0] * compounded_discount, - discounted_reward + discount_and_reward[1] * compounded_discount) - - initializer = (tf.constant(1, dtype=tf.float32), - tf.constant(0, dtype=tf.float32)) - elems = tf.stack((trajectory.discount, trajectory.reward), axis=-1) - total_discount, n_step_return = tf.scan( - compute_discount_and_reward, elems, initializer, reverse=True) - return reverb.ReplaySample( - info=flat_trajectory.info, - data=types.Transition( - observation=tree.map_structure(lambda x: x[0], - trajectory.observation), - action=tree.map_structure(lambda x: x[0], trajectory.action), - reward=n_step_return[0], - discount=total_discount[0], - next_observation=tree.map_structure(lambda x: x[-1], - trajectory.observation), - extras=tree.map_structure(lambda x: x[0], trajectory.extras))) - - class D4PGBuilder(builders.ActorLearnerBuilder[d4pg_networks.D4PGNetworks, actor_core_lib.ActorCore, reverb.ReplaySample]): @@ -214,8 +141,14 @@ def make_dataset_iterator( """Create a dataset iterator to use for learning/updating the agent.""" def postprocess( - flat_trajectory: reverb.ReplaySample) -> reverb.ReplaySample: - return _as_n_step_transition(flat_trajectory, self._config.discount) + flat_trajectory: reverb.ReplaySample, + ) -> reverb.ReplaySample: + return reverb.ReplaySample( + info=flat_trajectory.info, + data=adders_reverb.n_step_from_trajectory( + flat_trajectory.data, self._config.discount + ), + ) batch_size_per_device = self._config.batch_size // jax.device_count()