Skip to content

Commit

Permalink
Merge branch 'main' into transfqmix
Browse files Browse the repository at this point in the history
  • Loading branch information
mttga committed Jan 22, 2024
2 parents 04145fa + 0cec9ed commit 7ccb5c1
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 6 deletions.
50 changes: 50 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04

# install python
ARG DEBIAN_FRONTEND=noninteractive
ARG PYTHON_VERSION=3.10

# ENV TZ=Europe/London

RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \
software-properties-common \
build-essential \
curl \
ffmpeg \
git \
htop \
vim \
nano \
rsync \
wget \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*

RUN add-apt-repository ppa:deadsnakes/ppa
RUN apt-get update && apt-get install -y -qq python${PYTHON_VERSION} \
python${PYTHON_VERSION}-dev \
python${PYTHON_VERSION}-distutils

# Set python aliases
RUN update-alternatives --install /usr/bin/python python /usr/bin/python${PYTHON_VERSION} 1
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python get-pip.py

# default workdir
WORKDIR /home/workdir

# dev: install from source
COPY . .

#jaxmarl from source if needed, all the requirements
RUN pip install --ignore-installed -e '.[qlearning, dev]'

# install jax from to enable cuda
RUN pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
#disabling preallocation
ARG XLA_PYTHON_CLIENT_PREALLOCATE=false

#for jupyter
EXPOSE 9999
CMD ["/bin/bash"]
28 changes: 28 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
NVCC_RESULT := $(shell which nvcc 2> NULL; rm NULL)
NVCC_TEST := $(notdir $(NVCC_RESULT))
ifeq ($(NVCC_TEST),nvcc)
GPUS=--gpus all
else
GPUS=
endif


# Set flag for docker run command
BASE_FLAGS=-it --rm -v ${PWD}:/home/workdir --shm-size 20G
RUN_FLAGS=$(GPUS) $(BASE_FLAGS)

DOCKER_IMAGE_NAME = jaxmarl
IMAGE = $(DOCKER_IMAGE_NAME):latest
DOCKER_RUN=docker run $(RUN_FLAGS) $(IMAGE)
USE_CUDA = $(if $(GPUS),true,false)

# make file commands
build:
DOCKER_BUILDKIT=1 docker build --build-arg USE_CUDA=$(USE_CUDA) --tag $(IMAGE) --progress=plain ${PWD}/.

run:
$(DOCKER_RUN) /bin/bash

test:
$(DOCKER_RUN) /bin/bash -c "pytest ./tests/"

14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ import jax
from jaxmarl import make

key = jax.random.PRNGKey(0)
key, key_reset, key_act, key_step = jax.random.split(rng, 4)
key, key_reset, key_act, key_step = jax.random.split(key, 4)

# Initialise environment.
env = make('MPE_simple_world_comm_v3')
Expand All @@ -124,6 +124,16 @@ actions = {agent: env.action_space(agent).sample(key_act[i]) for i, agent in enu
obs, state, reward, done, infos = env.step(key_step, state, actions)
```

### Dockerfile 🐋
To help get experiments up and running we include a [Dockerfile](https://github.com/FLAIROx/JaxMARL/blob/main/Dockerfile) and its corresponding [Makefile](https://github.com/FLAIROx/JaxMARL/blob/main/Makefile). With Docker and the [Nvidia Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/index.html) installed, the container can be built with:
```
make build
```
The built container can then be run:
```
make run
```

## Contributing 🔨
Please contribute! Please take a look at our [contributing guide](https://github.com/FLAIROx/JaxMARL/blob/main/CONTRIBUTING.md) for how to add an environment/algorithm or submit a bug report. Our roadmap also lives there.

Expand All @@ -133,7 +143,7 @@ If you use JaxMARL in your work, please cite us as follows:
```
@article{flair2023jaxmarl,
title={JaxMARL: Multi-Agent RL Environments in JAX},
author={Alexander Rutherford and Benjamin Ellis and Matteo Gallici and Jonathan Cook and Andrei Lupu and Gardar Ingvarsson and Timon Willi and Akbir Khan and Christian Schroeder de Witt and Alexandra Souly and Saptarashmi Bandyopadhyay and Mikayel Samvelyan and Minqi Jiang and Robert Tjarko Lange and Shimon Whiteson and Bruno Lacerda and Nick Hawes and Tim Rocktaschel and Chris Lu and Jakob Nicolaus Foerster}
author={Alexander Rutherford and Benjamin Ellis and Matteo Gallici and Jonathan Cook and Andrei Lupu and Gardar Ingvarsson and Timon Willi and Akbir Khan and Christian Schroeder de Witt and Alexandra Souly and Saptarashmi Bandyopadhyay and Mikayel Samvelyan and Minqi Jiang and Robert Tjarko Lange and Shimon Whiteson and Bruno Lacerda and Nick Hawes and Tim Rocktaschel and Chris Lu and Jakob Nicolaus Foerster},
journal={arXiv preprint arXiv:2311.10090},
year={2023}
}
Expand Down
3 changes: 1 addition & 2 deletions jaxmarl/environments/mpe/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,7 @@ def set_actions(self, actions: Dict):
def _decode_continuous_action(
self, a_idx: int, action: chex.Array
) -> Tuple[chex.Array, chex.Array]:
u = jnp.array([action[1] - action[2], action[3] - action[4]])

u = jnp.array([action[2] - action[1], action[4] - action[3]])
u = u * self.accel[a_idx] * self.moveable[a_idx]
c = action[5:]
return u, c
Expand Down
2 changes: 1 addition & 1 deletion jaxmarl/environments/mpe/simple_speaker_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _decode_continuous_action(

u_act = action[LISTENER]

u_act = jnp.array([u_act[1] - u_act[2], u_act[3] - u_act[4]]) * self.accel[1]
u_act = jnp.array([u_act[2] - u_act[1], u_act[4] - u_act[3]]) * self.accel[1]
u = u.at[1].set(u_act)

return u, c
Expand Down
2 changes: 1 addition & 1 deletion jaxmarl/environments/mpe/simple_world_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _decode_continuous_action(
) -> Tuple[chex.Array, chex.Array]:
@partial(jax.vmap, in_axes=[0, 0])
def _set_u(a_idx, action):
u = jnp.array([action[1] - action[2], action[3] - action[4]])
u = jnp.array([action[2] - action[1], action[4] - action[3]])
return u * self.accel[a_idx] * self.moveable[a_idx]

lact = actions[self.leader]
Expand Down

0 comments on commit 7ccb5c1

Please sign in to comment.