From 70a92ae5b828f229a32783b21f6b2d5d130a1fe4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E8=B6=85?= Date: Wed, 7 Jul 2021 20:36:06 +0800 Subject: [PATCH] curiosity --- 11.curiosity.ipynb | 723 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 723 insertions(+) create mode 100644 11.curiosity.ipynb diff --git a/11.curiosity.ipynb b/11.curiosity.ipynb new file mode 100644 index 0000000..9cd109b --- /dev/null +++ b/11.curiosity.ipynb @@ -0,0 +1,723 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Configurations for Colab" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "IN_COLAB = \"google.colab\" in sys.modules\n", + "\n", + "if IN_COLAB:\n", + " !apt install python-opengl\n", + " !apt install ffmpeg\n", + " !apt install xvfb\n", + " !pip install pyvirtualdisplay\n", + " !pip install gym\n", + " from pyvirtualdisplay import Display\n", + " \n", + " # Start virtual display\n", + " dis = Display(visible=0, size=(400, 400))\n", + " dis.start()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 11. Curiosity\n", + "\n", + "[Deepak Pathak, Pulkit Agrawal, Alexei A. Efros and Trevor Darrell. Curiosity-driven Exploration by Self-supervised Prediction.\n", + "In ICML 2017.](https://pathak22.github.io/noreward-rl/resources/icml17.pdf)\n", + "\n", + "\n", + "\n", + "In many real-world scenarios, rewards extrinsic to the agent are extremely sparse, or absent altogether. In such cases, curiosity can serve as an intrinsic reward signal to enable the agent to explore its environment and learn skills that might be useful later in its life. We formulate curiosity as the error in an agent's ability to predict the consequence of its own actions in a visual feature space learned by a self-supervised inverse dynamics model. Our formulation scales to high-dimensional continuous state spaces like images, bypasses the difficulties of directly predicting pixels, and, critically, ignores the aspects of the environment that cannot affect the agent. The proposed approach is evaluated in two environments: VizDoom and Super Mario Bros. Three broad settings are investigated: 1) sparse extrinsic reward, where curiosity allows for far fewer interactions with the environment to reach the goal; 2) exploration with no extrinsic reward, where curiosity pushes the agent to explore more efficiently; and 3) generalization to unseen scenarios (e.g. new levels of the same game) where the knowledge gained from earlier experience helps the agent explore new places much faster than starting from scratch." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from typing import Dict, List, Tuple\n", + "\n", + "import gym\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "from IPython.display import clear_output" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Replay buffer\n", + "\n", + "Please see *01.dqn.ipynb* for detailed description." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "class ReplayBuffer:\n", + " \"\"\"A simple numpy replay buffer.\"\"\"\n", + "\n", + " def __init__(self, obs_dim: int, size: int, batch_size: int = 32):\n", + " self.obs_buf = np.zeros([size, obs_dim], dtype=np.float32)\n", + " self.next_obs_buf = np.zeros([size, obs_dim], dtype=np.float32)\n", + " self.acts_buf = np.zeros([size], dtype=np.float32)\n", + " self.rews_buf = np.zeros([size], dtype=np.float32)\n", + " self.done_buf = np.zeros(size, dtype=np.float32)\n", + " self.max_size, self.batch_size = size, batch_size\n", + " self.ptr, self.size, = 0, 0\n", + "\n", + " def store(\n", + " self,\n", + " obs: np.ndarray,\n", + " act: np.ndarray, \n", + " rew: float, \n", + " next_obs: np.ndarray, \n", + " done: bool,\n", + " ):\n", + " self.obs_buf[self.ptr] = obs\n", + " self.next_obs_buf[self.ptr] = next_obs\n", + " self.acts_buf[self.ptr] = act\n", + " self.rews_buf[self.ptr] = rew\n", + " self.done_buf[self.ptr] = done\n", + " self.ptr = (self.ptr + 1) % self.max_size\n", + " self.size = min(self.size + 1, self.max_size)\n", + "\n", + " def sample_batch(self) -> Dict[str, np.ndarray]:\n", + " idxs = np.random.choice(self.size, size=self.batch_size, replace=False)\n", + " return dict(obs=self.obs_buf[idxs],\n", + " next_obs=self.next_obs_buf[idxs],\n", + " acts=self.acts_buf[idxs],\n", + " rews=self.rews_buf[idxs],\n", + " done=self.done_buf[idxs])\n", + "\n", + " def __len__(self) -> int:\n", + " return self.size" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Curiosity Network\n", + "\n", + "Intrinsic Curiosity Module (ICM)\n", + "We propose intrinsic curiosity formulation to help agent exploration. Curiosity help agent discover the environment out of curiosity when extrinsic rewards are spare or not present at all. Our proposed intrinsic model (ICM) is learned jointly with agent's policy even without any rewards from the environment. A glimpse of our model is shown in figure below. For more details, refer to the paper.\n", + "![ICM](https://pathak22.github.io/noreward-rl/resources/method.jpg)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "class Network(nn.Module):\n", + " def __init__(self, in_dim: int, out_dim: int):\n", + " \"\"\"Initialization.\"\"\"\n", + " super(Network, self).__init__()\n", + "\n", + " self.layers = nn.Sequential(\n", + " nn.Linear(in_dim, 128), \n", + " nn.ReLU(),\n", + " nn.Linear(128, 128), \n", + " nn.ReLU(), \n", + " nn.Linear(128, out_dim)\n", + " )\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"Forward method implementation.\"\"\"\n", + " return self.layers(x)\n", + "\n", + "\n", + "class ICM(nn.Module):\n", + " def __init__(self, in_dim: int, out_dim: int):\n", + " \"\"\"Initialization.\"\"\"\n", + " super(ICM, self).__init__()\n", + "\n", + " self.fea_layers = nn.Sequential(\n", + " nn.Linear(in_dim, 128), \n", + " nn.ReLU(),\n", + " nn.Linear(128, 128), \n", + " nn.ReLU()\n", + " )\n", + " \n", + " self.pred_module1 = nn.Linear(128 + out_dim, 128)\n", + " self.pred_module2 = nn.Linear(128, 128)\n", + " \n", + " self.invpred_module1 = nn.Linear(128 + 128, 128)\n", + " self.invpred_module2 = nn.Linear(128, out_dim)\n", + " \n", + " def pred(self, feature_x, a_vec):\n", + " # Forward prediction: predict next state feature, given current state feature and action (one-hot)\n", + " pred_s_next = F.relu(self.pred_module1( torch.cat([feature_x, a_vec.float()], dim = -1).detach()))\n", + " pred_s_next = self.pred_module2(pred_s_next)\n", + " return pred_s_next\n", + " \n", + " def invpred(self, feature_x, feature_x_next):\n", + " # Inverse prediction: predict action (one-hot), given current and next state features\n", + " pred_a_vec = F.relu(self.invpred_module1(torch.cat([feature_x, feature_x_next], dim = -1)))\n", + " pred_a_vec = self.invpred_module2(pred_a_vec)\n", + " return F.softmax(pred_a_vec, dim = -1)\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"Forward method implementation.\"\"\"\n", + " return self.fea_layers(x)\n", + "\n", + " def get_full(self, x, x_next, a_vec):\n", + " # get feature\n", + " feature_x = self.fea_layers(x)\n", + " feature_x_next = self.fea_layers(x_next)\n", + "\n", + " pred_s_next = self.pred(feature_x, a_vec) # predict next state feature\n", + " pred_a_vec = self.invpred(feature_x, feature_x_next) # (inverse) predict action\n", + "\n", + " return pred_s_next, pred_a_vec, feature_x_next" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## DQN Agent\n", + "\n", + "Here is a summary of DQNAgent class.\n", + "\n", + "| Method | Note |\n", + "| --- | --- |\n", + "|select_action | select an action from the input state. |\n", + "|step | take an action and return the response of the env. |\n", + "|compute_dqn_loss | return dqn loss. |\n", + "|update_model | update the model by gradient descent. |\n", + "|target_hard_update| hard update from the local model to the target model.|\n", + "|train | train the agent during num_frames. |\n", + "|test | test the agent (1 episode). |\n", + "|plot | plot the training progresses. |\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "class DQNAgent:\n", + " \"\"\"DQN Agent interacting with environment.\n", + " \n", + " Attribute:\n", + " env (gym.Env): openAI Gym environment\n", + " memory (ReplayBuffer): replay memory to store transitions\n", + " batch_size (int): batch size for sampling\n", + " epsilon (float): parameter for epsilon greedy policy\n", + " epsilon_decay (float): step size to decrease epsilon\n", + " max_epsilon (float): max value of epsilon\n", + " min_epsilon (float): min value of epsilon\n", + " target_update (int): period for target model's hard update\n", + " gamma (float): discount factor\n", + " dqn (Network): model to train and select actions\n", + " dqn_target (Network): target model to update\n", + " optimizer (torch.optim): optimizer for training dqn\n", + " transition (list): transition information including \n", + " state, action, reward, next_state, done\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self, \n", + " env: gym.Env,\n", + " memory_size: int,\n", + " batch_size: int,\n", + " target_update: int,\n", + " epsilon_decay: float,\n", + " max_epsilon: float = 1.0,\n", + " min_epsilon: float = 0.1,\n", + " gamma: float = 0.99,\n", + " use_extrinsic: bool = True,\n", + " intrinsic_scale: float = 1.0,\n", + " ):\n", + " \"\"\"Initialization.\n", + " \n", + " Args:\n", + " env (gym.Env): openAI Gym environment\n", + " memory_size (int): length of memory\n", + " batch_size (int): batch size for sampling\n", + " target_update (int): period for target model's hard update\n", + " epsilon_decay (float): step size to decrease epsilon\n", + " lr (float): learning rate\n", + " max_epsilon (float): max value of epsilon\n", + " min_epsilon (float): min value of epsilon\n", + " gamma (float): discount factor\n", + " \"\"\"\n", + " obs_dim = env.observation_space.shape[0]\n", + " action_dim = env.action_space.n\n", + " \n", + " self.env = env\n", + " self.memory = ReplayBuffer(obs_dim, memory_size, batch_size)\n", + " self.batch_size = batch_size\n", + " self.epsilon = max_epsilon\n", + " self.epsilon_decay = epsilon_decay\n", + " self.max_epsilon = max_epsilon\n", + " self.min_epsilon = min_epsilon\n", + " self.target_update = target_update\n", + " self.gamma = gamma\n", + " self.use_extrinsic = use_extrinsic\n", + " self.intrinsic_scale = intrinsic_scale\n", + " \n", + " # device: cpu / gpu\n", + " self.device = torch.device(\n", + " \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + " )\n", + " print(self.device)\n", + "\n", + " # networks: dqn, dqn_target\n", + " self.dqn = Network(obs_dim, action_dim).to(self.device)\n", + " self.dqn_target = Network(obs_dim, action_dim).to(self.device)\n", + " self.dqn_target.load_state_dict(self.dqn.state_dict())\n", + " self.dqn_target.eval()\n", + " self.icm = ICM(obs_dim, action_dim).to(self.device)\n", + " \n", + " # optimizer\n", + " self.optimizer = optim.Adam(self.dqn.parameters())\n", + "\n", + " # transition to store in memory\n", + " self.transition = list()\n", + " \n", + " # mode: train / test\n", + " self.is_test = False\n", + "\n", + " def select_action(self, state: np.ndarray) -> np.ndarray:\n", + " \"\"\"Select an action from the input state.\"\"\"\n", + " # epsilon greedy policy\n", + " if self.epsilon > np.random.random():\n", + " selected_action = self.env.action_space.sample()\n", + " else:\n", + " selected_action = self.dqn(\n", + " torch.FloatTensor(state).to(self.device)\n", + " ).argmax()\n", + " selected_action = selected_action.detach().cpu().numpy()\n", + " \n", + " if not self.is_test:\n", + " self.transition = [state, selected_action]\n", + " \n", + " return selected_action\n", + "\n", + " def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool]:\n", + " \"\"\"Take an action and return the response of the env.\"\"\"\n", + " next_state, reward, done, _ = self.env.step(action)\n", + "\n", + " if not self.is_test:\n", + " self.transition += [reward, next_state, done]\n", + " self.memory.store(*self.transition)\n", + " \n", + " return next_state, reward, done\n", + "\n", + " def update_model(self) -> torch.Tensor:\n", + " \"\"\"Update the model by gradient descent.\"\"\"\n", + " samples = self.memory.sample_batch()\n", + "\n", + " loss = self._compute_dqn_loss(samples)\n", + "\n", + " self.optimizer.zero_grad()\n", + " loss.backward()\n", + " self.optimizer.step()\n", + "\n", + " return loss.item()\n", + " \n", + " def train(self, num_frames: int, plotting_interval: int = 200):\n", + " \"\"\"Train the agent.\"\"\"\n", + " self.is_test = False\n", + " \n", + " state = self.env.reset()\n", + " update_cnt = 0\n", + " epsilons = []\n", + " losses = []\n", + " scores = []\n", + " score = 0\n", + "\n", + " for frame_idx in range(1, num_frames + 1):\n", + " action = self.select_action(state)\n", + " next_state, reward, done = self.step(action)\n", + "\n", + " state = next_state\n", + " score += reward\n", + "\n", + " # if episode ends\n", + " if done:\n", + " state = self.env.reset()\n", + " scores.append(score)\n", + " score = 0\n", + "\n", + " # if training is ready\n", + " if len(self.memory) >= self.batch_size:\n", + " loss = self.update_model()\n", + " losses.append(loss)\n", + " update_cnt += 1\n", + " \n", + " # linearly decrease epsilon\n", + " self.epsilon = max(\n", + " self.min_epsilon, self.epsilon - (\n", + " self.max_epsilon - self.min_epsilon\n", + " ) * self.epsilon_decay\n", + " )\n", + " epsilons.append(self.epsilon)\n", + " \n", + " # if hard update is needed\n", + " if update_cnt % self.target_update == 0:\n", + " self._target_hard_update()\n", + "\n", + " # plotting\n", + " if frame_idx % plotting_interval == 0:\n", + " self._plot(frame_idx, scores, losses, epsilons)\n", + " \n", + " self.env.close()\n", + " \n", + " def test(self) -> List[np.ndarray]:\n", + " \"\"\"Test the agent.\"\"\"\n", + " self.is_test = True\n", + " \n", + " state = self.env.reset()\n", + " done = False\n", + " score = 0\n", + " \n", + " frames = []\n", + " while not done:\n", + " frames.append(self.env.render(mode=\"rgb_array\"))\n", + " action = self.select_action(state)\n", + " next_state, reward, done = self.step(action)\n", + "\n", + " state = next_state\n", + " score += reward\n", + " \n", + " print(\"score: \", score)\n", + " self.env.close()\n", + " \n", + " return frames\n", + "\n", + " def _compute_dqn_loss(self, samples: Dict[str, np.ndarray]) -> torch.Tensor:\n", + " \"\"\"Return dqn loss.\"\"\"\n", + " device = self.device # for shortening the following lines\n", + " state = torch.FloatTensor(samples[\"obs\"]).to(device)\n", + " next_state = torch.FloatTensor(samples[\"next_obs\"]).to(device)\n", + " action = torch.LongTensor(samples[\"acts\"].reshape(-1, 1)).to(device)\n", + " reward = torch.FloatTensor(samples[\"rews\"].reshape(-1, 1)).to(device)\n", + " done = torch.FloatTensor(samples[\"done\"].reshape(-1, 1)).to(device)\n", + "\n", + " a_vec = F.one_hot(action, num_classes = self.env.action_space.n).reshape(-1,self.env.action_space.n) # convert action from int to one-hot format\n", + " pred_s_next, pred_a_vec, feature_x_next = self.icm.get_full(state, next_state, a_vec)\n", + " forward_loss = F.mse_loss(pred_s_next, feature_x_next, reduction='none')\n", + " inverse_pred_loss = F.cross_entropy(pred_a_vec, action.reshape(-1), reduction='none')\n", + " \n", + " # calculate rewards\n", + " intrinsic_reward = self.intrinsic_scale * forward_loss.mean(-1)\n", + " total_reward = intrinsic_reward.clone().reshape(-1,1)\n", + " if self.use_extrinsic:\n", + " total_reward += reward\n", + "\n", + " # G_t = r + gamma * v(s_{t+1}) if state != Terminal\n", + " # = r otherwise\n", + " curr_q_value = self.dqn(state).gather(1, action)\n", + " next_q_value = self.dqn_target(\n", + " next_state\n", + " ).max(dim=1, keepdim=True)[0].detach()\n", + " mask = 1 - done\n", + " target = (total_reward + self.gamma * next_q_value * mask).to(self.device)\n", + "\n", + " # calculate dqn loss\n", + " loss = F.smooth_l1_loss(curr_q_value, target)\n", + "\n", + " return loss + forward_loss.mean() + inverse_pred_loss.mean()\n", + "\n", + " def _target_hard_update(self):\n", + " \"\"\"Hard update: target <- local.\"\"\"\n", + " self.dqn_target.load_state_dict(self.dqn.state_dict())\n", + " \n", + " def _plot(\n", + " self, \n", + " frame_idx: int, \n", + " scores: List[float], \n", + " losses: List[float], \n", + " epsilons: List[float],\n", + " ):\n", + " \"\"\"Plot the training progresses.\"\"\"\n", + " clear_output(True)\n", + " plt.figure(figsize=(20, 5))\n", + " plt.subplot(131)\n", + " plt.title('frame %s. score: %s' % (frame_idx, np.mean(scores[-10:])))\n", + " plt.plot(scores)\n", + " plt.subplot(132)\n", + " plt.title('loss')\n", + " plt.plot(losses)\n", + " plt.subplot(133)\n", + " plt.title('epsilons')\n", + " plt.plot(epsilons)\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Environment\n", + "\n", + "You can see the [code](https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py) and [configurations](https://github.com/openai/gym/blob/master/gym/envs/__init__.py#L53) of CartPole-v0 from OpenAI's repository." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# environment\n", + "env_id = \"CartPole-v0\"\n", + "env = gym.make(env_id)\n", + "if IN_COLAB:\n", + " env = gym.wrappers.Monitor(env, \"videos\", force=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set random seed" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[777]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "seed = 777\n", + "\n", + "def seed_torch(seed):\n", + " torch.manual_seed(seed)\n", + " if torch.backends.cudnn.enabled:\n", + " torch.backends.cudnn.benchmark = False\n", + " torch.backends.cudnn.deterministic = True\n", + "\n", + "np.random.seed(seed)\n", + "seed_torch(seed)\n", + "env.seed(seed)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cpu\n" + ] + } + ], + "source": [ + "# parameters\n", + "num_frames = 20000\n", + "memory_size = 1000\n", + "batch_size = 32\n", + "target_update = 100\n", + "epsilon_decay = 1 / 2000\n", + "\n", + "agent = DQNAgent(env, memory_size, batch_size, target_update, epsilon_decay)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "agent.train(num_frames)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test\n", + "\n", + "Run the trained agent (1 episode)." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "score: 17.0\n" + ] + } + ], + "source": [ + "frames = agent.test()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Render" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'JSAnimation'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# for jupyter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0manimation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mJSAnimation\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIPython_display\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdisplay_animation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 34\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisplay\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdisplay\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'JSAnimation'" + ] + } + ], + "source": [ + "if IN_COLAB: # for colab\n", + " import base64\n", + " import glob\n", + " import io\n", + " import os\n", + "\n", + " from IPython.display import HTML, display\n", + "\n", + "\n", + " def ipython_show_video(path: str) -> None:\n", + " \"\"\"Show a video at `path` within IPython Notebook.\"\"\"\n", + " if not os.path.isfile(path):\n", + " raise NameError(\"Cannot access: {}\".format(path))\n", + "\n", + " video = io.open(path, \"r+b\").read()\n", + " encoded = base64.b64encode(video)\n", + "\n", + " display(HTML(\n", + " data=\"\"\"\n", + " \n", + " \"\"\".format(encoded.decode(\"ascii\"))\n", + " ))\n", + "\n", + " list_of_files = glob.glob(\"videos/*.mp4\")\n", + " latest_file = max(list_of_files, key=os.path.getctime)\n", + " print(latest_file)\n", + " ipython_show_video(latest_file)\n", + " \n", + "else: # for jupyter\n", + " from matplotlib import animation\n", + " from JSAnimation.IPython_display import display_animation\n", + " from IPython.display import display\n", + "\n", + "\n", + " def display_frames_as_gif(frames: List[np.ndarray]) -> None:\n", + " \"\"\"Displays a list of frames as a gif, with controls.\"\"\"\n", + " patch = plt.imshow(frames[0])\n", + " plt.axis('off')\n", + "\n", + " def animate(i):\n", + " patch.set_data(frames[i])\n", + "\n", + " anim = animation.FuncAnimation(\n", + " plt.gcf(), animate, frames = len(frames), interval=50\n", + " )\n", + " display(display_animation(anim, default_mode='loop'))\n", + "\n", + "\n", + " # display \n", + " display_frames_as_gif(frames)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}