Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to Flax #17

Merged
merged 1 commit into from
Oct 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions .github/workflows/unit-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install tensorflow-cpu==2.12.0
pip install jax==0.4.8
pip install jaxlib==0.4.7
pip install jax==0.4.14
pip install jaxlib==0.4.14
pip install -r docker/requirements.txt
pip install -e imgx
pip install -e imgx_datasets
pip install -e .
- name: Test with pytest
run: |
pytest --splits 4 --group ${{ matrix.group }} --randomly-seed=0 -k "not slow" imgx/tests/unit
pytest --splits 4 --group ${{ matrix.group }} --randomly-seed=0 -k "not slow" imgx_datasets/tests
pytest --splits 4 --group ${{ matrix.group }} --randomly-seed=0 -k "not slow"
7 changes: 0 additions & 7 deletions .isort.cfg

This file was deleted.

16 changes: 8 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ default_language_version:
python: python3
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-added-large-files
- id: check-ast
Expand All @@ -27,13 +27,13 @@ repos:
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 23.7.0
rev: 23.10.0
hooks:
- id: black
args:
- --line-length=80
- --line-length=100
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.4.1
rev: v1.6.1
hooks: # https://github.com/python/mypy/issues/4008#issuecomment-582458665
- id: mypy
name: mypy-imgx
Expand Down Expand Up @@ -72,23 +72,23 @@ repos:
--warn-unreachable,
]
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.0.0
rev: v3.0.3
hooks:
- id: prettier
args:
- --print-width=80
- --print-width=100
- --prose-wrap=always
- --tab-width=2
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.0.280"
rev: "v0.1.1"
hooks:
- id: ruff
- repo: https://github.com/pre-commit/mirrors-pylint
rev: v3.0.0a5
hooks:
- id: pylint
- repo: https://github.com/asottile/pyupgrade
rev: v3.9.0
rev: v3.15.0
hooks:
- id: pyupgrade
args:
Expand Down
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ generated-members=
[FORMAT]

# Maximum number of characters on a single line.
max-line-length=80
max-line-length=100

# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
# lines made too long by directives to pytype.
Expand Down
21 changes: 7 additions & 14 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
pip:
pip install -e imgx
pip install -e imgx_datasets
pip install -e .

test:
pytest --cov=imgx -n 4 imgx/tests -x
pytest --cov=imgx_datasets -n 4 imgx_datasets/tests -x
pytest --cov=imgx -n 4 imgx
pytest --cov=imgx_datasets -n 4 imgx_datasets

build_dataset:
tfds build imgx_datasets/imgx_datasets/male_pelvic_mr &
tfds build imgx_datasets/imgx_datasets/amos_ct &
tfds build imgx_datasets/imgx_datasets/muscle_us &
tfds build imgx_datasets/imgx_datasets/brats2021_mr &

rebuild_dataset:
tfds build imgx_datasets/imgx_datasets/male_pelvic_mr --overwrite &
tfds build imgx_datasets/imgx_datasets/amos_ct --overwrite &
tfds build imgx_datasets/imgx_datasets/muscle_us --overwrite &
tfds build imgx_datasets/imgx_datasets/brats2021_mr --overwrite &
tfds build imgx_datasets/male_pelvic_mr
tfds build imgx_datasets/amos_ct
tfds build imgx_datasets/muscle_us
tfds build imgx_datasets/brats2021_mr
150 changes: 74 additions & 76 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,68 +1,80 @@
# A Recycling Training Strategy for Medical Image Segmentation with Diffusion Denoising Models
# ImgX-DiffSeg

:tada: This is a follow-up work of Importance of Aligning Training Strategy with
Evaluation for Diffusion Models in 3D Multiclass Segmentation
([paper](https://arxiv.org/abs/2303.06040),
[code](https://github.com/mathpluscode/ImgX-DiffSeg/tree/v0.1.0)), with better
recycling method, better network, more baseline training methods (including
self-conditioning) on four data sets (muscle ultrasound, male pelvic MR,
abdominal CT, brain MR).
ImgX-DiffSeg is a Jax-based deep learning toolkit using Flax for biomedical image segmentations.

:bookmark_tabs: The preprint is available on
[arXiv](https://arxiv.org/abs/2308.16355).
This repository includes the implementation of the following work

- [A Recycling Training Strategy for Medical Image Segmentation with Diffusion Denoising Models](https://arxiv.org/abs/2308.16355)
- [Importance of Aligning Training Strategy with Evaluation for Diffusion Models in 3D Multiclass Segmentation](https://arxiv.org/abs/2303.06040)

<div>
<img src="images/diffusion_training_strategy_diagram.png" width="600" alt="diffusion_training_strategy_diagram"></img>
</div>

---

ImgX is a Jax-based deep learning toolkit for biomedical image segmentations.
## Features

Current supported functionalities are summarized as follows.

**Data sets**

See the [readme](imgx_datasets/README.md) for details on training, validation,
and test splits.
See the [readme](imgx_datasets/README.md) for further details.

- [x] Muscle ultrasound from
[Marzola et al. 2021](https://data.mendeley.com/datasets/3jykz7wz8d/1).
- [x] Male pelvic MR from
[Li et al. 2022](https://zenodo.org/record/7013610#.Y1U95-zMKrM).
- [x] AMOS CT from
[Ji et al. 2022](https://zenodo.org/record/7155725#.ZAN4BuzP2rO).
- [x] Brain MR from [Baid et al. 2021](https://arxiv.org/abs/2107.02314).
- Muscle ultrasound from [Marzola et al. 2021](https://data.mendeley.com/datasets/3jykz7wz8d/1).
- Male pelvic MR from [Li et al. 2022](https://zenodo.org/record/7013610#.Y1U95-zMKrM).
- AMOS CT from [Ji et al. 2022](https://zenodo.org/record/7155725#.ZAN4BuzP2rO).
- Brain MR from [Baid et al. 2021](https://arxiv.org/abs/2107.02314).

**Algorithms**

- [x] Supervised segmentation.
- [x] Diffusion-based segmentation.
- [x] Gaussian noise based diffusion.
- [x] Prediction of noise or ground truth.
- [x] Training with recycling or self-conditioning.
- Supervised segmentation.
- Diffusion-based segmentation.
- [Gaussian noise based diffusion](https://arxiv.org/abs/2211.00611).
- Noise prediction ([epsilon-parameterization](https://arxiv.org/abs/2006.11239)) or ground truth
prediction ([x0-parameterization](https://arxiv.org/abs/2102.09672)).
- [Importance sampling](https://arxiv.org/abs/2102.09672) for timestep.
- Recycling training strategies, including [xt-recycling](https://arxiv.org/abs/2303.06040) and
[xT-recycling](https://arxiv.org/abs/2308.16355).
- Self-conditioning training strategies, including
[Chen et al. 2022](https://arxiv.org/abs/2208.04202) and
[Watson et al. 2023.](https://www.nature.com/articles/s41586-023-06415-8).

**Models**

- [x] U-Net with Transformers supporting 2D and 3D images.
- [U-Net](https://arxiv.org/abs/1505.04597) with [Transformers](https://arxiv.org/abs/1706.03762)
supporting 2D and 3D images.
- [Efficient attention](https://arxiv.org/abs/2112.05682).

**Training**

- [x] Patch-based training.
- [x] Multi-device training (one model per device).
- [x] Mixed precision training.
- [x] Gradient clipping and accumulation.
- Patch-based training.
- Multi-device training (one model per device) with
[`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html).
- Mixed precision training.
- Gradient clipping and accumulation.
- [Early stopping](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html).

**Changelog**

---
- October 2023: Migrated from [Haiku](https://github.com/google-deepmind/dm-haiku) to
[Flax](https://github.com/google/flax) following Google DeepMind's recommendation.

## Installation

### TPU with Docker

The following instructions have been tested only for TPU-v3-8. The docker
container uses root user.
The following instructions have been tested only for TPU-v3-8. The docker container uses root user.

1. Build the docker image inside the repository.
1. TPU often has limited disk space.
[RAM disk](https://www.linuxbabe.com/command-line/create-ramdisk-linux) can be used to help.

```bash
sudo mkdir /tmp/ramdisk
sudo chmod 777 /tmp/ramdisk
sudo mount -t tmpfs -o size=256G imgxramdisk /tmp/ramdisk
cd /tmp/ramdisk/
```

2. Build the docker image inside the repository.

```bash
sudo docker build --build-arg USER_ID=$(id -u) --build-arg GROUP_ID=$(id -g) -f docker/Dockerfile.tpu -t imgx .
Expand All @@ -74,7 +86,7 @@ container uses root user.
- `-f` provides the docker file.
- `-t` tag the docker image.

2. Run the Docker container.
3. Run the Docker container.

```bash
mkdir -p $(cd ../ && pwd)/tensorflow_datasets
Expand All @@ -84,27 +96,16 @@ container uses root user.
imgx bash
```

3. Install the package inside container.
4. Install the package inside container.

```bash
make pip
```

TPU often has limited disk space.
[RAM disk](https://www.linuxbabe.com/command-line/create-ramdisk-linux) can be
used to help.

```bash
sudo mkdir /tmp/ramdisk
sudo chmod 777 /tmp/ramdisk
sudo mount -t tmpfs -o size=256G imgxramdisk /tmp/ramdisk
cd /tmp/ramdisk/
```

### GPU with Docker

The following instructions have been tested only for CUDA == 11.4.1 and CUDNN ==
8.2.0. The docker container uses non-root user.
The following instructions have been tested only for CUDA == 11.4.1 and CUDNN == 8.2.0. The docker
container uses non-root user.
[Docker image used may be removed.](https://gitlab.com/nvidia/container-images/cuda/blob/master/doc/support-policy.md)

1. Build the docker image inside the repository.
Expand Down Expand Up @@ -155,8 +156,8 @@ conda env update -f docker/environment_mac_m1.yml

#### Install Conda for Linux / Mac Intel

[Install Conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html)
and then create the environment.
[Install Conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html) and
then create the environment.

```bash
conda install -y -n base conda-libmamba-solver
Expand All @@ -175,9 +176,8 @@ make pip

## Build Data Sets

Use the following commands to (re)build all data sets. Check the
[README](imgx_datasets/README.md) of imgx_datasets for details. Especially,
manual downloading is required for the BraTS 2021 dataset.
Use the following commands to (re)build all data sets. Check the [README](imgx_datasets/README.md)
of imgx_datasets for details. Especially, manual downloading is required for the BraTS 2021 dataset.

```bash
make build_dataset
Expand All @@ -188,15 +188,16 @@ make rebuild_dataset

### Training and Testing

Example command to use two GPUs for training, validation and testing. The
outputs are stored under `wandb/latest-run/files/`, where
Example command to use two GPUs for training, validation and testing. The outputs are stored under
`wandb/latest-run/files/`, where

- `ckpt` stores the model checkpoints and corresponding validation metrics.
- `test_evaluation` stores the prediction on test set and corresponding metrics.

```bash
# limit to two GPUs if using NVIDIA GPUs
export CUDA_VISIBLE_DEVICES="0,1"

# select data set to use
export DATASET_NAME="male_pelvic_mr"
export DATASET_NAME="amos_ct"
Expand All @@ -216,16 +217,14 @@ imgx_valid --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDIM
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDIM
```

```bash
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --num_seeds 3
```

Optionally, for debug purposes, use flag `debug=True` to run the experiment with
a small dataset and smaller models.
Optionally, for debug purposes, use flag `debug=True` to run the experiment with a small dataset and
smaller models.

```bash
imgx_train --config-name config_${DATASET_NAME}_seg debug=True
imgx_test --log_dir wandb/latest-run/
imgx_train --config-name config_${DATASET_NAME}_diff_seg debug=True
imgx_test --log_dir wandb/latest-run/ --num_timesteps 5 --sampler DDPM
```

## Code Quality
Expand All @@ -248,11 +247,12 @@ pre-commit run --all-files

### Code Test

Run the command below to test and get coverage report. As JAX tests requires two
CPUs, `-n 4` uses 4 threads, therefore requires 8 CPUs in total.
Run the command below to test and get coverage report. As JAX tests requires two CPUs, `-n 4` uses 4
threads, therefore requires 8 CPUs in total.

```bash
pytest --cov=imgx -n 4 tests
pytest --cov=imgx -n 4 imgx
pytest --cov=imgx_datasets -n 4 imgx_datasets
```

## References
Expand All @@ -266,21 +266,19 @@ pytest --cov=imgx -n 4 tests
- [Scenic (JAX)](https://github.com/google-research/scenic/)
- [DeepMind Research (JAX)](https://github.com/deepmind/deepmind-research/tree/master/ogb_lsc/)
- [Haiku (JAX)](https://github.com/deepmind/dm-haiku/)
- [Flax (JAX)](https://github.com/google/flax)

## Acknowledgement

This work was supported by the EPSRC grant (EP/T029404/1), the Wellcome/EPSRC
Centre for Interventional and Surgical Sciences (203145Z/16/Z), the
International Alliance for Cancer Early Detection, an alliance between Cancer
Research UK (C28070/A30912, C73666/A31378), Canary Center at Stanford
University, the University of Cambridge, OHSU Knight Cancer Institute,
University College London and the University of Manchester, and Cloud TPUs from
Google's TPU Research Cloud (TRC).
This work was supported by the EPSRC grant (EP/T029404/1), the Wellcome/EPSRC Centre for
Interventional and Surgical Sciences (203145Z/16/Z), the International Alliance for Cancer Early
Detection, an alliance between Cancer Research UK (C28070/A30912, C73666/A31378), Canary Center at
Stanford University, the University of Cambridge, OHSU Knight Cancer Institute, University College
London and the University of Manchester, and Cloud TPUs from Google's TPU Research Cloud (TRC).

## Citation

If you find the code base and method useful in your research, please cite the
relevant paper:
If you find the code base and method useful in your research, please cite the relevant paper:

```bibtex
@article{fu2023recycling,
Expand Down
4 changes: 2 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ COPY docker/requirements.txt /${USER}/requirements.txt

RUN /${USER}/conda/bin/pip3 install --upgrade pip \
&& /${USER}/conda/bin/pip3 install \
jax==0.4.8 \
jaxlib==0.4.7+cuda11.cudnn86 \
jax==0.4.14 \
jaxlib==0.4.14+cuda11.cudnn86 \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
&& /${USER}/conda/bin/pip3 install tensorflow-cpu==2.12.0 \
&& /${USER}/conda/bin/pip3 install -r /${USER}/requirements.txt
Expand Down
Loading