diff --git a/cellpose/__main__.py b/cellpose/__main__.py index 7636a1bf..072e931f 100644 --- a/cellpose/__main__.py +++ b/cellpose/__main__.py @@ -96,6 +96,8 @@ def main(): denoise.model_path(restore_type) except Exception as e: raise ValueError("restore_type invalid") + if args.train or args.train_size: + raise ValueError("restore_type cannot be used with training on CLI yet") model_type = None if pretrained_model and not os.path.exists(pretrained_model): @@ -134,16 +136,24 @@ def main(): ">>>> running cellpose on %d images using chan_to_seg %s and chan (opt) %s" % (nimg, cstr0[channels[0]], cstr1[channels[1]])) - # handle built-in model exceptions; bacterial ones get no size model - if builtin_size and restore_type is not None: + # handle built-in model exceptions + if builtin_size and restore_type is None: model = models.Cellpose(gpu=gpu, device=device, model_type=model_type) else: + builtin_size = False if args.all_channels: channels = None pretrained_model = None if model_type is not None else pretrained_model - model = models.CellposeModel(gpu=gpu, device=device, - pretrained_model=pretrained_model, - model_type=model_type) + if restore_type is None: + model = models.CellposeModel(gpu=gpu, device=device, + pretrained_model=pretrained_model, + model_type=model_type) + else: + model = denoise.CellposeDenoiseModel(gpu=gpu, device=device, + pretrained_model=pretrained_model, + model_type=model_type, + restore_type=restore_type, + chan2_restore=args.chan2_restore) # handle diameters if args.diameter == 0: @@ -151,9 +161,14 @@ def main(): diameter = None logger.info(">>>> estimating diameter for each image") else: - logger.info( - ">>>> not using cyto, cyto2, or nuclei model, cannot auto-estimate diameter" - ) + if restore_type is None: + logger.info( + ">>>> not using cyto3, cyto, cyto2, or nuclei model, cannot auto-estimate diameter" + ) + else: + logger.info( + ">>>> cannot auto-estimate diameter for image restoration" + ) diameter = model.diam_labels logger.info(">>>> using diameter %0.3f for all images" % diameter) else: @@ -175,17 +190,26 @@ def main(): channel_axis=args.channel_axis, z_axis=args.z_axis, anisotropy=args.anisotropy, niter=args.niter) masks, flows = out[:2] - if len(out) > 3: + if len(out) > 3 and restore_type is None: diams = out[-1] else: diams = diameter + ratio = 1. + if restore_type is not None: + imgs_dn = out[-1] + ratio = diams / model.dn.diam_mean if "upsample" in restore_type else 1. + diams = model.dn.diam_mean if "upsample" in restore_type and model.dn.diam_mean > diams else diams + else: + imgs_dn = None if args.exclude_on_edges: masks = utils.remove_edge_masks(masks) if not args.no_npy: - io.masks_flows_to_seg(image, masks, flows, image_name, - channels=channels, diams=diams) + io.masks_flows_to_seg(image, masks, flows, image_name, imgs_restore=imgs_dn, + channels=channels, diams=diams, + restore_type=restore_type, ratio=1.) if saving_something: - io.save_masks(image, masks, flows, image_name, png=args.save_png, + io.save_masks(image, masks, flows, image_name, + png=args.save_png, tif=args.save_tif, save_flows=args.save_flows, save_outlines=args.save_outlines, dir_above=args.dir_above, savedir=args.savedir, diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 76c536ce..e2638564 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -474,9 +474,9 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None, def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, normalize=True, rescale=None, diameter=None, tile=True, tile_overlap=0.1, - resample=True, invert=False, flow_threshold=0.4, cellprob_threshold=0.0, - do_3D=False, anisotropy=None, stitch_threshold=0.0, min_size=15, - niter=None, interp=True): + augment=False, resample=True, invert=False, flow_threshold=0.4, + cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0, + min_size=15, niter=None, interp=True): """ Restore array or list of images using the image restoration model, and then segment. @@ -510,6 +510,7 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, if diameter is None, set to diam_mean or diam_train if available. Defaults to None. tile (bool, optional): tiles image to ensure GPU/CPU memory usage limited (recommended). Defaults to True. tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1. + augment (bool, optional): augment tiles by flipping and averaging for segmentation. Defaults to False. resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries). Defaults to True. invert (bool, optional): invert image pixel intensity before running network. Defaults to False. flow_threshold (float, optional): flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4. @@ -549,7 +550,8 @@ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, diameter = self.dn.diam_mean if self.dn.ratio > 1 else diameter masks, flows, styles = self.cp.eval(img_restore, batch_size=batch_size, channels=channels_new, channel_axis=-1, normalize=normalize_params, rescale=rescale, diameter=diameter, - tile=tile, tile_overlap=tile_overlap, resample=resample, invert=invert, + tile=tile, tile_overlap=tile_overlap, augment=augment, + resample=resample, invert=invert, flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold, do_3D=do_3D, anisotropy=anisotropy, stitch_threshold=stitch_threshold, min_size=min_size, niter=niter, interp=interp) @@ -644,7 +646,7 @@ def __init__(self, gpu=False, pretrained_model=False, nchan=1, model_type=None, ) if chan2 and builtin: chan2_path = model_path(os.path.split(self.pretrained_model)[-1].split("_")[0] + "_nuclei") - print(f"loading model for chan2: {os.path.split(str(chan2_path)[-1])}") + print(f"loading model for chan2: {os.path.split(str(chan2_path))[-1]}") self.net_chan2 = CPnet(self.nbase, self.nclasses, sz=3, mkldnn=self.mkldnn, max_pool=True, diam_mean=17.).to(self.device) diff --git a/cellpose/dynamics.py b/cellpose/dynamics.py index ef32deb2..acdf1492 100644 --- a/cellpose/dynamics.py +++ b/cellpose/dynamics.py @@ -142,10 +142,10 @@ def masks_to_flows_gpu(masks, device=None, niter=None): masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels. Returns: - mu (float, 3D or 4D array): Flows in Y = mu[-2], flows in X = mu[-1]. - If masks are 3D, flows in Z = mu[0]. - mu_c (float, 2D or 3D array): For each pixel, the distance to the center of the mask - in which it resides. + tuple containing + - mu (float, 3D or 4D array): Flows in Y = mu[-2], flows in X = mu[-1]. + If masks are 3D, flows in Z = mu[0]. + - meds_p (float, 2D or 3D array): cell centers """ if device is None: device = torch.device("cuda") @@ -200,8 +200,9 @@ def masks_to_flows_gpu_3d(masks, device=None): masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels. Returns: - mu (float, 3D or 4D array): Flows in Y = mu[-2], flows in X = mu[-1]. If masks are 3D, flows in Z = mu[0]. - mu_c (float, 2D or 3D array): For each pixel, the distance to the center of the mask in which it resides. + tuple containing + - mu (float, 3D or 4D array): Flows in Y = mu[-2], flows in X = mu[-1]. If masks are 3D, flows in Z = mu[0]. + - mu_c (float, 2D or 3D array): zeros """ if device is None: device = torch.device("cuda") @@ -276,8 +277,10 @@ def masks_to_flows_cpu(masks, device=None, niter=None): masks (int, 2D or 3D array): Labelled masks 0=NO masks; 1,2,...=mask labels Returns: - mu (float, 3D or 4D array): Flows in Y = mu[-2], flows in X = mu[-1]. If masks are 3D, flows in Z = mu[0]. - mu_c (float, 2D or 3D array): For each pixel, the distance to the center of the mask in which it resides + tuple containing + - mu (float, 3D or 4D array): Flows in Y = mu[-2], flows in X = mu[-1]. + If masks are 3D, flows in Z = mu[0]. + - meds (float, 2D or 3D array): cell centers """ Ly, Lx = masks.shape mu = np.zeros((2, Ly, Lx), np.float64) @@ -323,8 +326,8 @@ def masks_to_flows(masks, device=None, niter=None): masks (int, 2D or 3D array): Labelled masks 0=NO masks; 1,2,...=mask labels Returns: - mu (float, 3D or 4D array): Flows in Y = mu[-2], flows in X = mu[-1]. If masks are 3D, flows in Z = mu[0]. - mu_c (float, 2D or 3D array): For each pixel, the distance to the center of the mask in which it resides + mu (float, 3D or 4D array): Flows in Y = mu[-2], flows in X = mu[-1]. + If masks are 3D, flows in Z = mu[0]. """ if masks.max() == 0: dynamics_logger.warning("empty masks!") @@ -360,32 +363,29 @@ def masks_to_flows(masks, device=None, niter=None): def labels_to_flows(labels, files=None, device=None, redo_flows=False, niter=None): - """ convert labels (list of masks or flows) to flows for training model + """Converts labels (list of masks or flows) to flows for training model. - if files is not None, flows are saved to files to be reused - - Parameters - -------------- - - labels: list of ND-arrays - labels[k] can be 2D or 3D, if [3 x Ly x Lx] then it is assumed that flows were precomputed. - Otherwise labels[k][0] or labels[k] (if 2D) is used to create flows and cell probabilities. - - Returns - -------------- - - flows: list of [4 x Ly x Lx] arrays - flows[k][0] is labels[k], flows[k][1] is cell distance transform, flows[k][2] is Y flow, - flows[k][3] is X flow, and flows[k][4] is heat distribution + Args: + labels (list of ND-arrays): The labels to convert. labels[k] can be 2D or 3D. If [3 x Ly x Lx], + it is assumed that flows were precomputed. Otherwise, labels[k][0] or labels[k] (if 2D) + is used to create flows and cell probabilities. + files (list of str, optional): The files to save the flows to. If provided, flows are saved to + files to be reused. Defaults to None. + device (str, optional): The device to use for computation. Defaults to None. + redo_flows (bool, optional): Whether to recompute the flows. Defaults to False. + niter (int, optional): The number of iterations for computing flows. Defaults to None. + Returns: + list of [4 x Ly x Lx] arrays: The flows for training the model. flows[k][0] is labels[k], + flows[k][1] is cell distance transform, flows[k][2] is Y flow, flows[k][3] is X flow, + and flows[k][4] is heat distribution. """ nimg = len(labels) if labels[0].ndim < 3: labels = [labels[n][np.newaxis, :, :] for n in range(nimg)] - if labels[0].shape[ - 0] == 1 or labels[0].ndim < 3 or redo_flows: # flows need to be recomputed - + # flows need to be recomputed + if labels[0].shape[0] == 1 or labels[0].ndim < 3 or redo_flows: dynamics_logger.info("computing flows for labels") # compute flows; labels are fixed here to be unique, so they need to be passed back @@ -418,17 +418,16 @@ def labels_to_flows(labels, files=None, device=None, redo_flows=False, niter=Non ], cache=True) def map_coordinates(I, yc, xc, Y): """ - bilinear interpolation of image "I" in-place with ycoordinates yc and xcoordinates xc to Y + Bilinear interpolation of image "I" in-place with y-coordinates yc and x-coordinates xc to Y. - Parameters - ------------- - I : C x Ly x Lx - yc : ni - new y coordinates - xc : ni - new x coordinates - Y : C x ni - I sampled at (yc,xc) + Args: + I (numpy.ndarray): Input image of shape (C, Ly, Lx). + yc (numpy.ndarray): New y-coordinates. + xc (numpy.ndarray): New x-coordinates. + Y (numpy.ndarray): Output array of shape (C, ni). + + Returns: + None """ C, Ly, Lx = I.shape yc_floor = yc.astype(np.int32) @@ -450,6 +449,24 @@ def map_coordinates(I, yc, xc, Y): def steps2D_interp(p, dP, niter, device=None): + """ Run dynamics of pixels to recover masks in 2D, with interpolation between pixel values. + + Euler integration of dynamics dP for niter steps. + + Args: + p (numpy.ndarray): Array of shape (n_points, 2) representing the initial pixel locations. + dP (numpy.ndarray): Array of shape (2, Ly, Lx) representing the flow field. + niter (int): Number of iterations to perform. + device (torch.device, optional): Device to use for computation. Defaults to None. + + Returns: + numpy.ndarray: Array of shape (n_points, 2) representing the final pixel locations. + + Raises: + None + + """ + shape = dP.shape[1:] if device is not None and device.type == "cuda": shape = np.array(shape)[[ @@ -494,31 +511,18 @@ def steps2D_interp(p, dP, niter, device=None): @njit("(float32[:,:,:,:],float32[:,:,:,:], int32[:,:], int32)", nogil=True) def steps3D(p, dP, inds, niter): - """ run dynamics of pixels to recover masks in 3D - - Euler integration of dynamics dP for niter steps - - Parameters - ---------------- - - p: float32, 4D array - pixel locations [axis x Lz x Ly x Lx] (start at initial meshgrid) - - dP: float32, 4D array - flows [axis x Lz x Ly x Lx] - - inds: int32, 2D array - non-zero pixels to run dynamics on [npixels x 3] - - niter: int32 - number of iterations of dynamics to run + """ Run dynamics of pixels to recover masks in 3D. - Returns - --------------- + Euler integration of dynamics dP for niter steps. - p: float32, 4D array - final locations of each pixel after dynamics + Args: + p (np.ndarray): Pixel locations [axis x Lz x Ly x Lx] (start at initial meshgrid). + dP (np.ndarray): Flows [axis x Lz x Ly x Lx]. + inds (np.ndarray): Non-zero pixels to run dynamics on [npixels x 3]. + niter (int): Number of iterations of dynamics to run. + Returns: + np.ndarray: Final locations of each pixel after dynamics. """ shape = p.shape[1:] for t in range(niter): @@ -536,31 +540,18 @@ def steps3D(p, dP, inds, niter): @njit("(float32[:,:,:], float32[:,:,:], int32[:,:], int32)", nogil=True) def steps2D(p, dP, inds, niter): - """ run dynamics of pixels to recover masks in 2D - - Euler integration of dynamics dP for niter steps - - Parameters - ---------------- - - p: float32, 3D array - pixel locations [axis x Ly x Lx] (start at initial meshgrid) + """Run dynamics of pixels to recover masks in 2D. - dP: float32, 3D array - flows [axis x Ly x Lx] + Euler integration of dynamics dP for niter steps. - inds: int32, 2D array - non-zero pixels to run dynamics on [npixels x 2] - - niter: int32 - number of iterations of dynamics to run - - Returns - --------------- - - p: float32, 3D array - final locations of each pixel after dynamics + Args: + p (np.ndarray): Pixel locations [axis x Ly x Lx] (start at initial meshgrid). + dP (np.ndarray): Flows [axis x Ly x Lx]. + inds (np.ndarray): Non-zero pixels to run dynamics on [npixels x 2]. + niter (int): Number of iterations of dynamics to run. + Returns: + np.ndarray: Final locations of each pixel after dynamics. """ shape = p.shape[1:] for t in range(niter): @@ -576,40 +567,22 @@ def steps2D(p, dP, inds, niter): def follow_flows(dP, mask=None, niter=200, interp=True, device=None): - """ define pixels and run dynamics to recover masks in 2D - - Pixels are meshgrid. Only pixels with non-zero cell-probability - are used (as defined by inds) - - Parameters - ---------------- - - dP: float32, 3D or 4D array - flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx] - - mask: (optional, default None) - pixel mask to seed masks. Useful when flows have low magnitudes. - - niter: int (optional, default 200) - number of iterations of dynamics to run - - interp: bool (optional, default True) - interpolate during 2D dynamics (not available in 3D) - (in previous versions + paper it was False) - - use_gpu: bool (optional, default False) - use GPU to run interpolated dynamics (faster than CPU) - - - Returns - --------------- + """ Run dynamics to recover masks in 2D or 3D. - p: float32, 3D or 4D array - final locations of each pixel after dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx] + Pixels are represented as a meshgrid. Only pixels with non-zero cell-probability + are used (as defined by inds). - inds: int32, 3D or 4D array - indices of pixels used for dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx] + Args: + dP (np.ndarray): Flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx]. + mask (np.ndarray, optional): Pixel mask to seed masks. Useful when flows have low magnitudes. + niter (int, optional): Number of iterations of dynamics to run. Default is 200. + interp (bool, optional): Interpolate during 2D dynamics (not available in 3D). Default is True. + use_gpu (bool, optional): Use GPU to run interpolated dynamics (faster than CPU). Default is False. + Returns: + tuple containing: + - p (np.ndarray): Final locations of each pixel after dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx]. + - inds (np.ndarray): Indices of pixels used for dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx]. """ shape = np.array(dP.shape[1:]).astype(np.int32) niter = np.uint32(niter) @@ -640,33 +613,22 @@ def follow_flows(dP, mask=None, niter=200, interp=True, device=None): def remove_bad_flow_masks(masks, flows, threshold=0.4, device=None): - """ remove masks which have inconsistent flows - + """Remove masks which have inconsistent flows. + Uses metrics.flow_error to compute flows from predicted masks - and compare flows to predicted flows from network. Discards + and compare flows to predicted flows from the network. Discards masks with flow errors greater than the threshold. - Parameters - ---------------- - - masks: int, 2D or 3D array - labelled masks, 0=NO masks; 1,2,...=mask labels, - size [Ly x Lx] or [Lz x Ly x Lx] - - flows: float, 3D or 4D array - flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx] - - threshold: float (optional, default 0.4) - masks with flow error greater than threshold are discarded. - - Returns - --------------- + Args: + masks (int, 2D or 3D array): Labelled masks, 0=NO masks; 1,2,...=mask labels, + size [Ly x Lx] or [Lz x Ly x Lx]. + flows (float, 3D or 4D array): Flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx]. + threshold (float, optional): Masks with flow error greater than threshold are discarded. + Default is 0.4. - masks: int, 2D or 3D array - masks with inconsistent flow masks removed, - 0=NO masks; 1,2,...=mask labels, - size [Ly x Lx] or [Lz x Ly x Lx] - + Returns: + masks (int, 2D or 3D array): Masks with inconsistent flow masks removed, + 0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx]. """ device0 = device if masks.size > 10000 * 10000 and (device is not None and device.type == "cuda"): @@ -699,38 +661,24 @@ def mem_info(): def get_masks(p, iscell=None, rpad=20): - """ create masks using pixel convergence after running dynamics - + """Create masks using pixel convergence after running dynamics. + Makes a histogram of final pixel locations p, initializes masks at peaks of histogram and extends the masks from the peaks so that they include all pixels with more than 2 final pixels p. Discards masks with flow errors greater than the threshold. - Parameters - ---------------- - p: float32, 3D or 4D array - final locations of each pixel after dynamics, - size [axis x Ly x Lx] or [axis x Lz x Ly x Lx]. - iscell: bool, 2D or 3D array - if iscell is not None, set pixels that are - iscell False to stay in their original location. - rpad: int (optional, default 20) - histogram edge padding - threshold: float (optional, default 0.4) - masks with flow error greater than threshold are discarded - (if flows is not None) - flows: float, 3D or 4D array (optional, default None) - flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx]. If flows - is not None, then masks with inconsistent flows are removed using - `remove_bad_flow_masks`. - Returns - --------------- - M0: int, 2D or 3D array - masks with inconsistent flow masks removed, - 0=NO masks; 1,2,...=mask labels, - size [Ly x Lx] or [Lz x Ly x Lx] - - """ + Parameters: + p (float32, 3D or 4D array): Final locations of each pixel after dynamics, + size [axis x Ly x Lx] or [axis x Lz x Ly x Lx]. + iscell (bool, 2D or 3D array): If iscell is not None, set pixels that are + iscell False to stay in their original location. + rpad (int, optional): Histogram edge padding. Default is 20. + + Returns: + M0 (int, 2D or 3D array): Masks with inconsistent flow masks removed, + 0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx]. + """ pflows = [] edges = [] shape0 = p.shape[1:] @@ -807,11 +755,27 @@ def get_masks(p, iscell=None, rpad=20): M0 = np.reshape(M0, shape0) return M0 - def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0, flow_threshold=0.4, interp=True, do_3D=False, min_size=15, resize=None, device=None): - """ compute masks using dynamics from dP, cellprob, and boundary """ + """Compute masks using dynamics from dP and cellprob, and resizes masks if resize is not None. + + Args: + dP (numpy.ndarray): The dynamics flow field array. + cellprob (numpy.ndarray): The cell probability array. + p (numpy.ndarray, optional): The pixels on which to run dynamics. Defaults to None + niter (int, optional): The number of iterations for mask computation. Defaults to 200. + cellprob_threshold (float, optional): The threshold for cell probability. Defaults to 0.0. + flow_threshold (float, optional): The threshold for quality control metrics. Defaults to 0.4. + interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True. + do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False. + min_size (int, optional): The minimum size of the masks. Defaults to 15. + resize (tuple, optional): The desired size for resizing the masks. Defaults to None. + device (str, optional): The torch device to use for computation. Defaults to None. + + Returns: + tuple: A tuple containing the computed masks and the final pixel locations. + """ mask, p = compute_masks(dP, cellprob, p=p, niter=niter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold, interp=interp, do_3D=do_3D, @@ -831,8 +795,23 @@ def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0, flow_threshold=0.4, interp=True, do_3D=False, min_size=15, device=None): - """ compute masks using dynamics from dP, cellprob, and boundary """ + """Compute masks using dynamics from dP and cellprob. + Args: + dP (numpy.ndarray): The dynamics flow field array. + cellprob (numpy.ndarray): The cell probability array. + p (numpy.ndarray, optional): The pixels on which to run dynamics. Defaults to None + niter (int, optional): The number of iterations for mask computation. Defaults to 200. + cellprob_threshold (float, optional): The threshold for cell probability. Defaults to 0.0. + flow_threshold (float, optional): The threshold for quality control metrics. Defaults to 0.4. + interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True. + do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False. + min_size (int, optional): The minimum size of the masks. Defaults to 15. + device (str, optional): The torch device to use for computation. Defaults to None. + + Returns: + tuple: A tuple containing the computed masks and the final pixel locations. + """ cp_mask = cellprob > cellprob_threshold if np.any(cp_mask): #mask at this point is a cell cluster binary map, not labels @@ -877,9 +856,6 @@ def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0, p = np.zeros((len(shape), *shape), np.uint16) return mask, p - # moving the cleanup to the end helps avoid some bugs arising from scaling... - # maybe better would be to rescale the min_size and hole_size parameters to do the - # cleanup at the prediction scale, or switch depending on which one is bigger... mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size) if mask.dtype == np.uint32: diff --git a/cellpose/gui/gui.py b/cellpose/gui/gui.py index efb70bc4..7dcfa20f 100644 --- a/cellpose/gui/gui.py +++ b/cellpose/gui/gui.py @@ -2051,7 +2051,14 @@ def compute_saturation(self, return_img=False): self.saturation[-1].append([x01, x99]) else: for n in range(self.NZ): - self.saturation[-1].append([0, 255]) + self.saturation[-1].append([0, 255.]) + # if only 2 restore channels, add blue + if len(self.saturation) < 3: + for i in range(3 - len(self.saturation)): + self.saturation.append([]) + for n in range(self.NZ): + self.saturation[-1].append([0, 255.]) + print(self.saturation[2][self.currentZ]) if invert: img_norm = 255. - img_norm @@ -2407,8 +2414,6 @@ def compute_segmentation(self, custom=False, model_name=None): channels = self.get_channels() if self.restore is not None and self.restore != "filter": data = self.stack_filtered.copy().squeeze() - if channels[1] != 0: - channels = [1, 2] # assuming aligned with denoising else: data = self.stack.copy().squeeze() flow_threshold, cellprob_threshold = self.get_thresholds() diff --git a/cellpose/gui/io.py b/cellpose/gui/io.py index 8e7a1885..fab19842 100644 --- a/cellpose/gui/io.py +++ b/cellpose/gui/io.py @@ -299,6 +299,9 @@ def _load_seg(parent, filename=None, image=None, image_file=None, load_3D=False) else: parent.filename = image_file + parent.restore = None + parent.ratio = 1. + if "normalize_params" in dat: parent.restore = None if "restore" not in dat else dat["restore"] print(f"GUI_INFO: restore: {parent.restore}") @@ -306,7 +309,31 @@ def _load_seg(parent, filename=None, image=None, image_file=None, load_3D=False) parent.set_restore_button() if "img_restore" in dat: - parent.stack_filtered = dat["img_restore"] + img = dat["img_restore"] + img_min = img.min() + img_max = img.max() + parent.stack_filtered = img.astype("float32") + parent.stack_filtered -= img_min + if img_max > img_min + 1e-3: + parent.stack_filtered /= (img_max - img_min) + parent.stack_filtered *= 255 + if parent.stack_filtered.ndim < 4: + parent.stack_filtered = parent.stack_filtered[np.newaxis,...] + if parent.stack_filtered.ndim < 4: + parent.stack_filtered = parent.stack_filtered[...,np.newaxis] + shape = parent.stack_filtered.shape + if shape[-1] == 2: + if "chan_choose" in dat: + channels = np.array(dat["chan_choose"]) - 1 + img = np.zeros((*shape[:-1], 3), dtype="float32") + img[..., channels] = parent.stack_filtered + parent.stack_filtered = img + else: + parent.stack_filtered = np.concatenate( + (parent.stack_filtered, np.zeros((*shape[:-1], 1), dtype="float32")), axis=-1) + elif shape[-1] > 3: + parent.stack_filtered = parent.stack_filtered[..., :3] + parent.restore = dat["restore"] parent.ViewDropDown.model().item(parent.ViewDropDown.count() - 1).setEnabled(True) @@ -314,8 +341,11 @@ def _load_seg(parent, filename=None, image=None, image_file=None, load_3D=False) if parent.restore and "upsample" in parent.restore: print(parent.stack_filtered.shape, image.shape) parent.ratio = dat["ratio"] + + parent.set_restore_button() _initialize_images(parent, image, load_3D=load_3D) + print(parent.stack.shape, parent.stack_filtered.shape) if "chan_choose" in dat: parent.ChannelChoose[0].setCurrentIndex(dat["chan_choose"][0]) parent.ChannelChoose[1].setCurrentIndex(dat["chan_choose"][1]) diff --git a/cellpose/io.py b/cellpose/io.py index aae1266c..2abcd6b4 100644 --- a/cellpose/io.py +++ b/cellpose/io.py @@ -469,7 +469,8 @@ def load_train_test_data(train_dir, test_dir=None, image_filter=None, return images, labels, image_names, test_images, test_labels, test_image_names -def masks_flows_to_seg(images, masks, flows, file_names, diams=30., channels=None): +def masks_flows_to_seg(images, masks, flows, file_names, diams=30., channels=None, + imgs_restore=None, restore_type=None, ratio=1.): """Save output of model eval to be loaded in GUI. Can be list output (run on multiple images) or single output (run on single image). @@ -494,13 +495,16 @@ def masks_flows_to_seg(images, masks, flows, file_names, diams=30., channels=Non if isinstance(masks, list): if not isinstance(diams, (list, np.ndarray)): diams = diams * np.ones(len(masks), np.float32) + if imgs_restore is None: + imgs_restore = [] * len(masks) for k, [image, mask, flow, diam, - file_name] in enumerate(zip(images, masks, flows, diams, file_names)): + file_name, img_restore] in enumerate(zip(images, masks, flows, diams, file_names, imgs_restore)): channels_img = channels if channels_img is not None and len(channels) > 2: channels_img = channels[k] masks_flows_to_seg(image, mask, flow, file_name, diams=diam, - channels=channels_img) + channels=channels_img, imgs_restore=img_restore, + restore_type=restore_type, ratio=ratio) return if len(channels) == 1: @@ -531,53 +535,30 @@ def masks_flows_to_seg(images, masks, flows, file_names, diams=30., channels=Non flowi.append(np.concatenate((flows[1], flows[2][np.newaxis, ...]), axis=0)) outlines = masks * utils.masks_to_outlines(masks) base = os.path.splitext(file_names)[0] - if masks.ndim == 3: - np.save( - base + "_seg.npy", { - "outlines": - outlines.astype(np.uint16) - if outlines.max() < 2**16 - 1 else outlines.astype(np.uint32), - "masks": - masks.astype(np.uint16) - if outlines.max() < 2**16 - 1 else masks.astype(np.uint32), - "chan_choose": - channels, - "img": - images, - "ismanual": - np.zeros(masks.max(), bool), - "filename": - file_names, - "flows": - flowi, - "est_diam": - diams - }) - else: - if images.shape[0] < 8: - np.transpose(images, (1, 2, 0)) - np.save( - base + "_seg.npy", { - "img": - images, - "outlines": - outlines.astype(np.uint16) - if outlines.max() < 2**16 - 1 else outlines.astype(np.uint32), - "masks": - masks.astype(np.uint16) - if masks.max() < 2**16 - 1 else masks.astype(np.uint32), - "chan_choose": - channels, - "ismanual": - np.zeros(masks.max(), bool), - "filename": - file_names, - "flows": - flowi, - "est_diam": - diams - }) + dat = {"outlines": + outlines.astype(np.uint16) + if outlines.max() < 2**16 - 1 else outlines.astype(np.uint32), + "masks": + masks.astype(np.uint16) + if outlines.max() < 2**16 - 1 else masks.astype(np.uint32), + "chan_choose": + channels, + "ismanual": + np.zeros(masks.max(), bool), + "filename": + file_names, + "flows": + flowi, + "diameter": + diams + } + if restore_type is not None and imgs_restore is not None: + dat["restore"] = restore_type + dat["ratio"] = ratio + dat["img_restore"] = imgs_restore + + np.save(base + "_seg.npy", dat) def save_to_png(images, masks, flows, file_names): """ deprecated (runs io.save_masks with png=True) diff --git a/tests/test_denoise.py b/tests/test_denoise.py index a011fbc1..5b32d453 100644 --- a/tests/test_denoise.py +++ b/tests/test_denoise.py @@ -34,7 +34,7 @@ def test_class_2D(data_dir, image_names): img_restore = model.eval(img, diameter=diams[m], channels=[chan[m], chan2[m]]) assert img_restore.shape == shapes[m] - io.imsave(str(data_dir.joinpath("2D").joinpath(f"rgb_2D_{model_type}.tif")), img_restore) + io.imsave(str(data_dir.joinpath("2D").joinpath(f"gray_2D_{model_type}.tif")), img_restore) clear_output(data_dir, image_names) @@ -60,10 +60,10 @@ def test_dn_cp_class_2D(data_dir, image_names): def test_cli_2D(data_dir, image_names): clear_output(data_dir, image_names) model_types = ["denoise_cyto3"] - chan = [1] - chan2 = [2] + chan = [2] + chan2 = [1] for m, model_type in enumerate(model_types): - cmd = "python -m cellpose --dir %s --pretrained_model %s --restore_type %s --chan %d --chan2 %d --chan2_denoise --diameter 30 --no_interp --save_npy" % ( + cmd = "python -m cellpose --dir %s --pretrained_model %s --restore_type %s --chan %d --chan2 %d --chan2_restore --diameter 30" % ( str(data_dir.joinpath("2D")), "cyto3", model_type, chan[m], chan2[m]) try: cmd_stdout = check_output(cmd, stderr=STDOUT, shell=True).decode() @@ -71,6 +71,5 @@ def test_cli_2D(data_dir, image_names): except Exception as e: print(e) raise ValueError(e) - compare_masks(data_dir, image_names, "2D", model_type) clear_output(data_dir, image_names) \ No newline at end of file