Graph Neural Network for Cell Tracking in Microscopy Videos (ECCV 2022)
Install dependencies on linux enviroment (click to expand):
we provide conda envrioment setup dependencies - if you are not familiar with conda, please read about before starting# Enter to the code folder
cd cell-tracker-gnn
# create conda environment python=3.8 pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=11.1 faiss-gpu pytorch-lightning==1.4.9
conda create --name cell-tracking-challenge --file requirements-conda.txt
conda activate cell-tracking-challenge
# install other requirements
pip install -r requirements.txt
# install other requirements
python setup.py install
The directory structure of our implementation looks like (click to expand):
├── configs <- Hydra configuration files
│ ├── callbacks <- Callbacks configs
│ ├── datamodule <- Datamodule configs
│ ├── feat_extract <- Feature extraction configs
│ ├── logger <- Logger configs
│ ├── metric_learning <- Metric learning configs
│ ├── model <- Model configs
│ ├── trainer <- Trainer configs
│ │
│ ├── config.yaml <- Main project configuration file
│
├── data <- Project data
│
├── logs <- Logs generated by Hydra and PyTorch Lightning loggers
│
├── outputs <- Outputs generated by Hydra and tensorboard loggers when training deep metric learning model
│
│
├── src
│ ├── callbacks <- Lightning callbacks
│ ├── datamodules <- Lightning datamodules and dataset files used
│ │ ├── datasets <- Graph Dataset implementation
│ │ │ └── graph_dataset.py <- Graph Dataset implementation
│ │ ├── extract_features <- Extract features used for graph
│ │ │ ├── preprocess_seq2graph_2d.py <- Extract features for 2d dataset with full segmentation
│ │ │ ├── preprocess_seq2graph_3D.py <- Extract features for 3d dataset
│ │ │ └── preprocess_seq2graph_patch_based.py <- Extract features for 2d dataset with markers annotations
│ │ ├── celltrack_datamodule.py <- Lightning datamodules implementing split for train, valid and test using separate sequences for each
│ │ └── celltrack_datamodule_mulSeq.py <- Lightning datamodules implementing split for train, valid and test using combine sequences for each
│ │
│ ├── metrics <- Lightning metrics use to track performances
│ ├── models <- Lightning model + PyTorch models + PyTorch Geometric model
│ │ ├── modules <- models implementation
│ │ │ ├── celltrack_model.py <- complete model implementation
│ │ │ ├── edge_mpnn.py <- Edge-oriented message passing implementation
│ │ │ ├── mlp.py <- multilayer perceptron implementation
│ │ │ └── pdn_conv.py <- PDN-Conv implementation
│ │ └── celltrack_plmodel.py <- Lightning model implementing training routine
│ ├── utils <- Utility scripts
│ │ └── utils.py <- Utils features
│ │
│ └── train.py <- Training pipeline
│
├── src_metric_learning
│ ├── Data <- Data modules - datasets and sampler
│ │ ├── dataset_2D.py <- Implemetation of 2D dataset
│ │ ├── dataset_3D.py <- Implemetation of 3D dataset
│ │ └── sampler.py <- Implemetation of sampler used for batch construction
│ ├── modules <- Pytorch models
│ │ ├── resnet_2d <- Implemetation of ResNet for 2D dataset
│ │ │ ├── resnet.py <- Final models
│ │ │ └── utils_resnet.py <- Multiple ResNet blocks and models Implemetation
│ │ ├── resnet_3d <- Implemetation of ResNet for 3D dataset
│ │ │ ├── resnet.py <- Final models
│ └── └── └── utils_resnet.py <- Multiple ResNet blocks and models Implemetation
│
├── LICENSE <- Attribution-NonCommercial 4.0 International
├── README.md <- All information
│
├── requirements.txt <- File for installing python dependencies (specification of dependencies)
├── requirements-conda.txt <- File for conda environment creation (specification of dependencies)
│
├── run.py <- Run training of the complete model with any pipeline configuration of 'configs/config.yaml'
├── run_feat_extract.py <- Run feature extraction pipeline 'configs/feat_extract/feat_extract.yaml' configuration file
└── run_train_metric_learning.py <- Run training of any settings using 'configs/metric_learning/...' configuration files
Our code consists of 3 run files located on the 'home' directory of the project -run.py
, run_feat_extract.py
, and run_train_metric_learning.py
- dividing our project into 3 parts namely 'complete model', 'feature extraction', and 'metric learning', respectively. An overview of each is provided in the next few sentences:
- Metric Learning: is responsible for training a model for extracting features using the Pytorch Metric Learning library and building using a separate source code.(see src_metric_learning in #Project Structure). Before running this part, we should generate CSV files consisting of relevant information about the cells, used by the datasets in metric learning training.
- Feature Extraction: After training a discriminative model to extract features, we are extracting features used later to build our graphs.
- Complete Model: When all the required data is ready, we can use it to train a graph neural network model as presented in the main paper.
We summarize all the relevant command lines to produce a run, an explanation for each variable is provided in Training code Section below.
export CUDA_VISIBLE_DEVICES=0 # select GPU number
# run feat_extract for metric learning -
python run_feat_extract.py params.input_images="/allen/aics/assay-dev/users/Filip/Data/gnn-tracking/nucmorph_data/Training/goldilocks" params.input_masks="/allen/aics/assay-dev/users/Filip/Data/gnn-tracking/nucmorph_data/Training/goldilocks" params.input_seg="/allen/aics/assay-dev/users/Filip/Data/gnn-tracking/nucmorph_data/Training/goldilocks" params.output_csv="/allen/aics/assay-dev/users/Filip/Data/gnn-tracking/nucmorph_data/basic_features" params.sequences=['01'] params.seg_dir='_GT/TRA' params.basic=True
# run metric learning training -
python run_train_metric_learning.py dataset.kwargs.data_dir_img="/allen/aics/assay-dev/users/Filip/Data/gnn-tracking/nucmorph_data/Training/goldilocks" dataset.kwargs.data_dir_mask="/allen/aics/assay-dev/users/Filip/Data/gnn-tracking/nucmorph_data/Training/goldilocks" dataset.kwargs.dir_csv="/allen/aics/assay-dev/users/Filip/Data/gnn-tracking/nucmorph_data/basic_features/goldilocks" dataset.kwargs.subdir_mask='GT/TRA'
# run feat_extract for cell tracking training -
python run_feat_extract.py params.input_images="/allen/aics/assay-dev/users/Filip/Data/gnn-tracking/nucmorph_data/Training/goldilocks" params.input_masks="/allen/aics/assay-dev/users/Filip/Data/gnn-tracking/nucmorph_data/Training/goldilocks" params.input_seg="/allen/aics/assay-dev/users/Filip/Data/gnn-tracking/nucmorph_data/Training/goldilocks" params.output_csv="/allen/aics/assay-dev/users/Filip/Data/gnn-tracking/nucmorph_data/ct_features" params.sequences=['01'] params.seg_dir='_GT/TRA' params.basic=False params.input_model="/allen/aics/modeling/ritvik/projects/cell-tracker-gnn/outputs/2023-01-10/16-20-43/all_params.pth"
# cell tracking training run
python run.py datamodule.dataset_params.main_path="/allen/aics/assay-dev/users/Filip/Data/gnn-tracking/nucmorph_data/ct_features/goldilocks" datamodule.dataset_params.exp_name="2D_SIM"
To run evaluation, we provide an example script (submitted to CTC) and all the relevant files to run our code in src/inference
folder, details below:
# our model needs CSVs, so let's create from image and segmentation.
python src/inference/preprocess_seq2graph_clean.py -cs 40 -ii "/allen/aics/assay-dev/users/Filip/Data/gnn-tracking/nucmorph_data/Test/goldilocks/01" -iseg "/allen/aics/assay-dev/users/Filip/Data/gnn-tracking/nucmorph_data/Test/goldilocks/01_GT/TRA" -im "/allen/aics/modeling/ritvik/projects/cell-tracker-gnn/outputs/2022-11-24/11-26-48/all_params.pth" -oc "/allen/aics/assay-dev/users/Filip/Data/gnn-tracking/nucmorph_data/Test/goldilocks/01_CSV"
# export pythonpath
export PYTHONPATH=.
# run the prediction
python src/inference/inference_clean.py -mp "/allen/aics/modeling/ritvik/projects/cell-tracker-gnn/logs/runs/2023-01-11/13-36-40/checkpoints/epoch=72.ckpt" -ns "01" -oc "/allen/aics/assay-dev/users/Filip/Data/gnn-tracking/nucmorph_data/Test/goldilocks/"
# postprocess
python src/inference/postprocess_clean.py -modality "2D" -iseg "/allen/aics/assay-dev/users/Filip/Data/gnn-tracking/nucmorph_data/Test/goldilocks/01_GT/TRA" -oi "/allen/aics/assay-dev/users/Filip/Data/gnn-tracking/nucmorph_data/Test/goldilocks/01_RES_inference"
rm -r "${DATASET}/${SEQUENCE}_CSV" "${DATASET}/${SEQUENCE}_RES_inference" "${DATASET}/${SEQUENCE}_SEG_RES"
You should create the same script as above with the relevant parameters to trained models (which are elaborated above how to produce), In comments, we explain each variable. Please refer to the main paper and read about the segmentation algorithms used. Please refer to read about evaluation-methodology of the challenge here http://celltrackingchallenge.net/evaluation-methodology/ - it is also provided with the Command-line software packages that implement the TRA measure (publicly available in the link)
The submitted softwate and pretrained models to the cell tracking challenge are available at the Releases
If you find either the code or the paper useful for your research, cite our paper:
@inproceedings{ben2022graph,
title={Graph Neural Network for Cell Tracking in Microscopy Videos},
author={Ben-Haim, Tal and Riklin-Raviv, Tammy},
booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
year={2022},
}