From 93af2a579874edb914657ffbc27fc315bff6abf3 Mon Sep 17 00:00:00 2001 From: Max Charles Date: Tue, 15 Oct 2024 13:42:55 +1100 Subject: [PATCH] added 'method' argument for psf.convolve (#278) added 'method' argument for psf.convolve --- src/dLux/psfs.py | 9 +++++++-- tests/test_psfs.py | 1 + 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/dLux/psfs.py b/src/dLux/psfs.py index f7fa6795..08cad9a6 100644 --- a/src/dLux/psfs.py +++ b/src/dLux/psfs.py @@ -79,7 +79,7 @@ def downsample(self: PSF, n: int) -> PSF: """ return self.set("data", dlu.downsample(self.data, n, mean=False)) - def convolve(self: PSF, other: Array) -> PSF: + def convolve(self: PSF, other: Array, method: str = "auto") -> PSF: """ Convolves the psf with some input array. @@ -87,13 +87,18 @@ def convolve(self: PSF, other: Array) -> PSF: ---------- other : Array The psf to convolve with. + method : str = "auto" + The method to use for the convolution. Can be "auto", "direct", + or "fft". Is "auto" by default, which calls "direct". Returns ------- psf : PSF The convolved psf. """ - return self.set("data", convolve(self.data, other, mode="same")) + return self.set( + "data", convolve(self.data, other, mode="same", method=method) + ) def rotate(self: PSF, angle: float, order: int = 1) -> PSF: """ diff --git a/tests/test_psfs.py b/tests/test_psfs.py index 8449c5c5..653d93fa 100644 --- a/tests/test_psfs.py +++ b/tests/test_psfs.py @@ -21,6 +21,7 @@ def test_properties(self, psf): def test_methods(self, psf): assert psf.downsample(2).npixels == 8 assert isinstance(psf.convolve(np.ones((2, 2))), PSF) + assert isinstance(psf.convolve(np.ones((2, 2)), method="fft"), PSF) assert isinstance(psf.rotate(np.pi), PSF) assert isinstance(psf.resize(8), PSF) assert isinstance(psf.flip(0), PSF)