diff --git a/docs/source/custom_agents.rst b/docs/source/custom_agents.rst index f7394f1..987c017 100644 --- a/docs/source/custom_agents.rst +++ b/docs/source/custom_agents.rst @@ -7,12 +7,12 @@ Although our library provides a palette of already implemented :ref:`agents ` ``BaseAgent``. We present adding a -custom agent on a simple example of the epsilon-greedy agent: +custom agent on an example of a simple epsilon-greedy agent: .. code-block:: python @@ -20,11 +20,10 @@ custom agent on a simple example of the epsilon-greedy agent: Firstly, we need to define the state of our agent, which in our case will hold - * constant experiment rate (e), - * quality values of each arm (Q), - * number of each arms' tries (N), + * quality values of each arm (`Q`), + * number of each arms' tries (`N`),, -and will inherit from AgentState: +and will inherit from ``AgentState``: .. code-block:: python @@ -34,7 +33,13 @@ and will inherit from AgentState: Q: Array N: Array -Next, we can define the Epsilon-greedy agent, which will have 3 static methods: +The ``BaseAgent`` interface breaks the agent's behaviour into three methods: + + * `init(PRNGKey, ...) -> AgentState` - initializes the agent's state, + * `update(AgentState, PRNGKey, ...) -> AgentState` - updates the agent's state after performing some action and receiving a reward, + * `sample(AgentState, PRNGKey, ...) -> Action` - samples new action according to the agent's and environment's state. + +We define the Epsilon-greedy agent, which will have 3 static methods: .. code-block:: python @@ -54,7 +59,7 @@ Next, we can define the Epsilon-greedy agent, which will have 3 static methods: N=jnp.ones(n_arms, dtype=jnp.int32) ) - # This method updates the agents state after performing some action and receiving a reward + # This method updates the agents state @staticmethod def update( state: EGreedyState, @@ -78,25 +83,25 @@ Next, we can define the Epsilon-greedy agent, which will have 3 static methods: state: EGreedyState, key: PRNGKey, e: Scalar - ) -> Tuple[EGreedyState, jnp.int32]: + ) -> jnp.int32: + + # Split PRNGkey to use it twice + epsilon_key, choice_key = jax.random.split(key) # We further want to jax.jit this function, so basic 'if' is not allowed here return jax.lax.cond( - # Split PRNGkey to use it twice - epsilon_key, choice_key = jax.random.split(key) - # The agent experiments with probability e jax.random.uniform(epsilon_key) < e, # On exploration, agent chooses a random arm - lambda: (state, jax.random.choice(choice_key, state.Q.size)), + lambda: jax.random.choice(choice_key, state.Q.size), # On exploitation, agent chooses the best known arm - lambda: (state, jnp.argmax(state.Q)) + lambda: jnp.argmax(state.Q) ) -Having defined these static methods, we can define the class constructor: +Having defined these static methods, we can implement the class constructor: .. code-block:: python @@ -112,15 +117,20 @@ Having defined these static methods, we can define the class constructor: # We specify the features of our agent self.n_arms = n_arms - # Here, we can use the jax.jit() functionality with the previously - # defined behaviour functions, to make the agent super fast + # Here we can use the jax.jit() functionality with the previously + # defined behaviour functions, to make the agent super fast. + # Note that we use partial() to specify the parameters that are + # constant during the agent's lifetime to avoid passing them + # every time the function is called. self.init = jax.jit(partial(self.init, n_arms=self.n_arms)) - self.update = jax.jit(partial(self.update)) + self.update = jax.jit(self.update) self.sample = jax.jit(partial(self.sample, e=e)) -Lastly, we must specify the parameter spaces that each of the implemented methods take. -This enables the library to automatically infer the necessary parameters from the environment. -Reinforced-lib uses the `Gymnasium `_ (former OpenAI Gym) format. +Now we specify the initialization arguments of our agent (i.e., the parameters that are required by the +agent's constructor). This is done by implementing the static method ``parameter_space()`` which returns +a dictionary in the format of a `Gymnasium `_ space. It is not required +to implement this method, but it is a good practice to do so. This enables the library to automatically +provide initialization arguments specified by :ref:`extensions `. .. code-block:: python @@ -132,6 +142,26 @@ Reinforced-lib uses the `Gymnasium `_ (former Ope 'n_arms': gym.spaces.Box(1, jnp.inf, (1,), jnp.int32), 'e': gym.spaces.Box(0.0, 1.0, (1,), jnp.float32) }) + +Specifying the action space of the agent is accomplished by implementing the ``action_space`` property. +While not mandatory, adhering to this practice is recommended as it allows users to conveniently inspect +the agent's action space through the ``action_space`` method of the ``RLib`` class. + +.. code-block:: python + + # Action returned by the agent in Gymnasium format. + @property + def action_space(self) -> gym.spaces.Space: + return gym.spaces.Discrete(self.n_arms) + +Finally, we define the observation spaces for our agent by implementing the properties called +``update_observation_space`` and ``sample_observation_space``. Although not mandatory, we strongly +encourage their implementation as it allows the library to deduce absent values from raw observations +and functions defined in the :ref:`extensions `. Moreover, having these properties +implemented facilitates a seamless export of the agent to the TensorFlow Lite format, where +the library can automatically generate an example set of parameters during the export procedure. + +.. code-block:: python # Parameters required by the 'update' method in Gymnasium format. @property @@ -145,16 +175,11 @@ Reinforced-lib uses the `Gymnasium `_ (former Ope @property def sample_observation_space(self) -> gym.spaces.Dict: return gym.spaces.Dict({}) - - # Action returned by the agent in Gymnasium format. - @property - def action_space(self) -> gym.spaces.Space: - return gym.spaces.Discrete(self.n_arms) Now we have a ready to operate epsilon-greedy agent! -Template Agent +Template agent -------------- Here is all of the above code in one piece. You can copy-paste it and use as an inspiration @@ -163,7 +188,6 @@ to create your own agent. .. code-block:: python from functools import partial - from typing import Tuple import gymnasium as gym import jax @@ -191,7 +215,7 @@ to create your own agent. self.n_arms = n_arms self.init = jax.jit(partial(self.init, n_arms=n_arms)) - self.update = jax.jit(partial(self.update)) + self.update = jax.jit(self.update) self.sample = jax.jit(partial(self.sample, e=e)) @staticmethod @@ -219,12 +243,11 @@ to create your own agent. @staticmethod def init( key: PRNGKey, - n_arms: jnp.int32, - optimistic_start: Scalar + n_arms: jnp.int32 ) -> EGreedyState: return EGreedyState( - Q=(optimistic_start * jnp.ones(n_arms)), + Q=jnp.zeros(n_arms), N=jnp.ones(n_arms, dtype=jnp.int32) ) @@ -233,8 +256,7 @@ to create your own agent. state: EGreedyState, key: PRNGKey, action: jnp.int32, - reward: Scalar, - alpha: Scalar + reward: Scalar ) -> EGreedyState: return EGreedyState( @@ -247,28 +269,134 @@ to create your own agent. state: EGreedyState, key: PRNGKey, e: Scalar - ) -> Tuple[EGreedyState, jnp.int32]: + ) -> jnp.int32: epsilon_key, choice_key = jax.random.split(key) return jax.lax.cond( jax.random.uniform(epsilon_key) < e, - lambda: (state, jax.random.choice(choice_key, state.Q.size)), - lambda: (state, jnp.argmax(state.Q)) + lambda: jax.random.choice(choice_key, state.Q.size), + lambda: jnp.argmax(state.Q) ) +Deep reinforcement learning agents +---------------------------------- -Sum up ------- +Although the above example is a simple one, it is not hard to extend it to deep reinforcement learning (DRL) agents. +This can be achieved by leveraging the JAX ecosystem, along with the `haiku `_ +library, which provides a convenient way to define neural networks, and `optax `_, +which provides a set of optimizers. Below, we provide excerpts of the code for the :ref:`deep Q-learning agent +`. -To sum everything up one more time: +The state of the DRL agent often contains parameters and state of the neural network as well as an experience +replay buffer: + +.. code-block:: python + + @dataclass + class QLearningState(AgentState): + params: hk.Params + state: hk.State + opt_state: optax.OptState + + replay_buffer: ReplayBuffer + prev_env_state: Array + epsilon: Scalar + +The agent's constructor allows you to specify parameters for the neural network architecture and optimizer, enabling +users to have full control over their choice and enhancing the agent's flexibility: + +.. code-block:: python -1. Custom agent inherits from the `BaseAgent`. -2. We implement the abstract methods *init()*, *update()* and *sample()*. -3. We use *jax.jit()* to optimize the agent's performance. -4. We provide the required parameters in the format of *Gymnasium* spaces. + def __init__( + self, + q_network: hk.TransformedWithState, + optimizer: optax.GradientTransformation = None, + ... + ) -> None: + + if optimizer is None: + optimizer = optax.adam(1e-3) + + self.init = jax.jit(partial(self.init, q_network=q_network, optimizer=optimizer, ...)) + + ... + +By implementing the constructor in this manner, users gain the flexibility to define their own architecture as follows: + +.. code-block:: python + + @hk.transform_with_state + def q_network(x: Array) -> Array: + return hk.nets.MLP([64, 64, 2])(x) + + rl = RLib( + agent_type=QLearning, + agent_params={ + 'q_network': q_network, + 'optimizer': optax.rmsprop(3e-4, decay=0.95, eps=1e-2) + }, + ... + ) + +During the development of a DRL agent, our library offers a set of :ref:`utility functions ` for your convenience. +Among these functions is gradient_step, designed to streamline parameter updates for the agent using JAX and optax. +In the following example code snippet, we showcase the implementation of a step function responsible for performing +a single step, taking into account the network, optimizer, and the implemented loss function: + +.. code-block:: python + + from reinforced_lib.utils.jax_utils import gradient_step + + step_fn=partial( + gradient_step, + optimizer=optimizer, + loss_fn=partial(self.loss_fn, q_network=q_network, ...) + ) + +Our Python library also includes a pre-built :ref:`experience replay buffer `, which is commonly +utilized in DRL agents. The following code provides an illustrative example of how to use this utility: + +.. code-block:: python + + from reinforced_lib.utils.experience_replay import experience_replay, ExperienceReplay, ReplayBuffer + + er = experience_replay( + experience_replay_buffer_size, + experience_replay_batch_size, + obs_space_shape, + act_space_shape + ) + + ... + + replay_buffer = er.init() + + ... + + replay_buffer = er.append(replay_buffer, prev_env_state, action, reward, terminal, env_state) + perform_update = er.is_ready(replay_buffer) + + for _ in range(experience_replay_steps): + batch = er.sample(replay_buffer, key) + ... + +Developing a DRL agent may pose challenges, so we strongly recommend thoroughly studying an example code of one of our +`DRL agents `_ prior to building +your custom agent. + + +Summary +------- + +To sum everything up one more time: -The built-in implementation of the epsilon-greedy agent with an addition of optional optimistic start and an exponential -recency-weighted average update can be found -`here `_. +1. All agents inherit from the ``BaseAgent`` class. +2. The agent's state is defined as a dataclass that inherits from the ``AgentState`` class. +3. The agent's behavior is determined by implementing the static methods ``init``, ``update``, and ``sample``. +4. Utilizing ``jax.jit`` can significantly increase the agent's performance. +5. Although not mandatory, it is highly recommended to implement the ``parameter_space``, ``update_observation_space``, + and ``sample_observation_space`` properties. +6. Implementing a custom deep reinforcement learning agent is possible using the JAX ecosystem and utility functions + provided by the library. \ No newline at end of file