-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add normal-gamma and log-normal Thompson sampling
- Loading branch information
Showing
6 changed files
with
407 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
from reinforced_lib.agents.mab.e_greedy import EGreedy | ||
from reinforced_lib.agents.mab.exp3 import Exp3 | ||
from reinforced_lib.agents.mab.normal_thompson_sampling import NormalThompsonSampling | ||
from reinforced_lib.agents.mab.lognormal_thompson_sampling import LogNormalThompsonSampling | ||
from reinforced_lib.agents.mab.softmax import Softmax | ||
from reinforced_lib.agents.mab.thompson_sampling import ThompsonSampling | ||
from reinforced_lib.agents.mab.ucb import UCB |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
from chex import PRNGKey, Scalar | ||
|
||
from reinforced_lib.agents.mab.normal_thompson_sampling import NormalThompsonSampling, NormalThompsonSamplingState | ||
|
||
|
||
class LogNormalThompsonSampling(NormalThompsonSampling): | ||
r""" | ||
Log-normal Thompson sampling agent. This algorithm is designed to handle positive rewards by transforming | ||
them into the log-space. For more details, refer to the documentation on ``NormalThompsonSampling``. | ||
""" | ||
|
||
@staticmethod | ||
def update( | ||
state: NormalThompsonSamplingState, | ||
key: PRNGKey, | ||
action: jnp.int32, | ||
reward: Scalar | ||
) -> NormalThompsonSamplingState: | ||
r""" | ||
Log-normal Thompson sampling update. The update is analogous to the one in ``NormalThompsonSampling`` except | ||
that the reward is transformed into the log-space. | ||
Parameters | ||
---------- | ||
state : NormalThompsonSamplingState | ||
Current state of the agent. | ||
key : PRNGKey | ||
A PRNG key used as the random key. | ||
action : int | ||
Previously selected action. | ||
reward : Float | ||
Reward obtained upon execution of action. | ||
Returns | ||
------- | ||
NormalThompsonSamplingState | ||
Updated agent state. | ||
""" | ||
|
||
return NormalThompsonSampling.update(state, key, action, jnp.log(reward)) | ||
|
||
@staticmethod | ||
def sample(state: NormalThompsonSamplingState, key: PRNGKey) -> jnp.int32: | ||
r""" | ||
Sampling actions is analogous to the one in ``NormalThompsonSampling`` except that the mean of the log-normal | ||
distribution is computed instead of the mean of the normal distribution. | ||
Parameters | ||
---------- | ||
state : NormalThompsonSamplingState | ||
Current state of the agent. | ||
key : PRNGKey | ||
A PRNG key used as the random key. | ||
Returns | ||
------- | ||
int | ||
Selected action. | ||
""" | ||
|
||
loc_key, scale_key = jax.random.split(key) | ||
|
||
scale = jnp.sqrt(NormalThompsonSampling.inverse_gamma(scale_key, state.alpha, state.beta)) | ||
loc = state.mu + jax.random.normal(loc_key, shape=state.mu.shape) * scale / jnp.sqrt(state.lam) | ||
log_normal_mean = jnp.exp(loc + 0.5 * jnp.square(scale)) | ||
|
||
return jnp.argmax(log_normal_mean) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
from functools import partial | ||
|
||
import gymnasium as gym | ||
import jax | ||
import jax.numpy as jnp | ||
from chex import dataclass, Array, PRNGKey, Scalar | ||
|
||
from reinforced_lib.agents import BaseAgent, AgentState | ||
|
||
|
||
@dataclass | ||
class NormalThompsonSamplingState(AgentState): | ||
""" | ||
Container for the state of the normal-gamma Thompson sampling agent. | ||
Attributes | ||
---------- | ||
alpha : array_like | ||
The concentration parameter of the gamma distribution. | ||
beta : array_like | ||
The scale parameter of the gamma distribution. | ||
lam : array_like | ||
The number of observations. | ||
mu : array_like | ||
The mean of the normal distribution. | ||
""" | ||
|
||
alpha: Array | ||
beta: Array | ||
lam: Array | ||
mu: Array | ||
|
||
|
||
class NormalThompsonSampling(BaseAgent): | ||
r""" | ||
Normal-gamma [10]_ Thompson sampling agent. The implementation is based on the work of Murphy [11]_. | ||
The normal-gamma distribution is a conjugate prior for the normal distribution with unknown mean and variance. | ||
The parameters of the normal-gamma distribution are updated after each observation. The mean of the normal | ||
distribution is sampled from the normal-gamma distribution and the action with the highest mean is selected. | ||
Parameters | ||
---------- | ||
n_arms : int | ||
Number of bandit arms. :math:`N \in \mathbb{N}_{+}` . | ||
alpha : float | ||
See also ``NormalThompsonSamplingState`` for interpretation. :math:`\alpha > 0`. | ||
beta : float | ||
See also ``NormalThompsonSamplingState`` for interpretation. :math:`\beta > 0`. | ||
lam : float | ||
See also ``NormalThompsonSamplingState`` for interpretation. :math:`\lambda > 0`. | ||
mu : float | ||
See also ``NormalThompsonSamplingState`` for interpretation. :math:`\mu \in \mathbb{R}`. | ||
References | ||
---------- | ||
.. [10] Normal-gamma distribution. Wikipedia. https://en.wikipedia.org/wiki/Normal-gamma_distribution | ||
.. [11] Kevin P. Murphy. 2007. Conjugate Bayesian analysis of the Gaussian distribution. | ||
https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf | ||
""" | ||
|
||
def __init__( | ||
self, | ||
n_arms: jnp.int32, | ||
alpha: Scalar, | ||
beta: Scalar, | ||
lam: Scalar, | ||
mu: Scalar | ||
) -> None: | ||
assert alpha > 0 | ||
assert beta > 0 | ||
assert lam > 0 | ||
|
||
self.n_arms = n_arms | ||
|
||
self.init = jax.jit(partial(self.init, n_arms=self.n_arms, alpha=alpha, beta=beta, lam=lam, mu=mu)) | ||
self.update = jax.jit(self.update) | ||
self.sample = jax.jit(self.sample) | ||
|
||
@staticmethod | ||
def parameter_space() -> gym.spaces.Dict: | ||
return gym.spaces.Dict({ | ||
'n_arms': gym.spaces.Box(1, jnp.inf, (1,), jnp.int32), | ||
'alpha': gym.spaces.Box(0.0, jnp.inf, (1,), jnp.float32), | ||
'beta': gym.spaces.Box(0.0, jnp.inf, (1,), jnp.float32), | ||
'lam': gym.spaces.Box(0.0, jnp.inf, (1,), jnp.float32), | ||
'mu': gym.spaces.Box(-jnp.inf, jnp.inf, (1,), jnp.float32) | ||
}) | ||
|
||
@property | ||
def update_observation_space(self) -> gym.spaces.Dict: | ||
return gym.spaces.Dict({ | ||
'action': gym.spaces.Discrete(self.n_arms), | ||
'reward': gym.spaces.Box(-jnp.inf, jnp.inf, (1,), jnp.float32) | ||
}) | ||
|
||
@property | ||
def sample_observation_space(self) -> gym.spaces.Dict: | ||
return gym.spaces.Dict({}) | ||
|
||
@property | ||
def action_space(self) -> gym.spaces.Space: | ||
return gym.spaces.Discrete(self.n_arms) | ||
|
||
@staticmethod | ||
def init( | ||
key: PRNGKey, | ||
n_arms: jnp.int32, | ||
alpha: Scalar, | ||
beta: Scalar, | ||
lam: Scalar, | ||
mu: Scalar | ||
) -> NormalThompsonSamplingState: | ||
r""" | ||
Creates and initializes an instance of the normal-gamma Thompson sampling agent for ``n_arms`` arms | ||
and the given initial parameters for the prior distribution. | ||
Parameters | ||
---------- | ||
key : PRNGKey | ||
A PRNG key used as the random key. | ||
n_arms : int | ||
Number of bandit arms. | ||
alpha : float | ||
See also ``NormalThompsonSamplingState`` for interpretation. | ||
beta : float | ||
See also ``NormalThompsonSamplingState`` for interpretation. | ||
lam : float | ||
See also ``NormalThompsonSamplingState`` for interpretation. | ||
mu : float | ||
See also ``NormalThompsonSamplingState`` for interpretation. | ||
Returns | ||
------- | ||
NormalThompsonSamplingState | ||
Initial state of the normal-gamma Thompson sampling agent. | ||
""" | ||
|
||
return NormalThompsonSamplingState( | ||
alpha=jnp.full((n_arms, 1), alpha), | ||
beta=jnp.full((n_arms, 1), beta), | ||
lam=jnp.full((n_arms, 1), lam), | ||
mu=jnp.full((n_arms, 1), mu) | ||
) | ||
|
||
@staticmethod | ||
def update( | ||
state: NormalThompsonSamplingState, | ||
key: PRNGKey, | ||
action: jnp.int32, | ||
reward: Scalar | ||
) -> NormalThompsonSamplingState: | ||
r""" | ||
Normal-gamma Thompson sampling update according to [11]_. | ||
.. math:: | ||
\begin{align} | ||
\alpha_{t + 1}(a) &= \alpha_t(a) + \frac{1}{2} \\ | ||
\beta_{t + 1}(a) &= \beta_t(a) + \frac{\lambda_t(a) (r_t(a) - \mu_t(a))^2}{2 (\lambda_t(a) + 1)} \\ | ||
\lambda_{t + 1}(a) &= \lambda_t(a) + 1 \\ | ||
\mu_{t + 1}(a) &= \frac{\mu_t(a) \lambda_t(a) + r_t(a)}{\lambda_t(a) + 1} | ||
\end{align} | ||
Parameters | ||
---------- | ||
state : NormalThompsonSamplingState | ||
Current state of the agent. | ||
key : PRNGKey | ||
A PRNG key used as the random key. | ||
action : int | ||
Previously selected action. | ||
reward : Float | ||
Reward obtained upon execution of action. | ||
Returns | ||
------- | ||
NormalThompsonSamplingState | ||
Updated agent state. | ||
""" | ||
|
||
lam = state.lam[action] | ||
mu = state.mu[action] | ||
|
||
return NormalThompsonSamplingState( | ||
alpha=state.alpha.at[action].add(1 / 2), | ||
beta=state.beta.at[action].add((lam * jnp.square(reward - mu)) / (2 * (lam + 1))), | ||
lam=state.lam.at[action].add(1), | ||
mu=state.mu.at[action].set((mu * lam + reward) / (lam + 1)) | ||
) | ||
|
||
@staticmethod | ||
def inverse_gamma(key: PRNGKey, concentration: Array, scale: Array) -> Array: | ||
r""" | ||
Samples from the inverse gamma distribution. Implementation is based on the gamma distribution and the | ||
following dependence: | ||
.. math:: | ||
\begin{gather} | ||
X \sim \operatorname{Gamma}(\alpha, \beta) \\ | ||
\frac{1}{X} \sim \operatorname{Inverse-gamma}(\alpha, \frac{1}{\beta}) | ||
\end{gather} | ||
Parameters | ||
---------- | ||
key : PRNGKey | ||
A PRNG key used as the random key. | ||
concentration : array_like | ||
The concentration parameter of the inverse-gamma distribution. | ||
scale : array_like | ||
The scale parameter of the inverse-gamma distribution. | ||
Returns | ||
------- | ||
array_like | ||
Sampled values from the inverse gamma distribution. | ||
""" | ||
|
||
gamma = jax.random.gamma(key, concentration) / scale | ||
return 1 / gamma | ||
|
||
@staticmethod | ||
def sample(state: NormalThompsonSamplingState, key: PRNGKey) -> jnp.int32: | ||
r""" | ||
The normal-gamma Thompson sampling policy is stochastic. The algorithm draws :math:`q_a` from the distribution | ||
:math:`\operatorname{Normal}(\mu(a), \operatorname{scale}(a)/\sqrt{\lambda(a)})` for each arm :math:`a` where | ||
:math:`\text{scale}(a)` is sampled from the inverse gamma distribution with parameters :math:`\alpha(a)` and | ||
:math:`\beta(a)`. The next action is selected as :math:`A = \operatorname*{argmax}_{a \in \mathscr{A}} q_a`, | ||
where :math:`\mathscr{A}` is a set of all actions. | ||
Parameters | ||
---------- | ||
state : NormalThompsonSamplingState | ||
Current state of the agent. | ||
key : PRNGKey | ||
A PRNG key used as the random key. | ||
Returns | ||
------- | ||
int | ||
Selected action. | ||
""" | ||
|
||
loc_key, scale_key = jax.random.split(key) | ||
|
||
scale = jnp.sqrt(NormalThompsonSampling.inverse_gamma(scale_key, state.alpha, state.beta)) | ||
loc = state.mu + jax.random.normal(loc_key, shape=state.mu.shape) * scale / jnp.sqrt(state.lam) | ||
|
||
return jnp.argmax(loc) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.