diff --git a/python/cucim/src/cucim/skimage/_shared/utils.py b/python/cucim/src/cucim/skimage/_shared/utils.py index b4371ff1b..51b25e198 100644 --- a/python/cucim/src/cucim/skimage/_shared/utils.py +++ b/python/cucim/src/cucim/skimage/_shared/utils.py @@ -9,17 +9,48 @@ from ._warnings import all_warnings, warn # noqa +__all__ = ['deprecate_func', 'get_bound_method_class', 'all_warnings', + 'safe_as_int', 'check_shape_equality', 'check_nD', 'warn', + 'reshape_nd', 'identity', 'slice_at_axis'] -class skimage_deprecation(Warning): - """Create our own deprecation class, since Python >= 2.7 - silences deprecations by default. +def _get_stack_rank(func): + """Return function rank in the call stack.""" + if _is_wrapped(func): + return 1 + _get_stack_rank(func.__wrapped__) + else: + return 0 + + +def _is_wrapped(func): + return "__wrapped__" in dir(func) + + +def _get_stack_length(func): + """Return function call stack length.""" + return _get_stack_rank(func.__globals__.get(func.__name__, func)) + + +class _DecoratorBaseClass: + """Used to manage decorators' warnings stacklevel. + + The `_stack_length` class variable is used to store the number of + times a function is wrapped by a decorator. + + Let `stack_length` be the total number of times a decorated + function is wrapped, and `stack_rank` be the rank of the decorator + in the decorators stack. The stacklevel of a warning is then + `stacklevel = 1 + stack_length - stack_rank`. """ - pass + _stack_length = {} + + def get_stack_length(self, func): + return self._stack_length.get(func.__name__, + _get_stack_length(func)) -class change_default_value: +class change_default_value(_DecoratorBaseClass): """Decorator for changing the default value of an argument. Parameters @@ -48,6 +79,7 @@ def __call__(self, func): arg_idx = list(parameters.keys()).index(self.arg_name) old_value = parameters[self.arg_name].default + stack_rank = _get_stack_rank(func) if self.warning_msg is None: self.warning_msg = ( f"The new recommended value for {self.arg_name} is " @@ -59,15 +91,17 @@ def __call__(self, func): @functools.wraps(func) def fixed_func(*args, **kwargs): + stacklevel = 1 + self.get_stack_length(func) - stack_rank if len(args) < arg_idx + 1 and self.arg_name not in kwargs.keys(): # warn that arg_name default value changed: - warnings.warn(self.warning_msg, FutureWarning, stacklevel=2) + warnings.warn(self.warning_msg, FutureWarning, + stacklevel=stacklevel) return func(*args, **kwargs) return fixed_func -class remove_arg: +class remove_arg(_DecoratorBaseClass): """Decorator to remove an argument from function's signature. Parameters @@ -100,21 +134,25 @@ def __call__(self, func): if self.help_msg is not None: warning_msg += f" {self.help_msg}" + stack_rank = _get_stack_rank(func) + @functools.wraps(func) def fixed_func(*args, **kwargs): + stacklevel = 1 + self.get_stack_length(func) - stack_rank if len(args) > arg_idx or self.arg_name in kwargs.keys(): # warn that arg_name is deprecated - warnings.warn(warning_msg, FutureWarning, stacklevel=2) + warnings.warn(warning_msg, FutureWarning, + stacklevel=stacklevel) return func(*args, **kwargs) return fixed_func -def docstring_add_deprecated(func, kwarg_mapping, deprecated_version): +def _docstring_add_deprecated(func, kwarg_mapping, deprecated_version): """Add deprecated kwarg(s) to the "Other Params" section of a docstring. Parameters - --------- + ---------- func : function The function whose docstring we wish to update. kwarg_mapping : dict @@ -176,7 +214,7 @@ def docstring_add_deprecated(func, kwarg_mapping, deprecated_version): return final_docstring -class deprecate_kwarg: +class deprecate_kwarg(_DecoratorBaseClass): """Decorator ensuring backward compatibility when argument names are modified in a function definition. @@ -213,14 +251,18 @@ def __init__(self, kwarg_mapping, deprecated_version, warning_msg=None, def __call__(self, func): + stack_rank = _get_stack_rank(func) + @functools.wraps(func) def fixed_func(*args, **kwargs): + stacklevel = 1 + self.get_stack_length(func) - stack_rank for old_arg, new_arg in self.kwarg_mapping.items(): if old_arg in kwargs: # warn that the function interface has changed: warnings.warn(self.warning_msg.format( old_arg=old_arg, func_name=func.__name__, - new_arg=new_arg), FutureWarning, stacklevel=2) + new_arg=new_arg), FutureWarning, + stacklevel=stacklevel) # Substitute new_arg to old_arg kwargs[new_arg] = kwargs.pop(old_arg) @@ -228,13 +270,13 @@ def fixed_func(*args, **kwargs): return func(*args, **kwargs) if func.__doc__ is not None: - newdoc = docstring_add_deprecated(func, self.kwarg_mapping, - self.deprecated_version) + newdoc = _docstring_add_deprecated(func, self.kwarg_mapping, + self.deprecated_version) fixed_func.__doc__ = newdoc return fixed_func -class channel_as_last_axis(): +class channel_as_last_axis: """Decorator for automatically making channels axis last for all arrays. This decorator reorders axes for compatibility with functions that only @@ -311,53 +353,66 @@ def fixed_func(*args, **kwargs): return fixed_func -class deprecated(object): - """Decorator to mark deprecated functions with warning. +class deprecate_func(_DecoratorBaseClass): + """Decorate a deprecated function and warn when it is called. Adapted from . Parameters ---------- - alt_func : str - If given, tell user what function to use instead. - behavior : {'warn', 'raise'} - Behavior during call to deprecated function: 'warn' = warn user that - function is deprecated; 'raise' = raise error. + deprecated_version : str + The package version when the deprecation was introduced. removed_version : str The package version in which the deprecated function will be removed. + hint : str, optional + A hint on how to address this deprecation, + e.g., "Use `skimage.submodule.alternative_func` instead." + Examples + -------- + >>> @deprecate_func( + ... deprecated_version="1.0.0", + ... removed_version="1.2.0", + ... hint="Use `bar` instead." + ... ) + ... def foo(): + ... pass + Calling ``foo`` will warn with:: + FutureWarning: `foo` is deprecated since version 1.0.0 + and will be removed in version 1.2.0. Use `bar` instead. """ - def __init__(self, alt_func=None, behavior='warn', removed_version=None): - self.alt_func = alt_func - self.behavior = behavior + def __init__(self, *, deprecated_version, removed_version=None, hint=None): + self.deprecated_version = deprecated_version self.removed_version = removed_version + self.hint = hint def __call__(self, func): - alt_msg = '' - if self.alt_func is not None: - alt_msg = f' Use ``{self.alt_func}`` instead.' - rmv_msg = '' - if self.removed_version is not None: - rmv_msg = f' and will be removed in version {self.removed_version}' + message = ( + f"`{func.__name__}` is deprecated since version " + f"{self.deprecated_version}" + ) + if self.removed_version: + message += ( + f" and will be removed in version {self.removed_version}." + ) + if self.hint: + message += f" {self.hint.rstrip('.')}." - msg = f'Function ``{func.__name__}`` is deprecated{rmv_msg}.{alt_msg}' + stack_rank = _get_stack_rank(func) @functools.wraps(func) def wrapped(*args, **kwargs): - if self.behavior == 'warn': - func_code = func.__code__ - warnings.simplefilter('always', skimage_deprecation) - warnings.warn_explicit(msg, - category=skimage_deprecation, - filename=func_code.co_filename, - lineno=func_code.co_firstlineno + 1) - elif self.behavior == 'raise': - raise skimage_deprecation(msg) + stacklevel = 1 + self.get_stack_length(func) - stack_rank + warnings.warn( + message, + category=FutureWarning, + stacklevel=stacklevel + ) return func(*args, **kwargs) # modify doc string to display deprecation warning - doc = '**Deprecated function**.' + alt_msg + doc = f'**Deprecated:** {message}' if wrapped.__doc__ is None: wrapped.__doc__ = doc else: diff --git a/python/cucim/src/cucim/skimage/color/__init__.py b/python/cucim/src/cucim/skimage/color/__init__.py index 390668712..ef9049f34 100644 --- a/python/cucim/src/cucim/skimage/color/__init__.py +++ b/python/cucim/src/cucim/skimage/color/__init__.py @@ -9,12 +9,14 @@ rgb_from_bex, rgb_from_bpx, rgb_from_bro, rgb_from_fgx, rgb_from_gdx, rgb_from_hax, rgb_from_hdx, rgb_from_hed, rgb_from_hpx, rgb_from_rbd, rgba2rgb, rgbcie2rgb, - separate_stains, xyz2lab, xyz2luv, xyz2rgb, ycbcr2rgb, - ydbdr2rgb, yiq2rgb, ypbpr2rgb, yuv2rgb) + separate_stains, xyz2lab, xyz2luv, xyz2rgb, + xyz_tristimulus_values, ycbcr2rgb, ydbdr2rgb, yiq2rgb, + ypbpr2rgb, yuv2rgb) from .colorlabel import color_dict, label2rgb from .delta_e import deltaE_cie76, deltaE_ciede94, deltaE_ciede2000, deltaE_cmc __all__ = ['convert_colorspace', + 'xyz_tristimulus_values', 'rgba2rgb', 'rgb2hsv', 'hsv2rgb', diff --git a/python/cucim/src/cucim/skimage/color/colorconv.py b/python/cucim/src/cucim/skimage/color/colorconv.py index ece630492..f7a1e3678 100644 --- a/python/cucim/src/cucim/skimage/color/colorconv.py +++ b/python/cucim/src/cucim/skimage/color/colorconv.py @@ -56,7 +56,7 @@ from scipy import linalg from .._shared.utils import (_supported_float_type, channel_as_last_axis, - identity) + deprecate_func, identity) from ..util import dtype, dtype_limits @@ -546,7 +546,7 @@ def hsv2rgb(hsv, *, channel_axis=-1): # ---------- # .. [1] https://en.wikipedia.org/wiki/Standard_illuminant -illuminants = \ +_illuminants = \ {"A": {'2': (1.098466069456375, 1, 0.3558228003436005), '10': (1.111420406956693, 1, 0.3519978321919493), 'R': (1.098466069456375, 1, 0.3558228003436005)}, @@ -573,7 +573,83 @@ def hsv2rgb(hsv, *, channel_axis=-1): 'R': (1.0, 1.0, 1.0)}} -def get_xyz_coords(illuminant, observer): +def xyz_tristimulus_values(*, illuminant, observer, dtype=None): + """Get the CIE XYZ tristimulus values. + + Given an illuminant and observer, this function returns the CIE XYZ + tristimulus values [2]_ scaled such that :math:`Y = 1`. + + Parameters + ---------- + illuminant : {"A", "B", "C", "D50", "D55", "D65", "D75", "E"} + The name of the illuminant (the function is NOT case sensitive). + observer : {"2", "10", "R"} + One of: 2-degree observer, 10-degree observer, or 'R' observer as in + R function ``grDevices::convertColor`` [3]_. + dtype : np.dtype, optional + This argument is ignored in the cuCIM implementation of + `xyz_tristimulus_values` since an array is not returned. The output is + always a 3-tuple of float. + + Returns + ------- + values : 3-tuple of float + Three elements :math:`X, Y, Z` containing the CIE XYZ tristimulus values + of the given illuminant. + + Raises + ------ + ValueError + If either the illuminant or the observer angle are not supported or + unknown. + + References + ---------- + .. [1] https://en.wikipedia.org/wiki/Standard_illuminant#White_points_of_standard_illuminants + .. [2] https://en.wikipedia.org/wiki/CIE_1931_color_space#Meaning_of_X,_Y_and_Z + .. [3] https://www.rdocumentation.org/packages/grDevices/versions/3.6.2/topics/convertColor + + Notes + ----- + The return type of this function differs from the one in scikit-image as it + always returns a 3-tuple of float rather than an array with a + user-specified dtype. + + The CIE XYZ tristimulus values are calculated from :math:`x, y` [1]_, using the + formula + + .. math:: X = x / y + + .. math:: Y = 1 + + .. math:: Z = (1 - x - y) / y + + The only exception is the illuminant "D65" with aperture angle 2° for + backward-compatibility reasons. + + Examples + -------- + Get the CIE XYZ tristimulus values for a "D65" illuminant for a 10 degree + field of view + + >>> xyz_tristimulus_values(illuminant="D65", observer="10") + array([0.94809668, 1. , 1.07305136]) + """ # noqa + illuminant = illuminant.upper() + observer = observer.upper() + try: + return _illuminants[illuminant][observer] + except KeyError: + raise ValueError(f'Unknown illuminant/observer combination ' + f'(`{illuminant}`, `{observer}`)') + + +@deprecate_func( + hint="Use `skimage.color.xyz_tristimulus_values` instead.", + deprecated_version="23.08", + removed_version="24.06", +) +def get_xyz_coords(illuminant, observer, dtype=float): """Get the XYZ coordinates of the given illuminant and observer [1]_. Parameters @@ -602,13 +678,7 @@ def get_xyz_coords(illuminant, observer): ---------- .. [1] https://en.wikipedia.org/wiki/Standard_illuminant """ - illuminant = illuminant.upper() - observer = observer.upper() - try: - return illuminants[illuminant][observer] - except KeyError: - raise ValueError(f'Unknown illuminant/observer combination ' - f'(`{illuminant}`, `{observer}`)') + return xyz_tristimulus_values(illuminant=illuminant, observer=observer) # Haematoxylin-Eosin-DAB colorspace @@ -625,7 +695,6 @@ def get_xyz_coords(illuminant, observer): [0.27, 0.57, 0.78]]) hed_from_rgb = linalg.inv(rgb_from_hed) - # Following matrices are adapted form the Java code written by G.Landini. # The original code is available at: # https://web.archive.org/web/20160624145052/http://www.mecourse.com/landinig/software/cdeconv/cdeconv.html @@ -1071,9 +1140,12 @@ def gray2rgb(image, *, channel_axis=-1): @cp.memoize(for_each_device=True) def _get_xyz_to_lab_kernel(xyz_ref_white, name='xyz2lab'): _xyz_to_lab = f""" + // scale by CIE XYZ tristimulus values of the reference white point arr[3*i] /= {xyz_ref_white[0]}; arr[3*i + 1] /= {xyz_ref_white[1]}; arr[3*i + 2] /= {xyz_ref_white[2]}; + + // Nonlinear distortion and linear transformation for (int ch=0; ch < 3; ch++) {{ if (arr[3*i + ch] > 0.008856) @@ -1083,6 +1155,8 @@ def _get_xyz_to_lab_kernel(xyz_ref_white, name='xyz2lab'): arr[3*i + ch] = 7.787 * arr[3*i + ch] + 16.0 / 116.0; }} }} + + // Vector scaling lab[3*i] = (116. * arr[3*i + 1]) - 16.0; lab[3*i + 1] = 500.0 * (arr[3*i] - arr[3*i + 1]); lab[3*i + 2] = 200.0 * (arr[3*i + 1] - arr[3*i + 2]); @@ -1130,13 +1204,13 @@ def xyz2lab(xyz, illuminant="D65", observer="2", *, channel_axis=-1): Notes ----- By default Observer="2", Illuminant="D65". CIE XYZ tristimulus values - x_ref=95.047, y_ref=100., z_ref=108.883. See function `get_xyz_coords` for - a list of supported illuminants. + x_ref=95.047, y_ref=100., z_ref=108.883. See function + :func:`~.xyz_tristimulus_values` for a list of supported illuminants. References ---------- - .. [1] http://www.easyrgb.com/index.php?X=MATH&H=07 - .. [2] https://en.wikipedia.org/wiki/Lab_color_space + .. [1] http://www.easyrgb.com/en/math.php + .. [2] https://en.wikipedia.org/wiki/CIELAB_color_space Examples -------- @@ -1150,8 +1224,9 @@ def xyz2lab(xyz, illuminant="D65", observer="2", *, channel_axis=-1): xyz = _prepare_colorarray(xyz, force_copy=True, force_c_contiguous=True, channel_axis=channel_axis) - xyz_ref_white = get_xyz_coords(illuminant, observer) - + xyz_ref_white = xyz_tristimulus_values( + illuminant=illuminant, observer=observer + ) name = f'xyz2lab_{xyz.dtype.char}' kern = _get_xyz_to_lab_kernel(xyz_ref_white, name=name) lab = cp.empty_like(xyz) @@ -1222,7 +1297,7 @@ def lab2xyz(lab, illuminant="D65", observer="2", *, channel_axis=-1): Returns ------- out : (..., 3, ...) ndarray - The image in XYZ color space. Same dimensions as input. + The image in XYZ color space, of same shape as input. Raises ------ @@ -1237,18 +1312,49 @@ def lab2xyz(lab, illuminant="D65", observer="2", *, channel_axis=-1): Notes ----- The CIE XYZ tristimulus values are x_ref = 95.047, y_ref = 100., and - z_ref = 108.883. See function :func:`~.get_xyz_coords` for a list of + z_ref = 108.883. See function :func:`~.xyz_tristimulus_values` for a list of supported illuminants. + See Also + -------- + xyz2lab + References ---------- - .. [1] http://www.easyrgb.com/index.php?X=MATH&H=07 + .. [1] http://www.easyrgb.com/en/math.php .. [2] https://en.wikipedia.org/wiki/CIELAB_color_space """ + xyz, n_invalid = _lab2xyz(lab, illuminant, observer, channel_axis) + if n_invalid > 0: + warn( + "Conversion from CIE-LAB to XYZ color space resulted in " + f"{n_invalid} negative Z values that have been clipped to zero", + stacklevel=3, + ) + + return xyz + + +def _lab2xyz(lab, illuminant, observer, channel_axis): + """Convert CIE-LAB to XYZ color space. + + Internal function for :func:`~.lab2xyz` and others. In addition to the + converted image, return the number of invalid pixels in the Z channel for + correct warning propagation. + + Returns + ------- + out : (..., 3, ...) ndarray + The image in XYZ format. Same dimensions as input. + n_invalid : int + Number of invalid pixels in the Z channel after conversion. + """ lab = _prepare_colorarray(lab, force_c_contiguous=True, channel_axis=channel_axis) - xyz_ref_white = get_xyz_coords(illuminant, observer) + xyz_ref_white = xyz_tristimulus_values( + illuminant=illuminant, observer=observer + ) name = f'lab2xyz_{lab.dtype.char}' kern = _get_lab_to_xyz_kernel(xyz_ref_white, name=name) @@ -1258,12 +1364,8 @@ def lab2xyz(lab, illuminant="D65", observer="2", *, channel_axis=-1): # operations? warnings = cp.zeros(lab.shape[:-1], dtype=np.int32) kern(lab, xyz, warnings, size=lab.size // 3) - - nwarn = int(cp.count_nonzero(warnings)) - if nwarn > 0: # synchronize! - warn(f'Color data out of range: Z < 0 in {nwarn} pixels', - stacklevel=3) - return xyz + n_invalid = int(cp.count_nonzero(warnings)) # synchronize! + return xyz, n_invalid @channel_as_last_axis() @@ -1302,8 +1404,8 @@ def rgb2lab(rgb, illuminant="D65", observer="2", *, channel_axis=-1): This function uses rgb2xyz and xyz2lab. By default Observer="2", Illuminant="D65". CIE XYZ tristimulus values - x_ref=95.047, y_ref=100., z_ref=108.883. See function `get_xyz_coords` for - a list of supported illuminants. + x_ref=95.047, y_ref=100., z_ref=108.883. See function + :func:`~.xyz_tristimulus_values` for a list of supported illuminants. References ---------- @@ -1335,7 +1437,7 @@ def lab2rgb(lab, illuminant="D65", observer="2", *, channel_axis=-1): Returns ------- out : (..., 3, ...) ndarray - The image in RGB format. Same dimensions as input. + The image in sRGB color space, of same shape as input. Raises ------ @@ -1344,17 +1446,28 @@ def lab2rgb(lab, illuminant="D65", observer="2", *, channel_axis=-1): Notes ----- - This function uses lab2xyz and xyz2rgb. + This function uses :func:`~.lab2xyz` and :func:`~.xyz2rgb`. The CIE XYZ tristimulus values are x_ref = 95.047, y_ref = 100., and - z_ref = 108.883. See function :func:`~.get_xyz_coords` for a list of + z_ref = 108.883. See function :func:`~.xyz_tristimulus_values` for a list of supported illuminants. + See Also + -------- + rgb2lab + References ---------- .. [1] https://en.wikipedia.org/wiki/Standard_illuminant .. [2] https://en.wikipedia.org/wiki/CIELAB_color_space """ - return xyz2rgb(lab2xyz(lab, illuminant, observer)) + xyz, n_invalid = _lab2xyz(lab, illuminant, observer, channel_axis) + if n_invalid != 0: + warn( + "Conversion from CIE-LAB, via XYZ to sRGB color space resulted in " + f"{n_invalid} negative Z values that have been clipped to zero", + stacklevel=3, + ) + return xyz2rgb(xyz, channel_axis=channel_axis) @cp.memoize(for_each_device=True) @@ -1442,12 +1555,12 @@ def xyz2luv(xyz, illuminant="D65", observer="2", *, channel_axis=-1): ----- By default XYZ conversion weights use observer=2A. Reference whitepoint for D65 Illuminant, with XYZ tristimulus values of ``(95.047, 100., - 108.883)``. See function 'get_xyz_coords' for a list of supported - illuminants. + 108.883)``. See function :func:`~.xyz_tristimulus_values` for a list of + supported illuminants. References ---------- - .. [1] http://www.easyrgb.com/index.php?X=MATH&H=16#text16 + .. [1] http://www.easyrgb.com/en/math.php .. [2] https://en.wikipedia.org/wiki/CIELUV Examples @@ -1466,7 +1579,9 @@ def xyz2luv(xyz, illuminant="D65", observer="2", *, channel_axis=-1): xyz = _prepare_colorarray(xyz, force_c_contiguous=True, channel_axis=channel_axis) - xyz_ref_white = get_xyz_coords(illuminant, observer) + xyz_ref_white = xyz_tristimulus_values( + illuminant=illuminant, observer=observer + ) kern = _get_xyz_to_luv_kernel(xyz_ref_white, xyz.dtype) luv = cp.empty_like(xyz) kern(xyz, luv, size=xyz.size // 3) @@ -1546,16 +1661,19 @@ def luv2xyz(luv, illuminant="D65", observer="2", *, channel_axis=-1): ----- XYZ conversion weights use observer=2A. Reference whitepoint for D65 Illuminant, with XYZ tristimulus values of ``(95.047, 100., 108.883)``. See - function 'get_xyz_coords' for a list of supported illuminants. + function :func:`~.xyz_tristimulus_values` for a list of supported + illuminants. References ---------- - .. [1] http://www.easyrgb.com/index.php?X=MATH&H=16#text16 + .. [1] http://www.easyrgb.com/en/math.php .. [2] https://en.wikipedia.org/wiki/CIELUV """ luv = _prepare_colorarray(luv, force_c_contiguous=True, channel_axis=channel_axis) - xyz_ref_white = get_xyz_coords(illuminant, observer) + xyz_ref_white = xyz_tristimulus_values( + illuminant=illuminant, observer=observer + ) kern = _get_luv_to_xyz_kernel(xyz_ref_white, luv.dtype) xyz = cp.empty_like(luv) kern(luv, xyz, size=luv.size // 3) @@ -1591,9 +1709,8 @@ def rgb2luv(rgb, *, channel_axis=-1): References ---------- - .. [1] http://www.easyrgb.com/index.php?X=MATH&H=16#text16 - .. [2] http://www.easyrgb.com/index.php?X=MATH&H=02#text2 - .. [3] https://en.wikipedia.org/wiki/CIELUV + .. [1] http://www.easyrgb.com/en/math.php + .. [2] https://en.wikipedia.org/wiki/CIELUV """ return xyz2luv(rgb2xyz(rgb)) diff --git a/python/cucim/src/cucim/skimage/color/delta_e.py b/python/cucim/src/cucim/skimage/color/delta_e.py index 14ed971db..bf07162a8 100644 --- a/python/cucim/src/cucim/skimage/color/delta_e.py +++ b/python/cucim/src/cucim/skimage/color/delta_e.py @@ -29,7 +29,7 @@ def _float_inputs(lab1, lab2, allow_float32=True): if allow_float32: - float_dtype = _supported_float_type([lab1.dtype, lab2.dtype]) + float_dtype = _supported_float_type((lab1.dtype, lab2.dtype)) else: float_dtype = cp.float64 lab1 = lab1.astype(float_dtype, copy=False) @@ -375,7 +375,7 @@ def get_dH2(lab1, lab2, *, channel_axis=-1): """ # This function needs double precision internally for accuracy input_is_float_32 = _supported_float_type( - [lab1.dtype, lab2.dtype] + (lab1.dtype, lab2.dtype) ) == cp.float32 lab1, lab2 = _float_inputs(lab1, lab2, allow_float32=False) a1, b1 = cp.moveaxis(lab1, source=channel_axis, destination=0)[1:3] diff --git a/python/cucim/src/cucim/skimage/color/tests/test_colorconv.py b/python/cucim/src/cucim/skimage/color/tests/test_colorconv.py index 0d50cf119..f3bd45a1e 100644 --- a/python/cucim/src/cucim/skimage/color/tests/test_colorconv.py +++ b/python/cucim/src/cucim/skimage/color/tests/test_colorconv.py @@ -666,9 +666,13 @@ def test_lab_full_gamut(self): a, b = cp.meshgrid(cp.arange(-100, 100), cp.arange(-100, 100)) L = cp.ones(a.shape) lab = cp.dstack((L, a, b)) + regex = ( + "Conversion from CIE-LAB to XYZ color space resulted in " + "\\d+ negative Z values that have been clipped to zero" + ) for value in [0, 10, 20]: lab[:, :, 0] = value - with expected_warnings(['Color data out of range']): + with pytest.warns(UserWarning, match=regex): lab2xyz(lab) @pytest.mark.parametrize("channel_axis", [0, 1, -1, -2]) @@ -800,6 +804,17 @@ def test_rgb2yiq_conversion(self): ) assert_array_almost_equal(yiq, gt, decimal=2) + @pytest.mark.parametrize("func", [lab2rgb, lab2xyz]) + def test_warning_stacklevel(self, func): + regex = ( + "Conversion from CIE-LAB.* XYZ.*color space resulted in " + "1 negative Z values that have been clipped to zero" + ) + with pytest.warns(UserWarning, match=regex) as messages: + func(lab=cp.array([[[0, 0, 300.]]])) + assert len(messages) == 1 + assert messages[0].filename == __file__, "warning points at wrong file" + def test_gray2rgb(): x = cp.asarray([0, 0.5, 1]) diff --git a/python/cucim/src/cucim/skimage/exposure/_adapthist.py b/python/cucim/src/cucim/skimage/exposure/_adapthist.py index 4d42e2ae3..939601973 100644 --- a/python/cucim/src/cucim/skimage/exposure/_adapthist.py +++ b/python/cucim/src/cucim/skimage/exposure/_adapthist.py @@ -4,14 +4,7 @@ http://tog.acm.org/resources/GraphicsGems/ -The Graphics Gems code is copyright-protected. In other words, you cannot -claim the text of the code as your own and resell it. Using the code is -permitted in any program, product, or library, non-commercial or commercial. -Giving credit is not required, though is a nice gesture. The code comes as-is, -and if there are any flaws or problems with any Gems code, nobody involved with -Gems - authors, editors, publishers, or webmasters - are to be held -responsible. Basically, don't be a jerk, and remember that anything free -comes with no guarantee. +Relicensed with permission of the author under the Modified BSD license. """ import functools import itertools @@ -97,7 +90,7 @@ def equalize_adapthist(image, kernel_size=None, elif isinstance(kernel_size, numbers.Number): kernel_size = (kernel_size,) * image.ndim elif len(kernel_size) != image.ndim: - ValueError(f'Incorrect value of `kernel_size`: {kernel_size}') + raise ValueError(f'Incorrect value of `kernel_size`: {kernel_size}') kernel_size = [int(k) for k in kernel_size] diff --git a/python/cucim/src/cucim/skimage/exposure/exposure.py b/python/cucim/src/cucim/skimage/exposure/exposure.py index 4230f36c9..c6038cae4 100644 --- a/python/cucim/src/cucim/skimage/exposure/exposure.py +++ b/python/cucim/src/cucim/skimage/exposure/exposure.py @@ -72,8 +72,14 @@ def _bincount_histogram(image, source_range, bin_centers=None): """ if bin_centers is None: bin_centers = _bincount_histogram_centers(image, source_range) - image_min, image_max = bin_centers[0], bin_centers[-1] - image = _offset_array(image, image_min.item(), image_max.item()) # synchronize # noqa + image_min, image_max = bin_centers[0].item(), bin_centers[-1].item() + image = _offset_array(image, image_min, image_max) # synchronize # noqa + + # Casting back to unsigned dtype seems necessary to avoid incorrect + # results for larger integer ranges with CUDA 12.x. + unsigned_dtype_char = image.dtype.char.upper() + image = image.astype(unsigned_dtype_char, copy=False) + hist = cp.bincount( image.ravel(), minlength=image_max - min(image_min, 0) + 1 ) @@ -178,7 +184,9 @@ def _get_numpy_hist_range(image, source_range): elif source_range == 'dtype': hist_range = dtype_limits(image, clip_negative=False) else: - ValueError('Wrong value for the `source_range` argument') + raise ValueError( + f'Incorrect value for `source_range` argument: {source_range}' + ) return hist_range diff --git a/python/cucim/src/cucim/skimage/exposure/tests/test_exposure.py b/python/cucim/src/cucim/skimage/exposure/tests/test_exposure.py index 7087611f0..11f6fb673 100644 --- a/python/cucim/src/cucim/skimage/exposure/tests/test_exposure.py +++ b/python/cucim/src/cucim/skimage/exposure/tests/test_exposure.py @@ -25,7 +25,8 @@ def test_wrong_source_range(): im = cp.array([-1, 100], dtype=cp.int8) - with pytest.raises(ValueError): + match = "Incorrect value for `source_range` argument" + with pytest.raises(ValueError, match=match): frequencies, bin_centers = exposure.histogram( im, source_range="foobar" ) @@ -592,6 +593,12 @@ def norm_brightness_err(img1, img2): return nbe +def test_adapthist_incorrect_kernel_size(): + img = cp.ones((8, 8), dtype=float) + with pytest.raises(ValueError, match="Incorrect value of `kernel_size`"): + exposure.equalize_adapthist(img, (3, 3, 3)) + + # Test Gamma Correction # ===================== diff --git a/python/cucim/src/cucim/skimage/feature/__init__.py b/python/cucim/src/cucim/skimage/feature/__init__.py index d26dfb347..fcfa62c03 100644 --- a/python/cucim/src/cucim/skimage/feature/__init__.py +++ b/python/cucim/src/cucim/skimage/feature/__init__.py @@ -1,4 +1,3 @@ -from .._shared.utils import deprecated from ._basic_features import multiscale_basic_features from ._canny import canny from ._daisy import daisy diff --git a/python/cucim/src/cucim/skimage/feature/_canny.py b/python/cucim/src/cucim/skimage/feature/_canny.py index edc826fab..1d60dbf86 100644 --- a/python/cucim/src/cucim/skimage/feature/_canny.py +++ b/python/cucim/src/cucim/skimage/feature/_canny.py @@ -358,6 +358,9 @@ def canny(image, sigma=1., low_threshold=None, high_threshold=None, mask=None, # mask by one and then mask the output. We also mask out the border points # because who knows what lies beyond the edge of the image? + if (image.dtype.kind in 'iu' and image.dtype.itemsize >= 8): + raise ValueError("64-bit or larger integer images are not supported") + check_nD(image, 2) dtype_max = dtype_limits(image, clip_negative=False)[1] diff --git a/python/cucim/src/cucim/skimage/feature/tests/test_canny.py b/python/cucim/src/cucim/skimage/feature/tests/test_canny.py index 6e06114fc..8f3ab90ff 100644 --- a/python/cucim/src/cucim/skimage/feature/tests/test_canny.py +++ b/python/cucim/src/cucim/skimage/feature/tests/test_canny.py @@ -77,7 +77,7 @@ def test_mask_none(self): assert cp.all(result1 == result2) @cp.testing.with_requires("scikit-image>=0.18") - @pytest.mark.parametrize('image_dtype', [cp.uint8, cp.int64, cp.float32, + @pytest.mark.parametrize('image_dtype', [cp.uint8, cp.int32, cp.float32, cp.float64]) def test_use_quantiles(self, image_dtype): dtype = cp.dtype(image_dtype) @@ -146,3 +146,24 @@ def test_dtype(self): feature.canny(image_float, 1.0, low, high), feature.canny(image_uint8, 1.0, 255 * low, 255 * high) ) + + def test_full_mask_matches_no_mask(self): + """The masked and unmasked algorithms should return the same result. + + """ + image = cp.array(data.camera()) + + for mode in ('constant', 'nearest', 'reflect'): + cp.testing.assert_array_equal( + feature.canny(image, mode=mode), + feature.canny(image, mode=mode, + mask=cp.ones_like(image, dtype=bool)) + ) + + @pytest.mark.parametrize('dtype', (cp.int64, cp.uint64)) + def test_unsupported_int64(self, dtype): + image = cp.zeros((10, 10), dtype=dtype) + image[3, 3] = cp.iinfo(dtype).max + match = "64-bit or larger integer images are not supported" + with pytest.raises(ValueError, match=match): + feature.canny(image) diff --git a/python/cucim/src/cucim/skimage/filters/__init__.py b/python/cucim/src/cucim/skimage/filters/__init__.py index 29ebabac1..32cfe78d8 100644 --- a/python/cucim/src/cucim/skimage/filters/__init__.py +++ b/python/cucim/src/cucim/skimage/filters/__init__.py @@ -4,7 +4,9 @@ __name__, # submodules={'rank'}, submod_attrs={ - 'lpi_filter': ['filter_inverse', 'wiener', 'LPIFilter2D'], + 'lpi_filter': [ + 'filter_forward', 'filter_inverse', 'wiener', 'LPIFilter2D' + ], '_gaussian': ['gaussian', 'difference_of_gaussians'], 'edges': ['sobel', 'sobel_h', 'sobel_v', 'scharr', 'scharr_h', 'scharr_v', @@ -14,12 +16,14 @@ 'farid', 'farid_h', 'farid_v'], '_rank_order': ['rank_order'], '_gabor': ['gabor_kernel', 'gabor'], - 'thresholding': ['threshold_local', 'threshold_otsu', 'threshold_yen', - 'threshold_isodata', 'threshold_li', 'threshold_minimum', - 'threshold_mean', 'threshold_triangle', - 'threshold_niblack', 'threshold_sauvola', - 'threshold_multiotsu', 'try_all_threshold', - 'apply_hysteresis_threshold'], + 'thresholding': [ + 'threshold_local', 'threshold_otsu', 'threshold_yen', + 'threshold_isodata', 'threshold_li', 'threshold_minimum', + 'threshold_mean', 'threshold_triangle', + 'threshold_niblack', 'threshold_sauvola', + 'threshold_multiotsu', 'try_all_threshold', + 'apply_hysteresis_threshold' + ], 'ridges': ['meijering', 'sato', 'frangi', 'hessian'], '_median': ['median'], '_sparse': ['correlate_sparse'], diff --git a/python/cucim/src/cucim/skimage/filters/lpi_filter.py b/python/cucim/src/cucim/skimage/filters/lpi_filter.py index d6cd59c0b..e443a0723 100644 --- a/python/cucim/src/cucim/skimage/filters/lpi_filter.py +++ b/python/cucim/src/cucim/skimage/filters/lpi_filter.py @@ -7,7 +7,7 @@ import numpy as np from cupyx.scipy import fft -from .._shared.utils import _supported_float_type, check_nD, deprecated +from .._shared.utils import _supported_float_type, check_nD, deprecate_func eps = np.finfo(float).eps @@ -19,7 +19,7 @@ def _min_limit(x, val=eps): def _center(x, oshape): """Return an array of shape ``oshape`` from the center of array ``x``.""" - start = (np.array(x.shape) - np.array(oshape)) // 2 + 1 + start = (np.array(x.shape) - np.array(oshape)) // 2 out = x[tuple(slice(s, s + n) for s, n in zip(start, oshape))] return out @@ -65,11 +65,10 @@ def __init__(self, impulse_response, **filter_params): Examples -------- - Gaussian filter: Use a 1-D gaussian in each direction without - normalization coefficients. + Gaussian filter without normalization of coefficients: >>> def filt_func(r, c, sigma = 1): - ... return cp.exp(-cp.hypot(r, c)/sigma) + ... return cp.exp(-(r**2 + c**2)/(2 * sigma**2)) >>> filter = LPIFilter2D(filt_func) """ @@ -83,14 +82,22 @@ def __init__(self, impulse_response, **filter_params): def _prepare(self, data): """Calculate filter and data FFT in preparation for filtering.""" dshape = np.array(data.shape) - dshape += dshape % 2 == 0 # all filter dimensions must be uneven - oshape = np.array(data.shape) * 2 - 1 + # all filter dimensions must be uneven + even_offset = tuple(int(s % 2 == 0) for s in data.shape) + dshape = tuple( + s + offset for s, offset in zip(data.shape, even_offset) + ) + + oshape = tuple(s * 2 - 1 for s in data.shape) float_dtype = _supported_float_type(data.dtype) data = data.astype(float_dtype, copy=False) if self._cache is None or np.any(self._cache.shape != oshape): - coords = cp.mgrid[[slice(0, float(n)) for n in dshape]] + coords = cp.mgrid[ + [slice(0 + offset, float(n + offset)) + for (n, offset) in zip(dshape, even_offset)] + ] # this steps over two sets of coordinates, # not over the coordinates individually for k, coord in enumerate(coords): @@ -127,7 +134,7 @@ def __call__(self, data): return out -def filter_forward(data, impulse_response=None, filter_params={}, +def filter_forward(data, impulse_response=None, filter_params=None, predefined_filter=None): """Apply the given filter to data. @@ -137,7 +144,7 @@ def filter_forward(data, impulse_response=None, filter_params={}, Input data. impulse_response : callable `f(r, c, **filter_params)` Impulse response of the filter. See LPIFilter2D.__init__. - filter_params : dict + filter_params : dict, optional Additional keyword parameters to the impulse_response function. Other Parameters @@ -148,32 +155,35 @@ def filter_forward(data, impulse_response=None, filter_params={}, Examples -------- + Gaussian filter without normalization: - Gaussian filter: - - >>> import cupy as cp - >>> def filt_func(r, c): - ... return cp.exp(-cp.hypot(r, c)/1) + >>> def filt_func(r, c, sigma=1): + ... return cp.exp(-(r**2 + c**2)/(2 * sigma**2)) >>> >>> from skimage import data >>> filtered = filter_forward(cp.array(data.coins()), filt_func) """ + if filter_params is None: + filter_params = {} check_nD(data, 2, 'data') if predefined_filter is None: predefined_filter = LPIFilter2D(impulse_response, **filter_params) return predefined_filter(data) -@deprecated(alt_func='cucim.skimage.filters.lpi_filter.filter_inverse', - removed_version='2023.06.00') -def inverse(data, impulse_response=None, filter_params={}, max_gain=2, +@deprecate_func( + hint="use cucim.skimage.filters.lpi_filter.filter_inverse instead", + deprecated_version="", + removed_version="2023.12.00", +) +def inverse(data, impulse_response=None, filter_params=None, max_gain=2, predefined_filter=None): return filter_inverse(data, impulse_response, filter_params, max_gain, predefined_filter) -def filter_inverse(data, impulse_response=None, filter_params={}, max_gain=2, +def filter_inverse(data, impulse_response=None, filter_params=None, max_gain=2, predefined_filter=None): """Apply the filter in reverse to the given data. @@ -182,21 +192,24 @@ def filter_inverse(data, impulse_response=None, filter_params={}, max_gain=2, data : (M, N) ndarray Input data. impulse_response : callable `f(r, c, **filter_params)` - Impulse response of the filter. See LPIFilter2D.__init__. - filter_params : dict + Impulse response of the filter. See :class:`~.LPIFilter2D`. This is a + required argument unless a `predifined_filter` is provided. + filter_params : dict, optional Additional keyword parameters to the impulse_response function. - max_gain : float + max_gain : float, optional Limit the filter gain. Often, the filter contains zeros, which would cause the inverse filter to have infinite gain. High gain causes amplification of artefacts, so a conservative limit is recommended. Other Parameters ---------------- - predefined_filter : LPIFilter2D + predefined_filter : LPIFilter2D, optional If you need to apply the same filter multiple times over different images, construct the LPIFilter2D and specify it here. """ + if filter_params is None: + filter_params = {} check_nD(data, 2, 'data') if predefined_filter is None: filt = LPIFilter2D(impulse_response, **filter_params) @@ -213,7 +226,7 @@ def filter_inverse(data, impulse_response=None, filter_params={}, max_gain=2, return _center(cp.abs(fft.ifftshift(fft.ifftn(G * F))), data.shape) -def wiener(data, impulse_response=None, filter_params={}, K=0.25, +def wiener(data, impulse_response=None, filter_params=None, K=0.25, predefined_filter=None): """Minimum Mean Square Error (Wiener) inverse filter. @@ -226,7 +239,7 @@ def wiener(data, impulse_response=None, filter_params={}, K=0.25, image. impulse_response : callable `f(r, c, **filter_params)` Impulse response of the filter. See LPIFilter2D.__init__. - filter_params : dict + filter_params : dict, optional Additional keyword parameters to the impulse_response function. Other Parameters @@ -236,6 +249,8 @@ def wiener(data, impulse_response=None, filter_params={}, K=0.25, images, construct the LPIFilter2D and specify it here. """ + if filter_params is None: + filter_params = {} check_nD(data, 2, 'data') if not isinstance(K, float): diff --git a/python/cucim/src/cucim/skimage/filters/ridges.py b/python/cucim/src/cucim/skimage/filters/ridges.py index 5d861b8f7..3615278bc 100644 --- a/python/cucim/src/cucim/skimage/filters/ridges.py +++ b/python/cucim/src/cucim/skimage/filters/ridges.py @@ -14,13 +14,13 @@ class of ridge filters relies on the eigenvalues of the Hessian matrix of import cupy as cp import numpy as np -from .._shared.utils import _supported_float_type, check_nD, deprecated +from .._shared.utils import _supported_float_type, check_nD, deprecate_func from ..feature.corner import (_symmetric_compute_eigenvalues, hessian_matrix, hessian_matrix_eigvals) from ..util import img_as_float -@deprecated(removed_version="2023.06.01") +@deprecate_func(deprecated_version="", removed_version="2023.06.01") def compute_hessian_eigenvalues(image, sigma, sorting='none', mode='constant', cval=0, use_gaussian_derivatives=False): diff --git a/python/cucim/src/cucim/skimage/filters/thresholding.py b/python/cucim/src/cucim/skimage/filters/thresholding.py index 367eb35d8..c2dedb2fa 100644 --- a/python/cucim/src/cucim/skimage/filters/thresholding.py +++ b/python/cucim/src/cucim/skimage/filters/thresholding.py @@ -704,6 +704,10 @@ def threshold_li(image, *, tolerance=None, initial_guess=None, weights=hist[foreground]) mean_back = np.average(bin_centers[background], weights=hist[background]) + + if mean_back == 0: + break + eps = 100 * np.finfo(float).eps mean_back = float(mean_back) mean_fore = float(mean_fore) @@ -726,6 +730,9 @@ def threshold_li(image, *, tolerance=None, initial_guess=None, mean_fore = float(cp.mean(image[foreground])) mean_back = float(cp.mean(image[~foreground])) + if mean_back == 0: + break + t_next = ((mean_back - mean_fore) / (math.log(mean_back + eps) - math.log(mean_fore + eps))) diff --git a/python/cucim/src/cucim/skimage/morphology/_skeletonize.py b/python/cucim/src/cucim/skimage/morphology/_skeletonize.py index d4ca12fde..a58ba8681 100644 --- a/python/cucim/src/cucim/skimage/morphology/_skeletonize.py +++ b/python/cucim/src/cucim/skimage/morphology/_skeletonize.py @@ -153,10 +153,10 @@ def thin(image, max_num_iter=None): # --------- Skeletonization by medial axis transform -------- -def _get_tiebreaker(n, random_seed): +def _get_tiebreaker(n, seed): # CuPy generator doesn't currently have the permutation method, so # fall back to cp.random.permutation instead. - cp.random.seed(random_seed) + cp.random.seed(seed) if n < 2 << 31: dtype = np.int32 else: @@ -165,7 +165,12 @@ def _get_tiebreaker(n, random_seed): return tiebreaker -def medial_axis(image, mask=None, return_distance=False, *, random_state=None): +@deprecate_kwarg( + {'random_state': 'seed'}, + deprecated_version='23.08', + removed_version='24.06' +) +def medial_axis(image, mask=None, return_distance=False, *, seed=None): """Compute the medial axis transform of a binary image. Parameters @@ -177,13 +182,12 @@ def medial_axis(image, mask=None, return_distance=False, *, random_state=None): value in `mask` are used for computing the medial axis. return_distance : bool, optional If true, the distance transform is returned as well as the skeleton. - random_state : {None, int, `numpy.random.Generator`}, optional - If `random_state` is None the `numpy.random.Generator` singleton is + seed : {None, int, `numpy.random.Generator`}, optional + If `seed` is None, the `numpy.random.Generator` singleton is used. + If `seed` is an int, a new ``Generator`` instance is used, seeded with + `seed`. + If `seed` is already a ``Generator`` instance, then that instance is used. - If `random_state` is an int, a new ``Generator`` instance is used, - seeded with `random_state`. - If `random_state` is already a ``Generator`` instance then that - instance is used. .. versionadded:: 0.19 @@ -295,7 +299,7 @@ def medial_axis(image, mask=None, return_distance=False, *, random_state=None): # We use a random # for tiebreaking. Assign each pixel in the image a # predictable, random # so that masking doesn't affect arbitrary choices # of skeletons - tiebreaker = _get_tiebreaker(n=distance.size, random_seed=random_state) + tiebreaker = _get_tiebreaker(n=distance.size, seed=seed) order = cp.lexsort( cp.stack( (tiebreaker, corner_score[masked_image], distance), diff --git a/python/cucim/src/cucim/skimage/restoration/__init__.py b/python/cucim/src/cucim/skimage/restoration/__init__.py index aa0b63ecb..07eaeb6b1 100644 --- a/python/cucim/src/cucim/skimage/restoration/__init__.py +++ b/python/cucim/src/cucim/skimage/restoration/__init__.py @@ -1,6 +1,6 @@ from ._denoise import denoise_tv_chambolle from .deconvolution import richardson_lucy, unsupervised_wiener, wiener -from .j_invariant import calibrate_denoiser +from .j_invariant import calibrate_denoiser, denoise_invariant __all__ = [ "wiener", @@ -8,4 +8,5 @@ "richardson_lucy", "denoise_tv_chambolle", "calibrate_denoiser", + "denoise_invariant", ] diff --git a/python/cucim/src/cucim/skimage/restoration/deconvolution.py b/python/cucim/src/cucim/skimage/restoration/deconvolution.py index 923505a4a..f9dda2e30 100644 --- a/python/cucim/src/cucim/skimage/restoration/deconvolution.py +++ b/python/cucim/src/cucim/skimage/restoration/deconvolution.py @@ -146,8 +146,10 @@ def wiener(image, psf, balance, reg=None, is_real=True, clip=True): return deconv +@deprecate_kwarg({'random_state': 'seed'}, removed_version="23.08.00", + deprecated_version="24.06.00") def unsupervised_wiener(image, psf, reg=None, user_params=None, is_real=True, - clip=True, *, random_state=None): + clip=True, *, seed=None): """Unsupervised Wiener-Hunt deconvolution. Return the deconvolution with a Wiener-Hunt approach, where the @@ -172,13 +174,12 @@ def unsupervised_wiener(image, psf, reg=None, user_params=None, is_real=True, clip : boolean, optional True by default. If true, pixel values of the result above 1 or under -1 are thresholded for skimage pipeline compatibility. - random_state : {None, int, `cupy.random.Generator`}, optional - If `random_state` is None the `cupy.random.Generator` singleton is + seed : {None, int, `numpy.random.Generator`}, optional + If `seed` is None, the `numpy.random.Generator` singleton is used. + If `seed` is an int, a new ``Generator`` instance is used, seeded with + `seed`. + If `seed` is already a ``Generator`` instance, then that instance is used. - If `random_state` is an int, a new ``Generator`` instance is used, - seeded with `random_state`. - If `random_state` is already a ``Generator`` instance then that - instance is used. Returns ------- @@ -304,10 +305,10 @@ def unsupervised_wiener(image, psf, reg=None, user_params=None, is_real=True, data_spectrum = uft.ufft2(image) try: - rng = cp.random.default_rng(random_state) + rng = cp.random.default_rng(seed) except AttributeError: # older CuPy without default_rng - rng = cp.random.RandomState(random_state) + rng = cp.random.RandomState(seed) # Gibbs sampling for iteration in range(params["max_num_iter"]): @@ -383,8 +384,6 @@ def unsupervised_wiener(image, psf, reg=None, user_params=None, is_real=True, return (x_postmean, {'noise': gn_chain, 'prior': gx_chain}) -@deprecate_kwarg({'iterations': 'num_iter'}, removed_version="23.02.00", - deprecated_version="22.02.00") def richardson_lucy(image, psf, num_iter=50, clip=True, filter_epsilon=None): """Richardson-Lucy deconvolution. diff --git a/python/cucim/src/cucim/skimage/restoration/j_invariant.py b/python/cucim/src/cucim/skimage/restoration/j_invariant.py index e8a58dab2..1e19a92ca 100644 --- a/python/cucim/src/cucim/skimage/restoration/j_invariant.py +++ b/python/cucim/src/cucim/skimage/restoration/j_invariant.py @@ -89,14 +89,16 @@ def _generate_grid_slice(shape, *, offset, stride=3): return mask -def _invariant_denoise(image, denoise_function, *, stride=4, - masks=None, denoiser_kwargs=None): +def denoise_invariant(image, denoise_function, *, stride=4, masks=None, + denoiser_kwargs=None): """Apply a J-invariant version of `denoise_function`. Parameters ---------- - image : ndarray - Input data to be denoised (converted using `img_as_float`). + image : ndarray ([M[, N[, ...P]][, C]) of ints, uints or floats + Input data to be denoised. `image` can be of any numeric type, + but it is cast into a ndarray of floats (using `img_as_float`) for the + computation of the denoised image. denoise_function : function Original denoising function. stride : int, optional @@ -112,6 +114,32 @@ def _invariant_denoise(image, denoise_function, *, stride=4, ------- output : ndarray Denoised image, of same shape as `image`. + + Notes + ----- + A denoising function is J-invariant if the prediction it makes for each + pixel does not depend on the value of that pixel in the original image. + The prediction for each pixel may instead use all the relevant information + contained in the rest of the image, which is typically quite significant. + Any function can be converted into a J-invariant one using a simple masking + procedure, as described in [1]. + + The pixel-wise error of a J-invariant denoiser is uncorrelated to the noise, + so long as the noise in each pixel is independent. Consequently, the average + difference between the denoised image and the oisy image, the + *self-supervised loss*, is the same as the difference between the denoised + image and the original clean image, the *ground-truth loss* (up to a + constant). + + This means that the best J-invariant denoiser for a given image can be found + using the noisy data alone, by selecting the denoiser minimizing the self- + supervised loss. + + References + ---------- + .. [1] J. Batson & L. Royer. Noise2Self: Blind Denoising by + Self-Supervision, International Conference on Machine Learning, + p. 524-533 (2019). """ image = img_as_float(image) @@ -252,7 +280,7 @@ def calibrate_denoiser(image, denoise_function, denoise_parameters, *, best_parameters = parameters_tested[idx] best_denoise_function = functools.partial( - _invariant_denoise, + denoise_invariant, denoise_function=denoise_function, stride=stride, denoiser_kwargs=best_parameters, @@ -303,7 +331,7 @@ def _calibrate_denoiser_search(image, denoise_function, denoise_parameters, *, multichannel = \ denoiser_kwargs.get('channel_axis', None) is not None if not approximate_loss: - denoised = _invariant_denoise( + denoised = denoise_invariant( image, denoise_function, stride=stride, denoiser_kwargs=denoiser_kwargs @@ -315,7 +343,7 @@ def _calibrate_denoiser_search(image, denoise_function, denoise_parameters, *, mask = _generate_grid_slice(image.shape[:spatialdims], offset=n_masks // 2, stride=stride) - masked_denoised = _invariant_denoise( + masked_denoised = denoise_invariant( image, denoise_function, masks=[mask], denoiser_kwargs=denoiser_kwargs diff --git a/python/cucim/src/cucim/skimage/restoration/tests/test_j_invariant.py b/python/cucim/src/cucim/skimage/restoration/tests/test_j_invariant.py index 88d814cbb..0386d8950 100644 --- a/python/cucim/src/cucim/skimage/restoration/tests/test_j_invariant.py +++ b/python/cucim/src/cucim/skimage/restoration/tests/test_j_invariant.py @@ -10,7 +10,7 @@ from cucim.skimage.data import binary_blobs from cucim.skimage.metrics import mean_squared_error as mse from cucim.skimage.restoration import calibrate_denoiser, denoise_tv_chambolle -from cucim.skimage.restoration.j_invariant import _invariant_denoise +from cucim.skimage.restoration.j_invariant import denoise_invariant from cucim.skimage.util import img_as_float, random_noise test_img = img_as_float(cp.asarray(camera())) @@ -36,9 +36,9 @@ def _denoise_wavelet(image, rescale_sigma=True, **kwargs): ) -def test_invariant_denoise(): - # denoised_img = _invariant_denoise(noisy_img, _denoise_wavelet) - denoised_img = _invariant_denoise(noisy_img, denoise_tv_chambolle) +def test_denoise_invariant(): + # denoised_img = denoise_invariant(noisy_img, _denoise_wavelet) + denoised_img = denoise_invariant(noisy_img, denoise_tv_chambolle) denoised_mse = mse(denoised_img, test_img) original_mse = mse(noisy_img, test_img) @@ -46,8 +46,8 @@ def test_invariant_denoise(): @pytest.mark.parametrize('dtype', [cp.float16, cp.float32, cp.float64]) -def test_invariant_denoise_color(dtype): - denoised_img_color = _invariant_denoise( +def test_denoise_invariant_color(dtype): + denoised_img_color = denoise_invariant( noisy_img_color.astype(dtype), _denoise_wavelet, denoiser_kwargs=dict(channel_axis=-1), @@ -59,8 +59,8 @@ def test_invariant_denoise_color(dtype): assert denoised_img_color.dtype == _supported_float_type(dtype) -def test_invariant_denoise_3d(): - denoised_img_3d = _invariant_denoise(noisy_img_3d, _denoise_wavelet) +def test_denoise_invariant_3d(): + denoised_img_3d = denoise_invariant(noisy_img_3d, _denoise_wavelet) denoised_mse = mse(denoised_img_3d, test_img_3d) original_mse = mse(noisy_img_3d, test_img_3d) @@ -76,8 +76,8 @@ def test_calibrate_denoiser_extra_output(): extra_output=True ) - all_denoised = [_invariant_denoise(noisy_img, _denoise_wavelet, - denoiser_kwargs=denoiser_kwargs) + all_denoised = [denoise_invariant(noisy_img, _denoise_wavelet, + denoiser_kwargs=denoiser_kwargs) for denoiser_kwargs in parameters_tested] ground_truth_losses = [float(mse(img, test_img)) for img in all_denoised] diff --git a/python/cucim/src/cucim/skimage/restoration/tests/test_restoration.py b/python/cucim/src/cucim/skimage/restoration/tests/test_restoration.py index 5d5ccd5a9..839f19f96 100644 --- a/python/cucim/src/cucim/skimage/restoration/tests/test_restoration.py +++ b/python/cucim/src/cucim/skimage/restoration/tests/test_restoration.py @@ -82,8 +82,7 @@ def test_unsupervised_wiener(dtype): psf = cp.asarray(psf, dtype=dtype) data = cp.asarray(data, dtype=dtype) - deconvolved, _ = restoration.unsupervised_wiener(data, psf, - random_state=seed) + deconvolved, _ = restoration.unsupervised_wiener(data, psf, seed=seed) float_type = _supported_float_type(dtype) assert deconvolved.dtype == float_type @@ -109,7 +108,7 @@ def test_unsupervised_wiener(dtype): user_params={"callback": lambda x: None, "max_num_iter": 200, "min_num_iter": 30}, - random_state=seed, + seed=seed, )[0] assert deconvolved2.real.dtype == float_type @@ -129,7 +128,8 @@ def test_unsupervised_wiener_deprecated_user_param(): otf = uft.ir2tf(psf, data.shape, is_real=False) _, laplacian = uft.laplacian(2, data.shape) with expected_warnings(["`max_iter` is a deprecated key", - "`min_iter` is a deprecated key"]): + "`min_iter` is a deprecated key", + "`random_state` is a deprecated argument name"]): restoration.unsupervised_wiener( data, otf, reg=laplacian, is_real=False, user_params={"max_iter": 200, "min_iter": 30}, random_state=5 @@ -175,17 +175,6 @@ def test_richardson_lucy(): cp.testing.assert_allclose(deconvolved, np.load(path), rtol=1e-4) -def test_richardson_lucy_deprecated_iterations_kwarg(): - psf = np.ones((5, 5)) / 25 - data = signal.convolve2d(cp.asnumpy(test_img), psf, 'same') - np.random.seed(0) - data += 0.1 * data.std() * np.random.standard_normal(data.shape) - data = cp.array(data) - psf = cp.array(psf) - with expected_warnings(["`iterations` is a deprecated argument"]): - restoration.richardson_lucy(data, psf, iterations=5) - - @pytest.mark.parametrize('dtype_image', [cp.float16, cp.float32, cp.float64]) @pytest.mark.parametrize('dtype_psf', [cp.float32, cp.float64]) @testing.with_requires("scikit-image>=0.18") diff --git a/python/cucim/src/cucim/skimage/segmentation/_chan_vese.py b/python/cucim/src/cucim/skimage/segmentation/_chan_vese.py index bab2f1887..84a7feb2b 100644 --- a/python/cucim/src/cucim/skimage/segmentation/_chan_vese.py +++ b/python/cucim/src/cucim/skimage/segmentation/_chan_vese.py @@ -8,37 +8,6 @@ from .._vendored import pad -@cp.fuse() -def _fused_curvature(phi, x_start, x_end, y_start, y_end, ul, ur, ll, lr): - fy = (y_end - y_start) / 2.0 - fx = (x_end - x_start) / 2.0 - fyy = y_end + y_start - 2 * phi - fxx = x_end + x_start - 2 * phi - fxy = .25 * (lr + ul - ur - ll) - grad2 = fx**2 + fy**2 - K = (fxx * fy**2 - 2 * fxy * fx * fy + fyy * fx**2) - K /= (grad2 * cp.sqrt(grad2) + 1e-8) - return K - - -def _cv_curvature(phi): - """Returns the 'curvature' of a level set 'phi'. - """ - P = pad(phi, 1, mode='edge') - y_start = P[:-2, 1:-1] - y_end = P[2:, 1:-1] - x_start = P[1:-1, :-2] - x_end = P[1:-1, 2:] - - lower_right = P[2:, 2:] - lower_left = P[2:, :-2] - upper_right = P[:-2, 2:] - upper_left = P[:-2, :-2] - K = _fused_curvature(phi, x_start, x_end, y_start, y_end, upper_left, - upper_right, lower_left, lower_right) - return K - - @cp.fuse() def _fused_variance_kernel1(eta, x_start, x_mid, x_end, y_start, y_mid, y_end): phixp = x_end - x_mid @@ -102,6 +71,21 @@ def _fused_variance_kernel2( def _cv_calculate_variation(image, phi, mu, lambda1, lambda2, dt): """Returns the variation of level set 'phi' based on algorithm parameters. + + This corresponds to equation (22) of the paper by Pascal Getreuer, + which computes the next iteration of the level set based on a current + level set. + + A full explanation regarding all the terms is beyond the scope of the + present description, but there is one difference of particular import. + In the original algorithm, convergence is accelerated, and required + memory is reduced, by using a single array. This array, therefore, is a + combination of non-updated and updated values. If this were to be + implemented in python, this would require a double loop, where the + benefits of having fewer iterations would be outweided by massively + increasing the time required to perform each individual iteration. A + similar approach is used by Rami Cohen, and it is from there that the + C1-4 notation is taken. """ eta = 1e-16 P = pad(phi, 1, mode='edge') @@ -183,17 +167,44 @@ def _cv_difference_from_average_term(image, Hphi, lambda_pos, lambda_neg): return out +@cp.fuse() +def _fused_edge_length(mu, phi, x_start, x_end, y_start, y_end): + fy = (y_end - y_start) / 2.0 + fx = (x_end - x_start) / 2.0 + grad2 = fx**2 + fy**2 + + del_phi = 1.0 / (1.0 + phi * phi) # _cv_delta(phi) + K = mu * del_phi * cp.sqrt(grad2) + return K + + def _cv_edge_length_term(phi, mu): - """Returns the 'energy' contribution due to the length of the - edge between regions at each point, multiplied by a factor 'mu'. + """Returns the 'curvature' of a level set 'phi'. """ - e = _cv_curvature(phi) - e *= mu - return e + P = pad(phi, 1, mode='edge') + y_start = P[:-2, 1:-1] + y_end = P[2:, 1:-1] + x_start = P[1:-1, :-2] + x_end = P[1:-1, 2:] + + K = _fused_edge_length(mu, phi, x_start, x_end, y_start, y_end) + return K def _cv_energy(image, phi, mu, lambda1, lambda2): """Returns the total 'energy' of the current level set function. + + This corresponds to equation (7) of the paper by Pascal Getreuer, + which is the weighted sum of the following: + (A) the length of the contour produced by the zero values of the + level set, + (B) the area of the "foreground" (area of the image where the + level set is positive), + (C) the variance of the image inside the foreground, + (D) the variance of the image outside of the foreground + + Each value is computed for each pixel, and then summed. The weight + of (B) is set to 0 in this implementation. """ H = _cv_heavyside(phi) avgenergy = _cv_difference_from_average_term(image, H, lambda1, lambda2) diff --git a/python/cucim/src/cucim/skimage/segmentation/_join.py b/python/cucim/src/cucim/skimage/segmentation/_join.py index 3744de969..6227a2fea 100644 --- a/python/cucim/src/cucim/skimage/segmentation/_join.py +++ b/python/cucim/src/cucim/skimage/segmentation/_join.py @@ -3,7 +3,7 @@ from ..util._map_array import ArrayMap, map_array -def join_segmentations(s1, s2): +def join_segmentations(s1, s2, return_mapping: bool = False): """Return the join of the two input segmentations. The join J of S1 and S2 is defined as the segmentation in which two @@ -14,11 +14,18 @@ def join_segmentations(s1, s2): ---------- s1, s2 : numpy arrays s1 and s2 are label fields of the same shape. + return_mapping : bool, optional + If true, return mappings for joined segmentation labels to the original + labels. Returns ------- j : numpy array The join segmentation of s1 and s2. + map_j_to_s1 : ArrayMap, optional + Mapping from labels of the joined segmentation j to labels of s1. + map_j_to_s2 : ArrayMap, optional + Mapping from labels of the joined segmentation j to labels of s2. Examples -------- @@ -38,11 +45,21 @@ def join_segmentations(s1, s2): if s1.shape != s2.shape: raise ValueError("Cannot join segmentations of different shape. " f"s1.shape: {s1.shape}, s2.shape: {s2.shape}") - s1 = relabel_sequential(s1)[0] - s2 = relabel_sequential(s2)[0] - j = (s2.max() + 1) * s1 + s2 - j = relabel_sequential(j)[0] - return j + s1_relabeled, _, backward_map1 = relabel_sequential(s1) + s2_relabeled, _, backward_map2 = relabel_sequential(s2) + factor = s2.max() + 1 + j_initial = factor * s1_relabeled + s2_relabeled + j, _, map_j_to_j_initial = relabel_sequential(j_initial) + if not return_mapping: + return j + # Determine label mapping + labels_j = cp.unique(j_initial) + labels_s1_relabeled, labels_s2_relabeled = cp.divmod(labels_j, factor) + map_j_to_s1 = ArrayMap(map_j_to_j_initial.in_values, + backward_map1[labels_s1_relabeled]) + map_j_to_s2 = ArrayMap(map_j_to_j_initial.in_values, + backward_map2[labels_s2_relabeled]) + return j, map_j_to_s1, map_j_to_s2 def relabel_sequential(label_field, offset=1): diff --git a/python/cucim/src/cucim/skimage/segmentation/tests/test_join.py b/python/cucim/src/cucim/skimage/segmentation/tests/test_join.py index af528073c..e68df7b8e 100644 --- a/python/cucim/src/cucim/skimage/segmentation/tests/test_join.py +++ b/python/cucim/src/cucim/skimage/segmentation/tests/test_join.py @@ -23,6 +23,12 @@ def test_join_segmentations(): [0, 5, 3, 2], [4, 5, 5, 3]]) assert_array_equal(j, j_ref) + + # test correct mapping + j, m1, m2 = join_segmentations(s1, s2, return_mapping=True) + assert_array_equal(m1[j], s1) + assert_array_equal(m2[j], s2) + # fmt: on # test correct exception when arrays are different shapes s3 = cp.array([[0, 0, 1, 1], [0, 2, 2, 1]]) diff --git a/python/cucim/src/cucim/skimage/transform/_geometric.py b/python/cucim/src/cucim/skimage/transform/_geometric.py index bb6bde4bf..4da5b11f4 100644 --- a/python/cucim/src/cucim/skimage/transform/_geometric.py +++ b/python/cucim/src/cucim/skimage/transform/_geometric.py @@ -487,6 +487,39 @@ class EssentialMatrixTransform(FundamentalMatrixTransform): params : (3, 3) ndarray Essential matrix. + Examples + -------- + >>> import cupy as cp + >>> from cucim.skimage import transform + >>> + >>> tform_matrix = transform.EssentialMatrixTransform( + ... rotation=cp.eye(3), translation=cp.array([0, 0, 1]) + ... ) + >>> tform_matrix.params + array([[ 0., -1., 0.], + [ 1., 0., 0.], + [ 0., 0., 0.]]) + >>> src = cp.array([[ 1.839035, 1.924743], + ... [ 0.543582, 0.375221], + ... [ 0.47324 , 0.142522], + ... [ 0.96491 , 0.598376], + ... [ 0.102388, 0.140092], + ... [15.994343, 9.622164], + ... [ 0.285901, 0.430055], + ... [ 0.09115 , 0.254594]]) + >>> dst = cp.array([[1.002114, 1.129644], + ... [1.521742, 1.846002], + ... [1.084332, 0.275134], + ... [0.293328, 0.588992], + ... [0.839509, 0.08729 ], + ... [1.779735, 1.116857], + ... [0.878616, 0.602447], + ... [0.642616, 1.028681]]) + >>> tform_matrix.estimate(src, dst) + True + >>> tform_matrix.residuals(src, dst) + array([0.42455187, 0.01460448, 0.13847034, 0.12140951, 0.27759346, + 0.32453118, 0.00210776, 0.26512283]) """ # CuPy Backend: if matrix is None cannot infer array module from it @@ -874,14 +907,26 @@ class AffineTransform(ProjectiveTransform): Has the following form:: - X = a0*x + a1*y + a2 = - = sx*x*cos(rotation) - sy*y*sin(rotation + shear) + a2 + X = a0 * x + a1 * y + a2 + = sx * x * [cos(rotation) + tan(shear_y) * sin(rotation)] + - sy * y * [tan(shear_x) * cos(rotation) + sin(rotation)] + + translation_x + + Y = b0 * x + b1 * y + b2 + = sx * x * [sin(rotation) - tan(shear_y) * cos(rotation)] + - sy * y * [tan(shear_x) * sin(rotation) - cos(rotation)] + + translation_y + + where ``sx`` and ``sy`` are scale factors in the x and y directions. - Y = b0*x + b1*y + b2 = - = sx*x*sin(rotation) + sy*y*cos(rotation + shear) + b2 + This is equivalent to applying the operations in the following order: - where ``sx`` and ``sy`` are scale factors in the x and y directions, - and the homogeneous transformation matrix is:: + 1. Scale + 2. Shear + 3. Rotate + 4. Translate + + The homogeneous transformation matrix is:: [[a0 a1 a2] [b0 b1 b2] @@ -909,11 +954,12 @@ class AffineTransform(ProjectiveTransform): .. versionadded:: 0.17 Added support for supplying a single scalar value. rotation : float, optional - Rotation angle in counter-clockwise direction as radians. Only - available for 2D. - shear : float, optional - Shear angle in counter-clockwise direction as radians. Only available - for 2D. + Rotation angle, clockwise, as radians. Only available for 2D. + shear : float or 2-tuple of float, optional + The x and y shear angles, clockwise, by which these axes are + rotated around the origin [2]. + If a single value is given, take that to be the x shear angle, with + the y angle remaining 0. Only available in 2D. translation : (tx, ty) as ndarray, list or tuple, optional Translation parameters. Only available for 2D. dimensionality : int, optional @@ -929,6 +975,39 @@ class AffineTransform(ProjectiveTransform): ------ ValueError If both ``matrix`` and any of the other parameters are provided. + + Examples + -------- + >>> import cupy as cp + >>> from cucim.skimage import transform + >>> from skimage import data + >>> img = cp.array(data.astronaut()) + + Define source and destination points: + + >>> src = cp.array([[150, 150], + ... [250, 100], + ... [150, 200]]) + >>> dst = cp.array([[200, 200], + ... [300, 150], + ... [150, 400]]) + + Estimate the transformation matrix: + + >>> tform = transform.AffineTransform() + >>> tform.estimate(src, dst) + True + + Apply the transformation: + + >>> warped = transform.warp(img, inverse_map=tform.inverse) + + References + ---------- + .. [1] Wikipedia, "Affine transformation", + https://en.wikipedia.org/wiki/Affine_transformation#Image_transformation + .. [2] Wikipedia, "Shear mapping", + https://en.wikipedia.org/wiki/Shear_mapping """ def __init__(self, matrix=None, scale=None, rotation=None, shear=None, @@ -968,17 +1047,27 @@ def __init__(self, matrix=None, scale=None, rotation=None, shear=None, else: sx, sy = scale - # fmt: off - self.params = np.array( - [ - [sx * _cos(rotation), -sy * _sin(rotation + shear), 0], # NOQA - [sx * _sin(rotation), sy * _cos(rotation + shear), 0], # NOQA - [ 0, 0, 1], # NOQA - ] + if np.isscalar(shear): + shear_x, shear_y = (shear, 0) + else: + shear_x, shear_y = shear + + a0 = sx * ( + math.cos(rotation) + math.tan(shear_y) * math.sin(rotation) ) - # fmt: on - self.params[0:2, 2] = translation - self.params = xp.asarray(self.params) + a1 = -sy * ( + math.tan(shear_x) * math.cos(rotation) + math.sin(rotation) + ) + a2 = translation[0] + + b0 = sx * ( + math.sin(rotation) - math.tan(shear_y) * math.cos(rotation) + ) + b1 = -sy * ( + math.tan(shear_x) * math.sin(rotation) - math.cos(rotation) + ) + b2 = translation[1] + self.params = xp.array([[a0, a1, a2], [b0, b1, b2], [0, 0, 1]]) else: # default to an identity transform self.params = xp.eye(dimensionality + 1) @@ -986,9 +1075,14 @@ def __init__(self, matrix=None, scale=None, rotation=None, shear=None, @property def scale(self): xp = cp.get_array_module(self.params) - return xp.sqrt(xp.sum(self.params * self.params, axis=0))[ - : self.dimensionality - ] + if self.dimensionality != 2: + return xp.sqrt(xp.sum(self.params * self.params, axis=0))[ + : self.dimensionality + ] + else: + ss = xp.sum(self.params * self.params, axis=0) + ss[1] = ss[1] / (math.tan(self.shear) ** 2 + 1) + return xp.sqrt(ss)[:self.dimensionality] @property def rotation(self): @@ -1066,7 +1160,13 @@ def estimate(self, src, dst): # find affine mapping from source positions to destination self.affines = [] - for tri in xp.asarray(self._tesselation.vertices): + + try: + tesselation_simplices = self._tesselation.simplices + except AttributeError: + # vertices is deprecated and scheduled for removal in SciPy 1.11 + tesselation_simplices = self._tesselation.vertices + for tri in xp.asarray(tesselation_simplices): affine = AffineTransform(dimensionality=ndim) ok &= affine.estimate(src[tri, :], dst[tri, :]) self.affines.append(affine) @@ -1077,7 +1177,12 @@ def estimate(self, src, dst): self._inverse_tesselation = spatial.Delaunay(cp.asnumpy(dst)) # find affine mapping from source positions to destination self.inverse_affines = [] - for tri in xp.asarray(self._inverse_tesselation.vertices): + try: + inv_tesselation_simplices = self._inverse_tesselation.simplices + except AttributeError: + # vertices is deprecated and scheduled for removal in SciPy 1.11 + inv_tesselation_simplices = self._inverse_tesselation.vertices + for tri in xp.asarray(inv_tesselation_simplices): affine = AffineTransform(dimensionality=ndim) ok &= affine.estimate(dst[tri, :], src[tri, :]) self.inverse_affines.append(affine) @@ -1112,7 +1217,12 @@ def __call__(self, coords): # coordinates outside of mesh out[simplex == -1, :] = -1 - for index in range(len(self._tesselation.vertices)): + try: + tesselation_simplices = self._tesselation.simplices + except AttributeError: + # vertices is deprecated and scheduled for removal in SciPy 1.11 + tesselation_simplices = self._tesselation.vertices + for index in range(len(tesselation_simplices)): # affine transform for triangle affine = self.affines[index] # all coordinates within triangle @@ -1150,7 +1260,12 @@ def inverse(self, coords): # coordinates outside of mesh out[simplex == -1, :] = -1 - for index in range(len(self._inverse_tesselation.vertices)): + try: + inv_tesselation_simplices = self._inverse_tesselation.simplices + except AttributeError: + # vertices is deprecated and scheduled for removal in SciPy 1.11 + inv_tesselation_simplices = self._inverse_tesselation.vertices + for index in range(len(inv_tesselation_simplices)): # affine transform for triangle affine = self.inverse_affines[index] # all coordinates within triangle @@ -1244,7 +1359,7 @@ class EuclideanTransform(ProjectiveTransform): matrix : (D+1, D+1) ndarray, optional Homogeneous transformation matrix. rotation : float or sequence of float, optional - Rotation angle in counter-clockwise direction as radians. If given as + Rotation angle, clockwise, as radians. If given as a vector, it is interpreted as Euler rotation angles [1]_. Only 2D (single rotation) and 3D (Euler rotations) values are supported. For higher dimensions, you must provide or estimate the transformation @@ -1390,7 +1505,7 @@ class SimilarityTransform(EuclideanTransform): scale : float, optional Scale factor. Implemented only for 2D and 3D. rotation : float, optional - Rotation angle in counter-clockwise direction as radians. + Rotation angle, clockwise, as radians. Implemented only for 2D and 3D. For 3D, this is given in XZX Euler angles. translation : (dim,) ndarray-like, optional diff --git a/python/cucim/src/cucim/skimage/transform/_warps.py b/python/cucim/src/cucim/skimage/transform/_warps.py index 88a395591..343f1f0f7 100644 --- a/python/cucim/src/cucim/skimage/transform/_warps.py +++ b/python/cucim/src/cucim/skimage/transform/_warps.py @@ -166,9 +166,6 @@ def resize(image, output_shape, order=None, mode='reflect', cval=0, clip=None, if clip is None: clip = True if order > 1 else False - # Save input value range for clip - img_bounds = (image.min(), image.max()) if clip else None - # Translate modes used by np.pad to those used by scipy.ndimage ndi_mode = _to_ndimage_mode(mode) if anti_aliasing: @@ -192,13 +189,15 @@ def resize(image, output_shape, order=None, mode='reflect', cval=0, clip=None, _ndi_mode = {'grid-constant': 'constant', 'grid-wrap':'wrap'}.get(ndi_mode, ndi_mode) # noqa # keep ndi.gaussian_filter rather than cucim.skimage.filters.gaussian # to avoid undesired dtype coercion - image = ndi.gaussian_filter(image, anti_aliasing_sigma, cval=cval, - mode=_ndi_mode) + filtered = ndi.gaussian_filter(image, anti_aliasing_sigma, cval=cval, + mode=_ndi_mode) + else: + filtered = image zoom_factors = [1 / f for f in factors] - out = ndi.zoom(image, zoom_factors, order=order, mode=ndi_mode, + out = ndi.zoom(filtered, zoom_factors, order=order, mode=ndi_mode, cval=cval, grid_mode=True) - _clip_warp_output(img_bounds, out, mode, cval, order, clip) + _clip_warp_output(image, out, mode, cval, order, clip) return out @@ -821,7 +820,7 @@ def _clip_warp_output(input_image, output_image, mode, cval, order, clip): cp.clip(output_image, min_val, max_val, out=output_image) -def warp(image, inverse_map, map_args={}, output_shape=None, order=None, +def warp(image, inverse_map, map_args=None, output_shape=None, order=None, mode='constant', cval=0., clip=None, preserve_range=False): """Warp an image according to a given coordinate transformation. @@ -963,6 +962,9 @@ def warp(image, inverse_map, map_args={}, output_shape=None, order=None, >>> warped = warp(cube, coords) """ # noqa + if map_args is None: + map_args = {} + if image.size == 0: raise ValueError("Cannot warp empty image with dimensions", image.shape) diff --git a/python/cucim/src/cucim/skimage/transform/tests/test_geometric.py b/python/cucim/src/cucim/skimage/transform/tests/test_geometric.py index 2b2891542..d9aee75db 100644 --- a/python/cucim/src/cucim/skimage/transform/tests/test_geometric.py +++ b/python/cucim/src/cucim/skimage/transform/tests/test_geometric.py @@ -269,6 +269,38 @@ def test_affine_init(): AffineTransform(scale=(0.5, 0.5)).scale) +@pytest.mark.parametrize("xp", [np, cp]) +def test_affine_shear(xp): + shear = 0.1 + # expected horizontal shear transform + cx = -np.tan(shear) + # fmt: off + expected = xp.array([ + [1, cx, 0], + [0, 1, 0], # noqa: E241 + [0, 0, 1], # noqa: E241 + ]) + # fmt: on + + tform = AffineTransform(shear=shear, xp=xp) + xp.testing.assert_array_almost_equal(tform.params, expected) + + shear = (1.2, 0.8) + # expected x, y shear transform + cx = -np.tan(shear[0]) + cy = -np.tan(shear[1]) + # fmt: off + expected = xp.array([ + [ 1, cx, 0], # noqa: E201 + [cy, 1, 0], # noqa: E241 + [ 0, 0, 1], # noqa: E241, E201 + ]) + # fmt: on + + tform = AffineTransform(shear=shear, xp=xp) + xp.testing.assert_array_almost_equal(tform.params, expected) + + def test_piecewise_affine(): tform = PiecewiseAffineTransform() assert tform.estimate(SRC, DST) diff --git a/python/cucim/src/cucim/skimage/transform/tests/test_warps.py b/python/cucim/src/cucim/skimage/transform/tests/test_warps.py index de7bbb06c..dbb47738e 100644 --- a/python/cucim/src/cucim/skimage/transform/tests/test_warps.py +++ b/python/cucim/src/cucim/skimage/transform/tests/test_warps.py @@ -472,10 +472,12 @@ def test_resize_clip(order, preserve_range, anti_aliasing, dtype): x = cp.ones((5, 5), dtype=dtype) if dtype == cp.uint8: x *= 255 + else: + x[0, 0] = cp.NaN resized = resize(x, (3, 3), order=order, preserve_range=preserve_range, anti_aliasing=anti_aliasing) - assert abs(float(resized.max()) - expected_max) < 1e-14 + assert abs(float(cp.nanmax(resized)) - expected_max) < 1e-14 @pytest.mark.parametrize('dtype', [cp.float16, cp.float32, cp.float64])