Skip to content

Commit

Permalink
start adding documentation for toothfairy2
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Aug 27, 2024
1 parent 52fa355 commit 6f184c8
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 0 deletions.
177 changes: 177 additions & 0 deletions documentation/competitions/Toothfairy2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Introduction

This document describes our submission to the [Toothfairy2 Challenge](https://toothfairy2.grand-challenge.org/toothfairy2/).
Our model is essentially a nnU-Net ResEnc L with the patch size upscaled to 160x320x320 pixels. We disable left/right
mirroring and train for 1500 instead of the standard 1000 epochs. Training was either done on 2xA100 40GB or one GH200 96GB.

# Dataset Conversion

# Experiment Planning and Preprocessing

## Extract fingerprint:
`nnUNetv2_extract_fingerprint -d 119 -np 48`

## Run planning:
`nnUNetv2_plan_experiment -d 119 -pl nnUNetPlannerResEncL_torchres`

This planner not only uses the ResEncL configuration but also replaces the default resampling scheme with one that is
faster (but less precise). Since all images in the challenge (train and test) should already have 0.3x0.3x0.3 spacing
resampling is not required. This is just here as a safety measure. The speed is needed at inference time because grand
challenge imposes a limit of 10 minutes per case.

## Edit the plans files
Add the following configuration to the generated plans file:

```json
"3d_fullres_torchres_ps160x320x320_bs2": {
"inherits_from": "3d_fullres",
"patch_size": [
160,
320,
320
],
"architecture": {
"network_class_name": "dynamic_network_architectures.architectures.unet.ResidualEncoderUNet",
"arch_kwargs": {
"n_stages": 7,
"features_per_stage": [
32,
64,
128,
256,
320,
320,
320
],
"conv_op": "torch.nn.modules.conv.Conv3d",
"kernel_sizes": [
[
3,
3,
3
],
[
3,
3,
3
],
[
3,
3,
3
],
[
3,
3,
3
],
[
3,
3,
3
],
[
3,
3,
3
],
[
3,
3,
3
]
],
"strides": [
[
1,
1,
1
],
[
2,
2,
2
],
[
2,
2,
2
],
[
2,
2,
2
],
[
2,
2,
2
],
[
2,
2,
2
],
[
1,
2,
2
]
],
"n_blocks_per_stage": [
1,
3,
4,
6,
6,
6,
6
],
"n_conv_per_stage_decoder": [
1,
1,
1,
1,
1,
1
],
"conv_bias": true,
"norm_op": "torch.nn.modules.instancenorm.InstanceNorm3d",
"norm_op_kwargs": {
"eps": 1e-05,
"affine": true
},
"dropout_op": null,
"dropout_op_kwargs": null,
"nonlin": "torch.nn.LeakyReLU",
"nonlin_kwargs": {
"inplace": true
}
},
"_kw_requires_import": [
"conv_op",
"norm_op",
"dropout_op",
"nonlin"
]
}
}
```
Aside from changing the patch size this makes the architecture one stage deeper (one more pooling + res blocks), enabling
it to make effective use of the larger input

# Training
We train two models on all training cases:

```bash
nnUNetv2_train 119 3d_fullres_torchres_ps160x320x320_bs2 all -p nnUNetResEncUNetLPlans -tr nnUNetTrainer_onlyMirror01_1500ep
nnUNet_results=${nnUNet_results}_2 nnUNetv2_train 119 3d_fullres_torchres_ps160x320x320_bs2 all -p nnUNetResEncUNetLPlans -tr nnUNetTrainer_onlyMirror01_1500ep
```
Note how in the second line we overwrite the nnUNet_results variable in order to be able to train the same model twice without overwriting the results

# Inference
We ensemble the two models from above. On a technical level we copy the two fold_all folders into one training output
directory and rename them to fold_0 and fold_1. This lets us use nnU-Net's cross-validation ensembling strategy which
is more computationally efficient (needed for time limit on grand-challenge.org).

Run inference with the inference script
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Union, Tuple, List

import numpy as np
import torch
from batchgeneratorsv2.helpers.scalar_type import RandomScalar
from batchgeneratorsv2.transforms.base.basic_transform import BasicTransform
from batchgeneratorsv2.transforms.intensity.brightness import MultiplicativeBrightnessTransform
Expand Down Expand Up @@ -51,6 +52,13 @@ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes


class nnUNetTrainer_onlyMirror01_1500ep(nnUNetTrainer_onlyMirror01):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
self.num_epochs = 1500


class nnUNetTrainer_onlyMirror01_DASegOrd0(nnUNetTrainer_onlyMirror01):
@staticmethod
def get_training_transforms(
Expand Down

0 comments on commit 6f184c8

Please sign in to comment.