Skip to content

Commit

Permalink
Added more test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
sivborg committed Sep 17, 2024
1 parent 12af338 commit e317ccb
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 4 deletions.
5 changes: 3 additions & 2 deletions pyxem/signals/beam_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,10 @@ def get_linear_plane(self, mask=None, fit_corners=None, initial_values=None, con
[d/dx, d/dy, x_0, d/dx, d/dy, y_0]
where the first three entries are for the x-shift, being in order the
step in x, step in y and the initial value at (0, 0). Similarly for the
last three entries for the y-shift.
last three entries for the y-shift. Currently only implemented for the
case when constrain_magnitude is `True`.
constrain_magnitude : bool, optional
Fits the linear planes so there are deflection with constant magnitude.
Fits the linear planes to deflections with constant magnitude.
In the presence of electromagnetic fields in the sample area, least squares
fitting can give inaccurate results. If the region is expected to have
uniform field strength, we can fit planes by trying to minimise the variance
Expand Down
89 changes: 89 additions & 0 deletions pyxem/tests/signals/test_beam_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,95 @@ def test_constrain_magnitude(self):
s_lp = s.get_linear_plane(constrain_magnitude=True)
assert np.allclose(s_lp.data, base_plane.data, rtol=1e-7)

def test_constrain_magnitude_mask(self):

p = [0.5]*6 # Plane parameters
x, y = np.meshgrid(np.arange(256), np.arange(256))
base_plane_x = p[0] * x + p[1] * y + p[2]
base_plane_y = p[3] * x + p[4] * y + p[5]

base_plane = np.stack((base_plane_x,base_plane_y)).T
data = base_plane.copy()

shifts = np.zeros_like(data)
shifts[:128,128:] =(10,10)
shifts[:128,:128] =(-10,-10)
shifts[128:,128:] =(-10,10)
shifts[128:,:128] =(10,10)
shifts[:,-10:] =(9999,1321)
shifts[:,:10] =(2213,-9879)
data += shifts

mask = np.zeros((256, 256), dtype=bool)
mask[:,-10:] =True
mask[:,:10] =True

s = BeamShift(data)

s_lp = s.get_linear_plane(constrain_magnitude=True)
assert not np.allclose(s_lp.data, base_plane.data, rtol=1e-7)

s_lp = s.get_linear_plane(constrain_magnitude=True, mask=mask)
assert np.allclose(s_lp.data, base_plane.data, rtol=1e-7)

def test_constrain_magnitude_mask(self):

p = [0.5]*6 # Plane parameters
x, y = np.meshgrid(np.arange(256), np.arange(256))
base_plane_x = p[0] * x + p[1] * y + p[2]
base_plane_y = p[3] * x + p[4] * y + p[5]

base_plane = np.stack((base_plane_x,base_plane_y)).T
data = base_plane.copy()

shifts = np.zeros_like(data)
shifts[:128,128:] =(10,10)
shifts[:128,:128] =(-10,-10)
shifts[128:,128:] =(-10,10)
shifts[128:,:128] =(10,10)
shifts[:,-10:] =(9999,1321)
shifts[:,:10] =(2213,-9879)
data += shifts

mask = np.zeros((256, 256), dtype=bool)
mask[:,-10:] =True
mask[:,:10] =True

s = BeamShift(data)

s_lp = s.get_linear_plane(constrain_magnitude=True)
assert not np.allclose(s_lp.data, base_plane.data, rtol=1e-7)

s_lp = s.get_linear_plane(constrain_magnitude=True, mask=mask)
assert np.allclose(s_lp.data, base_plane.data, rtol=1e-7)

def test_constrain_magnitude_initial_values(self):

p = [0.5]*6 # Plane parameters
x, y = np.meshgrid(np.arange(256), np.arange(256))
base_plane_x = p[0] * x + p[1] * y + p[2]
base_plane_y = p[3] * x + p[4] * y + p[5]

base_plane = np.stack((base_plane_x,base_plane_y)).T
data = base_plane.copy()

shifts = np.zeros_like(data)
shifts[:128,128:] =(10,10)
shifts[:128,:128] =(-10,-10)
shifts[128:, :] =(-10,10)
data += shifts

s = BeamShift(data)

# Plane fitting does poorly here, likely due to not enough different domains
s_lp = s.get_linear_plane(constrain_magnitude=True)
assert not np.allclose(s_lp.data, base_plane.data, rtol=1e-7)

# Varying the initial values around can help find different planes
initial_values= [1.0]*6
s_lp = s.get_linear_plane(constrain_magnitude=True, initial_values=initial_values)
assert np.allclose(s_lp.data, base_plane.data, rtol=1e-7)

def test_lazy_input_error(self):
s = LazyBeamShift(da.zeros((50, 40, 2)))
with pytest.raises(ValueError):
Expand Down
4 changes: 2 additions & 2 deletions pyxem/utils/_beam_shift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ def _get_linear_plane_from_signal2d(signal, mask=None, initial_values=None):

def _get_linear_plane_by_minimizing_magnitude_variance(signal, mask=None, initial_values=None):
if len(signal.axes_manager.navigation_axes) != 2:
raise ValueError("signal needs to have 1 navigation dimensions")
raise ValueError("signal needs to have 2 navigation dimensions")
if len(signal.axes_manager.signal_axes) != 1:
raise ValueError("signal needs to have 2 signal dimension")
raise ValueError("signal needs to have 1 signal dimension")
if initial_values is None:
initial_values = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1]

Expand Down

0 comments on commit e317ccb

Please sign in to comment.