Skip to content

Commit

Permalink
fix: scaling pdfs with action max/min (#23)
Browse files Browse the repository at this point in the history
* fix: scaling pdfs with action max/min

* test: additional test of TreeParameters.construct

* fix: scaling the probabilities inside the update function
  • Loading branch information
Cyprien authored Sep 19, 2022
1 parent fffc1bb commit 45ba963
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ and much more on CATX and its inner workings.
author = {Wissam Bejjani and Cyprien Courtot},
title = {CATX: contextual bandits library for Continuous Action Trees with smoothing in JAX},
url = {https://github.com/instadeepai/catx/},
version = {0.2.0},
version = {0.2.1},
year = {2022},
}
```
Expand Down
2 changes: 1 addition & 1 deletion catx/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.0
0.2.1
2 changes: 2 additions & 0 deletions catx/catx.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def _forward(

# Scale sampled actions to the environment action range.
actions = actions * (self._action_max - self._action_min) + self._action_min
probabilities /= self._action_max - self._action_min

return actions, probabilities

Expand Down Expand Up @@ -553,6 +554,7 @@ def _update(

# Scale actions from the environment action range to the tree action range.
actions = (actions - self._action_min) / (self._action_max - self._action_min)
probabilities *= self._action_max - self._action_min

smooth_costs = self._compute_smooth_costs(
costs=costs, actions=actions, probabilities=probabilities
Expand Down
12 changes: 11 additions & 1 deletion tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _forward(
tree_params=tree_parameters,
)

output_logits = tree(obs=x, network_extras=network_extras)
output_logits: Dict[int, Logits] = tree(obs=x, network_extras=network_extras)

# Validate the tree has as many neural networks as depth.
assert jnp.shape(jax.tree_leaves(tree.networks))[0] == tree_parameters.depth
Expand Down Expand Up @@ -116,3 +116,13 @@ def _forward(
assert logits_shape[d] == (jnp.shape(jax_observations)[0], 2**d, 2)

chex.assert_tree_all_finite(logits)


def test_tree_parameters__probabilities_and_volumes() -> None:
tree_param = TreeParameters.construct(bandwidth=1 / 4, discretization_parameter=4)

expected_volumes = jnp.full_like(tree_param.volumes, 1 / 2)
expected_probabilities = jnp.full_like(tree_param.probabilities, 2)

assert jnp.allclose(tree_param.volumes, expected_volumes)
assert jnp.allclose(tree_param.probabilities, expected_probabilities)

0 comments on commit 45ba963

Please sign in to comment.