Releases: takuseno/d3rlpy
Release v2.7.0
Breaking changes
Dependency
- Python 3.9 or later
- PyTorch v2.5.0 or later
OptimizerFactory
Import paths of OptimizerFactory
has been changed from d3rlpy.models.OptimizerFactory
to d3rlpy.optimizers.OptimizerFactory
.
# before
optim = d3rlpy.models.AdamFactory()
# after
optim = d3rlpy.optimizers.AdamFactory()
x2-3 speed up with CudaGraph and torch.compile
In this PR, d3rlpy supports CudaGraph and torch.compile to dramatically speed up training. You can just turn on this new feature by providing compile_graph
option:
import d3rlpy
# enable CudaGraph and torch.compile
sac = d3rlpy.algos.SACConfig(compile_graph=True).create(device="cuda:0")
Here is some benchmark result with NVIDIA RTX4070:
v2.6.2 | v2.7.0 | |
---|---|---|
Soft Actor-Critic | 7.4 msec | 3.0 msec |
Conservative Q-Learning | 12.5 msec | 3.8 msec |
Decision Transformer | 8.9 msec | 3.4 msec |
Note that this feature can be only enabled if you use CUDA device.
Enhanced optimizer
Learning rate scheduler
This release adds LRSchedulerFactory
that provides a learning rate scheduler to individual optimizer.
import d3rlpy
optim = d3rlpy.optimizers.AdamFactory(
lr_scheduler=d3rlpy.optimizers.CosineAnnealingLRFactory(T_max=1000000)
)
See an example here and docs here.
Gradient clipping
Now, clip_grad_norm
option has been added to clip gradients by global norm.
import d3rlpy
optim = d3rlpy.optimizers.AdamFactory(clip_grad_norm=0.1)
SimBa encoder
This release adds SimBa architecture that allows us to scale models effectively. See the paper here.
See docs here.
Enhancement
- Gradients are now being tracked by loggers (thanks, @hasan-yaman)
Development
- Replace black, isort and pylint with Ruff.
scripts/format
has been removed.scripts/lint
now formats code styles too.
Release v2.6.2
This is an emergency update to resolve an issue caused by the new Gymnasium version v1.0.0. Additionally, d3rlpy internally checks versions of both Gym and Gymnasium to make sure that dependencies are correct.
Release v2.6.1
Bugfix
There has been an issue in data-parallel distributed training feature of d3rlpy. Each process doesn't correctly synchronize parameters. In this release, this issue has been fixed and the data-parallel distributed training is working properly. Please check the latest example script to see how to use it.
Release v2.6.0
New Algorithm
ReBRAC has been added to d3rlpy! Please check a reproduction script here.
Enhancement
- DeepMind Control support has been added. You can install dependencies by
d3rlpy install dm_control
. Please check an example script here. use_layer_norm
option has been added toVectorEncoderFactory
.
Bugfix
- Fix return-to-go calculation for Decision Transformer.
- Fix custom model documentation.
Release v2.5.0
New Algorithm
Cal-QL has been added to d3rlpy in v2.5.0! Please check a reproduction script here. To support faithful reproduction, SparseRewardTransitionPicker
has been also added, which is used in the reproduction script.
Custom Algorithm Example
One of the frequent questions is "How can I implement a custom algorithm on top of d3rlpy?". Now, the new example script has been added to answer this question. Based on this example, you can build your own algorithm while you can utilize a whole training pipeline provided by d3rlpy. Please check the script here.
Enhancement
- Exporting Decision Transformer models as TorchScript and ONNX has been implemented. You can use this feature via
save_policy
method in the same way as you use with Q-learning algorithms. - Tuple observation support has been added to PyTorch/ONNX export.
- Modified return-to-go calculation for Q-learning algorithms and skip this calculation if return-to-go is not necessary.
n_updates
option has been added tofit_online
method to control update-to-data (UTD) ratio.write_at_termination
option has been added toReplayBuffer
.
Bugfix
- Action scaling has been fixed for D4RL datasets.
- Default replay buffer creation at
fix_online
method has been fixed.
Release v2.4.0
Tuple observations
In v2.4.0, d3rlpy supports tuple observations.
import numpy as np
import d3rlpy
observations = [np.random.random((1000, 100)), np.random.random((1000, 32))]
actions = np.random.random((1000, 4))
rewards = np.random.random((1000, 1))
terminals = np.random.randint(2, size=(1000, 1))
dataset = d3rlpy.dataset.MDPDataset(
observations=observations,
actions=actions,
rewards=rewards,
terminals=terminals,
)
You can find an example script here
Enhancements
logging_steps
andlogging_strategy
options have been added tofit
andfit_online
methods (thanks, @claudius-kienle )- Logging with WanDB has been supported. (thanks, @claudius-kienle )
- Goal-conditioned envs in Minari have been supported.
Bugfix
- Fix errors for distributed training.
- OPE documentation has been fixed.
Release v2.3.0
Distributed data parallel training
Distributed data parallel training with multiple nodes and GPUs has been one of the most demanded feature. Now, it's finally available! It's extremely easy to use this feature.
Example:
# train.py
from typing import Dict
import d3rlpy
def main() -> None:
# GPU version:
# rank = d3rlpy.distributed.init_process_group("nccl")
rank = d3rlpy.distributed.init_process_group("gloo")
print(f"Start running on rank={rank}.")
# GPU version:
# device = f"cuda:{rank}"
device = "cpu:0"
# setup algorithm
cql = d3rlpy.algos.CQLConfig(
actor_learning_rate=1e-3,
critic_learning_rate=1e-3,
alpha_learning_rate=1e-3,
).create(device=device)
# prepare dataset
dataset, env = d3rlpy.datasets.get_pendulum()
# disable logging on rank != 0 workers
logger_adapter: d3rlpy.logging.LoggerAdapterFactory
evaluators: Dict[str, d3rlpy.metrics.EvaluatorProtocol]
if rank == 0:
evaluators = {"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}
logger_adapter = d3rlpy.logging.FileAdapterFactory()
else:
evaluators = {}
logger_adapter = d3rlpy.logging.NoopAdapterFactory()
# start training
cql.fit(
dataset,
n_steps=10000,
n_steps_per_epoch=1000,
evaluators=evaluators,
logger_adapter=logger_adapter,
show_progress=rank == 0,
enable_ddp=True,
)
d3rlpy.distributed.destroy_process_group()
if __name__ == "__main__":
main()
You need to use torchrun
command to start training, which should be already installed once you install PyTorch.
$ torchrun \
--nnodes=1 \
--nproc_per_node=3 \
--rdzv_id=100 \
--rdzv_backend=c10d \
--rdzv_endpoint=localhost:29400 \
train.py
In this case, 3 processes will be launched and start training loop. DecisionTransformer
-based algorithms also support this distributed training feature.
The example is also available here
Minari support (thanks, @grahamannett !)
Minari is an OSS library to provide a standard format of offline reinforcement learning datasets. Now, d3rlpy provides an easy access to this library.
You can install Minari via d3rlpy CLI.
$ d3rlpy install minari
Example:
import d3rlpy
dataset, env = d3rlpy.datasets.get_minari("antmaze-umaze-v0")
iql = d3rlpy.algos.IQLConfig(
actor_learning_rate=3e-4,
critic_learning_rate=3e-4,
batch_size=256,
weight_temp=10.0,
max_weight=100.0,
expectile=0.9,
reward_scaler=d3rlpy.preprocessing.ConstantShiftRewardScaler(shift=-1),
).create(device="cpu:0")
iql.fit(
dataset,
n_steps=1000000,
n_steps_per_epoch=100000,
evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
)
Minimize redundant computes
From this version, calculation of some algorithms are optimized to remove redundant inference. Therefore, especially algorithms with dual optimization such as SAC
and CQL
became extremely faster than the previous version.
Enhancements
GoalConcatWrapper
has been added to support goal-conditioned environments.return_to_go
has been added toTransition
andTransitionMiniBatch
MixedReplayBuffer
has been added to sample two experiences from multiple buffers with arbitrary ratio.initial_temperature
supports 0 atDiscreteSAC
.
Bugfix
- Getting started page has been fixed.
Release v2.2.0
Algorithm
DiscreteDecisionTransformer
, a Decision Transformer implementation for discrete action-space, has been finally implemented in v2.2.0! The reduction results with Atari 2600 are available here.
import d3rlpy
dataset, env = d3rlpy.datasets.get_cartpole()
dt = d3rlpy.algos.DiscreteDecisionTransformerConfig(
batch_size=64,
num_heads=1,
learning_rate=1e-4,
max_timestep=1000,
num_layers=3,
position_encoding_type=d3rlpy.PositionEncodingType.SIMPLE,
encoder_factory=d3rlpy.models.VectorEncoderFactory([128], exclude_last_activation=True),
observation_scaler=d3rlpy.preprocessing.StandardObservationScaler(),
context_size=20,
warmup_tokens=100000,
).create()
dt.fit(
dataset,
n_steps=100000,
n_steps_per_epoch=1000,
eval_env=env,
eval_target_return=500,
)
Enhancement
- Expose
action_size
andaction_space
options for manual dataset creation #338 FrameStackTrajectorySlicer
has been added.
Refactoring
- Typing check of
numpy
is enabled. Some parts of codes differentiate data types of numpy arrays, which is checked by mypy.
Bugfix
Release v2.1.0
Upgrade PyTorch to v2
From this version, d3rlpy requires PyTorch v2 (v1 still may partially work). To do this, the minimum Python version has been bumped to 3.8. This change allows d3rlpy to utilize more advanced features such as torch.compile
in the upcoming releases.
Healthcheck
From this version, d3rlpy diagnoses dependency health automatically. In this version, the version of Gym
is checked to make sure you have installed the correct version of Gym
.
Gymnasium support
d3rlpy now supports Gymnasium
as well as Gym
. You can use it just same as Gym
. Please check example for the further details.
d3rlpy install command
To make your life easier, d3rlpy provides d3rlpy install
commands to install additional dependencies. This is the part of d3rlpy
CLI. Please check docs for the further details.
$ d3rlpy install atari # Atari 2600 dependencies
$ d3rlpy install d4rl_atari # Atari 2600 + d4rl-atari dependencies
$ d3rlpy install d4rl # D4RL dependencies
Refactoring
In this version, the internal design has been refactored. The algorithm implementation and the way to assign models are mainly refactored.
Enhancement
- Added Jupyter Notebook for TPU on Google Colaboratory.
- Added
d3rlpy.notebook_utils
to provide utilities for Jupyter Notebook. - Updated notebook link #313 (thanks @asmith26 !)
Bugfix
- Fixed typo docstrings #316 (thanks @asmith26 !)
- Fixed docker build #311 (thanks @HassamSheikh !)
Release v2.0.4
Bugfix
- Fix DiscreteCQL loss metrics #298
- Fix
dump
ReplayBuffer #299 - Fix
InitialStateValueEstimationEvaluator
#301 - Fix rendering interface to match the latest Gym version #302
To the rendering fix, I recommend you reinstall d4rl-atari
if you use it.
$ pip install -U git+https://github.com/takuseno/d4rl-atari