Code for ICML 2022 paper "Set Norm and Equivariant Skip Connections: Putting the Deep in Deep Sets."
Lily Zhang, Veronica Tozzo, John M. Higgins, Rajesh Ranganath
Clone the repository on your computer with the following command
https://github.com/veronicatozzo/deep_permutation_invariant.git
The code was tested on Linux Os with Python 3.7.9, CUDA version 10.1, GCC version 6.2.
The required libraries are listed in requirements.txt
, we suggest you create a virtual environment with Anaconda and install the libraries in it.
cd deep_permutation_invariant
conda create --name deepperminv python=3.7.9
conda activate deepperminv
pip install -r requirements.txt
We provide a novel single-cell benchmark datasets for prediction sets, called Flow-RBC. The dataset consists of 98,240 train and 23,104 test red blood cell (RBC) distributions. Each distribution consists of volume and hemoglobin mass flow cytometry measurements collected retrospectively at Massachussets General Hospital under an existing IRB-approved research protocol. The regression task consists in predicting, from a distribution, the hematocrit level measured on the same speciment. Hematocrit is the fraction of blood volume occupied by red blood cells and good prediction outcomes suggest a stronger relationship between single cells properties and aggregated population properties in the human blood than previously known.
Flow-RBC can be downloaded at the this link.
All other set datasets (with the exception of Anemia which is not publicly available), can be derived by the following open-source datasets:
- Point Cloud: ModelNet40. We use the HDF5 files downloaded from the Pointnet repository from this link. (As of Oct 2022, the issued certificate has expired, so downloading will require an additional
--no-check-certificate
flag.) - MNIST Variance: MNIST.
- CelebA Set Anomaly: CelebA.
See scripts
for the scripts to generate the h5 files of the datasets used in this repository.
The code is organized as follows:
-
deep_permutation_invariant/models
: re-implementations of SetTransformer Lee et al. 2019 and DeepSets Zaheer et al. 2017 for any depth of the encoder and for any lengths of the input sets. It also contains the implementation of SetTransformer++ and DeepSets++. -
deep_permutation_invariant/datasets
: dataloaders for the real and synthetic datasets used in the paper. The downloaded datasets will be stored in a subfolderdata
-
deep_permutation_invariant/experiments
: experiment scripts -
deep_permutation_invariant/configs
: hyperparameters setup for the experiments
The scripts for the main results in the papers are experiments/deep_sets_analysis.py
and experiments/set_transformer_analysis.py
. Each script takes in input a model (deepsets
or deepsets2
, settransformer
or settransformer2
), a task (hematocrit
, pointcloud
, normalvar
, mnistvar
), and a norm (set_norm
, layer_norm
, feature_norm
).
To run an experiment:
python experiments/set_transformer_analysis.py --model=settransformer2 --norm=set_norm --task=hematocrit --seed=0 --turnoff_wandb
the script will automatically save the trained model in the folder results/settransformer
(results/deepsets
).
The code requires 24GB GPUs to run with a batch size of 64. If using GPUs that do not meet this memory requirement, one can use gradient accumulation by specifying accum_steps
to a value greater than 1.
To log all the results to Weights & Biases (wandb), after creating an account and installing wandb in your conda environment
conda install -c conda-forge wandb
login with your api-key (you'll find it in the settings on your wandb account).
wandb login <api-key>
You can then create the sweeps using the provided configuration files
wandb sweep configs/set_transformer.yaml
you will see in the prompt the following lines
wandb: Creating sweep from: configs/settransformer.yaml
wandb: Created sweep with ID: <sweepid>
wandb: View sweep at: https://wandb.ai/<username>/results/sweeps/<sweepid>
wandb: Run sweep agent with: wandb agent <username>/results/<sweepid>
You can kick off the runs executing wandb agent <username>/results/<sweepid>
directly or with any scheduler. All the train and test loss curves will be logged at https://wandb.ai/<username>/results/sweeps/<sweepid>
.