Skip to content

Few-shot-segmenter is a Python code for 3D tomographic microstructural data segmentation.

License

Notifications You must be signed in to change notification settings

poyentung/few-shot-segmenter

Repository files navigation

Few-shot-segementer

few-shot-segmenter is an open-source Python code for complex microstructure segmentation, mainly built with Pytorch.

Benchmarks

We benchmark this model against Trainable Weka Segmentation and DeepLabV3-ResNet-50 on a 3D tomographic carbonaceous chondrite meteorite dataset. We use the metric of Intersection of Union (IoU) calculated between the model prediction and the ground truth along 1000 images in the dataset.

Model Framboid Plaquette Cauliflower Average
Trainable Weka 80.55 20.66 22.33 41.18
ResNet-50 86.24 41.08 73.90 67.07
5-shot 94.05 71.24 76.59 80.62

3D visualisation of the segmentation results using Dragonfly.


Qualitative analysis on the segmentation results done by different models.


Installation

1. Environment setup

Create a Python>=3.9 environment with Conda:

conda create -n fewshot python=3.9
conda activate fewshot

2. Installation

Install few-shot-segmenter from source using pip:

git clone https://github.com/poyentung/few-shot-segmenter.git
cd few-shot-segmenter
pip install -e ./

Getting Started

1. Prepare data and masks

demo_data/ folder structures all the necessary images for training and evaluation.

# Training data
datapath/
    └── specimen/            
        ├── phase0/
        │   ├── annotation/        # target masks
        │   └── image/             # input images
        ├── phase1/
        │   ├── annotation/
        │   └── image/
        └── ...
# Evaluation data
datapath/
    ├── query_set/  
    │       ├── iamge1.tiff
    │       ├── iamge2.tiff
    │       └── ...
    ├── (optional) query_ground_truth/                     
    │       ├── phase0/
    │       │   ├── mask1.tiff       
    │       │   ├── mask2.tiff       
    │       │   └── ... 
    │       ├── phase1/
    │       │   ├── mask1.tiff
    │       │   ├── mask2.tiff
    │       │   └── ...
    │       └── ...
    └── support_set/            
            ├── phase0/
            │   ├── annotation/      # target masks
            │   └── image/           # input images
            ├── phase1/
            │   ├── annotation/
            │   └── image/
            └── ...

2. Setup configuration file

All the config parameters for training modules are saved in the folder conf/, and overidden by train.yaml and train.yaml. For example, we can set the data augmentation of the datamodule in train.yaml:

......

datamodule:
    # Configuration
    datapath: ${original_work_dir}/demo_data/train    # directory of training data
    nshot: 3                                          # number of shot for the episodic learning technique
    nsamples: 500                                     # number of images (256*256) cropped from the large image for training

    # hyperparams
    val_data_ratio: 0.15                              # proportion for validation data
    batch_size: 5                                     # batch size for each mini-batch
    n_cpu: 8                                            

    # Data augmentation
    contrast: [0.5,1.5]                               # varying contrast of the image with the boundary condition
    rotation_degrees: 90.0                            # randomly rotate the image up to 90 degree before cropping
    scale: [0.2,0.3]                                  # randomly rescale the image before cropping
    crop_size: 256                                    # crop size of the image for training
    copy_paste_prob: 0.15                             # probability of copy-paste for the training data

......

3. Training model

Run the training with 10 epochs:

python train.py trainer.max_epoch=10

We can also override some of the parameters directly on the commandline. For example,

python train.py model_name=test2 datamodule.nshot=5 datamodule.batch_size=10

4. Evaluation

We only segment single phase each time when we call the function. Please note that this process is GPU-memory-intensive - please reduce the number of annotator.batch_size if the relevant error is present. The specified phase in the commandline is the filename in the data folder. For example, if we want to segment cauliflower with the model test (specified as model_name in the yaml file) and a batch_size of 5, we can run:

python segment.py model_name=test phase=cauliflower annotator.batch_size=5

We can also segment multiple phases in a run:

python segment.py --multirun phase=framboid,plaquette,cauliflower

See the example notebook on few-shot-segemnter model training and predictions for more details.

License

Few-shot-segmenter is released under the GPL-3.0 license.

About

Few-shot-segmenter is a Python code for 3D tomographic microstructural data segmentation.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published