diff --git a/README.md b/README.md index 66d8977..061c73e 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ and much more on CATX and its inner workings. author = {Wissam Bejjani and Cyprien Courtot}, title = {CATX: contextual bandits library for Continuous Action Trees with smoothing in JAX}, url = {https://github.com/instadeepai/catx/}, - version = {0.2.0}, + version = {0.2.1}, year = {2022}, } ``` diff --git a/catx/VERSION b/catx/VERSION index 0ea3a94..0c62199 100644 --- a/catx/VERSION +++ b/catx/VERSION @@ -1 +1 @@ -0.2.0 +0.2.1 diff --git a/catx/catx.py b/catx/catx.py index 84d9560..46521cc 100644 --- a/catx/catx.py +++ b/catx/catx.py @@ -348,6 +348,7 @@ def _forward( # Scale sampled actions to the environment action range. actions = actions * (self._action_max - self._action_min) + self._action_min + probabilities /= self._action_max - self._action_min return actions, probabilities @@ -553,6 +554,7 @@ def _update( # Scale actions from the environment action range to the tree action range. actions = (actions - self._action_min) / (self._action_max - self._action_min) + probabilities *= self._action_max - self._action_min smooth_costs = self._compute_smooth_costs( costs=costs, actions=actions, probabilities=probabilities diff --git a/tests/test_tree.py b/tests/test_tree.py index 00f5408..eafe0e4 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -81,7 +81,7 @@ def _forward( tree_params=tree_parameters, ) - output_logits = tree(obs=x, network_extras=network_extras) + output_logits: Dict[int, Logits] = tree(obs=x, network_extras=network_extras) # Validate the tree has as many neural networks as depth. assert jnp.shape(jax.tree_leaves(tree.networks))[0] == tree_parameters.depth @@ -116,3 +116,13 @@ def _forward( assert logits_shape[d] == (jnp.shape(jax_observations)[0], 2**d, 2) chex.assert_tree_all_finite(logits) + + +def test_tree_parameters__probabilities_and_volumes() -> None: + tree_param = TreeParameters.construct(bandwidth=1 / 4, discretization_parameter=4) + + expected_volumes = jnp.full_like(tree_param.volumes, 1 / 2) + expected_probabilities = jnp.full_like(tree_param.probabilities, 2) + + assert jnp.allclose(tree_param.volumes, expected_volumes) + assert jnp.allclose(tree_param.probabilities, expected_probabilities)