Skip to content

Commit

Permalink
Update of the API docs
Browse files Browse the repository at this point in the history
  • Loading branch information
m-wojnar committed Jul 20, 2023
1 parent 6a38384 commit 3a8fd2b
Show file tree
Hide file tree
Showing 17 changed files with 65 additions and 57 deletions.
13 changes: 7 additions & 6 deletions reinforced_lib/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,25 @@ def update(state: AgentState, key: PRNGKey, *args, **kwargs) -> AgentState:
@abstractmethod
def sample(state: AgentState, key: PRNGKey, *args, **kwargs) -> any:
"""
Selects the next action based on the current agent state.
Selects the next action based on the current environment and agent state.
"""

pass

@staticmethod
def parameter_space() -> gym.spaces.Dict:
"""
Parameter space of the agent constructor in Gymnasium format.
Type of returned value is required to be ``gym.spaces.Dict`` or ``None``.
If ``None``, the user must provide all parameters manually.
Parameters of the agent constructor in Gymnasium format. Type of returned value is required to
be ``gym.spaces.Dict`` or ``None``. If ``None``, the user must provide all parameters manually.
"""

return None

@property
def update_observation_space(self) -> gym.spaces.Space:
"""
Observation space of the ``update`` method in Gymnasium format.
Observation space of the ``update`` method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If ``None``, the user must provide all parameters manually.
"""

Expand All @@ -71,7 +71,8 @@ def update_observation_space(self) -> gym.spaces.Space:
@property
def sample_observation_space(self) -> gym.spaces.Space:
"""
Observation space of the ``sample`` method in Gymnasium format.
Observation space of the ``sample`` method in Gymnasium format. Allows to infer missing
observations using an extensions and easily export the agent to TensorFlow Lite format.
If ``None``, the user must provide all parameters manually.
"""

Expand Down
4 changes: 2 additions & 2 deletions reinforced_lib/agents/core/particle_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def simple_resample(operands: tuple[ParticleFilterState, PRNGKey]) -> ParticleFi

def effective_sample_size(state: ParticleFilterState, threshold: Scalar = 0.5) -> bool:
r"""
Calculates the effective sample size [1]_ (ESS). If ESS is smaller than the number of sample times threshold,
Calculates the effective sample size [9]_ (ESS). If ESS is smaller than the number of sample times threshold,
then a resampling is necessary.
Parameters
Expand All @@ -68,7 +68,7 @@ def effective_sample_size(state: ParticleFilterState, threshold: Scalar = 0.5) -
References
----------
.. [1] https://en.wikipedia.org/wiki/Effective_sample_size#Weighted_samples
.. [9] https://en.wikipedia.org/wiki/Effective_sample_size#Weighted_samples
"""

weights = jax.nn.softmax(state.logit_weights)
Expand Down
13 changes: 7 additions & 6 deletions reinforced_lib/agents/deep/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class DDPGState(AgentState):

class DDPG(BaseAgent):
r"""
Deep deterministic policy gradient [9]_ [10]_ agent with white Gaussian noise exploration and experience replay
Deep deterministic policy gradient [3]_ [4]_ agent with white Gaussian noise exploration and experience replay
buffer. The agent simultaneously learns a Q-function and a policy. The Q-function is updated using the Bellman
equation. The policy is learned using the gradient of the Q-function with respect to the policy parameters,
it is trained to maximize the Q-value. The agent uses two Q-networks (critics) and two policy networks (actors)
Expand Down Expand Up @@ -113,10 +113,10 @@ class DDPG(BaseAgent):
References
----------
.. [9] David Silver, Guy Lever, Nicolas Heess, Thomas Degris, Daan Wierstra, and Martin Riedmiller. 2014.
.. [3] David Silver, Guy Lever, Nicolas Heess, Thomas Degris, Daan Wierstra, and Martin Riedmiller. 2014.
Deterministic policy gradient algorithms. In Proceedings of the 31st International Conference on International
Conference on Machine Learning - Volume 32 (ICML'14). JMLR.org, I–387–I–395.
.. [10] Timothy P. Lillicrap, Jonathan J. Hunt, Alexander Pritzel, Nicolas Heess, Tom Erez, Yuval Tassa, David Silver,
.. [4] Timothy P. Lillicrap, Jonathan J. Hunt, Alexander Pritzel, Nicolas Heess, Tom Erez, Yuval Tassa, David Silver,
and Daan Wierstra. 2015. Continuous control with deep reinforcement learning. CoRR abs/1509.02971.
"""

Expand Down Expand Up @@ -427,8 +427,9 @@ def update(
Appends the transition to the experience replay buffer and performs ``experience_replay_steps`` steps.
Each step consists of sampling a batch of transitions from the experience replay buffer, calculating the
Q-network loss and the policy network loss using ``q_loss_fn`` and ``a_loss_fn`` respectively, performing
a gradient step on both networks, and soft updating the target networks. The noise parameter is decayed by
``noise_decay``.
a gradient step on both networks, and soft updating the target networks. Soft update of the parameters
is defined as :math:`\theta_{target} = \tau \theta + (1 - \tau) \theta_{target}`.The noise parameter is
decayed by ``noise_decay``.
Parameters
----------
Expand Down Expand Up @@ -519,7 +520,7 @@ def sample(
) -> Numeric:
r"""
Calculates deterministic action using the policy network. Then adds white Gaussian noise with standard
deviation ``state.noise`` to the action and clips it to the range ``[min_action, max_action]``.
deviation ``state.noise`` to the action and clips it to the range :math:`[min\_action, max\_action]`.
Parameters
----------
Expand Down
7 changes: 4 additions & 3 deletions reinforced_lib/agents/deep/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class DQNState(AgentState):

class DQN(BaseAgent):
r"""
Double Q-learning agent [8]_ with :math:`\epsilon`-greedy exploration and experience replay buffer. The agent
Double Q-learning agent [2]_ with :math:`\epsilon`-greedy exploration and experience replay buffer. The agent
uses two Q-networks to stabilize the learning process and avoid overestimation of the Q-values. The main Q-network
is trained to minimize the Bellman error. The target Q-network is updated with a soft update. This agent follows
the off-policy learning paradigm and is suitable for environments with discrete action spaces.
Expand Down Expand Up @@ -88,7 +88,7 @@ class DQN(BaseAgent):
References
----------
.. [8] van Hasselt, H., Guez, A., & Silver, D. (2016). Deep Reinforcement Learning with Double Q-Learning.
.. [2] van Hasselt, H., Guez, A., & Silver, D. (2016). Deep Reinforcement Learning with Double Q-Learning.
Proceedings of the Thirtieth AAAI Conference on Artificial Intelligence, 2094–2100. Phoenix, Arizona: AAAI Press.
"""

Expand Down Expand Up @@ -312,7 +312,8 @@ def update(
Appends the transition to the experience replay buffer and performs ``experience_replay_steps`` steps.
Each step consists of sampling a batch of transitions from the experience replay buffer, calculating the loss
using the ``loss_fn`` function, performing a gradient step on the main Q-network, and soft updating the target
Q-network. The :math:`\epsilon`-greedy parameter is decayed by ``epsilon_decay``.
Q-network. Soft update of the parameters is defined as :math:`\theta_{target} = \tau \theta + (1 - \tau) \theta_{target}`.
The :math:`\epsilon`-greedy parameter is decayed by ``epsilon_decay``.
Parameters
----------
Expand Down
6 changes: 3 additions & 3 deletions reinforced_lib/agents/deep/q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ class QLearningState(AgentState):

class QLearning(BaseAgent):
r"""
Deep Q-learning agent [6]_ with :math:`\epsilon`-greedy exploration and experience replay buffer. The agent uses
Deep Q-learning agent [1]_ with :math:`\epsilon`-greedy exploration and experience replay buffer. The agent uses
a deep neural network to approximate the Q-value function. The Q-network is trained to minimize the Bellman
error. This agent follows the on-policy learning paradigm and is suitable for environments with discrete action
error. This agent follows the off-policy learning paradigm and is suitable for environments with discrete action
spaces.
Parameters
Expand Down Expand Up @@ -78,7 +78,7 @@ class QLearning(BaseAgent):
References
----------
.. [6] Mnih, V., Kavukcuoglu, K., Silver, D., Graves, A., Antonoglou, I., Wierstra, D. & Riedmiller, M. (2013).
.. [1] Mnih, V., Kavukcuoglu, K., Silver, D., Graves, A., Antonoglou, I., Wierstra, D. & Riedmiller, M. (2013).
Playing Atari with Deep Reinforcement Learning.
"""

Expand Down
6 changes: 5 additions & 1 deletion reinforced_lib/agents/mab/e_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class EGreedyState(AgentState):

class EGreedy(BaseAgent):
r"""
Epsilon-greedy agent with an optimistic start behavior and optional exponential recency-weighted average update.
Epsilon-greedy [5]_ agent with an optimistic start behavior and optional exponential recency-weighted average update.
It selects a random action from a set of all actions :math:`\mathscr{A}` with probability
:math:`\epsilon` (exploration), otherwise it chooses the currently best action (exploitation).
Expand All @@ -41,6 +41,10 @@ class EGreedy(BaseAgent):
Interpreted as the optimistic start to encourage exploration in the early stages.
alpha : float, default=0.0
If non-zero, exponential recency-weighted average is used to update :math:`Q` values. :math:`\alpha \in [0, 1]`.
References
----------
.. [5] Richard Sutton and Andrew Barto. 2018. Reinforcement Learning: An Introduction. The MIT Press.
"""

def __init__(
Expand Down
6 changes: 3 additions & 3 deletions reinforced_lib/agents/mab/exp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Exp3State(AgentState):
class Exp3(BaseAgent):
r"""
Basic Exp3 agent for stationary multi-armed bandit problems with exploration factor :math:`\gamma`. The higher
the value, the more the agent explores. The implementation is inspired by the work of Auer et al. [7]_. There
the value, the more the agent explores. The implementation is inspired by the work of Auer et al. [6]_. There
are many variants of the Exp3 algorithm, you can find more information in the original paper.
Parameters
Expand All @@ -41,8 +41,8 @@ class Exp3(BaseAgent):
References
----------
.. [7] Auer, P., Cesa-Bianchi, N., Freund, Y., & Schapire, R. E. (2002). The Nonstochastic Multiarmed
Bandit Problem. SIAM Journal on Computing, 32(1), 48–77. doi:10.1137/S0097539701398375
.. [6] Peter Auer, Nicolò Cesa-Bianchi, Yoav Freund, and Robert E. Schapire. 2002. The Nonstochastic Multiarmed
Bandit Problem. SIAM Journal on Computing, 32(1), 48–77.
"""

def __init__(
Expand Down
8 changes: 2 additions & 6 deletions reinforced_lib/agents/mab/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Softmax(BaseAgent):
r"""
Softmax agent with baseline and optional exponential recency-weighted average update. It learns a preference
function :math:`H`, which indicates a preference of selecting one arm over others. Algorithm policy can be
controlled by the temperature parameter :math:`\tau`. The implementation is inspired by the work of Sutton and Barto [3]_.
controlled by the temperature parameter :math:`\tau`. The implementation is inspired by the work of Sutton and Barto [5]_.
Parameters
----------
Expand All @@ -46,10 +46,6 @@ class Softmax(BaseAgent):
Temperature parameter. :math:`\tau > 0`.
multiplier : float, default=1.0
Multiplier for the reward. :math:`multiplier > 0`.
References
----------
.. [3] Sutton, R. S., Barto, A. G. (2018). Reinforcement Learning: An Introduction. The MIT Press. 37-40.
"""

def __init__(
Expand Down Expand Up @@ -144,7 +140,7 @@ def update(
H_{t + 1}(a) = H_t(a) + \alpha (R_t - \bar{R}_t)(\mathbb{1}_{A_t = a} - \pi_t(a)),
where :math:`\bar{R_t}` is the average of all rewards up to but not including step :math:`t`
(by definition :math:`\bar{R}_1 = R_1`). The derivation of given formula can be found in [3]_.
(by definition :math:`\bar{R}_1 = R_1`). The derivation of given formula can be found in [5]_.
In the stationary case, :math:`\bar{R_t}` can be calculated as
:math:`\bar{R}_{t + 1} = \bar{R}_t + \frac{1}{t} \lbrack R_t - \bar{R}_t \rbrack`. To improve the
Expand Down
6 changes: 3 additions & 3 deletions reinforced_lib/agents/mab/thompson_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ThompsonSamplingState(AgentState):
class ThompsonSampling(BaseAgent):
r"""
Contextual Bernoulli Thompson sampling agent with the exponential smoothing. The implementation is inspired by the
work of Krotov et al. [4]_. Thompson sampling is based on a beta distribution with parameters related to the number
work of Krotov et al. [7]_. Thompson sampling is based on a beta distribution with parameters related to the number
of successful and failed attempts. Higher values of the parameters decrease the entropy of the distribution while
changing the ratio of the parameters shifts the expected value.
Expand All @@ -41,8 +41,8 @@ class ThompsonSampling(BaseAgent):
References
----------
.. [4] Krotov, Alexander & Kiryanov, Anton & Khorov, Evgeny. (2020). Rate Control With Spatial Reuse
for Wi-Fi 6 Dense Deployments. IEEE Access. 8. 168898-168909. 10.1109/ACCESS.2020.3023552.
.. [7] Alexander Krotov, Anton Kiryanov and Evgeny Khorov. 2020. Rate Control With Spatial Reuse
for Wi-Fi 6 Dense Deployments. IEEE Access. 8. 168898-168909.
"""

def __init__(self, n_arms: jnp.int32, decay: Scalar = 1.0) -> None:
Expand Down
4 changes: 2 additions & 2 deletions reinforced_lib/agents/mab/ucb.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ class UCB(BaseAgent):
c : float
Degree of exploration. :math:`c \geq 0`.
gamma : float, default=1.0
If less than one, a discounted UCB algorithm [5]_ is used. :math:`\gamma \in (0, 1]`.
If less than one, a discounted UCB algorithm [8]_ is used. :math:`\gamma \in (0, 1]`.
References
----------
.. [5] Garivier, A., & Moulines, E. (2008). On Upper-Confidence Bound Policies for Non-Stationary
.. [8] Aurélien Garivier, Eric Moulines. 2008. On Upper-Confidence Bound Policies for Non-Stationary
Bandit Problems. 10.48550/ARXIV.0805.3415.
"""

Expand Down
20 changes: 9 additions & 11 deletions reinforced_lib/exts/base_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
class BaseExt(ABC):
"""
Container for domain-specific knowledge and functions for a given environment. Provides the transformation
from the observation functions and the observation space to the agent update and sample spaces. Stores the
default argument values for agent initialization.
from the raw observations to the agent update and sample spaces. Stores the default argument values for
agent initialization.
"""

def __init__(self) -> None:
Expand Down Expand Up @@ -47,9 +47,8 @@ def get_agent_params(
user_parameters: dict[str, any] = None
) -> dict[str, any]:
"""
Composes agent initialization parameters from parameters passed by the user and default values
defined in the parameter functions. Returns a dictionary with the parameters fitting the agent
parameters space.
Composes agent initialization arguments from values passed by the user and default values stored in the
parameter functions. Returns a dictionary with the parameters matching the agent parameters space.
Parameters
----------
Expand Down Expand Up @@ -104,8 +103,8 @@ def setup_transformations(
agent_sample_space: gym.spaces.Space = None
) -> None:
"""
Create functions that transform environment observations and values provided by the observation functions
to the agent update and sample space.
Creates functions that transform raw observations and values provided by the observation functions
to the agent update and sample spaces.
Parameters
----------
Expand Down Expand Up @@ -372,16 +371,15 @@ def _tuple_transform(self, observations: list[Callable], accessor: Union[str, in

def transform(self, *args, action: any = None, **kwargs) -> tuple[any, any]:
"""
Transforms environment observations and values provided by the observation functions to
the agent observation and sample spaces. Supplies action selected by the agent if it is
required by the agent and the extension is not capable of providing this value.
Transforms raw observations and values provided by the observation functions to the agent observation
and sample spaces. Provides the last action selected by the agent if it is required by the agent.
Parameters
----------
*args : tuple
Environment observations.
action : any
Action selected by the agent.
The last action selected by the agent.
**kwargs : dict
Environment observations.
Expand Down
2 changes: 1 addition & 1 deletion reinforced_lib/exts/gymnasium.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class Gymnasium(BaseExt):
"""
Gymnasium [1]_ extension. Simplifies interaction of deep RL agents with the Gymnasium environments by providing
Gymnasium [1]_ extension. Simplifies interaction of RL agents with the Gymnasium environments by providing
the environment state, reward, terminal flag, and shapes of the observation and action spaces.
Parameters
Expand Down
10 changes: 5 additions & 5 deletions reinforced_lib/logs/base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ class SourceType(Enum):

class BaseLogger(ABC):
"""
Container for functions of a logger. Provides a simple interface for defining custom loggers.
Base interface for loggers.
"""

def __init__(self, **kwargs):
pass

def init(self, sources: list[Source]) -> None:
"""
Initializes the logger given the list of all sources.
Initializes the logger given the list of all sources defined by the user.
Parameters
----------
Expand All @@ -38,7 +38,7 @@ def init(self, sources: list[Source]) -> None:

def finish(self) -> None:
"""
Finalizes the loggers work, for example, saves data or shows plots.
Finalizes the loggers work (e.g., closes file or shows plots).
"""

pass
Expand All @@ -61,7 +61,7 @@ def log_scalar(self, source: Source, value: Scalar, custom: bool) -> None:

def log_array(self, source: Source, value: Array, custom: bool) -> None:
"""
Method of the logger interface used for logging arrays.
Method of the logger interface used for logging one-dimensional arrays.
Parameters
----------
Expand Down Expand Up @@ -110,7 +110,7 @@ def log_other(self, source: Source, value: any, custom: bool) -> None:
@staticmethod
def source_to_name(source: Source) -> str:
"""
Converts the source to a string. If source is a string itself, it returns that string.
Returns a full name of the source. If source is a string itself, returns that string.
Otherwise, it returns a string in the format "name-sourcetype" (e.g., "action-metric").
Parameters
Expand Down
4 changes: 3 additions & 1 deletion reinforced_lib/logs/csv_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

class CsvLogger(BaseLogger):
"""
Logger that saves values in CSV format.
Logger that saves values in CSV format. It saves the logged values to the CSV file when the experiment is finished.
``CsvLogger`` synchronizes the logged values in time. It means that if the same source is logged twice in a row,
the step number will be incremented for all columns and the logger will move to the next row.
Parameters
----------
Expand Down
4 changes: 2 additions & 2 deletions reinforced_lib/logs/logs_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class LogsObserver:
"""
Class responsible for managing singleton instances of the loggers, initialization and finalization
of the loggers, and passing the logged values to the appropriate loggers and methods.
of the loggers, and passing the logged values to the appropriate loggers and their methods.
"""

def __init__(self) -> None:
Expand Down Expand Up @@ -123,7 +123,7 @@ def update_metrics(self, metric: any, metric_name: str) -> None:

def update_custom(self, value: any, name: str) -> None:
"""
Passes custom values to loggers.
Passes values provided by the user to the loggers.
Parameters
----------
Expand Down
Loading

0 comments on commit 3a8fd2b

Please sign in to comment.