Skip to content

Commit

Permalink
Workaround for serialization of complex modules
Browse files Browse the repository at this point in the history
  • Loading branch information
m-wojnar committed Feb 8, 2024
1 parent d0063ca commit 585cf91
Showing 1 changed file with 59 additions and 3 deletions.
62 changes: 59 additions & 3 deletions docs/source/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ significantly simplified. Below, we present the basic training loop with the sim
After the necessary imports, we create an instance of the ``RLib`` class. We provide the chosen
agent type and the appropriate extension for the problem. This extension will help the agent to infer necessary
information from the environment. Next create a Gymnasium environment and define the training loop. Inside the loop,
information from the environment. Next create a gymnasium environment and define the training loop. Inside the loop,
we call the ``sample`` method which passes the observations to the agent, updates the agent's internal state
and returns an action proposed by the agent's policy. We apply the returned action in the environment to get its
altered state. We encourage you to see the :ref:`API <api_page>` section for more details about the ``RLib`` class.
Expand Down Expand Up @@ -258,6 +258,62 @@ with the training, we load the whole experiment to a new RLib instance.
# Continue the training
# ...
Reinforced-lib can even save the architecture of your agent's neural network. It is possible thanks to the
`cloudpickle <https://github.com/cloudpipe/cloudpickle>`_ library allowing to serialize the flax modules.
However, if you use your own implementation of agents or extensions, you have to ensure that they are available
when you restore the experiment as Reinforced-lib does not save the source code of the custom classes.

.. note::

Remember that the ``RLib`` class will not save the state of the environment. You have to save the environment
state separately if you want to continue the training from the exact point where you ended.

.. warning::

As of today (2024-02-08), cloudpickle does not support the serialization of the custom modules defined outside of
the main definition. It means that if you implement part of your model in a separate class, you will not be able
to restore the experiment. We are working on a solution to this problem.

The temporary solution is to define the whole model in one class as follows:

.. code-block:: python
class QNetwork(nn.Module):
@nn.compact
def __call__(self, x):
class MyModule(nn.Module):
@nn.compact
def __call__(self, x):
...
return x
x = MyModule()(x)
...
return x
To improve code readability, you can also define modules in external functions and then call them to include
custom module definitions in the main class. For example:

.. code-block:: python
def my_module_fn():
class MyModule(nn.Module):
@nn.compact
def __call__(self, x):
...
return x
return MyModule
class QNetwork(nn.Module):
@nn.compact
def __call__(self, x):
MyModule = my_module_fn(x)
x = MyModule()(x)
...
return x
Dynamic parameter change
~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -276,7 +332,7 @@ optimizer:
rl = RLib(
agent_type=DQN,
agent_params={
'q_network': q_network,
'q_network': QNetwork(),
'optimizer': optax.adam(1e-3),
},
ext_type=Gymnasium,
Expand All @@ -293,7 +349,7 @@ optimizer:
rl = RLib.load(
"<checkpoint-path>",
agent_params={
'q_network': q_network,
'q_network': QNetwork(),
'optimizer': optax.adam(1e-4),
}
)
Expand Down

0 comments on commit 585cf91

Please sign in to comment.