Skip to content

Commit

Permalink
Added constrain_magnitude linear plane fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
sivborg committed Sep 16, 2024
1 parent b1aa171 commit 827473b
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
23 changes: 18 additions & 5 deletions pyxem/signals/beam_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def make_linear_plane(self, **kwargs):
self.data = s_linear_plane.data
self.events.data_changed.trigger(None)

def get_linear_plane(self, mask=None, fit_corners=None):
def get_linear_plane(self, mask=None, fit_corners=None, constrain_magnitude=False):
"""Fit linear planes to the beam shifts, and returns a BeamShift signal
with the planes.
Expand Down Expand Up @@ -81,6 +81,14 @@ def get_linear_plane(self, mask=None, fit_corners=None):
fit_corners : float, optional
Make a mask so that the planes are fitted to the corners of the
signal. This mush be set with a number, like 0.05 (5%) or 0.10 (10%).
constrain_magnitude : bool, optional
Fits the linear planes so there are deflection with constant magnitude.
In the presence of electromagnetic fields in the sample area, least squares
fitting can give inaccurate results. If the region is expected to have
uniform field strength, we can fit planes by trying to minimise the variance
of the magnitudes, giving a constant deflection magnitude.
Note that for this to work several field directions must be present. Extra
care must be taken in presence of significant noise.
Examples
--------
Expand Down Expand Up @@ -116,10 +124,15 @@ def get_linear_plane(self, mask=None, fit_corners=None):
s_shift_x = self.isig[0].T
s_shift_y = self.isig[1].T
if mask is not None:
mask = mask.__array__()
plane_image_x = bst._get_linear_plane_from_signal2d(s_shift_x, mask=mask)
plane_image_y = bst._get_linear_plane_from_signal2d(s_shift_y, mask=mask)
plane_image = np.stack((plane_image_x, plane_image_y), -1)
mask = mask.__array__()
if mask.dtype != bool:
raise ValueError("mask needs to be an array of bools")
if constrain_magnitude:
plane_image = bst._get_linear_xy_planes_from_signal2d(self, mask=mask)
else:
plane_image_x = bst._get_linear_plane_from_signal2d(s_shift_x, mask=mask)
plane_image_y = bst._get_linear_plane_from_signal2d(s_shift_y, mask=mask)
plane_image = np.stack((plane_image_x, plane_image_y), -1)
s_bs = self._deepcopy_with_new_data(plane_image)
return s_bs

Expand Down
39 changes: 39 additions & 0 deletions pyxem/utils/_beam_shift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@ def _f_min(X, p):
def _residuals(params, X):
return _f_min(X, params)

def _distance_residuals(params, X):
plane_x_xy = params[0:2]
plane_y_xy = params[3:5]
distance_x = (plane_x_xy * X[:, :2]).sum(axis=1) + params[2]
distance_y = (plane_y_xy * X[:, :2]).sum(axis=1) + params[5]
distance = np.sqrt(((distance_x - X[:, 2]) ** 2 + (distance_y - X[:, 3]) ** 2))

return distance - np.mean(distance)


def _plane_parameters_to_image(p, xaxis, yaxis):
"""Get a plane 2D array from plane parameters.
Expand Down Expand Up @@ -158,6 +167,36 @@ def _get_linear_plane_from_signal2d(signal, mask=None, initial_values=None):
return plane


def _get_linear_xy_planes_from_signal2d(signal, mask=None, initial_values=None):
if len(signal.axes_manager.navigation_axes) != 2:
raise ValueError("signal needs to have 1 navigation dimensions")
if len(signal.axes_manager.signal_axes) != 1:
raise ValueError("signal needs to have 2 signal dimension")
if initial_values is None:
initial_values = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1]

signal = signal.T
sam = signal.axes_manager.signal_axes
xaxis, yaxis = sam[0].axis, sam[1].axis
x, y = np.meshgrid(xaxis, yaxis)
xx, yy = x.flatten(), y.flatten()
values_x = signal.data[0].flatten()
values_y = signal.data[1].flatten()
points = np.stack((xx, yy, values_x, values_y)).T
if mask is not None:
if mask.__array__().shape != signal.T.__array__().shape[:2]:
raise ValueError("signal and mask need to have the same navigation shape")
points = points[np.invert(mask).flatten()]

p = opt.leastsq(_distance_residuals, initial_values, args=points)[0]

x, y = np.meshgrid(xaxis, yaxis)
z_x = p[0] * x + p[1] * y + p[2]
z_y = p[3] * x + p[4] * y + p[5]

return np.stack((z_x, z_y), axis=-1)


def _get_limits_from_array(data, sigma=4, ignore_zeros=False, ignore_edges=False):
if ignore_edges:
x_lim = int(data.shape[0] * 0.05)
Expand Down

0 comments on commit 827473b

Please sign in to comment.