few-shot-segmenter
is an open-source Python code for complex microstructure segmentation, mainly built with Pytorch.
We benchmark this model against Trainable Weka Segmentation and DeepLabV3-ResNet-50 on a 3D tomographic carbonaceous chondrite meteorite dataset. We use the metric of Intersection of Union (IoU) calculated between the model prediction and the ground truth along 1000 images in the dataset.
Model | Framboid | Plaquette | Cauliflower | Average |
---|---|---|---|---|
Trainable Weka | 80.55 | 20.66 | 22.33 | 41.18 |
ResNet-50 | 86.24 | 41.08 | 73.90 | 67.07 |
5-shot | 94.05 | 71.24 | 76.59 | 80.62 |
3D visualisation of the segmentation results using Dragonfly.
Qualitative analysis on the segmentation results done by different models.
Create a Python>=3.9 environment with Conda:
conda create -n fewshot python=3.9
conda activate fewshot
Install few-shot-segmenter from source using pip:
git clone https://github.com/poyentung/few-shot-segmenter.git
cd few-shot-segmenter
pip install -e ./
demo_data/
folder structures all the necessary images for training and evaluation.
# Training data
datapath/
└── specimen/
├── phase0/
│ ├── annotation/ # target masks
│ └── image/ # input images
├── phase1/
│ ├── annotation/
│ └── image/
└── ...
# Evaluation data
datapath/
├── query_set/
│ ├── iamge1.tiff
│ ├── iamge2.tiff
│ └── ...
├── (optional) query_ground_truth/
│ ├── phase0/
│ │ ├── mask1.tiff
│ │ ├── mask2.tiff
│ │ └── ...
│ ├── phase1/
│ │ ├── mask1.tiff
│ │ ├── mask2.tiff
│ │ └── ...
│ └── ...
└── support_set/
├── phase0/
│ ├── annotation/ # target masks
│ └── image/ # input images
├── phase1/
│ ├── annotation/
│ └── image/
└── ...
All the config parameters for training modules are saved in the folder conf/, and overidden by train.yaml
and train.yaml
. For example, we can set the data augmentation of the datamodule in train.yaml:
......
datamodule:
# Configuration
datapath: ${original_work_dir}/demo_data/train # directory of training data
nshot: 3 # number of shot for the episodic learning technique
nsamples: 500 # number of images (256*256) cropped from the large image for training
# hyperparams
val_data_ratio: 0.15 # proportion for validation data
batch_size: 5 # batch size for each mini-batch
n_cpu: 8
# Data augmentation
contrast: [0.5,1.5] # varying contrast of the image with the boundary condition
rotation_degrees: 90.0 # randomly rotate the image up to 90 degree before cropping
scale: [0.2,0.3] # randomly rescale the image before cropping
crop_size: 256 # crop size of the image for training
copy_paste_prob: 0.15 # probability of copy-paste for the training data
......
Run the training with 10 epochs:
python train.py trainer.max_epoch=10
We can also override some of the parameters directly on the commandline. For example,
python train.py model_name=test2 datamodule.nshot=5 datamodule.batch_size=10
We only segment single phase each time when we call the function. Please note that this process is GPU-memory-intensive - please reduce the number of annotator.batch_size
if the relevant error is present. The specified phase
in the commandline is the filename in the data folder. For example, if we want to segment cauliflower
with the model test
(specified as model_name in the yaml
file) and a batch_size of 5, we can run:
python segment.py model_name=test phase=cauliflower annotator.batch_size=5
We can also segment multiple phases in a run:
python segment.py --multirun phase=framboid,plaquette,cauliflower
See the example notebook on few-shot-segemnter model training and predictions for more details.
Few-shot-segmenter is released under the GPL-3.0 license.