From 68308e977435885d4a031567b50d8e1bb4fd88f8 Mon Sep 17 00:00:00 2001 From: julia-bel Date: Sun, 26 May 2024 20:15:19 +0000 Subject: [PATCH] partially implemented pipelines and tests --- .gitignore | 161 ++++++++ .gitmodules | 3 + README.md | 30 ++ data/default_statistics.csv | 19 + data/empty-48-48-even-10_60_agents.png | Bin 0 -> 496 bytes data/empty-48-48-random-10_60_agents.png | Bin 0 -> 478 bytes g2rl/__init__.py | 5 + g2rl/agent.py | 235 +++++++++++ g2rl/environment.py | 221 +++++++++++ g2rl/metrics.py | 13 + g2rl/network.py | 83 ++++ g2rl/train.py | 149 +++++++ g2rl/utils.py | 96 +++++ notebooks/train&test.ipynb | 1 + pogema | 1 + renders/test.svg | 476 +++++++++++++++++++++++ requirements.txt | 5 + setup.py | 19 + 18 files changed, 1517 insertions(+) create mode 100644 .gitignore create mode 100644 .gitmodules create mode 100644 README.md create mode 100644 data/default_statistics.csv create mode 100644 data/empty-48-48-even-10_60_agents.png create mode 100644 data/empty-48-48-random-10_60_agents.png create mode 100644 g2rl/__init__.py create mode 100644 g2rl/agent.py create mode 100644 g2rl/environment.py create mode 100644 g2rl/metrics.py create mode 100644 g2rl/network.py create mode 100644 g2rl/train.py create mode 100644 g2rl/utils.py create mode 100644 notebooks/train&test.ipynb create mode 160000 pogema create mode 100644 renders/test.svg create mode 100644 requirements.txt create mode 100644 setup.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8c61bc0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,161 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ +*.pt diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..22d7219 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "pogema"] + path = pogema + url = https://github.com/AIRI-Institute/pogema.git diff --git a/README.md b/README.md new file mode 100644 index 0000000..db0910e --- /dev/null +++ b/README.md @@ -0,0 +1,30 @@ +# Multi-Agent Pathfinding in POGEMA with G2RL +Implementation of the [G2RL](https://ieeexplore.ieee.org/abstract/document/9205217) [1] approach in the [POGEMA](https://github.com/AIRI-Institute/pogema) environment. + +## Basic Concepts +**Problem**: MAPF +**Environment**: 2D grid with static obstacles and dynamic agents +**Agent actions**: wait, up, down, left, right +**Local observations**: free cells, static obstacles, dynamic agents +**Global guidance**: the shortest traversable path considering all static obstacles +**Objective**: minimize the overall number of steps and avoid conflicts + +## Code Implementation +Partially based on the [repo](https://github.com/Tushar-ml/G2RL-Path-Planning.git). +Installation: + +``` +pip install -r requirements.txt +pip install . +``` + +## Train & Test +[Notebook](notebooks/train&test.ipynb) with simple examples of training and testing implementations. + +## Demonstration +![Demo](renders/test.svg) + +## References +[1] B. Wang, Z. Liu, Q. Li, and A. Prorok, "Mobile robot path planning in +dynamic environments through globally guided reinforcement learning," +IEEE Robot. Autom. Lett., vol. 5, no. 4, pp. 6932–6939, Oct. 2020. diff --git a/data/default_statistics.csv b/data/default_statistics.csv new file mode 100644 index 0000000..ff3ac32 --- /dev/null +++ b/data/default_statistics.csv @@ -0,0 +1,19 @@ +map,num_agents,max_steps,density,size,start,final,done,detour_percentage,moving_cost +random,3,60,,48,"[[25, 22], [17, 13], [22, 39]]","[[28, 26], [32, 16], [46, 32]]",2,0.0,1.1200716845878138 +random,3,120,,48,"[[25, 22], [17, 13], [22, 39]]","[[28, 26], [32, 16], [46, 32]]",2,0.0,1.1200716845878138 +random,3,180,,48,"[[25, 22], [17, 13], [22, 39]]","[[28, 26], [32, 16], [46, 32]]",2,0.0,1.1200716845878138 +random,6,60,,48,"[[25, 22], [17, 13], [22, 39], [32, 36], [38, 29], [3, 19]]","[[28, 26], [32, 16], [46, 32], [28, 19], [33, 30], [46, 21]]",5,0.0,1.1419969278033792 +random,6,120,,48,"[[25, 22], [17, 13], [22, 39], [32, 36], [38, 29], [3, 19]]","[[28, 26], [32, 16], [46, 32], [28, 19], [33, 30], [46, 21]]",5,0.0,1.1419969278033792 +random,6,180,,48,"[[25, 22], [17, 13], [22, 39], [32, 36], [38, 29], [3, 19]]","[[28, 26], [32, 16], [46, 32], [28, 19], [33, 30], [46, 21]]",5,0.0,1.1419969278033792 +random,12,60,,48,"[[25, 22], [17, 13], [22, 39], [32, 36], [38, 29], [3, 19], [38, 39], [15, 2], [6, 42], [11, 41], [39, 47], [35, 35]]","[[28, 26], [32, 16], [46, 32], [28, 19], [33, 30], [46, 21], [39, 41], [29, 22], [45, 44], [1, 0], [8, 7], [6, 6]]",9,1.9157088122605366,1.137862533771267 +random,12,120,,48,"[[25, 22], [17, 13], [22, 39], [32, 36], [38, 29], [3, 19], [38, 39], [15, 2], [6, 42], [11, 41], [39, 47], [35, 35]]","[[28, 26], [32, 16], [46, 32], [28, 19], [33, 30], [46, 21], [39, 41], [29, 22], [45, 44], [1, 0], [8, 7], [6, 6]]",9,1.9157088122605366,1.137862533771267 +random,12,180,,48,"[[25, 22], [17, 13], [22, 39], [32, 36], [38, 29], [3, 19], [38, 39], [15, 2], [6, 42], [11, 41], [39, 47], [35, 35]]","[[28, 26], [32, 16], [46, 32], [28, 19], [33, 30], [46, 21], [39, 41], [29, 22], [45, 44], [1, 0], [8, 7], [6, 6]]",9,1.9157088122605366,1.137862533771267 +even,3,60,,48,"[[32, 25], [46, 10], [12, 4]]","[[6, 41], [45, 11], [28, 32]]",2,0.0,1.0 +even,3,120,,48,"[[32, 25], [46, 10], [12, 4]]","[[6, 41], [45, 11], [28, 32]]",2,0.0,1.0 +even,3,180,,48,"[[32, 25], [46, 10], [12, 4]]","[[6, 41], [45, 11], [28, 32]]",2,0.0,1.0 +even,6,60,,48,"[[32, 25], [46, 10], [12, 4], [0, 1], [32, 9], [39, 11]]","[[6, 41], [45, 11], [28, 32], [33, 24], [23, 38], [28, 26]]",5,0.0,1.0210526315789474 +even,6,120,,48,"[[32, 25], [46, 10], [12, 4], [0, 1], [32, 9], [39, 11]]","[[6, 41], [45, 11], [28, 32], [33, 24], [23, 38], [28, 26]]",5,0.0,1.0210526315789474 +even,6,180,,48,"[[32, 25], [46, 10], [12, 4], [0, 1], [32, 9], [39, 11]]","[[6, 41], [45, 11], [28, 32], [33, 24], [23, 38], [28, 26]]",5,0.0,1.0210526315789474 +even,12,60,,48,"[[32, 25], [46, 10], [12, 4], [0, 1], [32, 9], [39, 11], [42, 4], [24, 11], [39, 10], [8, 0], [8, 24], [27, 1]]","[[6, 41], [45, 11], [28, 32], [33, 24], [23, 38], [28, 26], [18, 0], [34, 27], [39, 43], [16, 8], [34, 18], [26, 43]]",10,3.5812104562104556,1.0468396734844103 +even,12,120,,48,"[[32, 25], [46, 10], [12, 4], [0, 1], [32, 9], [39, 11], [42, 4], [24, 11], [39, 10], [8, 0], [8, 24], [27, 1]]","[[6, 41], [45, 11], [28, 32], [33, 24], [23, 38], [28, 26], [18, 0], [34, 27], [39, 43], [16, 8], [34, 18], [26, 43]]",10,3.5812104562104556,1.0468396734844103 +even,12,180,,48,"[[32, 25], [46, 10], [12, 4], [0, 1], [32, 9], [39, 11], [42, 4], [24, 11], [39, 10], [8, 0], [8, 24], [27, 1]]","[[6, 41], [45, 11], [28, 32], [33, 24], [23, 38], [28, 26], [18, 0], [34, 27], [39, 43], [16, 8], [34, 18], [26, 43]]",10,3.5812104562104556,1.0468396734844103 diff --git a/data/empty-48-48-even-10_60_agents.png b/data/empty-48-48-even-10_60_agents.png new file mode 100644 index 0000000000000000000000000000000000000000..6d1317905fb0ab81d6d34a3d6c7dbacf39817771 GIT binary patch literal 496 zcmVl9K$7n~CLH$|F;pJQCcc88c1m(m0MLcZZ)Q*< z;vF1BMiqJXaj;Ib9WS73bhZvj)lZ@0g)!=ZSC+w>)*Y9F`UMCP6;;YHNy<0VWpvE! z>g?IIY7i5Nn+QnKKo_$vDxX*-xh1YSPk`7@I*mFIhn3roBPOQo9uHe~TR80t3+Zed z07NVp@`;_2JJOML(VJyNiiOM!@{rx_O%~nMlzPr10YGd@JL7bklsB#h44}{U|3C~y zE6|=|-WZw+^a(3=p5C;lJ}}1bi9sM$hm>x$O%I@r%0>tKj}m16EW}PlpBc~^r{6dX zDaNaHnV88@NBGOMPGV5AtQAv@?#_5QHh!EKo>G9KwBbRc4(-K6CX%vOY0tFJo?2hf z`tcYby}wFd**?h5{B4S%s??QefuF@@ie=X7E#&ZmF`I22xcI#N>DWtff29S;* mmN!S9<>K#D&)IK#()$DXz)DJ*lhDQh000053(sX;g*;4y%Rf4?OE>kt5pzfD3$6@3_l)rS)i0uasi zS0ji~`6)*u;aCR|v5&V0M2rEFmZf+l%BxO;)fHlZ^!-)d_5hd}!X)(}AyC@i6`vR( zeRo*>lKK!MgkMP*jkbwtuAZAD<1PrlWiPhxdr9q-d}C}|8O?oD> z_BqC+!zW~`>`yRmAB#J4J7dirN$Q=*CHs(ho}>>PFCfA@b*j3vzb-)KsjW=;08c?& UNky3CtN;K207*qoM6N<$f-~&fqyPW_ literal 0 HcmV?d00001 diff --git a/g2rl/__init__.py b/g2rl/__init__.py new file mode 100644 index 0000000..c5326e1 --- /dev/null +++ b/g2rl/__init__.py @@ -0,0 +1,5 @@ +from g2rl.environment import G2RLEnv +from g2rl.agent import G2RLAgent, DDQNAgent +from g2rl.network import CRNNModel +from g2rl.metrics import moving_cost, detour_percentage +from g2rl.train import train diff --git a/g2rl/agent.py b/g2rl/agent.py new file mode 100644 index 0000000..792b49d --- /dev/null +++ b/g2rl/agent.py @@ -0,0 +1,235 @@ +import random +import copy +from typing import Any +from collections import deque + +import numpy as np +import torch +import torch.nn.functional as F +from torch.optim import Adam + +from g2rl.network import CRNNModel +from g2rl.environment import G2RLEnv +from g2rl.utils import PrioritizedReplayBuffer + + +class G2RLAgent: + '''Inference implementation of G2RL agent''' + def __init__( + self, + model: torch.nn.Module, + action_space: list[int], + epsilon: float = 0.1, + device: str = 'cpu', + lifelong: bool = True, + ): + self.device = device + self.epsilon = epsilon + self.action_space = action_space + self.q_network = model.to(self.device) + self.q_network.eval() + self.lifelong = lifelong + + def act(self, state: dict[str, Any]) -> int: + state = state['view_cache'] + # check not lifelong status + local_guidance = state[-1,:,:,-1] + agent_coord = local_guidance.shape[0] // 2 + if not self.lifelong and \ + local_guidance[agent_coord,agent_coord] == 1 == local_guidance.sum(): + return 0 + # lifelong strategy + state = torch.from_numpy(state).float().to(self.device).unsqueeze(0) + if random.random() <= self.epsilon: + return random.choice(self.action_space) + with torch.no_grad(): + q_values = self.q_network(state) + return torch.argmax(q_values).item() + + +class DDQNAgent: + '''Implementation of DDQN agent with a prioritized sumtree reply buffer''' + def __init__( + self, + model: torch.nn.Module, + action_space: list[int], + gamma: float = 0.95, + tau: float = 0.01, + initial_epsilon: float = 1.0, + final_epsilon: float = 0.1, + decay_range: int = 5_000, + lr: float = 0.001, + replay_buffer_size: int = 1000, + device: str = 'cpu', + alpha: float = 0.6, + beta: float = 0.4 + ): + self.device = device + self.action_space = action_space + self.replay_buffer = PrioritizedReplayBuffer(replay_buffer_size, alpha) + + self.tau = tau + self.q_network = model + self.target_network = copy.deepcopy(model) + self.q_network.to(self.device) + self.target_network.to(self.device) + self.target_network.eval() + + self.gamma = gamma + self.final_epsilon = final_epsilon + self.epsilon = initial_epsilon + self.epsilon_decay = (initial_epsilon - final_epsilon) / decay_range + self.optimizer = Adam(self.q_network.parameters(), lr=lr) + self.beta = beta + + def save_weights(self, path: str): + torch.save(self.target_network.state_dict(), path) + + def store(self, state: dict[str, Any], action: int, reward: float, next_state: dict[str, Any], terminated: bool): + state_cache = state['view_cache'] + next_state_cache = next_state['view_cache'] + transition = (state_cache, action, reward, next_state_cache, terminated) + state_tensor = torch.tensor(np.array(state_cache)).float().to(self.device) + next_state_tensor = torch.tensor(np.array(next_state_cache)).float().to(self.device) + with torch.no_grad(): + curr_Q = self.q_network(state_tensor.unsqueeze(0)).squeeze(0)[action] + next_Q = self.target_network(next_state_tensor.unsqueeze(0)).max(1)[0].item() + target = reward + self.gamma * next_Q * (1 - terminated) + error = abs(curr_Q.item() - target) + self.replay_buffer.add(error, transition) + + def align_target_model(self): + for target_param, param in zip(self.target_network.parameters(), self.q_network.parameters()): + target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param) + + def act(self, state: dict[str, Any]) -> int: + state = state['view_cache'] + state = torch.from_numpy(state).float().to(self.device).unsqueeze(0) + if random.random() <= self.epsilon: + return random.choice(self.action_space) + with torch.no_grad(): + q_values = self.q_network(state) + return torch.argmax(q_values).item() + + def retrain(self, batch_size: int) -> float: + if len(self.replay_buffer) < batch_size: + return + + samples, indices, weights = self.replay_buffer.sample(batch_size, self.beta) + states, actions, rewards, next_states, dones = zip(*samples) + + states = torch.tensor(np.array(states)).float().to(self.device) + next_states = torch.tensor(np.array(next_states)).float().to(self.device) + actions = torch.tensor(actions).long().to(self.device) + rewards = torch.tensor(rewards).float().to(self.device) + dones = torch.tensor(dones).float().to(self.device) + weights = torch.tensor(weights).float().to(self.device) + + curr_Q = self.q_network(states).gather(1, actions.unsqueeze(-1)).squeeze(-1) + next_Q = self.target_network(next_states).max(1)[0] + expected_Q = rewards + self.gamma * next_Q * (1 - dones) + loss = (weights * F.mse_loss(curr_Q, expected_Q.detach(), reduction='none')).mean() + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + self.align_target_model() + + errors = torch.abs(curr_Q - expected_Q).detach().cpu().numpy() + self.replay_buffer.update_priorities(indices, errors) + + if self.epsilon > self.final_epsilon: + self.epsilon -= self.epsilon_decay + + return loss.item() + + +# class DDQNAgent: +# """Without prioritized tree reply buffer" +# def __init__( +# self, +# model: torch.nn.Module, +# action_space: list[int], +# gamma: float = 0.95, +# tau: float = 0.01, +# initial_epsilon: float = 1.0, +# final_epsilon: float = 0.1, +# decay_range: int = 5_000, +# lr: float = 0.001, +# replay_buffer_size: int = 1000, +# device: str = 'cpu' +# ): +# self.device = device +# self.action_space = action_space +# self.replay_buffer = deque(maxlen=replay_buffer_size) + +# self.tau = tau +# self.q_network = model +# self.target_network = copy.deepcopy(model) +# self.q_network.to(self.device) +# self.target_network.to(self.device) +# self.target_network.eval() + +# self.gamma = gamma +# self.final_epsilon = final_epsilon +# self.epsilon = initial_epsilon +# self.epsilon_decay = (initial_epsilon - final_epsilon) / decay_range +# self.optimizer = Adam(self.q_network.parameters(), lr=lr) + +# def save_weights(self, path: str): +# torch.save(self.target_network.state_dict(), path) + +# def store(self, state, action, reward, next_state, terminated): +# self.replay_buffer.append( +# ( +# state['view_cache'], +# action, +# reward, +# next_state['view_cache'], +# terminated +# ) +# ) + +# def align_target_model(self): +# # self.target_network.load_state_dict(self.q_network.state_dict()) +# for target_param, param in zip(self.target_network.parameters(), self.q_network.parameters()): +# target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param) + +# def act(self, state: dict[str, Any]) -> int: +# state = state['view_cache'] +# state = torch.from_numpy(state).float().to(self.device).unsqueeze(0) +# if random.random() <= self.epsilon: +# return random.choice(self.action_space) +# with torch.no_grad(): +# q_values = self.q_network(state) +# return torch.argmax(q_values).item() + +# def retrain(self, batch_size: int) -> float: +# if len(self.replay_buffer) < batch_size: +# return + +# minibatch = random.sample(self.replay_buffer, batch_size) +# states, actions, rewards, next_states, dones = zip(*minibatch) + +# states = torch.tensor(np.array(states)).float().to(self.device) +# next_states = torch.tensor(np.array(next_states)).float().to(self.device) +# actions = torch.tensor(actions).long().to(self.device) +# rewards = torch.tensor(rewards).float().to(self.device) +# dones = torch.tensor(dones).float().to(self.device) + +# curr_Q = self.q_network(states).gather(1, actions.unsqueeze(-1)).squeeze(-1) +# next_Q = self.target_network(next_states).max(1)[0] +# expected_Q = rewards + self.gamma * next_Q * (1 - dones) +# loss = F.mse_loss(curr_Q, expected_Q.detach()) + +# self.optimizer.zero_grad() +# loss.backward() +# self.optimizer.step() + +# self.align_target_model() + +# if self.epsilon > self.final_epsilon: +# self.epsilon -= self.epsilon_decay + +# return loss.item() diff --git a/g2rl/environment.py b/g2rl/environment.py new file mode 100644 index 0000000..6ba8ba7 --- /dev/null +++ b/g2rl/environment.py @@ -0,0 +1,221 @@ +from typing import Any +from copy import deepcopy +from collections import deque + +import torch +import numpy as np + +from pogema import pogema_v0, a_star, GridConfig +from pogema.animation import AnimationMonitor + + +class Grid: + '''Basic grid container''' + def __init__(self, obstacles: np.ndarray): + assert obstacles.ndim == 2 and obstacles.shape[0] == obstacles.shape[1] + self.obstacles = obstacles.copy().astype(bool) + self.size = obstacles.shape[0] + + def is_obstacle(self, h: int, w: int) -> bool|None: + if 0 <= h <= self.size and 0 <= w <= self.size: + return self.obstacles[h, w] + else: + return False + + +class G2RLEnv: + '''Environment for MAPF G2RL implementation''' + def __init__( + self, + size: int = 50, + num_agents: int = 3, + density: float|None = None, + map: str|list|None = None, + obs_radius: int = 7, + cache_size: int = 4, + r1: float = -0.01, + r2: float = -0.1, + r3: float = 0.1, + seed: int = 42, + animation: bool = True, + collission_system: str = 'soft', + on_target: str = 'restart', + max_episode_steps: int = 64, + ): + self.time_idx = 1 + self.num_agents = num_agents + self.obs_radius = obs_radius + self.cache_size = cache_size + self.r1, self.r2, self.r3 = r1, r2, r3 + self.collission_system = collission_system + self.on_target = on_target + self.obs, self.info = None, None + + self._set_env( + map, + seed=seed, + size=size, + density=density, + max_episode_steps=max_episode_steps, + animation=animation) + + self.actions = [ + ('idle', 0, 0), + ('up', -1, 0), + ('down', 1, 0), + ('left', 0, -1), + ('right', 0, 1), + ] + + def _get_reward(self, case: int, N: int = 0) -> float: + rewards = [self.r1, self.r1 + self.r2, self.r1 + N * self.r3] + return rewards[case] + + def _set_env( + self, + map: str|list|None, + size: int = 48, + density: float = 0.392, + seed: int = 42, + max_episode_steps: int = 64, + animation: bool = True, + ): + if map is not None: + self.grid_config = GridConfig( + map=map, + seed=seed, + observation_type='MAPF', + on_target=self.on_target, + num_agents=self.num_agents, + obs_radius=self.obs_radius, + collission_system=self.collission_system, + max_episode_steps=max_episode_steps, + ) + self.size = self.grid_config.size + else: + self.grid_config = GridConfig( + size=size, + density=density, + seed=seed, + observation_type='MAPF', + on_target=self.on_target, + num_agents=self.num_agents, + obs_radius=self.obs_radius, + collission_system=self.collission_system, + max_episode_steps=max_episode_steps, + ) + self.size = size + + self.env = pogema_v0(grid_config=self.grid_config) + if animation: + self.env = AnimationMonitor(self.env) + + def _set_global_guidance(self, obs: list[dict]): + grid = Grid(obs[0]['global_obstacles']) + coords = [[ob['global_xy'], ob['global_target_xy']] for ob in obs] + self.global_guidance = [a_star(st, tg, grid) for st, tg in coords] + + def save_animation(self, path): + self.env.save_animation(path) + + def get_action_space(self) -> list[int]: + return list(range(len(self.actions))) + + def reset(self) -> tuple[list, list]: + self.time_idx = 1 + self.obs, self.info = self.env.reset() + self._set_global_guidance(self.obs) + self.view_cache = [] + for i, (ob, guidance) in enumerate(zip(self.obs, self.global_guidance)): + guidance.remove(ob['global_xy']) + view = self._get_local_view(ob, guidance) + view_cache = [np.zeros_like(view) for _ in range(self.cache_size - 1)] + [view] + self.view_cache.append(deque(view_cache, self.cache_size)) + self.obs[i]['view_cache'] = np.array(self.view_cache[-1]) + return self.obs, self.info + + def _reset_agent(self, i: int, ob: dict) -> dict[str, Any]: + grid = Grid(ob['global_obstacles']) + self.global_guidance[i] = a_star(ob['global_xy'], ob['global_target_xy'], grid) + self.global_guidance[i].remove(ob['global_xy']) + + view = self._get_local_view(ob, self.global_guidance[i]) + view_cache = [np.zeros_like(view) for _ in range(self.cache_size - 1)] + [view] + self.view_cache[i] = deque(view_cache, self.cache_size) + ob['view_cache'] = np.array(self.view_cache[i]) + return ob + + def _get_local_view( + self, + obs: dict[str, Any], + global_guidance: np.ndarray, + ) -> np.ndarray: + local_coord = self.obs_radius + + local_guidance = np.zeros_like(obs['agents']) + local_size = local_guidance.shape[0] + delta = [global_coord - local_coord for global_coord in obs['global_xy']] + for global_cell in global_guidance: + h = global_cell[0] - delta[0] + w = global_cell[1] - delta[1] + if 0 <= h < local_size and 0 <= w < local_size: + local_guidance[h, w] = 1 + + curr_agent = np.zeros_like(obs['agents']) + curr_agent[local_coord, local_coord] = 1 + return np.dstack( + ( + curr_agent, + obs['obstacles'], + obs['agents'], + local_guidance, + ) + ) + + def step(self, actions: list[int]) -> tuple[list, ...]: + conflict_points = set() + obs, reward, terminated, truncated, info = self.env.step(actions) + # calculate reward + for i, (action, ob, status) in enumerate(zip(actions, obs, info)): + if status['is_active']: + new_point = ob['global_xy'] + # conflict + if self.actions[action] != 'idle' and new_point == self.obs[i]['global_xy']: + reward[i] = self._get_reward(1) + if self.collission_system != 'block_both': + # another agent (block strategy is considered) + if ob['global_obstacles'][new_point] == 0: + conflict_points.add(new_point) + # global guidance cell + elif new_point in self.global_guidance[i]: + new_point_idx = self.global_guidance[i].index(new_point) + reward[i] = self._get_reward(2, new_point_idx + 1) + # update global guidance + if self.on_target == 'nothing': + if new_point == self.global_guidance[i][-1]: + self.global_guidance[i] = self.global_guidance[i][-1:] + else: + self.global_guidance[i] = self.global_guidance[i][new_point_idx + 1:] + else: + self.global_guidance[i] = self.global_guidance[i][new_point_idx + 1:] + if len(self.global_guidance[i]) == 0: + ob = self._reset_agent(i, ob) + # free cell + else: + reward[i] = self._get_reward(0) + + # update history of observations + view = self._get_local_view(ob, self.global_guidance[i]) + self.view_cache[i].append(view) + + obs[i]['view_cache'] = np.array(self.view_cache[i]) + + # recalculate reward if strategy is not blocking + if self.collission_system != 'block_both': + for i, (ob, status) in enumerate(zip(obs, info)): + if status['is_active'] and ob['global_xy'] in conflict_points: + reward[i] = self._get_reward(1) + + self.obs, self.info = obs, info + self.time_idx += 1 + return obs, reward, terminated, truncated, info diff --git a/g2rl/metrics.py b/g2rl/metrics.py new file mode 100644 index 0000000..2e6381a --- /dev/null +++ b/g2rl/metrics.py @@ -0,0 +1,13 @@ +import numpy as np + + +def manhattan_distance(x_st: int, y_st: int, x_end: int, y_end: int) -> int: + return abs(x_end - x_st) + abs(y_end - y_st) + + +def moving_cost(num_steps: int, c_start: list[int], c_goal: list[int]) -> float: + return num_steps / (manhattan_distance(*c_start, *c_goal)) + + +def detour_percentage(num_steps: int, opt_path_len: int) -> float: + return (num_steps - opt_path_len) / opt_path_len * 100 diff --git a/g2rl/network.py b/g2rl/network.py new file mode 100644 index 0000000..17224a6 --- /dev/null +++ b/g2rl/network.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class CRNNModel(nn.Module): + def __init__( + self, + num_actions: int = 5, + num_timesteps: int = 4, + initial_channels: int = 4, + hidden_size: int = 128, + lstm_input_size: int|None = None, + num_kernels: list[int] = [32, 64], + ): + super(CRNNModel, self).__init__() + self.num_timesteps = num_timesteps + self.num_actions = num_actions + self.hidden_size = hidden_size + self.initial_channels = initial_channels + + # CNN blocks + self.conv_blocks = nn.ModuleList() + in_channels = initial_channels + for out_channels in num_kernels: + block = nn.Sequential( + nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 3, 3), + stride=(1, 1, 1), + padding=(0, 1, 1)), + nn.ReLU(), + nn.Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(1, 3, 3), + stride=(1, 2, 2), + padding=(0, 1, 1)), + nn.ReLU() + ) + self.conv_blocks.append(block) + in_channels = out_channels + + # LSTM Layer + if lstm_input_size is None: + self.lstm = None + else: + self.lstm = nn.LSTM( + input_size=lstm_input_size, + hidden_size=hidden_size, + batch_first=True) + + # FC layers + self.fc1 = nn.Linear(hidden_size, hidden_size // 2) + self.fc2 = nn.Linear(hidden_size // 2, num_actions) + + def forward(self, x): + # reshape to (batch_size, channels, depth, height, width) + x = x.permute(0, 4, 1, 2, 3) + + for block in self.conv_blocks: + x = block(x) + + # determine the LSTM input size if it hasn't been set + if self.lstm is None: + batch_size, _, depth, height, width = x.size() + lstm_input_size = height * width * self.conv_blocks[-1][0].out_channels + self.lstm = nn.LSTM( + input_size=lstm_input_size, + hidden_size=self.hidden_size, + batch_first=True, + device=x.device) + + batch_size, _, depth, height, width = x.size() + x = x.reshape(batch_size, self.num_timesteps, -1) + + lstm_out, _ = self.lstm(x) + x = lstm_out[:, -1, :] + + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return x diff --git a/g2rl/train.py b/g2rl/train.py new file mode 100644 index 0000000..2039aa6 --- /dev/null +++ b/g2rl/train.py @@ -0,0 +1,149 @@ +from pathlib import Path +from datetime import datetime +from tqdm.notebook import tqdm + +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter +from pogema import AStarAgent + +from g2rl import DDQNAgent, G2RLEnv, CRNNModel +from g2rl import moving_cost, detour_percentage + + +def get_timestamp() -> str: + now = datetime.now() + timestamp = now.strftime('%H-%M-%d-%m-%Y') + return timestamp + + +def get_normalized_probs(x: list[float]|None, size: int) -> np.ndarray: + x = [1] * size if x is None else x + [0] * (size - len(x)) + e_x = np.exp(x - np.max(x)) + return e_x / e_x.sum(axis=0) + + +def train( + model: torch.nn.Module, + map_settings: dict[str, dict], + map_probs: list[float]|None, + num_episodes: int = 300, + batch_size: int = 32, + decay_range: int = 1000, + log_dir = 'logs', + lr: float = 0.001, + replay_buffer_size: int = 1000, + device: str = 'cuda' + ) -> DDQNAgent: + timestamp = get_timestamp() + writer = SummaryWriter(log_dir=Path(log_dir) / timestamp) + maps = [G2RLEnv(**args) for _, args in map_settings.items()] + map_probs = get_normalized_probs(map_probs, len(maps)) + agent = DDQNAgent( + model, + maps[0].get_action_space(), + lr=lr, + decay_range=decay_range, + device=device, + replay_buffer_size=replay_buffer_size, + ) + + pbar = tqdm(total=num_episodes, desc='Episodes') + for episode in range(num_episodes): + torch.save(model.state_dict(), f'models/{timestamp}.pt') + env = np.random.choice(maps, p=map_probs) + target_idx = np.random.randint(env.num_agents) + agents = [agent if i == target_idx else AStarAgent() for i in range(env.num_agents)] + obs, info = env.reset() + + state = obs[target_idx] + opt_path = [state['global_xy']] + env.global_guidance[target_idx] + retrain_count = 0 + scalars = { + 'Reward': 0, + 'Moving Cost': 0, + 'Detour Percentage': 0, + 'Average Loss': 0, + 'Average Epsilon': 0, + } + + timesteps_per_episode = 50 + 10 * episode + for timestep in range(timesteps_per_episode): + actions = [agent.act(ob) for agent, ob in zip(agents, obs)] + obs, reward, terminated, truncated, info = env.step(actions) + terminated[target_idx] = obs[target_idx]['global_xy'] == opt_path[-1] + + # if the target agent has finished or FOV does not contain the global guidance + if terminated[target_idx] or (obs[target_idx]['view_cache'][-1][:,:,-1] == 0).all(): + if terminated[target_idx]: + scalars['Moving Cost'] = moving_cost(timestep + 1, opt_path[0], opt_path[-1]) + scalars['Detour Percentage'] = detour_percentage(timestep + 1, len(opt_path) - 1) + break + + agent.store( + state, + actions[target_idx], + reward[target_idx], + obs[target_idx], + terminated[target_idx], + ) + state = obs[target_idx] + scalars['Reward'] += reward[target_idx] + + if len(agent.replay_buffer) >= batch_size: + retrain_count += 1 + scalars['Average Loss'] += agent.retrain(batch_size) + scalars['Average Epsilon'] += round(agent.epsilon, 4) + + for name in scalars.keys(): + if 'Average' in name and retrain_count > 0: + scalars[name] /= retrain_count + + # logging + for name, value in scalars.items(): + writer.add_scalar(name, value, episode) + pbar.update(1) + pbar.set_postfix(scalars) + + writer.close() + return agent + + +if __name__ == '__main__': + # basic elements + map_settings = { + 'regular': { + 'size': 48, + 'density': 0.392, + 'num_agents': 4, + }, + 'random': { + 'size': 48, + 'density': 0.15, + 'num_agents': 6, + }, + 'free': { + 'size': 48, + 'density': 0, + 'num_agents': 11, + }, + } + map_probs = [0.3, 0.35, 0.35] + model = CRNNModel() + device = 'cpu' + + # train loop + trained_agent = train( + model, + map_settings=map_settings, + map_probs=map_probs, + num_episodes=300, + batch_size=32, + replay_buffer_size=500, + decay_range=10_000, + log_dir='logs', + device=device, + ) + + # save model + torch.save(model.state_dict(), 'models/best_model.pt') diff --git a/g2rl/utils.py b/g2rl/utils.py new file mode 100644 index 0000000..e823fd1 --- /dev/null +++ b/g2rl/utils.py @@ -0,0 +1,96 @@ +import numpy as np +from typing import Any + + +class SumTree: + def __init__(self, capacity: int): + self.capacity = capacity + self.tree = np.zeros(2 * capacity - 1) + self.data = np.zeros(capacity, dtype=object) + self.data_pointer = 0 + + def add(self, priority: float, data: Any): + tree_idx = self.data_pointer + self.capacity - 1 + self.data[self.data_pointer] = data + self.update(tree_idx, priority) + + self.data_pointer += 1 + if self.data_pointer >= self.capacity: + self.data_pointer = 0 + + def update(self, tree_idx: int, priority: float): + change = priority - self.tree[tree_idx] + self.tree[tree_idx] = priority + self._propagate(tree_idx, change) + + def _propagate(self, tree_idx: int, change: float): + parent = (tree_idx - 1) // 2 + self.tree[parent] += change + if parent != 0: + self._propagate(parent, change) + + def get_leaf(self, value: float) -> tuple[int, float, Any]: + parent_idx = 0 + while True: + left_child_idx = 2 * parent_idx + 1 + right_child_idx = left_child_idx + 1 + if left_child_idx >= len(self.tree): + leaf_idx = parent_idx + break + else: + if value <= self.tree[left_child_idx]: + parent_idx = left_child_idx + else: + value -= self.tree[left_child_idx] + parent_idx = right_child_idx + + data_idx = leaf_idx - self.capacity + 1 + return leaf_idx, self.tree[leaf_idx], self.data[data_idx] + + @property + def total_priority(self) -> float: + return self.tree[0] + + +class PrioritizedReplayBuffer: + def __init__(self, capacity: int, alpha: float = 0.6): + self.tree = SumTree(capacity) + self.alpha = alpha + self.max_priority = 1.0 + self.capacity = capacity + + def add(self, error: float, transition: tuple): + priority = (error + 1e-5) ** self.alpha + self.tree.add(priority, transition) + self.max_priority = max(self.max_priority, priority) + + def sample(self, batch_size: int, beta: float = 0.4) -> tuple[list, list[int], np.ndarray]: + batch = [] + idxs = [] + segment = self.tree.total_priority / batch_size + priorities = [] + + for i in range(batch_size): + a = segment * i + b = segment * (i + 1) + value = np.random.uniform(a, b) + idx, priority, data = self.tree.get_leaf(value) + batch.append(data) + idxs.append(idx) + priorities.append(priority) + + sampling_probabilities = priorities / self.tree.total_priority + is_weight = np.power(self.tree.capacity * sampling_probabilities, -beta) + is_weight /= is_weight.max() + is_weight = np.array(is_weight, dtype=np.float32) + + return batch, idxs, is_weight + + def update_priorities(self, idxs: list[int], errors: list[float], eps: float = 1e-5): + for idx, error in zip(idxs, errors): + priority = (error + eps) ** self.alpha + self.tree.update(idx, priority) + self.max_priority = max(self.max_priority, priority) + + def __len__(self) -> int: + return len(self.tree.data) diff --git a/notebooks/train&test.ipynb b/notebooks/train&test.ipynb new file mode 100644 index 0000000..d8e83ec --- /dev/null +++ b/notebooks/train&test.ipynb @@ -0,0 +1 @@ +{"cells":[{"cell_type":"markdown","metadata":{"id":"EpM2x94GLy1q"},"source":["## Install Libs"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"3ED97AGv3toA"},"outputs":[],"source":["%pip install -r requirements.txt\n","%pip install ."]},{"cell_type":"markdown","metadata":{"id":"yLQOxT2nMIZF"},"source":["## Train"]},{"cell_type":"code","execution_count":2,"metadata":{"executionInfo":{"elapsed":22875,"status":"ok","timestamp":1716753851069,"user":{"displayName":"Julia Bel","userId":"09344534935963523050"},"user_tz":-180},"id":"zAgcMy0NGCF3"},"outputs":[],"source":["from pathlib import Path\n","from datetime import datetime\n","\n","import numpy as np\n","import torch\n","from torch.utils.tensorboard import SummaryWriter\n","from tqdm.notebook import tqdm\n","\n","from g2rl import G2RLAgent, G2RLEnv\n","from g2rl import DDQNAgent, CRNNModel, moving_cost, detour_percentage, train"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":2407,"status":"ok","timestamp":1716739724915,"user":{"displayName":"Julia Bel","userId":"09344534935963523050"},"user_tz":-180},"id":"59tlFtQxBBOg","outputId":"f26695bf-5792-421f-daad-78288cabd1a2"},"outputs":[{"data":{"text/plain":[""]},"execution_count":5,"metadata":{},"output_type":"execute_result"}],"source":["device = 'cuda'\n","map_settings = {\n"," 'regular_1': {\n"," 'size': 50,\n"," 'density': 0.392,\n"," 'num_agents': 4,\n"," },\n"," 'regular_2': {\n"," 'size': 50,\n"," 'density': 0.452,\n"," 'num_agents': 6,\n"," },\n"," 'random_1': {\n"," 'size': 50,\n"," 'density': 0.15,\n"," 'num_agents': 6,\n"," },\n"," 'random_2': {\n"," 'size': 50,\n"," 'density': 0.25,\n"," 'num_agents': 7,\n"," },\n"," 'free_1': {\n"," 'size': 50,\n"," 'density': 0,\n"," 'num_agents': 11,\n"," },\n"," 'free_2': {\n"," 'size': 50,\n"," 'density': 0.1,\n"," 'num_agents': 15,\n"," },\n","}\n","map_probs = [1, 1, 1, 1, 1, 1]\n","model = CRNNModel(lstm_input_size=1024)\n","# model.load_state_dict(torch.load('models/16-07-24-05-2024.pt', map_location=device))"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":49,"referenced_widgets":["717d0f90d3b0408f954dffb08adb11b9","c30538ac699f457d97bd2472e94be38a","2900bef506144d6bb6a6cf63e1195a28","f49926abe67a4a53962a84a2386a624c","869ad1f1a6694f2bafa25a4531963d74","f48d953c02404f468484926972a291fc","0e517e802aaf42b7bd30d3a60bf5c8e2","8885c099470743db91c6e63033fbc5cb","0b8179875d3744e78e5907fc62e80c9c","cdaa0afab0a14e528a42b9db36c5f0bb","0ead522381ec424f8b71eb1ccdc03c7a"]},"executionInfo":{"elapsed":1725855,"status":"ok","timestamp":1716412276554,"user":{"displayName":"Julia Bel","userId":"09344534935963523050"},"user_tz":-180},"id":"nuP0fAIkemA-","outputId":"0a705739-a0d2-4e1c-d32b-71af19a498b3"},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"717d0f90d3b0408f954dffb08adb11b9","version_major":2,"version_minor":0},"text/plain":["Episodes: 0%| | 0/500 [00:00"],"image/svg+xml":"\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n \n\n \n \n\n \n \n\n \n \n\n \n \n\n \n \n\n \n \n\n \n \n\n \n \n\n \n \n\n \n"},"metadata":{}}]},{"cell_type":"code","execution_count":null,"metadata":{"id":"gvOMSh3FFG-z"},"outputs":[],"source":["%load_ext tensorboard\n","%tensorboard --logdir logs"]},{"cell_type":"markdown","metadata":{"id":"4Jc4MY0bHtD-"},"source":["### Default Maps"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"7Ba-qzKqTff_"},"outputs":[],"source":["from PIL import Image\n","import pandas as pd"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"PJ3A3fWsHVf8"},"outputs":[],"source":["def image2grid(image_path: str) -> np.ndarray:\n"," image = Image.open(image_path).convert('L')\n"," threshold_value = 128\n"," binary_image = np.array(image) < threshold_value\n"," return list(binary_image.astype(int))\n","\n","\n","def test(env, agent, name: str = \"\") -> tuple[list, list, dict]:\n"," obs, info = env.reset()\n"," opt_paths = [[(ob['global_xy'])] + path for path, ob in zip(env.global_guidance, obs)]\n"," start = [[path[0][0] - env.obs_radius, path[0][1] - env.obs_radius] for path in opt_paths]\n"," finish = [[path[-1][0] - env.obs_radius, path[-1][1] - env.obs_radius] for path in opt_paths]\n"," terminated = truncated = [False, ...]\n"," timesteps = [0] * len(obs)\n"," scalars = {\n"," 'moving_cost': [],\n"," 'detour_percentage': [],\n"," 'done': 0,\n"," }\n","\n"," latest_obs = obs\n"," latest_info = info\n"," while not all(terminated) and not all(truncated):\n"," timesteps = [t + 1 for t in timesteps]\n"," actions = [agent.act(ob) if status['is_active'] else 0 for ob, status in zip(obs, info)]\n"," obs, reward, terminated, truncated, info = env.step(actions)\n"," for i in range(env.num_agents):\n"," if obs[i]['global_xy'] == latest_obs[i]['global_target_xy']:\n"," if (latest_info[i]['is_active'] and not info[i]['is_active']) or info[i]['is_active']:\n"," if obs[i]['global_xy'] != latest_obs[i]['global_xy']:\n"," scalars['moving_cost'].append(moving_cost(timesteps[i], opt_paths[i][0], opt_paths[i][-1]))\n"," scalars['detour_percentage'].append(detour_percentage(timesteps[i], len(opt_paths[i]) - 1))\n"," scalars['done'] += 1\n"," if info[i]['is_active'] and obs[i]['global_xy'] != obs[i]['global_target_xy']:\n"," opt_paths[i] = [obs[i]['global_xy']] + env.global_guidance[i]\n"," timesteps[i] = 0\n"," latest_obs = obs\n"," latest_info = info\n","\n"," env.save_animation(f'renders/{name}.svg')\n","\n"," n = len(scalars['moving_cost'])\n"," scalars['moving_cost'] = sum(scalars['moving_cost']) / n if n > 0 else 0\n"," scalars['detour_percentage'] = sum(scalars['detour_percentage']) / n if n > 0 else 0\n"," return start, finish, scalars"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"JHOp4o4kIt6c"},"outputs":[],"source":["default_map_settings = {\n"," 'random': {\n"," 'size': 48,\n"," 'num_agents': None,\n"," 'density': None,\n"," 'map': image2grid('data/empty-48-48-random-10_60_agents.png'),\n","\n"," },\n"," 'even': {\n"," 'size': 48,\n"," 'num_agents': None,\n"," 'density': None,\n"," 'map': image2grid('data/empty-48-48-even-10_60_agents.png'),\n"," }\n","}"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"5H1Tk1YfHVnb"},"outputs":[],"source":["agent = G2RLAgent(\n"," model,\n"," action_space=5,\n"," epsilon=0,\n"," device=device)\n","\n","results = {\n"," 'map': [],\n"," 'num_agents': [],\n"," 'max_steps': [],\n"," 'density': [],\n"," 'size': [],\n"," 'start': [],\n"," 'final': [],\n"," 'done': [],\n"," 'detour_percentage': [],\n"," 'moving_cost': [],\n","}\n","\n","for map_name, map_value in default_map_settings.items():\n"," for num_agents in [3, 6, 12]:\n"," map_value['num_agents'] = num_agents\n"," for steps in [60, 120, 180]:\n"," env = G2RLEnv(**map_value, max_episode_steps=steps, on_target='finish')\n"," start, finish, scalars = test(env, agent, name=f'{map_name}_{num_agents}_{steps}')\n"," results['map'].append(map_name)\n"," results['num_agents'].append(num_agents)\n"," results['max_steps'].append(steps)\n"," results['done'].append(scalars['done'])\n"," results['detour_percentage'].append(scalars['detour_percentage'])\n"," results['moving_cost'].append(scalars['moving_cost'])\n"," results['start'].append(start)\n"," results['final'].append(finish)\n"," results['density'].append(map_value['density'])\n"," results['size'].append(map_value['size'])\n","\n","stat = pd.DataFrame(data=results)"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":613},"executionInfo":{"elapsed":266,"status":"ok","timestamp":1716752812602,"user":{"displayName":"Julia Bel","userId":"09344534935963523050"},"user_tz":-180},"id":"a-cprsUJljWr","outputId":"c16f797c-d854-4350-86e5-fa1ca74a39b9"},"outputs":[{"output_type":"execute_result","data":{"text/plain":[" map num_agents max_steps density size \\\n","0 random 3 60 None 48 \n","1 random 3 120 None 48 \n","2 random 3 180 None 48 \n","3 random 6 60 None 48 \n","4 random 6 120 None 48 \n","5 random 6 180 None 48 \n","6 random 12 60 None 48 \n","7 random 12 120 None 48 \n","8 random 12 180 None 48 \n","9 even 3 60 None 48 \n","10 even 3 120 None 48 \n","11 even 3 180 None 48 \n","12 even 6 60 None 48 \n","13 even 6 120 None 48 \n","14 even 6 180 None 48 \n","15 even 12 60 None 48 \n","16 even 12 120 None 48 \n","17 even 12 180 None 48 \n","\n"," start \\\n","0 [[25, 22], [17, 13], [22, 39]] \n","1 [[25, 22], [17, 13], [22, 39]] \n","2 [[25, 22], [17, 13], [22, 39]] \n","3 [[25, 22], [17, 13], [22, 39], [32, 36], [38, ... \n","4 [[25, 22], [17, 13], [22, 39], [32, 36], [38, ... \n","5 [[25, 22], [17, 13], [22, 39], [32, 36], [38, ... \n","6 [[25, 22], [17, 13], [22, 39], [32, 36], [38, ... \n","7 [[25, 22], [17, 13], [22, 39], [32, 36], [38, ... \n","8 [[25, 22], [17, 13], [22, 39], [32, 36], [38, ... \n","9 [[32, 25], [46, 10], [12, 4]] \n","10 [[32, 25], [46, 10], [12, 4]] \n","11 [[32, 25], [46, 10], [12, 4]] \n","12 [[32, 25], [46, 10], [12, 4], [0, 1], [32, 9],... \n","13 [[32, 25], [46, 10], [12, 4], [0, 1], [32, 9],... \n","14 [[32, 25], [46, 10], [12, 4], [0, 1], [32, 9],... \n","15 [[32, 25], [46, 10], [12, 4], [0, 1], [32, 9],... \n","16 [[32, 25], [46, 10], [12, 4], [0, 1], [32, 9],... \n","17 [[32, 25], [46, 10], [12, 4], [0, 1], [32, 9],... \n","\n"," final done \\\n","0 [[28, 26], [32, 16], [46, 32]] 2 \n","1 [[28, 26], [32, 16], [46, 32]] 2 \n","2 [[28, 26], [32, 16], [46, 32]] 2 \n","3 [[28, 26], [32, 16], [46, 32], [28, 19], [33, ... 5 \n","4 [[28, 26], [32, 16], [46, 32], [28, 19], [33, ... 5 \n","5 [[28, 26], [32, 16], [46, 32], [28, 19], [33, ... 5 \n","6 [[28, 26], [32, 16], [46, 32], [28, 19], [33, ... 9 \n","7 [[28, 26], [32, 16], [46, 32], [28, 19], [33, ... 9 \n","8 [[28, 26], [32, 16], [46, 32], [28, 19], [33, ... 9 \n","9 [[6, 41], [45, 11], [28, 32]] 2 \n","10 [[6, 41], [45, 11], [28, 32]] 2 \n","11 [[6, 41], [45, 11], [28, 32]] 2 \n","12 [[6, 41], [45, 11], [28, 32], [33, 24], [23, 3... 5 \n","13 [[6, 41], [45, 11], [28, 32], [33, 24], [23, 3... 5 \n","14 [[6, 41], [45, 11], [28, 32], [33, 24], [23, 3... 5 \n","15 [[6, 41], [45, 11], [28, 32], [33, 24], [23, 3... 10 \n","16 [[6, 41], [45, 11], [28, 32], [33, 24], [23, 3... 10 \n","17 [[6, 41], [45, 11], [28, 32], [33, 24], [23, 3... 10 \n","\n"," detour_percentage moving_cost \n","0 0.000000 1.120072 \n","1 0.000000 1.120072 \n","2 0.000000 1.120072 \n","3 0.000000 1.141997 \n","4 0.000000 1.141997 \n","5 0.000000 1.141997 \n","6 1.915709 1.137863 \n","7 1.915709 1.137863 \n","8 1.915709 1.137863 \n","9 0.000000 1.000000 \n","10 0.000000 1.000000 \n","11 0.000000 1.000000 \n","12 0.000000 1.021053 \n","13 0.000000 1.021053 \n","14 0.000000 1.021053 \n","15 3.581210 1.046840 \n","16 3.581210 1.046840 \n","17 3.581210 1.046840 "],"text/html":["\n","
\n","
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
mapnum_agentsmax_stepsdensitysizestartfinaldonedetour_percentagemoving_cost
0random360None48[[25, 22], [17, 13], [22, 39]][[28, 26], [32, 16], [46, 32]]20.0000001.120072
1random3120None48[[25, 22], [17, 13], [22, 39]][[28, 26], [32, 16], [46, 32]]20.0000001.120072
2random3180None48[[25, 22], [17, 13], [22, 39]][[28, 26], [32, 16], [46, 32]]20.0000001.120072
3random660None48[[25, 22], [17, 13], [22, 39], [32, 36], [38, ...[[28, 26], [32, 16], [46, 32], [28, 19], [33, ...50.0000001.141997
4random6120None48[[25, 22], [17, 13], [22, 39], [32, 36], [38, ...[[28, 26], [32, 16], [46, 32], [28, 19], [33, ...50.0000001.141997
5random6180None48[[25, 22], [17, 13], [22, 39], [32, 36], [38, ...[[28, 26], [32, 16], [46, 32], [28, 19], [33, ...50.0000001.141997
6random1260None48[[25, 22], [17, 13], [22, 39], [32, 36], [38, ...[[28, 26], [32, 16], [46, 32], [28, 19], [33, ...91.9157091.137863
7random12120None48[[25, 22], [17, 13], [22, 39], [32, 36], [38, ...[[28, 26], [32, 16], [46, 32], [28, 19], [33, ...91.9157091.137863
8random12180None48[[25, 22], [17, 13], [22, 39], [32, 36], [38, ...[[28, 26], [32, 16], [46, 32], [28, 19], [33, ...91.9157091.137863
9even360None48[[32, 25], [46, 10], [12, 4]][[6, 41], [45, 11], [28, 32]]20.0000001.000000
10even3120None48[[32, 25], [46, 10], [12, 4]][[6, 41], [45, 11], [28, 32]]20.0000001.000000
11even3180None48[[32, 25], [46, 10], [12, 4]][[6, 41], [45, 11], [28, 32]]20.0000001.000000
12even660None48[[32, 25], [46, 10], [12, 4], [0, 1], [32, 9],...[[6, 41], [45, 11], [28, 32], [33, 24], [23, 3...50.0000001.021053
13even6120None48[[32, 25], [46, 10], [12, 4], [0, 1], [32, 9],...[[6, 41], [45, 11], [28, 32], [33, 24], [23, 3...50.0000001.021053
14even6180None48[[32, 25], [46, 10], [12, 4], [0, 1], [32, 9],...[[6, 41], [45, 11], [28, 32], [33, 24], [23, 3...50.0000001.021053
15even1260None48[[32, 25], [46, 10], [12, 4], [0, 1], [32, 9],...[[6, 41], [45, 11], [28, 32], [33, 24], [23, 3...103.5812101.046840
16even12120None48[[32, 25], [46, 10], [12, 4], [0, 1], [32, 9],...[[6, 41], [45, 11], [28, 32], [33, 24], [23, 3...103.5812101.046840
17even12180None48[[32, 25], [46, 10], [12, 4], [0, 1], [32, 9],...[[6, 41], [45, 11], [28, 32], [33, 24], [23, 3...103.5812101.046840
\n","
\n","
\n","\n","
\n"," \n","\n"," \n","\n"," \n","
\n","\n","\n","
\n"," \n","\n","\n","\n"," \n","
\n","
\n","
\n"],"application/vnd.google.colaboratory.intrinsic+json":{"type":"dataframe","variable_name":"stat","repr_error":"Out of range float values are not JSON compliant: nan"}},"metadata":{},"execution_count":16}],"source":["stat"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"RBgt7cEL74Yz"},"outputs":[],"source":["stat.to_csv('data/default_statistics.csv', index=False)"]},{"cell_type":"markdown","metadata":{"id":"r9H99rOfNbtf"},"source":["### Different env modes"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"jAFoVqokR7AR"},"outputs":[],"source":["on_target_ops = ['restart', 'finish', 'nothing']\n","collission_system_ops = ['priority', 'block_both', 'soft']"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"U1mbzQvLNi45"},"outputs":[],"source":["for on_target in on_target_ops:\n"," for collission_system in collission_system_ops:\n"," env = G2RLEnv(\n"," **map_settings['free'],\n"," max_episode_steps=60,\n"," on_target=on_target,\n"," collission_system=collission_system)\n"," agent = G2RLAgent(\n"," model,\n"," action_space=5,\n"," epsilon=0,\n"," device=device,\n"," lifelong=on_target == 'restart')\n"," result = test(env, agent, name=f'free_{on_target}_{collission_system}')"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":651},"executionInfo":{"elapsed":321,"status":"ok","timestamp":1716543834073,"user":{"displayName":"Julia Bel","userId":"09344534935963523050"},"user_tz":-180},"id":"cCy4lyh0Ni8a","outputId":"f0591c3f-0b43-4061-e7a7-cebfeeb50484"},"outputs":[{"data":{"image/svg+xml":"\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n \n\n \n \n\n \n \n\n \n \n\n \n \n\n \n \n\n \n \n\n \n \n\n \n \n\n \n \n\n \n \n\n \n \n \n \n \n \n \n \n \n \n \n \n","text/plain":[""]},"metadata":{},"output_type":"display_data"}],"source":["display(SVG(f'renders/free_nothing_soft.svg'))"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"jEpJj6VxNdxl"},"outputs":[],"source":[]}],"metadata":{"colab":{"machine_shape":"hm","provenance":[],"mount_file_id":"1kU_IOgZpwR1KJWgk_seN_ZLlTJMlBQkN","authorship_tag":"ABX9TyNJGFIkWMgYt4pQe1HTOktj"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"},"widgets":{"application/vnd.jupyter.widget-state+json":{"0b8179875d3744e78e5907fc62e80c9c":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"ProgressStyleModel","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"ProgressStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","bar_color":null,"description_width":""}},"0e517e802aaf42b7bd30d3a60bf5c8e2":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"DescriptionStyleModel","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"0ead522381ec424f8b71eb1ccdc03c7a":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"DescriptionStyleModel","state":{"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"DescriptionStyleModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"StyleView","description_width":""}},"2900bef506144d6bb6a6cf63e1195a28":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"FloatProgressModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"FloatProgressModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"ProgressView","bar_style":"success","description":"","description_tooltip":null,"layout":"IPY_MODEL_8885c099470743db91c6e63033fbc5cb","max":500,"min":0,"orientation":"horizontal","style":"IPY_MODEL_0b8179875d3744e78e5907fc62e80c9c","value":500}},"717d0f90d3b0408f954dffb08adb11b9":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"HBoxModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HBoxModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HBoxView","box_style":"","children":["IPY_MODEL_c30538ac699f457d97bd2472e94be38a","IPY_MODEL_2900bef506144d6bb6a6cf63e1195a28","IPY_MODEL_f49926abe67a4a53962a84a2386a624c"],"layout":"IPY_MODEL_869ad1f1a6694f2bafa25a4531963d74"}},"869ad1f1a6694f2bafa25a4531963d74":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"8885c099470743db91c6e63033fbc5cb":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"c30538ac699f457d97bd2472e94be38a":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"HTMLModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_f48d953c02404f468484926972a291fc","placeholder":"​","style":"IPY_MODEL_0e517e802aaf42b7bd30d3a60bf5c8e2","value":"Episodes: 100%"}},"cdaa0afab0a14e528a42b9db36c5f0bb":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"f48d953c02404f468484926972a291fc":{"model_module":"@jupyter-widgets/base","model_module_version":"1.2.0","model_name":"LayoutModel","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"f49926abe67a4a53962a84a2386a624c":{"model_module":"@jupyter-widgets/controls","model_module_version":"1.5.0","model_name":"HTMLModel","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/controls","_model_module_version":"1.5.0","_model_name":"HTMLModel","_view_count":null,"_view_module":"@jupyter-widgets/controls","_view_module_version":"1.5.0","_view_name":"HTMLView","description":"","description_tooltip":null,"layout":"IPY_MODEL_cdaa0afab0a14e528a42b9db36c5f0bb","placeholder":"​","style":"IPY_MODEL_0ead522381ec424f8b71eb1ccdc03c7a","value":" 500/500 [31:11<00:00,  5.87s/it, Reward=0.78, Moving Cost=0, Detour Percentage=0, Average Loss=0.000353, Average Epsilon=0.0998]"}}}}},"nbformat":4,"nbformat_minor":0} \ No newline at end of file diff --git a/pogema b/pogema new file mode 160000 index 0000000..4a3d06e --- /dev/null +++ b/pogema @@ -0,0 +1 @@ +Subproject commit 4a3d06ec726bfb4c50a0f588d81d320b68386370 diff --git a/renders/test.svg b/renders/test.svg new file mode 100644 index 0000000..5fdbe81 --- /dev/null +++ b/renders/test.svg @@ -0,0 +1,476 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..82f2e87 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +numpy +pandas +torch +-e ./pogema + diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..6f5be1f --- /dev/null +++ b/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup, find_packages + + +setup( + name='g2rl', + version='0.1.0', + author='julia-bel', + license='MIT', + long_description=open('README.md').read(), + long_description_content_type='text/markdown', + description='Implementation of G2RL in the POGEMA environment', + install_requires=[ + "numpy", + "pandas", + "torch", + ], + packages=find_packages(), + python_requires='>=3.9', +)