This is the implementation of the paper: Semi-supervised Learning using Robust Loss
This repo is forked from the official implementation for TransBTS: Multimodal Brain Tumor Segmentation Using Transformer. The multimodal brain tumor datasets (BraTS 2019 & BraTS 2020) could be acquired from here.
python -m torch.distributed.launch --nproc_per_node=2 train_cv.py --corrupt_r=0.5 --train_partial=False --beta=0.001 --experiment='test_run' --fold=0
python /scratch1/wenhuicu/robust_seg/TransBTS/validation.py --test_file='model_epoch_last.pth' --valid_file='test_list.txt' --submission='' --experiment='test_run_f0' --csv_name='test_run_f0.csv'
- python 3.7
- pytorch 1.6.0
- torchvision 0.7.0
- pickle
- nibabel
After downloading the dataset from here, data preprocessing is needed which is to convert the .nii files as .pkl files and realize date normalization.
python3 preprocess.py
Run the training script on BraTS dataset. Distributed training is available for training the proposed TransBTS, where --nproc_per_node decides the numer of gpus and --master_port implys the port number.
-
train_cv.py is the training file mainly used for baseline training, CE loss, and Robust loss training, and performs 3-fold cross validation. To train on different folds, specify the fold number.
-
train_main.py is same as train_cv.py except that it does not have cross-validation part.
-
train_cps.py implements cross pseudo supervision. For comparison with robust loss, it is based on train_cv.py
python -m torch.distributed.launch --nproc_per_node=2 train_cv.py
Run python validation.py
-
validation.py is the one used for performance evaluation. It calculates Dice Scores and Hausdorff Distance. Results are saved in a csv file, you can use calc_mean_var() function in plot.py file to calculate mean dices across 3 folds.
-
predict.py has the code for actual model evaluation and metric calculation. We use validate_performance() function to calculate dices and hd, and save the mean in a csv file separately for each fold. compare_performance() is used to generate predicted segmentation maps and save them if specify --submission argument. Also, compare_performance saves all dice scores of each subject in a txt file for later analysis.
- TransBTS_downsample8x_skipconnection_lw.py is the one uses one layer in the transformer module, and half-sized hidden layer (last flatten layer)