Skip to content

jakobtroidl/neuron-shape-reasoning

Repository files navigation

Paper Data Models

Global Neuron Shape Reasoning with Point Affinity Transformers

This repository contains the official implementation of the paper "Global Neuron Shape Reasoning with Point Affinity Transformers" by Jakob Troidl, Johannes Knittel, Wanhua Li, Fangneng Zhan, Hanspeter Pfister*, and Srinivas Turaga* (*equal advising).

demo.mp4

Installation

conda create --name gnsr python=3.9
conda activate gnsr
pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu121
pip install torch_geometric # more details here https://pytorch-geometric.readthedocs.io/en/2.5.2/notes/installation.html
pip install -r requirements.txt
Troubleshooting & Versions!

All code was tested using PyTorch version 2.1.0 and Cuda version 12.1.

pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121

Getting Started

View download instructions for data and model checkpoints here. Training a Point Affinity Transformer Model on the FlyWire dataset:

python train_affinity.py \
    --data_path ./data/flywire_full_v783/train \
    --neuron_id_path ./data/flywire_full_v783/affinity/ol_family_balanced/affinity_train.csv \
    --fam_to_id_mapping ./data/flywire_full_v783/types/visual_neurons_family_to_id.json \
    --output_dir ./ckpt/flywire_affinity_train \
    --log_dir ./logs/flywire_affinity_train \
    --point_cloud_size 1024 \
    --data_global_scale_factor 659.88367 \
    --lr 1e-4 \

Testing a Pretrained Affinity Model on the FlyWire dataset:

python eval_affinity.py \
    --pth ./path/to/flywire_final.pth \
    --data_path ./data/flywire_full_v783/train \
    --neuron_id_path .data/flywire_full_v783/affinity/opticlobe_family/affinity_test_paper.csv \
    --output_dir ./ckpt/flywire_affinity_eval \
    --fam_to_id_mapping ./data/flywire_full_v783/types/visual_neurons_family_to_id.json \
    --point_cloud_size 1024 \
    --batch_size 1 \
    --data_global_scale_factor 659.88367 \
    --thresholds 0.8 \
    --store_tensors \
    --qual_results

Contrastive Neuron Embeddings

Train Deepset to produce contrastive neuron embeddings on the FlyWire dataset:

python train_contrastive.py \
    --model ae_d1024_m512 \
    --pth ./data/ckpt/flywire_affinity_final.pth \
    --data_path ./data/flywire_full_v783/train \
    --neuron_id_path ./data/flywire_full_v783/affinity/ol_family_balanced/affinity_train_metric_balanced.csv \
    --output_dir ./ckpt/flywire_deepset_train_normed \
    --fam_to_id_mapping ./data/flywire_full_v783/types/visual_neurons_family_to_id.json \
    --point_cloud_size 1024 \
    --batch_size 650 \
    --data_global_scale_factor 659.88367 \
    --depth 24 \
    --norm_emb

Test pretrained Deepset model on the FlyWire dataset:

python eval_contrastive.py \
    --model ae_d1024_m512 \
    --encoder_pth ./data/ckpt/flywire_affinity_final.pth \
    --deep_set_pth ./data/ckpt/flywire_deepset_final.pth \
    --data_path ./data/flywire_full_v783/train \
    --neuron_id_path ./data/flywire_full_v783/affinity/ol_family_balanced/affinity_test.csv \
    --output_dir ./ckpt/flywire_deepset_eval \
    --train_emb_path ./ckpt/flywire_deepset_train_normed/emb_ep_XXXX/ \
    --fam_to_id_mapping ./data/flywire_full_v783/types/visual_neurons_family_to_id.json \
    --point_cloud_size 1024 \
    --batch_size 650 \
    --data_global_scale_factor 659.88367 \
    --depth 24 \
    --norm_emb

Citation

@techreport{troidlgnsr2024,
  title = {Global Neuron Shape Reasoning with Point Affinity Transformers},
  author = {Troidl, Jakob and Knittel, Johannes and Li, Wanhua and Zhan, Fengnang and Pfister*, Hanspeter and Turaga*, Srinivas},
  journal = {bioRxiv},
  year = {2024},
  publisher = {Cold Spring Harbor Laboratory},
  keywords = {preprint}
}

Acknowledgements

We acknowledge NSF grants CRCNS-2309041, NCS-FO-2124179, and NIH grant R01HD104969. We also thank the HHMI Janelia Visiting Scientist Program and the Harvard Data Science Initiative Postdoctoral Fellowship for their support. The code is partially based on 3DShape2VecSet by Zhang et. al.

Contact

Please open an issue or contact Jakob Troidl (jtroidl@g.harvard.edu) for any questions or feedback.

Known Issues

In all datasets the train folder contains all data files (train + test). The actual train and test split is defined in the file given through the --neuron_id_path argument.

About

PyTorch Implementation of Global Neuron Shape Reasoning with Point Affinity Transformers

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published