Skip to content

Commit

Permalink
Merge branch 'main' into flax
Browse files Browse the repository at this point in the history
  • Loading branch information
Wotaker authored Feb 9, 2024
2 parents 8f4a629 + b0eee91 commit 30bae91
Show file tree
Hide file tree
Showing 11 changed files with 210 additions and 79 deletions.
29 changes: 3 additions & 26 deletions docs/source/custom_loggers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,11 @@ parameter. For instance:
rl = RLib(
...
logger_types=[PlotsLogger, CsvLogger, TensorboardLogger],
logger_sources='cumulative'
logger_sources=[('action', SourceType.METRIC), ('cumulative', SourceType.METRIC)]
)
In this example, three loggers (``PlotsLogger``, ``CsvLogger``, and ``TensorboardLogger``) are used, each logging
the cumulative reward.

It is also possible to mix different types of loggers. To do this, the user should specify a list of sources
of the same length as the list of loggers. Each source will be logged by the corresponding logger. For example:

.. code-block:: python
rl = RLib(
...
logger_types=[PlotsLogger, CsvLogger, TensorboardLogger],
logger_sources=[('action', SourceType.METRIC), 'cumulative', 'Q']
)
actions and cumulative rewards.

Users are not restricted to predefined sources and can log any value using the ``log`` method of the ``RLib`` class.
The ``log`` method takes two parameters: ``name`` and ``value``. The example below shows how to log a value from a
Expand All @@ -81,19 +70,7 @@ custom source:
rl.log('Epoch len', epoch_len)
Note that the ``log`` method does not take the ``SourceType`` parameter. In the provided example, all loggers
will log all the custom values passed to the ``log`` method. To log both predefined and custom values, set the
``logger_source`` as ``None`` for the desired logger, like this:

.. code-block:: python
rl = RLib(
...
logger_types=[StdoutLogger, PlotsLogger],
logger_sources=[None, 'cumulative']
)
In this example, the ``StdoutLogger`` will log all values passed to the ``log`` method, while the ``PlotsLogger``
will log only the cumulative reward.
will log all the custom values passed to the ``log`` method.

Loggers can be used to log values of different types. The base interface of loggers provides the following methods:

Expand Down
38 changes: 24 additions & 14 deletions docs/source/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,28 @@ allows you to monitor the training process in real time, create interactive visu
analysis. You can use the ``TensorboardLogger`` along with other loggers built into Reinforced-lib. To learn more about
available loggers, check out the :ref:`Logging module <logging_page>` section.

.. warning::

Some loggers perform actions upon completion of the training, such as saving the logs, closing the file, or
uploading the logs to the cloud. Therefore, it is important to gracefully close the Reinforced-lib instance
to ensure that the logs are saved properly. If you create a Reinforced-lib instance in a function, the destructor
will be called automatically when the function ends and you do not have to worry about anything. However, if
you create an instance in the main script, you have to close it manually by calling the ``finish`` method:

.. code-block:: python
rl = RLib(...)
# ...
rl.finish()
or by using the ``del`` statement:

.. code-block:: python
rl = RLib(...)
# ...
del rl
Advanced logging
~~~~~~~~~~~~~~~~

Expand All @@ -166,17 +188,7 @@ In the example above, it can be seen that ``action`` is both the name of the obs
you have to write the source name as a tuple containing a name and the type of the source ``(str, SourceType)``
as in the code above.

You can also plug multiple loggers to one source:

.. code-block:: python
rl = RLib(
...
logger_types=[StdoutLogger, CsvLogger, PlotsLogger],
logger_sources='cumulative'
)
Or mix different loggers and sources:
You can also plug multiple loggers to output the logs to different destinations simultaneously:

.. code-block:: python
Expand All @@ -186,8 +198,6 @@ Or mix different loggers and sources:
logger_sources=['terminal', 'epsilon', ('action', SourceType.METRIC)],
)
In this case remember to provide a list of loggers that has the same length as the list of sources, because given
loggers will be used to log values for consecutive sources.
Custom logging
~~~~~~~~~~~~~~
Expand All @@ -213,7 +223,7 @@ It is possible to mix predefined sources with custom ones:
rl = RLib(
...
logger_types=[TensorboardLogger, PlotsLogger, StdoutLogger],
logger_sources=[None, None, ('reward', SourceType.METRIC)]
logger_sources=('reward', SourceType.METRIC)
)
rl.log('my_value', 42)
Expand Down
10 changes: 10 additions & 0 deletions docs/source/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,13 @@ TensorboardLogger
.. autoclass:: TensorboardLogger
:show-inheritance:
:members:


WeightsAndBiasesLogger
----------------------

.. currentmodule:: reinforced_lib.logs.wandb_logger

.. autoclass:: WeightsAndBiasesLogger
:show-inheritance:
:members:
4 changes: 2 additions & 2 deletions examples/cart-pole/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from reinforced_lib import RLib
from reinforced_lib.agents.deep import DQN
from reinforced_lib.exts import Gymnasium
from reinforced_lib.logs import StdoutLogger, TensorboardLogger
from reinforced_lib.logs import StdoutLogger, TensorboardLogger, WeightsAndBiasesLogger


class QNetwork(nn.Module):
Expand Down Expand Up @@ -45,7 +45,7 @@ def run(num_epochs: int, render_every: int, seed: int) -> None:
},
ext_type=Gymnasium,
ext_params={'env_id': 'CartPole-v1'},
logger_types=[StdoutLogger, TensorboardLogger]
logger_types=[StdoutLogger, TensorboardLogger, WeightsAndBiasesLogger]
)

for epoch in range(num_epochs):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies = [
"matplotlib~=3.8.2",
"optax~=0.1.8",
"tensorboardX~=2.6.2.2"
"wandb~=0.16.2"
]

[project.optional-dependencies]
Expand Down
1 change: 1 addition & 0 deletions reinforced_lib/logs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from reinforced_lib.logs.plots_logger import PlotsLogger
from reinforced_lib.logs.stdout_logger import StdoutLogger
from reinforced_lib.logs.tb_logger import TensorboardLogger
from reinforced_lib.logs.wandb_logger import WeightsAndBiasesLogger
9 changes: 3 additions & 6 deletions reinforced_lib/logs/logs_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,8 @@ def add_logger(self, source: Source, logger_type: type, logger_params: dict[str,
self._observations_loggers[logger].append((source, source))
self._agent_state_loggers[logger].append((source, source))
self._metrics_loggers[logger].append((source, source))
elif source is None:
self._custom_loggers[logger] = [(None, None)]

self._custom_loggers[logger] = [(None, None)]
self._logger_sources[logger].append(source)
self._logger_instances[logger_type] = logger

Expand Down Expand Up @@ -143,19 +142,17 @@ def _update(loggers: dict[BaseLogger, list[Source]], get_value: Callable) -> Non
Parameters
----------
loggers : dict
Dictionary with the loggers instances and the connected sources.
Dictionary with the logger instances and the connected sources.
get_value : callable
Function that gets the selected value from the observations, state, or metrics.
"""

for logger, sources in loggers.items():
for source, name in sources:
if (value := get_value(name)) is not None:
if name is None:
if (custom := name is None):
source, value = value

custom = name is None

if is_scalar(value):
logger.log_scalar(source, value, custom)
elif is_dict(value):
Expand Down
6 changes: 3 additions & 3 deletions reinforced_lib/logs/plots_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class PlotsLogger(BaseLogger):
r"""
Logger that presents and saves values as matplotlib plots. Offers smoothing of the curve, scatter plots, and
multiple curves in a single chart (while logging arrays). ``PlotsLogger`` is able to synchronizes the logged
multiple curves in a single chart (while logging arrays). ``PlotsLogger`` is able to synchronize the logged
values in time. This means that if the same source is logged less often than other sources, the step will be
increased accordingly to maintain the appropriate spacing between the values on the x-axis.
Expand Down Expand Up @@ -127,7 +127,7 @@ def log_scalar(self, source: Source, value: Scalar, *_) -> None:

def log_array(self, source: Source, value: Array, *_) -> None:
"""
Adds a given array to the plot values.
Log values from an array to the same plot. Creates multiple line plots for each value in the array.
Parameters
----------
Expand All @@ -141,7 +141,7 @@ def log_array(self, source: Source, value: Array, *_) -> None:

def _log(self, source: Source, value: Numeric) -> None:
"""
Adds a given scalar to the plot values.
Adds a given value to the plot.
Parameters
----------
Expand Down
8 changes: 4 additions & 4 deletions reinforced_lib/logs/tb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def log_scalar(self, source: Source, value: Scalar, *_) -> None:

def log_array(self, source: Source, value: Array, *_) -> None:
"""
Adds a given array to the summary writer as a histogram.
Log values from an array to the same plot. Creates multiple line plots for each value in the array.
Parameters
----------
Expand All @@ -85,7 +85,7 @@ def log_array(self, source: Source, value: Array, *_) -> None:

name = self.source_to_name(source)
step = self._get_step(name)
self._writer.add_histogram(name, value, step)
self._writer.add_scalars(name, {str(i): float(v) for i, v in enumerate(value)}, step)

def log_dict(self, source: Source, value: dict, *_) -> None:
"""
Expand All @@ -109,8 +109,8 @@ def log_other(self, source: Source, value: any, *_) -> None:
----------
source : Source
Source of the logged value.
value : dict
Dictionary to log.
value : any
Value of any type to log.
"""

name = self.source_to_name(source)
Expand Down
151 changes: 151 additions & 0 deletions reinforced_lib/logs/wandb_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from collections import defaultdict

import wandb
from chex import Array, Scalar

from reinforced_lib.logs import BaseLogger, Source


class WeightsAndBiasesLogger(BaseLogger):
r"""
Logger that saves values to Weights & Biases [4]_ platform. ``WeightsAndBiasesLogger`` synchronizes
the logged values in time. This means that if the same source is logged less often than other sources,
the step will be increased accordingly to maintain the appropriate spacing between the values on the x-axis.
**Note**: to use this logger, you need to log into W&B before running the script. The necessary steps are
described in the official documentation [4]_.
Parameters
----------
wandb_sync_steps : bool, default=False
Set to ``True`` if you want to synchronize the logged values in time.
wandb_kwargs : dict, optional
Additional keyword arguments passed to ``wandb.init`` function.
References
----------
.. [4] Weights & Biases. https://docs.wandb.ai/
"""

def __init__(
self,
wandb_sync_steps: bool = False,
wandb_kwargs: dict = None,
**kwargs,
) -> None:
super().__init__(**kwargs)

self._sync_steps = wandb_sync_steps
self._current_values = set()
self._step = 0
self._steps = defaultdict(int)

wandb.init(**(wandb_kwargs or {}))
wandb.define_metric('*', step_metric='global_step')

def finish(self) -> None:
"""
Finishes the W&B run.
"""

wandb.finish()

def log_scalar(self, source: Source, value: Scalar, *_) -> None:
"""
Logs a scalar value to the W&B logger.
Parameters
----------
source : Source
Source of the logged value.
value : float
Scalar to log.
"""

self._log(source, value)

def log_array(self, source: Source, value: Array, *_) -> None:
"""
Logs an array to the W&B logger.
Parameters
----------
source : Source
Source of the logged value.
value : Array
Array to log.
"""

self._log(source, value)

def log_dict(self, source: Source, value: dict, *_) -> None:
"""
Logs a dictionary to the W&B logger.
Parameters
----------
source : Source
Source of the logged value.
value : dict
Dictionary to log.
"""

self._log(source, value)

def log_other(self, source: Source, value: any, *_) -> None:
"""
Logs an object to the W&B logger.
Parameters
----------
source : Source
Source of the logged value.
value : any
Value of any type to log.
"""

self._log(source, value)

def _log(self, source: Source, value: any) -> None:
"""
Adds a given value to the logger.
Parameters
----------
source : Source
Source of the logged value.
value : Numeric
Value to log.
"""

name = self.source_to_name(source)
step = self._get_step(name)
wandb.log({'global_step': step, name: value})

def _get_step(self, name: str) -> int:
"""
Returns the current step for a given source.
Parameters
----------
name : str
Name of the source.
Returns
-------
int
Current step for the given source.
"""

if self._sync_steps:
if name in self._current_values:
self._step += 1
self._current_values.clear()

self._current_values.add(name)
step = self._step
else:
step = self._steps[name] + 1

self._steps[name] = step
return step
Loading

0 comments on commit 30bae91

Please sign in to comment.