From acfc508ebd35ad46bbccbd19f2eb57d34609f0e7 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 26 Sep 2024 12:38:36 +0800 Subject: [PATCH] Fix dtype bug in ScaleIntensityRangePercentile (#8109) Fixes #8108 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/intensity/array.py | 2 +- tests/test_scale_intensity_range_percentiles.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 3b813809e4..20000c52c4 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1411,7 +1411,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: else: img_t = self._normalize(img=img_t) - return convert_to_dst_type(img_t, dst=img)[0] + return convert_to_dst_type(img_t, dst=img, dtype=self.dtype)[0] class MaskIntensity(Transform): diff --git a/tests/test_scale_intensity_range_percentiles.py b/tests/test_scale_intensity_range_percentiles.py index 7c3a684a00..a7390efe72 100644 --- a/tests/test_scale_intensity_range_percentiles.py +++ b/tests/test_scale_intensity_range_percentiles.py @@ -14,6 +14,7 @@ import unittest import numpy as np +import torch from monai.transforms.intensity.array import ScaleIntensityRangePercentiles from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose @@ -34,6 +35,7 @@ def test_scaling(self): scaler = ScaleIntensityRangePercentiles(lower=lower, upper=upper, b_min=b_min, b_max=b_max, dtype=np.uint8) for p in TEST_NDARRAYS: result = scaler(p(img)) + self.assertEqual(result.dtype, torch.uint8) assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4) def test_relative_scaling(self):