Skip to content

Latest commit

 

History

History
27 lines (23 loc) · 780 Bytes

README.md

File metadata and controls

27 lines (23 loc) · 780 Bytes

JAX-RL

JAX implementations of various deep reinforcement learning algorithms.

Main libraries used:

  • JAX - main framework
  • Haiku - neural networks
  • Optax - gradient based optimisation

Algorithms implemented

Algorithms Paper
Proximal Policy Optimization (PPO) https://arxiv.org/abs/1707.06347
Deep Q-Network (DQN) https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf
Double Deep Q-Network (DDQN) https://arxiv.org/abs/1509.06461
Deep Recurrent Q-Network (DRQN) https://arxiv.org/abs/1507.06527
Deep Deterministic Policy Gradient (DDPG) https://arxiv.org/abs/1509.02971

Tabular algorithms

  • Q-learning
  • Double Q-learning
  • SARSA
  • Expected SARSA

Installation

$ pip install git+https://github.com/hamishs/JAX-RL