Skip to content

Commit

Permalink
Merge pull request #41 from m-wojnar/flax
Browse files Browse the repository at this point in the history
Flax
  • Loading branch information
m-wojnar authored Feb 9, 2024
2 parents b0eee91 + 6b4e8cf commit 4e33f09
Show file tree
Hide file tree
Showing 31 changed files with 864 additions and 596 deletions.
17 changes: 10 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,25 +90,28 @@ effortlessly using Reinforced-lib.

```python
import gymnasium as gym
import haiku as hk
import optax
from chex import Array
from flax import linen as nn

from reinforced_lib import RLib
from reinforced_lib.agents.deep import QLearning
from reinforced_lib.agents.deep import DQN
from reinforced_lib.exts import Gymnasium


@hk.transform_with_state
def q_network(x: Array) -> Array:
return hk.nets.MLP([256, 2])(x)
class QNetwork(nn.Module):
@nn.compact
def __call__(self, x: Array) -> Array:
x = nn.Dense(256)(x)
x = nn.relu(x)
return nn.Dense(2)(x)


if __name__ == '__main__':
rl = RLib(
agent_type=QLearning,
agent_type=DQN,
agent_params={
'q_network': q_network,
'q_network': QNetwork(),
'optimizer': optax.rmsprop(3e-4, decay=0.95, eps=1e-2),
},
ext_type=Gymnasium,
Expand Down
38 changes: 15 additions & 23 deletions docs/source/agents.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,44 +19,44 @@ BaseAgent
:members:


Deep Q-Learning
---------------
Deep Q-Learning (DQN)
---------------------

.. currentmodule:: reinforced_lib.agents.deep.q_learning
.. currentmodule:: reinforced_lib.agents.deep.dqn

.. autoclass:: QLearningState
.. autoclass:: DQNState
:show-inheritance:
:members:

.. autoclass:: QLearning
.. autoclass:: DQN
:show-inheritance:
:members:


Deep Expected SARSA
-------------------
Double Deep Q-Learning (DDQN)
-----------------------------

.. currentmodule:: reinforced_lib.agents.deep.expected_sarsa
.. currentmodule:: reinforced_lib.agents.deep.ddqn

.. autoclass:: ExpectedSarsaState
.. autoclass:: DDQNState
:show-inheritance:
:members:

.. autoclass:: ExpectedSarsa
.. autoclass:: DDQN
:show-inheritance:
:members:


Deep Double Q-Learning (DQN)
----------------------------
Deep Expected SARSA
-------------------

.. currentmodule:: reinforced_lib.agents.deep.dqn
.. currentmodule:: reinforced_lib.agents.deep.expected_sarsa

.. autoclass:: DQNState
.. autoclass:: ExpectedSarsaState
:show-inheritance:
:members:

.. autoclass:: DQN
.. autoclass:: ExpectedSarsa
:show-inheritance:
:members:

Expand Down Expand Up @@ -167,11 +167,3 @@ Upper confidence bound (UCB)
.. autoclass:: UCB
:show-inheritance:
:members:


Particle filter (Core)
----------------------

.. automodule:: reinforced_lib.agents.core.particle_filter
:show-inheritance:
:members:
40 changes: 28 additions & 12 deletions docs/source/custom_agents.rst
Original file line number Diff line number Diff line change
Expand Up @@ -284,20 +284,20 @@ Deep learning agents
--------------------

Although the above example is a simple one, it is not hard to extend it to deep reinforcement learning (DRL) agents.
This can be achieved by leveraging the JAX ecosystem, along with the `haiku <https://dm-haiku.readthedocs.io/>`_
This can be achieved by leveraging the JAX ecosystem, along with the `flax <https://flax.readthedocs.io/>`_
library, which provides a convenient way to define neural networks, and `optax <https://optax.readthedocs.io/>`_,
which provides a set of optimizers. Below, we provide excerpts of the code for the :ref:`deep Q-learning agent
<Deep Q-Learning>`.
<Deep Q-Learning (DQN)>`.

The state of the DRL agent often contains parameters and state of the neural network as well as an experience
replay buffer:

.. code-block:: python
@dataclass
class QLearningState(AgentState):
params: hk.Params
state: hk.State
class DQNState(AgentState):
params: dict
state: dict
opt_state: optax.OptState
replay_buffer: ReplayBuffer
Expand All @@ -311,7 +311,7 @@ users to have full control over their choice and enhancing the agent's flexibili
def __init__(
self,
q_network: hk.TransformedWithState,
q_network: nn.Module,
optimizer: optax.GradientTransformation = None,
...
) -> None:
Expand All @@ -327,21 +327,33 @@ By implementing the constructor in this manner, users gain the flexibility to de

.. code-block:: python
@hk.transform_with_state
def q_network(x: Array) -> Array:
return hk.nets.MLP([64, 64, 2])(x)
class QNetwork(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(64)(x)
x = nn.relu(x)
x = nn.Dense(64)(x)
x = nn.relu(x)
return nn.Dense(2)(x)
rl = RLib(
agent_type=QLearning,
agent_type=DQN,
agent_params={
'q_network': q_network,
'q_network': QNetwork(),
'optimizer': optax.rmsprop(3e-4, decay=0.95, eps=1e-2)
},
...
)
.. note::

In some cases, it is necessary to use a PRNG key in the definition of a neural network to allow the stochastic
behavior of the model. The flax library provides a ``make_rng(stream_name)`` method that can be used to generate
a PRNG key from a given stream. The DRL algorithms implemented in Reinforced-lib offer a stream called ``rlib``
by default, so you can use it in your custom model as follows: ``key = self.make_rng('rlib')``.

During the development of a DRL agent, our library offers a set of :ref:`utility functions <JAX>` for your convenience.
Among these functions is gradient_step, designed to streamline parameter updates for the agent using JAX and optax.
Among these functions is ``gradient_step``, designed to streamline parameter updates for the agent using JAX and optax.
In the following example code snippet, we showcase the implementation of a step function responsible for performing
a single step, taking into account the network, optimizer, and the implemented loss function:

Expand All @@ -355,6 +367,10 @@ a single step, taking into account the network, optimizer, and the implemented l
loss_fn=partial(self.loss_fn, q_network=q_network, ...)
)
There are also other utility functions that can make it easier to implement DRL agents with flax. These are the
``init`` and ``forward`` methods which are used to initialize the network and to perform a forward pass through the
network. You can find more information about these functions in the :ref:`documentation <Utils>`.

Our Python library also includes a pre-built :ref:`experience replay buffer <Experience Replay>`, which is commonly
utilized in DRL agents. The following code provides an illustrative example of how to use this utility:

Expand Down
10 changes: 5 additions & 5 deletions docs/source/custom_extensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ Key concepts of extensions
There are three main benefits of using extensions:

#. Automatic initialization of agents - an extension can provide default arguments that can be used to
initialize an agent. For example, if we would like to train the :ref:`deep Double Q-learning agent
<Deep Double Q-Learning (DQN)>` on a `cart-pole` environment without using any extension, we would
initialize an agent. For example, if we would like to train the :ref:`deep Q-learning agent
<Deep Q-Learning (DQN)>` on a `cart-pole` environment without using any extension, we would
probably do it in the following way:

.. code-block:: python
rl = RLib(
agent_type=QLearning,
agent_type=DQN,
agent_params={
'q_network': q_network,
'obs_space_shape': (4,),
Expand All @@ -36,7 +36,7 @@ There are three main benefits of using extensions:
.. code-block:: python
rl = RLib(
agent_type=QLearning,
agent_type=DQN,
agent_params={'q_network': q_network}
ext_type=Gymnasium,
ext_params={'env_id': 'CartPole-v1'},
Expand All @@ -47,7 +47,7 @@ There are three main benefits of using extensions:
.. code-block:: python
rl = RLib(
agent_type=QLearning,
agent_type=DQN,
agent_params={
'q_network': q_network,
'act_space_size': 3
Expand Down
83 changes: 71 additions & 12 deletions docs/source/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,28 @@ significantly simplified. Below, we present the basic training loop with the sim
.. code-block:: python
import gymnasium as gym
import haiku as hk
import optax
from chex import Array
from flax import linen as nn
from reinforced_lib import RLib
from reinforced_lib.agents.deep import QLearning
from reinforced_lib.agents.deep import DQN
from reinforced_lib.exts import Gymnasium
@hk.transform_with_state
def q_network(x: Array) -> Array:
return hk.nets.MLP([256, 2])(x)
class QNetwork(nn.Module):
@nn.compact
def __call__(self, x: Array) -> Array:
x = nn.Dense(256)(x)
x = nn.relu(x)
return nn.Dense(2)(x)
if __name__ == '__main__':
rl = RLib(
agent_type=QLearning,
agent_type=DQN,
agent_params={
'q_network': q_network,
'q_network': QNetwork(),
'optimizer': optax.rmsprop(3e-4, decay=0.95, eps=1e-2),
},
ext_type=Gymnasium,
Expand All @@ -88,7 +91,7 @@ significantly simplified. Below, we present the basic training loop with the sim
After the necessary imports, we create an instance of the ``RLib`` class. We provide the chosen
agent type and the appropriate extension for the problem. This extension will help the agent to infer necessary
information from the environment. Next create a Gymnasium environment and define the training loop. Inside the loop,
information from the environment. Next create a gymnasium environment and define the training loop. Inside the loop,
we call the ``sample`` method which passes the observations to the agent, updates the agent's internal state
and returns an action proposed by the agent's policy. We apply the returned action in the environment to get its
altered state. We encourage you to see the :ref:`API <api_page>` section for more details about the ``RLib`` class.
Expand Down Expand Up @@ -265,6 +268,62 @@ with the training, we load the whole experiment to a new RLib instance.
# Continue the training
# ...
Reinforced-lib can even save the architecture of your agent's neural network. It is possible thanks to the
`cloudpickle <https://github.com/cloudpipe/cloudpickle>`_ library allowing to serialize the flax modules.
However, if you use your own implementation of agents or extensions, you have to ensure that they are available
when you restore the experiment as Reinforced-lib does not save the source code of the custom classes.

.. note::

Remember that the ``RLib`` class will not save the state of the environment. You have to save the environment
state separately if you want to continue the training from the exact point where you ended.

.. warning::

As of today (2024-02-08), cloudpickle does not support the serialization of the custom modules defined outside of
the main definition. It means that if you implement part of your model in a separate class, you will not be able
to restore the experiment. We are working on a solution to this problem.

The temporary solution is to define the whole model in one class as follows:

.. code-block:: python
class QNetwork(nn.Module):
@nn.compact
def __call__(self, x):
class MyModule(nn.Module):
@nn.compact
def __call__(self, x):
...
return x
x = MyModule()(x)
...
return x
To improve code readability, you can also define modules in external functions and then call them to include
custom module definitions in the main class. For example:

.. code-block:: python
def my_module_fn():
class MyModule(nn.Module):
@nn.compact
def __call__(self, x):
...
return x
return MyModule
class QNetwork(nn.Module):
@nn.compact
def __call__(self, x):
MyModule = my_module_fn(x)
x = MyModule()(x)
...
return x
Dynamic parameter change
~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -276,14 +335,14 @@ optimizer:
.. code-block:: python
from reinforced_lib import RLib
from reinforced_lib.agents.deep import QLearning
from reinforced_lib.agents.deep import DQN
from reinforced_lib.exts import Gymnasium
# Setting up the experiment
rl = RLib(
agent_type=QLearning,
agent_type=DQN,
agent_params={
'q_network': q_network,
'q_network': QNetwork(),
'optimizer': optax.adam(1e-3),
},
ext_type=Gymnasium,
Expand All @@ -300,7 +359,7 @@ optimizer:
rl = RLib.load(
"<checkpoint-path>",
agent_params={
'q_network': q_network,
'q_network': QNetwork(),
'optimizer': optax.adam(1e-4),
}
)
Expand Down
7 changes: 7 additions & 0 deletions docs/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,10 @@ Experience Replay

.. automodule:: reinforced_lib.utils.experience_replay
:members:

Particle filter (Core)
----------------------

.. automodule:: reinforced_lib.utils.particle_filter
:show-inheritance:
:members:
Loading

0 comments on commit 4e33f09

Please sign in to comment.