Skip to content

Commit

Permalink
Use matplotlib tri interpolators (#104)
Browse files Browse the repository at this point in the history
* Use matplotlib tri interpolators instead of scipy.interpolate. This is much faster because the matplotlib versions can use the existing triangulation.
  • Loading branch information
loganbvh authored Aug 3, 2023
1 parent cc8ae2e commit 7bf7f78
Show file tree
Hide file tree
Showing 13 changed files with 119 additions and 92 deletions.
10 changes: 5 additions & 5 deletions docs/notebooks/field-sources.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@
{
"data": {
"text/html": [
"<table><tr><th>Software</th><th>Version</th></tr><tr><td>SuperScreen</td><td>0.9.1</td></tr><tr><td>Numpy</td><td>1.23.3</td></tr><tr><td>Numba</td><td>0.57.0</td></tr><tr><td>SciPy</td><td>1.9.1</td></tr><tr><td>matplotlib</td><td>3.6.0</td></tr><tr><td>IPython</td><td>8.5.0</td></tr><tr><td>Python</td><td>3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:01:00) \n",
"[Clang 13.0.1 ]</td></tr><tr><td>OS</td><td>posix [darwin]</td></tr><tr><td>Number of CPUs</td><td>Physical: 10, Logical: 10</td></tr><tr><td>BLAS Info</td><td>OPENBLAS</td></tr><tr><td colspan='2'>Wed May 17 10:59:46 2023 PDT</td></tr></table>"
"<table><tr><th>Software</th><th>Version</th></tr><tr><td>SuperScreen</td><td>0.10.0</td></tr><tr><td>Numpy</td><td>1.23.3</td></tr><tr><td>Numba</td><td>0.57.0</td></tr><tr><td>SciPy</td><td>1.9.1</td></tr><tr><td>matplotlib</td><td>3.6.0</td></tr><tr><td>IPython</td><td>8.5.0</td></tr><tr><td>Python</td><td>3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:01:00) \n",
"[Clang 13.0.1 ]</td></tr><tr><td>OS</td><td>posix [darwin]</td></tr><tr><td>Number of CPUs</td><td>Physical: 10, Logical: 10</td></tr><tr><td>BLAS Info</td><td>OPENBLAS</td></tr><tr><td colspan='2'>Thu Aug 03 13:18:33 2023 PDT</td></tr></table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand Down Expand Up @@ -142,9 +142,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"x = 0.5517062950352116\n",
"y = 0.708038021214086\n",
"z = 0.6209502532087992\n",
"x = 0.45250693865850633\n",
"y = 0.26481336104450504\n",
"z = 0.4510648754100197\n",
"field(x, y, z) = 5.0\n"
]
}
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/logo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@
{
"data": {
"text/html": [
"<table><tr><th>Software</th><th>Version</th></tr><tr><td>SuperScreen</td><td>0.9.1</td></tr><tr><td>Numpy</td><td>1.23.3</td></tr><tr><td>Numba</td><td>0.57.0</td></tr><tr><td>SciPy</td><td>1.9.1</td></tr><tr><td>matplotlib</td><td>3.6.0</td></tr><tr><td>IPython</td><td>8.5.0</td></tr><tr><td>Python</td><td>3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:01:00) \n",
"[Clang 13.0.1 ]</td></tr><tr><td>OS</td><td>posix [darwin]</td></tr><tr><td>Number of CPUs</td><td>Physical: 10, Logical: 10</td></tr><tr><td>BLAS Info</td><td>OPENBLAS</td></tr><tr><td colspan='2'>Wed May 17 11:00:35 2023 PDT</td></tr></table>"
"<table><tr><th>Software</th><th>Version</th></tr><tr><td>SuperScreen</td><td>0.10.0</td></tr><tr><td>Numpy</td><td>1.23.3</td></tr><tr><td>Numba</td><td>0.57.0</td></tr><tr><td>SciPy</td><td>1.9.1</td></tr><tr><td>matplotlib</td><td>3.6.0</td></tr><tr><td>IPython</td><td>8.5.0</td></tr><tr><td>Python</td><td>3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:01:00) \n",
"[Clang 13.0.1 ]</td></tr><tr><td>OS</td><td>posix [darwin]</td></tr><tr><td>Number of CPUs</td><td>Physical: 10, Logical: 10</td></tr><tr><td>BLAS Info</td><td>OPENBLAS</td></tr><tr><td colspan='2'>Thu Aug 03 13:18:57 2023 PDT</td></tr></table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand Down
20 changes: 10 additions & 10 deletions docs/notebooks/polygons.ipynb

Large diffs are not rendered by default.

54 changes: 38 additions & 16 deletions docs/notebooks/quickstart.ipynb

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions docs/notebooks/scanning-squid.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@
{
"data": {
"text/html": [
"<table><tr><th>Software</th><th>Version</th></tr><tr><td>SuperScreen</td><td>0.9.2</td></tr><tr><td>Numpy</td><td>1.23.3</td></tr><tr><td>Numba</td><td>0.57.0</td></tr><tr><td>SciPy</td><td>1.9.1</td></tr><tr><td>matplotlib</td><td>3.6.0</td></tr><tr><td>IPython</td><td>8.5.0</td></tr><tr><td>Python</td><td>3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:01:00) \n",
"[Clang 13.0.1 ]</td></tr><tr><td>OS</td><td>posix [darwin]</td></tr><tr><td>Number of CPUs</td><td>Physical: 10, Logical: 10</td></tr><tr><td>BLAS Info</td><td>OPENBLAS</td></tr><tr><td colspan='2'>Thu May 18 14:28:01 2023 PDT</td></tr></table>"
"<table><tr><th>Software</th><th>Version</th></tr><tr><td>SuperScreen</td><td>0.10.0</td></tr><tr><td>Numpy</td><td>1.23.3</td></tr><tr><td>Numba</td><td>0.57.0</td></tr><tr><td>SciPy</td><td>1.9.1</td></tr><tr><td>matplotlib</td><td>3.6.0</td></tr><tr><td>IPython</td><td>8.5.0</td></tr><tr><td>Python</td><td>3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:01:00) \n",
"[Clang 13.0.1 ]</td></tr><tr><td>OS</td><td>posix [darwin]</td></tr><tr><td>Number of CPUs</td><td>Physical: 10, Logical: 10</td></tr><tr><td>BLAS Info</td><td>OPENBLAS</td></tr><tr><td colspan='2'>Thu Aug 03 13:20:42 2023 PDT</td></tr></table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand Down Expand Up @@ -160,10 +160,10 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Solver iterations: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:02<00:00, 2.32it/s]\n",
"Solver iterations: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 2.74it/s]\n",
"Solver iterations: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:04<00:00, 1.19it/s]\n",
"Solver iterations: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:03<00:00, 1.29it/s]\n"
"Solver iterations: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:02<00:00, 2.41it/s]\n",
"Solver iterations: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 2.59it/s]\n",
"Solver iterations: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:03<00:00, 1.28it/s]\n",
"Solver iterations: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:03<00:00, 1.42it/s]\n"
]
}
],
Expand Down
24 changes: 16 additions & 8 deletions docs/notebooks/terminal-currents.ipynb

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions superscreen/device/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import matplotlib.pyplot as plt
import numpy as np
import scipy.sparse as sp
from matplotlib.tri import Triangulation

from ..distance import q_matrix
from ..fem import gradient_vertices, laplace_operator
Expand Down Expand Up @@ -51,9 +52,19 @@ def __init__(
self.triangle_areas = np.asarray(triangle_areas)
self.edge_mesh = edge_mesh
self.operators: Optional[MeshOperators] = None
self._triangulation: Optional[Triangulation] = None
if build_operators:
self.operators = MeshOperators.from_mesh(self)

@property
def triangulation(self) -> Triangulation:
"""Matplotlib triangulation of the mesh."""
if self._triangulation is None:
self._triangulation = Triangulation(
self.sites[:, 0], self.sites[:, 1], self.elements
)
return self._triangulation

def stats(self) -> Dict[str, Union[int, float]]:
"""Returns a dictionary of information about the mesh."""
edge_lengths = self.edge_mesh.edge_lengths
Expand Down
48 changes: 18 additions & 30 deletions superscreen/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

import h5py
import matplotlib.pyplot as plt
import matplotlib.tri as mtri
import numpy as np
import pint
from scipy import interpolate

from .about import version_dict
from .device import Device, Polygon
Expand All @@ -33,7 +33,7 @@

logger = logging.getLogger("solution")

InterpolatorType = Literal["nearest", "linear", "cubic"]
InterpolatorType = Literal["linear", "cubic"]


class Fluxoid(NamedTuple):
Expand Down Expand Up @@ -271,9 +271,8 @@ def version_info(self) -> Dict[str, str]:
@staticmethod
def _select_interpolator(method: InterpolatorType) -> type:
return {
"nearest": interpolate.NearestNDInterpolator,
"linear": interpolate.LinearNDInterpolator,
"cubic": interpolate.CloughTocher2DInterpolator,
"linear": mtri.LinearTriInterpolator,
"cubic": mtri.CubicTriInterpolator,
}[method]

def interp_current_density(
Expand All @@ -284,20 +283,14 @@ def interp_current_density(
method: InterpolatorType = "linear",
units: Optional[str] = None,
with_units: bool = False,
**kwargs,
) -> np.ndarray:
"""Interpolates the current density ``J = [dg/dy, -dg/dx]`` within a film.
Additional keyword arguments are passed to the relevant interpolator:
:class:`scipy.interpolate.NearestNDInterpolator`,
:class:`scipy.interpolate.LinearNDInterpolator`, or
:class:`scipy.interpolate.CloughTocher2DInterpolator`.
Args:
positions: Shape ``(m, 2)`` array of x, y coordinates at which to evaluate
the current density.
film: The name of the film in which to interpolate current density.
method: Interpolation method to use ("nearest", "linear" or "cubic").
method: Interpolation method to use ("linear" or "cubic").
units: The desired units for the current density. Defaults to
``self.current_units / self.device.length_units``.
with_units: Whether to return arrays of pint.Quantities with units attached.
Expand All @@ -310,12 +303,13 @@ def interp_current_density(
if units is None:
units = default_units
positions = np.atleast_2d(positions)
interpolator = self._select_interpolator(method)
xy = device.meshes[film].sites
xv, yv = positions.T
interp_type = self._select_interpolator(method)
mesh = device.meshes[film]
J = self.film_solutions[film].current_density
Jx_interp = interpolator(xy, J[:, 0], **kwargs)
Jy_interp = interpolator(xy, J[:, 1], **kwargs)
J = np.stack([Jx_interp(positions), Jy_interp(positions)], axis=1)
Jx_interp = interp_type(mesh.triangulation, J[:, 0])
Jy_interp = interp_type(mesh.triangulation, J[:, 1])
J = np.array([Jx_interp(xv, yv).data, Jy_interp(xv, yv).data]).T
in_film = device.films[film].contains_points(positions)
J[~in_film] = 0
J[~np.isfinite(J).all(axis=1)] = 0
Expand All @@ -339,7 +333,7 @@ def current_through_path(
path_coords: An ``(n, 2)`` array of ``(x, y)`` coordinates defining
the path.
film: The name of the film in which to interpolate current density.
interp_method: Interpolation method to use ("nearest", "linear" or "cubic").
interp_method: Interpolation method to use ("linear" or "cubic").
units: The current units to return.
with_units: Whether to return a :class:`pint.Quantity` with units attached.
Expand Down Expand Up @@ -378,22 +372,16 @@ def interp_field(
method: InterpolatorType = "linear",
units: Optional[str] = None,
with_units: bool = False,
**kwargs,
):
"""Interpolates the z-component of the field within a film.
Additional keyword arguments are passed to the relevant interpolator:
:class:`scipy.interpolate.NearestNDInterpolator`,
:class:`scipy.interpolate.LinearNDInterpolator`, or
:class:`scipy.interpolate.CloughTocher2DInterpolator`.
Args:
positions: Shape ``(m, 2)`` array of x, y coordinates at which to evaluate
the fields.
film: The name of the film in which to interpolate the field.
dataset: The dataset to interpolate. One of 'field', 'self_field',
'applied_field', or 'field_from_other_films'.
method: Interpolation method to use: 'nearest', 'linear', or 'cubic'.
method: Interpolation method to use: 'linear' or 'cubic'.
units: The desired units for the current density. Defaults to
``self.field_units``.
with_units: Whether to return arrays of pint.Quantities with units attached.
Expand All @@ -403,7 +391,7 @@ def interp_field(
"""
from .solver import convert_field

interpolator = self._select_interpolator(method)
interp_type = self._select_interpolator(method)
device = self.device
if units is None:
units = self.field_units
Expand All @@ -413,7 +401,7 @@ def interp_field(
"applied_field",
"field_from_other_films",
)
points = self.device.meshes[film].sites
mesh = self.device.meshes[film]
if dataset not in valid_datasets:
raise ValueError(
f"Invalid dataset: {dataset!r}. Expected one of {valid_datasets!r}"
Expand All @@ -427,11 +415,11 @@ def interp_field(
else:
field = self.film_solutions[film].field_from_other_films
if field is None:
field = np.zeros(len(points))
field = np.zeros(len(mesh.sites))
positions = np.atleast_2d(positions)
Hz_interp = interpolator(points, field, **kwargs)
Hz_interp = interp_type(mesh.triangulation, field)
Hz = convert_field(
Hz_interp(positions),
Hz_interp(positions[:, 0], positions[:, 1]).data,
units,
old_units=self.field_units,
ureg=device.ureg,
Expand Down
4 changes: 2 additions & 2 deletions superscreen/sources/vortex.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ def pearl_vortex(
hzk = np.fft.fftshift(hzk)
hz = np.abs(np.fft.fftshift(np.fft.ifft2(hzk))) / (dx * dy)
# Interpolate to x, y, z coordinates
XY = np.stack([X.ravel(), Y.ravel()], axis=1)
XY = np.array([X.ravel(), Y.ravel()]).T
interp = LinearNDInterpolator(XY, hz.ravel())
return interp(np.stack([x, y], axis=1)).squeeze()
return interp(np.array([x, y]).T).squeeze()


def PearlVortexField(
Expand Down
4 changes: 2 additions & 2 deletions superscreen/test/test_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def test_fluxoid_simply_connected(
@pytest.mark.parametrize(
"film, method, positions",
[
("disk", "nearest", [0, 0]),
("disk", "linear", [0, 0]),
("ring", "linear", np.array([[1, 0], [0, 1]])),
("disk", "cubic", None),
],
Expand Down Expand Up @@ -340,7 +340,7 @@ def test_bz_from_vector_potential(

@pytest.mark.parametrize("units", [None, "mT", "mA/um"])
@pytest.mark.parametrize("with_units", [False, True])
@pytest.mark.parametrize("method", ["nearest", "linear", "cubic"])
@pytest.mark.parametrize("method", ["linear", "cubic"])
def test_interp_field(solution2: sc.Solution, units, with_units, method):
solution = solution2
positions = np.random.random(size=(100, 2))
Expand Down
4 changes: 2 additions & 2 deletions superscreen/test/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ def test_plot_streams(solution, films, units):
]
thetas = np.linspace(0, 2 * np.pi, endpoint=False)
cross_section_coord_params.append(
[r * np.stack([np.cos(thetas), np.sin(thetas)], axis=1) for r in (0.5, 1.0, 1.5)]
[r * np.array([np.cos(thetas), np.sin(thetas)]).T for r in (0.5, 1.0, 1.5)]
)


@pytest.mark.parametrize("interp_method", ["nearest", "linear", "cubic"])
@pytest.mark.parametrize("interp_method", ["linear", "cubic"])
@pytest.mark.parametrize("cross_section_coords", cross_section_coord_params)
def test_cross_section(solution, cross_section_coords, interp_method):
if cross_section_coords is None:
Expand Down
2 changes: 1 addition & 1 deletion superscreen/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version_info__ = (0, 9, 2)
__version_info__ = (0, 10, 0)
__version__ = ".".join(map(str, __version_info__))
14 changes: 6 additions & 8 deletions superscreen/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,24 +219,22 @@ def cross_section(
cross_section_coords: A shape (m, 2) array of (x, y) coordinates specifying
the cross-section path (or a list of such arrays for multiple
cross sections).
interp_method: The interpolation method to use: "nearest", "linear", "cubic".
interp_method: The interpolation method to use: "linear" or "cubic".
Returns:
A list of coordinate arrays, a list of curvilinear coordinate (path) arrays,
and a list of cross section values.
"""
valid_methods = ("nearest", "linear", "cubic")
valid_methods = ("linear", "cubic")
if interp_method not in valid_methods:
raise ValueError(
f"Interpolation method must be one of {valid_methods} "
f"(got {interp_method})."
)
if interp_method == "nearest":
interpolator = interpolate.NearestNDInterpolator
elif interp_method == "linear":
interpolator = interpolate.LinearNDInterpolator
else: # "cubic"
interpolator = interpolate.CloughTocher2DInterpolator
interpolator = {
"linear": interpolate.LinearNDInterpolator,
"cubic": interpolate.CloughTocher2DInterpolator,
}[interp_method]

if not (isinstance(cross_section_coords, Sequence)):
cross_section_coords = [cross_section_coords]
Expand Down

0 comments on commit 7bf7f78

Please sign in to comment.