diff --git a/reinforced_lib/agents/base_agent.py b/reinforced_lib/agents/base_agent.py index 577fc97..397fba3 100644 --- a/reinforced_lib/agents/base_agent.py +++ b/reinforced_lib/agents/base_agent.py @@ -44,7 +44,7 @@ def update(state: AgentState, key: PRNGKey, *args, **kwargs) -> AgentState: @abstractmethod def sample(state: AgentState, key: PRNGKey, *args, **kwargs) -> any: """ - Selects the next action based on the current agent state. + Selects the next action based on the current environment and agent state. """ pass @@ -52,9 +52,8 @@ def sample(state: AgentState, key: PRNGKey, *args, **kwargs) -> any: @staticmethod def parameter_space() -> gym.spaces.Dict: """ - Parameter space of the agent constructor in Gymnasium format. - Type of returned value is required to be ``gym.spaces.Dict`` or ``None``. - If ``None``, the user must provide all parameters manually. + Parameters of the agent constructor in Gymnasium format. Type of returned value is required to + be ``gym.spaces.Dict`` or ``None``. If ``None``, the user must provide all parameters manually. """ return None @@ -62,7 +61,8 @@ def parameter_space() -> gym.spaces.Dict: @property def update_observation_space(self) -> gym.spaces.Space: """ - Observation space of the ``update`` method in Gymnasium format. + Observation space of the ``update`` method in Gymnasium format. Allows to infer missing + observations using an extensions and easily export the agent to TensorFlow Lite format. If ``None``, the user must provide all parameters manually. """ @@ -71,7 +71,8 @@ def update_observation_space(self) -> gym.spaces.Space: @property def sample_observation_space(self) -> gym.spaces.Space: """ - Observation space of the ``sample`` method in Gymnasium format. + Observation space of the ``sample`` method in Gymnasium format. Allows to infer missing + observations using an extensions and easily export the agent to TensorFlow Lite format. If ``None``, the user must provide all parameters manually. """ diff --git a/reinforced_lib/agents/core/particle_filter.py b/reinforced_lib/agents/core/particle_filter.py index 5c21997..3fb3553 100644 --- a/reinforced_lib/agents/core/particle_filter.py +++ b/reinforced_lib/agents/core/particle_filter.py @@ -51,7 +51,7 @@ def simple_resample(operands: tuple[ParticleFilterState, PRNGKey]) -> ParticleFi def effective_sample_size(state: ParticleFilterState, threshold: Scalar = 0.5) -> bool: r""" - Calculates the effective sample size [1]_ (ESS). If ESS is smaller than the number of sample times threshold, + Calculates the effective sample size [9]_ (ESS). If ESS is smaller than the number of sample times threshold, then a resampling is necessary. Parameters @@ -68,7 +68,7 @@ def effective_sample_size(state: ParticleFilterState, threshold: Scalar = 0.5) - References ---------- - .. [1] https://en.wikipedia.org/wiki/Effective_sample_size#Weighted_samples + .. [9] https://en.wikipedia.org/wiki/Effective_sample_size#Weighted_samples """ weights = jax.nn.softmax(state.logit_weights) diff --git a/reinforced_lib/agents/deep/ddpg.py b/reinforced_lib/agents/deep/ddpg.py index fb99f69..1d61456 100644 --- a/reinforced_lib/agents/deep/ddpg.py +++ b/reinforced_lib/agents/deep/ddpg.py @@ -68,7 +68,7 @@ class DDPGState(AgentState): class DDPG(BaseAgent): r""" - Deep deterministic policy gradient [9]_ [10]_ agent with white Gaussian noise exploration and experience replay + Deep deterministic policy gradient [3]_ [4]_ agent with white Gaussian noise exploration and experience replay buffer. The agent simultaneously learns a Q-function and a policy. The Q-function is updated using the Bellman equation. The policy is learned using the gradient of the Q-function with respect to the policy parameters, it is trained to maximize the Q-value. The agent uses two Q-networks (critics) and two policy networks (actors) @@ -113,10 +113,10 @@ class DDPG(BaseAgent): References ---------- - .. [9] David Silver, Guy Lever, Nicolas Heess, Thomas Degris, Daan Wierstra, and Martin Riedmiller. 2014. + .. [3] David Silver, Guy Lever, Nicolas Heess, Thomas Degris, Daan Wierstra, and Martin Riedmiller. 2014. Deterministic policy gradient algorithms. In Proceedings of the 31st International Conference on International Conference on Machine Learning - Volume 32 (ICML'14). JMLR.org, I–387–I–395. - .. [10] Timothy P. Lillicrap, Jonathan J. Hunt, Alexander Pritzel, Nicolas Heess, Tom Erez, Yuval Tassa, David Silver, + .. [4] Timothy P. Lillicrap, Jonathan J. Hunt, Alexander Pritzel, Nicolas Heess, Tom Erez, Yuval Tassa, David Silver, and Daan Wierstra. 2015. Continuous control with deep reinforcement learning. CoRR abs/1509.02971. """ @@ -427,8 +427,9 @@ def update( Appends the transition to the experience replay buffer and performs ``experience_replay_steps`` steps. Each step consists of sampling a batch of transitions from the experience replay buffer, calculating the Q-network loss and the policy network loss using ``q_loss_fn`` and ``a_loss_fn`` respectively, performing - a gradient step on both networks, and soft updating the target networks. The noise parameter is decayed by - ``noise_decay``. + a gradient step on both networks, and soft updating the target networks. Soft update of the parameters + is defined as :math:`\theta_{target} = \tau \theta + (1 - \tau) \theta_{target}`.The noise parameter is + decayed by ``noise_decay``. Parameters ---------- @@ -519,7 +520,7 @@ def sample( ) -> Numeric: r""" Calculates deterministic action using the policy network. Then adds white Gaussian noise with standard - deviation ``state.noise`` to the action and clips it to the range ``[min_action, max_action]``. + deviation ``state.noise`` to the action and clips it to the range :math:`[min\_action, max\_action]`. Parameters ---------- diff --git a/reinforced_lib/agents/deep/dqn.py b/reinforced_lib/agents/deep/dqn.py index 5b52302..cdd7af4 100644 --- a/reinforced_lib/agents/deep/dqn.py +++ b/reinforced_lib/agents/deep/dqn.py @@ -54,7 +54,7 @@ class DQNState(AgentState): class DQN(BaseAgent): r""" - Double Q-learning agent [8]_ with :math:`\epsilon`-greedy exploration and experience replay buffer. The agent + Double Q-learning agent [2]_ with :math:`\epsilon`-greedy exploration and experience replay buffer. The agent uses two Q-networks to stabilize the learning process and avoid overestimation of the Q-values. The main Q-network is trained to minimize the Bellman error. The target Q-network is updated with a soft update. This agent follows the off-policy learning paradigm and is suitable for environments with discrete action spaces. @@ -88,7 +88,7 @@ class DQN(BaseAgent): References ---------- - .. [8] van Hasselt, H., Guez, A., & Silver, D. (2016). Deep Reinforcement Learning with Double Q-Learning. + .. [2] van Hasselt, H., Guez, A., & Silver, D. (2016). Deep Reinforcement Learning with Double Q-Learning. Proceedings of the Thirtieth AAAI Conference on Artificial Intelligence, 2094–2100. Phoenix, Arizona: AAAI Press. """ @@ -312,7 +312,8 @@ def update( Appends the transition to the experience replay buffer and performs ``experience_replay_steps`` steps. Each step consists of sampling a batch of transitions from the experience replay buffer, calculating the loss using the ``loss_fn`` function, performing a gradient step on the main Q-network, and soft updating the target - Q-network. The :math:`\epsilon`-greedy parameter is decayed by ``epsilon_decay``. + Q-network. Soft update of the parameters is defined as :math:`\theta_{target} = \tau \theta + (1 - \tau) \theta_{target}`. + The :math:`\epsilon`-greedy parameter is decayed by ``epsilon_decay``. Parameters ---------- diff --git a/reinforced_lib/agents/deep/q_learning.py b/reinforced_lib/agents/deep/q_learning.py index 2676dc6..df5ab5b 100644 --- a/reinforced_lib/agents/deep/q_learning.py +++ b/reinforced_lib/agents/deep/q_learning.py @@ -46,9 +46,9 @@ class QLearningState(AgentState): class QLearning(BaseAgent): r""" - Deep Q-learning agent [6]_ with :math:`\epsilon`-greedy exploration and experience replay buffer. The agent uses + Deep Q-learning agent [1]_ with :math:`\epsilon`-greedy exploration and experience replay buffer. The agent uses a deep neural network to approximate the Q-value function. The Q-network is trained to minimize the Bellman - error. This agent follows the on-policy learning paradigm and is suitable for environments with discrete action + error. This agent follows the off-policy learning paradigm and is suitable for environments with discrete action spaces. Parameters @@ -78,7 +78,7 @@ class QLearning(BaseAgent): References ---------- - .. [6] Mnih, V., Kavukcuoglu, K., Silver, D., Graves, A., Antonoglou, I., Wierstra, D. & Riedmiller, M. (2013). + .. [1] Mnih, V., Kavukcuoglu, K., Silver, D., Graves, A., Antonoglou, I., Wierstra, D. & Riedmiller, M. (2013). Playing Atari with Deep Reinforcement Learning. """ diff --git a/reinforced_lib/agents/mab/e_greedy.py b/reinforced_lib/agents/mab/e_greedy.py index 03220f0..f4f3b61 100644 --- a/reinforced_lib/agents/mab/e_greedy.py +++ b/reinforced_lib/agents/mab/e_greedy.py @@ -27,7 +27,7 @@ class EGreedyState(AgentState): class EGreedy(BaseAgent): r""" - Epsilon-greedy agent with an optimistic start behavior and optional exponential recency-weighted average update. + Epsilon-greedy [5]_ agent with an optimistic start behavior and optional exponential recency-weighted average update. It selects a random action from a set of all actions :math:`\mathscr{A}` with probability :math:`\epsilon` (exploration), otherwise it chooses the currently best action (exploitation). @@ -41,6 +41,10 @@ class EGreedy(BaseAgent): Interpreted as the optimistic start to encourage exploration in the early stages. alpha : float, default=0.0 If non-zero, exponential recency-weighted average is used to update :math:`Q` values. :math:`\alpha \in [0, 1]`. + + References + ---------- + .. [5] Richard Sutton and Andrew Barto. 2018. Reinforcement Learning: An Introduction. The MIT Press. """ def __init__( diff --git a/reinforced_lib/agents/mab/exp3.py b/reinforced_lib/agents/mab/exp3.py index 4cdc4de..eef4a16 100644 --- a/reinforced_lib/agents/mab/exp3.py +++ b/reinforced_lib/agents/mab/exp3.py @@ -25,7 +25,7 @@ class Exp3State(AgentState): class Exp3(BaseAgent): r""" Basic Exp3 agent for stationary multi-armed bandit problems with exploration factor :math:`\gamma`. The higher - the value, the more the agent explores. The implementation is inspired by the work of Auer et al. [7]_. There + the value, the more the agent explores. The implementation is inspired by the work of Auer et al. [6]_. There are many variants of the Exp3 algorithm, you can find more information in the original paper. Parameters @@ -41,8 +41,8 @@ class Exp3(BaseAgent): References ---------- - .. [7] Auer, P., Cesa-Bianchi, N., Freund, Y., & Schapire, R. E. (2002). The Nonstochastic Multiarmed - Bandit Problem. SIAM Journal on Computing, 32(1), 48–77. doi:10.1137/S0097539701398375 + .. [6] Peter Auer, Nicolò Cesa-Bianchi, Yoav Freund, and Robert E. Schapire. 2002. The Nonstochastic Multiarmed + Bandit Problem. SIAM Journal on Computing, 32(1), 48–77. """ def __init__( diff --git a/reinforced_lib/agents/mab/softmax.py b/reinforced_lib/agents/mab/softmax.py index 529f969..64b4f3d 100644 --- a/reinforced_lib/agents/mab/softmax.py +++ b/reinforced_lib/agents/mab/softmax.py @@ -32,7 +32,7 @@ class Softmax(BaseAgent): r""" Softmax agent with baseline and optional exponential recency-weighted average update. It learns a preference function :math:`H`, which indicates a preference of selecting one arm over others. Algorithm policy can be - controlled by the temperature parameter :math:`\tau`. The implementation is inspired by the work of Sutton and Barto [3]_. + controlled by the temperature parameter :math:`\tau`. The implementation is inspired by the work of Sutton and Barto [5]_. Parameters ---------- @@ -46,10 +46,6 @@ class Softmax(BaseAgent): Temperature parameter. :math:`\tau > 0`. multiplier : float, default=1.0 Multiplier for the reward. :math:`multiplier > 0`. - - References - ---------- - .. [3] Sutton, R. S., Barto, A. G. (2018). Reinforcement Learning: An Introduction. The MIT Press. 37-40. """ def __init__( @@ -144,7 +140,7 @@ def update( H_{t + 1}(a) = H_t(a) + \alpha (R_t - \bar{R}_t)(\mathbb{1}_{A_t = a} - \pi_t(a)), where :math:`\bar{R_t}` is the average of all rewards up to but not including step :math:`t` - (by definition :math:`\bar{R}_1 = R_1`). The derivation of given formula can be found in [3]_. + (by definition :math:`\bar{R}_1 = R_1`). The derivation of given formula can be found in [5]_. In the stationary case, :math:`\bar{R_t}` can be calculated as :math:`\bar{R}_{t + 1} = \bar{R}_t + \frac{1}{t} \lbrack R_t - \bar{R}_t \rbrack`. To improve the diff --git a/reinforced_lib/agents/mab/thompson_sampling.py b/reinforced_lib/agents/mab/thompson_sampling.py index 064299b..8c79c7a 100644 --- a/reinforced_lib/agents/mab/thompson_sampling.py +++ b/reinforced_lib/agents/mab/thompson_sampling.py @@ -28,7 +28,7 @@ class ThompsonSamplingState(AgentState): class ThompsonSampling(BaseAgent): r""" Contextual Bernoulli Thompson sampling agent with the exponential smoothing. The implementation is inspired by the - work of Krotov et al. [4]_. Thompson sampling is based on a beta distribution with parameters related to the number + work of Krotov et al. [7]_. Thompson sampling is based on a beta distribution with parameters related to the number of successful and failed attempts. Higher values of the parameters decrease the entropy of the distribution while changing the ratio of the parameters shifts the expected value. @@ -41,8 +41,8 @@ class ThompsonSampling(BaseAgent): References ---------- - .. [4] Krotov, Alexander & Kiryanov, Anton & Khorov, Evgeny. (2020). Rate Control With Spatial Reuse - for Wi-Fi 6 Dense Deployments. IEEE Access. 8. 168898-168909. 10.1109/ACCESS.2020.3023552. + .. [7] Alexander Krotov, Anton Kiryanov and Evgeny Khorov. 2020. Rate Control With Spatial Reuse + for Wi-Fi 6 Dense Deployments. IEEE Access. 8. 168898-168909. """ def __init__(self, n_arms: jnp.int32, decay: Scalar = 1.0) -> None: diff --git a/reinforced_lib/agents/mab/ucb.py b/reinforced_lib/agents/mab/ucb.py index 1c5abfd..148545d 100644 --- a/reinforced_lib/agents/mab/ucb.py +++ b/reinforced_lib/agents/mab/ucb.py @@ -38,11 +38,11 @@ class UCB(BaseAgent): c : float Degree of exploration. :math:`c \geq 0`. gamma : float, default=1.0 - If less than one, a discounted UCB algorithm [5]_ is used. :math:`\gamma \in (0, 1]`. + If less than one, a discounted UCB algorithm [8]_ is used. :math:`\gamma \in (0, 1]`. References ---------- - .. [5] Garivier, A., & Moulines, E. (2008). On Upper-Confidence Bound Policies for Non-Stationary + .. [8] Aurélien Garivier, Eric Moulines. 2008. On Upper-Confidence Bound Policies for Non-Stationary Bandit Problems. 10.48550/ARXIV.0805.3415. """ diff --git a/reinforced_lib/exts/base_ext.py b/reinforced_lib/exts/base_ext.py index 45152fc..ed8d255 100644 --- a/reinforced_lib/exts/base_ext.py +++ b/reinforced_lib/exts/base_ext.py @@ -12,8 +12,8 @@ class BaseExt(ABC): """ Container for domain-specific knowledge and functions for a given environment. Provides the transformation - from the observation functions and the observation space to the agent update and sample spaces. Stores the - default argument values for agent initialization. + from the raw observations to the agent update and sample spaces. Stores the default argument values for + agent initialization. """ def __init__(self) -> None: @@ -47,9 +47,8 @@ def get_agent_params( user_parameters: dict[str, any] = None ) -> dict[str, any]: """ - Composes agent initialization parameters from parameters passed by the user and default values - defined in the parameter functions. Returns a dictionary with the parameters fitting the agent - parameters space. + Composes agent initialization arguments from values passed by the user and default values stored in the + parameter functions. Returns a dictionary with the parameters matching the agent parameters space. Parameters ---------- @@ -104,8 +103,8 @@ def setup_transformations( agent_sample_space: gym.spaces.Space = None ) -> None: """ - Create functions that transform environment observations and values provided by the observation functions - to the agent update and sample space. + Creates functions that transform raw observations and values provided by the observation functions + to the agent update and sample spaces. Parameters ---------- @@ -372,16 +371,15 @@ def _tuple_transform(self, observations: list[Callable], accessor: Union[str, in def transform(self, *args, action: any = None, **kwargs) -> tuple[any, any]: """ - Transforms environment observations and values provided by the observation functions to - the agent observation and sample spaces. Supplies action selected by the agent if it is - required by the agent and the extension is not capable of providing this value. + Transforms raw observations and values provided by the observation functions to the agent observation + and sample spaces. Provides the last action selected by the agent if it is required by the agent. Parameters ---------- *args : tuple Environment observations. action : any - Action selected by the agent. + The last action selected by the agent. **kwargs : dict Environment observations. diff --git a/reinforced_lib/exts/gymnasium.py b/reinforced_lib/exts/gymnasium.py index 2adcd47..e1c8899 100644 --- a/reinforced_lib/exts/gymnasium.py +++ b/reinforced_lib/exts/gymnasium.py @@ -6,7 +6,7 @@ class Gymnasium(BaseExt): """ - Gymnasium [1]_ extension. Simplifies interaction of deep RL agents with the Gymnasium environments by providing + Gymnasium [1]_ extension. Simplifies interaction of RL agents with the Gymnasium environments by providing the environment state, reward, terminal flag, and shapes of the observation and action spaces. Parameters diff --git a/reinforced_lib/logs/base_logger.py b/reinforced_lib/logs/base_logger.py index 14e695b..ae75057 100644 --- a/reinforced_lib/logs/base_logger.py +++ b/reinforced_lib/logs/base_logger.py @@ -18,7 +18,7 @@ class SourceType(Enum): class BaseLogger(ABC): """ - Container for functions of a logger. Provides a simple interface for defining custom loggers. + Base interface for loggers. """ def __init__(self, **kwargs): @@ -26,7 +26,7 @@ def __init__(self, **kwargs): def init(self, sources: list[Source]) -> None: """ - Initializes the logger given the list of all sources. + Initializes the logger given the list of all sources defined by the user. Parameters ---------- @@ -38,7 +38,7 @@ def init(self, sources: list[Source]) -> None: def finish(self) -> None: """ - Finalizes the loggers work, for example, saves data or shows plots. + Finalizes the loggers work (e.g., closes file or shows plots). """ pass @@ -61,7 +61,7 @@ def log_scalar(self, source: Source, value: Scalar, custom: bool) -> None: def log_array(self, source: Source, value: Array, custom: bool) -> None: """ - Method of the logger interface used for logging arrays. + Method of the logger interface used for logging one-dimensional arrays. Parameters ---------- @@ -110,7 +110,7 @@ def log_other(self, source: Source, value: any, custom: bool) -> None: @staticmethod def source_to_name(source: Source) -> str: """ - Converts the source to a string. If source is a string itself, it returns that string. + Returns a full name of the source. If source is a string itself, returns that string. Otherwise, it returns a string in the format "name-sourcetype" (e.g., "action-metric"). Parameters diff --git a/reinforced_lib/logs/csv_logger.py b/reinforced_lib/logs/csv_logger.py index e7b76be..d32eb19 100644 --- a/reinforced_lib/logs/csv_logger.py +++ b/reinforced_lib/logs/csv_logger.py @@ -12,7 +12,9 @@ class CsvLogger(BaseLogger): """ - Logger that saves values in CSV format. + Logger that saves values in CSV format. It saves the logged values to the CSV file when the experiment is finished. + ``CsvLogger`` synchronizes the logged values in time. It means that if the same source is logged twice in a row, + the step number will be incremented for all columns and the logger will move to the next row. Parameters ---------- diff --git a/reinforced_lib/logs/logs_observer.py b/reinforced_lib/logs/logs_observer.py index 12bf303..cd44588 100644 --- a/reinforced_lib/logs/logs_observer.py +++ b/reinforced_lib/logs/logs_observer.py @@ -10,7 +10,7 @@ class LogsObserver: """ Class responsible for managing singleton instances of the loggers, initialization and finalization - of the loggers, and passing the logged values to the appropriate loggers and methods. + of the loggers, and passing the logged values to the appropriate loggers and their methods. """ def __init__(self) -> None: @@ -123,7 +123,7 @@ def update_metrics(self, metric: any, metric_name: str) -> None: def update_custom(self, value: any, name: str) -> None: """ - Passes custom values to loggers. + Passes values provided by the user to the loggers. Parameters ---------- diff --git a/reinforced_lib/logs/plots_logger.py b/reinforced_lib/logs/plots_logger.py index f15817d..419f33c 100644 --- a/reinforced_lib/logs/plots_logger.py +++ b/reinforced_lib/logs/plots_logger.py @@ -11,8 +11,10 @@ class PlotsLogger(BaseLogger): r""" - Logger that presents and saves values as line plots. Offers smoothing of the curve and plotting - multiple curves in a single chart (while logging arrays). + Logger that presents and saves values as matplotlib plots. Offers smoothing of the curve, scatter plots, and + multiple curves in a single chart (while logging arrays). ``PlotsLogger`` is able to synchronizes the logged + values in time. This means that if the same source is logged less often than other sources, the step will be + increased accordingly to maintain the appropriate spacing between the values on the x-axis. Parameters ---------- diff --git a/reinforced_lib/logs/tb_logger.py b/reinforced_lib/logs/tb_logger.py index bc3fc86..c235201 100644 --- a/reinforced_lib/logs/tb_logger.py +++ b/reinforced_lib/logs/tb_logger.py @@ -10,6 +10,9 @@ class TensorboardLogger(BaseLogger): """ Logger that saves values in TensorBoard [2]_ format. Offers a possibility to log to Comet [3]_. + ``TensorboardLogger`` synchronizes the logged values in time. This means that if the same source + is logged less often than other sources, the step will be increased accordingly to maintain the + appropriate spacing between the values on the x-axis. Parameters ----------