From 59b3ad4a8b21591b84825aa64439576c69cd06e3 Mon Sep 17 00:00:00 2001 From: Gregory Lee Date: Tue, 20 Aug 2024 18:08:30 -0400 Subject: [PATCH] pass all constants to _fused_variance_kernel2 as device scalars * avoids an issue with an internal numpy.can_cast call in CuPy's kernel fusion code when using NumPy 2.0 --- .../src/cucim/skimage/segmentation/_chan_vese.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/cucim/src/cucim/skimage/segmentation/_chan_vese.py b/python/cucim/src/cucim/skimage/segmentation/_chan_vese.py index ea33f26b5..8a08838d2 100644 --- a/python/cucim/src/cucim/skimage/segmentation/_chan_vese.py +++ b/python/cucim/src/cucim/skimage/segmentation/_chan_vese.py @@ -7,6 +7,8 @@ from .._shared.utils import _supported_float_type from .._vendored import pad +_one = cp.asarray(1.0, dtype=cp.float32) + @cp.fuse() def _fused_variance_kernel1(eta, x_start, x_mid, x_end, y_start, y_mid, y_end): @@ -53,7 +55,7 @@ def _fused_hphi_hinv(phi): @cp.fuse() def _fused_variance_kernel2( - image, c1, c2, lam1, lam2, phi, K, dt, mu, delta_phi, Csum + image, c1, c2, lam1, lam2, phi, K, dt, mu, delta_phi, Csum, one ): difference_term = image - c1 difference_term *= difference_term @@ -65,7 +67,7 @@ def _fused_variance_kernel2( difference_term += term2 new_phi = phi + (dt * delta_phi) * (mu * K + difference_term) - out = new_phi / (1 + mu * dt * delta_phi * Csum) + out = new_phi / (one + mu * dt * delta_phi * Csum) return out @@ -105,7 +107,7 @@ def _cv_calculate_variation(image, phi, mu, lambda1, lambda2, dt): c1, c2 = _cv_calculate_averages(image, Hphi, Hinv) delta_phi = _cv_delta(phi) out = _fused_variance_kernel2( - image, c1, c2, lambda1, lambda2, phi, K, dt, mu, delta_phi, Csum + image, c1, c2, lambda1, lambda2, phi, K, dt, mu, delta_phi, Csum, _one ) return out @@ -430,6 +432,9 @@ def chan_vese( phivar = tol + 1 dt = cp.asarray(dt, dtype=float_dtype) + mu = cp.asarray(mu, dtype=float_dtype) + lambda1 = cp.asarray(lambda1, dtype=float_dtype) + lambda2 = cp.asarray(lambda2, dtype=float_dtype) while phivar > tol and i < max_num_iter: # Save old level set values oldphi = phi