Skip to content

Commit

Permalink
pass all constants to _fused_variance_kernel2 as device scalars
Browse files Browse the repository at this point in the history
* avoids an issue with an internal numpy.can_cast call in CuPy's kernel fusion code when using NumPy 2.0
  • Loading branch information
grlee77 committed Aug 20, 2024
1 parent ef99071 commit 59b3ad4
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions python/cucim/src/cucim/skimage/segmentation/_chan_vese.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from .._shared.utils import _supported_float_type
from .._vendored import pad

_one = cp.asarray(1.0, dtype=cp.float32)


@cp.fuse()
def _fused_variance_kernel1(eta, x_start, x_mid, x_end, y_start, y_mid, y_end):
Expand Down Expand Up @@ -53,7 +55,7 @@ def _fused_hphi_hinv(phi):

@cp.fuse()
def _fused_variance_kernel2(
image, c1, c2, lam1, lam2, phi, K, dt, mu, delta_phi, Csum
image, c1, c2, lam1, lam2, phi, K, dt, mu, delta_phi, Csum, one
):
difference_term = image - c1
difference_term *= difference_term
Expand All @@ -65,7 +67,7 @@ def _fused_variance_kernel2(
difference_term += term2

new_phi = phi + (dt * delta_phi) * (mu * K + difference_term)
out = new_phi / (1 + mu * dt * delta_phi * Csum)
out = new_phi / (one + mu * dt * delta_phi * Csum)
return out


Expand Down Expand Up @@ -105,7 +107,7 @@ def _cv_calculate_variation(image, phi, mu, lambda1, lambda2, dt):
c1, c2 = _cv_calculate_averages(image, Hphi, Hinv)
delta_phi = _cv_delta(phi)
out = _fused_variance_kernel2(
image, c1, c2, lambda1, lambda2, phi, K, dt, mu, delta_phi, Csum
image, c1, c2, lambda1, lambda2, phi, K, dt, mu, delta_phi, Csum, _one
)
return out

Expand Down Expand Up @@ -430,6 +432,9 @@ def chan_vese(
phivar = tol + 1

dt = cp.asarray(dt, dtype=float_dtype)
mu = cp.asarray(mu, dtype=float_dtype)
lambda1 = cp.asarray(lambda1, dtype=float_dtype)
lambda2 = cp.asarray(lambda2, dtype=float_dtype)
while phivar > tol and i < max_num_iter:
# Save old level set values
oldphi = phi
Expand Down

0 comments on commit 59b3ad4

Please sign in to comment.