Skip to content

Commit

Permalink
packaging updates for conda-forge
Browse files Browse the repository at this point in the history
* fix: adding examples in MANIFEST.in

* Add docs copyright (#6)

* docs: added copyrights to docs and slides

* docs: added citation reference

* docs: updated slides font size

* docs: updated slides footer format

* fix test imports (#7)

* fix: removed conftest imports for type hint

* fix: removed example imports in integration test

* chore: removed leftover comments

* update copyright (#8)

* docs: updated copyright in slides and footer

* ci: update version

* docs: updated copyright on all slides

Co-authored-by: Cyprien Courtot <c.courtot@instadeep.com>
  • Loading branch information
WissBe and Cyprien Courtot authored Jul 19, 2022
1 parent f4c97ba commit 63402ba
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 16 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ include README.md
include Dockerfile
graft tests
graft catx
graft examples

include requirements.txt
include requirements-test.txt
Expand Down
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,22 @@ run `pip install --upgrade pip` then run one of the following depending on your
```

- \[Optional\] follow step 2 above for JAX to use GPU.


## Citing CATX

To cite this repository:

```
@software{catx2022github,
author = {Wissam Bejjani and Cyprien Courtot},
title = {{CATX}: contextual bandits library for {C}ontinuous {A}ction {T}rees with {S}moothing in {JAX}},
url = {https://github.com/instadeepai/catx/},
version = {0.1.2},
year = {2022},
}
```

In this bibtex entry, the version number is intended to be from
[`catx/VERSION`](https://github.com/instadeepai/catx/blob/main/catx/VERSION),
and the year corresponds to the project's open-source release.
2 changes: 1 addition & 1 deletion catx/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.1
0.1.3
Binary file modified docs/artifacts/algo_catx.pdf
Binary file not shown.
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,5 @@ nav:
- tree: api/tree.md
- catx: api/catx.md
- network builder: api/network_builder.md

copyright: InstaDeep © 2022 Copyright, all rights reserved.
59 changes: 49 additions & 10 deletions tests/integration/test_convergence.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,62 @@
from typing import List, Tuple

import haiku as hk
from typing import List, Tuple, Optional
import jax
import optax
import pytest
from jax import numpy as jnp

import haiku as hk
import tensorflow as tf
import numpy as np
from sklearn.datasets import fetch_openml
from catx.catx import CATX
from catx.network_builder import NetworkBuilder

import numpy as np

from examples.openml_environment import OpenMLEnvironment
from catx.type_defs import Actions, Costs, Observations


def moving_average(x: List[float], w: int) -> np.ndarray:
return np.convolve(x, np.ones(w), "valid") / w


class OpenMLEnvironment:
def __init__(self, dataset_id: int, batch_size: int = 5) -> None:
self.x, self.y = fetch_openml(
data_id=dataset_id, as_frame=False, return_X_y=True
)
rows_with_nan_idx = np.argwhere(np.isnan(self.x))[:, 0]
self.x = np.delete(self.x, rows_with_nan_idx, axis=0)
self.y = np.delete(self.y, rows_with_nan_idx, axis=0)
self.x = self._normalize_data(self.x)
self.y = self._normalize_data(self.y)
self._y_mean = np.mean(self.y)
physical_devices = tf.config.list_physical_devices("GPU")

try:
tf.config.experimental.set_memory_growth(physical_devices[0], True)
except Exception:
pass

self.dataset = tf.data.Dataset.from_tensor_slices((self.x, self.y))
self.dataset = self.dataset.batch(batch_size)
self.iterator = iter(self.dataset)

def get_new_observations(self) -> Optional[Observations]:
try:
x, y = self.iterator.get_next()
self.x = x.numpy()
self.y = y.numpy()
return self.x
except tf.errors.OutOfRangeError:
return None

def get_costs(self, actions: Actions) -> Costs:
costs = np.abs(actions - self.y)
return costs

def _normalize_data(self, data: np.ndarray) -> np.ndarray:
return (data - np.min(data, axis=0)) / (
np.max(data, axis=0) - np.min(data, axis=0)
)


class MLPBuilder(NetworkBuilder):
def create_network(self, depth: int) -> hk.Module:
return hk.nets.MLP([5, 5] + [2 ** (depth + 1)], name=f"mlp_depth_{depth}")
Expand Down Expand Up @@ -46,13 +85,13 @@ def test_catx_convergence(dataset_id_loss: Tuple[int, float]) -> None:
no_iterations = 1000
for _ in range(no_iterations):
env_key, subkey = jax.random.split(env_key)
obs = environment.get_new_observations(env_key)
obs = environment.get_new_observations()
if obs is None:
break

actions, probabilities = catx.sample(obs=obs, epsilon=epsilon)

costs = environment.get_costs(key=subkey, obs=obs, actions=actions)
costs = environment.get_costs(actions=actions)

catx.learn(obs, actions, probabilities, costs)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_catx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from jax import numpy as jnp

from catx.catx import CATX
from catx.network_builder import NetworkBuilder

from catx.type_defs import Observations, Actions, Probabilities, JaxObservations
from tests.conftest import MLPBuilder


@pytest.fixture
def catx(mlp_builder: MLPBuilder, request: pytest.FixtureRequest = None) -> CATX:
def catx(mlp_builder: NetworkBuilder, request: pytest.FixtureRequest = None) -> CATX:
if not request:
action_min = 0.0
action_max = 1.0
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_catx__sample_action_range(
"action_min, action_max", [(0.0, 0.0), (1.0, 0.0), (-5.0, -10.0)]
)
def test_catx__init_action_range_sad(
mlp_builder: MLPBuilder, action_min: float, action_max: float
mlp_builder: NetworkBuilder, action_min: float, action_max: float
) -> None:
rng_key = jax.random.PRNGKey(42)
rng_key, catx_key = jax.random.split(rng_key, num=2)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import numpy as np
from chex import assert_type

from catx.network_builder import NetworkBuilder
from catx.tree import TreeParameters, Tree
from catx.type_defs import JaxObservations, Logits
from tests.conftest import MLPBuilder


@pytest.mark.parametrize("bandwidth", [1.5 / 4, 1 / 8])
Expand Down Expand Up @@ -67,7 +67,7 @@ def test_tree_parameters__bandwidth(bandwidth: float) -> None:


def test_tree(
mlp_builder: MLPBuilder,
mlp_builder: NetworkBuilder,
tree_parameters: TreeParameters,
jax_observations: JaxObservations,
) -> None:
Expand Down

0 comments on commit 63402ba

Please sign in to comment.