-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added decorator to support scalar arguments in methods of droplet class
This functionality was ported from `py-pde` where it is now deprecated.
- Loading branch information
1 parent
080535a
commit bef2a64
Showing
4 changed files
with
79 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
"""Miscellaneous functions. | ||
.. autosummary:: | ||
:nosignatures: | ||
enable_scalar_args | ||
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de> | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
import functools | ||
from typing import Any, Callable, TypeVar | ||
|
||
from pde.tools.misc import number_array | ||
|
||
TFunc = TypeVar("TFunc", bound=Callable[..., Any]) | ||
|
||
|
||
def enable_scalar_args(method: TFunc) -> TFunc: | ||
"""Decorator that makes vectorized methods work with scalars. | ||
This decorator allows to call functions that are written to work on numpy arrays to | ||
also accept python scalars, like `int` and `float`. Essentially, this wrapper turns | ||
them into an array and unboxes the result. Note that the dtype of the returned value | ||
will always be double or cdouble even if the function is called with an integer. | ||
Args: | ||
method: The method being decorated | ||
Returns: | ||
The decorated method | ||
""" | ||
|
||
@functools.wraps(method) | ||
def wrapper(self, *args): | ||
args = [number_array(arg, copy=None) for arg in args] | ||
if args[0].ndim == 0: | ||
args = [arg[None] for arg in args] | ||
return method(self, *args)[0] | ||
else: | ||
return method(self, *args) | ||
|
||
return wrapper # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
""" | ||
.. codeauthor:: David Zwicker <david.zwicker@ds.mpg.de> | ||
""" | ||
|
||
import numpy as np | ||
from scipy import integrate | ||
|
||
from droplets.tools import misc | ||
|
||
|
||
def test_enable_scalar_args(): | ||
"""Test the enable_scalar_args decorator.""" | ||
|
||
class Test: | ||
@misc.enable_scalar_args | ||
def meth(self, arr): | ||
return arr + 1 | ||
|
||
t = Test() | ||
|
||
assert t.meth(1) == 2 | ||
assert isinstance(t.meth(1), float) | ||
assert isinstance(t.meth(1.0), float) | ||
np.testing.assert_equal(t.meth(np.ones(2)), np.full(2, 2)) |
File renamed without changes.