Skip to content

Commit

Permalink
Docs for external state (#20)
Browse files Browse the repository at this point in the history
* docs: updated API references

* docs: updated docs to account for state + fixed typos

* docs: fixed typo in catx shopping image example

* docs: update typos in images

* docs: fixed docstring typos

* chore: updated catx VERSION

* docs: fixed image type

* docs: removed red highlights from images
  • Loading branch information
WissBe authored Aug 3, 2022
1 parent dd9f54a commit 5dbc206
Show file tree
Hide file tree
Showing 13 changed files with 62 additions and 42 deletions.
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
![Logo_CatX_Final_PNG](docs/img/Logo_CatX_Final_PNG.png)

CATX is library for training and using a contextual bandits in a continuous action space.
CATX is a library for training and using contextual bandits in a continuous action space.

CATX builds on the work presented in
["Efficient Contextual Bandits with Continuous Actions (CATS)"](https://arxiv.org/pdf/2006.06040.pdf) by Majzoubi et al.
CATX brings forth the freedom to implement custom neural network architectures
as decision agents within the learning algorithm.
It allows for grater scalability and context modalities while
It allows for greater scalability and context modalities while
also leveraging the computational speed of [JAX](https://github.com/google/jax).

## Target users
Expand All @@ -20,22 +20,21 @@ and allowing custom neural networks in the tree structure of the CATS algorithm.


## Documentation
Go to [documentations](https://catx.readthedocs.io/en/main/)
Go to [documentation](https://catx.readthedocs.io/en/main/)
to find everything you need to know about CATX
from the [installation](https://catx.readthedocs.io/en/main/installation/) with `pip install catx`
to a quick [getting started](https://catx.readthedocs.io/en/main/getting_started/) example
and much more on CATX and its inner workings.


## Citing CATX
To cite this repository:

```
@software{catx2022github,
author = {Wissam Bejjani and Cyprien Courtot},
title = {CATX: contextual bandits library for Continuous Action Trees with smoothing in JAX},
url = {https://github.com/instadeepai/catx/},
version = {0.1.4},
version = {0.2.0},
year = {2022},
}
```
Expand Down
2 changes: 1 addition & 1 deletion catx/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.4
0.2.0
12 changes: 6 additions & 6 deletions catx/catx.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

@dataclass
class CATXState:
"""Holds the CATX's training state."""
"""Holds the CATX's training state and extra parameterization for the networks."""

params: hk.Params
depth_params: Dict[int, hk.Params]
Expand Down Expand Up @@ -196,7 +196,7 @@ def init(
"""Initializes the parameters of tree's neural networks,
the forward functions, and the optimizer states.
This functions can only be called once. It is called the first time a CATX instance is used.
This function can only be called once. It is called the first time a CATX instance is used.
Args:
obs: the observations, i.e., batched contexts.
Expand Down Expand Up @@ -251,7 +251,7 @@ def _create_forward_fn(
network_extras: NetworkExtras,
) -> Tuple[hk.Params, Wrapped]:
"""Creates a jitted forward function of the tree
and initializes the parameters of tree's neural networks.
and initializes the parameters of the tree's neural networks.
Args:
obs: the observations, i.e., batched contexts.
Expand All @@ -269,7 +269,7 @@ def _forward(
epsilon: float,
network_extras: NetworkExtras,
) -> Tuple[JaxActions, JaxProbabilities]:
"""This forward function defines how the tree is traversed and how actions sampled:
"""This forward function defines how the tree is traversed and how actions are sampled:
- All the tree logits are queried (one set of pairwise logits per tree depth).
- The tree is traversed by following the max of the logits at each
tree depth until an action centroid is reached.
Expand Down Expand Up @@ -369,7 +369,7 @@ def _create_forward_single_depth_fns(
key: chex.PRNGKey,
network_extras: NetworkExtras,
) -> Tuple[Dict[int, Wrapped], Dict[int, hk.Params]]:
"""Creates a dictionary of jitted forward functions, one per neural networks at each tree depth
"""Creates a dictionary of jitted forward functions, one per neural network at each tree depth
and initializes the parameters of these neural networks.
Args:
Expand Down Expand Up @@ -493,7 +493,7 @@ def _loss(
rng_key: chex.PRNGKey,
network_extras: NetworkExtras,
) -> JaxLoss:
"""Computes the loss function a given depth.
"""Computes the loss function at a given depth.
Args:
layer_params: a dictionary of neural network parameters with tree depth as key.
Expand Down
4 changes: 2 additions & 2 deletions catx/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TreeParameters:
bandwidth: the bucket half width covered by action centroid.
discretization_parameter: the number of action centroids.
action_space: the range from which actions can be generated.
depth: number layers in the tree.
depth: number of layers in the tree.
spaces: an array indicating the start and end range of each action centroid.
volumes: an array indicating the bandwidth of around each action centroid.
probabilities: h-smoothing of policy π_t (one over volumes).
Expand Down Expand Up @@ -55,7 +55,7 @@ def construct(
discretization_parameter & (discretization_parameter - 1)
):
raise ValueError(
"discretization_parameter must be power of 2 number and larger than 1."
"discretization_parameter must be a power of 2 number and larger than 1."
)

action_space = jnp.linspace(
Expand Down
1 change: 0 additions & 1 deletion docs/api/network_builder.md

This file was deleted.

1 change: 1 addition & 0 deletions docs/api/network_module.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: catx.network_module
44 changes: 22 additions & 22 deletions docs/bandits.md
Original file line number Diff line number Diff line change
@@ -1,41 +1,41 @@
# From bandits to contextual Bandits with continuous actions

## What is bandits?
The bandits problem consist of finding the action to execute among several actions
that entails the lowest cost (or highest reward).
## What are bandits?
The bandits' problem consists of finding the action to execute among several actions
that entail the lowest cost (or highest reward).

The classic bandits example goes as follows:
The classic bandits' example goes as follows:

In casino far far away, there exists multiple slot machines (also known as multi-arm bandit).
In a casino far far away, there exist multiple slot machines (also known as multi-armed bandit).
Each machine represents an action. The illustrated example shows 4 machines, i.e., 4 actions.
Some machines have higher winning probabilities.
For example, the blue machine could have a win probability of 2% while the green machine 2.5%, the yellow 1.5%, and the red 3%.
It is most advantagous to always play the red machine but these probabilities are not known to the player.
The goal is to find the best machine to play while minimizing the amount of time or money waisted on suboptimal machines.
Most approaches to this problem starts by randomly trying machines and keeping track of each machine win rate,
then gradually converge to the machine which appears to have the highest winning probability.
It is most advantageous to always play the red machine but these probabilities are not known to the player.
The goal is to find the best machine to play while minimizing the amount of time or money wasted on suboptimal machines.
Most approaches to this problem start by randomly trying machines and keeping track of each machine's win rate,
then gradually converging to the machine which appears to have the highest winning probability.

![bandits](img/bandits.png)


---
A/B testing is one of the simplest approaches to tackle the bandits problem:
A/B testing is one of the simplest approaches to tackle the bandits' problem:

* step 1: try each machine a fixed number of times

* step 2: always play the machine with the highest win rate as unraveled in step 1
* step 2: always play the machine with the highest win rate as unravelled in step 1

Note: A/B testing focuses on simplicity
at the expense of not minimizing the amount of time or money waisted on suboptimal actions.
at the expense of not minimizing the amount of time or money wasted on suboptimal actions.
---

The same bandits problem can be found under different domains.
For example, in an online shopping, the pricing algorithm has to select the profit margin to add on certain product
The same bandits' problem can be found under different domains.
For example, in an online shopping store, the pricing algorithm has to select the profit margin to add on certain product
from 4 available profit margin options, say 2%, 4%, 6%, and 8%.
A too high of a margin will drive customers away and a low margin might be missing a profit opportunity.


## What is contextual bandits?
## What are contextual bandits?
Contextual bandits are similar the previous bandits problem with a small difference.
The cost (or reward) associated to an action is also conditioned on the current context of the environment.
The contextexual information is also available to the dacision maker.
Expand All @@ -52,23 +52,23 @@ Different machine are affected differently by the context.
In the online shopping example, the contextual information could be the location and age
of the customer.

Solving a contextual bandits problem often involves learning a condtional probabilities over the actions
Solving a contextual bandits' problem often involves learning conditional probabilities over the actions
conditioned by the context.

## What is contextual bandits with continuous actions?
## What are contextual bandits with continuous actions?
From an application perspective, contextual bandits with continuous actions
are the more general case of contextual bandits where the actions
are defined over a continuous space.

Also following on the slot machine example, now instead of having 4 machines, i.e., actions, to chose from,
there is only one slot machine with an action selection knob that can be rotated between 0 and 180 degree.
Also following the slot machine example, now instead of having 4 machines, i.e., actions, to choose from,
there is only one slot machine with an action selection knob that can be rotated anywhere between 0 and 180 degrees.
The winning probability over the continuous action space is also conditioned by the temperature and humidity context.

![contextual_bandits_cont_act](img/contextual_bandits_cont_act.png)

In the online shopping example, the action range is now defined over a continuous margin value between 2% and 8%.
The continuous action space give more freedom to the pricing algorithm to finetune the margin
to minimize the regret of missing profit opportunities.
In the online shopping store example, the action range is now defined over a continuous margin value between 2% and 8%.
The continuous action space gives more freedom to the pricing algorithm to finetune the margin
to **minimize the regret** of missing profit opportunities or driving customers away.

## CATX and the online shopping example:
Below is a visual example of using CATX to learn the best margin to apply in the online shopping example.
Expand Down
6 changes: 5 additions & 1 deletion docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ class BlackFridayEnvironment:
## Training loop
One of the main advantages of CATX is the flexibility
of defining a custom neural network architecture within the tree.
In this example, we use a multilayer perceptron (MLP) network with dropouts.

The custom neural network must be a [JAX/Haiku](https://github.com/deepmind/dm-haiku)
network which inherits from CATXHaikuNetwork.
In this example, we use a multilayer perceptron (MLP)
network with dropouts that are activated during the learning step.
> **_IMPORTANT:_** The number of neurons at the output layer should be 2**(depth+1)
```python
Expand Down
Binary file removed docs/img/catx_example.png
Binary file not shown.
Binary file modified docs/img/catx_shopping.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 19 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

# CATX: contextual bandits for continuous actions using trees with smoothing in JAX

CATX is library for training and using a contextual bandits in a continuous action space.
CATX is a library for training and using contextual bandits in a continuous action space.


CATX builds on the work presented in
["Efficient Contextual Bandits with Continuous Actions (CATS)"](https://arxiv.org/pdf/2006.06040.pdf) by Majzoubi et al.
CATX brings forth the freedom to implement custom neural network architectures
as decision agents within the learning algorithm.
It allows for grater scalability and context modalities while
It allows for greater scalability and context modalities while
also leveraging the computational speed of [JAX](https://github.com/google/jax).

## Target users
Expand All @@ -19,3 +19,20 @@ Contextual bandits settings, where the exploration-exploitation trade-off needs
can be found in many industries and use cases.
CATX offers a valuable boost to this type of problem, by implementing contextual bandits with continuous actions in JAX,
and allowing custom neural networks in the tree structure of the CATS algorithm.


## Citing CATX

```
@software{catx2022github,
author = {Wissam Bejjani and Cyprien Courtot},
title = {CATX: contextual bandits library for Continuous Action Trees with smoothing in JAX},
url = {https://github.com/instadeepai/catx/},
version = {0.2.0},
year = {2022},
}
```

In this bibtex entry, the version number is intended to be from
[`catx/VERSION`](https://github.com/instadeepai/catx/blob/main/catx/VERSION),
and the year corresponds to the project's open-source release.
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,6 @@ nav:
- API Reference:
- tree: api/tree.md
- catx: api/catx.md
- network builder: api/network_builder.md
- network module: api/network_module.md

copyright: InstaDeep © 2022 Copyright, all rights reserved.
2 changes: 1 addition & 1 deletion tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_tree_parameters__discretization_parameter(
) -> None:
with pytest.raises(
ValueError,
match="discretization_parameter must be power of 2 number and larger than 1.",
match="discretization_parameter must be a power of 2 number and larger than 1.",
):
TreeParameters.construct(
bandwidth=1 / 8, discretization_parameter=discretization_parameter
Expand Down

0 comments on commit 5dbc206

Please sign in to comment.