diff --git a/docs/conf.py b/docs/conf.py index 33922f0..a26f4a7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -174,6 +174,7 @@ def notebook_modification_function(notebook_content, notebook_filename): dummy_notebook_content["cells"] + notebook_content["cells"] ) + jupyterlite_contents = ["tutorials"] # numpydoc_validation_checks = {"all"} # can be uncommented to get the warnings from numpy. diff --git a/tutorials/Tutorial_Deep_Q_Learning.ipynb b/tutorials/Tutorial_Deep_Q_Learning.ipynb new file mode 100644 index 0000000..c315bab --- /dev/null +++ b/tutorials/Tutorial_Deep_Q_Learning.ipynb @@ -0,0 +1,1059 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "2j_no2BuvPUE" + }, + "source": [ + "# Tutorial - Deep Q-Learning \n", + "\n", + "Deep Q-Learning uses a neural network to approximate $Q$ functions. Hence, we usually refer to this algorithm as DQN (for *deep Q network*) [DQN Article](https://arxiv.org/abs/1312.5602).\n", + "\n", + "The parameters of the neural network are denoted by $\\theta$. \n", + "* As input, the network takes a state $s$,\n", + "* As output, the network returns $Q(s, a, \\theta)$, the value of each action $a$ in state $s$, according to the parameters $\\theta$.\n", + "\n", + "\n", + "The goal of Deep Q-Learning is to learn the parameters $\\theta$ so that $Q(s, a, \\theta)$ approximates well the optimal $Q$-function $Q^*(s, a)$. \n", + "\n", + "In addition to the network with parameters $\\theta$, the algorithm keeps another network with the same architecture and parameters $\\theta^-$, called **target network**.\n", + "\n", + "The algorithm works as follows:\n", + "\n", + "1. At each time $t$, the agent is in state $s_t$ and has observed the transitions $(s_i, a_i, r_i, s_i')_{i=1}^{t-1}$, which are stored in a **replay buffer**.\n", + "\n", + "2. Choose action $a_t = \\arg\\max_a Q(s_t, a)$ with probability $1-\\varepsilon_t$, and $a_t$=random action with probability $\\varepsilon_t$. \n", + "\n", + "3. Take action $a_t$, observe reward $r_t$ and next state $s_t'$.\n", + "\n", + "4. Add transition $(s_t, a_t, r_t, s_t')$ to the **replay buffer**.\n", + "\n", + "4. Sample a minibatch $\\mathcal{B}$ containing $B$ transitions from the replay buffer. Using this minibatch, we define the loss:\n", + "\n", + "$$\n", + "L(\\theta) = \\sum_{(s_i, a_i, r_i, s_i') \\in \\mathcal{B}}\n", + "\\left[\n", + "Q(s_i, a_i, \\theta) - y_i\n", + "\\right]^2\n", + "$$\n", + "where the $y_i$ are the **targets** computed with the **target network** $\\theta^-$:\n", + "\n", + "$$\n", + "y_i = r_i + \\gamma \\max_{a'} Q(s_i', a', \\theta^-).\n", + "$$\n", + "\n", + "5. Update the parameters $\\theta$ to minimize the loss, e.g., with gradient descent (**keeping $\\theta^-$ fixed**): \n", + "$$\n", + "\\theta \\gets \\theta - \\eta \\nabla_\\theta L(\\theta)\n", + "$$\n", + "where $\\eta$ is the optimization learning rate. \n", + "\n", + "6. Every $N$ transitions ($t\\mod N$ = 0), update target parameters: $\\theta^- \\gets \\theta$.\n", + "\n", + "7. $t \\gets t+1$. Stop if $t = T$, otherwise go to step 2." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HhKHif__t9OD" + }, + "source": [ + "# Notebook setup" + ] + }, + { + "cell_type": "code", + "execution_count": 213, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "aylqy_sDqebM", + "outputId": "e1a78b7f-f832-4119-e8c5-3e02264944d9" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Defaulting to user installation because normal site-packages is not writeable\n", + "Requirement already satisfied: rlberry in /home/hkohler/.local/lib/python3.10/site-packages (0.5.0)\n", + "Requirement already satisfied: dill in /home/hkohler/.local/lib/python3.10/site-packages (from rlberry) (0.3.7)\n", + "Requirement already satisfied: matplotlib in /home/hkohler/.local/lib/python3.10/site-packages (from rlberry) (3.7.2)\n", + "Requirement already satisfied: numpy>=1.17 in /home/hkohler/.local/lib/python3.10/site-packages (from rlberry) (1.25.1)\n", + "Requirement already satisfied: pandas in /home/hkohler/.local/lib/python3.10/site-packages (from rlberry) (2.0.3)\n", + "Requirement already satisfied: docopt in /home/hkohler/.local/lib/python3.10/site-packages (from rlberry) (0.6.2)\n", + "Requirement already satisfied: tqdm in /home/hkohler/.local/lib/python3.10/site-packages (from rlberry) (4.66.1)\n", + "Requirement already satisfied: gymnasium in /home/hkohler/.local/lib/python3.10/site-packages (from rlberry) (0.29.1)\n", + "Requirement already satisfied: seaborn in /home/hkohler/.local/lib/python3.10/site-packages (from rlberry) (0.12.2)\n", + "Requirement already satisfied: pyyaml in /usr/lib/python3/dist-packages (from rlberry) (5.4.1)\n", + "Requirement already satisfied: pygame-ce in /home/hkohler/.local/lib/python3.10/site-packages (from rlberry) (2.3.2)\n", + "Requirement already satisfied: scipy>=1.6 in /home/hkohler/.local/lib/python3.10/site-packages (from rlberry) (1.11.1)\n", + "Requirement already satisfied: farama-notifications>=0.0.1 in /home/hkohler/.local/lib/python3.10/site-packages (from gymnasium->rlberry) (0.0.4)\n", + "Requirement already satisfied: cloudpickle>=1.2.0 in /home/hkohler/.local/lib/python3.10/site-packages (from gymnasium->rlberry) (2.2.1)\n", + "Requirement already satisfied: typing-extensions>=4.3.0 in /home/hkohler/.local/lib/python3.10/site-packages (from gymnasium->rlberry) (4.7.1)\n", + "Requirement already satisfied: pillow>=6.2.0 in /usr/lib/python3/dist-packages (from matplotlib->rlberry) (9.0.1)\n", + "Requirement already satisfied: pyparsing<3.1,>=2.3.1 in /usr/lib/python3/dist-packages (from matplotlib->rlberry) (2.4.7)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /home/hkohler/.local/lib/python3.10/site-packages (from matplotlib->rlberry) (2.8.2)\n", + "Requirement already satisfied: cycler>=0.10 in /home/hkohler/.local/lib/python3.10/site-packages (from matplotlib->rlberry) (0.11.0)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /home/hkohler/.local/lib/python3.10/site-packages (from matplotlib->rlberry) (1.1.0)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /home/hkohler/.local/lib/python3.10/site-packages (from matplotlib->rlberry) (4.41.0)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /home/hkohler/.local/lib/python3.10/site-packages (from matplotlib->rlberry) (1.4.4)\n", + "Requirement already satisfied: packaging>=20.0 in /home/hkohler/.local/lib/python3.10/site-packages (from matplotlib->rlberry) (23.1)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/lib/python3/dist-packages (from pandas->rlberry) (2022.1)\n", + "Requirement already satisfied: tzdata>=2022.1 in /home/hkohler/.local/lib/python3.10/site-packages (from pandas->rlberry) (2023.3)\n", + "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.7->matplotlib->rlberry) (1.16.0)\n", + "Defaulting to user installation because normal site-packages is not writeable\n", + "Requirement already satisfied: stable-baselines3 in /home/hkohler/.local/lib/python3.10/site-packages (2.1.0)\n", + "Requirement already satisfied: matplotlib in /home/hkohler/.local/lib/python3.10/site-packages (from stable-baselines3) (3.7.2)\n", + "Requirement already satisfied: gymnasium<0.30,>=0.28.1 in /home/hkohler/.local/lib/python3.10/site-packages (from stable-baselines3) (0.29.1)\n", + "Requirement already satisfied: torch>=1.13 in /home/hkohler/.local/lib/python3.10/site-packages (from stable-baselines3) (2.0.1)\n", + "Requirement already satisfied: pandas in /home/hkohler/.local/lib/python3.10/site-packages (from stable-baselines3) (2.0.3)\n", + "Requirement already satisfied: cloudpickle in /home/hkohler/.local/lib/python3.10/site-packages (from stable-baselines3) (2.2.1)\n", + "Requirement already satisfied: numpy>=1.20 in /home/hkohler/.local/lib/python3.10/site-packages (from stable-baselines3) (1.25.1)\n", + "Requirement already satisfied: typing-extensions>=4.3.0 in /home/hkohler/.local/lib/python3.10/site-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (4.7.1)\n", + "Requirement already satisfied: farama-notifications>=0.0.1 in /home/hkohler/.local/lib/python3.10/site-packages (from gymnasium<0.30,>=0.28.1->stable-baselines3) (0.0.4)\n", + "Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /home/hkohler/.local/lib/python3.10/site-packages (from torch>=1.13->stable-baselines3) (11.7.91)\n", + "Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in /home/hkohler/.local/lib/python3.10/site-packages (from torch>=1.13->stable-baselines3) (10.2.10.91)\n", + "Requirement already satisfied: jinja2 in /usr/lib/python3/dist-packages (from torch>=1.13->stable-baselines3) (3.0.3)\n", + "Requirement already satisfied: networkx in /home/hkohler/.local/lib/python3.10/site-packages (from torch>=1.13->stable-baselines3) (3.1)\n", + "Requirement already satisfied: filelock in /home/hkohler/.local/lib/python3.10/site-packages (from torch>=1.13->stable-baselines3) (3.12.3)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /home/hkohler/.local/lib/python3.10/site-packages (from torch>=1.13->stable-baselines3) (11.7.101)\n", + "Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in /home/hkohler/.local/lib/python3.10/site-packages (from torch>=1.13->stable-baselines3) (11.7.4.91)\n", + "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /home/hkohler/.local/lib/python3.10/site-packages (from torch>=1.13->stable-baselines3) (8.5.0.96)\n", + "Requirement already satisfied: sympy in /home/hkohler/.local/lib/python3.10/site-packages (from torch>=1.13->stable-baselines3) (1.12)\n", + "Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /home/hkohler/.local/lib/python3.10/site-packages (from torch>=1.13->stable-baselines3) (2.14.3)\n", + "Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /home/hkohler/.local/lib/python3.10/site-packages (from torch>=1.13->stable-baselines3) (11.4.0.1)\n", + "Requirement already satisfied: triton==2.0.0 in /home/hkohler/.local/lib/python3.10/site-packages (from torch>=1.13->stable-baselines3) (2.0.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /home/hkohler/.local/lib/python3.10/site-packages (from torch>=1.13->stable-baselines3) (11.7.99)\n", + "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /home/hkohler/.local/lib/python3.10/site-packages (from torch>=1.13->stable-baselines3) (11.10.3.66)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /home/hkohler/.local/lib/python3.10/site-packages (from torch>=1.13->stable-baselines3) (11.7.99)\n", + "Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /home/hkohler/.local/lib/python3.10/site-packages (from torch>=1.13->stable-baselines3) (10.9.0.58)\n", + "Requirement already satisfied: wheel in /usr/lib/python3/dist-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.13->stable-baselines3) (0.37.1)\n", + "Requirement already satisfied: setuptools in /home/hkohler/.local/lib/python3.10/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.13->stable-baselines3) (68.0.0)\n", + "Requirement already satisfied: cmake in /home/hkohler/.local/lib/python3.10/site-packages (from triton==2.0.0->torch>=1.13->stable-baselines3) (3.27.2)\n", + "Requirement already satisfied: lit in /home/hkohler/.local/lib/python3.10/site-packages (from triton==2.0.0->torch>=1.13->stable-baselines3) (16.0.6)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /home/hkohler/.local/lib/python3.10/site-packages (from matplotlib->stable-baselines3) (1.4.4)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /home/hkohler/.local/lib/python3.10/site-packages (from matplotlib->stable-baselines3) (1.1.0)\n", + "Requirement already satisfied: packaging>=20.0 in /home/hkohler/.local/lib/python3.10/site-packages (from matplotlib->stable-baselines3) (23.1)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /home/hkohler/.local/lib/python3.10/site-packages (from matplotlib->stable-baselines3) (2.8.2)\n", + "Requirement already satisfied: cycler>=0.10 in /home/hkohler/.local/lib/python3.10/site-packages (from matplotlib->stable-baselines3) (0.11.0)\n", + "Requirement already satisfied: pillow>=6.2.0 in /usr/lib/python3/dist-packages (from matplotlib->stable-baselines3) (9.0.1)\n", + "Requirement already satisfied: pyparsing<3.1,>=2.3.1 in /usr/lib/python3/dist-packages (from matplotlib->stable-baselines3) (2.4.7)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /home/hkohler/.local/lib/python3.10/site-packages (from matplotlib->stable-baselines3) (4.41.0)\n", + "Requirement already satisfied: tzdata>=2022.1 in /home/hkohler/.local/lib/python3.10/site-packages (from pandas->stable-baselines3) (2023.3)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/lib/python3/dist-packages (from pandas->stable-baselines3) (2022.1)\n", + "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.7->matplotlib->stable-baselines3) (1.16.0)\n", + "Requirement already satisfied: mpmath>=0.19 in /home/hkohler/.local/lib/python3.10/site-packages (from sympy->torch>=1.13->stable-baselines3) (1.3.0)\n" + ] + } + ], + "source": [ + "# install rlberry library\n", + "!pip install rlberry" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 1. Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "VWBRfwosfA9f" + }, + "outputs": [], + "source": [ + "# Imports\n", + "from rlberry.agents import AgentWithSimplePolicy\n", + "from rlberry.manager import (\n", + " AgentManager,\n", + " evaluate_agents,\n", + " plot_writer_data,\n", + " read_writer_data,\n", + ")\n", + "from rlberry.wrappers import WriterWrapper\n", + "from rlberry.envs import gym_make, atari_make\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "import numpy as np\n", + "from copy import deepcopy\n", + "import gymnasium as gym\n", + "\n", + "\n", + "rng = np.random.default_rng(seed=42)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6g16Je-dhM2Q" + }, + "source": [ + "# 2. Define the replay buffer class" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "Jvh82br9hMNt" + }, + "outputs": [], + "source": [ + "class ReplayBuffer:\n", + " def __init__(self, capacity, env):\n", + " self.capacity = capacity\n", + " self.position = 0\n", + " self.observations = torch.zeros((capacity, *env.observation_space.shape), dtype=torch.float32)\n", + " self.next_observations = torch.zeros((capacity, *env.observation_space.shape), dtype=torch.float32)\n", + " self.actions = torch.zeros((capacity, 1), dtype=torch.int64)\n", + " self.rewards = torch.zeros((capacity,1), dtype=torch.float32)\n", + " self.terminateds = torch.zeros((capacity,1),dtype=torch.bool)\n", + "\n", + " def push(self, state, action, reward, next_state, done):\n", + " \"\"\"Saves a transition.\"\"\"\n", + " self.observations[self.position] = torch.tensor(state)\n", + " self.next_observations[self.position] = torch.tensor(next_state)\n", + " self.actions[self.position] = torch.tensor(action)\n", + " self.rewards[self.position] = torch.tensor(reward)\n", + " self.terminateds[self.position] = torch.tensor(done)\n", + " self.position = (self.position + 1) % self.capacity\n", + "\n", + " def sample(self, batch_size):\n", + " indices = rng.choice(np.arange(self.position), size=batch_size, replace=False)\n", + " return self.observations[indices], self.actions[indices], self.rewards[indices], self.next_observations[indices], self.terminateds[indices]\n", + "\n", + " def __len__(self):\n", + " return self.position\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UCc9WZppi92W" + }, + "source": [ + "# 3. Define the neural network class for the $Q$-functions" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "sdNz3Jrwi9iS" + }, + "outputs": [], + "source": [ + "class Net(nn.Module):\n", + " \"\"\"\n", + " Basic neural net.\n", + " \"\"\"\n", + " def __init__(self, obs_size, hidden_size, n_actions):\n", + " super(Net, self).__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Linear(obs_size, hidden_size),\n", + " nn.Tanh(),\n", + " nn.Linear(hidden_size, n_actions)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " return self.net(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xnR8nfoSjZjL" + }, + "source": [ + "# 4. Implement Deep Q-Learning" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a DQN class that is interfacable with rlberry.AgentManager\n", + "class MyDQN(AgentWithSimplePolicy):\n", + " name = \"MyDQN\"\n", + " def __init__(self, env: gym.Env, \n", + " gamma: float, \n", + " batch_size: int, \n", + " buffer_capacity: float, \n", + " update_target_every: int,\n", + " epsilon_start: float,\n", + " decrease_epsilon: int,\n", + " epsilon_min: float,\n", + " replay_buffer: ReplayBuffer,\n", + " qnetworks: torch.nn,\n", + " optimizer: optim,\n", + " loss_fn: torch.nn,\n", + " **kwargs):\n", + " AgentWithSimplePolicy.__init__(self, env, **kwargs) # Interface with rlberry API\n", + " self.env = WriterWrapper(self.env, self.writer, write_scalar=\"action_and_reward\") # \"action\", \"reward\" \n", + " self.replay_buffer = ReplayBuffer(buffer_capacity, self.env)\n", + "\n", + " # create network and target network\n", + " hidden_size = 64\n", + " obs_size = self.env.observation_space.shape[0]\n", + " n_actions = self.env.action_space.n\n", + "\n", + " self.qnet = qnetworks(obs_size, hidden_size, n_actions)\n", + " self.qtarget = qnetworks(obs_size, hidden_size, n_actions)\n", + "\n", + " # objective and optimizer\n", + " self.optimizer = optimizer(params=self.qnet.parameters(), lr=1e-3)\n", + " self.loss = loss_fn\n", + "\n", + " \n", + " self.gamma = gamma\n", + " self.batch_size = batch_size\n", + " self.buffer_capacity = buffer_capacity\n", + " self.update_target_every = update_target_every\n", + " self.epsilon_start = epsilon_start\n", + " self.decrease_epsilon = decrease_epsilon\n", + " self.epsilon_min = epsilon_min\n", + " self.replay = replay_buffer\n", + " \n", + "\n", + " def choose_action(self, state, epsilon):\n", + " \"\"\"\n", + " Return action according to an epsilon-greedy exploration policy\n", + " \"\"\"\n", + " if rng.random() < epsilon:\n", + " return self.env.action_space.sample()\n", + " else:\n", + " with torch.no_grad():\n", + " return self.qnet(torch.FloatTensor(state)).argmax().item()\n", + " \n", + "\n", + " def update(self, state, action, reward, next_state, done):\n", + " \"\"\"\n", + " Updates qnetwork weights\n", + " \"\"\"\n", + " \n", + " # add data to replay buffer\n", + " self.replay_buffer.push(state, action, reward, next_state, done)\n", + " \n", + " # if there is not enough samples in the buffer, we dont update the networks\n", + " if len(self.replay_buffer) < self.batch_size:\n", + " return np.inf\n", + " \n", + " # get batch\n", + " states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)\n", + " \n", + "\n", + " # compute targets y = r + gamma max_ a' Q'(s', a')\n", + " with torch.no_grad():\n", + " qtarget_snext = self.qtarget(torch.FloatTensor(next_states))\n", + " qtarget_snext_maxa = torch.max(qtarget_snext, dim=1).values.reshape(self.batch_size, 1)\n", + " targets = rewards + self.gamma * (~dones * qtarget_snext_maxa)\n", + "\n", + " predictions = self.qnet(torch.FloatTensor(states)).gather(dim=1, index=actions)\n", + " loss = self.loss(predictions, targets)\n", + " \n", + " # Optimize the model\n", + " self.optimizer.zero_grad()\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " \n", + " return loss.item()\n", + "\n", + " def fit(self, budget=10000, **kwargs):\n", + " state, _ = self.env.reset()\n", + " epsilon = self.epsilon_start\n", + " total_time_steps = 0\n", + " episode_nb = 0\n", + " episode_reward = 0\n", + " while total_time_steps < budget:\n", + " action = self.choose_action(state, epsilon)\n", + "\n", + " # take action and update replay buffer and networks\n", + " next_state, reward, terminated, truncated, _ = self.env.step(action)\n", + " episode_reward += reward\n", + " done = terminated or truncated\n", + " loss = self.update(state, action, reward, next_state, terminated)\n", + "\n", + " # update state\n", + " state = next_state\n", + "\n", + " # end episode if done\n", + " if done:\n", + " episode_nb += 1\n", + " state, _ = self.env.reset()\n", + " self.writer.add_scalar('episode_reward', episode_reward, global_step=total_time_steps)\n", + " episode_reward = 0\n", + " \n", + " # update target network\n", + " if total_time_steps % self.update_target_every == 0:\n", + " self.qtarget.load_state_dict(self.qnet.state_dict())\n", + "\n", + " total_time_steps += 1\n", + " # decrease epsilon\n", + " epsilon = self.epsilon_min + (self.epsilon_start - self.epsilon_min) * \\\n", + " np.exp(-1. * total_time_steps / self.decrease_epsilon)\n", + " self.writer.add_scalar('epsilon', epsilon)\n", + "\n", + " def policy(self, state):\n", + " return self.qnet(torch.FloatTensor(state)).argmax().item() " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 5. Instantiate two DQN agents to compare, one with a small batch size and one with default batch size." + ] + }, + { + "cell_type": "code", + "execution_count": 258, + "metadata": {}, + "outputs": [], + "source": [ + "# Environment\n", + "env_ctor = gym_make # environment constructor\n", + "env_kwargs = {\"id\": \"CartPole-v1\"} # parameters passed to the environment constructor. Here name of the environment\n", + "\n", + "names = [\"DQN\", \"DQN-small-batch-size\"]\n", + "agent_list = [MyDQN, MyDQN] # list of agents on which we run experiments\n", + "\n", + "agent_params = [dict(gamma=0.99, \n", + " batch_size=256, \n", + " buffer_capacity=10000, \n", + " update_target_every=1000, \n", + " epsilon_start=1, \n", + " decrease_epsilon=10_000, \n", + " epsilon_min=0.05,\n", + " qnetworks=Net,\n", + " replay_buffer=ReplayBuffer,\n", + " optimizer=optim.Adam,\n", + " loss_fn=F.mse_loss,\n", + " ), \n", + " dict(gamma=0.99, \n", + " batch_size=8, \n", + " buffer_capacity=10000, \n", + " update_target_every=1000, \n", + " epsilon_start=1, \n", + " decrease_epsilon=10_000, \n", + " epsilon_min=0.05,\n", + " qnetworks=Net,\n", + " replay_buffer=ReplayBuffer,\n", + " optimizer=optim.Adam,\n", + " loss_fn=F.mse_loss,\n", + " )]\n", + "\n", + "\n", + "fit_budget = 50_000 # budgent used by an agent in one fit. This is used differently by different agents.\n", + "n_fit = 3 # number of parallel fits of the same agent.\n", + "\n", + "agent_managers = [\n", + " AgentManager( # an AgentManager is used to repeatedly fit and evaluate an agent on an environment\n", + " Agent,\n", + " agent_name=names[e],\n", + " init_kwargs=agent_params[e],\n", + " train_env=(env_ctor, env_kwargs),\n", + " fit_budget=fit_budget,\n", + " n_fit=n_fit,\n", + " eval_kwargs=dict(eval_horizon=500, n_simulations=10)\n", + " )\n", + " for e, Agent in enumerate(agent_list)\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 6. Fit the DQN agents." + ] + }, + { + "cell_type": "code", + "execution_count": 259, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[38;21m[INFO] 16:00: Running AgentManager fit() for DQN with n_fit = 3 and max_workers = None. \u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 2]] | max_global_step = 1302 | reward = 1.0 | action = 1 | epsilon = 0.8840238379253593 | episode_reward = 14.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 0]] | max_global_step = 1316 | reward = 1.0 | action = 1 | epsilon = 0.8828570215143315 | episode_reward = 36.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 1]] | max_global_step = 1304 | reward = 1.0 | action = 0 | epsilon = 0.883857049837139 | episode_reward = 13.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 2]] | max_global_step = 2415 | reward = 1.0 | action = 1 | epsilon = 0.7961763635995357 | episode_reward = 13.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 0]] | max_global_step = 2428 | reward = 1.0 | action = 1 | epsilon = 0.7952069645727475 | episode_reward = 22.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 1]] | max_global_step = 2408 | reward = 1.0 | action = 1 | epsilon = 0.7966988699099283 | episode_reward = 57.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 2]] | max_global_step = 3527 | reward = 1.0 | action = 1 | epsilon = 0.7176485982466707 | episode_reward = 31.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 0]] | max_global_step = 3545 | reward = 1.0 | action = 0 | epsilon = 0.7164479117118934 | episode_reward = 80.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 1]] | max_global_step = 3509 | reward = 1.0 | action = 0 | epsilon = 0.7189183364526556 | episode_reward = 18.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 2]] | max_global_step = 4083 | reward = 1.0 | action = 0 | epsilon = 0.681540444325822 | episode_reward = 25.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 0]] | max_global_step = 4097 | reward = 1.0 | action = 0 | epsilon = 0.6806569063246778 | episode_reward = 18.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 1]] | max_global_step = 4064 | reward = 1.0 | action = 1 | epsilon = 0.682741511822842 | episode_reward = 49.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 2]] | max_global_step = 4771 | reward = 1.0 | action = 0 | epsilon = 0.6395514447364363 | episode_reward = 20.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 0]] | max_global_step = 4787 | reward = 1.0 | action = 0 | epsilon = 0.6386089166484011 | episode_reward = 47.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 1]] | max_global_step = 4749 | reward = 1.0 | action = 1 | epsilon = 0.6408498856761857 | episode_reward = 19.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 2]] | max_global_step = 5589 | reward = 1.0 | action = 1 | epsilon = 0.5932458525509163 | episode_reward = 116.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 0]] | max_global_step = 5602 | reward = 1.0 | action = 1 | epsilon = 0.5925400917864916 | episode_reward = 98.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 1]] | max_global_step = 5574 | reward = 1.0 | action = 0 | epsilon = 0.5940613327870172 | episode_reward = 12.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 2]] | max_global_step = 6710 | reward = 1.0 | action = 1 | epsilon = 0.5356858348840324 | episode_reward = 111.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 0]] | max_global_step = 6702 | reward = 1.0 | action = 0 | epsilon = 0.5360259339892508 | episode_reward = 111.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 1]] | max_global_step = 6685 | reward = 1.0 | action = 1 | epsilon = 0.5368528807826506 | episode_reward = 110.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 2]] | max_global_step = 7767 | reward = 1.0 | action = 1 | epsilon = 0.48692518741559415 | episode_reward = 188.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 0]] | max_global_step = 7787 | reward = 1.0 | action = 0 | epsilon = 0.48605221030886203 | episode_reward = 131.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 1]] | max_global_step = 7752 | reward = 1.0 | action = 1 | epsilon = 0.487581066983416 | episode_reward = 21.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 2]] | max_global_step = 8732 | reward = 1.0 | action = 1 | epsilon = 0.4467323946839852 | episode_reward = 219.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 0]] | max_global_step = 8752 | reward = 1.0 | action = 0 | epsilon = 0.44593972283069444 | episode_reward = 56.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 1]] | max_global_step = 8719 | reward = 1.0 | action = 1 | epsilon = 0.44724848218126523 | episode_reward = 199.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 2]] | max_global_step = 9749 | reward = 1.0 | action = 1 | epsilon = 0.4084044095993085 | episode_reward = 103.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 0]] | max_global_step = 9769 | reward = 1.0 | action = 0 | epsilon = 0.4076525500679663 | episode_reward = 201.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 1]] | max_global_step = 9751 | reward = 1.0 | action = 1 | epsilon = 0.40829690440301447 | episode_reward = 137.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 2]] | max_global_step = 11011 | reward = 1.0 | action = 1 | epsilon = 0.3658798704782359 | episode_reward = 29.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 0]] | max_global_step = 11025 | reward = 1.0 | action = 0 | epsilon = 0.3654379480774276 | episode_reward = 228.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 1]] | max_global_step = 11009 | reward = 1.0 | action = 1 | epsilon = 0.36594305277035016 | episode_reward = 199.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 2]] | max_global_step = 12053 | reward = 1.0 | action = 0 | epsilon = 0.3346220001282746 | episode_reward = 206.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 0]] | max_global_step = 12079 | reward = 1.0 | action = 1 | epsilon = 0.3338829441170905 | episode_reward = 298.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:00: [DQN[worker: 1]] | max_global_step = 12066 | reward = 1.0 | action = 0 | epsilon = 0.33425223192951276 | episode_reward = 428.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 13085 | reward = 1.0 | action = 1 | epsilon = 0.3067138356602251 | episode_reward = 249.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 13112 | reward = 1.0 | action = 1 | epsilon = 0.3060216431842919 | episode_reward = 168.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 13102 | reward = 1.0 | action = 0 | epsilon = 0.3062777928809787 | episode_reward = 376.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 14124 | reward = 1.0 | action = 0 | epsilon = 0.2813801397521377 | episode_reward = 316.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 14159 | reward = 1.0 | action = 1 | epsilon = 0.28059478313978153 | episode_reward = 227.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 14134 | reward = 1.0 | action = 0 | epsilon = 0.28114887526390175 | episode_reward = 396.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 15160 | reward = 1.0 | action = 0 | epsilon = 0.25860906220386176 | episode_reward = 290.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 15200 | reward = 1.0 | action = 1 | epsilon = 0.25779707127278056 | episode_reward = 366.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 15163 | reward = 1.0 | action = 0 | epsilon = 0.25854648887166976 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 16233 | reward = 1.0 | action = 0 | epsilon = 0.23738437426318232 | episode_reward = 316.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 16251 | reward = 1.0 | action = 0 | epsilon = 0.23704738577013923 | episode_reward = 266.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 16210 | reward = 1.0 | action = 1 | epsilon = 0.23781585433586044 | episode_reward = 450.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 17251 | reward = 1.0 | action = 1 | epsilon = 0.21924747359062885 | episode_reward = 255.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 17274 | reward = 1.0 | action = 0 | epsilon = 0.21885865171792968 | episode_reward = 333.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 17227 | reward = 1.0 | action = 0 | epsilon = 0.2196541553501506 | episode_reward = 365.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 18252 | reward = 1.0 | action = 1 | epsilon = 0.2031261336338343 | episode_reward = 251.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 18280 | reward = 1.0 | action = 0 | epsilon = 0.20269798015425783 | episode_reward = 381.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 18243 | reward = 1.0 | action = 1 | epsilon = 0.2032640091887979 | episode_reward = 336.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 19263 | reward = 1.0 | action = 1 | epsilon = 0.18840192950473478 | episode_reward = 251.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 19303 | reward = 1.0 | action = 0 | epsilon = 0.18784942752733974 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 19262 | reward = 1.0 | action = 0 | epsilon = 0.18841577038971796 | episode_reward = 264.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 20466 | reward = 1.0 | action = 1 | epsilon = 0.1727146798321741 | episode_reward = 428.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 20510 | reward = 1.0 | action = 0 | epsilon = 0.17217592137870685 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 20481 | reward = 1.0 | action = 0 | epsilon = 0.17253074579743954 | episode_reward = 356.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 21486 | reward = 1.0 | action = 0 | epsilon = 0.16081498231203795 | episode_reward = 243.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 21520 | reward = 1.0 | action = 0 | epsilon = 0.16043885115747936 | episode_reward = 326.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 21486 | reward = 1.0 | action = 0 | epsilon = 0.16081498231203795 | episode_reward = 345.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 22493 | reward = 1.0 | action = 1 | epsilon = 0.15019937835549935 | episode_reward = 225.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 22552 | reward = 1.0 | action = 0 | epsilon = 0.14960994256862714 | episode_reward = 304.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 22504 | reward = 1.0 | action = 1 | epsilon = 0.15008921963771074 | episode_reward = 311.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 23439 | reward = 1.0 | action = 0 | epsilon = 0.1411550574585841 | episode_reward = 221.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 23497 | reward = 1.0 | action = 0 | epsilon = 0.14062788839344298 | episode_reward = 342.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 23451 | reward = 1.0 | action = 1 | epsilon = 0.14104573699503037 | episode_reward = 297.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 24350 | reward = 1.0 | action = 0 | epsilon = 0.13321785969565475 | episode_reward = 234.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 24408 | reward = 1.0 | action = 1 | epsilon = 0.13273659313160555 | episode_reward = 273.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 24356 | reward = 1.0 | action = 1 | epsilon = 0.13316794395605672 | episode_reward = 303.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 25147 | reward = 1.0 | action = 1 | epsilon = 0.1268428159835826 | episode_reward = 207.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 25209 | reward = 1.0 | action = 1 | epsilon = 0.1263678643958337 | episode_reward = 258.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 25142 | reward = 1.0 | action = 1 | epsilon = 0.12688124699852749 | episode_reward = 283.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 25926 | reward = 1.0 | action = 0 | epsilon = 0.12108397926276343 | episode_reward = 195.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 25984 | reward = 1.0 | action = 1 | epsilon = 0.12067995314928918 | episode_reward = 335.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 25922 | reward = 1.0 | action = 1 | epsilon = 0.12111241854194518 | episode_reward = 262.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 26687 | reward = 1.0 | action = 1 | epsilon = 0.1158751961584574 | episode_reward = 198.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 26745 | reward = 1.0 | action = 1 | epsilon = 0.11549422590246672 | episode_reward = 282.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 26673 | reward = 1.0 | action = 0 | epsilon = 0.11596748602090894 | episode_reward = 262.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 27495 | reward = 1.0 | action = 1 | epsilon = 0.1107618414731443 | episode_reward = 207.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 27557 | reward = 1.0 | action = 0 | epsilon = 0.11038628348879892 | episode_reward = 254.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 27480 | reward = 1.0 | action = 1 | epsilon = 0.11085305262661702 | episode_reward = 270.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 28358 | reward = 1.0 | action = 0 | epsilon = 0.10573799130821762 | episode_reward = 306.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 28434 | reward = 1.0 | action = 1 | epsilon = 0.10531598821726039 | episode_reward = 258.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 28342 | reward = 1.0 | action = 0 | epsilon = 0.10582724347700533 | episode_reward = 264.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 29171 | reward = 1.0 | action = 1 | epsilon = 0.101385805903869 | episode_reward = 238.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 29250 | reward = 1.0 | action = 1 | epsilon = 0.10098145731709363 | episode_reward = 268.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 29169 | reward = 1.0 | action = 1 | epsilon = 0.1013960840928344 | episode_reward = 243.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 29985 | reward = 1.0 | action = 0 | epsilon = 0.09736871475843922 | episode_reward = 219.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 30274 | reward = 1.0 | action = 1 | epsilon = 0.09601935112199322 | episode_reward = 274.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 29979 | reward = 1.0 | action = 1 | epsilon = 0.09739714451536846 | episode_reward = 311.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 2]] | max_global_step = 30980 | reward = 1.0 | action = 0 | epsilon = 0.0928824214090031 | episode_reward = 234.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 0]] | max_global_step = 31043 | reward = 1.0 | action = 0 | epsilon = 0.09261311137148691 | episode_reward = 246.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:01: [DQN[worker: 1]] | max_global_step = 30967 | reward = 1.0 | action = 0 | epsilon = 0.09293820480818811 | episode_reward = 339.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 31753 | reward = 1.0 | action = 0 | epsilon = 0.089692489347529 | episode_reward = 223.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 31816 | reward = 1.0 | action = 0 | epsilon = 0.08944321271052806 | episode_reward = 250.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 31748 | reward = 1.0 | action = 1 | epsilon = 0.08971234055459097 | episode_reward = 299.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 32541 | reward = 1.0 | action = 1 | epsilon = 0.08668478108029745 | episode_reward = 252.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 32610 | reward = 1.0 | action = 1 | epsilon = 0.0864325273669702 | episode_reward = 243.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 32532 | reward = 1.0 | action = 1 | epsilon = 0.08671781224506428 | episode_reward = 378.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 33335 | reward = 1.0 | action = 0 | epsilon = 0.08388464576828217 | episode_reward = 216.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 33411 | reward = 1.0 | action = 1 | epsilon = 0.0836280985746251 | episode_reward = 219.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 33324 | reward = 1.0 | action = 0 | epsilon = 0.08392533174991076 | episode_reward = 317.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 34110 | reward = 1.0 | action = 0 | epsilon = 0.08135776690607546 | episode_reward = 236.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 34174 | reward = 1.0 | action = 0 | epsilon = 0.08115771803709033 | episode_reward = 240.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 34085 | reward = 1.0 | action = 0 | epsilon = 0.08143625939807414 | episode_reward = 310.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 34936 | reward = 1.0 | action = 1 | epsilon = 0.07887170311808674 | episode_reward = 79.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 35011 | reward = 1.0 | action = 1 | epsilon = 0.07865597533511028 | episode_reward = 231.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 34915 | reward = 1.0 | action = 1 | epsilon = 0.07893239740132697 | episode_reward = 320.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 35759 | reward = 1.0 | action = 1 | epsilon = 0.07659071207217588 | episode_reward = 176.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 35836 | reward = 1.0 | action = 1 | epsilon = 0.07638674985151202 | episode_reward = 227.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 35732 | reward = 1.0 | action = 1 | epsilon = 0.07666527039892398 | episode_reward = 320.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 36606 | reward = 1.0 | action = 1 | epsilon = 0.07443122395716767 | episode_reward = 153.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 36686 | reward = 1.0 | action = 1 | epsilon = 0.07423655388404213 | episode_reward = 249.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 36575 | reward = 1.0 | action = 0 | epsilon = 0.07450707826486518 | episode_reward = 406.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 37286 | reward = 1.0 | action = 0 | epsilon = 0.07282512686429443 | episode_reward = 223.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 37376 | reward = 1.0 | action = 0 | epsilon = 0.07262062237312948 | episode_reward = 209.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 37266 | reward = 1.0 | action = 0 | epsilon = 0.07287082279872549 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 38064 | reward = 1.0 | action = 1 | epsilon = 0.07111665328916857 | episode_reward = 226.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 38157 | reward = 1.0 | action = 0 | epsilon = 0.07092117877893239 | episode_reward = 252.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 38045 | reward = 1.0 | action = 0 | epsilon = 0.0711568130701285 | episode_reward = 367.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 38804 | reward = 1.0 | action = 0 | epsilon = 0.06961043817831136 | episode_reward = 402.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 38897 | reward = 1.0 | action = 1 | epsilon = 0.06942890653378754 | episode_reward = 265.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 38784 | reward = 1.0 | action = 0 | epsilon = 0.06964969830170467 | episode_reward = 305.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 39490 | reward = 1.0 | action = 1 | epsilon = 0.06831026780233697 | episode_reward = 189.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 39601 | reward = 1.0 | action = 0 | epsilon = 0.06810814767172013 | episode_reward = 268.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 39482 | reward = 1.0 | action = 0 | epsilon = 0.06832492187742732 | episode_reward = 373.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 40409 | reward = 1.0 | action = 1 | epsilon = 0.06670255972452366 | episode_reward = 182.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 40518 | reward = 1.0 | action = 0 | epsilon = 0.06652149044383929 | episode_reward = 264.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 40406 | reward = 1.0 | action = 1 | epsilon = 0.06670757124413135 | episode_reward = 417.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 41107 | reward = 1.0 | action = 1 | epsilon = 0.06557647844793664 | episode_reward = 211.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 41218 | reward = 1.0 | action = 1 | epsilon = 0.06540607611605473 | episode_reward = 333.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 41112 | reward = 1.0 | action = 0 | epsilon = 0.065568692155448 | episode_reward = 365.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 41859 | reward = 1.0 | action = 1 | epsilon = 0.06444808651372655 | episode_reward = 147.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 41965 | reward = 1.0 | action = 1 | epsilon = 0.06429574562978177 | episode_reward = 320.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 41864 | reward = 1.0 | action = 0 | epsilon = 0.06444086427617954 | episode_reward = 497.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 42578 | reward = 1.0 | action = 1 | epsilon = 0.06344573539268861 | episode_reward = 153.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 42681 | reward = 1.0 | action = 0 | epsilon = 0.06330795510471739 | episode_reward = 343.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 42583 | reward = 1.0 | action = 1 | epsilon = 0.0634390142054291 | episode_reward = 355.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 43306 | reward = 1.0 | action = 1 | epsilon = 0.06250166686470679 | episode_reward = 141.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 43408 | reward = 1.0 | action = 0 | epsilon = 0.0623747979938792 | episode_reward = 343.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 43306 | reward = 1.0 | action = 0 | epsilon = 0.06250166686470679 | episode_reward = 359.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 44024 | reward = 1.0 | action = 1 | epsilon = 0.061635514136881474 | episode_reward = 142.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 44122 | reward = 1.0 | action = 0 | epsilon = 0.0615220430149831 | episode_reward = 351.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 44033 | reward = 1.0 | action = 0 | epsilon = 0.06162504688512811 | episode_reward = 344.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 44725 | reward = 1.0 | action = 0 | epsilon = 0.06084779663875191 | episode_reward = 120.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 44829 | reward = 1.0 | action = 1 | epsilon = 0.060735564174111814 | episode_reward = 382.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 44741 | reward = 1.0 | action = 0 | epsilon = 0.06083045404190714 | episode_reward = 289.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 45468 | reward = 1.0 | action = 0 | epsilon = 0.06007101991973842 | episode_reward = 143.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 45569 | reward = 1.0 | action = 0 | epsilon = 0.05996981456591426 | episode_reward = 332.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 45485 | reward = 1.0 | action = 1 | epsilon = 0.06005391373025567 | episode_reward = 294.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 2]] | max_global_step = 46237 | reward = 1.0 | action = 1 | epsilon = 0.05932558767129545 | episode_reward = 314.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 0]] | max_global_step = 46332 | reward = 1.0 | action = 0 | epsilon = 0.059237414076133074 | episode_reward = 354.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:02: [DQN[worker: 1]] | max_global_step = 46251 | reward = 1.0 | action = 1 | epsilon = 0.05931254098336815 | episode_reward = 288.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN[worker: 2]] | max_global_step = 47003 | reward = 1.0 | action = 0 | epsilon = 0.05863792148142126 | episode_reward = 169.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN[worker: 0]] | max_global_step = 47091 | reward = 1.0 | action = 0 | epsilon = 0.0585622412537755 | episode_reward = 344.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN[worker: 1]] | max_global_step = 47011 | reward = 1.0 | action = 0 | epsilon = 0.05863101390763404 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN[worker: 2]] | max_global_step = 47784 | reward = 1.0 | action = 0 | epsilon = 0.05798897115252949 | episode_reward = 174.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN[worker: 0]] | max_global_step = 47872 | reward = 1.0 | action = 1 | epsilon = 0.05791897663396629 | episode_reward = 359.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN[worker: 1]] | max_global_step = 47802 | reward = 1.0 | action = 1 | epsilon = 0.057974603938826426 | episode_reward = 313.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN[worker: 2]] | max_global_step = 48561 | reward = 1.0 | action = 1 | epsilon = 0.05739173130756995 | episode_reward = 183.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN[worker: 0]] | max_global_step = 48643 | reward = 1.0 | action = 1 | epsilon = 0.05733136694298395 | episode_reward = 371.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN[worker: 1]] | max_global_step = 48582 | reward = 1.0 | action = 0 | epsilon = 0.057376224959188434 | episode_reward = 382.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN[worker: 2]] | max_global_step = 49300 | reward = 1.0 | action = 1 | epsilon = 0.05686517811730764 | episode_reward = 165.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN[worker: 0]] | max_global_step = 49386 | reward = 1.0 | action = 0 | epsilon = 0.056806390733574644 | episode_reward = 358.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN[worker: 1]] | max_global_step = 49322 | reward = 1.0 | action = 1 | epsilon = 0.0568500913270039 | episode_reward = 371.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN[worker: 1]] | max_global_step = 49927 | reward = 1.0 | action = 0 | epsilon = 0.05644794828331579 | episode_reward = 356.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN[worker: 2]] | max_global_step = 49924 | reward = 1.0 | action = 0 | epsilon = 0.05644988295798748 | episode_reward = 184.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: ... trained! \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: Running AgentManager fit() for DQN-small-batch-size with n_fit = 3 and max_workers = None. \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 2]] | max_global_step = 896 | reward = 1.0 | action = 0 | epsilon = 0.918581989326102 | episode_reward = 9.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 0]] | max_global_step = 900 | reward = 1.0 | action = 1 | epsilon = 0.9183214538115854 | episode_reward = 25.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 1]] | max_global_step = 923 | reward = 1.0 | action = 1 | epsilon = 0.9162399810888118 | episode_reward = 10.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 2]] | max_global_step = 1799 | reward = 1.0 | action = 0 | epsilon = 0.8435860554784582 | episode_reward = 25.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 0]] | max_global_step = 1803 | reward = 1.0 | action = 1 | epsilon = 0.8432686845346873 | episode_reward = 13.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 1]] | max_global_step = 1822 | reward = 1.0 | action = 1 | epsilon = 0.8417629049776392 | episode_reward = 32.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 2]] | max_global_step = 2626 | reward = 1.0 | action = 1 | epsilon = 0.7805969827987554 | episode_reward = 18.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 0]] | max_global_step = 2624 | reward = 1.0 | action = 1 | epsilon = 0.780743116808229 | episode_reward = 58.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 1]] | max_global_step = 2658 | reward = 1.0 | action = 1 | epsilon = 0.7782628091235076 | episode_reward = 13.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 2]] | max_global_step = 3465 | reward = 1.0 | action = 1 | epsilon = 0.7218008783229077 | episode_reward = 15.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 0]] | max_global_step = 3453 | reward = 1.0 | action = 1 | epsilon = 0.7226075232670643 | episode_reward = 14.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 1]] | max_global_step = 3498 | reward = 1.0 | action = 0 | epsilon = 0.7195875893597907 | episode_reward = 22.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 2]] | max_global_step = 4350 | reward = 1.0 | action = 1 | epsilon = 0.6649014337241329 | episode_reward = 58.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 0]] | max_global_step = 4348 | reward = 1.0 | action = 1 | epsilon = 0.6650244263097262 | episode_reward = 35.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 1]] | max_global_step = 4387 | reward = 1.0 | action = 1 | epsilon = 0.6626917683468433 | episode_reward = 14.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 2]] | max_global_step = 5169 | reward = 1.0 | action = 0 | epsilon = 0.6166047578724597 | episode_reward = 48.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 0]] | max_global_step = 5159 | reward = 1.0 | action = 1 | epsilon = 0.6171149316983298 | episode_reward = 17.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 1]] | max_global_step = 5209 | reward = 1.0 | action = 1 | epsilon = 0.6142864341763439 | episode_reward = 23.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 2]] | max_global_step = 5965 | reward = 1.0 | action = 1 | epsilon = 0.5731990501059381 | episode_reward = 38.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 0]] | max_global_step = 5966 | reward = 1.0 | action = 1 | epsilon = 0.5731467328168356 | episode_reward = 34.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 1]] | max_global_step = 6019 | reward = 1.0 | action = 1 | epsilon = 0.5703813897651974 | episode_reward = 36.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 2]] | max_global_step = 6806 | reward = 1.0 | action = 0 | epsilon = 0.5309974576758437 | episode_reward = 32.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 0]] | max_global_step = 6801 | reward = 1.0 | action = 1 | epsilon = 0.5312380165393858 | episode_reward = 38.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 1]] | max_global_step = 6837 | reward = 1.0 | action = 0 | epsilon = 0.52950867436345 | episode_reward = 22.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 2]] | max_global_step = 7535 | reward = 1.0 | action = 1 | epsilon = 0.4971803516951027 | episode_reward = 100.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 0]] | max_global_step = 7532 | reward = 1.0 | action = 1 | epsilon = 0.49731452592573955 | episode_reward = 151.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 1]] | max_global_step = 7583 | reward = 1.0 | action = 1 | epsilon = 0.4950390292920709 | episode_reward = 129.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 2]] | max_global_step = 8374 | reward = 1.0 | action = 0 | epsilon = 0.46119270967516035 | episode_reward = 139.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 0]] | max_global_step = 8348 | reward = 1.0 | action = 1 | epsilon = 0.4622632017569783 | episode_reward = 161.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 1]] | max_global_step = 8407 | reward = 1.0 | action = 0 | epsilon = 0.45983801021672827 | episode_reward = 89.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 2]] | max_global_step = 9210 | reward = 1.0 | action = 1 | epsilon = 0.42821468517471956 | episode_reward = 163.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 0]] | max_global_step = 9180 | reward = 1.0 | action = 1 | epsilon = 0.4293510328995704 | episode_reward = 79.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:03: [DQN-small-batch-size[worker: 1]] | max_global_step = 9239 | reward = 1.0 | action = 0 | epsilon = 0.42711945144419833 | episode_reward = 35.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 10044 | reward = 1.0 | action = 1 | epsilon = 0.39798590796273076 | episode_reward = 172.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 9993 | reward = 1.0 | action = 0 | epsilon = 0.3997301945851715 | episode_reward = 152.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 10084 | reward = 1.0 | action = 0 | epsilon = 0.3965620865684777 | episode_reward = 39.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 10882 | reward = 1.0 | action = 1 | epsilon = 0.369981116973421 | episode_reward = 210.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 10841 | reward = 1.0 | action = 1 | epsilon = 0.3712957326736405 | episode_reward = 238.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 10936 | reward = 1.0 | action = 0 | epsilon = 0.36825787588019016 | episode_reward = 26.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 11763 | reward = 1.0 | action = 1 | epsilon = 0.3430261882144891 | episode_reward = 129.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 11707 | reward = 1.0 | action = 0 | epsilon = 0.34464227240732215 | episode_reward = 193.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 11798 | reward = 1.0 | action = 1 | epsilon = 0.3419731904700941 | episode_reward = 43.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 12534 | reward = 1.0 | action = 0 | epsilon = 0.32125571793971863 | episode_reward = 109.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 12480 | reward = 1.0 | action = 0 | epsilon = 0.32272446085341655 | episode_reward = 285.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 12572 | reward = 1.0 | action = 1 | epsilon = 0.32022690219946215 | episode_reward = 106.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 13208 | reward = 1.0 | action = 1 | epsilon = 0.3035755952255441 | episode_reward = 192.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 13147 | reward = 1.0 | action = 1 | epsilon = 0.305127133737823 | episode_reward = 455.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 13241 | reward = 1.0 | action = 0 | epsilon = 0.302740174962877 | episode_reward = 254.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 13818 | reward = 1.0 | action = 1 | epsilon = 0.28856981302822654 | episode_reward = 218.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 13745 | reward = 1.0 | action = 1 | epsilon = 0.2903177448522231 | episode_reward = 299.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 13844 | reward = 1.0 | action = 0 | epsilon = 0.28795033718192464 | episode_reward = 396.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 14551 | reward = 1.0 | action = 0 | epsilon = 0.2717081737801379 | episode_reward = 243.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 14479 | reward = 1.0 | action = 0 | epsilon = 0.2733102331241028 | episode_reward = 194.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 14573 | reward = 1.0 | action = 0 | epsilon = 0.2712209519383604 | episode_reward = 190.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 15274 | reward = 1.0 | action = 1 | epsilon = 0.2562444229474867 | episode_reward = 114.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 15205 | reward = 1.0 | action = 0 | epsilon = 0.2576724304260102 | episode_reward = 242.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 15306 | reward = 1.0 | action = 1 | epsilon = 0.2555854956400313 | episode_reward = 203.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 16083 | reward = 1.0 | action = 0 | epsilon = 0.24021632641939789 | episode_reward = 299.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 16014 | reward = 1.0 | action = 0 | epsilon = 0.24153335760396138 | episode_reward = 367.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 16117 | reward = 1.0 | action = 0 | epsilon = 0.23957068911495333 | episode_reward = 234.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 16888 | reward = 1.0 | action = 0 | epsilon = 0.2255040263125459 | episode_reward = 283.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 16807 | reward = 1.0 | action = 0 | epsilon = 0.22693138191179624 | episode_reward = 272.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 16923 | reward = 1.0 | action = 1 | epsilon = 0.22489083592958725 | episode_reward = 251.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 17543 | reward = 1.0 | action = 1 | epsilon = 0.21437690374680712 | episode_reward = 241.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 17456 | reward = 1.0 | action = 1 | epsilon = 0.21581322173308037 | episode_reward = 197.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 17579 | reward = 1.0 | action = 0 | epsilon = 0.21378621077860965 | episode_reward = 246.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 18196 | reward = 1.0 | action = 1 | epsilon = 0.2039860454881407 | episode_reward = 320.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 18109 | reward = 1.0 | action = 1 | epsilon = 0.20533156862264618 | episode_reward = 148.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 18232 | reward = 1.0 | action = 0 | epsilon = 0.20343269235763956 | episode_reward = 254.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 19006 | reward = 1.0 | action = 1 | epsilon = 0.1920049597196658 | episode_reward = 216.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 18919 | reward = 1.0 | action = 0 | epsilon = 0.1932457926659995 | episode_reward = 248.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 19048 | reward = 1.0 | action = 0 | epsilon = 0.1914097896209503 | episode_reward = 146.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 19827 | reward = 1.0 | action = 1 | epsilon = 0.18081210552089821 | episode_reward = 235.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 19747 | reward = 1.0 | action = 0 | epsilon = 0.18186279953743612 | episode_reward = 182.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 19869 | reward = 1.0 | action = 1 | epsilon = 0.18026384682690783 | episode_reward = 187.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 20636 | reward = 1.0 | action = 0 | epsilon = 0.1706461624889969 | episode_reward = 313.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 20568 | reward = 1.0 | action = 0 | epsilon = 0.17146935206646408 | episode_reward = 151.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 20670 | reward = 1.0 | action = 0 | epsilon = 0.17023666208171198 | episode_reward = 174.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 21494 | reward = 1.0 | action = 1 | epsilon = 0.16072636577752833 | episode_reward = 295.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 21428 | reward = 1.0 | action = 1 | epsilon = 0.16145957672623695 | episode_reward = 223.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 21517 | reward = 1.0 | action = 0 | epsilon = 0.16047198778307195 | episode_reward = 214.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 22349 | reward = 1.0 | action = 0 | epsilon = 0.15165268812102434 | episode_reward = 286.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 22289 | reward = 1.0 | action = 0 | epsilon = 0.1522644376631293 | episode_reward = 282.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 22383 | reward = 1.0 | action = 0 | epsilon = 0.15130765586862294 | episode_reward = 325.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 23222 | reward = 1.0 | action = 0 | epsilon = 0.143154740295612 | episode_reward = 348.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 23178 | reward = 1.0 | action = 0 | epsilon = 0.14356552421480373 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 23257 | reward = 1.0 | action = 1 | epsilon = 0.1428292686122755 | episode_reward = 171.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 24093 | reward = 1.0 | action = 0 | epsilon = 0.13538427792391416 | episode_reward = 253.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 24055 | reward = 1.0 | action = 1 | epsilon = 0.13570935543612175 | episode_reward = 206.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 24130 | reward = 1.0 | action = 0 | epsilon = 0.13506893983081608 | episode_reward = 202.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 2]] | max_global_step = 24959 | reward = 1.0 | action = 0 | epsilon = 0.1283011260872075 | episode_reward = 331.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 0]] | max_global_step = 24915 | reward = 1.0 | action = 0 | epsilon = 0.1286464101097828 | episode_reward = 179.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:04: [DQN-small-batch-size[worker: 1]] | max_global_step = 24980 | reward = 1.0 | action = 1 | epsilon = 0.128136866255613 | episode_reward = 169.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 25841 | reward = 1.0 | action = 0 | epsilon = 0.12169793772177095 | episode_reward = 295.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 25798 | reward = 1.0 | action = 0 | epsilon = 0.1219997023222696 | episode_reward = 354.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 25854 | reward = 1.0 | action = 1 | epsilon = 0.12159763084016106 | episode_reward = 128.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 26701 | reward = 1.0 | action = 0 | epsilon = 0.11578961404387869 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 26644 | reward = 1.0 | action = 1 | epsilon = 0.11615906939198964 | episode_reward = 239.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 26727 | reward = 1.0 | action = 1 | epsilon = 0.11561222167342619 | episode_reward = 255.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 27564 | reward = 1.0 | action = 1 | epsilon = 0.11034402788154471 | episode_reward = 348.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 27516 | reward = 1.0 | action = 1 | epsilon = 0.11064043923290565 | episode_reward = 214.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 27583 | reward = 1.0 | action = 0 | epsilon = 0.11022948308058958 | episode_reward = 86.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 28414 | reward = 1.0 | action = 1 | epsilon = 0.10542673089946288 | episode_reward = 257.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 28379 | reward = 1.0 | action = 1 | epsilon = 0.10562106434275476 | episode_reward = 256.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 28434 | reward = 1.0 | action = 0 | epsilon = 0.10532152009267126 | episode_reward = 123.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 29261 | reward = 1.0 | action = 1 | epsilon = 0.10092540854652021 | episode_reward = 270.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 29213 | reward = 1.0 | action = 1 | epsilon = 0.10117043810803457 | episode_reward = 352.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 29264 | reward = 1.0 | action = 1 | epsilon = 0.10091013321537048 | episode_reward = 187.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 30062 | reward = 1.0 | action = 1 | epsilon = 0.0970100770757105 | episode_reward = 10.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 30024 | reward = 1.0 | action = 0 | epsilon = 0.09718433654210248 | episode_reward = 172.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 30067 | reward = 1.0 | action = 1 | epsilon = 0.09698187948958684 | episode_reward = 319.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 30939 | reward = 1.0 | action = 1 | epsilon = 0.09305860025662048 | episode_reward = 10.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 30908 | reward = 1.0 | action = 0 | epsilon = 0.09319228902794915 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 30952 | reward = 1.0 | action = 1 | epsilon = 0.09300266044504259 | episode_reward = 320.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 31700 | reward = 1.0 | action = 1 | epsilon = 0.08990341800827328 | episode_reward = 48.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 31667 | reward = 1.0 | action = 1 | epsilon = 0.09003531680101046 | episode_reward = 421.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 31707 | reward = 1.0 | action = 1 | epsilon = 0.08987549538972414 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 32434 | reward = 1.0 | action = 1 | epsilon = 0.08708312389528094 | episode_reward = 150.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 32394 | reward = 1.0 | action = 0 | epsilon = 0.08722803046260982 | episode_reward = 173.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 32439 | reward = 1.0 | action = 0 | epsilon = 0.0870608806945713 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 33218 | reward = 1.0 | action = 1 | epsilon = 0.08428342442988423 | episode_reward = 78.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 33180 | reward = 1.0 | action = 1 | epsilon = 0.08441394928287362 | episode_reward = 173.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 33226 | reward = 1.0 | action = 1 | epsilon = 0.08425600865811123 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 34033 | reward = 1.0 | action = 0 | epsilon = 0.08160015370282815 | episode_reward = 162.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 33990 | reward = 1.0 | action = 0 | epsilon = 0.08173632692636074 | episode_reward = 345.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 34049 | reward = 1.0 | action = 0 | epsilon = 0.08154963388353663 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 34859 | reward = 1.0 | action = 1 | epsilon = 0.07909487333478453 | episode_reward = 140.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 34807 | reward = 1.0 | action = 0 | epsilon = 0.07924656072152886 | episode_reward = 84.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 34875 | reward = 1.0 | action = 0 | epsilon = 0.0790483587590326 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 35655 | reward = 1.0 | action = 0 | epsilon = 0.07686869850157957 | episode_reward = 244.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 35608 | reward = 1.0 | action = 0 | epsilon = 0.0769979782776326 | episode_reward = 102.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 35668 | reward = 1.0 | action = 1 | epsilon = 0.07683379188774253 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 36491 | reward = 1.0 | action = 0 | epsilon = 0.07471380475801093 | episode_reward = 214.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 36436 | reward = 1.0 | action = 1 | epsilon = 0.07485010516671346 | episode_reward = 146.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 36484 | reward = 1.0 | action = 0 | epsilon = 0.07473111047763674 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 37286 | reward = 1.0 | action = 1 | epsilon = 0.07282512686429443 | episode_reward = 460.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 37218 | reward = 1.0 | action = 1 | epsilon = 0.07298086664209937 | episode_reward = 199.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 37260 | reward = 1.0 | action = 1 | epsilon = 0.0728845494099763 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 38071 | reward = 1.0 | action = 1 | epsilon = 0.07110187680423924 | episode_reward = 378.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 38003 | reward = 1.0 | action = 1 | epsilon = 0.07124585854963318 | episode_reward = 79.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 38039 | reward = 1.0 | action = 0 | epsilon = 0.07116951096695869 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 38819 | reward = 1.0 | action = 1 | epsilon = 0.0695810445717601 | episode_reward = 175.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 38749 | reward = 1.0 | action = 0 | epsilon = 0.06971859274069915 | episode_reward = 193.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 38788 | reward = 1.0 | action = 0 | epsilon = 0.06964183999415027 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 39549 | reward = 1.0 | action = 0 | epsilon = 0.06820255528668023 | episode_reward = 176.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 39484 | reward = 1.0 | action = 1 | epsilon = 0.06832125725952584 | episode_reward = 242.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 39517 | reward = 1.0 | action = 1 | epsilon = 0.06826089676017048 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 40306 | reward = 1.0 | action = 0 | epsilon = 0.06687548512670552 | episode_reward = 217.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 40258 | reward = 1.0 | action = 1 | epsilon = 0.06695668217232492 | episode_reward = 9.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 40283 | reward = 1.0 | action = 0 | epsilon = 0.06691603493131124 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 2]] | max_global_step = 41084 | reward = 1.0 | action = 0 | epsilon = 0.06561234557975706 | episode_reward = 132.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 0]] | max_global_step = 41045 | reward = 1.0 | action = 1 | epsilon = 0.0656733526139083 | episode_reward = 8.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:05: [DQN-small-batch-size[worker: 1]] | max_global_step = 41057 | reward = 1.0 | action = 0 | epsilon = 0.06565455587107294 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 2]] | max_global_step = 41889 | reward = 1.0 | action = 0 | epsilon = 0.06440480720560703 | episode_reward = 268.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 0]] | max_global_step = 41830 | reward = 1.0 | action = 0 | epsilon = 0.06449004677759182 | episode_reward = 63.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 1]] | max_global_step = 41873 | reward = 1.0 | action = 1 | epsilon = 0.06442931620460313 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 2]] | max_global_step = 42683 | reward = 1.0 | action = 1 | epsilon = 0.06330662437574446 | episode_reward = 220.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 0]] | max_global_step = 42624 | reward = 1.0 | action = 0 | epsilon = 0.06338402704788762 | episode_reward = 87.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 1]] | max_global_step = 42655 | reward = 1.0 | action = 1 | epsilon = 0.06334260080788667 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 2]] | max_global_step = 43524 | reward = 1.0 | action = 0 | epsilon = 0.06223207970357811 | episode_reward = 273.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 0]] | max_global_step = 43438 | reward = 1.0 | action = 0 | epsilon = 0.06233772923084368 | episode_reward = 149.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 1]] | max_global_step = 43476 | reward = 1.0 | action = 1 | epsilon = 0.06229093482544598 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 2]] | max_global_step = 44279 | reward = 1.0 | action = 1 | epsilon = 0.06134255957140431 | episode_reward = 174.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 0]] | max_global_step = 44194 | reward = 1.0 | action = 1 | epsilon = 0.06143938224115523 | episode_reward = 112.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 1]] | max_global_step = 44232 | reward = 1.0 | action = 1 | epsilon = 0.06139599507646097 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 2]] | max_global_step = 45029 | reward = 1.0 | action = 1 | epsilon = 0.06052298576066393 | episode_reward = 215.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 0]] | max_global_step = 44949 | reward = 1.0 | action = 1 | epsilon = 0.06060750728205383 | episode_reward = 268.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 1]] | max_global_step = 44993 | reward = 1.0 | action = 0 | epsilon = 0.060560936780250485 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 2]] | max_global_step = 45787 | reward = 1.0 | action = 0 | epsilon = 0.059754824514196295 | episode_reward = 279.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 0]] | max_global_step = 45715 | reward = 1.0 | action = 0 | epsilon = 0.059825312703671904 | episode_reward = 122.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 1]] | max_global_step = 45754 | reward = 1.0 | action = 0 | epsilon = 0.059787068608587374 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 2]] | max_global_step = 46511 | reward = 1.0 | action = 1 | epsilon = 0.05907353545351516 | episode_reward = 221.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 0]] | max_global_step = 46427 | reward = 1.0 | action = 1 | epsilon = 0.05915007416386105 | episode_reward = 127.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 1]] | max_global_step = 46464 | reward = 1.0 | action = 0 | epsilon = 0.05911628144453736 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 2]] | max_global_step = 47297 | reward = 1.0 | action = 0 | epsilon = 0.05838766340935608 | episode_reward = 211.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 0]] | max_global_step = 47211 | reward = 1.0 | action = 0 | epsilon = 0.05846010838155505 | episode_reward = 124.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 1]] | max_global_step = 47264 | reward = 1.0 | action = 1 | epsilon = 0.05841538841971361 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 2]] | max_global_step = 48094 | reward = 1.0 | action = 0 | epsilon = 0.05774588693635186 | episode_reward = 210.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 0]] | max_global_step = 47993 | reward = 1.0 | action = 1 | epsilon = 0.05782373439427728 | episode_reward = 168.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 1]] | max_global_step = 48055 | reward = 1.0 | action = 1 | epsilon = 0.057776154879528224 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 2]] | max_global_step = 48850 | reward = 1.0 | action = 1 | epsilon = 0.05718116757393946 | episode_reward = 155.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 0]] | max_global_step = 48742 | reward = 1.0 | action = 0 | epsilon = 0.05725987045195807 | episode_reward = 140.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 1]] | max_global_step = 48792 | reward = 1.0 | action = 1 | epsilon = 0.0572229393669679 | episode_reward = 488.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 2]] | max_global_step = 49627 | reward = 1.0 | action = 1 | epsilon = 0.056644317543240556 | episode_reward = 255.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 0]] | max_global_step = 49524 | reward = 1.0 | action = 1 | epsilon = 0.05671310767495324 | episode_reward = 166.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 1]] | max_global_step = 49575 | reward = 1.0 | action = 1 | epsilon = 0.056678957981548594 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 0]] | max_global_step = 49886 | reward = 1.0 | action = 0 | epsilon = 0.0564744391404252 | episode_reward = 174.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 1]] | max_global_step = 49935 | reward = 1.0 | action = 1 | epsilon = 0.05644279198748248 | episode_reward = 500.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: [DQN-small-batch-size[worker: 0]] | max_global_step = 49954 | reward = 1.0 | action = 1 | epsilon = 0.05643056230458411 | episode_reward = 174.0 | \u001b[0m\n", + "\u001b[38;21m[INFO] 16:06: ... trained! \u001b[0m\n" + ] + } + ], + "source": [ + "for manager in agent_managers:\n", + " manager.fit() " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 7. Evaluate the DQN agents on 10 episodes. Plot the mean episodic rewards." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this [paper](https://arxiv.org/pdf/2310.03882.pdf), authors claim that a small batch size increases DQN performances." + ] + }, + { + "cell_type": "code", + "execution_count": 260, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[38;21m[INFO] 16:06: Evaluating DQN... \u001b[0m\n", + "[INFO] Evaluation:.......... Evaluation finished \n", + "\u001b[38;21m[INFO] 16:06: Evaluating DQN-small-batch-size... \u001b[0m\n", + "[INFO] Evaluation:.......... Evaluation finished \n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "output = evaluate_agents(agent_managers, n_simulations=10, plot=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 8. Plot the episodic rewards during learning." + ] + }, + { + "cell_type": "code", + "execution_count": 261, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "output_mydqn = plot_writer_data(\n", + " agent_managers,\n", + " tag=\"episode_reward\",\n", + " title=\"Learning Reward\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 9. Bonus: modify the Net class so that Q networks take images as input and use rlberry.envs.atari_make(\"AtariGame\") to train DQN on Pong ." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# class CNNNetwork(torch.nn.Module):\n", + "# \"\"\"\n", + "# Basic CNN Qnetwork.\n", + "# \"\"\"\n", + "# def __init__(self, obs_size, hidden_size, n_actions):\n", + "# super(CNNNetwork, self).__init__()\n", + "# n_input_channels = obs_size\n", + "# self.cnn = nn.Sequential(\n", + "# nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),\n", + "# nn.ReLU(),\n", + "# nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),\n", + "# nn.ReLU(),\n", + "# nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),\n", + "# nn.ReLU(),\n", + "# nn.Flatten(),\n", + "# )\n", + "\n", + "# self.linear = nn.Sequential(nn.LazyLinear(hidden_size), nn.ReLU(), nn.Linear(hidden_size, n_actions))\n", + "\n", + "# def forward(self, observations):\n", + "# return self.linear(self.cnn(observations))\n", + " \n", + "\n", + "# # Environment\n", + "# env_ctor = atari_make # environment constructor\n", + "# env_kwargs = {\"id\": \"ALE/Pong-v5\"} # parameters passed to the environment constructor. Here name of the environment\n", + "# names = [\"DQN\", \"DQN-small-batch-size\"]\n", + "# agent_list = [MyDQN, MyDQN] # list of agents on which we run experiments\n", + "\n", + "# agent_params = [dict(gamma=0.99, \n", + "# batch_size=256, \n", + "# buffer_capacity=10000, \n", + "# update_target_every=1000, \n", + "# epsilon_start=1, \n", + "# decrease_epsilon=10_000, \n", + "# epsilon_min=0.05,\n", + "# qnetworks=CNNNetwork,\n", + "# replay_buffer=ReplayBuffer,\n", + "# optimizer=optim.Adam,\n", + "# loss_fn=F.mse_loss,\n", + "# ), \n", + "# dict(gamma=0.99, \n", + "# batch_size=8, \n", + "# buffer_capacity=10000, \n", + "# update_target_every=1000, \n", + "# epsilon_start=1, \n", + "# decrease_epsilon=10_000, \n", + "# epsilon_min=0.05,\n", + "# qnetworks=CNNNetwork,\n", + "# replay_buffer=ReplayBuffer,\n", + "# optimizer=optim.Adam,\n", + "# loss_fn=F.mse_loss,\n", + "# )]\n", + "\n", + "\n", + "# fit_budget = 50_000 # budgent used by an agent in one fit. This is used differently by different agents.\n", + "# n_fit = 3 # number of parallel fits of the same agent.\n", + "\n", + "# agent_managers = [\n", + "# AgentManager( # an AgentManager is used to repeatedly fit and evaluate an agent on an environment\n", + "# Agent,\n", + "# agent_name=names[e],\n", + "# init_kwargs=agent_params[e],\n", + "# train_env=(env_ctor, env_kwargs),\n", + "# fit_budget=fit_budget,\n", + "# n_fit=n_fit,\n", + "# eval_kwargs=dict(eval_horizon=500, n_simulations=10)\n", + "# )\n", + "# for e, Agent in enumerate(agent_list)\n", + "# ]\n", + "# for manager in agent_managers:\n", + "# manager.fit() " + ] + } + ], + "metadata": { + "colab": { + "authorship_tag": "ABX9TyP9EbLl6g2dURBpFFjKPouU", + "collapsed_sections": [], + "include_colab_link": true, + "name": "Tutorial_Deep_Q_Learning.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tutorials/quick_start.md b/tutorials/quick_start.md index 79659c4..ac904a6 100644 --- a/tutorials/quick_start.md +++ b/tutorials/quick_start.md @@ -2,5 +2,6 @@ ```python import numpy as np + print(np.pi) ```