Skip to content

Commit

Permalink
Add normal-gamma and log-normal Thompson sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
m-wojnar committed Dec 15, 2023
1 parent a99f3a2 commit 708295b
Show file tree
Hide file tree
Showing 6 changed files with 407 additions and 1 deletion.
24 changes: 24 additions & 0 deletions docs/source/agents.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,30 @@ Thompson sampling
:members:


Normal-gamma Thompson sampling
------------------------------

.. currentmodule:: reinforced_lib.agents.mab.normal_thompson_sampling

.. autoclass:: NormalThompsonSamplingState
:show-inheritance:
:members:

.. autoclass:: NormalThompsonSampling
:show-inheritance:
:members:


Log-normal Thompson sampling
---------------------------------

.. currentmodule:: reinforced_lib.agents.mab.lognormal_thompson_sampling

.. autoclass:: LogNormalThompsonSampling
:show-inheritance:
:members:


Upper confidence bound (UCB)
----------------------------

Expand Down
2 changes: 2 additions & 0 deletions reinforced_lib/agents/mab/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from reinforced_lib.agents.mab.e_greedy import EGreedy
from reinforced_lib.agents.mab.exp3 import Exp3
from reinforced_lib.agents.mab.normal_thompson_sampling import NormalThompsonSampling
from reinforced_lib.agents.mab.lognormal_thompson_sampling import LogNormalThompsonSampling
from reinforced_lib.agents.mab.softmax import Softmax
from reinforced_lib.agents.mab.thompson_sampling import ThompsonSampling
from reinforced_lib.agents.mab.ucb import UCB
69 changes: 69 additions & 0 deletions reinforced_lib/agents/mab/lognormal_thompson_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import jax
import jax.numpy as jnp
from chex import PRNGKey, Scalar

from reinforced_lib.agents.mab.normal_thompson_sampling import NormalThompsonSampling, NormalThompsonSamplingState


class LogNormalThompsonSampling(NormalThompsonSampling):
r"""
Log-normal Thompson sampling agent. This algorithm is designed to handle positive rewards by transforming
them into the log-space. For more details, refer to the documentation on ``NormalThompsonSampling``.
"""

@staticmethod
def update(
state: NormalThompsonSamplingState,
key: PRNGKey,
action: jnp.int32,
reward: Scalar
) -> NormalThompsonSamplingState:
r"""
Log-normal Thompson sampling update. The update is analogous to the one in ``NormalThompsonSampling`` except
that the reward is transformed into the log-space.
Parameters
----------
state : NormalThompsonSamplingState
Current state of the agent.
key : PRNGKey
A PRNG key used as the random key.
action : int
Previously selected action.
reward : Float
Reward obtained upon execution of action.
Returns
-------
NormalThompsonSamplingState
Updated agent state.
"""

return NormalThompsonSampling.update(state, key, action, jnp.log(reward))

@staticmethod
def sample(state: NormalThompsonSamplingState, key: PRNGKey) -> jnp.int32:
r"""
Sampling actions is analogous to the one in ``NormalThompsonSampling`` except that the mean of the log-normal
distribution is computed instead of the mean of the normal distribution.
Parameters
----------
state : NormalThompsonSamplingState
Current state of the agent.
key : PRNGKey
A PRNG key used as the random key.
Returns
-------
int
Selected action.
"""

loc_key, scale_key = jax.random.split(key)

scale = jnp.sqrt(NormalThompsonSampling.inverse_gamma(scale_key, state.alpha, state.beta))
loc = state.mu + jax.random.normal(loc_key, shape=state.mu.shape) * scale / jnp.sqrt(state.lam)
log_normal_mean = jnp.exp(loc + 0.5 * jnp.square(scale))

return jnp.argmax(log_normal_mean)
246 changes: 246 additions & 0 deletions reinforced_lib/agents/mab/normal_thompson_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
from functools import partial

import gymnasium as gym
import jax
import jax.numpy as jnp
from chex import dataclass, Array, PRNGKey, Scalar

from reinforced_lib.agents import BaseAgent, AgentState


@dataclass
class NormalThompsonSamplingState(AgentState):
"""
Container for the state of the normal-gamma Thompson sampling agent.
Attributes
----------
alpha : array_like
The concentration parameter of the gamma distribution.
beta : array_like
The scale parameter of the gamma distribution.
lam : array_like
The number of observations.
mu : array_like
The mean of the normal distribution.
"""

alpha: Array
beta: Array
lam: Array
mu: Array


class NormalThompsonSampling(BaseAgent):
r"""
Normal-gamma [10]_ Thompson sampling agent. The implementation is based on the work of Murphy [11]_.
The normal-gamma distribution is a conjugate prior for the normal distribution with unknown mean and variance.
The parameters of the normal-gamma distribution are updated after each observation. The mean of the normal
distribution is sampled from the normal-gamma distribution and the action with the highest mean is selected.
Parameters
----------
n_arms : int
Number of bandit arms. :math:`N \in \mathbb{N}_{+}` .
alpha : float
See also ``NormalThompsonSamplingState`` for interpretation. :math:`\alpha > 0`.
beta : float
See also ``NormalThompsonSamplingState`` for interpretation. :math:`\beta > 0`.
lam : float
See also ``NormalThompsonSamplingState`` for interpretation. :math:`\lambda > 0`.
mu : float
See also ``NormalThompsonSamplingState`` for interpretation. :math:`\mu \in \mathbb{R}`.
References
----------
.. [10] Normal-gamma distribution. Wikipedia. https://en.wikipedia.org/wiki/Normal-gamma_distribution
.. [11] Kevin P. Murphy. 2007. Conjugate Bayesian analysis of the Gaussian distribution.
https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf
"""

def __init__(
self,
n_arms: jnp.int32,
alpha: Scalar,
beta: Scalar,
lam: Scalar,
mu: Scalar
) -> None:
assert alpha > 0
assert beta > 0
assert lam > 0

self.n_arms = n_arms

self.init = jax.jit(partial(self.init, n_arms=self.n_arms, alpha=alpha, beta=beta, lam=lam, mu=mu))
self.update = jax.jit(self.update)
self.sample = jax.jit(self.sample)

@staticmethod
def parameter_space() -> gym.spaces.Dict:
return gym.spaces.Dict({
'n_arms': gym.spaces.Box(1, jnp.inf, (1,), jnp.int32),
'alpha': gym.spaces.Box(0.0, jnp.inf, (1,), jnp.float32),
'beta': gym.spaces.Box(0.0, jnp.inf, (1,), jnp.float32),
'lam': gym.spaces.Box(0.0, jnp.inf, (1,), jnp.float32),
'mu': gym.spaces.Box(-jnp.inf, jnp.inf, (1,), jnp.float32)
})

@property
def update_observation_space(self) -> gym.spaces.Dict:
return gym.spaces.Dict({
'action': gym.spaces.Discrete(self.n_arms),
'reward': gym.spaces.Box(-jnp.inf, jnp.inf, (1,), jnp.float32)
})

@property
def sample_observation_space(self) -> gym.spaces.Dict:
return gym.spaces.Dict({})

@property
def action_space(self) -> gym.spaces.Space:
return gym.spaces.Discrete(self.n_arms)

@staticmethod
def init(
key: PRNGKey,
n_arms: jnp.int32,
alpha: Scalar,
beta: Scalar,
lam: Scalar,
mu: Scalar
) -> NormalThompsonSamplingState:
r"""
Creates and initializes an instance of the normal-gamma Thompson sampling agent for ``n_arms`` arms
and the given initial parameters for the prior distribution.
Parameters
----------
key : PRNGKey
A PRNG key used as the random key.
n_arms : int
Number of bandit arms.
alpha : float
See also ``NormalThompsonSamplingState`` for interpretation.
beta : float
See also ``NormalThompsonSamplingState`` for interpretation.
lam : float
See also ``NormalThompsonSamplingState`` for interpretation.
mu : float
See also ``NormalThompsonSamplingState`` for interpretation.
Returns
-------
NormalThompsonSamplingState
Initial state of the normal-gamma Thompson sampling agent.
"""

return NormalThompsonSamplingState(
alpha=jnp.full((n_arms, 1), alpha),
beta=jnp.full((n_arms, 1), beta),
lam=jnp.full((n_arms, 1), lam),
mu=jnp.full((n_arms, 1), mu)
)

@staticmethod
def update(
state: NormalThompsonSamplingState,
key: PRNGKey,
action: jnp.int32,
reward: Scalar
) -> NormalThompsonSamplingState:
r"""
Normal-gamma Thompson sampling update according to [11]_.
.. math::
\begin{align}
\alpha_{t + 1}(a) &= \alpha_t(a) + \frac{1}{2} \\
\beta_{t + 1}(a) &= \beta_t(a) + \frac{\lambda_t(a) (r_t(a) - \mu_t(a))^2}{2 (\lambda_t(a) + 1)} \\
\lambda_{t + 1}(a) &= \lambda_t(a) + 1 \\
\mu_{t + 1}(a) &= \frac{\mu_t(a) \lambda_t(a) + r_t(a)}{\lambda_t(a) + 1}
\end{align}
Parameters
----------
state : NormalThompsonSamplingState
Current state of the agent.
key : PRNGKey
A PRNG key used as the random key.
action : int
Previously selected action.
reward : Float
Reward obtained upon execution of action.
Returns
-------
NormalThompsonSamplingState
Updated agent state.
"""

lam = state.lam[action]
mu = state.mu[action]

return NormalThompsonSamplingState(
alpha=state.alpha.at[action].add(1 / 2),
beta=state.beta.at[action].add((lam * jnp.square(reward - mu)) / (2 * (lam + 1))),
lam=state.lam.at[action].add(1),
mu=state.mu.at[action].set((mu * lam + reward) / (lam + 1))
)

@staticmethod
def inverse_gamma(key: PRNGKey, concentration: Array, scale: Array) -> Array:
r"""
Samples from the inverse gamma distribution. Implementation is based on the gamma distribution and the
following dependence:
.. math::
\begin{gather}
X \sim \operatorname{Gamma}(\alpha, \beta) \\
\frac{1}{X} \sim \operatorname{Inverse-gamma}(\alpha, \frac{1}{\beta})
\end{gather}
Parameters
----------
key : PRNGKey
A PRNG key used as the random key.
concentration : array_like
The concentration parameter of the inverse-gamma distribution.
scale : array_like
The scale parameter of the inverse-gamma distribution.
Returns
-------
array_like
Sampled values from the inverse gamma distribution.
"""

gamma = jax.random.gamma(key, concentration) / scale
return 1 / gamma

@staticmethod
def sample(state: NormalThompsonSamplingState, key: PRNGKey) -> jnp.int32:
r"""
The normal-gamma Thompson sampling policy is stochastic. The algorithm draws :math:`q_a` from the distribution
:math:`\operatorname{Normal}(\mu(a), \operatorname{scale}(a)/\sqrt{\lambda(a)})` for each arm :math:`a` where
:math:`\text{scale}(a)` is sampled from the inverse gamma distribution with parameters :math:`\alpha(a)` and
:math:`\beta(a)`. The next action is selected as :math:`A = \operatorname*{argmax}_{a \in \mathscr{A}} q_a`,
where :math:`\mathscr{A}` is a set of all actions.
Parameters
----------
state : NormalThompsonSamplingState
Current state of the agent.
key : PRNGKey
A PRNG key used as the random key.
Returns
-------
int
Selected action.
"""

loc_key, scale_key = jax.random.split(key)

scale = jnp.sqrt(NormalThompsonSampling.inverse_gamma(scale_key, state.alpha, state.beta))
loc = state.mu + jax.random.normal(loc_key, shape=state.mu.shape) * scale / jnp.sqrt(state.lam)

return jnp.argmax(loc)
2 changes: 1 addition & 1 deletion reinforced_lib/agents/mab/ucb.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def sample(
where :math:`\mathscr{A}` is a set of all actions and :math:`Q` is calculated as :math:`Q(a) = \frac{R(a)}{N(a)}`.
The second component of the sum represents a sort of upper bound on the value of :math:`Q`, where :math:`c`
behaves like a confidence interval and the square root - like an approximation of the :math:`Q` function
estimation uncertainty. Note that the UCB policy is deterministic.
estimation uncertainty. Note that the UCB policy is deterministic (apart from choosing between several optimal actions).
Parameters
----------
Expand Down
Loading

0 comments on commit 708295b

Please sign in to comment.