Skip to content

Commit

Permalink
Refactor code and add new functionalities
Browse files Browse the repository at this point in the history
  • Loading branch information
mathpluscode committed Dec 29, 2023
1 parent 230ecc6 commit ae49d83
Show file tree
Hide file tree
Showing 113 changed files with 4,578 additions and 1,343 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ jobs:
python-version: ["3.9"]
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
check-latest: true
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/unit-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ jobs:

steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
check-latest: true
python-version: ${{ matrix.python-version }}
Expand All @@ -28,9 +28,9 @@ jobs:
docker/environment_mac_m1.yml
docker/Dockerfile
docker/requirements.txt
pyproject.toml
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install tensorflow-cpu==2.12.0
pip install jax==0.4.20
pip install jaxlib==0.4.20
Expand Down
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@ repos:
files: "docker/requirements.txt"
- id: trailing-whitespace
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 23.11.0
rev: 23.12.1
hooks:
- id: black
args:
- --line-length=100
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.7.1
rev: v1.8.0
hooks: # https://github.com/python/mypy/issues/4008#issuecomment-582458665
- id: mypy
name: mypy
Expand All @@ -54,15 +54,15 @@ repos:
--warn-unreachable,
]
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.1.0
rev: v4.0.0-alpha.8
hooks:
- id: prettier
args:
- --print-width=100
- --prose-wrap=always
- --tab-width=2
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.1.6"
rev: "v0.1.9"
hooks:
- id: ruff
- repo: https://github.com/pre-commit/mirrors-pylint
Expand Down
9 changes: 4 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ pip:

test:
pytest --cov=imgx -n 4 imgx
pytest --cov=imgx_datasets -n 4 imgx_datasets

build_dataset:
tfds build imgx_datasets/male_pelvic_mr
tfds build imgx_datasets/amos_ct
tfds build imgx_datasets/muscle_us
tfds build imgx_datasets/brats2021_mr
tfds build imgx/datasets/male_pelvic_mr
tfds build imgx/datasets/amos_ct
tfds build imgx/datasets/muscle_us
tfds build imgx/datasets/brats2021_mr
38 changes: 23 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,24 @@ This repository includes the implementation of the following work
:construction: **The codebase is still under active development for more enhancements and
applications.** :construction:

- November 2023:
- :warning: Upgraded to JAX to 0.4.20.
- :warning: Removed Haiku-specific modification to convolutional layers. This may impact model
performance.
- :smiley: Added example notebooks for inference on a single image without TFDS.
- Added integration tests for training, validation, and testing.
- Refactored config.
- Added `patch_size` and `scale_factor` to data config.
- Moved loss config from the main config to task config.
- Refactored code, including defining `imgx/task` submodule.
- October 2023:
- :blush: Migrated from [Haiku](https://github.com/google-deepmind/dm-haiku) to
[Flax](https://github.com/google/flax) following Google DeepMind's recommendation.
- December 2023:

:warning: This release changed network architecture and training strategies, which may impact
model performance.

- Network
- Added dropout in U-net.
- Aligned Transformer to haiku implementation.
- Training
- Support anisotropic volumes for data augmentation.
- Added data augmentation including, random gamma adjustment, random flip, random shearing.
- Functionalities
- Added registration related metrics and losses.
- Refactoring
- Moved data set iterator out of `Experiment` to facilitate using non-TFDS data sets.
- Merged `imgx_datasets` into `imgx`.
- Used `jax.random.fold_in` for random key splitting to avoid passing key between functions.
- Used `optax.softmax_cross_entropy` to replace custom implementation.

:mailbox: Please feel free to
[create an issue](https://github.com/mathpluscode/ImgX-DiffSeg/issues/new/choose) to request
Expand All @@ -38,7 +43,7 @@ Current supported functionalities are summarized as follows.

**Data sets**

See the [readme](imgx_datasets/README.md) for further details.
See the [readme](imgx/datasets/README.md) for further details.

- Muscle ultrasound from [Marzola et al. 2021](https://data.mendeley.com/datasets/3jykz7wz8d/1).
- Male pelvic MR from [Li et al. 2022](https://zenodo.org/record/7013610#.Y1U95-zMKrM).
Expand Down Expand Up @@ -68,6 +73,10 @@ See the [readme](imgx_datasets/README.md) for further details.
**Training**

- Patch-based training.
- Data augmentation with anisotropic support, including
- Random affine: rotation, scaling, shearing, shifting.
- Random gamma adjustment.
- Random flip.
- Multi-device training (one model per device) with
[`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html).
- Mixed precision training.
Expand Down Expand Up @@ -284,7 +293,6 @@ Run the command below to test and get a coverage report. As JAX tests require tw

```bash
pytest --cov=imgx -n 4 imgx -k "not integration"
pytest --cov=imgx_datasets -n 4 imgx_datasets
```

`-k "not integration"` excludes integration tests, which require downloading muscle ultrasound and
Expand Down
14 changes: 8 additions & 6 deletions docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
SimpleITK==2.3.1
chex==0.1.8
coverage==7.3.2
coverage==7.3.3
flax==0.7.5
hydra-core==1.3.2
kaggle==1.5.16
matplotlib==3.8.2
nbmake==1.4.6
numpy==1.26.2
opencv-python==4.8.1.78
optax==0.1.7
pandas==2.1.3
pre-commit==3.5.0
pandas==2.1.4
pre-commit==3.6.0
protobuf==3.20.3 # https://github.com/tensorflow/datasets/issues/4858
pytest-cov==4.1.0
pytest-mock==3.12.0
Expand All @@ -19,8 +20,9 @@ pytest-xdist==3.5.0
pytest==7.4.3
rdkit-pypi==2022.9.5
rich==13.7.0
ruff==0.1.6
ruff==0.1.8
tensorflow-datasets==4.9.3
torch==2.1.1 # for testing only
wandb==0.16.0
tomli==2.0.1
torch==2.1.2 # for testing only
wandb==0.16.1
wily==1.25.0
20 changes: 11 additions & 9 deletions examples/segmentation/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ data:
- 0
- 0
data_augmentation:
max_rotation:
- 0.088
max_translation:
- 10
- 10
max_scaling:
- 0.15
- 0.15
max_rotation: 30
max_zoom: 0.2
max_shear: 30
max_shift: 0.3
max_log_gamma: 0.3
v_min: 0.0
v_max: 1.0
p: 0.5
trainer:
max_num_samples: 512000
batch_size: 64
Expand All @@ -40,6 +40,7 @@ task:
scale_factor:
- 2
- 2
num_res_blocks: 2
num_channels:
- 8
- 16
Expand All @@ -49,10 +50,11 @@ task:
num_heads: 8
widening_factor: 4
num_transform_layers: 1
dropout: 0.1
loss:
dice: 1.0
cross_entropy: 0.0
focal: 20.0
focal: 1.0
early_stopping:
metric: mean_binary_dice_score_without_background
mode: max
Expand Down
Binary file not shown.
Binary file not shown.
24 changes: 13 additions & 11 deletions examples/segmentation/inference.ipynb

Large diffs are not rendered by default.

11 changes: 8 additions & 3 deletions imgx/conf/data/amos_ct.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ loader:
patch_shape: [128, 128, 128]
patch_overlap: [64, 0, 0] # image shape is [192, 128, 128]
data_augmentation:
max_rotation: [0.088, 0.088, 0.088] # roughly 5 degrees
max_translation: [10, 10, 10]
max_scaling: [0.15, 0.15, 0.15]
max_rotation: 30 # degrees
max_zoom: 0.2 # as a fraction of the image size
max_shear: 30 # degrees
max_shift: 0.3 # as a fraction of the image size
max_log_gamma: 0.3
v_min: 0.0 # minimum value for intensity
v_max: 1.0 # maximum value for intensity
p: 0.5 # probability of applying each augmentation

trainer:
max_num_samples: 100_000
Expand Down
11 changes: 8 additions & 3 deletions imgx/conf/data/brats2021_mr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ loader:
patch_shape: [128, 128, 128]
patch_overlap: [0, 0, 32] # image shape is [179, 219, 155]
data_augmentation:
max_rotation: [0.088, 0.088, 0.088] # roughly 5 degrees
max_translation: [20, 20, 4]
max_scaling: [0.15, 0.15, 0.15]
max_rotation: 30 # degrees
max_zoom: 0.2 # as a fraction of the image size
max_shear: 30 # degrees
max_shift: 0.3 # as a fraction of the image size
max_log_gamma: 0.3
v_min: 0.0 # minimum value for intensity
v_max: 1.0 # maximum value for intensity
p: 0.5 # probability of applying each augmentation

trainer:
max_num_samples: 100_000
Expand Down
13 changes: 9 additions & 4 deletions imgx/conf/data/male_pelvic_mr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@ loader:
patch_shape: [256, 256, 32]
patch_overlap: [0, 0, 16] # image shape is [256, 256, 48]
data_augmentation:
max_rotation: [0.088, 0.088, 0.088] # roughly 5 degrees
max_translation: [20, 20, 4]
max_scaling: [0.15, 0.15, 0.15]
max_rotation: 30 # degrees
max_zoom: 0.2 # as a fraction of the image size
max_shear: 30 # degrees
max_shift: 0.3 # as a fraction of the image size
max_log_gamma: 0.3
v_min: 0.0 # minimum value for intensity
v_max: 1.0 # maximum value for intensity
p: 0.5 # probability of applying each augmentation

trainer:
max_num_samples: 100_000
batch_size: 8 # all model replicas are updated every `batch_size` samples
batch_size_per_replica: 1 # each model replicate takes `batch_size_per_replica` samples per step
num_devices_per_replica: 1 # model is split into num_devices_per_replica shards/slices

patch_size: [2, 2, 2]
patch_size: [2, 2, 1] # do not downsample z axis
scale_factor: [2, 2, 2]
11 changes: 8 additions & 3 deletions imgx/conf/data/muscle_us.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ loader:
patch_shape: [480, 512]
patch_overlap: [0, 0]
data_augmentation:
max_rotation: [0.088] # roughly 5 degrees
max_translation: [10, 10]
max_scaling: [0.15, 0.15]
max_rotation: 30 # degrees
max_zoom: 0.2 # as a fraction of the image size
max_shear: 30 # degrees
max_shift: 0.3 # as a fraction of the image size
max_log_gamma: 0.3
v_min: 0.0 # minimum value for intensity
v_max: 1.0 # maximum value for intensity
p: 0.5 # probability of applying each augmentation

trainer:
max_num_samples: 512_000
Expand Down
6 changes: 4 additions & 2 deletions imgx/conf/task/gaussian_diff_seg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,22 @@ model:
num_spatial_dims: 3
patch_size: MISSING # data dependent, will be set after loading config
scale_factor: MISSING # data dependent, will be set after loading config
num_res_blocks: 2
num_channels: [32, 64, 128, 256]
out_channels: MISSING # data dependent, will be set after loading config
num_heads: 8
widening_factor: 4
dropout: 0.1

loss:
dice: 1.0
cross_entropy: 0.0
focal: 20.0
focal: 1.0
mse: 0.0
vlb: 0.0

early_stopping: # used on validation set
metric: "mean_binary_dice_score_without_background"
mode: "max"
min_delta: 0.0001
patience: 5
patience: 10
6 changes: 4 additions & 2 deletions imgx/conf/task/seg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@ model:
num_spatial_dims: 3
patch_size: MISSING # data dependent, will be set after loading config
scale_factor: MISSING # data dependent, will be set after loading config
num_res_blocks: 2
num_channels: [32, 64, 128, 256]
out_channels: MISSING # data dependent, will be set after loading config
num_heads: 8
widening_factor: 4
num_transform_layers: 1
dropout: 0.1

loss:
dice: 1.0
cross_entropy: 0.0
focal: 20.0
focal: 1.0

early_stopping: # used on validation set
metric: "mean_binary_dice_score_without_background"
mode: "max"
min_delta: 0.0001
patience: 5
patience: 10
8 changes: 0 additions & 8 deletions imgx/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1 @@
"""Module to handle data."""
from __future__ import annotations

from typing import Callable

import jax
import jax.numpy as jnp

AugmentationFn = Callable[[jax.Array, dict[str, jnp.ndarray]], dict[str, jnp.ndarray]]
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Image augmentation functions."""
"""Data augmentation module."""
from __future__ import annotations

from collections.abc import Sequence
from typing import Callable

import jax
import jax.numpy as jnp
from jax import numpy as jnp

from imgx.data import AugmentationFn
AugmentationFn = Callable[[jax.Array, dict[str, jnp.ndarray]], dict[str, jnp.ndarray]]


def chain_aug_fns(
Expand Down
Loading

0 comments on commit ae49d83

Please sign in to comment.