This repo provides an official PyTorch implementation of "S2P: State-conditioned Image Synthesis for Data Augmentation in Offline Reinforcement Learning" (NeurIPS 2022). [paper]
conda create -n s2p python=3.8.5
conda activate s2p
conda install pytorch torchvision cudatoolkit=11.3 -c pytorch
pip install -r requirements.txt
Our experiments have been done with PyTorch 1.10.1, CUDA 11.4, Python 3.8.5 and Ubuntu 18.04. We use a single NVIDIA RTX A6000 for training, but you can still run our code with GPUs which have smaller memory by reducing the batchSize. A simpel visualziation of the generation results can be done by GPUs with 4GB of memory use.
We provide pre-trained weights of S2P in some environments for simple test of the generation performance.
Create a folder ./checkpoints
and download the model weights into it.
Here are model weights of S2P trained on cheetah and walker environment of DeepMind Controp Suite.
Env_type | model |
---|---|
cheetah | cheetah_30.pth |
walker | walker_30.pth |
We provide pre-trained models of S2P and some tiny dataset for simple visualization of S2P.
Reviewers can easily visualize N-step generation results with --seq_len
.
python simple_test.py --env_type=cheetah --dataroot=./datasets --netG=s2p --start_idx=0 --seq_len=5 --gpu_ids=0