This repository is based on discrete Walk-Jump Sampling.
PyTorch implementation for our paper
Antibody sequence optimization with gradient-guided discrete walk-jump sampling.
Anonymous authors.
We train discriminators on the smoothed data manifold 💊 and use its gradient to guide the walk and the jump step in dWJS. We call our method gradient-guided discrete Walk-Jump Sampling (gg-dWJS) 🐾 🚶 🏃
First, clone our repository.
cd ./gg-dWJS
Install Anaconda environment in case not available. Run the following command next.
./scripts/install.sh
conda activate wj
Change data
in train.yaml
- data: mnist.yaml
Change model_cfg/arch
in denoise.yaml
- model_cfg/arch: unet
To train the denoise model, run
walkjump_train
First download the binarized MNIST dataset with labels from Kaggle and save it in data/mnist_with_labels
.
Next, change data
, model
in train.yaml
- data: mnist_labels.yaml
- model: mnist_classifier.yaml
Change model_cfg/arch
in mnist_classifier.yaml
- model_cfg/arch: convnet
To train the denoise model, run
walkjump_train
To sample, first make the following changes to the sample.yaml
file
guide_path: provide the checkpoint directory here
guidance: true
Then run
walkjump_sample
We also provide a script plot.py
to easily visualize the generated MNIST images.
Finally use PyTorch-FID to calculate the FID between the samples.
QVQLQES-GPGLVKPSETLSLTCTVSG-GSIST-----YNWNWIRQSPGKGLEWIGEIYH----SGSTYYNPSLKSRVTLSVDTSKKQFSLKLTSVTAADTAIYYCARLGPYYSY--S------------SYSRGFDVWGRGTLVTVSSSYVLTQP-PSVSVSPGQTATLTCGLST--NLDN-----YHVHWYQQKPGQAPRTLIYR--------ADTRLSGVPERFSGSKSG--DTATLTITGVQAGDEGDYYCQLSDSG----------------------GSVVFGTGTKVTVL
Change data
in train.yaml
- data: poas.yaml
Change model_cfg/arch
in denoise.yaml
- model_cfg/arch: bytenet
To train the denoise model, run
walkjump_train
Change data
, model
in train.yaml
- data: poas_labels_yaml
- model: ab_discriminator.yaml
Change model_cfg/arch
in ab_discriminator.yaml
- model_cfg/arch: bytenet_mlp
To train the denoise model, run
walkjump_train
To sample, first make the following changes to the sample.yaml
file
guide_path: provide the checkpoint directory here
guidance: true
output_csv: provide the path here
Then run
walkjump_sample
To evaluate % beta sheet and instability index, run the following script
python samples/eval_criteria.py --dir [directory of the CSV file]
Make sure to store the sample csv file accordingly.
To evaluate DCS, run the following script
python samples/eval_dcs.py --dir [directory of the CSV file]
Below are the directories to the datasets used in this work.
Experiments | Data |
---|---|
POAS experiment | data/poas_train_2.csv.gz |
hu4D5 experiment | data/HER2.csv |
MNIST experiment | data/binarized_mnist_[train/test/valid].amat |
Below are the directories to all the checkpoints trained and used in this work.
Checkpoint | Directory |
---|---|
POAS % beta sheet discriminator | checkpoints/ab_beta_perc_check |
POAS instability index discriminator | checkpoints/ab_instability_index_check |
POAS score model | checkpoints/poas_check |
POAS preference conditional discriminator | checkpoints/ab_singlevalue_ar_beta |
hu4D5 classifier | checkpoints/her_classifier |
hu4D5 score model | checkpoints/her_score |
hu4D5 discriminator | checkpoints/her_noised_classifier |
If you use the code and/or model checkpoints, please cite:
@inproceedings{ikram2024antibody,
title={Antibody sequence optimization with gradient-guided discrete walk-jump sampling},
author={Ikram, Zarif and Liu, Dianbo and Rahman, M Saifur},
booktitle={ICLR 2024 Workshop on Generative and Experimental Perspectives for Biomolecular Design}
}