Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance, WIP] Faster SAC #1958

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import abc
import warnings
import torchrl._utils
from copy import deepcopy
from dataclasses import dataclass
from typing import Iterator, List, Optional, Tuple
Expand Down Expand Up @@ -289,11 +290,7 @@ def _compare_and_expand(param):

# set the functional module: we need to convert the params to non-differentiable params
# otherwise they will appear twice in parameters
with params.apply(
self._make_meta_params, device=torch.device("meta"), filter_empty=False
).to_module(module):
# avoid buffers and params being exposed
self.__dict__[module_name] = deepcopy(module)
self._make_meta_module(module_name=module_name, module=module, params=params)

name_params_target = "target_" + module_name
if create_target_params:
Expand All @@ -308,14 +305,21 @@ def _compare_and_expand(param):
setattr(self, name_params_target + "_params", target_params)
self._has_update_associated[module_name] = not create_target_params

def _make_meta_module(self, *, module_name, module, params):
with params.apply(
self._make_meta_params, device=torch.device("meta"), filter_empty=False
).to_module(module):
# avoid buffers and params being exposed
self.__dict__[module_name] = deepcopy(module)

def __getattr__(self, item):
if item.startswith("target_") and item.endswith("_params"):
params = self._modules.get(item, None)
if params is None:
# no target param, take detached data
params = getattr(self, item[7:])
params = params.data
elif not self._has_update_associated[item[7:-7]] and RL_WARNINGS:
elif not self._has_update_associated[item[7:-7]] and torchrl._utils.RL_WARNINGS:
# no updater associated
warnings.warn(
self.TARGET_NET_WARNING,
Expand Down
90 changes: 52 additions & 38 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
# LICENSE file in the root directory of this source tree.
import math
import warnings
from copy import deepcopy
from dataclasses import dataclass
from functools import wraps
from numbers import Number
from typing import Dict, Optional, Tuple, Union

import torchrl._utils
import numpy as np
import torch
from tensordict import TensorDict, TensorDictBase

from tensordict.nn import dispatch, TensorDictModule
from tensordict.nn.utils import _set_dispatch_td_nn_modules
from tensordict.utils import NestedKey
from torch import Tensor
from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
Expand Down Expand Up @@ -374,13 +376,33 @@ def __init__(
)
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self._vmap_qnetworkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
from functorch import dim

d0 = dim.dims(1)
self.d0 = d0
self.__dict__["_qvalue_net_vmapped_detached"] = deepcopy(self.qvalue_network)
self._cached_detached_qvalue_params[d0].to_module(
self._qvalue_net_vmapped_detached
)
if self._version == 1:
self._vmap_qnetwork00 = _vmap_func(
qvalue_network, randomness=self.vmap_randomness
)
self.__dict__["_qvalue_net_vmapped"] = deepcopy(self.qvalue_network)
self.qvalue_network_params[d0].to_module(self._qvalue_net_vmapped)
self.__dict__["_qvalue_net_vmapped_target"] = deepcopy(self.qvalue_network)
torchrl._utils.RL_WARNINGS = False
self.target_qvalue_network_params[d0].to_module(self._qvalue_net_vmapped_target)
torchrl._utils.RL_WARNINGS = True

self.actor_network_params.to_module(self.actor_network)

# self._vmap_qnetworkN0 = _vmap_func(
# self.qvalue_network, (None, 0), randomness=self.vmap_randomness
# )
# if self._version == 1:
# self.__dict__["_qvalue_net_vmapped00"] = deepcopy(self.qvalue_network)
# self.qvalue_network_params[d0].to_module(self._qvalue_net_vmapped00)

# self._vmap_qnetwork00 = _vmap_func(
# qvalue_network, randomness=self.vmap_randomness
# )

@property
def target_entropy_buffer(self):
Expand Down Expand Up @@ -546,7 +568,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
else:
loss_qvalue, value_metadata = self._qvalue_v2_loss(tensordict_reshape)
loss_value = None
loss_actor, metadata_actor = self._actor_loss(tensordict_reshape)
loss_actor, metadata_actor = self.actor_loss(tensordict_reshape)
loss_alpha = self._alpha_loss(log_prob=metadata_actor["log_prob"])
tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"])
if (loss_actor.shape != loss_qvalue.shape) or (
Expand Down Expand Up @@ -574,24 +596,20 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def _cached_detached_qvalue_params(self):
return self.qvalue_network_params.detach()

def _actor_loss(
@_set_dispatch_td_nn_modules(False)
def actor_loss(
self, tensordict: TensorDictBase
) -> Tuple[Tensor, Dict[str, Tensor]]:
with set_exploration_type(
ExplorationType.RANDOM
), self.actor_network_params.to_module(self.actor_network):
with set_exploration_type(ExplorationType.RANDOM):
dist = self.actor_network.get_dist(tensordict)
a_reparm = dist.rsample()
log_prob = dist.log_prob(a_reparm)

td_q = tensordict.select(*self.qvalue_network.in_keys)
td_q.set(self.tensor_keys.action, a_reparm)
td_q = self._vmap_qnetworkN0(
td_q,
self._cached_detached_qvalue_params, # should we clone?
)
td_q = self._qvalue_net_vmapped_detached(td_q)
min_q_logprob = (
td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1)
td_q.get(self.tensor_keys.state_action_value)._tensor.min(0)[0].squeeze(-1)
)

if log_prob.shape != min_q_logprob.shape:
Expand All @@ -615,6 +633,7 @@ def _cached_target_params_actor_value(self):
_run_checks=False,
)

@_set_dispatch_td_nn_modules(False)
def _qvalue_v1_loss(
self, tensordict: TensorDictBase
) -> Tuple[Tensor, Dict[str, Tensor]]:
Expand All @@ -640,10 +659,8 @@ def _qvalue_v1_loss(
)

# if vmap=True, it is assumed that the input tensordict must be cast to the param shape
tensordict_chunks = self._vmap_qnetwork00(
tensordict_chunks, self.qvalue_network_params
)
pred_val = tensordict_chunks.get(self.tensor_keys.state_action_value)
tensordict_chunks = self._qvalue_net_vmapped(tensordict_chunks[self.d0])
pred_val = tensordict_chunks.get(self.tensor_keys.state_action_value)._tensor
pred_val = pred_val.squeeze(-1)
loss_value = distance_loss(
pred_val, target_chunks, loss_function=self.loss_function
Expand All @@ -652,6 +669,7 @@ def _qvalue_v1_loss(

return loss_value, metadata

@_set_dispatch_td_nn_modules(False)
def _compute_target_v2(self, tensordict) -> Tensor:
r"""Value network for SAC v2.

Expand All @@ -667,22 +685,18 @@ def _compute_target_v2(self, tensordict) -> Tensor:
tensordict = tensordict.clone(False)
# get actions and log-probs
with torch.no_grad():
with set_exploration_type(
ExplorationType.RANDOM
), self.actor_network_params.to_module(self.actor_network):
with set_exploration_type(ExplorationType.RANDOM):
next_tensordict = tensordict.get("next").clone(False)
next_dist = self.actor_network.get_dist(next_tensordict)
next_action = next_dist.rsample()
next_tensordict.set(self.tensor_keys.action, next_action)
next_sample_log_prob = next_dist.log_prob(next_action)

# get q-values
next_tensordict_expand = self._vmap_qnetworkN0(
next_tensordict, self.target_qvalue_network_params
)
next_tensordict_expand = self._qvalue_net_vmapped_target(next_tensordict)
state_action_value = next_tensordict_expand.get(
self.tensor_keys.state_action_value
)
)._tensor
if (
state_action_value.shape[-len(next_sample_log_prob.shape) :]
!= next_sample_log_prob.shape
Expand All @@ -696,19 +710,19 @@ def _compute_target_v2(self, tensordict) -> Tensor:
target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
return target_value

@_set_dispatch_td_nn_modules(False)
def _qvalue_v2_loss(
self, tensordict: TensorDictBase
) -> Tuple[Tensor, Dict[str, Tensor]]:
# we pass the alpha value to the tensordict. Since it's a scalar, we must erase the batch-size first.
target_value = self._compute_target_v2(tensordict)

tensordict_expand = self._vmap_qnetworkN0(
tensordict.select(*self.qvalue_network.in_keys),
self.qvalue_network_params,
)
pred_val = tensordict_expand.get(self.tensor_keys.state_action_value).squeeze(
-1
tensordict_expand = self._qvalue_net_vmapped(
tensordict.select(*self.qvalue_network.in_keys)
)
pred_val = tensordict_expand.get(
self.tensor_keys.state_action_value
)._tensor.squeeze(-1)
td_error = abs(pred_val - target_value)
loss_qval = distance_loss(
pred_val,
Expand All @@ -718,6 +732,7 @@ def _qvalue_v2_loss(
metadata = {"td_error": td_error.detach().max(0)[0]}
return loss_qval, metadata

@_set_dispatch_td_nn_modules(False)
def _value_loss(
self, tensordict: TensorDictBase
) -> Tuple[Tensor, Dict[str, Tensor]]:
Expand All @@ -732,13 +747,12 @@ def _value_loss(

td_copy.set(self.tensor_keys.action, action, inplace=False)

td_copy = self._vmap_qnetworkN0(
td_copy,
self.target_qvalue_network_params,
td_copy = self._qvalue_net_vmapped_target(
td_copy
)

min_qval = (
td_copy.get(self.tensor_keys.state_action_value).squeeze(-1).min(0)[0]
td_copy.get(self.tensor_keys.state_action_value)._tensor.squeeze(-1).min(0)[0]
)

log_p = action_dist.log_prob(action)
Expand Down
Loading