Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync dev #184

Merged
merged 17 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions experiments/vision-transformer/unetr/dsb/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Checking distance loss based segmentation on DSB, with U-Net and U-Netr.
40 changes: 40 additions & 0 deletions experiments/vision-transformer/unetr/dsb/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch_em
from torch_em.data.datasets import get_dsb_loader
from torch_em.transform.raw import get_raw_transform
from micro_sam.training import identity


def get_loaders(normalize_raw, batch_size=4, patch_shape=(1, 256, 256), data_root="./data"):
if normalize_raw:
raw_trafo = get_raw_transform()
else:
raw_trafo = identity

label_trafo = torch_em.transform.label.PerObjectDistanceTransform(
distances=True,
boundary_distances=True,
directed_distances=False,
foreground=True,
min_size=25,
)

train_loader = get_dsb_loader(
data_root, patch_shape=patch_shape, split="train",
download=True, batch_size=batch_size, ndim=2,
label_transform=label_trafo, raw_transform=raw_trafo
)
val_loader = get_dsb_loader(
data_root, patch_shape=patch_shape, split="test", batch_size=batch_size,
label_transform=label_trafo, raw_transform=raw_trafo, ndim=2,
)

return train_loader, val_loader


# TODO visualize the loader
def main():
pass


if __name__ == "__main__":
main()
42 changes: 42 additions & 0 deletions experiments/vision-transformer/unetr/dsb/train_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
import torch_em
from torch_em.model import UNet2d

from common import get_loaders


def train_unet(use_dice, mask_background):
model = UNet2d(in_channels=1, out_channels=3, initial_features=64, final_activation="Sigmoid")

n_iterations = 10_000

if use_dice:
loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=mask_background)
name = "distance_unet-dice"
else:
loss = torch_em.loss.DistanceLoss(mask_distances_in_bg=mask_background)
name = "distance_unet-dist-loss"

if mask_background:
name += "-mask-bg"

train_loader, val_loader = get_loaders(True)

trainer = torch_em.default_segmentation_trainer(
name=name,
model=model,
train_loader=train_loader,
val_loader=val_loader,
loss=loss,
metric=loss,
learning_rate=1e-4,
device=torch.device("cuda"),
mixed_precision=True,
log_image_interval=100,
compile_model=False,
)
trainer.fit(n_iterations)


if __name__ == "__main__":
train_unet(use_dice=False, mask_background=True)
58 changes: 58 additions & 0 deletions experiments/vision-transformer/unetr/dsb/train_unetr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
import torch_em
from torch_em.model.unetr import UNETR

from common import get_loaders


def get_model(pretrained):
checkpoint = "/home/nimcpape/.cache/micro_sam/models/vit_b" if pretrained else None
model = UNETR(
backbone="sam", encoder="vit_b", out_channels=3,
encoder_checkpoint_path=checkpoint,
use_sam_stats=pretrained, final_activation="Sigmoid",
)
return model


def train_unetr(pretrained, use_dice, mask_background):
model = get_model(pretrained)

n_iterations = 10_000

if use_dice:
loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=mask_background)
name = "distance_unetr-dice"
else:
loss = torch_em.loss.DistanceLoss(mask_distances_in_bg=mask_background)
name = "distance_unetr-dist-loss"

if mask_background:
name += "-mask-bg"

if pretrained:
name += "-pretrained"

train_loader, val_loader = get_loaders(True)

# the trainer object that handles the training details
# the model checkpoints will be saved in "checkpoints/dsb-boundary-model"
# the tensorboard logs will be saved in "logs/dsb-boundary-model"
trainer = torch_em.default_segmentation_trainer(
name=name,
model=model,
train_loader=train_loader,
val_loader=val_loader,
loss=loss,
metric=loss,
learning_rate=1e-4,
device=torch.device("cuda"),
mixed_precision=True,
log_image_interval=100,
compile_model=False,
)
trainer.fit(n_iterations)


if __name__ == "__main__":
train_unetr(pretrained=True, use_dice=False, mask_background=True)
7 changes: 4 additions & 3 deletions experiments/vision-transformer/unetr/initialize_with_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNETR(encoder="vit_h", out_channels=1,
encoder_checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_h_4b8939.pth")
model = UNETR(
backbone="mae", encoder="vit_b", out_channels=1, use_sam_stats=False
)
model.to(device)

x = torch.randn(1, 3, 1024, 1024).to(device=device)
x = torch.randn(1, 1, 512, 512).to(device=device)

y = model(x)
print(y.shape)
9 changes: 9 additions & 0 deletions experiments/vision-transformer/unetr/livecell/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Using different `UNETR` settings on LIVECell

- Binary Segmentation - TODO
- Foreground-Boundary Segmentation - TODO
- Affinities - TODO
- Distance Maps (HoVerNet-style)
```python
python livecell_all_hovernet [--train / --predict / --evaluate] -i <LIVECELL_DATA> -s <SAVE_ROOT> --save_dir <PREDICTION_DIR>
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os

import imageio.v2 as imageio
import napari

LIVECELL_FOLDER = "/home/pape/Work/data/incu_cyte/livecell"


def check_hv_segmentation(image, gt):
from torch_em.transform.label import PerObjectDistanceTransform
from common import opencv_hovernet_instance_segmentation

# This transform gives only directed boundary distances
# and foreground probabilities.
trafo = PerObjectDistanceTransform(
distances=False,
boundary_distances=False,
directed_distances=True,
foreground=True,
min_size=10,
)
target = trafo(gt)
seg = opencv_hovernet_instance_segmentation(target)

v = napari.Viewer()
v.add_image(image)
v.add_image(target)
v.add_labels(gt)
v.add_labels(seg)
napari.run()


def check_distance_segmentation(image, gt):
from torch_em.transform.label import PerObjectDistanceTransform
from torch_em.util.segmentation import watershed_from_center_and_boundary_distances

# This transform gives distance to the centroid,
# to the boundaries and the foreground probabilities
trafo = PerObjectDistanceTransform(
distances=True,
boundary_distances=True,
directed_distances=False,
foreground=True,
min_size=10,
)
target = trafo(gt)

# run the segmentation
fg, cdist, bdist = target
seg = watershed_from_center_and_boundary_distances(
cdist, bdist, fg, min_size=50,
)

# visualize it
v = napari.Viewer()
v.add_image(image)
v.add_image(target)
v.add_labels(gt)
v.add_labels(seg)
napari.run()


def main():
# load image and ground-truth from LiveCELL
fname = "A172_Phase_A7_1_01d00h00m_1.tif"
image_path = os.path.join(LIVECELL_FOLDER, "images/livecell_train_val_images", fname)
image = imageio.imread(image_path)

label_path = os.path.join(LIVECELL_FOLDER, "annotations/livecell_train_val_images/A172", fname)
gt = imageio.imread(label_path)

# Check the hovernet instance segmentation on GT.
check_hv_segmentation(image, gt)

# Check the new distance based segmentation on GT.
check_distance_segmentation(image, gt)


if __name__ == "__main__":
main()
Loading