Skip to content

Commit

Permalink
Implementation of StochasticQ in ClippedDoubleQLearning (#18)
Browse files Browse the repository at this point in the history
* wip dsac

* implement dsac with quantile network

* small changes

* make dsac and td4 work

* add some documentation for td3 and dsac
  • Loading branch information
frederikschubert authored May 3, 2022
1 parent b2affe5 commit 3362952
Show file tree
Hide file tree
Showing 20 changed files with 1,011 additions and 72 deletions.
2 changes: 1 addition & 1 deletion coax/_base/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def func(S, is_training):
hk.Linear(self.env_discrete.action_space.n * 51),
hk.Reshape((self.env_discrete.action_space.n, 51))
))
return seq(flatten(S))
return {'logits': seq(flatten(S))}
return func

@property
Expand Down
12 changes: 8 additions & 4 deletions coax/policy_objectives/_deterministic_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import haiku as hk
import chex

from ..utils import check_preprocessors
from .._core.q import Q
from ..utils import check_preprocessors, is_qfunction, is_stochastic
from ._base import PolicyObjective


Expand Down Expand Up @@ -61,7 +60,7 @@ class DeterministicPG(PolicyObjective):
REQUIRES_PROPENSITIES = False

def __init__(self, pi, q_targ, optimizer=None, regularizer=None):
if not isinstance(q_targ, Q):
if not is_qfunction(q_targ):
raise TypeError(f"q must be a q-function, got: {type(q_targ)}")
if q_targ.modeltype != 1:
raise TypeError("q must be a type-1 q-function")
Expand Down Expand Up @@ -97,7 +96,12 @@ def objective_func(self, params, state, hyperparams, rng, transition_batch, Adv)
A = self.pi.proba_dist.mode(dist_params)
log_pi = self.pi.proba_dist.log_proba(dist_params, A)
params_q, state_q = hyperparams['q']['params'], hyperparams['q']['function_state']
Q, _ = self.q_targ.function_type1(params_q, state_q, next(rngs), S, A, True)
if is_stochastic(self.q_targ):
dist_params_q, _ = self.q_targ.function_type1(params_q, state_q, rng, S, A, True)
Q = self.q_targ.proba_dist.mean(dist_params_q)
Q = self.q_targ.proba_dist.postprocess_variate(next(rngs), Q, batch_mode=True)
else:
Q, _ = self.q_targ.function_type1(params_q, state_q, next(rngs), S, A, True)

# clip importance weights to reduce variance
W = jnp.clip(transition_batch.W, 0.1, 10.)
Expand Down
9 changes: 7 additions & 2 deletions coax/policy_objectives/_soft_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


from ._base import PolicyObjective
from ..utils import is_qfunction
from ..utils import is_qfunction, is_stochastic


class SoftPG(PolicyObjective):
Expand Down Expand Up @@ -37,7 +37,12 @@ def objective_func(self, params, state, hyperparams, rng, transition_batch, Adv)
for q_targ, params_q, state_q in qs:
# compute objective: q(s, a)
S = q_targ.observation_preprocessor(next(rngs), transition_batch.S)
Q, _ = q_targ.function_type1(params_q, state_q, next(rngs), S, A, True)
if is_stochastic(q_targ):
dist_params_q, _ = q_targ.function_type1(params_q, state_q, rng, S, A, True)
Q = q_targ.proba_dist.mean(dist_params_q)
Q = q_targ.proba_dist.postprocess_variate(next(rngs), Q, batch_mode=True)
else:
Q, _ = q_targ.function_type1(params_q, state_q, next(rngs), S, A, True)
Q_sa_list.append(Q)
# take the min to mitigate over-estimation
Q_sa_next_list = jnp.stack(Q_sa_list, axis=-1)
Expand Down
127 changes: 100 additions & 27 deletions coax/td_learning/_clippeddoubleqlearning.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import chex
from gym.spaces import Discrete

from .._core.q import Q
from ..utils import get_grads_diagnostics, is_policy, is_stochastic, jit
from ._base import BaseTDLearning
from ..proba_dists import DiscretizedIntervalDist, EmpiricalQuantileDist
from ..utils import (get_grads_diagnostics, is_policy, is_qfunction,
is_stochastic, jit, single_to_batch, batch_to_single, stack_trees)
from ..value_losses import quantile_huber
from ._base import BaseTDLearningQ


class ClippedDoubleQLearning(BaseTDLearning): # TODO(krholshe): make this less ugly
class ClippedDoubleQLearning(BaseTDLearningQ): # TODO(krholshe): make this less ugly
r"""
TD-learning with `TD3 <https://arxiv.org/abs/1802.09477>`_ style double q-learning updates, in
Expand Down Expand Up @@ -97,18 +99,14 @@ def __init__(
self, q, pi_targ_list=None, q_targ_list=None,
optimizer=None, loss_function=None, policy_regularizer=None):

if is_stochastic(q):
raise NotImplementedError(f"{type(self).__name__} is not yet implement for StochasticQ")

super().__init__(
f=q,
f_targ=None,
q=q,
q_targ=None,
optimizer=optimizer,
loss_function=loss_function,
policy_regularizer=policy_regularizer)

self._check_input_lists(pi_targ_list, q_targ_list)
del self._f_targ # no need for this (only potential source of confusion)
self.q_targ_list = q_targ_list
self.pi_targ_list = [] if pi_targ_list is None else pi_targ_list

Expand Down Expand Up @@ -136,11 +134,34 @@ def loss_func(params, target_params, state, target_state, rng, transition_batch)
metrics.update({f'{self.__class__.__name__}/{k}': v for k,
v in regularizer_metrics.items()})

Q, state_new = self.q.function_type1(params, state, next(rngs), S, A, True)
G = self.target_func(target_params, target_state, next(rngs), transition_batch)
# flip sign (typical example: regularizer = -beta * entropy)
G -= regularizer
loss = self.loss_function(G, Q, W)
if is_stochastic(self.q):
dist_params, state_new = \
self.q.function_type1(params, state, next(rngs), S, A, True)
dist_params_target = \
self.target_func(target_params, target_state, rng, transition_batch)

if self.policy_regularizer is not None:
dist_params_target = self.q.proba_dist.affine_transform(
dist_params_target, 1., -regularizer, self.q.value_transform)

if isinstance(self.q.proba_dist, DiscretizedIntervalDist):
loss = jnp.mean(self.q.proba_dist.cross_entropy(dist_params_target,
dist_params))
elif isinstance(self.q.proba_dist, EmpiricalQuantileDist):
loss = quantile_huber(dist_params_target['values'],
dist_params['values'],
dist_params['quantile_fractions'], W)
# the rest here is only needed for metrics dict
Q = self.q.proba_dist.mean(dist_params)
Q = self.q.proba_dist.postprocess_variate(next(rngs), Q, batch_mode=True)
G = self.q.proba_dist.mean(dist_params_target)
G = self.q.proba_dist.postprocess_variate(next(rngs), G, batch_mode=True)
else:
Q, state_new = self.q.function_type1(params, state, next(rngs), S, A, True)
G = self.target_func(target_params, target_state, next(rngs), transition_batch)
# flip sign (typical example: regularizer = -beta * entropy)
G -= regularizer
loss = self.loss_function(G, Q, W)

dLoss_dQ = jax.grad(self.loss_function, argnums=1)
td_error = -Q.shape[0] * dLoss_dQ(G, Q) # e.g. (G - Q) if loss function is MSE
Expand All @@ -149,7 +170,11 @@ def loss_func(params, target_params, state, target_state, rng, transition_batch)
Q_targ_list = []
qs = list(zip(self.q_targ_list, target_params['q_targ'], target_state['q_targ']))
for q, pm, st in qs:
Q_targ, _ = q.function_type1(pm, st, next(rngs), S, A, False)
if is_stochastic(q):
Q_targ = q.mean_func_type1(pm, st, next(rngs), S, A)
Q_targ = q.proba_dist.postprocess_variate(next(rngs), Q_targ, batch_mode=True)
else:
Q_targ, _ = q.function_type1(pm, st, next(rngs), S, A, False)
assert Q_targ.ndim == 1, f"bad shape: {Q_targ.shape}"
Q_targ_list.append(Q_targ)
Q_targ_list = jnp.stack(Q_targ_list, axis=-1)
Expand Down Expand Up @@ -184,10 +209,6 @@ def td_error_func(params, target_params, state, target_state, rng, transition_ba
self._grads_and_metrics_func = jit(grads_and_metrics_func)
self._td_error_func = jit(td_error_func)

@property
def q(self):
return self._f

@property
def target_params(self):
return hk.data_structures.to_immutable_dict({
Expand All @@ -211,27 +232,41 @@ def target_func(self, target_params, target_state, rng, transition_batch):
# collect list of q-values
if isinstance(self.q.action_space, Discrete):
Q_sa_next_list = []
A_next_list = []
qs = list(zip(self.q_targ_list, target_params['q_targ'], target_state['q_targ']))

# compute A_next from q_i
for q_i, params_i, state_i in qs:
S_next = q_i.observation_preprocessor(next(rngs), transition_batch.S_next)
Q_s_next, _ = q_i.function_type2(params_i, state_i, next(rngs), S_next, False)
if is_stochastic(q_i):
Q_s_next = q_i.mean_func_type2(params_i, state_i, next(rngs), S_next)
Q_s_next = q_i.proba_dist.postprocess_variate(
next(rngs), Q_s_next, batch_mode=True)
else:
Q_s_next, _ = q_i.function_type2(params_i, state_i, next(rngs), S_next, False)
assert Q_s_next.ndim == 2, f"bad shape: {Q_s_next.shape}"
A_next = (Q_s_next == Q_s_next.max(axis=1, keepdims=True)).astype(Q_s_next.dtype)
A_next /= A_next.sum(axis=1, keepdims=True) # there may be ties

# evaluate on q_j
for q_j, params_j, state_j in qs:
S_next = q_j.observation_preprocessor(next(rngs), transition_batch.S_next)
Q_sa_next, _ = q_j.function_type1(
params_j, state_j, next(rngs), S_next, A_next, False)
if is_stochastic(q_j):
Q_sa_next = q_j.mean_func_type1(
params_j, state_j, next(rngs), S_next, A_next)
Q_sa_next = q_j.proba_dist.postprocess_variate(
next(rngs), Q_sa_next, batch_mode=True)
else:
Q_sa_next, _ = q_j.function_type1(
params_j, state_j, next(rngs), S_next, A_next, False)
assert Q_sa_next.ndim == 1, f"bad shape: {Q_sa_next.shape}"
f_inv = q_j.value_transform.inverse_func
Q_sa_next_list.append(f_inv(Q_sa_next))
A_next_list.append(A_next)

else:
Q_sa_next_list = []
A_next_list = []
qs = list(zip(self.q_targ_list, target_params['q_targ'], target_state['q_targ']))
pis = list(zip(self.pi_targ_list, target_params['pi_targ'], target_state['pi_targ']))

Expand All @@ -244,17 +279,55 @@ def target_func(self, target_params, target_state, rng, transition_batch):
# evaluate on q_j
for q_j, params_j, state_j in qs:
S_next = q_j.observation_preprocessor(next(rngs), transition_batch.S_next)
Q_sa_next, _ = q_j.function_type1(
params_j, state_j, next(rngs), S_next, A_next, False)
if is_stochastic(q_j):
Q_sa_next = q_j.mean_func_type1(
params_j, state_j, next(rngs), S_next, A_next)
Q_sa_next = q_j.proba_dist.postprocess_variate(
next(rngs), Q_sa_next, batch_mode=True)
else:
Q_sa_next, _ = q_j.function_type1(
params_j, state_j, next(rngs), S_next, A_next, False)
assert Q_sa_next.ndim == 1, f"bad shape: {Q_sa_next.shape}"
f_inv = q_j.value_transform.inverse_func
Q_sa_next_list.append(f_inv(Q_sa_next))
A_next_list.append(A_next)

# take the min to mitigate over-estimation
A_next_list = jnp.stack(A_next_list, axis=1)
Q_sa_next_list = jnp.stack(Q_sa_next_list, axis=-1)
assert Q_sa_next_list.ndim == 2, f"bad shape: {Q_sa_next_list.shape}"
Q_sa_next = jnp.min(Q_sa_next_list, axis=-1)

if is_stochastic(self.q):
Q_sa_next_argmin = jnp.argmin(Q_sa_next_list, axis=-1)
Q_sa_next_argmin_q = Q_sa_next_argmin % len(self.q_targ_list)

def target_dist_params(A_next_idx, q_targ_idx, p, s, t, A_next_list):
return self._get_target_dist_params(batch_to_single(p, q_targ_idx),
batch_to_single(s, q_targ_idx),
next(rngs),
single_to_batch(t),
single_to_batch(batch_to_single(A_next_list,
A_next_idx)))

def tile_parameters(params, state, reps):
return jax.tree_util.tree_map(lambda t: jnp.tile(t, [reps, *([1] * (t.ndim - 1))]),
stack_trees(params, state))
# stack and tile q-function params to select the argmin for the target dist params
tiled_target_params, tiled_target_state = tile_parameters(
target_params['q_targ'], target_state['q_targ'], reps=len(self.q_targ_list))

vtarget_dist_params = jax.vmap(target_dist_params, in_axes=(0, 0, None, None, 0, 0))
dist_params = vtarget_dist_params(
Q_sa_next_argmin,
Q_sa_next_argmin_q,
tiled_target_params,
tiled_target_state,
transition_batch,
A_next_list)
# unwrap dist params computed for single batches
return jax.tree_util.tree_map(lambda t: jnp.squeeze(t, axis=1), dist_params)

Q_sa_next = jnp.min(Q_sa_next_list, axis=-1)
assert Q_sa_next.ndim == 1, f"bad shape: {Q_sa_next.shape}"
f = self.q.value_transform.transform_func
return f(transition_batch.Rn + transition_batch.In * Q_sa_next)
Expand Down Expand Up @@ -283,5 +356,5 @@ def _check_input_lists(self, pi_targ_list, q_targ_list):
if not q_targ_list:
raise ValueError("q_targ_list cannot be empty")
for q_targ in q_targ_list:
if not isinstance(q_targ, Q):
if not is_qfunction(q_targ):
raise TypeError(f"all q_targ in q_targ_list must be a coax.Q, got: {type(q_targ)}")
81 changes: 81 additions & 0 deletions coax/td_learning/_clippeddoubleqlearning_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .._base.test_case import TestCase
from .._core.q import Q
from .._core.stochastic_q import StochasticQ
from .._core.policy import Policy
from ..utils import get_transition_batch
from ._clippeddoubleqlearning import ClippedDoubleQLearning
Expand Down Expand Up @@ -40,6 +41,31 @@ def test_update_discrete_type1(self):
self.assertPytreeNotEqual(function_state1, q1.function_state)
self.assertPytreeNotEqual(function_state2, q2.function_state)

def test_update_discrete_stochastic_type1(self):
env = self.env_discrete
func_q = self.func_q_stochastic_type1
transition_batch = self.transition_discrete

q1 = StochasticQ(func_q, env, value_range=(0, 1))
q2 = StochasticQ(func_q, env, value_range=(0, 1))
q_targ1 = q1.copy()
q_targ2 = q2.copy()
updater1 = ClippedDoubleQLearning(q1, q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0))
updater2 = ClippedDoubleQLearning(q2, q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0))

params1 = deepcopy(q1.params)
params2 = deepcopy(q2.params)
function_state1 = deepcopy(q1.function_state)
function_state2 = deepcopy(q2.function_state)

updater1.update(transition_batch)
updater2.update(transition_batch)

self.assertPytreeNotEqual(params1, q1.params)
self.assertPytreeNotEqual(params2, q2.params)
self.assertPytreeNotEqual(function_state1, q1.function_state)
self.assertPytreeNotEqual(function_state2, q2.function_state)

def test_update_discrete_type2(self):
env = self.env_discrete
func_q = self.func_q_type2
Expand All @@ -65,6 +91,31 @@ def test_update_discrete_type2(self):
self.assertPytreeNotEqual(function_state1, q1.function_state)
self.assertPytreeNotEqual(function_state2, q2.function_state)

def test_update_discrete_stochastic_type2(self):
env = self.env_discrete
func_q = self.func_q_stochastic_type2
transition_batch = self.transition_discrete

q1 = StochasticQ(func_q, env, value_range=(0, 1))
q2 = StochasticQ(func_q, env, value_range=(0, 1))
q_targ1 = q1.copy()
q_targ2 = q2.copy()
updater1 = ClippedDoubleQLearning(q1, q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0))
updater2 = ClippedDoubleQLearning(q2, q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0))

params1 = deepcopy(q1.params)
params2 = deepcopy(q2.params)
function_state1 = deepcopy(q1.function_state)
function_state2 = deepcopy(q2.function_state)

updater1.update(transition_batch)
updater2.update(transition_batch)

self.assertPytreeNotEqual(params1, q1.params)
self.assertPytreeNotEqual(params2, q2.params)
self.assertPytreeNotEqual(function_state1, q1.function_state)
self.assertPytreeNotEqual(function_state2, q2.function_state)

def test_update_boxspace(self):
env = self.env_boxspace
func_q = self.func_q_type1
Expand Down Expand Up @@ -95,6 +146,36 @@ def test_update_boxspace(self):
self.assertPytreeNotEqual(function_state1, q1.function_state)
self.assertPytreeNotEqual(function_state2, q2.function_state)

def test_update_boxspace_stochastic(self):
env = self.env_boxspace
func_q = self.func_q_stochastic_type1
func_pi = self.func_pi_boxspace
transition_batch = self.transition_boxspace

q1 = StochasticQ(func_q, env, value_range=(0, 1))
q2 = StochasticQ(func_q, env, value_range=(0, 1))
pi1 = Policy(func_pi, env)
pi2 = Policy(func_pi, env)
q_targ1 = q1.copy()
q_targ2 = q2.copy()
updater1 = ClippedDoubleQLearning(
q1, pi_targ_list=[pi1, pi2], q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0))
updater2 = ClippedDoubleQLearning(
q2, pi_targ_list=[pi1, pi2], q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0))

params1 = deepcopy(q1.params)
params2 = deepcopy(q2.params)
function_state1 = deepcopy(q1.function_state)
function_state2 = deepcopy(q2.function_state)

updater1.update(transition_batch)
updater2.update(transition_batch)

self.assertPytreeNotEqual(params1, q1.params)
self.assertPytreeNotEqual(params2, q2.params)
self.assertPytreeNotEqual(function_state1, q1.function_state)
self.assertPytreeNotEqual(function_state2, q2.function_state)

def test_discrete_with_pi(self):
env = self.env_discrete
func_q = self.func_q_type1
Expand Down
Loading

0 comments on commit 3362952

Please sign in to comment.