Skip to content

Commit

Permalink
Added decorator to support scalar arguments in methods of droplet class
Browse files Browse the repository at this point in the history
This functionality was ported from `py-pde` where it is now deprecated.
  • Loading branch information
david-zwicker committed Oct 15, 2024
1 parent 080535a commit bef2a64
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 10 deletions.
20 changes: 10 additions & 10 deletions droplets/droplets.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@
from pde.fields import ScalarField
from pde.grids.base import GridBase
from pde.tools.cuboid import Cuboid
from pde.tools.misc import preserve_scalars
from pde.tools.plotting import PlotReference, plot_on_axes

from .tools import spherical
from .tools.misc import enable_scalar_args

TDroplet = TypeVar("TDroplet", bound="DropletBase")
DTypeList = list[Union[tuple[str, type[Any]], tuple[str, type[Any], tuple[int, ...]]]]
Expand Down Expand Up @@ -399,7 +399,7 @@ def merge_data(drop1: np.ndarray, drop2: np.ndarray, out: np.ndarray) -> None:

return merge_data # type: ignore

@preserve_scalars
@enable_scalar_args
def interface_position(self, *args) -> np.ndarray:
r"""Calculates the position of the interface of the droplet.
Expand Down Expand Up @@ -914,7 +914,7 @@ def __init__(
"the highest mode."
)

@preserve_scalars
@enable_scalar_args
def interface_distance(self, φ: np.ndarray) -> np.ndarray: # type: ignore
"""Calculates the distance of the droplet interface to the origin.
Expand All @@ -933,7 +933,7 @@ def interface_distance(self, φ: np.ndarray) -> np.ndarray: # type: ignore
dist += b * np.cos(n * φ)
return self.radius * dist

@preserve_scalars
@enable_scalar_args
def interface_position(self, φ: np.ndarray) -> np.ndarray:
"""Calculates the position of the interface of the droplet.
Expand All @@ -948,7 +948,7 @@ def interface_position(self, φ: np.ndarray) -> np.ndarray:
pos = dist[:, None] * np.transpose([np.cos(φ), np.sin(φ)])
return self.position[None, :] + pos # type: ignore

@preserve_scalars
@enable_scalar_args
def interface_curvature(self, φ: np.ndarray) -> np.ndarray: # type: ignore
r"""Calculates the mean curvature of the interface of the droplet.
Expand Down Expand Up @@ -1097,7 +1097,7 @@ def __init__(
opt_modes,
)

@preserve_scalars
@enable_scalar_args
def interface_distance( # type: ignore
self, θ: np.ndarray, φ: np.ndarray | None = None
) -> np.ndarray:
Expand All @@ -1122,7 +1122,7 @@ def interface_distance( # type: ignore
dist += a * spherical.spherical_harmonic_real_k(k, θ, φ) # type: ignore
return self.radius * dist

@preserve_scalars
@enable_scalar_args
def interface_position(
self, θ: np.ndarray, φ: np.ndarray | None = None
) -> np.ndarray:
Expand All @@ -1146,7 +1146,7 @@ def interface_position(
pos = dist[:, None] * np.transpose(unit_vector)
return self.position[None, :] + pos # type: ignore

@preserve_scalars
@enable_scalar_args
def interface_curvature( # type: ignore
self, θ: np.ndarray, φ: np.ndarray | None = None
) -> np.ndarray:
Expand Down Expand Up @@ -1231,7 +1231,7 @@ def check_data(self):
if not np.allclose(self.position[:2], 0):
raise ValueError("Droplet must lie on z-axis")

@preserve_scalars
@enable_scalar_args
def interface_distance(self, θ: np.ndarray) -> np.ndarray: # type: ignore
r"""Calculates the distance of the droplet interface to the origin.
Expand All @@ -1248,7 +1248,7 @@ def interface_distance(self, θ: np.ndarray) -> np.ndarray: # type: ignore
dist += a * spherical.spherical_harmonic_symmetric(order, θ) # type: ignore
return self.radius * dist

@preserve_scalars
@enable_scalar_args
def interface_curvature(self, θ: np.ndarray) -> np.ndarray: # type: ignore
r"""Calculates the mean curvature of the interface of the droplet.
Expand Down
45 changes: 45 additions & 0 deletions droplets/tools/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Miscellaneous functions.
.. autosummary::
:nosignatures:
enable_scalar_args
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

from __future__ import annotations

import functools
from typing import Any, Callable, TypeVar

from pde.tools.misc import number_array

TFunc = TypeVar("TFunc", bound=Callable[..., Any])


def enable_scalar_args(method: TFunc) -> TFunc:
"""Decorator that makes vectorized methods work with scalars.
This decorator allows to call functions that are written to work on numpy arrays to
also accept python scalars, like `int` and `float`. Essentially, this wrapper turns
them into an array and unboxes the result. Note that the dtype of the returned value
will always be double or cdouble even if the function is called with an integer.
Args:
method: The method being decorated
Returns:
The decorated method
"""

@functools.wraps(method)
def wrapper(self, *args):
args = [number_array(arg, copy=None) for arg in args]
if args[0].ndim == 0:
args = [arg[None] for arg in args]
return method(self, *args)[0]
else:
return method(self, *args)

return wrapper # type: ignore
24 changes: 24 additions & 0 deletions tests/tools/test_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de>
"""

import numpy as np
from scipy import integrate

from droplets.tools import misc


def test_enable_scalar_args():
"""Test the enable_scalar_args decorator."""

class Test:
@misc.enable_scalar_args
def meth(self, arr):
return arr + 1

t = Test()

assert t.meth(1) == 2
assert isinstance(t.meth(1), float)
assert isinstance(t.meth(1.0), float)
np.testing.assert_equal(t.meth(np.ones(2)), np.full(2, 2))
File renamed without changes.

0 comments on commit bef2a64

Please sign in to comment.