Skip to content

Commit

Permalink
Merge pull request #875 from MouseLand/cli_restore
Browse files Browse the repository at this point in the history
adding CLI for restore
  • Loading branch information
carsen-stringer authored Feb 23, 2024
2 parents d4857f5 + 307bcf1 commit 83f00ce
Show file tree
Hide file tree
Showing 13 changed files with 621 additions and 253 deletions.
59 changes: 45 additions & 14 deletions cellpose/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from natsort import natsorted
from tqdm import tqdm
from cellpose import utils, models, io, version_str, train
from cellpose import utils, models, io, version_str, train, denoise
from cellpose.cli import get_arg_parser

try:
Expand Down Expand Up @@ -90,9 +90,18 @@ def main():
else:
pretrained_model = args.pretrained_model

restore_type = args.restore_type
if restore_type is not None:
try:
denoise.model_path(restore_type)
except Exception as e:
raise ValueError("restore_type invalid")
if args.train or args.train_size:
raise ValueError("restore_type cannot be used with training on CLI yet")

model_type = None
if pretrained_model and not os.path.exists(pretrained_model):
model_type = pretrained_model if pretrained_model is not None else "cyto"
model_type = pretrained_model if pretrained_model is not None else "cyto3"
model_strings = models.get_user_models()
all_models = models.MODEL_NAMES.copy()
all_models.extend(model_strings)
Expand Down Expand Up @@ -127,26 +136,39 @@ def main():
">>>> running cellpose on %d images using chan_to_seg %s and chan (opt) %s"
% (nimg, cstr0[channels[0]], cstr1[channels[1]]))

# handle built-in model exceptions; bacterial ones get no size model
if builtin_size:
# handle built-in model exceptions
if builtin_size and restore_type is None:
model = models.Cellpose(gpu=gpu, device=device, model_type=model_type)
else:
builtin_size = False
if args.all_channels:
channels = None
pretrained_model = None if model_type is not None else pretrained_model
model = models.CellposeModel(gpu=gpu, device=device,
pretrained_model=pretrained_model,
model_type=model_type)
if restore_type is None:
model = models.CellposeModel(gpu=gpu, device=device,
pretrained_model=pretrained_model,
model_type=model_type)
else:
model = denoise.CellposeDenoiseModel(gpu=gpu, device=device,
pretrained_model=pretrained_model,
model_type=model_type,
restore_type=restore_type,
chan2_restore=args.chan2_restore)

# handle diameters
if args.diameter == 0:
if builtin_size:
diameter = None
logger.info(">>>> estimating diameter for each image")
else:
logger.info(
">>>> not using cyto, cyto2, or nuclei model, cannot auto-estimate diameter"
)
if restore_type is None:
logger.info(
">>>> not using cyto3, cyto, cyto2, or nuclei model, cannot auto-estimate diameter"
)
else:
logger.info(
">>>> cannot auto-estimate diameter for image restoration"
)
diameter = model.diam_labels
logger.info(">>>> using diameter %0.3f for all images" % diameter)
else:
Expand All @@ -168,17 +190,26 @@ def main():
channel_axis=args.channel_axis, z_axis=args.z_axis,
anisotropy=args.anisotropy, niter=args.niter)
masks, flows = out[:2]
if len(out) > 3:
if len(out) > 3 and restore_type is None:
diams = out[-1]
else:
diams = diameter
ratio = 1.
if restore_type is not None:
imgs_dn = out[-1]
ratio = diams / model.dn.diam_mean if "upsample" in restore_type else 1.
diams = model.dn.diam_mean if "upsample" in restore_type and model.dn.diam_mean > diams else diams
else:
imgs_dn = None
if args.exclude_on_edges:
masks = utils.remove_edge_masks(masks)
if not args.no_npy:
io.masks_flows_to_seg(image, masks, flows, image_name,
channels=channels, diams=diams)
io.masks_flows_to_seg(image, masks, flows, image_name, imgs_restore=imgs_dn,
channels=channels, diams=diams,
restore_type=restore_type, ratio=1.)
if saving_something:
io.save_masks(image, masks, flows, image_name, png=args.save_png,
io.save_masks(image, masks, flows, image_name,
png=args.save_png,
tif=args.save_tif, save_flows=args.save_flows,
save_outlines=args.save_outlines,
dir_above=args.dir_above, savedir=args.savedir,
Expand Down
5 changes: 5 additions & 0 deletions cellpose/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def get_arg_parser():
model_args.add_argument("--pretrained_model", required=False, default="cyto",
type=str,
help="model to use for running or starting training")
model_args.add_argument("--restore_type", required=False, default=None,
type=str,
help="model to use for image restoration")
model_args.add_argument("--chan2_restore", action="store_true",
help="use nuclei restore model for second channel")
model_args.add_argument(
"--add_model", required=False, default=None, type=str,
help="model path to copy model to hidden .cellpose folder for using in GUI/CLI")
Expand Down
16 changes: 9 additions & 7 deletions cellpose/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,19 +464,19 @@ def one_chan_cellpose(device, model_type="cyto2", pretrained_model=None):
class CellposeDenoiseModel():
""" model to run Cellpose and Image restoration """
def __init__(self, gpu=False, pretrained_model=False, model_type=None,
restore_type="denoise_cyto3", chan2_denoise=False,
restore_type="denoise_cyto3", chan2_restore=False,
device=None):

self.dn = DenoiseModel(gpu=gpu, model_type=restore_type,
chan2=chan2_denoise, device=device)
chan2=chan2_restore, device=device)
self.cp = CellposeModel(gpu=gpu, model_type=model_type,
pretrained_model=pretrained_model, device=device)

def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
normalize=True, rescale=None, diameter=None, tile=True, tile_overlap=0.1,
resample=True, invert=False, flow_threshold=0.4, cellprob_threshold=0.0,
do_3D=False, anisotropy=None, stitch_threshold=0.0, min_size=15,
niter=None, interp=True):
augment=False, resample=True, invert=False, flow_threshold=0.4,
cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0,
min_size=15, niter=None, interp=True):
"""
Restore array or list of images using the image restoration model, and then segment.
Expand Down Expand Up @@ -510,6 +510,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
tile (bool, optional): tiles image to ensure GPU/CPU memory usage limited (recommended). Defaults to True.
tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
augment (bool, optional): augment tiles by flipping and averaging for segmentation. Defaults to False.
resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries). Defaults to True.
invert (bool, optional): invert image pixel intensity before running network. Defaults to False.
flow_threshold (float, optional): flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4.
Expand Down Expand Up @@ -549,7 +550,8 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
diameter = self.dn.diam_mean if self.dn.ratio > 1 else diameter
masks, flows, styles = self.cp.eval(img_restore, batch_size=batch_size, channels=channels_new, channel_axis=-1,
normalize=normalize_params, rescale=rescale, diameter=diameter,
tile=tile, tile_overlap=tile_overlap, resample=resample, invert=invert,
tile=tile, tile_overlap=tile_overlap, augment=augment,
resample=resample, invert=invert,
flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold,
do_3D=do_3D, anisotropy=anisotropy, stitch_threshold=stitch_threshold,
min_size=min_size, niter=niter, interp=interp)
Expand Down Expand Up @@ -644,7 +646,7 @@ def __init__(self, gpu=False, pretrained_model=False, nchan=1, model_type=None,
)
if chan2 and builtin:
chan2_path = model_path(os.path.split(self.pretrained_model)[-1].split("_")[0] + "_nuclei")
print(f"loading model for chan2: {os.path.split(str(chan2_path)[-1])}")
print(f"loading model for chan2: {os.path.split(str(chan2_path))[-1]}")
self.net_chan2 = CPnet(self.nbase, self.nclasses, sz=3,
mkldnn=self.mkldnn, max_pool=True,
diam_mean=17.).to(self.device)
Expand Down
Loading

0 comments on commit 83f00ce

Please sign in to comment.