diff --git a/coax/_base/test_case.py b/coax/_base/test_case.py index 411cc92..7d1d659 100644 --- a/coax/_base/test_case.py +++ b/coax/_base/test_case.py @@ -239,7 +239,7 @@ def func(S, is_training): hk.Linear(self.env_discrete.action_space.n * 51), hk.Reshape((self.env_discrete.action_space.n, 51)) )) - return seq(flatten(S)) + return {'logits': seq(flatten(S))} return func @property diff --git a/coax/policy_objectives/_deterministic_pg.py b/coax/policy_objectives/_deterministic_pg.py index e5533b2..a990cb1 100644 --- a/coax/policy_objectives/_deterministic_pg.py +++ b/coax/policy_objectives/_deterministic_pg.py @@ -4,8 +4,7 @@ import haiku as hk import chex -from ..utils import check_preprocessors -from .._core.q import Q +from ..utils import check_preprocessors, is_qfunction, is_stochastic from ._base import PolicyObjective @@ -61,7 +60,7 @@ class DeterministicPG(PolicyObjective): REQUIRES_PROPENSITIES = False def __init__(self, pi, q_targ, optimizer=None, regularizer=None): - if not isinstance(q_targ, Q): + if not is_qfunction(q_targ): raise TypeError(f"q must be a q-function, got: {type(q_targ)}") if q_targ.modeltype != 1: raise TypeError("q must be a type-1 q-function") @@ -97,7 +96,12 @@ def objective_func(self, params, state, hyperparams, rng, transition_batch, Adv) A = self.pi.proba_dist.mode(dist_params) log_pi = self.pi.proba_dist.log_proba(dist_params, A) params_q, state_q = hyperparams['q']['params'], hyperparams['q']['function_state'] - Q, _ = self.q_targ.function_type1(params_q, state_q, next(rngs), S, A, True) + if is_stochastic(self.q_targ): + dist_params_q, _ = self.q_targ.function_type1(params_q, state_q, rng, S, A, True) + Q = self.q_targ.proba_dist.mean(dist_params_q) + Q = self.q_targ.proba_dist.postprocess_variate(next(rngs), Q, batch_mode=True) + else: + Q, _ = self.q_targ.function_type1(params_q, state_q, next(rngs), S, A, True) # clip importance weights to reduce variance W = jnp.clip(transition_batch.W, 0.1, 10.) diff --git a/coax/policy_objectives/_soft_pg.py b/coax/policy_objectives/_soft_pg.py index 60b6f7c..0314eaa 100644 --- a/coax/policy_objectives/_soft_pg.py +++ b/coax/policy_objectives/_soft_pg.py @@ -4,7 +4,7 @@ from ._base import PolicyObjective -from ..utils import is_qfunction +from ..utils import is_qfunction, is_stochastic class SoftPG(PolicyObjective): @@ -37,7 +37,12 @@ def objective_func(self, params, state, hyperparams, rng, transition_batch, Adv) for q_targ, params_q, state_q in qs: # compute objective: q(s, a) S = q_targ.observation_preprocessor(next(rngs), transition_batch.S) - Q, _ = q_targ.function_type1(params_q, state_q, next(rngs), S, A, True) + if is_stochastic(q_targ): + dist_params_q, _ = q_targ.function_type1(params_q, state_q, rng, S, A, True) + Q = q_targ.proba_dist.mean(dist_params_q) + Q = q_targ.proba_dist.postprocess_variate(next(rngs), Q, batch_mode=True) + else: + Q, _ = q_targ.function_type1(params_q, state_q, next(rngs), S, A, True) Q_sa_list.append(Q) # take the min to mitigate over-estimation Q_sa_next_list = jnp.stack(Q_sa_list, axis=-1) diff --git a/coax/td_learning/_clippeddoubleqlearning.py b/coax/td_learning/_clippeddoubleqlearning.py index 7d5aeed..2b62dbe 100644 --- a/coax/td_learning/_clippeddoubleqlearning.py +++ b/coax/td_learning/_clippeddoubleqlearning.py @@ -6,12 +6,14 @@ import chex from gym.spaces import Discrete -from .._core.q import Q -from ..utils import get_grads_diagnostics, is_policy, is_stochastic, jit -from ._base import BaseTDLearning +from ..proba_dists import DiscretizedIntervalDist, EmpiricalQuantileDist +from ..utils import (get_grads_diagnostics, is_policy, is_qfunction, + is_stochastic, jit, single_to_batch, batch_to_single, stack_trees) +from ..value_losses import quantile_huber +from ._base import BaseTDLearningQ -class ClippedDoubleQLearning(BaseTDLearning): # TODO(krholshe): make this less ugly +class ClippedDoubleQLearning(BaseTDLearningQ): # TODO(krholshe): make this less ugly r""" TD-learning with `TD3 `_ style double q-learning updates, in @@ -97,18 +99,14 @@ def __init__( self, q, pi_targ_list=None, q_targ_list=None, optimizer=None, loss_function=None, policy_regularizer=None): - if is_stochastic(q): - raise NotImplementedError(f"{type(self).__name__} is not yet implement for StochasticQ") - super().__init__( - f=q, - f_targ=None, + q=q, + q_targ=None, optimizer=optimizer, loss_function=loss_function, policy_regularizer=policy_regularizer) self._check_input_lists(pi_targ_list, q_targ_list) - del self._f_targ # no need for this (only potential source of confusion) self.q_targ_list = q_targ_list self.pi_targ_list = [] if pi_targ_list is None else pi_targ_list @@ -136,11 +134,34 @@ def loss_func(params, target_params, state, target_state, rng, transition_batch) metrics.update({f'{self.__class__.__name__}/{k}': v for k, v in regularizer_metrics.items()}) - Q, state_new = self.q.function_type1(params, state, next(rngs), S, A, True) - G = self.target_func(target_params, target_state, next(rngs), transition_batch) - # flip sign (typical example: regularizer = -beta * entropy) - G -= regularizer - loss = self.loss_function(G, Q, W) + if is_stochastic(self.q): + dist_params, state_new = \ + self.q.function_type1(params, state, next(rngs), S, A, True) + dist_params_target = \ + self.target_func(target_params, target_state, rng, transition_batch) + + if self.policy_regularizer is not None: + dist_params_target = self.q.proba_dist.affine_transform( + dist_params_target, 1., -regularizer, self.q.value_transform) + + if isinstance(self.q.proba_dist, DiscretizedIntervalDist): + loss = jnp.mean(self.q.proba_dist.cross_entropy(dist_params_target, + dist_params)) + elif isinstance(self.q.proba_dist, EmpiricalQuantileDist): + loss = quantile_huber(dist_params_target['values'], + dist_params['values'], + dist_params['quantile_fractions'], W) + # the rest here is only needed for metrics dict + Q = self.q.proba_dist.mean(dist_params) + Q = self.q.proba_dist.postprocess_variate(next(rngs), Q, batch_mode=True) + G = self.q.proba_dist.mean(dist_params_target) + G = self.q.proba_dist.postprocess_variate(next(rngs), G, batch_mode=True) + else: + Q, state_new = self.q.function_type1(params, state, next(rngs), S, A, True) + G = self.target_func(target_params, target_state, next(rngs), transition_batch) + # flip sign (typical example: regularizer = -beta * entropy) + G -= regularizer + loss = self.loss_function(G, Q, W) dLoss_dQ = jax.grad(self.loss_function, argnums=1) td_error = -Q.shape[0] * dLoss_dQ(G, Q) # e.g. (G - Q) if loss function is MSE @@ -149,7 +170,11 @@ def loss_func(params, target_params, state, target_state, rng, transition_batch) Q_targ_list = [] qs = list(zip(self.q_targ_list, target_params['q_targ'], target_state['q_targ'])) for q, pm, st in qs: - Q_targ, _ = q.function_type1(pm, st, next(rngs), S, A, False) + if is_stochastic(q): + Q_targ = q.mean_func_type1(pm, st, next(rngs), S, A) + Q_targ = q.proba_dist.postprocess_variate(next(rngs), Q_targ, batch_mode=True) + else: + Q_targ, _ = q.function_type1(pm, st, next(rngs), S, A, False) assert Q_targ.ndim == 1, f"bad shape: {Q_targ.shape}" Q_targ_list.append(Q_targ) Q_targ_list = jnp.stack(Q_targ_list, axis=-1) @@ -184,10 +209,6 @@ def td_error_func(params, target_params, state, target_state, rng, transition_ba self._grads_and_metrics_func = jit(grads_and_metrics_func) self._td_error_func = jit(td_error_func) - @property - def q(self): - return self._f - @property def target_params(self): return hk.data_structures.to_immutable_dict({ @@ -211,12 +232,18 @@ def target_func(self, target_params, target_state, rng, transition_batch): # collect list of q-values if isinstance(self.q.action_space, Discrete): Q_sa_next_list = [] + A_next_list = [] qs = list(zip(self.q_targ_list, target_params['q_targ'], target_state['q_targ'])) # compute A_next from q_i for q_i, params_i, state_i in qs: S_next = q_i.observation_preprocessor(next(rngs), transition_batch.S_next) - Q_s_next, _ = q_i.function_type2(params_i, state_i, next(rngs), S_next, False) + if is_stochastic(q_i): + Q_s_next = q_i.mean_func_type2(params_i, state_i, next(rngs), S_next) + Q_s_next = q_i.proba_dist.postprocess_variate( + next(rngs), Q_s_next, batch_mode=True) + else: + Q_s_next, _ = q_i.function_type2(params_i, state_i, next(rngs), S_next, False) assert Q_s_next.ndim == 2, f"bad shape: {Q_s_next.shape}" A_next = (Q_s_next == Q_s_next.max(axis=1, keepdims=True)).astype(Q_s_next.dtype) A_next /= A_next.sum(axis=1, keepdims=True) # there may be ties @@ -224,14 +251,22 @@ def target_func(self, target_params, target_state, rng, transition_batch): # evaluate on q_j for q_j, params_j, state_j in qs: S_next = q_j.observation_preprocessor(next(rngs), transition_batch.S_next) - Q_sa_next, _ = q_j.function_type1( - params_j, state_j, next(rngs), S_next, A_next, False) + if is_stochastic(q_j): + Q_sa_next = q_j.mean_func_type1( + params_j, state_j, next(rngs), S_next, A_next) + Q_sa_next = q_j.proba_dist.postprocess_variate( + next(rngs), Q_sa_next, batch_mode=True) + else: + Q_sa_next, _ = q_j.function_type1( + params_j, state_j, next(rngs), S_next, A_next, False) assert Q_sa_next.ndim == 1, f"bad shape: {Q_sa_next.shape}" f_inv = q_j.value_transform.inverse_func Q_sa_next_list.append(f_inv(Q_sa_next)) + A_next_list.append(A_next) else: Q_sa_next_list = [] + A_next_list = [] qs = list(zip(self.q_targ_list, target_params['q_targ'], target_state['q_targ'])) pis = list(zip(self.pi_targ_list, target_params['pi_targ'], target_state['pi_targ'])) @@ -244,17 +279,55 @@ def target_func(self, target_params, target_state, rng, transition_batch): # evaluate on q_j for q_j, params_j, state_j in qs: S_next = q_j.observation_preprocessor(next(rngs), transition_batch.S_next) - Q_sa_next, _ = q_j.function_type1( - params_j, state_j, next(rngs), S_next, A_next, False) + if is_stochastic(q_j): + Q_sa_next = q_j.mean_func_type1( + params_j, state_j, next(rngs), S_next, A_next) + Q_sa_next = q_j.proba_dist.postprocess_variate( + next(rngs), Q_sa_next, batch_mode=True) + else: + Q_sa_next, _ = q_j.function_type1( + params_j, state_j, next(rngs), S_next, A_next, False) assert Q_sa_next.ndim == 1, f"bad shape: {Q_sa_next.shape}" f_inv = q_j.value_transform.inverse_func Q_sa_next_list.append(f_inv(Q_sa_next)) + A_next_list.append(A_next) # take the min to mitigate over-estimation + A_next_list = jnp.stack(A_next_list, axis=1) Q_sa_next_list = jnp.stack(Q_sa_next_list, axis=-1) assert Q_sa_next_list.ndim == 2, f"bad shape: {Q_sa_next_list.shape}" - Q_sa_next = jnp.min(Q_sa_next_list, axis=-1) + if is_stochastic(self.q): + Q_sa_next_argmin = jnp.argmin(Q_sa_next_list, axis=-1) + Q_sa_next_argmin_q = Q_sa_next_argmin % len(self.q_targ_list) + + def target_dist_params(A_next_idx, q_targ_idx, p, s, t, A_next_list): + return self._get_target_dist_params(batch_to_single(p, q_targ_idx), + batch_to_single(s, q_targ_idx), + next(rngs), + single_to_batch(t), + single_to_batch(batch_to_single(A_next_list, + A_next_idx))) + + def tile_parameters(params, state, reps): + return jax.tree_util.tree_map(lambda t: jnp.tile(t, [reps, *([1] * (t.ndim - 1))]), + stack_trees(params, state)) + # stack and tile q-function params to select the argmin for the target dist params + tiled_target_params, tiled_target_state = tile_parameters( + target_params['q_targ'], target_state['q_targ'], reps=len(self.q_targ_list)) + + vtarget_dist_params = jax.vmap(target_dist_params, in_axes=(0, 0, None, None, 0, 0)) + dist_params = vtarget_dist_params( + Q_sa_next_argmin, + Q_sa_next_argmin_q, + tiled_target_params, + tiled_target_state, + transition_batch, + A_next_list) + # unwrap dist params computed for single batches + return jax.tree_util.tree_map(lambda t: jnp.squeeze(t, axis=1), dist_params) + + Q_sa_next = jnp.min(Q_sa_next_list, axis=-1) assert Q_sa_next.ndim == 1, f"bad shape: {Q_sa_next.shape}" f = self.q.value_transform.transform_func return f(transition_batch.Rn + transition_batch.In * Q_sa_next) @@ -283,5 +356,5 @@ def _check_input_lists(self, pi_targ_list, q_targ_list): if not q_targ_list: raise ValueError("q_targ_list cannot be empty") for q_targ in q_targ_list: - if not isinstance(q_targ, Q): + if not is_qfunction(q_targ): raise TypeError(f"all q_targ in q_targ_list must be a coax.Q, got: {type(q_targ)}") diff --git a/coax/td_learning/_clippeddoubleqlearning_test.py b/coax/td_learning/_clippeddoubleqlearning_test.py index 26e7e85..f69918d 100644 --- a/coax/td_learning/_clippeddoubleqlearning_test.py +++ b/coax/td_learning/_clippeddoubleqlearning_test.py @@ -4,6 +4,7 @@ from .._base.test_case import TestCase from .._core.q import Q +from .._core.stochastic_q import StochasticQ from .._core.policy import Policy from ..utils import get_transition_batch from ._clippeddoubleqlearning import ClippedDoubleQLearning @@ -40,6 +41,31 @@ def test_update_discrete_type1(self): self.assertPytreeNotEqual(function_state1, q1.function_state) self.assertPytreeNotEqual(function_state2, q2.function_state) + def test_update_discrete_stochastic_type1(self): + env = self.env_discrete + func_q = self.func_q_stochastic_type1 + transition_batch = self.transition_discrete + + q1 = StochasticQ(func_q, env, value_range=(0, 1)) + q2 = StochasticQ(func_q, env, value_range=(0, 1)) + q_targ1 = q1.copy() + q_targ2 = q2.copy() + updater1 = ClippedDoubleQLearning(q1, q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0)) + updater2 = ClippedDoubleQLearning(q2, q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0)) + + params1 = deepcopy(q1.params) + params2 = deepcopy(q2.params) + function_state1 = deepcopy(q1.function_state) + function_state2 = deepcopy(q2.function_state) + + updater1.update(transition_batch) + updater2.update(transition_batch) + + self.assertPytreeNotEqual(params1, q1.params) + self.assertPytreeNotEqual(params2, q2.params) + self.assertPytreeNotEqual(function_state1, q1.function_state) + self.assertPytreeNotEqual(function_state2, q2.function_state) + def test_update_discrete_type2(self): env = self.env_discrete func_q = self.func_q_type2 @@ -65,6 +91,31 @@ def test_update_discrete_type2(self): self.assertPytreeNotEqual(function_state1, q1.function_state) self.assertPytreeNotEqual(function_state2, q2.function_state) + def test_update_discrete_stochastic_type2(self): + env = self.env_discrete + func_q = self.func_q_stochastic_type2 + transition_batch = self.transition_discrete + + q1 = StochasticQ(func_q, env, value_range=(0, 1)) + q2 = StochasticQ(func_q, env, value_range=(0, 1)) + q_targ1 = q1.copy() + q_targ2 = q2.copy() + updater1 = ClippedDoubleQLearning(q1, q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0)) + updater2 = ClippedDoubleQLearning(q2, q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0)) + + params1 = deepcopy(q1.params) + params2 = deepcopy(q2.params) + function_state1 = deepcopy(q1.function_state) + function_state2 = deepcopy(q2.function_state) + + updater1.update(transition_batch) + updater2.update(transition_batch) + + self.assertPytreeNotEqual(params1, q1.params) + self.assertPytreeNotEqual(params2, q2.params) + self.assertPytreeNotEqual(function_state1, q1.function_state) + self.assertPytreeNotEqual(function_state2, q2.function_state) + def test_update_boxspace(self): env = self.env_boxspace func_q = self.func_q_type1 @@ -95,6 +146,36 @@ def test_update_boxspace(self): self.assertPytreeNotEqual(function_state1, q1.function_state) self.assertPytreeNotEqual(function_state2, q2.function_state) + def test_update_boxspace_stochastic(self): + env = self.env_boxspace + func_q = self.func_q_stochastic_type1 + func_pi = self.func_pi_boxspace + transition_batch = self.transition_boxspace + + q1 = StochasticQ(func_q, env, value_range=(0, 1)) + q2 = StochasticQ(func_q, env, value_range=(0, 1)) + pi1 = Policy(func_pi, env) + pi2 = Policy(func_pi, env) + q_targ1 = q1.copy() + q_targ2 = q2.copy() + updater1 = ClippedDoubleQLearning( + q1, pi_targ_list=[pi1, pi2], q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0)) + updater2 = ClippedDoubleQLearning( + q2, pi_targ_list=[pi1, pi2], q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0)) + + params1 = deepcopy(q1.params) + params2 = deepcopy(q2.params) + function_state1 = deepcopy(q1.function_state) + function_state2 = deepcopy(q2.function_state) + + updater1.update(transition_batch) + updater2.update(transition_batch) + + self.assertPytreeNotEqual(params1, q1.params) + self.assertPytreeNotEqual(params2, q2.params) + self.assertPytreeNotEqual(function_state1, q1.function_state) + self.assertPytreeNotEqual(function_state2, q2.function_state) + def test_discrete_with_pi(self): env = self.env_discrete func_q = self.func_q_type1 diff --git a/coax/td_learning/_softclippeddoubleqlearning.py b/coax/td_learning/_softclippeddoubleqlearning.py index 316ef00..68bd0a8 100644 --- a/coax/td_learning/_softclippeddoubleqlearning.py +++ b/coax/td_learning/_softclippeddoubleqlearning.py @@ -1,7 +1,10 @@ -import jax.numpy as jnp import haiku as hk +import jax +import jax.numpy as jnp from gym.spaces import Discrete +from ..utils import (batch_to_single, is_stochastic, single_to_batch, + stack_trees) from ._clippeddoubleqlearning import ClippedDoubleQLearning @@ -17,12 +20,18 @@ def target_func(self, target_params, target_state, rng, transition_batch): # collect list of q-values if isinstance(self.q.action_space, Discrete): Q_sa_next_list = [] + A_next_list = [] qs = list(zip(self.q_targ_list, target_params['q_targ'], target_state['q_targ'])) # compute A_next from q_i for q_i, params_i, state_i in qs: S_next = q_i.observation_preprocessor(next(rngs), transition_batch.S_next) - Q_s_next, _ = q_i.function_type2(params_i, state_i, next(rngs), S_next, False) + if is_stochastic(q_i): + Q_s_next = q_i.mean_func_type2(params_i, state_i, next(rngs), S_next) + Q_s_next = q_i.proba_dist.postprocess_variate( + next(rngs), Q_s_next, batch_mode=True) + else: + Q_s_next, _ = q_i.function_type2(params_i, state_i, next(rngs), S_next, False) assert Q_s_next.ndim == 2, f"bad shape: {Q_s_next.shape}" A_next = (Q_s_next == Q_s_next.max(axis=1, keepdims=True)).astype(Q_s_next.dtype) A_next /= A_next.sum(axis=1, keepdims=True) # there may be ties @@ -30,14 +39,22 @@ def target_func(self, target_params, target_state, rng, transition_batch): # evaluate on q_j for q_j, params_j, state_j in qs: S_next = q_j.observation_preprocessor(next(rngs), transition_batch.S_next) - Q_sa_next, _ = q_j.function_type1( - params_j, state_j, next(rngs), S_next, A_next, False) + if is_stochastic(q_j): + Q_sa_next = q_j.mean_func_type1( + params_j, state_j, next(rngs), S_next, A_next) + Q_sa_next = q_j.proba_dist.postprocess_variate( + next(rngs), Q_sa_next, batch_mode=True) + else: + Q_sa_next, _ = q_j.function_type1( + params_j, state_j, next(rngs), S_next, A_next, False) assert Q_sa_next.ndim == 1, f"bad shape: {Q_sa_next.shape}" f_inv = q_j.value_transform.inverse_func Q_sa_next_list.append(f_inv(Q_sa_next)) + A_next_list.append(A_next) else: Q_sa_next_list = [] + A_next_list = [] qs = list(zip(self.q_targ_list, target_params['q_targ'], target_state['q_targ'])) pis = list(zip(self.pi_targ_list, target_params['pi_targ'], target_state['pi_targ'])) @@ -45,22 +62,60 @@ def target_func(self, target_params, target_state, rng, transition_batch): for pi_i, params_i, state_i in pis: S_next = pi_i.observation_preprocessor(next(rngs), transition_batch.S_next) dist_params, _ = pi_i.function(params_i, state_i, next(rngs), S_next, False) - A_next = pi_i.proba_dist.sample(dist_params, next(rngs)) + A_next = pi_i.proba_dist.sample(dist_params, next(rngs)) # sample instead of mode # evaluate on q_j for q_j, params_j, state_j in qs: S_next = q_j.observation_preprocessor(next(rngs), transition_batch.S_next) - Q_sa_next, _ = q_j.function_type1( - params_j, state_j, next(rngs), S_next, A_next, False) + if is_stochastic(q_j): + Q_sa_next = q_j.mean_func_type1( + params_j, state_j, next(rngs), S_next, A_next) + Q_sa_next = q_j.proba_dist.postprocess_variate( + next(rngs), Q_sa_next, batch_mode=True) + else: + Q_sa_next, _ = q_j.function_type1( + params_j, state_j, next(rngs), S_next, A_next, False) assert Q_sa_next.ndim == 1, f"bad shape: {Q_sa_next.shape}" f_inv = q_j.value_transform.inverse_func Q_sa_next_list.append(f_inv(Q_sa_next)) + A_next_list.append(A_next) # take the min to mitigate over-estimation + A_next_list = jnp.stack(A_next_list, axis=1) Q_sa_next_list = jnp.stack(Q_sa_next_list, axis=-1) assert Q_sa_next_list.ndim == 2, f"bad shape: {Q_sa_next_list.shape}" - Q_sa_next = jnp.min(Q_sa_next_list, axis=-1) + if is_stochastic(self.q): + Q_sa_next_argmin = jnp.argmin(Q_sa_next_list, axis=-1) + Q_sa_next_argmin_q = Q_sa_next_argmin % len(self.q_targ_list) + + def target_dist_params(A_next_idx, q_targ_idx, p, s, t, A_next_list): + return self._get_target_dist_params(batch_to_single(p, q_targ_idx), + batch_to_single(s, q_targ_idx), + next(rngs), + single_to_batch(t), + single_to_batch(batch_to_single(A_next_list, + A_next_idx))) + + def tile_parameters(params, state, reps): + return jax.tree_util.tree_map(lambda t: jnp.tile(t, [reps, *([1] * (t.ndim - 1))]), + stack_trees(params, state)) + # stack and tile q-function params to select the argmin for the target dist params + tiled_target_params, tiled_target_state = tile_parameters( + target_params['q_targ'], target_state['q_targ'], reps=len(self.q_targ_list)) + + vtarget_dist_params = jax.vmap(target_dist_params, in_axes=(0, 0, None, None, 0, 0)) + dist_params = vtarget_dist_params( + Q_sa_next_argmin, + Q_sa_next_argmin_q, + tiled_target_params, + tiled_target_state, + transition_batch, + A_next_list) + # unwrap dist params computed for single batches + return jax.tree_util.tree_map(lambda t: jnp.squeeze(t, axis=1), dist_params) + + Q_sa_next = jnp.min(Q_sa_next_list, axis=-1) assert Q_sa_next.ndim == 1, f"bad shape: {Q_sa_next.shape}" f = self.q.value_transform.transform_func return f(transition_batch.Rn + transition_batch.In * Q_sa_next) diff --git a/coax/utils/__init__.py b/coax/utils/__init__.py index 56c29e8..f0d532a 100644 --- a/coax/utils/__init__.py +++ b/coax/utils/__init__.py @@ -59,6 +59,7 @@ coax.utils.render_episode coax.utils.safe_sample coax.utils.single_to_batch + coax.utils.stack_trees coax.utils.tree_ravel coax.utils.unvectorize @@ -87,6 +88,7 @@ merge_dicts, safe_sample, single_to_batch, + stack_trees, tree_ravel, unvectorize, ) @@ -165,6 +167,7 @@ 'render_episode', 'safe_sample', 'single_to_batch', + 'stack_trees', 'tree_ravel', 'unvectorize', ) diff --git a/coax/utils/_array.py b/coax/utils/_array.py index c57f963..6d5be87 100644 --- a/coax/utils/_array.py +++ b/coax/utils/_array.py @@ -31,6 +31,7 @@ 'merge_dicts', 'single_to_batch', 'safe_sample', + 'stack_trees', 'tree_ravel', 'tree_sample', 'unvectorize', @@ -752,6 +753,7 @@ class StepwiseLinearFunction: """ + def __init__(self, *steps): if len(steps) < 2: raise TypeError("need at least two steps") @@ -1053,3 +1055,20 @@ def _check_leaf_batch_size(pytree): if leaf.shape[0] != batch_size: raise TypeError("all leaves must have the same batch_size") return batch_size + + +def stack_trees(*trees): + """ + Stack + Parameters + ---------- + trees : sequence of pytrees with ndarray leaves + A typical example are pytrees containing the parameters and function states of + a model that should be used in a function which is vectorized by `jax.vmap`. The trees + have to have the same pytree structure. + Returns + ------- + pytree : pytree with ndarray leaves + A tuple of pytrees. + """ + return jax.tree_util.tree_multimap(lambda *args: jnp.stack(args), *zip(*trees)) diff --git a/coax/utils/_array_test.py b/coax/utils/_array_test.py index 1266e01..91a8035 100644 --- a/coax/utils/_array_test.py +++ b/coax/utils/_array_test.py @@ -98,7 +98,6 @@ def test_default_preprocessor(self): self.assertArrayShape(default_preprocessor(mds)(next(rngs), mds_batch)[0], (7, 3)) self.assertArrayShape(default_preprocessor(mds)(next(rngs), mds_batch)[1], (7, 5)) - def test_chunks_pow2(self): chunk_sizes = (2048, 1024, 512, 64, 32, 1) tn = get_transition_batch(self.env_discrete, batch_size=sum(chunk_sizes)) diff --git a/doc/_notebooks/cartpole/iqn.ipynb b/doc/_notebooks/cartpole/iqn.ipynb index 1420bfa..3c211ab 100644 --- a/doc/_notebooks/cartpole/iqn.ipynb +++ b/doc/_notebooks/cartpole/iqn.ipynb @@ -40,22 +40,18 @@ "env = gym.make('CartPole-v0')\n", "env = coax.wrappers.TrainMonitor(\n", " env, name=name, tensorboard_dir=f\"./data/tensorboard/{name}\")\n", - "quantile_embedding_dim = 32\n", + "quantile_embedding_dim = 64\n", "layer_size = 256\n", "num_quantiles = 32\n", "\n", "\n", "def quantile_net(x, quantile_fractions):\n", - " x_size = x.shape[-1]\n", - " x_tiled = jnp.tile(x[:, None, :], [num_quantiles, 1])\n", " quantiles_emb = coax.utils.quantile_cos_embedding(\n", " quantile_fractions, quantile_embedding_dim)\n", - " quantiles_emb = hk.Linear(x_size)(quantiles_emb)\n", - " quantiles_emb = hk.LayerNorm(axis=-1, create_scale=True,\n", - " create_offset=True)(quantiles_emb)\n", - " quantiles_emb = jax.nn.sigmoid(quantiles_emb)\n", - " x = x_tiled * quantiles_emb\n", - " x = hk.Linear(x_size)(x)\n", + " quantiles_emb = hk.Linear(x.shape[-1])(quantiles_emb)\n", + " quantiles_emb = jax.nn.relu(quantiles_emb)\n", + " x = x[:, None, :] * quantiles_emb\n", + " x = hk.Linear(layer_size)(x)\n", " x = jax.nn.relu(x)\n", " return x\n", "\n", @@ -66,9 +62,9 @@ " hk.Flatten(), hk.Linear(layer_size), jax.nn.relu\n", " ))\n", " quantile_fractions = coax.utils.quantiles_uniform(rng=hk.next_rng_key(),\n", - " batch_size=jax.tree_leaves(S)[0].shape[0],\n", - " num_quantiles=num_quantiles)\n", - " X = jax.vmap(jnp.kron)(S, A)\n", + " batch_size=S.shape[0],\n", + " num_quantiles=num_quantiles)\n", + " X = jnp.concatenate((S, A), axis=-1)\n", " x = encoder(X)\n", " quantile_x = quantile_net(x, quantile_fractions=quantile_fractions)\n", " quantile_values = hk.Linear(1, w_init=jnp.zeros)(quantile_x)\n", @@ -88,7 +84,7 @@ "buffer = coax.experience_replay.SimpleReplayBuffer(capacity=100000)\n", "\n", "# updater\n", - "qlearning = coax.td_learning.QLearning(q, q_targ=q_targ, optimizer=adam(0.001))\n", + "qlearning = coax.td_learning.QLearning(q, q_targ=q_targ, optimizer=adam(1e-3))\n", "\n", "\n", "# train\n", diff --git a/doc/_notebooks/pendulum/dsac.ipynb b/doc/_notebooks/pendulum/dsac.ipynb new file mode 100644 index 0000000..94431aa --- /dev/null +++ b/doc/_notebooks/pendulum/dsac.ipynb @@ -0,0 +1,194 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install git+https://github.com/coax-dev/coax.git@main" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext tensorboard\n", + "%tensorboard --logdir ./data/tensorboard" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import gym\n", + "import jax\n", + "import coax\n", + "import haiku as hk\n", + "import jax.numpy as jnp\n", + "from numpy import prod\n", + "import optax\n", + "\n", + "\n", + "# the name of this script\n", + "name = 'dsac'\n", + "\n", + "# the Pendulum MDP\n", + "env = gym.make('Pendulum-v1')\n", + "env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f\"./data/tensorboard/{name}\")\n", + "\n", + "quantile_embedding_dim = 64\n", + "layer_size = 256\n", + "num_quantiles = 32\n", + "\n", + "\n", + "def func_pi(S, is_training):\n", + " seq = hk.Sequential((\n", + " hk.Linear(8), jax.nn.relu,\n", + " hk.Linear(8), jax.nn.relu,\n", + " hk.Linear(8), jax.nn.relu,\n", + " hk.Linear(prod(env.action_space.shape) * 2, w_init=jnp.zeros),\n", + " hk.Reshape((*env.action_space.shape, 2)),\n", + " ))\n", + " x = seq(S)\n", + " mu, logvar = x[..., 0], x[..., 1]\n", + " return {'mu': mu, 'logvar': logvar}\n", + "\n", + "\n", + "def quantile_net(x, quantile_fractions):\n", + " quantiles_emb = coax.utils.quantile_cos_embedding(\n", + " quantile_fractions, quantile_embedding_dim)\n", + " quantiles_emb = hk.Linear(x.shape[-1])(quantiles_emb)\n", + " quantiles_emb = jax.nn.relu(quantiles_emb)\n", + " x = x[:, None, :] * quantiles_emb\n", + " x = hk.Linear(layer_size)(x)\n", + " x = jax.nn.relu(x)\n", + " return x\n", + "\n", + "\n", + "def func_q(S, A, is_training):\n", + " encoder = hk.Sequential((\n", + " hk.Flatten(),\n", + " hk.Linear(layer_size),\n", + " jax.nn.relu\n", + " ))\n", + " quantile_fractions = coax.utils.quantiles_uniform(rng=hk.next_rng_key(),\n", + " batch_size=S.shape[0],\n", + " num_quantiles=num_quantiles)\n", + " X = jnp.concatenate((S, A), axis=-1)\n", + " x = encoder(X)\n", + " quantile_x = quantile_net(x, quantile_fractions=quantile_fractions)\n", + " quantile_values = hk.Linear(1)(quantile_x)\n", + " return {'values': quantile_values.squeeze(axis=-1),\n", + " 'quantile_fractions': quantile_fractions}\n", + "\n", + "\n", + "# main function approximators\n", + "pi = coax.Policy(func_pi, env)\n", + "q1 = coax.StochasticQ(func_q, env, action_preprocessor=pi.proba_dist.preprocess_variate,\n", + " value_range=None, num_bins=num_quantiles)\n", + "q2 = coax.StochasticQ(func_q, env, action_preprocessor=pi.proba_dist.preprocess_variate,\n", + " value_range=None, num_bins=num_quantiles)\n", + "\n", + "# target network\n", + "q1_targ = q1.copy()\n", + "q2_targ = q2.copy()\n", + "\n", + "# experience tracer\n", + "tracer = coax.reward_tracing.NStep(n=5, gamma=0.9, record_extra_info=True)\n", + "buffer = coax.experience_replay.SimpleReplayBuffer(capacity=50000)\n", + "alpha = 0.2\n", + "policy_regularizer = coax.regularizers.NStepEntropyRegularizer(pi,\n", + " beta=alpha / tracer.n,\n", + " gamma=tracer.gamma,\n", + " n=[tracer.n])\n", + "\n", + "# updaters (use current pi to update the q-functions and use sampled action in contrast to TD3)\n", + "qlearning1 = coax.td_learning.SoftClippedDoubleQLearning(\n", + " q1, pi_targ_list=[pi], q_targ_list=[q1_targ, q2_targ],\n", + " loss_function=coax.value_losses.mse, optimizer=optax.adam(3e-4),\n", + " policy_regularizer=policy_regularizer)\n", + "qlearning2 = coax.td_learning.SoftClippedDoubleQLearning(\n", + " q2, pi_targ_list=[pi], q_targ_list=[q1_targ, q2_targ],\n", + " loss_function=coax.value_losses.mse, optimizer=optax.adam(3e-4),\n", + " policy_regularizer=policy_regularizer)\n", + "soft_pg = coax.policy_objectives.SoftPG(pi, [q1_targ, q2_targ], optimizer=optax.adam(\n", + " 1e-3), regularizer=coax.regularizers.NStepEntropyRegularizer(pi,\n", + " beta=alpha / tracer.n,\n", + " gamma=tracer.gamma,\n", + " n=jnp.arange(tracer.n)))\n", + "\n", + "\n", + "# train\n", + "while env.T < 1000000:\n", + " s = env.reset()\n", + "\n", + " for t in range(env.spec.max_episode_steps):\n", + " a = pi(s)\n", + " s_next, r, done, info = env.step(a)\n", + "\n", + " # trace rewards and add transition to replay buffer\n", + " tracer.add(s, a, r, done)\n", + " while tracer:\n", + " buffer.add(tracer.pop())\n", + "\n", + " # learn\n", + " if len(buffer) >= 5000:\n", + " transition_batch = buffer.sample(batch_size=256)\n", + "\n", + " # init metrics dict\n", + " metrics = {}\n", + "\n", + " # flip a coin to decide which of the q-functions to update\n", + " qlearning = qlearning1 if jax.random.bernoulli(q1.rng) else qlearning2\n", + " metrics.update(qlearning.update(transition_batch))\n", + "\n", + " # delayed policy updates\n", + " if env.T >= 7500 and env.T % 4 == 0:\n", + " metrics.update(soft_pg.update(transition_batch))\n", + "\n", + " env.record_metrics(metrics)\n", + "\n", + " # sync target networks\n", + " q1_targ.soft_update(q1, tau=0.005)\n", + " q2_targ.soft_update(q2, tau=0.005)\n", + "\n", + " if done:\n", + " break\n", + "\n", + " s = s_next\n", + "\n", + " # generate an animated GIF to see what's going on\n", + " # if env.period(name='generate_gif', T_period=10000) and env.T > 5000:\n", + " # T = env.T - env.T % 10000 # round to 10000s\n", + " # coax.utils.generate_gif(\n", + " # env=env, policy=pi, filepath=f\"./data/gifs/{name}/T{T:08d}.gif\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/doc/_notebooks/pendulum/td4.ipynb b/doc/_notebooks/pendulum/td4.ipynb new file mode 100644 index 0000000..c1b803d --- /dev/null +++ b/doc/_notebooks/pendulum/td4.ipynb @@ -0,0 +1,186 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install git+https://github.com/coax-dev/coax.git@main" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext tensorboard\n", + "%tensorboard --logdir ./data/tensorboard" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import gym\n", + "import jax\n", + "import coax\n", + "import haiku as hk\n", + "import jax.numpy as jnp\n", + "from numpy import prod\n", + "import optax\n", + "\n", + "\n", + "# the name of this script\n", + "name = 'td3'\n", + "\n", + "# the Pendulum MDP\n", + "env = gym.make('Pendulum-v1')\n", + "env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f\"./data/tensorboard/{name}\")\n", + "quantile_embedding_dim = 64\n", + "layer_size = 256\n", + "num_quantiles = 32\n", + "\n", + "\n", + "def func_pi(S, is_training):\n", + " seq = hk.Sequential((\n", + " hk.Linear(8), jax.nn.relu,\n", + " hk.Linear(8), jax.nn.relu,\n", + " hk.Linear(8), jax.nn.relu,\n", + " hk.Linear(prod(env.action_space.shape), w_init=jnp.zeros),\n", + " hk.Reshape(env.action_space.shape),\n", + " ))\n", + " mu = seq(S)\n", + " return {'mu': mu, 'logvar': jnp.full_like(mu, jnp.log(0.05))} # (almost) deterministic\n", + "\n", + "\n", + "def quantile_net(x, quantile_fractions):\n", + " quantiles_emb = coax.utils.quantile_cos_embedding(\n", + " quantile_fractions, quantile_embedding_dim)\n", + " quantiles_emb = hk.Linear(x.shape[-1])(quantiles_emb)\n", + " quantiles_emb = jax.nn.relu(quantiles_emb)\n", + " x = x[:, None, :] * quantiles_emb\n", + " x = hk.Linear(layer_size)(x)\n", + " x = jax.nn.relu(x)\n", + " return x\n", + "\n", + "\n", + "def func_q(S, A, is_training):\n", + " encoder = hk.Sequential((\n", + " hk.Flatten(),\n", + " hk.Linear(layer_size),\n", + " jax.nn.relu\n", + " ))\n", + " quantile_fractions = coax.utils.quantiles_uniform(rng=hk.next_rng_key(),\n", + " batch_size=S.shape[0],\n", + " num_quantiles=num_quantiles)\n", + " X = jnp.concatenate((S, A), axis=-1)\n", + " x = encoder(X)\n", + " quantile_x = quantile_net(x, quantile_fractions=quantile_fractions)\n", + " quantile_values = hk.Linear(1)(quantile_x)\n", + " return {'values': quantile_values.squeeze(axis=-1),\n", + " 'quantile_fractions': quantile_fractions}\n", + "\n", + "\n", + "# main function approximators\n", + "pi = coax.Policy(func_pi, env)\n", + "q1 = coax.StochasticQ(func_q, env, action_preprocessor=pi.proba_dist.preprocess_variate,\n", + " value_range=None, num_bins=num_quantiles)\n", + "q2 = coax.StochasticQ(func_q, env, action_preprocessor=pi.proba_dist.preprocess_variate,\n", + " value_range=None, num_bins=num_quantiles)\n", + "\n", + "\n", + "# target network\n", + "q1_targ = q1.copy()\n", + "q2_targ = q2.copy()\n", + "pi_targ = pi.copy()\n", + "\n", + "\n", + "# experience tracer\n", + "tracer = coax.reward_tracing.NStep(n=5, gamma=0.9)\n", + "buffer = coax.experience_replay.SimpleReplayBuffer(capacity=25000)\n", + "\n", + "\n", + "# updaters\n", + "qlearning1 = coax.td_learning.ClippedDoubleQLearning(\n", + " q1, pi_targ_list=[pi_targ], q_targ_list=[q1_targ, q2_targ],\n", + " loss_function=coax.value_losses.mse, optimizer=optax.adam(1e-3))\n", + "qlearning2 = coax.td_learning.ClippedDoubleQLearning(\n", + " q2, pi_targ_list=[pi_targ], q_targ_list=[q1_targ, q2_targ],\n", + " loss_function=coax.value_losses.mse, optimizer=optax.adam(1e-3))\n", + "determ_pg = coax.policy_objectives.DeterministicPG(pi, q1_targ, optimizer=optax.adam(1e-3))\n", + "\n", + "\n", + "# train\n", + "while env.T < 1000000:\n", + " s = env.reset()\n", + "\n", + " for t in range(env.spec.max_episode_steps):\n", + " a = pi.mode(s)\n", + " s_next, r, done, info = env.step(a)\n", + "\n", + " # trace rewards and add transition to replay buffer\n", + " tracer.add(s, a, r, done)\n", + " while tracer:\n", + " buffer.add(tracer.pop())\n", + "\n", + " # learn\n", + " if len(buffer) >= 5000:\n", + " transition_batch = buffer.sample(batch_size=128)\n", + "\n", + " # init metrics dict\n", + " metrics = {}\n", + "\n", + " # flip a coin to decide which of the q-functions to update\n", + " qlearning = qlearning1 if jax.random.bernoulli(q1.rng) else qlearning2\n", + " metrics.update(qlearning.update(transition_batch))\n", + "\n", + " # delayed policy updates\n", + " if env.T >= 7500 and env.T % 4 == 0:\n", + " metrics.update(determ_pg.update(transition_batch))\n", + "\n", + " env.record_metrics(metrics)\n", + "\n", + " # sync target networks\n", + " q1_targ.soft_update(q1, tau=0.001)\n", + " q2_targ.soft_update(q2, tau=0.001)\n", + " pi_targ.soft_update(pi, tau=0.001)\n", + "\n", + " if done:\n", + " break\n", + "\n", + " s = s_next\n", + "\n", + " # generate an animated GIF to see what's going on\n", + " # if env.period(name='generate_gif', T_period=10000) and env.T > 5000:\n", + " # T = env.T - env.T % 10000 # round to 10000s\n", + " # coax.utils.generate_gif(\n", + " # env=env, policy=pi, filepath=f\"./data/gifs/{name}/T{T:08d}.gif\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/doc/examples/cartpole/iqn.py b/doc/examples/cartpole/iqn.py index 9266569..831cede 100644 --- a/doc/examples/cartpole/iqn.py +++ b/doc/examples/cartpole/iqn.py @@ -13,22 +13,18 @@ env = gym.make('CartPole-v0') env = coax.wrappers.TrainMonitor( env, name=name, tensorboard_dir=f"./data/tensorboard/{name}") -quantile_embedding_dim = 32 +quantile_embedding_dim = 64 layer_size = 256 num_quantiles = 32 def quantile_net(x, quantile_fractions): - x_size = x.shape[-1] - x_tiled = jnp.tile(x[:, None, :], [num_quantiles, 1]) quantiles_emb = coax.utils.quantile_cos_embedding( quantile_fractions, quantile_embedding_dim) - quantiles_emb = hk.Linear(x_size)(quantiles_emb) - quantiles_emb = hk.LayerNorm(axis=-1, create_scale=True, - create_offset=True)(quantiles_emb) - quantiles_emb = jax.nn.sigmoid(quantiles_emb) - x = x_tiled * quantiles_emb - x = hk.Linear(x_size)(x) + quantiles_emb = hk.Linear(x.shape[-1])(quantiles_emb) + quantiles_emb = jax.nn.relu(quantiles_emb) + x = x[:, None, :] * quantiles_emb + x = hk.Linear(layer_size)(x) x = jax.nn.relu(x) return x @@ -39,9 +35,9 @@ def func(S, A, is_training): hk.Flatten(), hk.Linear(layer_size), jax.nn.relu )) quantile_fractions = coax.utils.quantiles_uniform(rng=hk.next_rng_key(), - batch_size=jax.tree_leaves(S)[0].shape[0], - num_quantiles=num_quantiles) - X = jax.vmap(jnp.kron)(S, A) + batch_size=S.shape[0], + num_quantiles=num_quantiles) + X = jnp.concatenate((S, A), axis=-1) x = encoder(X) quantile_x = quantile_net(x, quantile_fractions=quantile_fractions) quantile_values = hk.Linear(1, w_init=jnp.zeros)(quantile_x) @@ -61,7 +57,7 @@ def func(S, A, is_training): buffer = coax.experience_replay.SimpleReplayBuffer(capacity=100000) # updater -qlearning = coax.td_learning.QLearning(q, q_targ=q_targ, optimizer=adam(0.001)) +qlearning = coax.td_learning.QLearning(q, q_targ=q_targ, optimizer=adam(1e-3)) # train diff --git a/doc/examples/cartpole/iqn.rst b/doc/examples/cartpole/iqn.rst index 923a83d..d1d7318 100644 --- a/doc/examples/cartpole/iqn.rst +++ b/doc/examples/cartpole/iqn.rst @@ -1,5 +1,5 @@ Cartpole with IQN -================= +==================== In this notebook we solve the `CartPole `_ environment using a simple :doc:`IQN ` agent. Our function approximator is an Implicit Quantile Network that diff --git a/doc/examples/pendulum/dsac.py b/doc/examples/pendulum/dsac.py new file mode 100644 index 0000000..0767622 --- /dev/null +++ b/doc/examples/pendulum/dsac.py @@ -0,0 +1,142 @@ +import gym +import jax +import coax +import haiku as hk +import jax.numpy as jnp +from numpy import prod +import optax + + +# the name of this script +name = 'dsac' + +# the Pendulum MDP +env = gym.make('Pendulum-v1') +env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}") + +quantile_embedding_dim = 64 +layer_size = 256 +num_quantiles = 32 + + +def func_pi(S, is_training): + seq = hk.Sequential(( + hk.Linear(8), jax.nn.relu, + hk.Linear(8), jax.nn.relu, + hk.Linear(8), jax.nn.relu, + hk.Linear(prod(env.action_space.shape) * 2, w_init=jnp.zeros), + hk.Reshape((*env.action_space.shape, 2)), + )) + x = seq(S) + mu, logvar = x[..., 0], x[..., 1] + return {'mu': mu, 'logvar': logvar} + + +def quantile_net(x, quantile_fractions): + quantiles_emb = coax.utils.quantile_cos_embedding( + quantile_fractions, quantile_embedding_dim) + quantiles_emb = hk.Linear(x.shape[-1])(quantiles_emb) + quantiles_emb = jax.nn.relu(quantiles_emb) + x = x[:, None, :] * quantiles_emb + x = hk.Linear(layer_size)(x) + x = jax.nn.relu(x) + return x + + +def func_q(S, A, is_training): + encoder = hk.Sequential(( + hk.Flatten(), + hk.Linear(layer_size), + jax.nn.relu + )) + quantile_fractions = coax.utils.quantiles_uniform(rng=hk.next_rng_key(), + batch_size=S.shape[0], + num_quantiles=num_quantiles) + X = jnp.concatenate((S, A), axis=-1) + x = encoder(X) + quantile_x = quantile_net(x, quantile_fractions=quantile_fractions) + quantile_values = hk.Linear(1)(quantile_x) + return {'values': quantile_values.squeeze(axis=-1), + 'quantile_fractions': quantile_fractions} + + +# main function approximators +pi = coax.Policy(func_pi, env) +q1 = coax.StochasticQ(func_q, env, action_preprocessor=pi.proba_dist.preprocess_variate, + value_range=None, num_bins=num_quantiles) +q2 = coax.StochasticQ(func_q, env, action_preprocessor=pi.proba_dist.preprocess_variate, + value_range=None, num_bins=num_quantiles) + +# target network +q1_targ = q1.copy() +q2_targ = q2.copy() + +# experience tracer +tracer = coax.reward_tracing.NStep(n=5, gamma=0.9, record_extra_info=True) +buffer = coax.experience_replay.SimpleReplayBuffer(capacity=50000) +alpha = 0.2 +policy_regularizer = coax.regularizers.NStepEntropyRegularizer(pi, + beta=alpha / tracer.n, + gamma=tracer.gamma, + n=[tracer.n]) + +# updaters (use current pi to update the q-functions and use sampled action in contrast to TD3) +qlearning1 = coax.td_learning.SoftClippedDoubleQLearning( + q1, pi_targ_list=[pi], q_targ_list=[q1_targ, q2_targ], + loss_function=coax.value_losses.mse, optimizer=optax.adam(3e-4), + policy_regularizer=policy_regularizer) +qlearning2 = coax.td_learning.SoftClippedDoubleQLearning( + q2, pi_targ_list=[pi], q_targ_list=[q1_targ, q2_targ], + loss_function=coax.value_losses.mse, optimizer=optax.adam(3e-4), + policy_regularizer=policy_regularizer) +soft_pg = coax.policy_objectives.SoftPG(pi, [q1_targ, q2_targ], optimizer=optax.adam( + 1e-3), regularizer=coax.regularizers.NStepEntropyRegularizer(pi, + beta=alpha / tracer.n, + gamma=tracer.gamma, + n=jnp.arange(tracer.n))) + + +# train +while env.T < 1000000: + s = env.reset() + + for t in range(env.spec.max_episode_steps): + a = pi(s) + s_next, r, done, info = env.step(a) + + # trace rewards and add transition to replay buffer + tracer.add(s, a, r, done) + while tracer: + buffer.add(tracer.pop()) + + # learn + if len(buffer) >= 5000: + transition_batch = buffer.sample(batch_size=256) + + # init metrics dict + metrics = {} + + # flip a coin to decide which of the q-functions to update + qlearning = qlearning1 if jax.random.bernoulli(q1.rng) else qlearning2 + metrics.update(qlearning.update(transition_batch)) + + # delayed policy updates + if env.T >= 7500 and env.T % 4 == 0: + metrics.update(soft_pg.update(transition_batch)) + + env.record_metrics(metrics) + + # sync target networks + q1_targ.soft_update(q1, tau=0.005) + q2_targ.soft_update(q2, tau=0.005) + + if done: + break + + s = s_next + + # generate an animated GIF to see what's going on + # if env.period(name='generate_gif', T_period=10000) and env.T > 5000: + # T = env.T - env.T % 10000 # round to 10000s + # coax.utils.generate_gif( + # env=env, policy=pi, filepath=f"./data/gifs/{name}/T{T:08d}.gif") diff --git a/doc/examples/pendulum/dsac.rst b/doc/examples/pendulum/dsac.rst new file mode 100644 index 0000000..6e7eac5 --- /dev/null +++ b/doc/examples/pendulum/dsac.rst @@ -0,0 +1,25 @@ +Pendulum with DSAC +================== + +In this notebook we solve the `Pendulum `_ environment +using `DSAC`, the distributional variant of `SAC`. We follow the `implementation https://arxiv.org/abs/2004.14547>` +by using quantile regression to approximate the q function. + +This notebook periodically generates GIFs, so that we can inspect how the training is progressing. + +After a few hundred episodes, this is what you can expect: + +.. image:: /_static/img/pendulum.gif + :alt: Successfully swinging up the pendulum. + :width: 360px + :align: center + +---- + +:download:`sac.py` + +.. image:: https://colab.research.google.com/assets/colab-badge.svg + :alt: Open in Google Colab + :target: https://colab.research.google.com/github/coax-dev/coax/blob/main/doc/_notebooks/pendulum/dsac.ipynb + +.. literalinclude:: dsac.py diff --git a/doc/examples/pendulum/index.rst b/doc/examples/pendulum/index.rst index b3f696e..d33a7d5 100644 --- a/doc/examples/pendulum/index.rst +++ b/doc/examples/pendulum/index.rst @@ -10,4 +10,6 @@ In these notebooks we solve the `Pendulum PPO TD3 - SAC \ No newline at end of file + SAC + TD4 + DSAC \ No newline at end of file diff --git a/doc/examples/pendulum/td4.py b/doc/examples/pendulum/td4.py new file mode 100644 index 0000000..b129a1e --- /dev/null +++ b/doc/examples/pendulum/td4.py @@ -0,0 +1,134 @@ +import gym +import jax +import coax +import haiku as hk +import jax.numpy as jnp +from numpy import prod +import optax + + +# the name of this script +name = 'td3' + +# the Pendulum MDP +env = gym.make('Pendulum-v1') +env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}") +quantile_embedding_dim = 64 +layer_size = 256 +num_quantiles = 32 + + +def func_pi(S, is_training): + seq = hk.Sequential(( + hk.Linear(8), jax.nn.relu, + hk.Linear(8), jax.nn.relu, + hk.Linear(8), jax.nn.relu, + hk.Linear(prod(env.action_space.shape), w_init=jnp.zeros), + hk.Reshape(env.action_space.shape), + )) + mu = seq(S) + return {'mu': mu, 'logvar': jnp.full_like(mu, jnp.log(0.05))} # (almost) deterministic + + +def quantile_net(x, quantile_fractions): + quantiles_emb = coax.utils.quantile_cos_embedding( + quantile_fractions, quantile_embedding_dim) + quantiles_emb = hk.Linear(x.shape[-1])(quantiles_emb) + quantiles_emb = jax.nn.relu(quantiles_emb) + x = x[:, None, :] * quantiles_emb + x = hk.Linear(layer_size)(x) + x = jax.nn.relu(x) + return x + + +def func_q(S, A, is_training): + encoder = hk.Sequential(( + hk.Flatten(), + hk.Linear(layer_size), + jax.nn.relu + )) + quantile_fractions = coax.utils.quantiles_uniform(rng=hk.next_rng_key(), + batch_size=S.shape[0], + num_quantiles=num_quantiles) + X = jnp.concatenate((S, A), axis=-1) + x = encoder(X) + quantile_x = quantile_net(x, quantile_fractions=quantile_fractions) + quantile_values = hk.Linear(1)(quantile_x) + return {'values': quantile_values.squeeze(axis=-1), + 'quantile_fractions': quantile_fractions} + + +# main function approximators +pi = coax.Policy(func_pi, env) +q1 = coax.StochasticQ(func_q, env, action_preprocessor=pi.proba_dist.preprocess_variate, + value_range=None, num_bins=num_quantiles) +q2 = coax.StochasticQ(func_q, env, action_preprocessor=pi.proba_dist.preprocess_variate, + value_range=None, num_bins=num_quantiles) + + +# target network +q1_targ = q1.copy() +q2_targ = q2.copy() +pi_targ = pi.copy() + + +# experience tracer +tracer = coax.reward_tracing.NStep(n=5, gamma=0.9) +buffer = coax.experience_replay.SimpleReplayBuffer(capacity=25000) + + +# updaters +qlearning1 = coax.td_learning.ClippedDoubleQLearning( + q1, pi_targ_list=[pi_targ], q_targ_list=[q1_targ, q2_targ], + loss_function=coax.value_losses.mse, optimizer=optax.adam(1e-3)) +qlearning2 = coax.td_learning.ClippedDoubleQLearning( + q2, pi_targ_list=[pi_targ], q_targ_list=[q1_targ, q2_targ], + loss_function=coax.value_losses.mse, optimizer=optax.adam(1e-3)) +determ_pg = coax.policy_objectives.DeterministicPG(pi, q1_targ, optimizer=optax.adam(1e-3)) + + +# train +while env.T < 1000000: + s = env.reset() + + for t in range(env.spec.max_episode_steps): + a = pi.mode(s) + s_next, r, done, info = env.step(a) + + # trace rewards and add transition to replay buffer + tracer.add(s, a, r, done) + while tracer: + buffer.add(tracer.pop()) + + # learn + if len(buffer) >= 5000: + transition_batch = buffer.sample(batch_size=128) + + # init metrics dict + metrics = {} + + # flip a coin to decide which of the q-functions to update + qlearning = qlearning1 if jax.random.bernoulli(q1.rng) else qlearning2 + metrics.update(qlearning.update(transition_batch)) + + # delayed policy updates + if env.T >= 7500 and env.T % 4 == 0: + metrics.update(determ_pg.update(transition_batch)) + + env.record_metrics(metrics) + + # sync target networks + q1_targ.soft_update(q1, tau=0.001) + q2_targ.soft_update(q2, tau=0.001) + pi_targ.soft_update(pi, tau=0.001) + + if done: + break + + s = s_next + + # generate an animated GIF to see what's going on + # if env.period(name='generate_gif', T_period=10000) and env.T > 5000: + # T = env.T - env.T % 10000 # round to 10000s + # coax.utils.generate_gif( + # env=env, policy=pi, filepath=f"./data/gifs/{name}/T{T:08d}.gif") diff --git a/doc/examples/pendulum/td4.rst b/doc/examples/pendulum/td4.rst new file mode 100644 index 0000000..8f84304 --- /dev/null +++ b/doc/examples/pendulum/td4.rst @@ -0,0 +1,25 @@ +Pendulum with TD4 +================== + +In this notebook we solve the `Pendulum `_ environment +using TD4 which is the distributional variant of :doc:`TD3 `. We estimate the q function using quantile +regression as in :doc:`IQN `. + +This notebook periodically generates GIFs, so that we can inspect how the training is progressing. + +After a few hundred episodes, this is what you can expect: + +.. image:: /_static/img/pendulum.gif + :alt: Successfully swinging up the pendulum. + :width: 360px + :align: center + +---- + +:download:`td4.py` + +.. image:: https://colab.research.google.com/assets/colab-badge.svg + :alt: Open in Google Colab + :target: https://colab.research.google.com/github/coax-dev/coax/blob/main/doc/_notebooks/pendulum/td4.ipynb + +.. literalinclude:: td4.py diff --git a/doc/versions.html b/doc/versions.html index bf800b8..96ff5d2 100644 --- a/doc/versions.html +++ b/doc/versions.html @@ -109,7 +109,7 @@ function updateCommand() { var codecellName = 'codecell0'; - var jaxlibVersion = '0.1.73'; // this is automatically updated from conf.py + var jaxlibVersion = '0.1.76'; // this is automatically updated from conf.py // get the selected os version var osVersion = null;