Skip to content

Commit

Permalink
Add note about 64-bit mode
Browse files Browse the repository at this point in the history
  • Loading branch information
m-wojnar committed Dec 19, 2023
1 parent b45fe5b commit b90e50f
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
19 changes: 19 additions & 0 deletions docs/source/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,25 @@ mandatory, we strongly encourage their implementation as they allow easy samplin
methods. To learn more about the agent's methods, check out the :ref:`Custom agents <custom_agents>` section.


64-bit floating-point precision
-------------------------------

By default, JAX uses 32-bit floating-point precision. However, in some cases, you might want to use 64-bit
floating-point precision. The easiest way to achieve this is to set the ``JAX_ENABLE_X64`` environment variable to
``True``:

.. code-block:: bash
export JAX_ENABLE_X64=True
Alternatively, you can set the environment variable in your Python script:

.. code-block:: python
import os
os.environ['JAX_ENABLE_X64'] = 'True'
Real-world examples
-------------------

Expand Down
4 changes: 2 additions & 2 deletions reinforced_lib/agents/mab/normal_thompson_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class NormalThompsonSamplingState(AgentState):

class NormalThompsonSampling(BaseAgent):
r"""
Normal Thompson sampling agent [11]_. The normal-inverse-gamma distribution is a conjugate prior for the normal
Normal Thompson sampling agent [10]_. The normal-inverse-gamma distribution is a conjugate prior for the normal
distribution with unknown mean and variance. The parameters of the distribution are updated after each observation.
The mean of the normal distribution is sampled from the normal-inverse-gamma distribution and the action with
the highest expected value is selected.
Expand Down Expand Up @@ -148,7 +148,7 @@ def update(
reward: Scalar
) -> NormalThompsonSamplingState:
r"""
Normal Thompson sampling update according to [11]_.
Normal Thompson sampling update according to [10]_.
.. math::
\begin{align}
Expand Down
3 changes: 2 additions & 1 deletion reinforced_lib/agents/mab/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class Softmax(BaseAgent):
r"""
Softmax agent with baseline and optional exponential recency-weighted average update. It learns a preference
function :math:`H`, which indicates a preference of selecting one arm over others. Algorithm policy can be
controlled by the temperature parameter :math:`\tau`. The implementation is inspired by the work of Sutton and Barto [5]_.
controlled by the temperature parameter :math:`\tau`. The implementation is inspired by the work of Sutton
and Barto [5]_. **Note:** For this agent, some environments find it very beneficial to use 64-bit JAX mode!
Parameters
----------
Expand Down

0 comments on commit b90e50f

Please sign in to comment.