diff --git a/.github/workflows/python-app.yml b/.github/workflows/build-flake.yml similarity index 100% rename from .github/workflows/python-app.yml rename to .github/workflows/build-flake.yml diff --git a/docs/source/api/emd.rst b/docs/source/api/emd.rst index 660f9777f..8e4223496 100644 --- a/docs/source/api/emd.rst +++ b/docs/source/api/emd.rst @@ -12,7 +12,6 @@ Classes .. autoclass:: emdfile.PointList .. autoclass:: emdfile.PointListArray .. autoclass:: emdfile.Root -.. autoclass:: emdfile.RootedNode Functions @@ -26,4 +25,4 @@ Functions .. autofunction:: emdfile.read .. autofunction:: emdfile.save .. autofunction:: emdfile.set_author -.. autofunction:: emdfile.tqdmnd \ No newline at end of file +.. autofunction:: emdfile.tqdmnd diff --git a/py4DSTEM/__init__.py b/py4DSTEM/__init__.py index dc4f765ae..0d85948b1 100644 --- a/py4DSTEM/__init__.py +++ b/py4DSTEM/__init__.py @@ -2,7 +2,9 @@ from emdfile import tqdmnd -# io classes +### io + +# substructure from emdfile import ( Node, Root, @@ -10,54 +12,83 @@ Array, PointList, PointListArray, - Custom + Custom, + print_h5_tree, ) +_emd_hook = True -# processing classes -from py4DSTEM.classes import ( - DataCube, +# structure +from py4DSTEM import io +from py4DSTEM.io import import_file,read,save + + + +### basic data classes + +# data +from py4DSTEM.data import ( + Data, + Calibration, DiffractionSlice, RealSlice, - VirtualDiffraction, - VirtualImage, - Probe, QPoints, - Calibration, - Data, ) -from py4DSTEM.process.diskdetection import ( + +# datacube +from py4DSTEM.datacube import ( + DataCube, + VirtualImage, + VirtualDiffraction +) + + + +### visualization + +from py4DSTEM import visualize +from py4DSTEM.visualize import show, show_complex + +### analysis classes + +# braggvectors +from py4DSTEM.braggvectors import ( + Probe, BraggVectors, BraggVectorMap, ) + +# strain +from py4DSTEM.process import StrainMap + +# TODO - crystal +# TODO - ptycho +# TODO - others + +# TODO - where from py4DSTEM.process import ( PolarDatacube, ) -# submodules -from py4DSTEM import io + +### more submodules +# TODO + from py4DSTEM import preprocess from py4DSTEM import process -from py4DSTEM import classes -from py4DSTEM import visualize +### utilities -# functions -from emdfile import print_h5_tree -from py4DSTEM.visualize import show -from py4DSTEM.io import import_file,read,save +# config from py4DSTEM.utils.configuration_checker import check_config +# TODO - config .toml - -# test paths +# testing from os.path import dirname,join _TESTPATH = join(dirname(__file__), "../test/unit_test_data") -# hook for emd _get_class -_emd_hook = True - diff --git a/py4DSTEM/braggvectors/__init__.py b/py4DSTEM/braggvectors/__init__.py new file mode 100644 index 000000000..030fe6358 --- /dev/null +++ b/py4DSTEM/braggvectors/__init__.py @@ -0,0 +1,8 @@ +from py4DSTEM.braggvectors.probe import Probe +from py4DSTEM.braggvectors.braggvectors import BraggVectors +from py4DSTEM.braggvectors.braggvector_methods import BraggVectorMap +from py4DSTEM.braggvectors.diskdetection import * +from py4DSTEM.braggvectors.probe import * +#from .diskdetection_aiml import * +#from .diskdetection_parallel_new import * + diff --git a/py4DSTEM/process/diskdetection/braggvector_methods.py b/py4DSTEM/braggvectors/braggvector_methods.py similarity index 67% rename from py4DSTEM/process/diskdetection/braggvector_methods.py rename to py4DSTEM/braggvectors/braggvector_methods.py index bb03852ed..be766ad49 100644 --- a/py4DSTEM/process/diskdetection/braggvector_methods.py +++ b/py4DSTEM/braggvectors/braggvector_methods.py @@ -2,10 +2,11 @@ import numpy as np from scipy.ndimage import gaussian_filter -from emdfile import Array,Metadata -from emdfile import _read_metadata -from py4DSTEM.process.calibration.origin import set_measured_origin, set_fit_origin -from py4DSTEM.process.utils import get_CoM +from warnings import warn +import inspect + +from emdfile import Array, Metadata, tqdmnd, _read_metadata +from py4DSTEM.datacube import VirtualImage class BraggVectorMethods: @@ -23,7 +24,8 @@ def histogram( weights_thresh = 0.005, ): """ - Returns a 2D histogram of Bragg vector intensities in diffraction space. + Returns a 2D histogram of Bragg vector intensities in diffraction space, + aka a Bragg vector map. Parameters ---------- @@ -177,9 +179,149 @@ def histogram( + # bragg virtual imaging + + def get_virtual_image( + self, + mode = None, + geometry = None, + name = 'bragg_virtual_image', + returncalc = True, + center = True, + ellipse = True, + pixel = True, + rotate = True, + ): + ''' + Calculates a virtual image based on the values of the Braggvectors + integrated over some detector function determined by `mode` and + `geometry`. + + Parameters + ---------- + mode : str + defines the type of detector used. Options: + - 'circular', 'circle': uses round detector, like bright field + - 'annular', 'annulus': uses annular detector, like dark field + geometry : variable + expected value depends on the value of `mode`, as follows: + - 'circle', 'circular': nested 2-tuple, ((qx,qy),radius) + - 'annular' or 'annulus': nested 2-tuple, + ((qx,qy),(radius_i,radius_o)) + Values can be in pixels or calibrated units. Note that (qx,qy) + can be skipped, which assumes peaks centered at (0,0). + center: bool + Apply calibration - center coordinate. + ellipse: bool + Apply calibration - elliptical correction. + pixel: bool + Apply calibration - pixel size. + rotate: bool + Apply calibration - QR rotation. + + Returns + ------- + virtual_im : VirtualImage + ''' + + # parse inputs + circle_modes = ['circular','circle'] + annulus_modes = ['annular','annulus'] + modes = circle_modes + annulus_modes + [None] + assert(mode in modes), f"Unrecognized mode {mode}" + + # set geometry + if mode is None: + if geometry is None: + qxy_center = None + radial_range = np.array((0,np.inf)) + else: + if len(geometry[0]) == 0: + qxy_center = None + else: + qxy_center = np.array(geometry[0]) + if isinstance(geometry[1], int) or isinstance(geometry[1], float): + radial_range = np.array((0,geometry[1])) + elif len(geometry[1]) == 0: + radial_range = None + else: + radial_range = np.array(geometry[1]) + elif mode == 'circular' or mode == 'circle': + radial_range = np.array((0,geometry[1])) + if len(geometry[0]) == 0: + qxy_center = None + else: + qxy_center = np.array(geometry[0]) + elif mode == 'annular' or mode == 'annulus': + radial_range = np.array(geometry[1]) + if len(geometry[0]) == 0: + qxy_center = None + else: + qxy_center = np.array(geometry[0]) + + # allocate space + im_virtual = np.zeros(self.shape) + + # generate image + for rx,ry in tqdmnd( + self.shape[0], + self.shape[1], + ): + # Get user-specified Bragg vectors + p = self.get_vectors( + rx, + ry, + center = center, + ellipse = ellipse, + pixel = pixel, + rotate = rotate, + ) + + if p.data.shape[0] > 0: + if radial_range is None: + im_virtual[rx,ry] = np.sum(p.I) + else: + if qxy_center is None: + qr = np.hypot(p.qx,p.qy) + else: + qr = np.hypot( + p.qx - qxy_center[0], + p.qy - qxy_center[1]) + sub = np.logical_and( + qr >= radial_range[0], + qr < radial_range[1]) + if np.sum(sub) > 0: + im_virtual[rx,ry] = np.sum(p.I[sub]) + + # wrap in Virtual Image class + ans = VirtualImage( + data = im_virtual, + name = name + ) + # add generating params as metadta + ans.metadata = Metadata( + name = 'gen_params', + data = { + '_calling_method' : inspect.stack()[0][3], + '_calling_class' : __class__.__name__, + 'mode' : mode, + 'geometry' : geometry, + 'name' : name, + 'returncalc' : returncalc + } + ) + # attach to the tree + self.attach(ans) + + # return + if returncalc: + return ans + + + + # calibration measurements - @set_measured_origin def measure_origin( self, center_guess = None, @@ -242,6 +384,7 @@ def measure_origin( np.argmax(gaussian_filter(bvm, 10)), (Q_Nx, Q_Ny) ) else: + from py4DSTEM.process.utils import get_CoM x0, y0 = get_CoM(bvm) else: x0, y0 = center_guess @@ -269,11 +412,14 @@ def measure_origin( qx0[Rx, Ry] = x0 qy0[Rx, Ry] = y0 + # set calibration metadata + self.calibration.set_origin_meas((qx0,qy0)) + self.calibration.set_origin_meas_mask(mask) + # return return qx0, qy0, mask - @set_measured_origin def measure_origin_beamstop( self, center_guess, @@ -353,14 +499,18 @@ def measure_origin_beamstop( y0_curr = np.mean(centers[found_center,1]) center_curr = x0_curr,y0_curr - # return + # collect answers mask = found_center qx0,qy0 = centers[:,:,0],centers[:,:,1] + # set calibration metadata + self.calibration.set_origin_meas((qx0,qy0)) + self.calibration.set_origin_meas_mask(mask) + + # return return qx0,qy0,mask - @set_fit_origin def fit_origin( self, mask=None, @@ -418,7 +568,7 @@ def fit_origin( try: self.calibration.set_origin([qx0_fit,qy0_fit]) except AttributeError: - # should a warning be raised? + warn("No calibration found on this datacube - fit values are not being stored") pass if plot: from py4DSTEM.visualize import show_image_grid @@ -450,12 +600,18 @@ def fit_origin( W = 3, cmap = cmap, axsize = axsize, + title = [ + 'measured origin, x', 'fitorigin, x', 'residuals, x', + 'measured origin, y', 'fitorigin, y', 'residuals, y' + ], vmin = -1*plot_range, vmax = 1*plot_range, intensity_range = "absolute", **kwargs, ) + # update calibration metadata + self.calibration.set_origin((qx0_fit,qy0_fit)) self.setcal() if returncalc: @@ -472,11 +628,11 @@ def fit_p_ellipse( ): """ Args: - bvm (BraggVectorMap): a 2D array used for ellipse fitting + bvm (BraggVectorMap): a 2D array used for ellipse fitting center (2-tuple of floats): the center (x0,y0) of the annular fitting region fitradii (2-tuple of floats): inner and outer radii (ri,ro) of the fit region mask (ar-shaped ndarray of bools): ignore data wherever mask==True - + Returns: p_ellipse if returncal is True """ @@ -505,318 +661,119 @@ def fit_p_ellipse( if returncalc: return p_ellipse - - # Deprecated?? - - # Lattice vectors - def choose_lattice_vectors( + def mask_in_Q( self, - index_g0, - index_g1, - index_g2, - mode = 'centered', - plot = True, - subpixel = 'multicorr', - upsample_factor = 16, - sigma=0, - minAbsoluteIntensity=0, - minRelativeIntensity=0, - relativeToPeak=0, - minSpacing=0, - edgeBoundary=1, - maxNumPeaks=10, - bvm_vis_params = {}, - returncalc = False, + mask, + update_inplace = False, + returncalc = True ): """ - Choose which lattice vectors to use for strain mapping. - - Args: - index_g0 (int): origin - index_g1 (int): second point of vector 1 - index_g2 (int): second point of vector 2 - mode (str): centered or raw bragg map - plot (bool): plot bragg vector maps and vectors - subpixel (str): specifies the subpixel resolution algorithm to use. - must be in ('pixel','poly','multicorr'), which correspond - to pixel resolution, subpixel resolution by fitting a - parabola, and subpixel resultion by Fourier upsampling. - upsample_factor: the upsampling factor for the 'multicorr' - algorithm - sigma: if >0, applies a gaussian filter - maxNumPeaks: the maximum number of maxima to return - minAbsoluteIntensity, minRelativeIntensity, relativeToPeak, - minSpacing, edgeBoundary, maxNumPeaks: filtering applied - after maximum detection and before subpixel refinement - """ - from py4DSTEM.process.utils import get_maxima_2D - - if mode == "centered": - bvm = self.bvm_centered - else: - bvm = self.bvm_raw - - g = get_maxima_2D( - bvm, - subpixel = subpixel, - upsample_factor = upsample_factor, - sigma = sigma, - minAbsoluteIntensity = minAbsoluteIntensity, - minRelativeIntensity = minRelativeIntensity, - relativeToPeak = relativeToPeak, - minSpacing = minSpacing, - edgeBoundary = edgeBoundary, - maxNumPeaks = maxNumPeaks, - ) - - self.g = g - - from py4DSTEM.visualize import select_lattice_vectors - g1,g2 = select_lattice_vectors( - bvm, - gx = g['x'], - gy = g['y'], - i0 = index_g0, - i1 = index_g1, - i2 = index_g2, - **bvm_vis_params, - ) - - self.g1 = g1 - self.g2 = g2 + Remove peaks which fall inside the diffraction shaped boolean array + `mask`, in raw (uncalibrated) peak positions. - if returncalc: - return g1, g2 - - def index_bragg_directions( - self, - x0 = None, - y0 = None, - plot = True, - bvm_vis_params = {}, - returncalc = False, - ): - """ - From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of - lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the - reciprocal lattice directions. + Parameters + ---------- + mask : 2d boolean array + The mask. Must be diffraction space shaped + update_inplace : bool + If False (default) copies this BraggVectors instance and + removes peaks from the copied instance. If True, removes + peaks from this instance. + returncalc : bool + Toggles returning the answer - Args: - x0 (float): x-coord of origin - y0 (float): y-coord of origin - Plot (bool): plot results + Returns + ------- + bvects : BraggVectors """ + # Copy peaks, if requested + if update_inplace: + v = self._v_uncal + else: + v = self._v_uncal.copy( name='_v_uncal' ) - if x0 is None: - x0 = self.Qshape[0]/2 - if y0 is None: - y0 = self.Qshape[0]/2 - - from py4DSTEM.process.latticevectors import index_bragg_directions - _, _, braggdirections = index_bragg_directions( - x0, - y0, - self.g['x'], - self.g['y'], - self.g1, - self.g2 - ) - - self.braggdirections = braggdirections + # Loop and remove masked peaks + for rx in range(v.shape[0]): + for ry in range(v.shape[1]): + p = v[rx,ry] + xs = np.round(p.data["qx"]).astype(int) + ys = np.round(p.data["qy"]).astype(int) + sub = mask[xs,ys] + p.remove(sub) - if plot: - from py4DSTEM.visualize import show_bragg_indexing - show_bragg_indexing( - self.bvm_centered, - **bvm_vis_params, - braggdirections = braggdirections, - points = True - ) + # assign the return value + if update_inplace: + ans = self + else: + ans = self.copy( name=self.name+'_masked' ) + ans.set_raw_vectors( v ) + # return if returncalc: - return braggdirections - - + return ans + else: + return - def add_indices_to_braggpeaks( + # alias + def get_masked_peaks( self, - maxPeakSpacing, - mask = None, - returncalc = False, - ): - """ - Using the peak positions (qx,qy) and indices (h,k) in the PointList lattice, - identify the indices for each peak in the PointListArray braggpeaks. - Return a new braggpeaks_indexed PointListArray, containing a copy of braggpeaks plus - three additional data columns -- 'h','k', and 'index_mask' -- specifying the peak - indices with the ints (h,k) and indicating whether the peak was successfully indexed - or not with the bool index_mask. If `mask` is specified, only the locations where - mask is True are indexed. - - Args: - maxPeakSpacing (float): Maximum distance from the ideal lattice points - to include a peak for indexing - qx_shift,qy_shift (number): the shift of the origin in the `lattice` PointList - relative to the `braggpeaks` PointListArray - mask (bool): Boolean mask, same shape as the pointlistarray, indicating which - locations should be indexed. This can be used to index different regions of - the scan with different lattices - """ - from py4DSTEM.process.latticevectors import add_indices_to_braggpeaks - - bragg_peaks_indexed = add_indices_to_braggpeaks( - self.vectors, - self.braggdirections, - maxPeakSpacing = maxPeakSpacing, - qx_shift = self.Qshape[0]/2, - qy_shift = self.Qshape[1]/2, - ) - - self.bragg_peaks_indexed = bragg_peaks_indexed - - if returncalc: - return bragg_peaks_indexed - - - def fit_lattice_vectors_all_DPs(self, returncalc = False): - """ - Fits lattice vectors g1,g2 to each diffraction pattern in braggpeaks, given some - known (h,k) indexing. - - - """ - - from py4DSTEM.process.latticevectors import fit_lattice_vectors_all_DPs - g1g2_map = fit_lattice_vectors_all_DPs(self.bragg_peaks_indexed) - self.g1g2_map = g1g2_map - if returncalc: - return g1g2_map - - def get_strain_from_reference_region(self, mask, returncalc = False): + mask, + update_inplace = False, + returncalc = True): """ - Gets a strain map from the reference region of real space specified by mask and the - lattice vector map g1g2_map. - - Args: - mask (ndarray of bools): use lattice vectors from g1g2_map scan positions - wherever mask==True - + Alias for `mask_in_Q`. """ - from py4DSTEM.process.latticevectors import get_strain_from_reference_region - - strainmap_median_g1g2 = get_strain_from_reference_region( - self.g1g2_map, + warn("`.get_masked_peaks` is deprecated and will be removed in a future version. Use `.mask_in_Q`") + return self.mask_in_Q( mask = mask, + update_inplace = update_inplace, + returncalc = returncalc ) - self.strainmap_median_g1g2 = strainmap_median_g1g2 - - if returncalc: - return strainmap_median_g1g2 - - - def get_strain_from_reference_g1g2(self, mask, returncalc = False): - """ - Gets a strain map from the reference lattice vectors g1,g2 and lattice vector map - g1g2_map. - - - Args: - mask (ndarray of bools): use lattice vectors from g1g2_map scan positions - wherever mask==True - - """ - from py4DSTEM.process.latticevectors import get_reference_g1g2 - g1_ref,g2_ref = get_reference_g1g2(self.g1g2_map, mask) - - from py4DSTEM.process.latticevectors import get_strain_from_reference_g1g2 - strainmap_reference_g1g2 = get_strain_from_reference_g1g2(self.g1g2_map, g1_ref, g2_ref) - - self.strainmap_reference_g1g2 = strainmap_reference_g1g2 - - if returncalc: - return strainmap_reference_g1g2 - - def get_rotated_strain_map(self, mode, g_reference = None, returncalc = True, flip_theta = False): - """ - Starting from a strain map defined with respect to the xy coordinate system of - diffraction space, i.e. where exx and eyy are the compression/tension along the Qx - and Qy directions, respectively, get a strain map defined with respect to some other - right-handed coordinate system, in which the x-axis is oriented along (xaxis_x, - xaxis_y). - - Args: - g_referencce (tupe): reference coordinate system for xaxis_x and xaxis_y - """ - - assert mode in ("median","reference") - if g_reference is None: - g_reference = np.subtract(self.g1, self.g2) - - from py4DSTEM.process.latticevectors import get_rotated_strain_map - - if mode == "median": - strainmap_raw = self.strainmap_median_g1g2 - elif mode == "reference": - strainmap_raw = self.strainmap_reference_g1g2 - - strainmap = get_rotated_strain_map( - strainmap_raw, - xaxis_x = g_reference[0], - xaxis_y = g_reference[1], - flip_theta = flip_theta - ) - - if returncalc: - return strainmap - - - - def get_masked_peaks( + def mask_in_R( self, mask, update_inplace = False, - returncalc = True): + returncalc = True + ): """ - Removes all bragg peaks which fall inside `mask` in the raw - (uncalibrated) positions. + Remove peaks which fall inside the real space shaped boolean array + `mask`. - Args: - mask (bool): binary image where peaks will be deleted - update_inplace (bool): if True, removes peaks from this - BraggVectors instance. If False, returns a new - BraggVectors instance with the requested peaks removed - returncalc (bool): if True, return the BraggVectors + Parameters + ---------- + mask : 2d boolean array + The mask. Must be real space shaped + update_inplace : bool + If False (default) copies this BraggVectors instance and + removes peaks from the copied instance. If True, removes + peaks from this instance. + returncalc : bool + Toggles returning the answer - Returns: - (BraggVectors or None) + Returns + ------- + bvects : BraggVectors """ + # Copy peaks, if requested + if update_inplace: + v = self._v_uncal + else: + v = self._v_uncal.copy( name='_v_uncal' ) - # Copy peaks - v = self._v_uncal.copy( name='_v_uncal' ) - - # Loop over all peaks + # Loop and remove masked peaks for rx in range(v.shape[0]): for ry in range(v.shape[1]): - p = v.get_pointlist(rx,ry) - sub = mask.ravel()[np.ravel_multi_index(( - np.round(p.data["qx"]).astype('int'), - np.round(p.data["qy"]).astype('int')), - self.Qshape)] - p.remove(sub) + if mask[rx,ry]: + p = v[rx,ry] + p.remove(np.ones(len(p),dtype=bool)) - # if modifying this BraggVectors instance was requested + # assign the return value if update_inplace: - self._v_uncal = v ans = self - - # if a new instance was requested else: ans = self.copy( name=self.name+'_masked' ) - ans._v_uncal = v - - # re-calibrate - ans.calibrate() + ans.set_raw_vectors( v ) # return if returncalc: @@ -827,6 +784,7 @@ def get_masked_peaks( + ######### END BraggVectorMethods CLASS ######## diff --git a/py4DSTEM/process/diskdetection/braggvectors.py b/py4DSTEM/braggvectors/braggvectors.py similarity index 78% rename from py4DSTEM/process/diskdetection/braggvectors.py rename to py4DSTEM/braggvectors/braggvectors.py index 3a22bd27b..f1ff406d0 100644 --- a/py4DSTEM/process/diskdetection/braggvectors.py +++ b/py4DSTEM/braggvectors/braggvectors.py @@ -1,8 +1,8 @@ # Defines the BraggVectors class -from py4DSTEM.classes import Data +from py4DSTEM.data import Data from emdfile import Custom,PointListArray,PointList,Metadata -from py4DSTEM.process.diskdetection.braggvector_methods import BraggVectorMethods +from py4DSTEM.braggvectors.braggvector_methods import BraggVectorMethods from os.path import basename import numpy as np from warnings import warn @@ -45,6 +45,19 @@ class BraggVectors(Custom,BraggVectorMethods,Data): >>> vects.qx,vects.qy,vects.I >>> vects['qx'],vects['qy'],vects['intensity'] + Alternatively, you can access the centered vectors in pixel units with + + >>> vects.get_vectors( + >>> scan_x, + >>> scan_y, + >>> center = bool, + >>> ellipse = bool, + >>> pixel = bool, + >>> rotate = bool + >>> ) + + which will return the vectors at scan position (scan_x,scan_y) with the + requested calibrations applied. """ def __init__( @@ -52,12 +65,13 @@ def __init__( Rshape, Qshape, name = 'braggvectors', - verbose = True, + verbose = False, + calibration = None ): Custom.__init__(self,name=name) + Data.__init__(self,calibration=calibration) self.Rshape = Rshape - self.shape = self.Rshape self.Qshape = Qshape self.verbose = verbose @@ -78,9 +92,28 @@ def __init__( "pixel" : False, "rotate" : False, } + + # register with calibrations + self.calibration.register_target(self) + + # setup vector getters + self._set_raw_vector_getter() + self._set_cal_vector_getter() + + + # set new raw vectors + def set_raw_vectors(self,x): + """ Given some PointListArray x of the correct shape, sets this to the raw vectors + """ + assert(isinstance(x,PointListArray)), f"Raw vectors must be set to a PointListArray, not type {type(x)}" + assert(x.shape == self.Rshape), "Shapes don't match!" + self._v_uncal = x self._set_raw_vector_getter() self._set_cal_vector_getter() + + # calibration state, vector getters + @property def calstate(self): return self._calstate @@ -94,6 +127,10 @@ def _set_cal_vector_getter(self): ) + # shape + @property + def shape(self): + return self.Rshape # raw vectors @@ -111,7 +148,6 @@ def raw(self): return self._raw_vector_getter - # calibrated vectors @property @@ -190,7 +226,7 @@ def setcal( if pixel: assert(c.get_Q_pixel_size() is not None), "Requested calibration not found" if rotate: - assert(c.get_RQ_rotflip() is not None), "Requested calibration not found" + assert(c.get_QR_rotflip() is not None), "Requested calibration not found" # set the calibrations self._calstate = { @@ -200,17 +236,70 @@ def setcal( "rotate" : rotate, } if self.verbose: - print('current calstate: ', self.calstate) + print('current calibration state: ', self.calstate) pass + def calibrate(self): + """ + Autoupdate the calstate when relevant calibrations are set + """ + self.setcal() + - # copy + # vector getter method + + def get_vectors( + self, + scan_x, + scan_y, + center, + ellipse, + pixel, + rotate + ): + """ + Returns the bragg vectors at the specified scan position with + the specified calibration state. + + Parameters + ---------- + scan_x : int + scan_y : int + center : bool + ellipse : bool + pixel : bool + rotate : bool + + Returns + ------- + vectors : BVects + """ + ans = self._v_uncal[scan_x,scan_y].data + ans = self.cal._transform( + data = ans, + cal = self.calibration, + scanxy = (scan_x,scan_y), + center = center, + ellipse = ellipse, + pixel = pixel, + rotate = rotate, + ) + return BVects(ans) + + + # copy def copy(self, name=None): name = name if name is not None else self.name+"_copy" - braggvector_copy = BraggVectors(self.Rshape, self.Qshape, name=name) - braggvector_copy._v_uncal = self._v_uncal.copy() + braggvector_copy = BraggVectors( + self.Rshape, + self.Qshape, + name=name, + calibration = self.calibration.copy() + ) + + braggvector_copy.set_raw_vectors( self._v_uncal.copy() ) for k in self.metadata.keys(): braggvector_copy.metadata = self.metadata[k].copy() return braggvector_copy @@ -389,13 +478,16 @@ def _transform( if center: origin = cal.get_origin(x,y) + assert(origin is not None), "Requested calibration was not found!" ans['qx'] -= origin[0] ans['qy'] -= origin[1] # ellipse if ellipse: - a,b,theta = cal.get_ellipse(x,y) + ell = cal.get_ellipse(x,y) + assert(ell is not None), "Requested calibration was not found!" + a,b,theta = ell # Get the transformation matrix e = b/a sint, cost = np.sin(theta-np.pi/2.), np.cos(theta-np.pi/2.) @@ -415,6 +507,7 @@ def _transform( # pixel size if pixel: qpix = cal.get_Q_pixel_size() + assert(qpix is not None), "Requested calibration was not found!" ans['qx'] *= qpix ans['qy'] *= qpix @@ -423,6 +516,8 @@ def _transform( if rotate: flip = cal.get_QR_flip() theta = cal.get_QR_rotation_degrees() + assert(flip is not None), "Requested calibration was not found!" + assert(theta is not None), "Requested calibration was not found!" # rotation matrix R = np.array([ [np.cos(theta), -np.sin(theta)], @@ -438,6 +533,4 @@ def _transform( # return - return ans - - + return ans \ No newline at end of file diff --git a/py4DSTEM/process/diskdetection/diskdetection.py b/py4DSTEM/braggvectors/diskdetection.py similarity index 97% rename from py4DSTEM/process/diskdetection/diskdetection.py rename to py4DSTEM/braggvectors/diskdetection.py index 1674fdc93..fb7755349 100644 --- a/py4DSTEM/process/diskdetection/diskdetection.py +++ b/py4DSTEM/braggvectors/diskdetection.py @@ -5,11 +5,12 @@ from scipy.ndimage import gaussian_filter from emdfile import tqdmnd -from py4DSTEM.process.diskdetection.braggvectors import BraggVectors -from py4DSTEM.classes import DataCube, QPoints +from py4DSTEM.braggvectors.braggvectors import BraggVectors +from py4DSTEM.data import QPoints +from py4DSTEM.datacube import DataCube from py4DSTEM.preprocess.utils import get_maxima_2D from py4DSTEM.process.utils.cross_correlate import get_cross_correlation_FT -from py4DSTEM.process.diskdetection.diskdetection_aiml import find_Bragg_disks_aiml +from py4DSTEM.braggvectors.diskdetection_aiml import find_Bragg_disks_aiml @@ -572,7 +573,7 @@ def _find_Bragg_disks_CUDA_unbatched( ): # compute - from py4DSTEM.process.diskdetection.diskdetection_cuda import find_Bragg_disks_CUDA + from py4DSTEM.braggvectors.diskdetection_cuda import find_Bragg_disks_CUDA peaks = find_Bragg_disks_CUDA( datacube, probe, @@ -618,7 +619,7 @@ def _find_Bragg_disks_CUDA_batched( ): # compute - from py4DSTEM.process.diskdetection.diskdetection_cuda import find_Bragg_disks_CUDA + from py4DSTEM.braggvectors.diskdetection_cuda import find_Bragg_disks_CUDA peaks = find_Bragg_disks_CUDA( datacube, probe, @@ -668,7 +669,7 @@ def _find_Bragg_disks_ipp( ): # compute - from py4DSTEM.process.diskdetection.diskdetection_parallel import find_Bragg_disks_ipp + from py4DSTEM.braggvectors.diskdetection_parallel import find_Bragg_disks_ipp peaks = find_Bragg_disks_ipp( datacube, probe, @@ -720,7 +721,7 @@ def _find_Bragg_disks_dask( ): # compute - from py4DSTEM.process.diskdetection.diskdetection_parallel import find_Bragg_disks_dask + from py4DSTEM.braggvectors.diskdetection_parallel import find_Bragg_disks_dask peaks = find_Bragg_disks_dask( datacube, probe, diff --git a/py4DSTEM/process/diskdetection/diskdetection_aiml.py b/py4DSTEM/braggvectors/diskdetection_aiml.py similarity index 99% rename from py4DSTEM/process/diskdetection/diskdetection_aiml.py rename to py4DSTEM/braggvectors/diskdetection_aiml.py index e076e0732..6d9bea623 100644 --- a/py4DSTEM/process/diskdetection/diskdetection_aiml.py +++ b/py4DSTEM/braggvectors/diskdetection_aiml.py @@ -14,10 +14,10 @@ from numbers import Number from emdfile import tqdmnd, PointList, PointListArray -from py4DSTEM.process.diskdetection.braggvectors import BraggVectors -from py4DSTEM.classes import QPoints +from py4DSTEM.braggvectors.braggvectors import BraggVectors +from py4DSTEM.data import QPoints from py4DSTEM.process.utils import get_maxima_2D -# from py4DSTEM.process.diskdetection import universal_threshold +# from py4DSTEM.braggvectors import universal_threshold def find_Bragg_disks_aiml_single_DP(DP, probe, num_attempts = 5, @@ -430,7 +430,7 @@ def find_Bragg_disks_aiml_serial(datacube, probe, int(t2/60), int(t2%60))) if global_threshold == True: - from py4DSTEM.process.diskdetection import universal_threshold + from py4DSTEM.braggvectors import universal_threshold peaks = universal_threshold(peaks, minGlobalIntensity, metric, minPeakSpacing, maxNumPeaks) @@ -624,7 +624,7 @@ def _parse_distributed(distributed): name=name, filter_function=filter_function) elif _check_cuda_device_available(): - from py4DSTEM.process.diskdetection.diskdetection_aiml_cuda import find_Bragg_disks_aiml_CUDA + from py4DSTEM.braggvectors.diskdetection_aiml_cuda import find_Bragg_disks_aiml_CUDA return find_Bragg_disks_aiml_CUDA(datacube, probe, num_attempts = num_attempts, diff --git a/py4DSTEM/process/diskdetection/diskdetection_aiml_cuda.py b/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py similarity index 98% rename from py4DSTEM/process/diskdetection/diskdetection_aiml_cuda.py rename to py4DSTEM/braggvectors/diskdetection_aiml_cuda.py index 1b6c38565..d714feda9 100644 --- a/py4DSTEM/process/diskdetection/diskdetection_aiml_cuda.py +++ b/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py @@ -7,11 +7,12 @@ from time import time from emdfile import tqdmnd -from py4DSTEM.process.diskdetection.braggvectors import BraggVectors -from py4DSTEM.classes import PointList, PointListArray, QPoints -from py4DSTEM.process.diskdetection.kernels import kernels -from py4DSTEM.process.diskdetection.diskdetection_aiml import _get_latest_model -# from py4DSTEM.process.diskdetection.diskdetection import universal_threshold +from py4DSTEM.braggvectors.braggvectors import BraggVectors +from emdfile import PointList, PointListArray +from py4DSTEM.data import QPoints +from py4DSTEM.braggvectors.kernels import kernels +from py4DSTEM.braggvectors.diskdetection_aiml import _get_latest_model +# from py4DSTEM.braggvectors.diskdetection import universal_threshold try: import cupy as cp @@ -194,7 +195,7 @@ def find_Bragg_disks_aiml_CUDA(datacube, probe, print("Analyzed {} diffraction patterns in {}h {}m {}s".format(datacube.R_N, int(t2/3600), int(t2/60), int(t2%60))) if global_threshold == True: - from py4DSTEM.process.diskdetection import universal_threshold + from py4DSTEM.braggvectors import universal_threshold peaks = universal_threshold(peaks, minGlobalIntensity, metric, minPeakSpacing, maxNumPeaks) peaks.name = name diff --git a/py4DSTEM/process/diskdetection/diskdetection_cuda.py b/py4DSTEM/braggvectors/diskdetection_cuda.py similarity index 99% rename from py4DSTEM/process/diskdetection/diskdetection_cuda.py rename to py4DSTEM/braggvectors/diskdetection_cuda.py index 2cd96be51..b912bde48 100644 --- a/py4DSTEM/process/diskdetection/diskdetection_cuda.py +++ b/py4DSTEM/braggvectors/diskdetection_cuda.py @@ -12,7 +12,7 @@ from emdfile import tqdmnd from py4DSTEM import PointList, PointListArray -from py4DSTEM.process.diskdetection.kernels import kernels +from py4DSTEM.braggvectors.kernels import kernels def find_Bragg_disks_CUDA( @@ -198,9 +198,9 @@ def find_Bragg_disks_CUDA( threads=threads, ) - # clean up - del batched_subcube, batched_crosscorr, subFFT, cc, ccc - cp.get_default_memory_pool().free_all_blocks() + # clean up + del batched_subcube, batched_crosscorr, subFFT, cc, ccc + cp.get_default_memory_pool().free_all_blocks() else: diff --git a/py4DSTEM/process/diskdetection/diskdetection_parallel.py b/py4DSTEM/braggvectors/diskdetection_parallel.py similarity index 99% rename from py4DSTEM/process/diskdetection/diskdetection_parallel.py rename to py4DSTEM/braggvectors/diskdetection_parallel.py index f6d2624e1..5a6d6dc11 100644 --- a/py4DSTEM/process/diskdetection/diskdetection_parallel.py +++ b/py4DSTEM/braggvectors/diskdetection_parallel.py @@ -9,7 +9,7 @@ # local import py4DSTEM -from py4DSTEM.classes import PointListArray +from emdfile import PointListArray def _find_Bragg_disks_single_DP_FK(DP, probe_kernel_FT, diff --git a/py4DSTEM/process/diskdetection/diskdetection_parallel_new.py b/py4DSTEM/braggvectors/diskdetection_parallel_new.py similarity index 98% rename from py4DSTEM/process/diskdetection/diskdetection_parallel_new.py rename to py4DSTEM/braggvectors/diskdetection_parallel_new.py index c9e224eee..241cae2f7 100644 --- a/py4DSTEM/process/diskdetection/diskdetection_parallel_new.py +++ b/py4DSTEM/braggvectors/diskdetection_parallel_new.py @@ -17,8 +17,8 @@ import distributed import py4DSTEM -from py4DSTEM.classes import PointListArray, PointList -from py4DSTEM.process.diskdetection.diskdetection import _find_Bragg_disks_single_DP_FK +from emdfile import PointListArray, PointList +from py4DSTEM.braggvectors.diskdetection import _find_Bragg_disks_single_DP_FK diff --git a/py4DSTEM/process/diskdetection/kernels.py b/py4DSTEM/braggvectors/kernels.py similarity index 100% rename from py4DSTEM/process/diskdetection/kernels.py rename to py4DSTEM/braggvectors/kernels.py diff --git a/py4DSTEM/process/diskdetection/multicorr_col_kernel.cu b/py4DSTEM/braggvectors/multicorr_col_kernel.cu similarity index 100% rename from py4DSTEM/process/diskdetection/multicorr_col_kernel.cu rename to py4DSTEM/braggvectors/multicorr_col_kernel.cu diff --git a/py4DSTEM/process/diskdetection/multicorr_row_kernel.cu b/py4DSTEM/braggvectors/multicorr_row_kernel.cu similarity index 100% rename from py4DSTEM/process/diskdetection/multicorr_row_kernel.cu rename to py4DSTEM/braggvectors/multicorr_row_kernel.cu diff --git a/py4DSTEM/braggvectors/probe.py b/py4DSTEM/braggvectors/probe.py new file mode 100644 index 000000000..9195d06bd --- /dev/null +++ b/py4DSTEM/braggvectors/probe.py @@ -0,0 +1,617 @@ +# Defines the Probe class + +import numpy as np +from typing import Optional +from warnings import warn + +from py4DSTEM.data import DiffractionSlice, Data +from scipy.ndimage import ( + binary_opening, binary_dilation, distance_transform_edt) + + + +class Probe(DiffractionSlice,Data): + """ + Stores a vacuum probe. + + Both a vacuum probe and a kernel for cross-correlative template matching + derived from that probe are stored and can be accessed at + + >>> p.probe + >>> p.kernel + + respectively, for some Probe instance `p`. If a kernel has not been computed + the latter expression returns None. + + + """ + + def __init__( + self, + data: np.ndarray, + name: Optional[str] = 'probe' + ): + """ + Accepts: + data (2D or 3D np.ndarray): the vacuum probe, or + the vacuum probe + kernel + name (str): a name + + Returns: + (Probe) + """ + # if only the probe is passed, make space for the kernel + if data.ndim == 2: + data = np.stack([ + data, + np.zeros_like(data) + ]) + + # initialize as a DiffractionSlice + DiffractionSlice.__init__( + self, + name = name, + data = data, + slicelabels = [ + 'probe', + 'kernel' + ] + ) + + + ## properties + + @property + def probe(self): + return self.get_slice('probe').data + @probe.setter + def probe(self,x): + assert(x.shape == (self.data.shape[1:])) + self.data[0,:,:] = x + @property + def kernel(self): + return self.get_slice('kernel').data + @kernel.setter + def kernel(self,x): + assert(x.shape == (self.data.shape[1:])) + self.data[1,:,:] = x + + + # read + @classmethod + def _get_constructor_args(cls,group): + """ + Returns a dictionary of args/values to pass to the class constructor + """ + ar_constr_args = DiffractionSlice._get_constructor_args(group) + args = { + 'data' : ar_constr_args['data'], + 'name' : ar_constr_args['name'], + } + return args + + + + # generation methods + + @classmethod + def from_vacuum_data( + cls, + data, + mask = None, + threshold = 0.2, + expansion = 12, + opening = 3 + ): + """ + Generates and returns a vacuum probe Probe instance from either a + 2D vacuum image or a 3D stack of vacuum diffraction patterns. + + The probe is multiplied by `mask`, if it's passed. An additional + masking step zeros values outside of a mask determined by `threshold`, + `expansion`, and `opening`, generated by first computing the binary image + probe < max(probe)*threshold, then applying a binary expansion and + then opening to this image. No alignment is performed - i.e. it is assumed + that the beam was stationary during acquisition of the stack. To align + the images, use the DataCube .get_vacuum_probe method. + + Parameters + ---------- + data : 2D or 3D array + the vacuum diffraction data. For 3D stacks, use shape (N,Q_Nx,Q_Ny) + mask : boolean array, optional + mask applied to the probe + threshold : float + threshold determining mask which zeros values outside of probe + expansion : int + number of pixels by which the zeroing mask is expanded to capture + the full probe + opening : int + size of binary opening used to eliminate stray bright pixels + + Returns + ------- + probe : Probe + the vacuum probe + """ + assert(isinstance(data,np.ndarray)) + if data.ndim == 3: + probe = np.average(data,axis=0) + elif data.ndim == 2: + probe = data + else: + raise Exception(f"data must be 2- or 3-D, not {data.ndim}-D") + + if mask is not None: + probe *= mask + + mask = probe > np.max(probe)*threshold + mask = binary_opening(mask, iterations=opening) + mask = binary_dilation(mask, iterations=1) + mask = np.cos((np.pi/2)*np.minimum( + distance_transform_edt(np.logical_not(mask)) / expansion, 1))**2 + + probe = cls(probe*mask) + return probe + + + @classmethod + def generate_synthetic_probe( + cls, + radius, + width, + Qshape + ): + """ + Makes a synthetic probe, with the functional form of a disk blurred by a + sigmoid (a logistic function). + + Parameters + ---------- + radius : float + the probe radius + width : float + the blurring of the probe edge. width represents the + full width of the blur, with x=-w/2 to x=+w/2 about the edge + spanning values of ~0.12 to 0.88 + Qshape : 2 tuple + the diffraction plane dimensions + + Returns + ------- + probe : Probe + the probe + """ + # Make coords + Q_Nx,Q_Ny = Qshape + qy,qx = np.meshgrid(np.arange(Q_Ny),np.arange(Q_Nx)) + qy,qx = qy - Q_Ny/2., qx-Q_Nx/2. + qr = np.sqrt(qx**2+qy**2) + + # Shift zero to disk edge + qr = qr - radius + + # Calculate logistic function + probe = 1/(1+np.exp(4*qr/width)) + + return cls(probe) + + + + # calibration methods + + def measure_disk( + self, + thresh_lower=0.01, + thresh_upper=0.99, + N=100, + returncalc=True, + data=None, + ): + """ + Finds the center and radius of an average probe image. + + A naive algorithm. Creates a series of N binary masks by thresholding + the probe image a linspace of N thresholds from thresh_lower to + thresh_upper, relative to the image max/min. For each mask, we find the + square root of the number of True valued pixels divided by pi to + estimate a radius. Because the central disk is intense relative to the + remainder of the image, the computed radii are expected to vary very + little over a wider range threshold values. A range of r values + considered trustworthy is estimated by taking the derivative + r(thresh)/dthresh identifying where it is small, and the mean of this + range is returned as the radius. A center is estimated using a binary + thresholded image in combination with the center of mass operator. + + Parameters + ---------- + thresh_lower : float, 0 to 1 + the lower limit of threshold values + thresh_upper : float, 0 to 1) + the upper limit of threshold values + N : int + the number of thresholds / masks to use + returncalc : True + toggles returning the answer + data : 2d array, optional + if passed, uses this 2D array in place of the probe image when + performing the computation. This also supresses storing the + results in the Probe's calibration metadata + + Returns + ------- + r, x0, y0 : (3-tuple) + the radius and origin + """ + from py4DSTEM.process.utils import get_CoM + + # set the image + im = self.probe if data is None else data + + # define the thresholds + thresh_vals = np.linspace(thresh_lower, thresh_upper, N) + r_vals = np.zeros(N) + + # get binary images and compute a radius for each + immax = np.max(im) + for i,val in enumerate(thresh_vals): + mask = im > immax * val + r_vals[i] = np.sqrt(np.sum(mask) / np.pi) + + # Get derivative and determine trustworthy r-values + dr_dtheta = np.gradient(r_vals) + mask = (dr_dtheta <= 0) * (dr_dtheta >= 2 * np.median(dr_dtheta)) + r = np.mean(r_vals[mask]) + + # Get origin + thresh = np.mean(thresh_vals[mask]) + mask = im > immax * thresh + x0, y0 = get_CoM(im * mask) + + # Store metadata and return + ans = r,x0,y0 + if data is None: + try: + self.calibration.set_probe_param(ans) + except AttributeError: + warn(f"Couldn't store the probe parameters in metadata as no calibration was found for this Probe instance, {self}") + pass + if returncalc: + return ans + + + + + + + + + # Kernel generation methods + + def get_kernel( + self, + mode = 'flat', + origin = None, + data = None, + returncalc = True, + **kwargs + ): + """ + Creates a cross-correlation kernel from the vacuum probe. + + Specific behavior and valid keyword arguments depend on the `mode` + specified. In each case, the center of the probe is shifted to the + origin and the kernel normalized such that it sums to 1. This is the + only processing performed if mode is 'flat'. Otherwise, a centrosymmetric + region of negative intensity is added around the probe intended to promote + edge-filtering-like behavior during cross correlation, with the + functional form of the subtracted region defined by `mode` and the + relevant **kwargs. For normalization, flat probes integrate to 1, and the + remaining probes integrate to 1 before subtraction and 0 after. Required + keyword arguments are: + + - 'flat': No required arguments. This mode is recommended for bullseye + or other structured probes + - 'gaussian': Required arg `sigma` (number), the width (standard + deviation) of a centered gaussian to be subtracted. + - 'sigmoid': Required arg `radii` (2-tuple), the inner and outer radii + (ri,ro) of an annular region with a sine-squared sigmoidal radial + profile to be subtracted. + - 'sigmoid_log': Required arg `radii` (2-tuple), the inner and outer radii + (ri,ro) of an annular region with a logistic sigmoidal radial + profile to be subtracted. + + Parameters + ---------- + mode : str + must be in 'flat','gaussian','sigmoid','sigmoid_log' + origin : 2-tuple, optional + specify the origin. If not passed, looks for a value for the probe + origin in metadata. If not found there, calls .measure_disk. + data : 2d array, optional + if specified, uses this array instead of the probe image to compute + the kernel + **kwargs + see descriptions above + + Returns + ------- + kernel : 2D array + """ + + modes = [ + 'flat', + 'gaussian', + 'sigmoid', + 'sigmoid_log' + ] + + # parse args + assert mode in modes, f"mode must be in {modes}. Received {mode}" + + # get function + function_dict = { + 'flat' : self.get_probe_kernel_flat, + 'gaussian' : self.get_probe_kernel_edge_gaussian, + 'sigmoid' : self._get_probe_kernel_edge_sigmoid_sine_squared, + 'sigmoid_log' : self._get_probe_kernel_edge_sigmoid_sine_squared + } + fn = function_dict[mode] + + # check for the origin + if origin is None: + try: + x = self.calibration.get_probe_params() + except AttributeError: + x = None + finally: + if x is None: + origin = None + else: + r,x,y = x + origin = (x,y) + + # get the data + probe = data if data is not None else self.probe + + # compute + kern = fn( + probe, + origin = origin, + **kwargs + ) + + # add to the Probe + self.kernel = kern + + # return + if returncalc: + return kern + + + + @staticmethod + def get_probe_kernel_flat( + probe, + origin=None, + bilinear=False + ): + """ + Creates a cross-correlation kernel from the vacuum probe by normalizing + and shifting the center. + + Parameters + ---------- + probe : 2d array + the vacuum probe + origin : 2-tuple (optional) + the origin of diffraction space. If not specified, finds the origin + using get_probe_radius. + bilinear : bool (optional) + By default probe is shifted via a Fourier transform. Setting this to + True overrides it and uses bilinear shifting. Not recommended! + + Returns + ------- + kernel : ndarray + the cross-correlation kernel corresponding to the probe, in real + space + """ + from py4DSTEM.process.utils import get_shifted_ar + + Q_Nx, Q_Ny = probe.shape + + # Get CoM + if origin is None: + from py4DSTEM.process.calibration import get_probe_size + _,xCoM,yCoM = get_probe_size(probe) + else: + xCoM,yCoM = origin + + # Normalize + probe = probe/np.sum(probe) + + # Shift center to corners of array + probe_kernel = get_shifted_ar(probe, -xCoM, -yCoM, bilinear=bilinear) + + # Return + return probe_kernel + + + @staticmethod + def get_probe_kernel_edge_gaussian( + probe, + sigma, + origin=None, + bilinear=True, + ): + """ + Creates a cross-correlation kernel from the probe, subtracting a + gaussian from the normalized probe such that the kernel integrates to + zero, then shifting the center of the probe to the array corners. + + Parameters + ---------- + probe : ndarray + the diffraction pattern corresponding to the probe over vacuum + sigma : float + the width of the gaussian to subtract, relative to the standard + deviation of the probe + origin : 2-tuple (optional) + the origin of diffraction space. If not specified, finds the origin + using get_probe_radius. + bilinear : bool + By default probe is shifted via a Fourier transform. Setting this to + True overrides it and uses bilinear shifting. Not recommended! + + Returns + ------- + kernel : ndarray + the cross-correlation kernel + """ + from py4DSTEM.process.utils import get_shifted_ar + + Q_Nx, Q_Ny = probe.shape + + # Get CoM + if origin is None: + from py4DSTEM.process.calibration import get_probe_size + _,xCoM,yCoM = get_probe_size(probe) + else: + xCoM,yCoM = origin + + # Shift probe to origin + probe_kernel = get_shifted_ar(probe, -xCoM, -yCoM, bilinear=bilinear) + + # Generate normalization kernel + # Coordinates + qy,qx = np.meshgrid( + np.mod(np.arange(Q_Ny) + Q_Ny//2, Q_Ny) - Q_Ny//2, + np.mod(np.arange(Q_Nx) + Q_Nx//2, Q_Nx) - Q_Nx//2) + qr2 = (qx**2 + qy**2) + # Calculate Gaussian normalization kernel + qstd2 = np.sum(qr2*probe_kernel) / np.sum(probe_kernel) + kernel_norm = np.exp(-qr2 / (2*qstd2*sigma**2)) + + # Output normalized kernel + probe_kernel = probe_kernel/np.sum(probe_kernel) - kernel_norm/np.sum(kernel_norm) + + return probe_kernel + + + @staticmethod + def get_probe_kernel_edge_sigmoid( + probe, + radii, + origin=None, + type='sine_squared', + bilinear=True, + ): + """ + Creates a convolution kernel from an average probe, subtracting an annular + trench about the probe such that the kernel integrates to zero, then + shifting the center of the probe to the array corners. + + Parameters + ---------- + probe : ndarray + the diffraction pattern corresponding to the probe over vacuum + radii : 2-tuple + the sigmoid inner and outer radii + origin : 2-tuple (optional) + the origin of diffraction space. If not specified, finds the origin + using get_probe_radius. + type : string + must be 'logistic' or 'sine_squared' + bilinear : bool + By default probe is shifted via a Fourier transform. Setting this to + True overrides it and uses bilinear shifting. Not recommended! + + Returns + ------- + kernel : 2d array + the cross-correlation kernel + """ + from py4DSTEM.process.utils import get_shifted_ar + + # parse inputs + if isinstance(probe,Probe): + probe = probe.probe + + valid_types = ('logistic','sine_squared') + assert(type in valid_types), "type must be in {}".format(valid_types) + Q_Nx, Q_Ny = probe.shape + ri,ro = radii + + # Get CoM + if origin is None: + from py4DSTEM.process.calibration import get_probe_size + _,xCoM,yCoM = get_probe_size(probe) + else: + xCoM,yCoM = origin + + # Shift probe to origin + probe_kernel = get_shifted_ar(probe, -xCoM, -yCoM, bilinear=bilinear) + + # Generate normalization kernel + # Coordinates + qy,qx = np.meshgrid( + np.mod(np.arange(Q_Ny) + Q_Ny//2, Q_Ny) - Q_Ny//2, + np.mod(np.arange(Q_Nx) + Q_Nx//2, Q_Nx) - Q_Nx//2) + qr = np.sqrt(qx**2 + qy**2) + # Calculate sigmoid + if type == 'logistic': + r0 = 0.5*(ro+ri) + sigma = 0.25*(ro-ri) + sigmoid = 1/(1+np.exp((qr-r0)/sigma)) + elif type == 'sine_squared': + sigmoid = (qr - ri) / (ro - ri) + sigmoid = np.minimum(np.maximum(sigmoid, 0.0), 1.0) + sigmoid = np.cos((np.pi/2)*sigmoid)**2 + else: + raise Exception("type must be in {}".format(valid_types)) + + # Output normalized kernel + probe_kernel = probe_kernel/np.sum(probe_kernel) - sigmoid/np.sum(sigmoid) + + return probe_kernel + + + def _get_probe_kernel_edge_sigmoid_sine_squared( + self, + probe, + radii, + origin=None, + **kwargs, + ): + return self.get_probe_kernel_edge_sigmoid( + probe, + radii, + origin = origin, + type='sine_squared', + **kwargs, + ) + + def _get_probe_kernel_edge_sigmoid_logistic( + self, + probe, + radii, + origin=None, + **kwargs, + ): + return self.get_probe_kernel_edge_sigmoid( + probe, + radii, + origin = origin, + type='logistic', + **kwargs + ) + + + + + + + + + diff --git a/py4DSTEM/process/diskdetection/threshold.py b/py4DSTEM/braggvectors/threshold.py similarity index 100% rename from py4DSTEM/process/diskdetection/threshold.py rename to py4DSTEM/braggvectors/threshold.py diff --git a/py4DSTEM/classes/__init__.py b/py4DSTEM/classes/__init__.py deleted file mode 100644 index 27ea3e353..000000000 --- a/py4DSTEM/classes/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -_emd_hook = True - -from py4DSTEM.classes.calibration import Calibration -from py4DSTEM.classes.data import Data -from py4DSTEM.classes.datacube import DataCube -from py4DSTEM.classes.diffractionslice import DiffractionSlice -from py4DSTEM.classes.realslice import RealSlice -from py4DSTEM.classes.probe import Probe -from py4DSTEM.classes.virtualdiffraction import VirtualDiffraction -from py4DSTEM.classes.virtualimage import VirtualImage -from py4DSTEM.classes.qpoints import QPoints - - diff --git a/py4DSTEM/classes/calibration.py b/py4DSTEM/classes/calibration.py deleted file mode 100644 index dbeef3e2b..000000000 --- a/py4DSTEM/classes/calibration.py +++ /dev/null @@ -1,454 +0,0 @@ -# Defines the Calibration class, which stores calibration metadata - -import numpy as np -from numbers import Number -from typing import Optional - -from emdfile import Metadata -from py4DSTEM.classes.propagating_calibration import propagating_calibration - -class Calibration(Metadata): - """ - Stores calibration measurements. - - Usage: - - >>> c = Calibration() - >>> c.set_p(p) - >>> p = c.get_p() - - If the parameter has not been set, the getter methods return None. For - parameters with multiple values, they're returned as a tuple. If any of - the multiple values can't be found, a single None is returned instead. - Some parameters may have distinct values for each scan position; these - are stored as 2D arrays, and - - >>> c.get_p() - - will return the entire 2D array, while - - >>> c.get_p(rx,ry) - - will return the value of `p` at position `rx,ry`. - - The Calibration object is capable of automatically calling the ``calibrate`` method - of any other py4DSTEM objects when certain calibrations are updated. The methods - that trigger propagation of calibration information are tagged with the - @propagating_calibration decorator. Use the ``register_target`` method - to set up an object to recieve calls to ``calibrate`` - - """ - def __init__( - self, - name: Optional[str] = 'calibration', - datacube = None - ): - """ - Args: - name (optional, str): - """ - Metadata.__init__( - self, - name=name) - - # List to hold objects that will re-`calibrate` when - # certain properties are changed - self._targets = [] - - # set datacube - self._datacube = datacube - - # set initial pixel values - self.set_Q_pixel_size(1) - self.set_R_pixel_size(1) - self.set_Q_pixel_units('pixels') - self.set_R_pixel_units('pixels') - - - - # datacube - - @property - def datacube(self): - return self._datacube - - - - - ### getter/setter methods - - - - # pixel size/units - - @propagating_calibration - def set_Q_pixel_size(self,x): - if self._has_datacube(): - self.datacube.set_dim(2,x) - self.datacube.set_dim(3,x) - self._params['Q_pixel_size'] = x - def get_Q_pixel_size(self): - return self._get_value('Q_pixel_size') - - @propagating_calibration - def set_R_pixel_size(self,x): - if self._has_datacube(): - self.datacube.set_dim(0,x) - self.datacube.set_dim(1,x) - self._params['R_pixel_size'] = x - def get_R_pixel_size(self): - return self._get_value('R_pixel_size') - - @propagating_calibration - def set_Q_pixel_units(self,x): - assert(x in ('pixels','A^-1','mrad')), f"Q pixel units must be 'A^-1', 'mrad' or 'pixels'." - if self._has_datacube(): - self.datacube.set_dim_units(2,x) - self.datacube.set_dim_units(3,x) - self._params['Q_pixel_units'] = x - def get_Q_pixel_units(self): - return self._get_value('Q_pixel_units') - - @propagating_calibration - def set_R_pixel_units(self,x): - if self._has_datacube(): - self.datacube.set_dim_units(0,x) - self.datacube.set_dim_units(1,x) - self._params['R_pixel_units'] = x - def get_R_pixel_units(self): - return self._get_value('R_pixel_units') - - - - # datacube shape - - def get_R_Nx(self): - self._validate_datacube() - return self.datacube.R_Nx - def get_R_Ny(self): - self._validate_datacube() - return self.datacube.R_Ny - def get_Q_Nx(self): - self._validate_datacube() - return self.datacube.Q_Nx - def get_Q_Ny(self): - self._validate_datacube() - return self.datacube.Q_Ny - def get_datacube_shape(self): - self._validate_datacube() - """ (R_Nx,R_Ny,Q_Nx,Q_Ny) - """ - return self.datacube.data.dshape - def get_Qshape(self,x): - self._validate_datacube() - return self.data.Qshape - def get_Rshape(self,x): - self._validate_datacube() - return self.data.Rshape - - # is there a datacube? - def _validate_datacube(self): - assert(self.datacube is not None), "Can't find shape attr because Calibration doesn't point to a DataCube" - def _has_datacube(self): - return(self.datacube is not None) - - - - - # origin - def set_qx0(self,x): - self._params['qx0'] = x - x = np.asarray(x) - qx0_mean = np.mean(x) - qx0_shift = x-qx0_mean - self._params['qx0_mean'] = qx0_mean - self._params['qx0_shift'] = qx0_shift - def set_qx0_mean(self,x): - self._params['qx0_mean'] = x - def get_qx0(self,rx=None,ry=None): - return self._get_value('qx0',rx,ry) - def get_qx0_mean(self): - return self._get_value('qx0_mean') - def get_qx0shift(self,rx=None,ry=None): - return self._get_value('qx0_shift',rx,ry) - - def set_qy0(self,x): - self._params['qy0'] = x - x = np.asarray(x) - qy0_mean = np.mean(x) - qy0_shift = x-qy0_mean - self._params['qy0_mean'] = qy0_mean - self._params['qy0_shift'] = qy0_shift - def set_qy0_mean(self,x): - self._params['qy0_mean'] = x - def get_qy0(self,rx=None,ry=None): - return self._get_value('qy0',rx,ry) - def get_qy0_mean(self): - return self._get_value('qy0_mean') - def get_qy0shift(self,rx=None,ry=None): - return self._get_value('qy0_shift',rx,ry) - - def set_qx0_meas(self,x): - self._params['qx0_meas'] = x - def get_qx0_meas(self,rx=None,ry=None): - return self._get_value('qx0_meas',rx,ry) - - def set_qy0_meas(self,x): - self._params['qy0_meas'] = x - def get_qy0_meas(self,rx=None,ry=None): - return self._get_value('qy0_meas',rx,ry) - - def set_origin_meas_mask(self,x): - self._params['origin_meas_mask'] = x - def get_origin_meas_mask(self,rx=None,ry=None): - return self._get_value('origin_meas_mask',rx,ry) - - @propagating_calibration - def set_origin(self,x): - """ - Args: - x (2-tuple of numbers or of 2D, R-shaped arrays): the origin - """ - qx0,qy0 = x - self.set_qx0(qx0) - self.set_qy0(qy0) - def get_origin(self,rx=None,ry=None): - qx0 = self._get_value('qx0',rx,ry) - qy0 = self._get_value('qy0',rx,ry) - ans = (qx0,qy0) - if any([x is None for x in ans]): - ans = None - return ans - def get_origin_mean(self): - qx0 = self._get_value('qx0_mean') - qy0 = self._get_value('qy0_mean') - return qx0,qy0 - def get_origin_shift(self,rx=None,ry=None): - qx0 = self._get_value('qx0_shift',rx,ry) - qy0 = self._get_value('qy0_shift',rx,ry) - ans = (qx0,qy0) - if any([x is None for x in ans]): - ans = None - return ans - - def set_origin_meas(self,x): - """ - Args: - x (2-tuple or 3 uple of 2D R-shaped arrays): qx0,qy0,[mask] - """ - qx0,qy0 = x[0],x[1] - self.set_qx0_meas(qx0) - self.set_qy0_meas(qy0) - try: - m = x[2] - self.set_origin_meas_mask(m) - except IndexError: - pass - def get_origin_meas(self,rx=None,ry=None): - qx0 = self._get_value('qx0_meas',rx,ry) - qy0 = self._get_value('qy0_meas',rx,ry) - ans = (qx0,qy0) - if any([x is None for x in ans]): - ans = None - return ans - - def set_probe_semiangle(self,x): - self._params['probe_semiangle'] = x - def get_probe_semiangle(self): - return self._get_value('probe_semiangle') - def set_probe_param(self, x): - """ - Args: - x (3-tuple): (probe size, x0, y0) - """ - probe_semiangle, qx0, qy0 = x - self.set_probe_semiangle(probe_semiangle) - self.set_qx0_mean(qx0) - self.set_qy0_mean(qy0) - def get_probe_param(self): - probe_semiangle = self._get_value('probe_semiangle') - qx0 = self._get_value('qx0') - qy0 = self._get_value('qy0') - ans = (probe_semiangle,qx0,qy0) - if any([x is None for x in ans]): - ans = None - return ans - - - # ellipse - def set_a(self,x): - self._params['a'] = x - def get_a(self,rx=None,ry=None): - return self._get_value('a',rx,ry) - def set_b(self,x): - self._params['b'] = x - def get_b(self,rx=None,ry=None): - return self._get_value('b',rx,ry) - def set_theta(self,x): - self._params['theta'] = x - def get_theta(self,rx=None,ry=None): - return self._get_value('theta',rx,ry) - - @propagating_calibration - def set_ellipse(self,x): - """ - Args: - x (3-tuple): (a,b,theta) - """ - a,b,theta = x - self._params['a'] = a - self._params['b'] = b - self._params['theta'] = theta - - @propagating_calibration - def set_p_ellipse(self,x): - """ - Args: - x (5-tuple): (qx0,qy0,a,b,theta) NOTE: does *not* change qx0,qy0! - """ - _,_,a,b,theta = x - self._params['a'] = a - self._params['b'] = b - self._params['theta'] = theta - def get_ellipse(self,rx=None,ry=None): - a = self.get_a(rx,ry) - b = self.get_b(rx,ry) - theta = self.get_theta(rx,ry) - ans = (a,b,theta) - if any([x is None for x in ans]): - ans = None - return ans - def get_p_ellipse(self,rx=None,ry=None): - qx0,qy0 = self.get_origin(rx,ry) - a,b,theta = self.get_ellipse(rx,ry) - return (qx0,qy0,a,b,theta) - - # Q/R-space rotation and flip - def set_QR_rotation_degrees(self,x): - self._params['QR_rotation_degrees'] = x - def get_QR_rotation_degrees(self): - return self._get_value('QR_rotation_degrees') - - def set_QR_flip(self,x): - self._params['QR_flip'] = x - def get_QR_flip(self): - return self._get_value('QR_flip') - - @propagating_calibration - def set_QR_rotflip(self, rot_flip): - """ - Args: - rot_flip (tuple), (rot, flip) where: - rot (number): rotation in degrees - flip (bool): True indicates a Q/R axes flip - """ - rot,flip = rot_flip - self.set_QR_rotation_degrees(rot) - self.set_QR_flip(flip) - def get_QR_rotflip(self): - rot = self.get_QR_rotation_degrees() - flip = self.get_QR_flip() - if rot is None or flip is None: - return None - return (rot,flip) - - - # probe - def set_convergence_semiangle_pixels(self,x): - self._params['convergence_semiangle_pixels'] = x - def get_convergence_semiangle_pixels(self): - return self._get_value('convergence_semiangle_pixels') - def set_convergence_semiangle_pixels(self,x): - self._params['convergence_semiangle_mrad'] = x - def get_convergence_semiangle_pixels(self): - return self._get_value('convergence_semiangle_mrad') - def set_probe_center(self,x): - self._params['probe_center'] = x - def get_probe_center(self): - return self._get_value('probe_center') - - - # For parameters which can have 2D or (2+n)D array values, - # this function enables returning the value(s) at a 2D position, - # rather than the whole array - def _get_value(self,p,rx=None,ry=None): - """ Enables returning the value of a pixel (rx,ry), - if these are passed and `p` is an appropriate array - """ - v = self._params.get(p) - - if v is None: - return v - - if (rx is None) or (ry is None) or (not isinstance(v,np.ndarray)): - return v - - else: - er = f"`rx` and `ry` must be ints; got values {rx} and {ry}" - assert np.all([isinstance(i,(int,np.integer)) for i in (rx,ry)]), er - return v[rx,ry] - - - - def copy(self,name=None): - """ - """ - if name is None: name = self.name+"_copy" - cal = Calibration(name=name) - cal._params.update(self._params) - return cal - - - # Methods for assigning objects which will be - # auto-calibrated when the Calibration instance is updated - - def register_target(self,new_target): - """ - Register an object to recieve calls to it `calibrate` - method when certain calibrations get updated - """ - self._targets.append(new_target) - - def unregister_target(self,target): - """ - Unlink an object from recieving calls to `calibrate` when - certain calibration values are changed - """ - if target in self._targets: - self._targets.remove(target) - - - # HDF5 i/o - - # write is inherited from Metadata - - # read - def from_h5(group): - """ - Takes a valid group for an HDF5 file object which is open in - read mode. Determines if it's a valid Metadata representation, and - if so loads and returns it as a Calibration instance. Otherwise, - raises an exception. - - Accepts: - group (HDF5 group) - - Returns: - A Calibration instance - """ - # load the group as a Metadata instance - metadata = Metadata.from_h5(group) - - # convert it to a Calibration instance - cal = Calibration(name = metadata.name) - cal._params.update(metadata._params) - - # return - return cal - - - - -########## End of class ########## - - diff --git a/py4DSTEM/classes/data.py b/py4DSTEM/classes/data.py deleted file mode 100644 index 194d5af35..000000000 --- a/py4DSTEM/classes/data.py +++ /dev/null @@ -1,39 +0,0 @@ -# Base class for all py4DSTEM data -# which adds a pointer to 'calibration' metadata - -import warnings - -from emdfile import Node -from py4DSTEM.classes import Calibration - - -class Data: - - def __init__(self): - assert(isinstance(self,Node)), "Data instances must alse inherit from Node" - pass - - - # calibration - - @property - def calibration(self): - try: - return self.root.metadata['calibration'] - except KeyError: - return None - except AttributeError: - return None - - @calibration.setter - def calibration(self, x): - assert( isinstance( x, Calibration) ) - if 'calibration' in self.root.metadata.keys(): - warnings.warn("A 'calibration' key already exists in root.metadata - overwriting...") - x.name = 'calibration' - self.root.metadata['calibration'] = x - - - - - diff --git a/py4DSTEM/classes/datacube.py b/py4DSTEM/classes/datacube.py deleted file mode 100644 index 3f2edcf62..000000000 --- a/py4DSTEM/classes/datacube.py +++ /dev/null @@ -1,246 +0,0 @@ -# Defines the DataCube class, which stores 4D-STEM datacubes - -from typing import Optional,Union -import numpy as np - -from emdfile import Array,RootedNode -from py4DSTEM.classes import Data, Calibration -from py4DSTEM.classes.methods import DataCubeMethods - -class DataCube(Array,RootedNode,Data,DataCubeMethods): - """ - Storage and processing methods for 4D-STEM datasets. - """ - - def __init__( - self, - data: np.ndarray, - name: Optional[str] = 'datacube', - calibration: Optional[Union[Calibration,None]] = None, - slicelabels: Optional[Union[bool,list]] = None, - ): - """ - Accepts: - data (np.ndarray): the data - name (str): the name of the datacube - calibration (None or Calibration or 'pass'): default (None) - creates and attaches a new Calibration instance to root - metadata, or, passing a Calibration instance uses this instead. - 'skip' is for internal use for the reader - slicelabels (None or list): names for slices if this is a - stack of datacubes - - Returns: - A new DataCube instance. - """ - - # initialize DataCubeMethods - super(DataCubeMethods).__init__() - - # initialize as an Array - Array.__init__( - self, - data = data, - name = name, - units = 'pixel intensity', - dim_names = [ - 'Rx', - 'Ry', - 'Qx', - 'Qy' - ], - slicelabels = slicelabels - ) - - # set up EMD tree - RootedNode.__init__(self) - - # set up calibration - self._setup_calibration( calibration ) - - # cartesian coords - # TODO - tmp hack, needs to be refactored - - # this will break when preprocess methods are called - self.qyy,self.qxx = np.meshgrid( - np.arange(0,self.Q_Ny), - np.arange(0,self.Q_Nx) - ) - - # polar coords - self.polar = None - - - - def _setup_calibration(self, cal): - """ - Ensures that a calibration instance exists. Passing None - makes a new Calibration instance, puts it in root.calibration, and - makes a two way link. Passing a Calibration instance attaches that - instance. `'skip'` does nothing (internal use, on read from file). - """ - if cal is None: - self.calibration = Calibration( datacube = self ) - elif cal == 'skip': - pass - else: - assert(isinstance(cal, Calibration)), "`calibration` must be a Calibration instance, not type f{type(cal)}" - self.calibration = cal - cal._datacube = self - - - def copy(self): - """ - Copys datacube - """ - from py4DSTEM import DataCube - new_datacube = DataCube( - data = self.data.copy(), - name = self.name, - calibration = self.calibration.copy(), - slicelabels = self.slicelabels, - ) - - Qpixsize = new_datacube.calibration.get_Q_pixel_size() - Qpixunits = new_datacube.calibration.get_Q_pixel_units() - Rpixsize = new_datacube.calibration.get_R_pixel_size() - Rpixunits = new_datacube.calibration.get_R_pixel_units() - - new_datacube.set_dim( - 0, - [0,Rpixsize], - units = Rpixunits, - name = 'Rx' - ) - new_datacube.set_dim( - 1, - [0,Rpixsize], - units = Rpixunits, - name = 'Ry' - ) - - new_datacube.set_dim( - 2, - [0,Qpixsize], - units = Qpixunits, - name = 'Qx' - ) - new_datacube.set_dim( - 3, - [0,Qpixsize], - units = Qpixunits, - name = 'Qy' - ) - - return new_datacube - - - # properties - - - ## pixel size / units - - # Q - @property - def Q_pixel_size(self): - return self.calibration.get_Q_pixel_size() - @property - def Q_pixel_units(self): - return self.calibration.get_Q_pixel_units() - - # R - @property - def R_pixel_size(self): - return self.calibration.get_R_pixel_size() - @property - def R_pixel_units(self): - return self.calibration.get_R_pixel_units() - - # aliases - qpixsize = Q_pixel_size - qpixunit = Q_pixel_units - rpixsize = R_pixel_size - rpixunit = R_pixel_units - - - ## shape - - # FOV - @property - def R_Nx(self): - return self.data.shape[0] - @property - def R_Ny(self): - return self.data.shape[1] - @property - def Q_Nx(self): - return self.data.shape[2] - @property - def Q_Ny(self): - return self.data.shape[3] - - @property - def Rshape(self): - return (self.data.shape[0],self.data.shape[1]) - @property - def Qshape(self): - return (self.data.shape[2],self.data.shape[3]) - - @property - def R_N(self): - return self.R_Nx*self.R_Ny - - # aliases - qnx = Q_Nx - qny = Q_Ny - rnx = R_Nx - rny = R_Ny - rshape = Rshape - qshape = Qshape - rn = R_N - - - - - - # HDF5 i/o - - # to_h5 is inherited from Array - - # read - @classmethod - def _get_constructor_args(cls,group): - """ Construct a datacube with no calibration / metadata - """ - # We only need some of the Array constructors; - # dim vector/units are passed through when Calibration - # is loaded, and the runtim dim vectors are then set - # in _add_root_links - ar_args = Array._get_constructor_args(group) - - args = { - 'data': ar_args['data'], - 'name': ar_args['name'], - 'slicelabels': ar_args['slicelabels'], - 'calibration': 'skip' - } - - return args - - - def _add_root_links(self,group): - """ When reading from file, link to calibration metadata, - then use it to populate the datacube dim vectors - """ - # Link to the datacube - self.calibration._datacube = self - - # Populate dim vectors - self.calibration.set_Q_pixel_size( self.calibration.get_Q_pixel_size() ) - self.calibration.set_R_pixel_size( self.calibration.get_R_pixel_size() ) - self.calibration.set_Q_pixel_units( self.calibration.get_Q_pixel_units() ) - self.calibration.set_R_pixel_units( self.calibration.get_R_pixel_units() ) - - return - - - diff --git a/py4DSTEM/classes/methods/__init__.py b/py4DSTEM/classes/methods/__init__.py deleted file mode 100644 index ab06bce4e..000000000 --- a/py4DSTEM/classes/methods/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from py4DSTEM.classes.methods.datacube_methods import DataCubeMethods -from py4DSTEM.classes.methods.probe_methods import ProbeMethods - diff --git a/py4DSTEM/classes/methods/probe_methods.py b/py4DSTEM/classes/methods/probe_methods.py deleted file mode 100644 index 9fd040ff6..000000000 --- a/py4DSTEM/classes/methods/probe_methods.py +++ /dev/null @@ -1,76 +0,0 @@ -# Functions to become Probe methods - -import numpy as np - - - -class ProbeMethods: - """ - A container for Probe object instance methods. - """ - - def __init__(self): - pass - - - - # Kernel generation - - def get_kernel( - self, - mode = 'flat', - returncalc = True, - **kwargs - ): - """ - Creates a kernel from the probe for cross-correlative template matching. - - Precise behavior and valid keyword arguments depend on the `mode` - selected. In each case, the center of the probe is shifted to the - origin and the kernel normalized such that it sums to 1. In 'flat' - mode, this is the only processing performed. In the remaining modes, - some additional processing is performed which adds a ring of - negative intensity around the central probe, which results in - edge-filetering-like behavior during cross correlation. Valid modes, - and the required additional kwargs, if any, for each, are: - - - 'flat': creates a flat probe kernel. For bullseye or other - structured probes, this mode is recommended. No required - arguments, optional arg `origin` (2 tuple) - - 'gaussian': subtracts a gaussian with a width of standard - deviation `sigma`, which is a required argument. Optional - arg `origin`. - - 'sigmoid': subtracts an annulus with inner and outer radii - of (ri,ro) and a sine-squared sigmoid radial profile from - the probe template. Required arg: `radii` (2 tuple). Optional - args `origin` (2-tuple) - - 'sigmoid_log': subtracts an annulus with inner and outer radii - of (ri,ro) and a logistic sigmoid radial profile from - the probe template. Required arg: `radii` (2 tuple). Optional - args `origin` (2-tuple) - - Returns: - (2D array) - """ - - # perform computation - from py4DSTEM.process.probe import get_kernel - kern = get_kernel( - self.probe, - mode = mode, - **kwargs - ) - - # add to the Probe - self.kernel = kern - - # return - if returncalc: - return kern - - - - - - - diff --git a/py4DSTEM/classes/probe.py b/py4DSTEM/classes/probe.py deleted file mode 100644 index fc6dade97..000000000 --- a/py4DSTEM/classes/probe.py +++ /dev/null @@ -1,84 +0,0 @@ - -from py4DSTEM.classes import DiffractionSlice, Data -from py4DSTEM.classes.methods import ProbeMethods - -from typing import Optional -import numpy as np - -class Probe(DiffractionSlice,ProbeMethods,Data): - """ - Stores a vacuum probe. - """ - - def __init__( - self, - data: np.ndarray, - name: Optional[str] = 'probe' - ): - """ - Accepts: - data (2D or 3D np.ndarray): the vacuum probe, or - the vacuum probe + kernel - name (str): a name - - Returns: - (Probe) - """ - # if only the probe is passed, make space for the kernel - if data.ndim == 2: - data = np.stack([ - data, - np.zeros_like(data) - ]) - - # initialize as a ProbeMethods instance - super(ProbeMethods).__init__() - - # initialize as a DiffractionSlice - DiffractionSlice.__init__( - self, - name = name, - data = data, - slicelabels = [ - 'probe', - 'kernel' - ] - ) - - - ## properties - - @property - def probe(self): - return self.get_slice('probe').data - @probe.setter - def probe(self,x): - assert(x.shape == (self.data.shape[1:])) - self.data[0,:,:] = x - @property - def kernel(self): - return self.get_slice('kernel').data - @kernel.setter - def kernel(self,x): - assert(x.shape == (self.data.shape[1:])) - self.data[1,:,:] = x - - - - - # read - @classmethod - def _get_constructor_args(cls,group): - """ - Returns a dictionary of args/values to pass to the class constructor - """ - ar_constr_args = DiffractionSlice._get_constructor_args(group) - args = { - 'data' : ar_constr_args['data'], - 'name' : ar_constr_args['name'], - } - return args - - - - diff --git a/py4DSTEM/classes/propagating_calibration.py b/py4DSTEM/classes/propagating_calibration.py deleted file mode 100644 index 8b1196ee6..000000000 --- a/py4DSTEM/classes/propagating_calibration.py +++ /dev/null @@ -1,46 +0,0 @@ -import warnings - -class propagating_calibration(object): - """ - A decorator which, when attached to a method of Calibration, - causes `calibrate` to be called on any objects in the - Calibration object's `_targets` list, following execution of - the decorated function. - This allows objects associated with the Calibration to - automatically respond to changes in the calibration state. - """ - def __init__(self, func): - self.func = func - - def __call__(self, *args, **kwargs): - """ - Update the parameters the caller wanted by calling the wrapped - method, then loop through the list of targetsand call their - `calibrate` methods. - """ - self.func(*args,**kwargs) - - calibration = args[0] - assert hasattr(calibration, "_targets"), "Calibration object appears to be in an invalid state. _targets attribute is missing." - for target in calibration._targets: - if hasattr(target,'calibrate') and callable(target.calibrate): - try: - target.calibrate() - except Exception as err: - print(f"Attempted to calibrate object {target} but this raised an error: {err}") - else: - warnings.warn(f"{target} is registered as a target for calibration propagation but does not appear to have a calibrate() method") - - def __get__(self, instance, owner): - """ - This is some magic to make sure that the Calibration instance - on which the decorator was called gets passed through and - everything dispatches correctly (by making sure `instance`, - the Calibration instance to which the call was directed, gets - placed in the `self` slot of the wrapped method (which is *not* - actually bound to the instance due to this decoration.) using - partial application of the method.) - """ - from functools import partial - return partial(self.__call__, instance) - diff --git a/py4DSTEM/classes/virtualdiffraction.py b/py4DSTEM/classes/virtualdiffraction.py deleted file mode 100644 index 9d1bb4317..000000000 --- a/py4DSTEM/classes/virtualdiffraction.py +++ /dev/null @@ -1,50 +0,0 @@ - -from typing import Optional -import numpy as np - -from py4DSTEM.classes import DiffractionSlice,Data - - - -class VirtualDiffraction(DiffractionSlice,Data): - """ - Stores a diffraction-space shaped 2D image with metadata - indicating how this image was generated from a datacube. - """ - def __init__( - self, - data: np.ndarray, - name: Optional[str] = 'virtualdiffraction', - ): - """ - Args: - data (np.ndarray) : the 2D data - name (str) : the name - - Returns: - A new VirtualDiffraction instance - """ - # initialize as a DiffractionSlice - DiffractionSlice.__init__( - self, - data = data, - name = name, - ) - - - # read - @classmethod - def _get_constructor_args(cls,group): - """ - Returns a dictionary of args/values to pass to the class constructor - """ - ar_constr_args = DiffractionSlice._get_constructor_args(group) - args = { - 'data' : ar_constr_args['data'], - 'name' : ar_constr_args['name'], - } - return args - - - - diff --git a/py4DSTEM/classes/virtualimage.py b/py4DSTEM/classes/virtualimage.py deleted file mode 100644 index a14d83e99..000000000 --- a/py4DSTEM/classes/virtualimage.py +++ /dev/null @@ -1,54 +0,0 @@ - -from typing import Optional -import numpy as np - -from py4DSTEM.classes import RealSlice,Data - - - -class VirtualImage(RealSlice,Data): - """ - Stores a real-space shaped 2D image with metadata - indicating how this image was generated from a datacube. - """ - def __init__( - self, - data: np.ndarray, - name: Optional[str] = 'virtualimage', - ): - """ - Args: - data (np.ndarray) : the 2D data - name (str) : the name - - Returns: - A new VirtualImage instance - """ - # initialize as a RealSlice - RealSlice.__init__( - self, - data = data, - name = name, - ) - - - - # read - @classmethod - def _get_constructor_args(cls,group): - """ - Returns a dictionary of args/values to pass to the class constructor - """ - ar_constr_args = RealSlice._get_constructor_args(group) - args = { - 'data' : ar_constr_args['data'], - 'name' : ar_constr_args['name'], - } - return args - - - - - - - diff --git a/py4DSTEM/data/__init__.py b/py4DSTEM/data/__init__.py new file mode 100644 index 000000000..ff9429fe0 --- /dev/null +++ b/py4DSTEM/data/__init__.py @@ -0,0 +1,8 @@ +_emd_hook = True + +from py4DSTEM.data.calibration import Calibration +from py4DSTEM.data.data import Data +from py4DSTEM.data.diffractionslice import DiffractionSlice +from py4DSTEM.data.realslice import RealSlice +from py4DSTEM.data.qpoints import QPoints + diff --git a/py4DSTEM/data/calibration.py b/py4DSTEM/data/calibration.py new file mode 100644 index 000000000..50ec8f6f9 --- /dev/null +++ b/py4DSTEM/data/calibration.py @@ -0,0 +1,816 @@ +# Defines the Calibration class, which stores calibration metadata + +import numpy as np +from numbers import Number +from typing import Optional +from warnings import warn + +from emdfile import Metadata,Root +from py4DSTEM.data.propagating_calibration import call_calibrate + +class Calibration(Metadata): + """ + Stores calibration measurements. + + Usage + ----- + For some calibration instance `c` + + >>> c['x'] = y + + will set the value of some calibration item called 'x' to y, and + + >>> _y = c['x'] + + will return the value currently stored as 'x' and assign it to _y. + Additionally, for calibration items in the list `l` given below, + the syntax + + >>> c.set_p(p) + >>> p = c.get_p() + + is equivalent to + + >>> c.p = p + >>> p = c.p + + is equivalent to + + >>> c['p'] = p + >>> p = c['p'] + + where in the first line of each couplet the parameter `p` is set and in + the second it's retrieved, for parameters p in the list + + calibrate + --------- + l = [ + Q_pixel_size, * + R_pixel_size, * + Q_pixel_units, * + R_pixel_units, * + qx0, + qy0, + qx0_mean, + qy0_mean, + qx0shift, + qy0shift, + origin, * + origin_meas, + origin_meas_mask, + origin_shift, + a, * + b, * + theta, * + p_ellipse, * + ellipse, * + QR_rotation_degrees, * + QR_flip, * + QR_rotflip, * + probe_semiangle, + probe_param, + probe_center, + probe_convergence_semiangle_pixels, + probe_convergence_semiangle_mrad, + ] + + There are two advantages to using the getter/setter syntax for parameters + in `l` (e.g. either c.set_p or c.p) instead of the normal dictionary-like + getter/setter syntax (i.e. c['p']). These are (1) enabling retrieving + parameters by beam scan position, and (2) enabling propagation of any + calibration changes to downstream data objects which are affected by the + altered calibrations. See below. + + Get a parameter by beam scan position + ------------------------------------- + Some parameters support retrieval by beam scan position. In these cases, + calling + + >>> c.get_p(rx,ry) + + will return the value of parameter p at beam position (rx,ry). This works + only for the above syntax. Using either of + + >>> c.p + >>> c['p'] + + will return an R-space shaped array. + + Trigger downstream calibrations + ------------------------------- + Some objects store their own internal calibration state, which depends on + the calibrations stored here. For example, a DataCube stores dimension + vectors which calibrate its 4 dimensions, and which depend on the pixel + sizes and the origin position. + + Modifying certain parameters therefore can trigger other objects which + depend on these parameters to re-calibrate themselves by calling their + .calibrate() method, if the object has one. Methods marked with a * in the + list `l` above have this property. Only objects registered with the + Calibration instance will have their .calibrate method triggered by changing + these parameters. An object `data` can be registered by calling + + >>> c.register_target( data ) + + and deregistered with + + >>> c.deregister_target( data ) + + If an object without a .calibrate method is registerd when a * method is + called, nothing happens. + + The .calibrate methods are triggered by setting some parameter `p` using + either + + >>> c.set_p( val ) + + or + + >>> c.p = val + + syntax. Setting the parameter with + + >>> c['p'] = val + + will not trigger re-calibrations. + + Calibration + Data + ------------------ + Data in py4DSTEM is stored in filetree like representations, and + Calibration instances are the top-level objects in these trees, + in that they live here: + + Root + |--metadata + | |-- *****---> calibration <---***** + | + |--some_object(e.g.datacube) + | |--another_object(e.g.max_dp) + | |--etc. + |--etc. + : + + Every py4DSTEM Data object has a tree with a calibration, and calling + + >>> data.calibration + + will return the that Calibration instance. See also the docstring + for the `Data` class. + + Attaching an object to a different Calibration + ---------------------------------------------- + To modify the calibration associated with some object `data`, use + + >>> c.attach( data ) + + where `c` is the new calibration instance. This (1) moves `data` into the + top level of `c`'s data tree, which means the new calibration will now be + accessible normally at + + >>> data.calibration + + and (2) if and only if `data` was registered with its old calibration, + de-registers it there and registers it with the new calibration. If + `data` was not registered with the old calibration and it should be + registered with the new one, `c.register_target( data )` should be + called. + + To attach `data` to a different location in the calibration instance's + tree, use `node.attach( data )`. See the Data.attach docstring. + """ + def __init__( + self, + name: Optional[str] = 'calibration', + root: Optional[Root] = None, + ): + """ + Args: + name (optional, str): + """ + Metadata.__init__( + self, + name=name + ) + + # Set the root + if root is None: + root = Root( name="py4DSTEM_root" ) + self.set_root(root) + + # List to hold objects that will re-`calibrate` when + # certain properties are changed + self._targets = [] + + # set initial pixel values + self['Q_pixel_size'] = 1 + self['R_pixel_size'] = 1 + self['Q_pixel_units'] = 'pixels' + self['R_pixel_units'] = 'pixels' + + + # EMD root property + @property + def root(self): + return self._root + @root.setter + def root(self): + raise Exception("Calibration.root does not support assignment; to change the root, use self.set_root") + def set_root(self,root): + assert(isinstance(root,Root)), f"root must be a Root, not type {type(root)}" + self._root = root + + + # Attach data to the calibration instance + def attach(self,data): + """ + Attach `data` to this calibration instance, placing it in the top + level of the Calibration instance's tree. If `data` was in a + different data tree, remove it. If `data` was registered with + a different calibration instance, de-register it there and + register it here. If `data` was not previously registerd and it + should be, after attaching it run `self.register_target(data)`. + """ + from py4DSTEM.data import Data + assert(isinstance(data,Data)), f"data must be a Data instance" + self.root.attach(data) + + + # Register for auto-calibration + def register_target(self,new_target): + """ + Register an object to recieve calls to it `calibrate` + method when certain calibrations get updated + """ + if new_target not in self._targets: + self._targets.append(new_target) + + def unregister_target(self,target): + """ + Unlink an object from recieving calls to `calibrate` when + certain calibration values are changed + """ + if target in self._targets: + self._targets.remove(target) + + @property + def targets(self): + return tuple(self._targets) + + + + ######### Begin Calibration Metadata Params ######### + + # pixel size/units + + @call_calibrate + def set_Q_pixel_size(self,x): + self._params['Q_pixel_size'] = x + def get_Q_pixel_size(self): + return self._get_value('Q_pixel_size') + # aliases + @property + def Q_pixel_size(self): + return self.get_Q_pixel_size() + @Q_pixel_size.setter + def Q_pixel_size(self,x): + self.set_Q_pixel_size(x) + @property + def qpixsize(self): + return self.get_Q_pixel_size() + @qpixsize.setter + def qpixsize(self,x): + self.set_Q_pixel_size(x) + + @call_calibrate + def set_R_pixel_size(self,x): + self._params['R_pixel_size'] = x + def get_R_pixel_size(self): + return self._get_value('R_pixel_size') + # aliases + @property + def R_pixel_size(self): + return self.get_R_pixel_size() + @R_pixel_size.setter + def R_pixel_size(self,x): + self.set_R_pixel_size(x) + @property + def qpixsize(self): + return self.get_R_pixel_size() + @qpixsize.setter + def qpixsize(self,x): + self.set_R_pixel_size(x) + + @call_calibrate + def set_Q_pixel_units(self,x): + assert(x in ('pixels','A^-1','mrad')), f"Q pixel units must be 'A^-1', 'mrad' or 'pixels'." + self._params['Q_pixel_units'] = x + def get_Q_pixel_units(self): + return self._get_value('Q_pixel_units') + # aliases + @property + def Q_pixel_units(self): + return self.get_Q_pixel_units() + @Q_pixel_units.setter + def Q_pixel_units(self,x): + self.set_Q_pixel_units(x) + @property + def qpixunits(self): + return self.get_Q_pixel_units() + @qpixunits.setter + def qpixunits(self,x): + self.set_Q_pixel_units(x) + + @call_calibrate + def set_R_pixel_units(self,x): + self._params['R_pixel_units'] = x + def get_R_pixel_units(self): + return self._get_value('R_pixel_units') + # aliases + @property + def R_pixel_units(self): + return self.get_R_pixel_units() + @R_pixel_units.setter + def R_pixel_units(self,x): + self.set_R_pixel_units(x) + @property + def rpixunits(self): + return self.get_R_pixel_units() + @rpixunits.setter + def rpixunits(self,x): + self.set_R_pixel_units(x) + + + # origin + + # qx0,qy0 + def set_qx0(self,x): + self._params['qx0'] = x + x = np.asarray(x) + qx0_mean = np.mean(x) + qx0_shift = x-qx0_mean + self._params['qx0_mean'] = qx0_mean + self._params['qx0_shift'] = qx0_shift + def set_qx0_mean(self,x): + self._params['qx0_mean'] = x + def get_qx0(self,rx=None,ry=None): + return self._get_value('qx0',rx,ry) + def get_qx0_mean(self): + return self._get_value('qx0_mean') + def get_qx0shift(self,rx=None,ry=None): + return self._get_value('qx0_shift',rx,ry) + + def set_qy0(self,x): + self._params['qy0'] = x + x = np.asarray(x) + qy0_mean = np.mean(x) + qy0_shift = x-qy0_mean + self._params['qy0_mean'] = qy0_mean + self._params['qy0_shift'] = qy0_shift + def set_qy0_mean(self,x): + self._params['qy0_mean'] = x + def get_qy0(self,rx=None,ry=None): + return self._get_value('qy0',rx,ry) + def get_qy0_mean(self): + return self._get_value('qy0_mean') + def get_qy0shift(self,rx=None,ry=None): + return self._get_value('qy0_shift',rx,ry) + + def set_qx0_meas(self,x): + self._params['qx0_meas'] = x + def get_qx0_meas(self,rx=None,ry=None): + return self._get_value('qx0_meas',rx,ry) + + def set_qy0_meas(self,x): + self._params['qy0_meas'] = x + def get_qy0_meas(self,rx=None,ry=None): + return self._get_value('qy0_meas',rx,ry) + + def set_origin_meas_mask(self,x): + self._params['origin_meas_mask'] = x + def get_origin_meas_mask(self,rx=None,ry=None): + return self._get_value('origin_meas_mask',rx,ry) + + # aliases + @property + def qx0(self): + return self.get_qx0() + @qx0.setter + def qx0(self,x): + self.set_qx0(x) + @property + def qx0_mean(self): + return self.get_qx0_mean() + @qx0_mean.setter + def qx0_mean(self,x): + self.set_qx0_mean(x) + @property + def qx0shift(self): + return self.get_qx0shift() + @property + def qy0(self): + return self.get_qy0() + @qy0.setter + def qy0(self,x): + self.set_qy0(x) + @property + def qy0_mean(self): + return self.get_qy0_mean() + @qy0_mean.setter + def qy0_mean(self,x): + self.set_qy0_mean(x) + @property + def qy0_shift(self): + return self.get_qy0_shift() + @property + def qx0_meas(self): + return self.get_qx0_meas() + @qx0_meas.setter + def qx0_meas(self,x): + self.set_qx0_meas(x) + @property + def qy0_meas(self): + return self.get_qy0_meas() + @qy0_meas.setter + def qy0_meas(self,x): + self.set_qy0_meas(x) + @property + def origin_meas_mask(self): + return self.get_origin_meas_mask() + @origin_meas_mask.setter + def origin_meas_mask(self,x): + self.set_origin_meas_mask(x) + + + # origin = (qx0,qy0) + + @call_calibrate + def set_origin(self,x): + """ + Args: + x (2-tuple of numbers or of 2D, R-shaped arrays): the origin + """ + qx0,qy0 = x + self.set_qx0(qx0) + self.set_qy0(qy0) + def get_origin(self,rx=None,ry=None): + qx0 = self._get_value('qx0',rx,ry) + qy0 = self._get_value('qy0',rx,ry) + ans = (qx0,qy0) + if any([x is None for x in ans]): + ans = None + return ans + def get_origin_mean(self): + qx0 = self._get_value('qx0_mean') + qy0 = self._get_value('qy0_mean') + return qx0,qy0 + def get_origin_shift(self,rx=None,ry=None): + qx0 = self._get_value('qx0_shift',rx,ry) + qy0 = self._get_value('qy0_shift',rx,ry) + ans = (qx0,qy0) + if any([x is None for x in ans]): + ans = None + return ans + + def set_origin_meas(self,x): + """ + Args: + x (2-tuple or 3 uple of 2D R-shaped arrays): qx0,qy0,[mask] + """ + qx0,qy0 = x[0],x[1] + self.set_qx0_meas(qx0) + self.set_qy0_meas(qy0) + try: + m = x[2] + self.set_origin_meas_mask(m) + except IndexError: + pass + def get_origin_meas(self,rx=None,ry=None): + qx0 = self._get_value('qx0_meas',rx,ry) + qy0 = self._get_value('qy0_meas',rx,ry) + ans = (qx0,qy0) + if any([x is None for x in ans]): + ans = None + return ans + + # aliases + @property + def origin(self): + return self.get_origin() + @origin.setter + def origin(self,x): + self.set_origin(x) + @property + def origin_meas(self): + return self.get_origin_meas() + @origin_meas.setter + def origin_meas(self,x): + self.set_origin_meas(x) + @property + def origin_shift(self): + return self.get_origin_shift() + + + # ellipse + + @call_calibrate + def set_a(self,x): + self._params['a'] = x + def get_a(self,rx=None,ry=None): + return self._get_value('a',rx,ry) + @call_calibrate + def set_b(self,x): + self._params['b'] = x + def get_b(self,rx=None,ry=None): + return self._get_value('b',rx,ry) + @call_calibrate + def set_theta(self,x): + self._params['theta'] = x + def get_theta(self,rx=None,ry=None): + return self._get_value('theta',rx,ry) + + @call_calibrate + def set_ellipse(self,x): + """ + Args: + x (3-tuple): (a,b,theta) + """ + a,b,theta = x + self._params['a'] = a + self._params['b'] = b + self._params['theta'] = theta + + @call_calibrate + def set_p_ellipse(self,x): + """ + Args: + x (5-tuple): (qx0,qy0,a,b,theta) NOTE: does *not* change qx0,qy0! + """ + _,_,a,b,theta = x + self._params['a'] = a + self._params['b'] = b + self._params['theta'] = theta + def get_ellipse(self,rx=None,ry=None): + a = self.get_a(rx,ry) + b = self.get_b(rx,ry) + theta = self.get_theta(rx,ry) + ans = (a,b,theta) + if any([x is None for x in ans]): + ans = None + return ans + def get_p_ellipse(self,rx=None,ry=None): + qx0,qy0 = self.get_origin(rx,ry) + a,b,theta = self.get_ellipse(rx,ry) + return (qx0,qy0,a,b,theta) + + # aliases + @property + def a(self): + return self.get_a() + @a.setter + def a(self,x): + self.set_a(x) + @property + def b(self): + return self.get_b() + @b.setter + def b(self,x): + self.set_b(x) + @property + def theta(self): + return self.get_theta() + @theta.setter + def theta(self,x): + self.set_theta(x) + @property + def p_ellipse(self): + return self.get_p_ellipse() + @p_ellipse.setter + def p_ellipse(self,x): + self.set_p_ellipse(x) + @property + def ellipse(self): + return self.get_ellipse() + @ellipse.setter + def ellipse(self,x): + self.set_ellipse(x) + + + + # Q/R-space rotation and flip + + @call_calibrate + def set_QR_rotation_degrees(self,x): + self._params['QR_rotation_degrees'] = x + def get_QR_rotation_degrees(self): + return self._get_value('QR_rotation_degrees') + + @call_calibrate + def set_QR_flip(self,x): + self._params['QR_flip'] = x + def get_QR_flip(self): + return self._get_value('QR_flip') + + @call_calibrate + def set_QR_rotflip(self, rot_flip): + """ + Args: + rot_flip (tuple), (rot, flip) where: + rot (number): rotation in degrees + flip (bool): True indicates a Q/R axes flip + """ + rot,flip = rot_flip + self._params['QR_rotation_degrees'] = rot + self._params['QR_flip'] = flip + def get_QR_rotflip(self): + rot = self.get_QR_rotation_degrees() + flip = self.get_QR_flip() + if rot is None or flip is None: + return None + return (rot,flip) + + # aliases + @property + def QR_rotation_degrees(self): + return self.get_QR_rotation_degrees() + @QR_rotation_degrees.setter + def QR_rotation_degrees(self,x): + self.set_QR_rotation_degrees(x) + @property + def QR_flip(self): + return self.get_QR_flip() + @QR_flip.setter + def QR_flip(self,x): + self.set_QR_flip(x) + @property + def QR_rotflip(self): + return self.get_QR_rotflip() + @QR_rotflip.setter + def QR_rotflip(self,x): + self.set_QR_rotflip(x) + + + + + + # probe + + def set_probe_semiangle(self,x): + self._params['probe_semiangle'] = x + def get_probe_semiangle(self): + return self._get_value('probe_semiangle') + def set_probe_param(self, x): + """ + Args: + x (3-tuple): (probe size, x0, y0) + """ + probe_semiangle, qx0, qy0 = x + self.set_probe_semiangle(probe_semiangle) + self.set_qx0_mean(qx0) + self.set_qy0_mean(qy0) + def get_probe_param(self): + probe_semiangle = self._get_value('probe_semiangle') + qx0 = self._get_value('qx0') + qy0 = self._get_value('qy0') + ans = (probe_semiangle,qx0,qy0) + if any([x is None for x in ans]): + ans = None + return ans + + def set_convergence_semiangle_pixels(self,x): + self._params['convergence_semiangle_pixels'] = x + def get_convergence_semiangle_pixels(self): + return self._get_value('convergence_semiangle_pixels') + def set_convergence_semiangle_mrad(self,x): + self._params['convergence_semiangle_mrad'] = x + def get_convergence_semiangle_mrad(self): + return self._get_value('convergence_semiangle_mrad') + def set_probe_center(self,x): + self._params['probe_center'] = x + def get_probe_center(self): + return self._get_value('probe_center') + + #aliases + @property + def probe_semiangle(self): + return self.get_probe_semiangle() + @probe_semiangle.setter + def probe_semiangle(self,x): + self.set_probe_semiangle(x) + @property + def probe_param(self): + return self.get_probe_param() + @probe_param.setter + def probe_param(self,x): + self.set_probe_param(x) + @property + def probe_center(self): + return self.get_probe_center() + @probe_center.setter + def probe_center(self,x): + self.set_probe_center(x) + @property + def probe_convergence_semiangle_pixels(self): + return self.get_probe_convergence_semiangle_pixels() + @probe_convergence_semiangle_pixels.setter + def probe_convergence_semiangle_pixels(self,x): + self.set_probe_convergence_semiangle_pixels(x) + @property + def probe_convergence_semiangle_mrad(self): + return self.get_probe_convergence_semiangle_mrad() + @probe_convergence_semiangle_mrad.setter + def probe_convergence_semiangle_mrad(self,x): + self.set_probe_convergence_semiangle_mrad(x) + + + + + + ######## End Calibration Metadata Params ######## + + + + # calibrate targets + @call_calibrate + def calibrate(self): + pass + + + + # For parameters which can have 2D or (2+n)D array values, + # this function enables returning the value(s) at a 2D position, + # rather than the whole array + def _get_value(self,p,rx=None,ry=None): + """ Enables returning the value of a pixel (rx,ry), + if these are passed and `p` is an appropriate array + """ + v = self._params.get(p) + + if v is None: + return v + + if (rx is None) or (ry is None) or (not isinstance(v,np.ndarray)): + return v + + else: + er = f"`rx` and `ry` must be ints; got values {rx} and {ry}" + assert np.all([isinstance(i,(int,np.integer)) for i in (rx,ry)]), er + return v[rx,ry] + + + + def copy(self,name=None): + """ + """ + if name is None: name = self.name+"_copy" + cal = Calibration(name=name) + cal._params.update(self._params) + return cal + + + + # HDF5 i/o + + # write is inherited from Metadata + def to_h5(self,group): + """ + Saves the metadata dictionary _params to group, then adds the + calibration's target's list + """ + # Add targets list to metadata + targets = [x._treepath for x in self.targets] + self['_target_paths'] = targets + # Save the metadata + Metadata.to_h5(self,group) + del(self._params['_target_paths']) + + # read + @classmethod + def from_h5(cls,group): + """ + Takes a valid group for an HDF5 file object which is open in + read mode. Determines if it's a valid Metadata representation, and + if so loads and returns it as a Calibration instance. Otherwise, + raises an exception. + + Accepts: + group (HDF5 group) + + Returns: + A Calibration instance + """ + # load the group as a Metadata instance + metadata = Metadata.from_h5(group) + + # convert it to a Calibration instance + cal = Calibration(name = metadata.name) + cal._params.update(metadata._params) + + # return + return cal + + + + +########## End of class ########## + + diff --git a/py4DSTEM/data/data.py b/py4DSTEM/data/data.py new file mode 100644 index 000000000..3a7b415db --- /dev/null +++ b/py4DSTEM/data/data.py @@ -0,0 +1,160 @@ +# Base class for all py4DSTEM data +# which adds an EMD root and a pointer to 'calibration' metadata + +import warnings + +from emdfile import Node, Root +from py4DSTEM.data import Calibration + + +class Data: + """ + The purpose of the `Data` class is to ensure calibrations are linked + to data containing class instances, while allowing multiple objects + to share a single Calibration. The calibrations of a Data instance + `data` is accessible as + + >>> data.calibration + + In py4DSTEM, Data containing objects are stored internally in filetree + like representations, defined by the EMD1.0 and `emdfile` specifications, + e.g. + + Root + |--metadata + | |--calibration + | + |--some_object(e.g.datacube) + | |--another_object(e.g.max_dp) + | |--etc. + | + |--one_more_object(e.g.crystal) + | |--etc. + : + + Calibrations are metadata which always live in the root of such a tree. + Running `data.calibration` returns the calibrations from the tree root, + and therefore the same calibration instance is referred to be all objects + in the same tree. The root itself is accessible from any Data instance + as + + >>> data.root + + To examine the tree of a Data instance, in a Python interpreter do + + >>> data.tree(True) + + to display the whole data tree, and + + >>> data.tree() + + to display the tree of from the current node on, i.e. the branch + downstream of `data`. + + Calling + + >>> data.calibration + + will raise a warning and return None if no root calibrations are found. + + Some objects should be modified when the calibrations change - these + objects must have .calibrate() method, which is called any time relevant + calibration parameters change if the object has been registered with + the calibrations. + + To transfer `data` from it's current tree to another existing tree, use + + >>> data.attach(some_other_data) + + which will move the data to the new tree. If the data was registered with + it's old calibrations, this will also de-register it there and register + it with the new calibrations such that .calibrate() is called when it + should be. + + See also the Calibration docstring. + """ + + def __init__( + self, + calibration = None + ): + assert(isinstance(self,Node)), "Data instances must inherit from Node" + assert(calibration is None or isinstance(calibration,Calibration)), f"calibration must be None or a Calibration instance, not type {type(calibration)}" + + + # set up calibration + EMD tree + if calibration is None: + if self.root is None: + root = Root( name=self.name+"_root" ) + root.tree( self ) + self.calibration = Calibration() + elif 'calibration' not in self.root.metadata: + self.calibration = Calibration() + else: + pass + elif calibration.root is None: + if self.root is None: + root = Root( name=self.name+"_root" ) + root.tree(self) + self.calibration = calibration + elif 'calibration' not in self.root.metadata: + self.calibration = calibration + else: + warnings.warn("A calibration was passed to instantiate a new Data instance, but the instance already has a calibration. The passed calibration *WAS NOT* attached. To attach the new calibration and overwrite the existing calibration, use `data.calibration = new_calibration`") + pass + else: + if self.root is None: + calibration.root.tree(self) + self.calibration = calibration + elif 'calibration' not in self.root.metadata: + self.calibration = calibration + warnings.warn("A calibration was passed to instantiate a new Data instance. The Data already had a root but no calibration, and the calibration already exists in a different root. The calibration has been added and now lives in both roots, and can therefore be modified from either place!") + else: + warnings.warn("A calibration was passed to instantiate a new Data instance, however the Data already has a root and calibration, and the calibration already has a root!! The passed calibration *WAS NOT* attached. To attach the new calibration and overwrite the existing calibration, use `data.calibration = new_calibration.") + + + # calibration property + + @property + def calibration(self): + try: + return self.root.metadata['calibration'] + except KeyError: + warnings.warn("No calibration metadata found in root, returning None") + return None + except AttributeError: + warnings.warn("No root or root metadata found, returning None") + return None + + @calibration.setter + def calibration(self, x): + assert( isinstance( x, Calibration) ) + if 'calibration' in self.root.metadata.keys(): + warnings.warn("A 'calibration' key already exists in root.metadata - overwriting...") + x.name = 'calibration' + self.root.metadata['calibration'] = x + + + # transfer trees + + def attach(self,node): + """ + Attach `node` to the current object's tree, attaching calibration and detaching + calibrations as needed. + """ + assert(isinstance(node,Node)), f"node must be a Node, not type {type(node)}" + register = False + if hasattr(node,'calibration'): + if node.calibration is not None: + if node in node.calibration._targets: + register = True + node.calibration.unregister_target(node) + if node.root is None: + self.tree(node) + else: + self.graft(node) + if register: + self.calibration.register_target(node) + + + diff --git a/py4DSTEM/classes/diffractionslice.py b/py4DSTEM/data/diffractionslice.py similarity index 86% rename from py4DSTEM/classes/diffractionslice.py rename to py4DSTEM/data/diffractionslice.py index 9f75dc43d..40104cf80 100644 --- a/py4DSTEM/classes/diffractionslice.py +++ b/py4DSTEM/data/diffractionslice.py @@ -2,7 +2,7 @@ # diffraction-shaped data from emdfile import Array -from py4DSTEM.classes import Data +from py4DSTEM.data import Data from typing import Optional,Union import numpy as np @@ -18,7 +18,8 @@ def __init__( data: np.ndarray, name: Optional[str] = 'diffractionslice', units: Optional[str] = 'intensity', - slicelabels: Optional[Union[bool,list]] = None + slicelabels: Optional[Union[bool,list]] = None, + calibration = None ): """ Accepts: @@ -39,6 +40,13 @@ def __init__( units = units, slicelabels = slicelabels ) + # initialize as Data + Data.__init__( + self, + calibration + ) + + # read diff --git a/py4DSTEM/data/propagating_calibration.py b/py4DSTEM/data/propagating_calibration.py new file mode 100644 index 000000000..a80382338 --- /dev/null +++ b/py4DSTEM/data/propagating_calibration.py @@ -0,0 +1,85 @@ +# Define decorators call_* which, when used to decorate class methods, +# calls all objects in a list _targets? to call some method *. + +import warnings + + +# This is the abstract pattern: + +class call_method(object): + """ + A decorator which, when attached to a method of SomeClass, + causes `method` to be called on any objects in the + instance's `_targets` list, following execution of + the decorated function. + """ + def __init__(self, func): + self.func = func + + def __call__(self, *args, **kwargs): + """ + Update the parameters the caller wanted by calling the wrapped + method, then loop through the list of targets and call their + `calibrate` methods. + """ + self.func(*args,**kwargs) + some_object = args[0] + assert hasattr(some_object, "_targets"), "SomeObject object appears to be in an invalid state. _targets attribute is missing." + for target in some_object._targets: + if hasattr(target,'method') and callable(target.method): + try: + target.method() + except Exception as err: + print(f"Attempted to call .method(), but this raised an error: {err}") + else: + # warn or pass or error out here, as needs be + #pass + warnings.warn(f"{target} is registered as a target but does not appear to have a .method() callable") + + def __get__(self, instance, owner): + """ + This is some magic to make sure that the Calibration instance + on which the decorator was called gets passed through and + everything dispatches correctly (by making sure `instance`, + the Calibration instance to which the call was directed, gets + placed in the `self` slot of the wrapped method (which is *not* + actually bound to the instance due to this decoration.) using + partial application of the method.) + """ + from functools import partial + return partial(self.__call__, instance) + + +# This is a functional decorator, @call_calibrate: + +# calls: calibrate() +# targets: _targets + +class call_calibrate(object): + """ + Decorated methods cause all targets in _targets to call .calibrate(). + """ + def __init__(self, func): + self.func = func + + def __call__(self, *args, **kwargs): + """ + """ + self.func(*args,**kwargs) + calibration = args[0] + assert hasattr(calibration, "_targets"), "Calibration object appears to be in an invalid state. _targets attribute is missing." + for target in calibration._targets: + if hasattr(target,'calibrate') and callable(target.calibrate): + try: + target.calibrate() + except Exception as err: + print(f"Attempted to calibrate object {target} but this raised an error: {err}") + else: + pass + + def __get__(self, instance, owner): + """ + """ + from functools import partial + return partial(self.__call__, instance) + diff --git a/py4DSTEM/classes/qpoints.py b/py4DSTEM/data/qpoints.py similarity index 97% rename from py4DSTEM/classes/qpoints.py rename to py4DSTEM/data/qpoints.py index 62788f518..c29127406 100644 --- a/py4DSTEM/classes/qpoints.py +++ b/py4DSTEM/data/qpoints.py @@ -1,7 +1,7 @@ # Defines the QPoints class, which stores PointLists with fields 'qx','qy','intensity' from emdfile import PointList -from py4DSTEM.classes import Data +from py4DSTEM.data import Data from typing import Optional import numpy as np diff --git a/py4DSTEM/classes/realslice.py b/py4DSTEM/data/realslice.py similarity index 86% rename from py4DSTEM/classes/realslice.py rename to py4DSTEM/data/realslice.py index 1a75f1e8e..205cbc1ab 100644 --- a/py4DSTEM/classes/realslice.py +++ b/py4DSTEM/data/realslice.py @@ -1,7 +1,7 @@ # Defines the RealSlice class, which stores 2(+1)D real-space shaped data from emdfile import Array -from py4DSTEM.classes import Data +from py4DSTEM.data import Data from typing import Optional,Union import numpy as np @@ -17,7 +17,8 @@ def __init__( data: np.ndarray, name: Optional[str] = 'realslice', units: Optional[str] = 'intensity', - slicelabels: Optional[Union[bool,list]] = None + slicelabels: Optional[Union[bool,list]] = None, + calibration = None ): """ Accepts: @@ -37,6 +38,11 @@ def __init__( units = 'intensity', slicelabels = slicelabels ) + # initialize as Data + Data.__init__( + self, + calibration + ) # read diff --git a/py4DSTEM/datacube/__init__.py b/py4DSTEM/datacube/__init__.py new file mode 100644 index 000000000..881966e2f --- /dev/null +++ b/py4DSTEM/datacube/__init__.py @@ -0,0 +1,7 @@ +_emd_hook = True + +from py4DSTEM.datacube.datacube import DataCube +from py4DSTEM.datacube.virtualimage import VirtualImage +from py4DSTEM.datacube.virtualdiffraction import VirtualDiffraction + + diff --git a/py4DSTEM/classes/methods/datacube_methods.py b/py4DSTEM/datacube/datacube.py similarity index 56% rename from py4DSTEM/classes/methods/datacube_methods.py rename to py4DSTEM/datacube/datacube.py index e6ce58632..ae3a82a36 100644 --- a/py4DSTEM/classes/methods/datacube_methods.py +++ b/py4DSTEM/datacube/datacube.py @@ -1,18 +1,341 @@ -# Functions to become DataCube methods +# Defines the DataCube class, which stores 4D-STEM datacubes import numpy as np -from scipy.ndimage import distance_transform_edt, binary_fill_holes, gaussian_filter1d from scipy.interpolate import interp1d +from scipy.ndimage import (binary_opening, binary_dilation, + distance_transform_edt, binary_fill_holes, gaussian_filter1d, gaussian_filter) +from typing import Optional,Union -from emdfile import Array, Metadata, Node +from emdfile import Array, Metadata, Node, Root, tqdmnd +from py4DSTEM.data import Data, Calibration +from py4DSTEM.datacube.virtualimage import DataCubeVirtualImager +from py4DSTEM.datacube.virtualdiffraction import DataCubeVirtualDiffraction - -class DataCubeMethods: +class DataCube( + Array, + Data, + DataCubeVirtualImager, + DataCubeVirtualDiffraction, + ): """ - A container for DataCube object instance methods. + Storage and processing methods for 4D-STEM datasets. """ + def __init__( + self, + data: np.ndarray, + name: Optional[str] = 'datacube', + slicelabels: Optional[Union[bool,list]] = None, + calibration: Optional[Union[Calibration,None]] = None, + ): + """ + Accepts: + data (np.ndarray): the data + name (str): the name of the datacube + calibration (None or Calibration or 'pass'): default (None) + creates and attaches a new Calibration instance to root + metadata, or, passing a Calibration instance uses this instead. + slicelabels (None or list): names for slices if this is a + stack of datacubes + + Returns: + A new DataCube instance. + """ + # initialize as an Array + Array.__init__( + self, + data = data, + name = name, + units = 'pixel intensity', + dim_names = [ + 'Rx', + 'Ry', + 'Qx', + 'Qy' + ], + slicelabels = slicelabels + ) + + # initialize as Data + Data.__init__( + self, + calibration + ) + + # register with calibration + self.calibration.register_target(self) + + # cartesian coords + self.calibrate() + + # polar coords + self.polar = None + + + + + def calibrate(self): + """ + Calibrate the coordinate axes of the datacube. Using the calibrations + at self.calibration, sets the 4 dim vectors (Qx,Qy,Rx,Ry) according + to the pixel size, units and origin positions, then updates the + meshgrids representing Q and R space. + """ + assert(self.calibration is not None), "No calibration found!" + + # Get calibration values + rpixsize = self.calibration.get_R_pixel_size() + rpixunits = self.calibration.get_R_pixel_units() + qpixsize = self.calibration.get_Q_pixel_size() + qpixunits = self.calibration.get_Q_pixel_units() + origin = self.calibration.get_origin_mean() + if origin is None or origin==(None,None): + origin = (0,0) + + # Calc dim vectors + dim_rx = np.arange(self.R_Nx)*rpixsize + dim_ry = np.arange(self.R_Ny)*rpixsize + dim_qx = -origin[0] + np.arange(self.Q_Nx)*qpixsize + dim_qy = -origin[1] + np.arange(self.Q_Ny)*qpixsize + + # Set dim vectors + self.set_dim( + 0, + dim_rx, + units = rpixunits + ) + self.set_dim( + 1, + dim_ry, + units = rpixunits + ) + self.set_dim( + 2, + dim_qx, + units = qpixunits + ) + self.set_dim( + 3, + dim_qy, + units = qpixunits + ) + + # Set meshgrids + self._qxx,self._qyy = np.meshgrid( dim_qx,dim_qy ) + self._rxx,self._ryy = np.meshgrid( dim_rx,dim_ry ) + + self._qyy_raw,self._qxx_raw = np.meshgrid( np.arange(self.Q_Ny),np.arange(self.Q_Nx) ) + self._ryy_raw,self._rxx_raw = np.meshgrid( np.arange(self.R_Ny),np.arange(self.R_Nx) ) + + + + # coordinate meshgrids + @property + def rxx(self): + return self._rxx + @property + def ryy(self): + return self._ryy + @property + def qxx(self): + return self._qxx + @property + def qyy(self): + return self._qyy + @property + def rxx_raw(self): + return self._rxx_raw + @property + def ryy_raw(self): + return self._ryy_raw + @property + def qxx_raw(self): + return self._qxx_raw + @property + def qyy_raw(self): + return self._qyy_raw + + # coordinate meshgrids with shifted origin + def qxxs(self,rx,ry): + qx0_shift = self.calibration.get_qx0shift(rx,ry) + if qx0_shift is None: + raise Exception("Can't compute shifted meshgrid - origin shift is not defined") + return self.qxx - qx0_shift + def qyys(self,rx,ry): + qy0_shift = self.calibration.get_qy0shift(rx,ry) + if qy0_shift is None: + raise Exception("Can't compute shifted meshgrid - origin shift is not defined") + return self.qyy - qy0_shift + + + + # shape properties + + ## shape + + # FOV + @property + def R_Nx(self): + return self.data.shape[0] + @property + def R_Ny(self): + return self.data.shape[1] + @property + def Q_Nx(self): + return self.data.shape[2] + @property + def Q_Ny(self): + return self.data.shape[3] + + @property + def Rshape(self): + return (self.data.shape[0],self.data.shape[1]) + @property + def Qshape(self): + return (self.data.shape[2],self.data.shape[3]) + + @property + def R_N(self): + return self.R_Nx*self.R_Ny + + # aliases + qnx = Q_Nx + qny = Q_Ny + rnx = R_Nx + rny = R_Ny + rshape = Rshape + qshape = Qshape + rn = R_N + + + + ## pixel size / units + + # Q + @property + def Q_pixel_size(self): + return self.calibration.get_Q_pixel_size() + @property + def Q_pixel_units(self): + return self.calibration.get_Q_pixel_units() + + # R + @property + def R_pixel_size(self): + return self.calibration.get_R_pixel_size() + @property + def R_pixel_units(self): + return self.calibration.get_R_pixel_units() + + # aliases + qpixsize = Q_pixel_size + qpixunit = Q_pixel_units + rpixsize = R_pixel_size + rpixunit = R_pixel_units + + + + + + + def copy(self): + """ + Copys datacube + """ + from py4DSTEM import DataCube + new_datacube = DataCube( + data = self.data.copy(), + name = self.name, + calibration = self.calibration.copy(), + slicelabels = self.slicelabels, + ) + + Qpixsize = new_datacube.calibration.get_Q_pixel_size() + Qpixunits = new_datacube.calibration.get_Q_pixel_units() + Rpixsize = new_datacube.calibration.get_R_pixel_size() + Rpixunits = new_datacube.calibration.get_R_pixel_units() + + new_datacube.set_dim( + 0, + [0,Rpixsize], + units = Rpixunits, + name = 'Rx' + ) + new_datacube.set_dim( + 1, + [0,Rpixsize], + units = Rpixunits, + name = 'Ry' + ) + + new_datacube.set_dim( + 2, + [0,Qpixsize], + units = Qpixunits, + name = 'Qx' + ) + new_datacube.set_dim( + 3, + [0,Qpixsize], + units = Qpixunits, + name = 'Qy' + ) + + return new_datacube + + + + + + + + + # I/O + + # to_h5 is inherited from Array + + # read + @classmethod + def _get_constructor_args(cls,group): + """ Construct a datacube with no calibration / metadata + """ + # We only need some of the Array constructors; + # dim vector/units are passed through when Calibration + # is loaded, and the runtim dim vectors are then set + # in _add_root_links + ar_args = Array._get_constructor_args(group) + + args = { + 'data': ar_args['data'], + 'name': ar_args['name'], + 'slicelabels': ar_args['slicelabels'], + 'calibration': None + } + + return args + + + def _add_root_links(self,group): + """ When reading from file, link to calibration metadata, + then use it to populate the datacube dim vectors + """ + # Link to the datacube + self.calibration._datacube = self + + # Populate dim vectors + self.calibration.set_Q_pixel_size( self.calibration.get_Q_pixel_size() ) + self.calibration.set_R_pixel_size( self.calibration.get_R_pixel_size() ) + self.calibration.set_Q_pixel_units( self.calibration.get_Q_pixel_units() ) + self.calibration.set_R_pixel_units( self.calibration.get_R_pixel_units() ) + + return + + + + + # Class methods + def add( self, data, @@ -28,7 +351,7 @@ def add( data = data, name = name ) - self.tree( data ) + self.attach( data ) def set_scan_shape( self, @@ -108,16 +431,26 @@ def crop_R( def bin_Q( self, - N + N, + dtype = None ): """ Bins the data in diffraction space by bin factor N - Accepts: - N (int): the binning factor + Parameters + ---------- + N : int + The binning factor + dtype : a datatype (optional) + Specify the datatype for the output. If not passed, the datatype + is left unchanged + + Returns + ------ + datacube : DataCube """ from py4DSTEM.preprocess import bin_data_diffraction - d = bin_data_diffraction(self,N) + d = bin_data_diffraction(self,N,dtype) return d def pad_Q( @@ -233,647 +566,103 @@ def filter_hot_pixels( - - - - # Diffraction imaging - - def get_virtual_diffraction( - self, - method = 'max', - mode = None, - geometry = None, - calibrated = False, - shift_center = False, - verbose = True, - name = 'virtual_diffracton', - returncalc = True, - ): - """ - Function to calculate virtual diffraction patterns - - Args: - datacube (Datacube) : datacube class object which stores 4D-dataset - needed for calculation - method (str) : defines method used for diffraction pattern, options - are 'mean', 'median', and 'max' - mode (str) : defines mode for selecting area in real space to use - for virtual diffraction. The default is None, which means no - geometry will be applied and the whole datacube will be used - for the calculation. Options: - - 'point' uses singular point as detector - - 'circle' or 'circular' uses round detector, like bright - field - - 'annular' or 'annulus' uses annular detector, like dark - field - - 'rectangle', 'square', 'rectangular', uses rectangular - detector - - 'mask' flexible detector, any 2D array - geometry (variable) : valid entries are determined by the `mode`, - values in pixels argument, as follows. The default is None, which - means no geometry will be applied and the whole datacube will be - used for the calculation. If mode is None the geometry will not - be applied. - - 'point': 2-tuple, (rx,ry), ints - - 'circle' or 'circular': nested 2-tuple, ((rx,ry),radius), - - 'annular' or 'annulus': nested 2-tuple, - ((rx,ry),(radius_i,radius_o)) - - 'rectangle', 'square', 'rectangular': 4-tuple, - (rxmin,rxmax,rymin,rymax) - - `mask`: flexible detector, any boolean or floating point 2D - array with the same shape as datacube.Rshape - calibrated (bool): if True, geometry is specified in units of 'A' - instead of pixels. The datacube's calibrations must have its - `"R_pixel_units"` parameter set to "A". If mode is None the - geometry and calibration will not be applied. - shift_center (bool): if True, the difraction patterns are shifted to - account for beam shift or the changing of the origin through the - scan. The datacube's calibration['origin'] parameter must be set. - Only 'max' and 'mean' supported for this option. - verbose (bool): if True, show progress bar - - Returns: - (VirtualDiffraction): the diffraction image - """ - - # perform computation - from py4DSTEM.classes.virtualdiffraction import VirtualDiffraction - from py4DSTEM.process.virtualdiffraction import get_virtual_diffraction - dp = get_virtual_diffraction( - self, - method = method, - mode = mode, - geometry = geometry, - shift_center = shift_center, - calibrated = calibrated, - verbose = verbose, - ) - - # wrap with a py4dstem class - dp = VirtualDiffraction( - data = dp, - name = name - ) - - # add the args used to gen this dp as metadata - dp.metadata = Metadata( - name='gen_params', - data = { - #'gen_func' : - 'method' : method, - 'mode' : mode, - 'geometry' : geometry, - 'shift_center' : shift_center, - 'calibrated' : calibrated, - } - ) - - # add to the tree - self.tree( dp ) - - # return - if returncalc: - return dp - - - def get_dp_max( - self, - method = 'max', - mode = None, - geometry = None, - calibrated = False, - shift_center = False, - verbose = True, - name = 'dp_max', - returncalc = True, - ): - """ - Function to calculate maximum virtual diffraction. Default captures - pattern across entire 4D-dataset. - - Args: - datacube (Datacube) : datacube class object which stores 4D-dataset - needed for calculation - mode (str) : defines mode for selecting area in real space to use for - virtual diffraction. The default is None, which means no - geometry will be applied and the whole datacube will be used - for the calculation. Options: - - 'point' uses singular point as detector - - 'circle' or 'circular' uses round detector, like bright - field - - 'annular' or 'annulus' uses annular detector, like dark - field - - 'rectangle', 'square', 'rectangular', uses rectangular - detector - - 'mask' flexible detector, any 2D array - geometry (variable) : valid entries are determined by the `mode`, - values in pixels argument, as follows. The default is None, which - means no geometry will be applied and the whole datacube will be - used for the calculation. If mode is None the geometry will not - be applied. - - 'point': 2-tuple, (rx,ry), - rx and ry are each single float or int to define center - - 'circle' or 'circular': nested 2-tuple, ((rx,ry),radius), - - 'annular' or 'annulus': nested 2-tuple, - ((rx,ry),(radius_i,radius_o)), - - 'rectangle', 'square', 'rectangular': 4-tuple, - (xmin,xmax,ymin,ymax) - - `mask`: flexible detector, any boolean or floating point 2D - array with the same shape as datacube.Rshape - calibrated (bool): if True, geometry is specified in units of 'A' - instead of pixels. The datacube's calibrations must have its - `"R_pixel_units"` parameter set to "A". If mode is None the - geometry and calibration will not be applied. - shift_center (bool) : if True, the difraction patterns are shifted to - account for beam shift or the changing of the origin through the - scan. The datacube's calibration['origin'] parameter must be set. - Only 'max' and 'mean' supported for this option. - verbose (bool): if True, show progress bar - - Returns: - (VirtualDiffraction): the diffraction image - """ - - # perform computation - from py4DSTEM.classes.virtualdiffraction import VirtualDiffraction - from py4DSTEM.process.virtualdiffraction import get_virtual_diffraction - dp = get_virtual_diffraction( - self, - method = method, - mode = mode, - geometry = geometry, - shift_center = shift_center, - calibrated = calibrated, - verbose = verbose, - ) - - # wrap with a py4dstem class - dp = VirtualDiffraction( - data = dp, - name = name - ) - - # add the args used to gen this dp as metadata - dp.metadata = Metadata( - name='gen_params', - data = { - #'gen_func' : - 'method' : method, - 'mode' : mode, - 'geometry' : geometry, - 'shift_center' : shift_center, - 'calibrated' : calibrated, - } - ) - - # add to the tree - self.tree( dp ) - - # return - if returncalc: - return dp - - - def get_dp_mean( - self, - method = 'mean', - mode = None, - geometry = None, - calibrated = False, - shift_center = False, - verbose = True, - name = 'dp_mean', - returncalc = True, - ): - """ - Function to calculate mean virtual diffraction. Default captures pattern - across entire 4D-dataset. - - Args: - datacube (Datacube) : datacube class object which stores 4D-dataset - needed for calculation - mode (str) : defines mode for selecting area in real space to use for - virtual diffraction. The default is None, which means no - geometry will be applied and the whole datacube will be used - for the calculation. Options: - - 'point' uses singular point as detector - - 'circle' or 'circular' uses round detector, like bright - field - - 'annular' or 'annulus' uses annular detector, like dark - field - - 'rectangle', 'square', 'rectangular', uses rectangular - detector - - 'mask' flexible detector, any 2D array - geometry (variable) : valid entries are determined by the `mode`, - values in pixels argument, as follows. The default is None, - which means no geometry will be applied and the whole datacube - will be used for the calculation. If mode is None the geometry - will not be applied. - - 'point': 2-tuple, (rx,ry), - qx and qy are each single float or int to define center - - 'circle' or 'circular': nested 2-tuple, ((rx,ry),radius), - qx, qy and radius, are each single float or int - - 'annular' or 'annulus': nested 2-tuple, - ((rx,ry),(radius_i,radius_o)), - - 'rectangle', 'square', 'rectangular': 4-tuple, - (xmin,xmax,ymin,ymax) - - `mask`: flexible detector, any boolean or floating point 2D - array with the same shape as datacube.Rshape - calibrated (bool): if True, geometry is specified in units of 'A' - instead of pixels. The datacube's calibrations must have its - `"R_pixel_units"` parameter set to "A". If mode is None the - geometry and calibration will not be applied. - shift_center (bool): if True, the diffraction patterns are shifted to - account for beam shift or the changing of the origin through the - scan. The datacube's calibration['origin'] parameter must be set. - Only 'max' and 'mean' supported for this option. - verbose (bool) : if True, show progress bar - - Returns: - (VirtualDiffraction): the diffraction image - """ - - # perform computation - from py4DSTEM.classes.virtualdiffraction import VirtualDiffraction - from py4DSTEM.process.virtualdiffraction import get_virtual_diffraction - dp = get_virtual_diffraction( - self, - method = method, - mode = mode, - geometry = geometry, - shift_center = shift_center, - calibrated = calibrated, - verbose = verbose, - ) - - # wrap with a py4dstem class - dp = VirtualDiffraction( - data = dp, - name = name, - ) - - # add the args used to gen this dp as metadata - dp.metadata = Metadata( - name='gen_params', - data = { - #'gen_func' : - 'method' : method, - 'mode' : mode, - 'geometry' : geometry, - 'shift_center' : shift_center, - 'calibrated' : calibrated, - } - ) - - # add to the tree - self.tree( dp ) - - # return - if returncalc: - return dp - - def get_dp_median( - self, - method = 'median', - mode = None, - geometry = None, - calibrated = False, - shift_center = False, - verbose = True, - name = 'dp_median', - returncalc = True, - ): - """ - Function to calculate median virtual diffraction. Default captures - pattern across entire 4D-dataset. - - Args: - datacube (Datacube) : datacube class object which stores 4D-dataset - needed for calculation - mode (str) : defines mode for selecting area in real space to use for - virtual diffraction. The default is None, which means no - geometry will be applied and the whole datacube will be used - for the calculation. Options: - - 'point' uses singular point as detector - - 'circle' or 'circular' uses round detector, like bright - field - - 'annular' or 'annulus' uses annular detector, like dark - field - - 'rectangle', 'square', 'rectangular', uses rectangular - detector - - 'mask' flexible detector, any 2D array - geometry (variable) : valid entries are determined by the `mode`, - values in pixels argument, as follows. The default is None, - which means no geometry will be applied and the whole datacube - will be used for the calculation. If mode is None the geometry - will not be applied. - - 'point': 2-tuple, (rx,ry), - - 'circle' or 'circular': nested 2-tuple, ((rx,ry),radius), - - 'annular' or 'annulus': nested 2-tuple, - ((rx,ry),(radius_i,radius_o)), - - 'rectangle', 'square', 'rectangular': 4-tuple, - (xmin,xmax,ymin,ymax) - - `mask`: flexible detector, any boolean or floating point 2D - array with the same shape as datacube.Rshape - calibrated (bool): if True, geometry is specified in units of 'A' - instead of pixels. The datacube's calibrations must have its - `"R_pixel_units"` parameter set to "A". If mode is None the - geometry and calibration will not be applied. - shift_center (bool) : if True, the diffraction patterns are shifted to - account for beam shift or the changing of the origin through the - scan. The datacube's calibration['origin'] parameter must be set. - Only 'max' and 'mean' supported for this option. - verbose (bool): if True, show progress bar - - Returns: - (VirtualDiffraction): the diffraction image - """ - - # perform computation - from py4DSTEM.classes.virtualdiffraction import VirtualDiffraction - from py4DSTEM.process.virtualdiffraction import get_virtual_diffraction - dp = get_virtual_diffraction( - self, - method = method, - mode = mode, - geometry = geometry, - shift_center = shift_center, - calibrated = calibrated, - verbose = verbose, - ) - - # wrap with a py4dstem class - dp = VirtualDiffraction( - data = dp, - name = name, - ) - - # add the args used to gen this dp as metadata - dp.metadata = Metadata( - name='gen_params', - data = { - #'gen_func' : - 'method' : method, - 'mode' : mode, - 'geometry' : geometry, - 'shift_center' : shift_center, - 'calibrated' : calibrated, - } - ) - - # add to the tree - self.tree( dp ) - - # return - if returncalc: - return dp - - - - # Virtual imaging - - def get_virtual_image( - self, - mode, - geometry, - centered = False, - calibrated = False, - shift_center = False, - verbose = True, - dask = False, - return_mask = False, - name = 'virtual_image', - returncalc = True, - test_config = False - ): - """ - Get a virtual image and store it in `datacube`s tree under `name`. - The kind of virtual image is specified by the `mode` argument. - - Args: - mode (str): defines geometry mode for calculating virtual image - options: - - 'point' uses singular point as detector - - 'circle' or 'circular' uses round detector, like bright - field - - 'annular' or 'annulus' uses annular detector, like dark - field - - 'rectangle', 'square', 'rectangular', uses rectangular - detector - - 'mask' flexible detector, any 2D array - geometry (variable) : valid entries are determined by the `mode`, - values in - pixels argument, as follows: - - 'point': 2-tuple, (qx,qy), ints - - 'circle' or 'circular': nested 2-tuple, ((qx,qy),radius), - - 'annular' or 'annulus': nested 2-tuple, - ((qx,qy),(radius_i,radius_o)), - - 'rectangle', 'square', 'rectangular': 4-tuple, - (xmin,xmax,ymin,ymax) - - `mask`: any boolean or floating point 2D array with the same - size as datacube.Qshape - centered (bool): if False, the origin is in the upper left corner. - If True, the origin is set to the mean origin in the datacube - calibrations, so that a bright-field image could be specified - with, e.g., geometry = ((0,0),R). The origin can set with - datacube.calibration.set_origin(). For `mode="mask"`, - has no effect. Default is False. - calibrated (bool): if True, geometry is specified in units of 'A^-1' - instead of pixels. The datacube's calibrations must have its - `"Q_pixel_units"` parameter set to "A^-1". For `mode="mask"`, has - no effect. Default is False. - shift_center (bool): if True, the mask is shifted at each real space - position to account for any shifting of the origin of the - diffraction images. The datacube's calibration['origin'] - parameter must be set. The shift applied to each pattern is the - difference between the local origin position and the mean origin - position over all patterns, rounded to the nearest integer for - speed. Default is False. - verbose (bool): if True, show progress bar - dask (bool): if True, use dask arrays - return_mask (bool): if False (default) returns a virtual image as - usual. If True, does *not* generate or return a virtual image, - instead returning the mask that would be used in virtual image - computation for any call to this function where - `shift_center = False`. Otherwise, must be a 2-tuple of integers - corresponding to a scan position (rx,ry); in this case, returns - the mask that would be used for virtual image computation at this - scan position with `shift_center` set to `True`. Setting - return_mask to True does not add anything to the datacube's tree - name (str): the output object's name - returncalc (bool): if True, returns the output - Returns: - (Optional): if returncalc is True, returns the VirtualImage - """ - # perform computation - from py4DSTEM.classes.virtualimage import VirtualImage - from py4DSTEM.process.virtualimage import get_virtual_image - im = get_virtual_image( - self, - mode = mode, - geometry = geometry, - centered = centered, - calibrated = calibrated, - shift_center = shift_center, - verbose = verbose, - dask = dask, - return_mask = return_mask, - test_config = test_config - ) - - # if a mask is requested, skip the remaining i/o functionality - if return_mask is not False: - return im - - # wrap with a py4dstem class - im = VirtualImage( - data = im, - name = name, - ) - - # add generating params as metadata - im.metadata = Metadata( - name = 'gen_params', - data = { - 'mode' : mode, - 'geometry' : geometry, - 'shift_center' : shift_center, - 'centered' : centered, - 'calibrated' : calibrated, - 'verbose' : verbose, - 'dask' : dask, - 'return_mask' : return_mask, - 'test_config' : test_config - } - ) - - # add to the tree - self.tree( im ) - - # return - if returncalc: - return im - - - # Position detector - - def position_detector( - self, - mode, - geometry, - scan_position = None, - centered = None, - calibrated = None, - shift_center = None, - invert = False, - color = 'r', - alpha = 0.7, - ): - """ - Display a diffraction space image with an overlaid mask representing - a virtual detector. - - Args: - mode: see py4DSTEM.process.get_virtual_image - geometry: see py4DSTEM.process.get_virtual_image - scan_position: if None, positions the unshifted detector over the - mean or max diffraction pattern. Otherwise, must be a tuple - (rx,ry) of ints, and a detector is positioned over the - diffraction pattern at this position, including shifts if they - would be applied for this dataset (i.e. if it contains the - appropriate calibrations) - centered (bool): if False, the origin is in the upper left corner. - If True, the origin is set to the mean origin in the datacube - calibrations, so that a bright-field image could be specified - with, e.g., geometry = ((0,0),R). The origin can set with - datacube.calibration.set_origin(). For `mode="mask"`, - has no effect. Default is False. - calibrated (bool): if True, geometry is specified in units of 'A^-1' - instead of pixels. The datacube's calibrations must have its - `"Q_pixel_units"` parameter set to "A^-1". For `mode="mask"`, has - no effect. - shift_center (bool): if True, the mask is shifted at each real space - position to account for any shifting of the origin of the - diffraction images. The datacube's calibration['origin'] - parameter must be set. The shift applied to each pattern is the - difference between the local origin position and the mean origin - position over all patterns, rounded to the nearest integer for - speed. - invert (bool): if True, invert the display mask - """ - - # parse inputs - if scan_position is None: - data = self - else: - data = (self,scan_position[0],scan_position[1]) - - # make and show visualization - from py4DSTEM.visualize import position_detector - position_detector( - data, - mode, - geometry, - centered, - calibrated, - shift_center, - invert = invert, - color = color, - alpha = alpha, - ) - - - - # Probe def get_vacuum_probe( self, ROI = None, - name = 'probe', - returncalc = True, + align = True, + mask = None, + threshold = 0.2, + expansion = 12, + opening = 3, + verbose = False, + returncalc = True ): """ - Computes a vacuum probe from the DataCube by aligning and averaging - either all or some subset of the diffraction patterns. + Computes a vacuum probe. - Args: - ROI (None or boolean array or tuple): if None, uses the whole - datacube. Otherwise, uses a subset of diffraction patterns. - If `ROI` is a boolean array, it should be Rspace shaped, and - diffraction patterns where True are used. Else should be - a 4-tuple representing (Rxmin,Rxmax,Rymin,Rymax) of a - rectangular region to use. + Which diffraction patterns are included in the calculation is specified + by the `ROI` parameter. Diffraction patterns are aligned before averaging + if `align` is True (default). A global mask is applied to each diffraction + pattern before aligning/averaging if `mask` is specified. After averaging, + a final masking step is applied according to the parameters `threshold`, + `expansion`, and `opening`. - Returns: - (Probe) a Probe instance + Parameters + ---------- + ROI : optional, boolean array or len 4 list/tuple + If unspecified, uses the whole datacube. If a boolean array is + passed must be real-space shaped, and True pixels are used. If a + 4-tuple is passed, uses the region inside the limits + (rx_min,rx_max,ry_min,ry_max) + align : optional, bool + if True, aligns the probes before averaging + mask : optional, array + mask applied to each diffraction pattern before alignment and + averaging + threshold : float + in the final masking step, values less than max(probe)*threshold + are considered outside the probe + expansion : int + number of pixels by which the final mask is expanded after + thresholding + opening : int + size of binary opening applied to the final mask to eliminate stray + bright pixels + verbose : bool + toggles verbose output + returncalc : bool + if True, returns the answer + Returns + ------- + probe : Probe, optional + the vacuum probe """ + from py4DSTEM.process.utils import get_shifted_ar, get_shift + from py4DSTEM.braggvectors import Probe - # perform computation - from py4DSTEM.classes.probe import Probe - from py4DSTEM.process.probe import get_vacuum_probe + # parse region to use if ROI is None: - x = get_vacuum_probe( - self - ) + ROI = np.ones(self.Rshape,dtype=bool) + elif isinstance(ROI,tuple): + assert(len(ROI)==4), "if ROI is a tuple must be length 4" + _ROI = np.ones(self.Rshape,dtype=bool) + ROI = _ROI[ROI[0]:ROI[1],ROI[2]:ROI[3]] else: - x = get_vacuum_probe( - self, - ROI = ROI - ) - - # wrap with a py4dstem class - x = Probe( - data = x - ) - - # add to the tree - self.tree( x ) - - # return + assert(isinstance(ROI,np.ndarray)) + assert(ROI.shape == self.Rshape) + xy = np.vstack(np.nonzero(ROI)) + length = xy.shape[1] + + # setup global mask + if mask is None: + mask = 1 + else: + assert(mask.shape == self.Qshape) + + # compute average probe + probe = self.data[xy[0,0],xy[1,0],:,:] + for n in tqdmnd(range(1,length)): + curr_DP = self.data[xy[0,n],xy[1,n],:,:] * mask + if align: + xshift,yshift = get_shift(probe, curr_DP) + curr_DP = get_shifted_ar(curr_DP, xshift, yshift) + probe = probe*(n-1)/n + curr_DP/n + + # mask + mask = probe > np.max(probe)*threshold + mask = binary_opening(mask, iterations=opening) + mask = binary_dilation(mask, iterations=1) + mask = np.cos((np.pi/2)*np.minimum(distance_transform_edt(np.logical_not(mask)) / expansion, 1))**2 + probe *= mask + + # make a probe, add to tree, and return + probe = Probe(probe) + self.attach(probe) if returncalc: - return x + return probe @@ -883,7 +672,7 @@ def get_probe_size( thresh_lower=0.01, thresh_upper=0.99, N=100, - plot = True, + plot = False, returncal = True, write_to_cal = True, **kwargs, @@ -933,13 +722,13 @@ def get_probe_size( from py4DSTEM.process.calibration import get_probe_size if dp is None: - assert 'dp_mean' in self._branch.keys(), "calculate .get_dp_mean()" + assert('dp_mean' in self.treekeys), "calculate .get_dp_mean() or pass a `dp` arg" DP = self.tree( 'dp_mean' ).data elif type(dp) == str: - assert dp in self._branch.keys(), "mode not found" + assert(dp in self.treekeys), f"mode {dp} not found in the tree" DP = self.tree( dp ) elif type(dp) == np.ndarray: - assert len(dp.shape) == 2, "must be a 2D array" + assert(dp.shape == self.Qshape), "must be a diffraction space shape 2D array" DP = dp x = get_probe_size( @@ -954,8 +743,7 @@ def get_probe_size( try: self.calibration.set_probe_param(x) except AttributeError: - # should we raise an error here? - pass + raise Exception('writing to calibrations were requested, but could not be completed') #plot results if plot: @@ -1148,7 +936,7 @@ def find_Bragg_disks( variable See above. """ - from py4DSTEM.process.diskdetection import find_Bragg_disks + from py4DSTEM.braggvectors import find_Bragg_disks sigma_cc = sigma if sigma is not None else sigma_cc @@ -1229,7 +1017,7 @@ def find_Bragg_disks( # add to tree if data is None: - self.tree( peaks ) + self.attach( peaks ) # return if returncalc: @@ -1242,8 +1030,11 @@ def find_Bragg_disks( def get_beamstop_mask( self, threshold = 0.25, - distance_edge = 4.0, + distance_edge = 2.0, include_edges = True, + sigma = 0, + use_max_dp = False, + scale_radial = None, name = "mask_beamstop", returncalc = True, ): @@ -1258,6 +1049,12 @@ def get_beamstop_mask( distance_edge (float): How many pixels to expand the mask. include_edges (bool): If set to True, edge pixels will be included in the mask. + sigma (float): + Gaussain blur std to apply to image before thresholding. + use_max_dp (bool): + Use the max DP instead of the mean DP. + scale_radial (float): + Scale from center of image by this factor (can help with edge) name (string): Name of the output array. returncalc (bool): Set to true to return the result. @@ -1266,19 +1063,42 @@ def get_beamstop_mask( """ - # Calculate dp_mean if needed - if not "dp_mean" in self._branch.keys(): - self.get_dp_mean(); - - # normalized dp_mean - int_sort = np.sort(self.tree("dp_mean").data.ravel()) + if scale_radial is not None: + x = np.arange(self.data.shape[2]) * 2.0 / self.data.shape[2] + y = np.arange(self.data.shape[3]) * 2.0 / self.data.shape[3] + ya, xa = np.meshgrid(y - np.mean(y), x - np.mean(x)) + im_scale = 1.0 + np.sqrt(xa**2 + ya**2)*scale_radial + + # Get image for beamstop mask + if use_max_dp: + # if not "dp_mean" in self.tree.keys(): + # self.get_dp_max(); + # im = self.tree["dp_max"].data.astype('float') + if not "dp_max" in self._branch.keys(): + self.get_dp_max(); + im = self.tree("dp_max").data.copy().astype('float') + else: + if not "dp_mean" in self._branch.keys(): + self.get_dp_mean(); + im = self.tree("dp_mean").data.copy() + + # if not "dp_mean" in self.tree.keys(): + # self.get_dp_mean(); + # im = self.tree["dp_mean"].data.astype('float') + + # smooth and scale if needed + if sigma > 0.0: + im = gaussian_filter(im, sigma, mode='nearest') + if scale_radial is not None: + im *= im_scale + + # Calculate beamstop mask + int_sort = np.sort(im.ravel()) ind = np.round(np.clip( int_sort.shape[0]*threshold, 0,int_sort.shape[0])).astype('int') intensity_threshold = int_sort[ind] - - # Use threshold to calculate initial mask - mask_beamstop = self.tree("dp_mean").data >= intensity_threshold + mask_beamstop = im >= intensity_threshold # clean up mask mask_beamstop = np.logical_not(binary_fill_holes(np.logical_not(mask_beamstop))) @@ -1315,7 +1135,7 @@ def get_beamstop_mask( ) # Add to tree - self.tree( mask_beamstop ) + self.tree(x) # return if returncalc: @@ -1366,7 +1186,7 @@ def get_radial_bkgrnd( # define the 2D cartesian coordinate system origin = self.calibration.get_origin() origin = origin[0][rx,ry],origin[1][rx,ry] - qxx,qyy = self.qxx-origin[0], self.qyy-origin[1] + qxx,qyy = self.qxx_raw-origin[0], self.qyy_raw-origin[1] # get distance qr in polar-elliptical coords ellipse = self.calibration.get_ellipse() @@ -1651,11 +1471,6 @@ def get_braggmask( vects = braggvectors.raw[rx,ry] # loop for idx in range(len(vects.data)): - qr = np.hypot(self.qxx-vects.qx[idx], self.qyy-vects.qy[idx]) + qr = np.hypot(self.qxx_raw-vects.qx[idx], self.qyy_raw-vects.qy[idx]) mask = np.logical_and(mask, qr>radius) - return mask - - - - - + return mask \ No newline at end of file diff --git a/py4DSTEM/datacube/virtualdiffraction.py b/py4DSTEM/datacube/virtualdiffraction.py new file mode 100644 index 000000000..f211a35a9 --- /dev/null +++ b/py4DSTEM/datacube/virtualdiffraction.py @@ -0,0 +1,362 @@ +# Virtual diffraction from a self. Includes: +# * VirtualDiffraction - a container for virtual diffraction data + metadata +# * DataCubeVirtualDiffraction - methods inherited by DataCube for virt diffraction + +import numpy as np +import dask.array as da +from typing import Optional +import inspect + +from emdfile import tqdmnd,Metadata +from py4DSTEM.data import Calibration, DiffractionSlice, Data +from py4DSTEM.visualize.show import show + +# Virtual diffraction container class + +class VirtualDiffraction(DiffractionSlice,Data): + """ + Stores a diffraction-space shaped 2D image with metadata + indicating how this image was generated from a self. + """ + def __init__( + self, + data: np.ndarray, + name: Optional[str] = 'virtualdiffraction', + ): + """ + Args: + data (np.ndarray) : the 2D data + name (str) : the name + + Returns: + A new VirtualDiffraction instance + """ + # initialize as a DiffractionSlice + DiffractionSlice.__init__( + self, + data = data, + name = name, + ) + + # read + @classmethod + def _get_constructor_args(cls,group): + """ + Returns a dictionary of args/values to pass to the class constructor + """ + ar_constr_args = DiffractionSlice._get_constructor_args(group) + args = { + 'data' : ar_constr_args['data'], + 'name' : ar_constr_args['name'], + } + return args + + +# DataCube virtual diffraction methods + +class DataCubeVirtualDiffraction: + + def __init__(self): + pass + + def get_virtual_diffraction( + self, + method, + mask = None, + shift_center = False, + subpixel = False, + verbose = True, + name = 'virtual_diffraction', + returncalc = True + ): + """ + Function to calculate virtual diffraction images. + + Parameters + ---------- + method : str + defines method used for averaging/combining diffraction patterns. + Options are ('mean', 'median', 'max') + mask : None or 2D array + if None (default), all pixels are used. Otherwise, must be a boolean + or floating point or complex array with the same shape as real space. + For bool arrays, only True pixels are used in the computation. + Otherwise a weighted average is performed. + shift_center : bool + toggles shifting the diffraction patterns to account for beam shift. + Currently only supported for 'max' and 'mean' modes. Default is + False. + subpixel : bool + if shift_center is True, toggles subpixel shifts via Fourier + interpolation. Ignored if shift_center is False. + verbose : bool + toggles progress bar + name : string + name for the output DiffractionImage instance + returncalc : bool + toggles returning the output + + Returns + ------- + diff_im : DiffractionImage + """ + # parse inputs + assert method in ('max', 'median', 'mean'), 'check doc strings for supported types' + assert(mask is None or mask.shape == self.Rshape), "mask must be None or real-space shaped" + + + # Calculate + + # ...with no center shifting + if shift_center == False: + + # ...for the whole pattern + if mask is None: + if method == 'mean': + virtual_diffraction = np.mean(self.data, axis=(0,1)) + elif method == 'max': + virtual_diffraction = np.max(self.data, axis=(0,1)) + else: + virtual_diffraction = np.median(self.data, axis=(0,1)) + + # ...for boolean masks + elif mask.dtype == bool: + mask_indices = np.nonzero(mask) + if method == 'mean': + virtual_diffraction = np.mean( + self.data[mask_indices[0],mask_indices[1],:,:], axis=0) + elif method == 'max': + virtual_diffraction = np.max( + self.data[mask_indices[0],mask_indices[1],:,:], axis=0) + else: + virtual_diffraction = np.median( + self.data[mask_indices[0],mask_indices[1],:,:], axis=0) + + # ...for complex and floating point masks + else: + # allocate space + if mask.dtype == 'complex': + virtual_diffraction = np.zeros(self.Qshape, dtype='complex') + else: + virtual_diffraction = np.zeros(self.Qshape) + # set computation method + if method == 'mean': + fn = np.sum + elif method == 'max': + fn = np.max + else: + fn = np.median + # loop + for qx,qy in tqdmnd( + self.Q_Nx, + self.Q_Ny, + disable = not verbose, + ): + virtual_diffraction[qx,qy] = fn( np.squeeze(self.data[:,:,qx,qy])*mask ) + # normalize weighted means + if method == 'mean': + virtual_diffraction /= np.sum(mask) + + + # ...with center shifting + else: + assert method in ('max', 'mean'),\ + "only 'mean' and 'max' are supported for center-shifted virtual diffraction" + + # Get calibration metadata + assert(self.calibration.get_origin() is not None), "origin is not calibrated" + x0, y0 = self.calibration.get_origin() + x0_mean, y0_mean = self.calibration.get_origin_mean() + + # get shifts + qx_shift = x0_mean-x0 + qy_shift = y0_mean-y0 + + + # ...for integer shifts + if not subpixel: + + # round shifts -> int + qx_shift = qx_shift.round().astype(int) + qy_shift = qy_shift.round().astype(int) + + # ...for boolean masks and unmasked + if mask is None or mask.dtype==bool: + # get scan points + mask = np.ones(self.Rshape,dtype=bool) if mask is None else mask + mask_indices = np.nonzero(mask) + # allocate space + virtual_diffraction = np.zeros(self.Qshape) + # loop + for rx,ry in zip(mask_indices[0],mask_indices[1]): + # get shifted DP + DP = np.roll( + self.data[rx,ry, :,:,], + (qx_shift[rx,ry], qy_shift[rx,ry]), + axis=(0,1), + ) + # compute + if method == 'mean': + virtual_diffraction += DP + elif method == 'max': + virtual_diffraction = np.maximum(virtual_diffraction, DP) + # normalize means + if method == 'mean': + virtual_diffraction /= len(mask_indices[0]) + + # ...for floating point and complex masks + else: + # allocate space + if mask.dtype == 'complex': + virtual_diffraction = np.zeros(self.Qshape, dtype = 'complex') + else: + virtual_diffraction = np.zeros(self.Qshape) + # loop + for rx,ry in tqdmnd( + self.R_Nx, + self.R_Ny, + disable = not verbose, + ): + # get shifted DP + DP = np.roll( + self.data[rx,ry, :,:,], + (qx_shift[rx,ry], qy_shift[rx,ry]), + axis=(0,1), + ) + # compute + w = mask[rx,ry] + if method == 'mean': + virtual_diffraction += DP*w + elif method == 'max': + virtual_diffraction = np.maximum(virtual_diffraction, DP*w) + if method == 'mean': + virtual_diffraction /= np.sum(mask) + + # TODO subpixel shifting + else: + raise Exception("subpixel shifting has not been implemented yet!") + pass + + + # wrap, add to tree, and return + + # wrap in DiffractionImage + ans = VirtualDiffraction( + data = virtual_diffraction, + name = name + ) + + # add the args used to gen this dp as metadata + ans.metadata = Metadata( + name='gen_params', + data = { + '_calling_method' : inspect.stack()[0][3], + '_calling_class' : __class__.__name__, + 'method' : method, + 'mask' : mask, + 'shift_center' : shift_center, + 'subpixel' : subpixel, + 'verbose' : verbose, + 'name' : name, + 'returncalc' : returncalc + } + ) + + # add to the tree + self.attach( ans ) + + # return + if returncalc: + return ans + + + + # additional interfaces + + def get_dp_max( + self, + returncalc = True, + ): + """ + Calculates the max diffraction pattern. + + Calls `DataCube.get_virtual_diffraction` - see that method's docstring + for more custimizable virtual diffraction. + + Parameters + ---------- + returncalc : bool + toggles returning the answer + + Returns + ------- + max_dp : VirtualDiffraction + """ + return self.get_virtual_diffraction( + method = 'max', + mask = None, + shift_center = False, + subpixel = False, + verbose = True, + name = 'dp_max', + returncalc = True + ) + + def get_dp_mean( + self, + returncalc = True, + ): + """ + Calculates the mean diffraction pattern. + + Calls `DataCube.get_virtual_diffraction` - see that method's docstring + for more custimizable virtual diffraction. + + Parameters + ---------- + returncalc : bool + toggles returning the answer + + Returns + ------- + mean_dp : VirtualDiffraction + """ + return self.get_virtual_diffraction( + method = 'mean', + mask = None, + shift_center = False, + subpixel = False, + verbose = True, + name = 'dp_mean', + returncalc = True + ) + + def get_dp_median( + self, + returncalc = True, + ): + """ + Calculates the max diffraction pattern. + + Calls `DataCube.get_virtual_diffraction` - see that method's docstring + for more custimizable virtual diffraction. + + Parameters + ---------- + returncalc : bool + toggles returning the answer + + Returns + ------- + max_dp : VirtualDiffraction + """ + return self.get_virtual_diffraction( + method = 'median', + mask = None, + shift_center = False, + subpixel = False, + verbose = True, + name = 'dp_median', + returncalc = True + ) + diff --git a/py4DSTEM/datacube/virtualimage.py b/py4DSTEM/datacube/virtualimage.py new file mode 100644 index 000000000..ad6344c7d --- /dev/null +++ b/py4DSTEM/datacube/virtualimage.py @@ -0,0 +1,690 @@ +# Virtual imaging from a datacube. Includes: +# * VirtualImage - a container for virtual image data + metadata +# * DataCubeVirtualImager - methods inherited by DataCube for virt imaging +# +# for bragg virtual imaging methods, goto diskdetection.virtualimage.py + +import numpy as np +import dask.array as da +from typing import Optional +import inspect + +from emdfile import tqdmnd,Metadata +from py4DSTEM.data import Calibration, RealSlice, Data, DiffractionSlice +from py4DSTEM.visualize.show import show + + + +# Virtual image container class + +class VirtualImage(RealSlice,Data): + """ + A container for storing virtual image data and metadata, + including the real-space shaped 2D image and metadata + indicating how this image was generated from a datacube. + """ + def __init__( + self, + data: np.ndarray, + name: Optional[str] = 'virtualimage', + ): + """ + Parameters + ---------- + data : np.ndarray + the 2D data + name : str + the name + """ + # initialize as a RealSlice + RealSlice.__init__( + self, + data = data, + name = name, + ) + + # read + @classmethod + def _get_constructor_args(cls,group): + """ + Returns a dictionary of args/values to pass to the class constructor + """ + ar_constr_args = RealSlice._get_constructor_args(group) + args = { + 'data' : ar_constr_args['data'], + 'name' : ar_constr_args['name'], + } + return args + + + + + +# DataCube virtual imaging methods + +class DataCubeVirtualImager: + + def __init__(self): + pass + + + def get_virtual_image( + self, + mode, + geometry, + centered = False, + calibrated = False, + shift_center = False, + verbose = True, + dask = False, + return_mask = False, + name = 'virtual_image', + returncalc = True, + test_config = False + ): + """ + Calculate a virtual image. + + The detector is determined by the combination of the `mode` and + `geometry` arguments, supporting point, circular, rectangular, + annular, and custom mask detectors. The values passed to geometry + may be given with respect to an origin at the corner of the detector + array or with respect to the calibrated center position, and in units of + pixels or real calibrated units, depending on the values of the + `centered` and `calibrated` arguments, respectively. The mask may be + shifted pattern-by-pattern to account for diffraction scan shifts using + the `shift_center` argument. + + The computed virtual image is stored in the datacube's tree, and is + also returned by default. + + Parameters + ---------- + mode : str + defines geometry mode for calculating virtual image, and the + expected input for the `geometry` argument. options: + - 'point': uses a single pixel detector + - 'circle', 'circular': uses a round detector, like bright + field + - 'annular', 'annulus': uses an annular detector, like dark + field + - 'rectangle', 'square', 'rectangular': uses rectangular + detector + - 'mask': any diffraction-space shaped 2D array, representing + a flexible detector + geometry : variable + the expected value of this argument is determined by `mode` as + follows: + - 'point': 2-tuple, (qx,qy), ints + - 'circle', 'circular': nested 2-tuple, ((qx,qy),radius), + - 'annular', 'annulus': nested 2-tuple, + ((qx,qy),(radius_i,radius_o)), + - 'rectangle', 'square', 'rectangular': 4-tuple, + (xmin,xmax,ymin,ymax) + - `mask`: any boolean or floating point 2D array with the same + size as datacube.Qshape + centered : bool + if False, the origin is in the upper left corner. If True, the origin + is set to the mean origin in the datacube calibrations, so that a + bright-field image could be specified with, e.g., geometry=((0,0),R). + The origin can set with datacube.calibration.set_origin(). For + `mode="mask"`, has no effect. Default is False. + calibrated : bool + if True, geometry is specified in units of 'A^-1' instead of pixels. + The datacube's calibrations must have its `"Q_pixel_units"` parameter + set to "A^-1". For `mode="mask"`, has no effect. Default is False. + shift_center : bool + if True, the mask is shifted at each real space position to account + for any shifting of the origin of the diffraction images. The + datacube's calibration['origin'] parameter must be set. The shift + applied to each pattern is the difference between the local origin + position and the mean origin position over all patterns, rounded to + the nearest integer for speed. Default is False. If `shift_center` is + True, `centered` is automatically set to True. + verbose : bool + toggles a progress bar + dask : bool + if True, use dask to distribute the calculation + return_mask : bool + if False (default) returns a virtual image as usual. Otherwise does + *not* compute or return a virtual image, instead finding and + returning the mask that will be used in subsequent calls to this + function using these same parameters. In this case, must be either + `True` or a 2-tuple of integers corresponding to `(rx,ry)`. If True + is passed, returns the mask used if `shift_center` is set to False. + If a 2-tuple is passed, returns the mask used at scan position + (rx,ry) if `shift_center` is set to True. Nothing is added to the + datacube's tree. + name : str + the output object's name + returncalc : bool + if True, returns the output + test_config : bool + if True, prints the Boolean values of + (`centered`,`calibrated`,`shift_center`). Does not compute the + virtual image. + + Returns + ------- + virt_im : VirtualImage (optional, if returncalc is True) + """ + # parse inputs + assert mode in ('point', 'circle', 'circular', 'annulus', 'annular', 'rectangle', 'square', 'rectangular', 'mask'),\ + 'check doc strings for supported modes' + if shift_center == True: + centered = True + if test_config: + for x,y in zip(['centered','calibrated','shift_center'], + [centered,calibrated,shift_center]): + print(f"{x} = {y}") + + # Get geometry + g = self.get_calibrated_detector_geometry( + self.calibration, + mode, + geometry, + centered, + calibrated + ) + + # Get mask + mask = self.make_detector(self.Qshape, mode, g) + # if return_mask is True, skip computation + if return_mask == True and shift_center == False: + return mask + + + # Calculate virtual image + + # no center shifting + if shift_center == False: + + # single CPU + if not dask: + + # allocate space + if mask.dtype == 'complex': + virtual_image = np.zeros(self.Rshape, dtype = 'complex') + else: + virtual_image = np.zeros(self.Rshape) + # compute + for rx,ry in tqdmnd( + self.R_Nx, + self.R_Ny, + disable = not verbose, + ): + virtual_image[rx,ry] = np.sum(self.data[rx,ry]*mask) + + # dask + if dask == True: + + # set up a generalized universal function for dask distribution + def _apply_mask_dask(self,mask): + virtual_image = np.sum(np.multiply(self.data,mask), dtype=np.float64) + apply_mask_dask = da.as_gufunc( + _apply_mask_dask,signature='(i,j),(i,j)->()', + output_dtypes=np.float64, + axes=[(2,3),(0,1),()], + vectorize=True + ) + + # compute + virtual_image = apply_mask_dask(self.data, mask) + + # with center shifting + else: + + # get shifts + assert(self.calibration.get_origin_shift() is not None), "origin need to be calibrated" + qx_shift,qy_shift = self.calibration.get_origin_shift() + qx_shift = qx_shift.round().astype(int) + qy_shift = qy_shift.round().astype(int) + + # if return_mask is True, get+return the mask and skip the computation + if return_mask is not False: + try: + rx,ry = return_mask + except TypeError: + raise Exception(f"if `shift_center=True`, return_mask must be a 2-tuple of ints or False, but revieced inpute value of {return_mask}") + _mask = np.roll( + mask, + (qx_shift[rx,ry], qy_shift[rx,ry]), + axis=(0,1) + ) + return _mask + + # allocate space + if mask.dtype == 'complex': + virtual_image = np.zeros(self.Rshape, dtype = 'complex') + else: + virtual_image = np.zeros(self.Rshape) + + # loop + for rx,ry in tqdmnd( + self.R_Nx, + self.R_Ny, + disable = not verbose, + ): + # get shifted mask + _mask = np.roll( + mask, + (qx_shift[rx,ry], qy_shift[rx,ry]), + axis=(0,1) + ) + # add to output array + virtual_image[rx,ry] = np.sum(self.data[rx,ry]*_mask) + + + # data handling + + # wrap with a py4dstem class + ans = VirtualImage( + data = virtual_image, + name = name, + ) + + # add generating params as metadata + ans.metadata = Metadata( + name = 'gen_params', + data = { + '_calling_method' : inspect.stack()[0][3], + '_calling_class' : __class__.__name__, + 'mode' : mode, + 'geometry' : geometry, + 'centered' : centered, + 'calibrated' : calibrated, + 'shift_center' : shift_center, + 'verbose' : verbose, + 'dask' : dask, + 'return_mask' : return_mask, + 'name' : name, + 'returncalc' : True, + 'test_config' : test_config + } + ) + + # add to the tree + self.attach( ans ) + + # return + if returncalc: + return ans + + + + + # Position detector + + def position_detector( + self, + mode, + geometry, + data = None, + centered = None, + calibrated = None, + shift_center = False, + scan_position = None, + invert = False, + color = 'r', + alpha = 0.7, + **kwargs + ): + """ + Position a virtual detector by displaying a mask over a diffraction + space image. Calling `.get_virtual_image()` using the same `mode` + and `geometry` parameters will compute a virtual image using this + detector. + + Parameters + ---------- + mode : str + see the DataCube.get_virtual_image docstring + geometry : variable + see the DataCube.get_virtual_image docstring + data : None or 2d-array or 2-tuple of ints + The diffraction image to overlay the mask on. If `None` (default), + looks for a max or mean or median diffraction image in this order + and if found, uses it, otherwise, uses the diffraction pattern at + scan position (0,0). If a 2d array is passed, must be diffraction + space shaped array. If a 2-tuple is passed, uses the diffraction + pattern at scan position (rx,ry). + centered : bool + see the DataCube.get_virtual_image docstring + calibrated : bool + see the DataCube.get_virtual_image docstring + shift_center : None or bool or 2-tuple of ints + If `None` (default) and `data` is either None or an array, the mask + is not shifted. If `None` and `data` is a 2-tuple, shifts the mask + according to the origin at the scan position (rx,ry) specified in + `data`. If False, does not shift the mask. If True and `data` is + a 2-tuple, shifts the mask accordingly, and if True and `data` is + any other value, raises an error. If `shift_center` is a 2-tuple, + shifts the mask according to the origin value at this 2-tuple + regardless of the value of `data` (enabling e.g. overlaying the + mask for a specific scan position on a max or mean diffraction + image.) + invert : bool + if True, invert the masked pixel (i.e. pixels *outside* the detector + are overlaid with a mask) + color : any matplotlib color specification + the mask color + alpha : number + the mask transparency + kwargs : dict + Any additional arguments are passed on to the show() function + """ + # parse inputs + + # mode + assert mode in ('point', 'circle', 'circular', 'annulus', 'annular', 'rectangle', 'square', 'rectangular', 'mask'),\ + 'check doc strings for supported modes' + + # data + if data is None: + image = None + keys = ['dp_mean','dp_max','dp_median'] + for k in keys: + try: + image = self.tree(k) + break + except: + pass + if image is None: + image = self[0,0] + elif isinstance(data, np.ndarray): + assert(data.shape == self.Qshape), f"Can't position a detector over an image with a shape that is different from diffraction space. Diffraction space in this dataset has shape {self.Qshape} but the image passed has shape {data.shape}" + image = data + elif isinstance(data, DiffractionSlice): + assert(data.shape == self.Qshape), f"Can't position a detector over an image with a shape that is different from diffraction space. Diffraction space in this dataset has shape {self.Qshape} but the image passed has shape {data.shape}" + image = data.data + elif isinstance(data,tuple): + rx,ry = data[:2] + image = self[rx,ry] + else: + raise Exception(f"Invalid argument passed to `data`. Expected None or np.ndarray or tuple, not type {type(data)}") + + # shift center + if shift_center is None: + shift_center = False + elif shift_center == True: + assert(isinstance(data,tuple)), "If shift_center is set to True, `data` should be a 2-tuple (rx,ry). To shift the detector mask while using some other input for `data`, set `shift_center` to a 2-tuple (rx,ry)" + elif isinstance(shift_center,tuple): + rx,ry = shift_center[:2] + shift_center = True + else: + shift_center = False + + + # Get the mask + + # Get geometry + g = self.get_calibrated_detector_geometry( + calibration = self.calibration, + mode = mode, + geometry = geometry, + centered = centered, + calibrated = calibrated + ) + + # Get mask + mask = self.make_detector(image.shape, mode, g) + if not(invert): + mask = np.logical_not(mask) + + # Shift center + if shift_center: + try: + rx,ry + except NameError: + raise Exception("if `shift_center` is True then `data` must be the 3-tuple (DataCube,rx,ry)") + # get shifts + assert(self.calibration.get_origin_shift() is not None), "origin shifts need to be calibrated" + qx_shift,qy_shift = self.calibration.cal.get_origin_shift() + qx_shift = int(np.round(qx_shift[rx,ry])) + qy_shift = int(np.round(qy_shift[rx,ry])) + mask = np.roll( + mask, + (qx_shift, qy_shift), + axis=(0,1) + ) + + # Show + show( + image, + mask = mask, + mask_color = color, + mask_alpha = alpha, + **kwargs + ) + return + + + + @staticmethod + def get_calibrated_detector_geometry( + calibration, + mode, + geometry, + centered, + calibrated + ): + """ + Determine the detector geometry in pixels, given some mode and geometry + in calibrated units, where the calibration state is specified by { + centered, calibrated} + + Parameters + ---------- + calibration : Calibration + Used to retrieve the center positions. If `None`, confirms that + centered and calibrated are False then passes, otherwise raises + an exception + mode : str + see the DataCube.get_virtual_image docstring + geometry : variable + see the DataCube.get_virtual_image docstring + centered : bool + see the DataCube.get_virtual_image docstring + calibrated : bool + see the DataCube.get_virtual_image docstring + + Returns + ------- + geo : tuple + the geometry in detector pixels + """ + # Parse inputs + g = geometry + if calibration is None: + assert calibrated is False and centered is False, "No calibration found - set a calibration or set `centered` and `calibrated` to False" + return g + else: + assert(isinstance(calibration, Calibration)) + cal = calibration + + # Get calibration metadata + if centered: + assert cal.get_qx0_mean() is not None, "origin needs to be calibrated" + x0_mean, y0_mean = cal.get_origin_mean() + + if calibrated: + assert cal['Q_pixel_units'] == 'A^-1', \ + 'check calibration - must be calibrated in A^-1 to use `calibrated=True`' + unit_conversion = cal.get_Q_pixel_size() + + + # Convert units into detector pixels + + # Shift center + if centered == True: + if mode == 'point': + g = (g[0] + x0_mean, g[1] + y0_mean) + if mode in('circle', 'circular', 'annulus', 'annular'): + g = ((g[0][0] + x0_mean, g[0][1] + y0_mean), g[1]) + if mode in('rectangle', 'square', 'rectangular') : + g = (g[0] + x0_mean, g[1] + x0_mean, g[2] + y0_mean, g[3] + y0_mean) + + # Scale by the detector pixel size + if calibrated == True: + if mode == 'point': + g = (g[0]/unit_conversion, g[1]/unit_conversion) + if mode in('circle', 'circular'): + g = ((g[0][0]/unit_conversion, g[0][1]/unit_conversion), + (g[1]/unit_conversion)) + if mode in('annulus', 'annular'): + g = ((g[0][0]/unit_conversion, g[0][1]/unit_conversion), + (g[1][0]/unit_conversion, g[1][1]/unit_conversion)) + if mode in('rectangle', 'square', 'rectangular') : + g = (g[0]/unit_conversion, g[1]/unit_conversion, + g[2]/unit_conversion, g[3]/unit_conversion) + + return g + + + @staticmethod + def make_detector( + shape, + mode, + geometry, + ): + ''' + Generate a 2D mask representing a detector function. + + Parameters + ---------- + shape : 2-tuple + defines shape of mask. Should be the shape of diffraction space. + mode : str + defines geometry mode for calculating virtual image. See the + docstring for DataCube.get_virtual_image + geometry : variable + defines geometry for calculating virtual image. See the + docstring for DataCube.get_virtual_image + + Returns + ------- + detector_mask : 2d array + ''' + g = geometry + + #point mask + if mode == 'point': + assert(isinstance(g,tuple) and len(g)==2), 'specify qx and qy as tuple (qx, qy)' + mask = np.zeros(shape, dtype=bool) + + qx = int(g[0]) + qy = int(g[1]) + + mask[qx,qy] = 1 + + #circular mask + if mode in('circle', 'circular'): + assert(isinstance(g,tuple) and len(g)==2 and len(g[0])==2 and isinstance(g[1],(float,int))), \ + 'specify qx, qy, radius_i as ((qx, qy), radius)' + + qxa, qya = np.indices(shape) + mask = (qxa - g[0][0]) ** 2 + (qya - g[0][1]) ** 2 < g[1] ** 2 + + #annular mask + if mode in('annulus', 'annular'): + assert(isinstance(g,tuple) and len(g)==2 and len(g[0])==2 and len(g[1])==2), \ + 'specify qx, qy, radius_i, radius_0 as ((qx, qy), (radius_i, radius_o))' + + assert g[1][1] > g[1][0], "Inner radius must be smaller than outer radius" + + qxa, qya = np.indices(shape) + mask1 = (qxa - g[0][0]) ** 2 + (qya - g[0][1]) ** 2 > g[1][0] ** 2 + mask2 = (qxa - g[0][0]) ** 2 + (qya - g[0][1]) ** 2 < g[1][1] ** 2 + mask = np.logical_and(mask1, mask2) + + #rectangle mask + if mode in('rectangle', 'square', 'rectangular') : + assert(isinstance(g,tuple) and len(g)==4), \ + 'specify x_min, x_max, y_min, y_max as (x_min, x_max, y_min, y_max)' + mask = np.zeros(shape, dtype=bool) + + xmin = int(np.round(g[0])) + xmax = int(np.round(g[1])) + ymin = int(np.round(g[2])) + ymax = int(np.round(g[3])) + + mask[xmin:xmax, ymin:ymax] = 1 + + #flexible mask + if mode == 'mask': + assert type(g) == np.ndarray, '`geometry` type should be `np.ndarray`' + assert (g.shape == shape), 'mask and diffraction pattern shapes do not match' + mask = g + return mask + + + + + # TODO where should this go? + def make_bragg_mask( + self, + Qshape, + g1, + g2, + radius, + origin, + max_q, + return_sum = True, + **kwargs, + ): + ''' + Creates and returns a mask consisting of circular disks + about the points of a 2D lattice. + + Args: + Qshape (2 tuple): the shape of diffraction space + g1,g2 (len 2 array or tuple): the lattice vectors + radius (number): the disk radius + origin (len 2 array or tuple): the origin + max_q (nuumber): the maxima distance to tile to + return_sum (bool): if False, return a 3D array, where each + slice contains a single disk; if False, return a single + 2D masks of all disks + + Returns: + (2 or 3D array) the mask + ''' + nas = np.asarray + g1,g2,origin = nas(g1),nas(g2),nas(origin) + + # Get N,M, the maximum indices to tile out to + L1 = np.sqrt(np.sum(g1**2)) + H = int(max_q/L1) + 1 + L2 = np.hypot(-g2[0]*g1[1],g2[1]*g1[0])/np.sqrt(np.sum(g1**2)) + K = int(max_q/L2) + 1 + + # Compute number of points + N = 0 + for h in range(-H,H+1): + for k in range(-K,K+1): + v = h*g1 + k*g2 + if np.sqrt(v.dot(v)) < max_q: + N += 1 + + #create mask + mask = np.zeros((Qshape[0], Qshape[1], N), dtype=bool) + N = 0 + for h in range(-H,H+1): + for k in range(-K,K+1): + v = h*g1 + k*g2 + if np.sqrt(v.dot(v)) < max_q: + center = origin + v + mask[:,:,N] = self.make_detector( + Qshape, + mode = 'circle', + geometry = (center, radius), + ) + N += 1 + + + if return_sum: + mask = np.sum(mask, axis = 2) + return mask + + + diff --git a/py4DSTEM/io/__init__.py b/py4DSTEM/io/__init__.py index 399127577..80b591f8a 100644 --- a/py4DSTEM/io/__init__.py +++ b/py4DSTEM/io/__init__.py @@ -5,15 +5,8 @@ from py4DSTEM.io.save import save -# TODO -# - read fn - triage new/old EMD files -# - save fn - call EMD write fn with any special defaults -# (mod root __init__ emd.write import) - # google downloader -from py4DSTEM.io.google_drive_downloader import ( - download_file_from_google_drive, - get_sample_data_ids -) +from py4DSTEM.io.google_drive_downloader import gdrive_download, get_sample_file_ids + diff --git a/py4DSTEM/io/filereaders/__init__.py b/py4DSTEM/io/filereaders/__init__.py index d256334a8..b6f4eb0a2 100644 --- a/py4DSTEM/io/filereaders/__init__.py +++ b/py4DSTEM/io/filereaders/__init__.py @@ -2,4 +2,5 @@ from py4DSTEM.io.filereaders.read_K2 import read_gatan_K2_bin from py4DSTEM.io.filereaders.empad import read_empad from py4DSTEM.io.filereaders.read_mib import load_mib - +from py4DSTEM.io.filereaders.read_arina import read_arina +from py4DSTEM.io.filereaders.read_abTEM import read_abTEM diff --git a/py4DSTEM/io/filereaders/empad.py b/py4DSTEM/io/filereaders/empad.py index 04feaee11..98b515863 100644 --- a/py4DSTEM/io/filereaders/empad.py +++ b/py4DSTEM/io/filereaders/empad.py @@ -8,7 +8,7 @@ import numpy as np from pathlib import Path from emdfile import tqdmnd -from py4DSTEM.classes import DataCube +from py4DSTEM.datacube import DataCube from py4DSTEM.preprocess.utils import bin2D diff --git a/py4DSTEM/io/filereaders/read_K2.py b/py4DSTEM/io/filereaders/read_K2.py index 2e483e294..c03de98d3 100644 --- a/py4DSTEM/io/filereaders/read_K2.py +++ b/py4DSTEM/io/filereaders/read_K2.py @@ -9,7 +9,7 @@ except ImportError: pass from emdfile import tqdmnd -from py4DSTEM.classes import DataCube +from py4DSTEM.datacube import DataCube def read_gatan_K2_bin(fp, mem="MEMMAP", binfactor=1, metadata=False, **kwargs): diff --git a/py4DSTEM/io/filereaders/read_abTEM.py b/py4DSTEM/io/filereaders/read_abTEM.py new file mode 100644 index 000000000..1fec9e73e --- /dev/null +++ b/py4DSTEM/io/filereaders/read_abTEM.py @@ -0,0 +1,81 @@ +import h5py +from py4DSTEM.data import DiffractionSlice, RealSlice +from py4DSTEM.datacube import DataCube + +def read_abTEM( + filename, + mem="RAM", + binfactor: int = 1, +): + """ + File reader for abTEM datasets + Args: + filename: str with path to file + mem (str): Must be "RAM" or "MEMMAP". Specifies how the data is + loaded; "RAM" transfer the data from storage to RAM, while "MEMMAP" + leaves the data in storage and creates a memory map which points to + the diffraction patterns, allowing them to be retrieved individually + from storage. + binfactor (int): Diffraction space binning factor for bin-on-load. + + Returns: + DataCube + """ + assert mem == "RAM", "read_abTEM does not support memory mapping" + assert binfactor == 1, "abTEM files can only be read at full resolution" + + with h5py.File(filename, "r") as f: + datasets = {} + for key in f.keys(): + datasets[key] = f.get(key)[()] + + data = datasets["array"] + + sampling = datasets["sampling"] + units = datasets["units"] + + assert len(data.shape) in (2, 4), "abtem reader supports only 4D and 2D data" + + if len(data.shape) == 4: + + datacube = DataCube(data=data) + + datacube.calibration.set_R_pixel_size(sampling[0]) + if sampling[0] != sampling[1]: + print( + "Warning: py4DSTEM currently only handles uniform x,y sampling. Setting sampling with x calibration" + ) + datacube.calibration.set_Q_pixel_size(sampling[2]) + if sampling[2] != sampling[3]: + print( + "Warning: py4DSTEM currently only handles uniform qx,qy sampling. Setting sampling with qx calibration" + ) + + if units[0] == b"\xc3\x85": + datacube.calibration.set_R_pixel_units("A") + else: + datacube.calibration.set_R_pixel_units(units[0].decode("utf-8")) + + datacube.calibration.set_Q_pixel_units(units[2].decode("utf-8")) + + return datacube + + else: + if units[0] == b"mrad": + diffraction = DiffractionSlice(data=data) + if sampling[0] != sampling[1]: + print( + "Warning: py4DSTEM currently only handles uniform qx,qy sampling. Setting sampling with x calibration" + ) + diffraction.calibration.set_Q_pixel_units(units[0].decode("utf-8")) + diffraction.calibration.set_Q_pixel_size(sampling[0]) + return diffraction + else: + image = RealSlice(data=data) + if sampling[0] != sampling[1]: + print( + "Warning: py4DSTEM currently only handles uniform x,y sampling. Setting sampling with x calibration" + ) + image.calibration.set_Q_pixel_units("A") + image.calibration.set_Q_pixel_size(sampling[0]) + return image diff --git a/py4DSTEM/io/filereaders/read_arina.py b/py4DSTEM/io/filereaders/read_arina.py new file mode 100644 index 000000000..323b5643f --- /dev/null +++ b/py4DSTEM/io/filereaders/read_arina.py @@ -0,0 +1,115 @@ +import h5py +import hdf5plugin +import numpy as np +from py4DSTEM.datacube import DataCube +from py4DSTEM.preprocess.utils import bin2D + + +def read_arina( + filename, + scan_width=1, + mem="RAM", + binfactor: int = 1, + dtype_bin: float = None, + flatfield: np.ndarray = None, +): + + """ + File reader for arina 4D-STEM datasets + Args: + filename: str with path to master file + scan_width: x dimension of scan + mem (str): Must be "RAM" or "MEMMAP". Specifies how the data is + loaded; "RAM" transfer the data from storage to RAM, while "MEMMAP" + leaves the data in storage and creates a memory map which points to + the diffraction patterns, allowing them to be retrieved individually + from storage. + binfactor (int): Diffraction space binning factor for bin-on-load. + dtype_bin(float): specify datatype for bin on load if need something + other than uint16 + flatfield (np.ndarray): + flatfield forcorrection factors + + Returns: + DataCube + """ + assert mem == "RAM", "read_arina does not support memory mapping" + + f = h5py.File(filename, "r") + nimages = 0 + + # Count the number of images in all datasets + for dset in f["entry"]["data"]: + nimages = nimages + f["entry"]["data"][dset].shape[0] + height = f["entry"]["data"][dset].shape[1] + width = f["entry"]["data"][dset].shape[2] + dtype = f["entry"]["data"][dset].dtype + + width = width // binfactor + height = height // binfactor + + assert ( + nimages % scan_width < 1e-6 + ), "scan_width must be integer multiple of x*y size" + + if dtype.type is np.uint32: + print("Dataset is uint32 but will be converted to uint16") + dtype = np.dtype(np.uint16) + + if dtype_bin: + array_3D = np.empty((nimages, width, height), dtype=dtype_bin) + else: + array_3D = np.empty((nimages, width, height), dtype=dtype) + + image_index = 0 + + if flatfield is None: + correction_factors = 1 + else: + # Avoid div by 0 errors -> pixel with value 0 will be set to meadian + flatfield[flatfield == 0] = 1 + correction_factors = np.median(flatfield) / flatfield + + for dset in f["entry"]["data"]: + image_index = _processDataSet( + f["entry"]["data"][dset], + image_index, + array_3D, + binfactor, + correction_factors, + ) + + if f.__bool__(): + f.close() + + scan_height = int(nimages / scan_width) + + datacube = DataCube( + np.flip( + array_3D.reshape( + scan_width, scan_height, array_3D.data.shape[1], array_3D.data.shape[2] + ), + 0, + ) + ) + + return datacube + + +def _processDataSet(dset, start_index, array_3D, binfactor, correction_factors): + image_index = start_index + nimages_dset = dset.shape[0] + + for i in range(nimages_dset): + if binfactor == 1: + array_3D[image_index] = np.multiply( + dset[i].astype(array_3D.dtype), correction_factors + ) + else: + array_3D[image_index] = bin2D( + np.multiply(dset[i].astype(array_3D.dtype), correction_factors), + binfactor, + ) + + image_index = image_index + 1 + return image_index diff --git a/py4DSTEM/io/filereaders/read_dm.py b/py4DSTEM/io/filereaders/read_dm.py index 153e6d939..bc41695e1 100644 --- a/py4DSTEM/io/filereaders/read_dm.py +++ b/py4DSTEM/io/filereaders/read_dm.py @@ -5,7 +5,7 @@ from ncempy.io import dm from emdfile import tqdmnd, Array -from py4DSTEM.classes import DataCube +from py4DSTEM.datacube import DataCube from py4DSTEM.preprocess.utils import bin2D @@ -30,7 +30,7 @@ def read_dm( Metadata instance. kwargs: "dtype": a numpy dtype specifier to use for data binned on load, - defaults to np.float32 + defaults to the data's current dtype Returns: DataCube if a 4D dataset is found, else an ND Array @@ -96,7 +96,7 @@ def read_dm( elif Q_pixel_units == "1/nm": Q_pixel_units = "A^-1" Q_pixel_size /= 10 - + pixel_size_found = True except Exception as err: pass @@ -108,10 +108,12 @@ def read_dm( if binfactor == 1: _data = dmFile.getDataset(dataset_index)["data"] else: + # get a memory map + _mmap = dmFile.getMemmap(dataset_index) + # get the dtype for the binned data - dtype = kwargs.get("dtype", np.float32) + dtype = kwargs.get("dtype", _mmap[0,0].dtype) - _mmap = dmFile.getMemmap(dataset_index) if titan_shape is not None: # NCEM TitanX tags were found _mmap = np.reshape(_mmap, titan_shape + _mmap.shape[-2:]) @@ -120,7 +122,7 @@ def read_dm( _mmap.shape[2] // binfactor, _mmap.shape[3] // binfactor, ) - _data = np.zeros(new_shape) + _data = np.zeros(new_shape, dtype=dtype) for rx, ry in tqdmnd(*_data.shape[:2]): _data[rx, ry] = bin2D(_mmap[rx, ry], binfactor, dtype=dtype) diff --git a/py4DSTEM/io/filereaders/read_mib.py b/py4DSTEM/io/filereaders/read_mib.py index cfb0b10b8..e9dccff90 100644 --- a/py4DSTEM/io/filereaders/read_mib.py +++ b/py4DSTEM/io/filereaders/read_mib.py @@ -3,7 +3,7 @@ # Based on the PyXEM load_mib module https://github.com/pyxem/pyxem/blob/563a3bb5f3233f46cd3e57f3cd6f9ddf7af55ad0/pyxem/utils/io_utils.py import numpy as np -from py4DSTEM.classes import DataCube +from py4DSTEM.datacube import DataCube import os def load_mib( diff --git a/py4DSTEM/io/google_drive_downloader.py b/py4DSTEM/io/google_drive_downloader.py index 00b81d39d..86ad1a9f4 100644 --- a/py4DSTEM/io/google_drive_downloader.py +++ b/py4DSTEM/io/google_drive_downloader.py @@ -1,146 +1,278 @@ import gdown -from typing import Union -import pathlib import os -# Built-in sample datasets +### File IDs # single files -sample_file_ids = { - 'FCU-Net' : '1-KX0saEYfhZ9IJAOwabH38PCVtfXidJi', - 'sample_diffraction_pattern':'1ymYMnuDC0KV6dqduxe2O1qafgSd0jjnU', - 'Au_sim':'1PmbCYosA1eYydWmmZebvf6uon9k_5g_S', - 'carbon_nanotube':'1bHv3u61Cr-y_GkdWHrJGh1lw2VKmt3UM', - 'Si_SiGe_exp':'1fXNYSGpe6w6E9RBA-Ai_owZwoj3w8PNC', - 'Si_SiGe_probe':'141Tv0YF7c5a-MCrh3CkY_w4FgWtBih80', - 'Si_SiGe_EELS_strain':'1klkecq8IuEOYB-bXchO7RqOcgCl4bmDJ', - 'AuAgPd_wire':'1OQYW0H6VELsmnLTcwicP88vo2V5E3Oyt', - 'AuAgPd_wire_probe':'17OduUKpxVBDumSK_VHtnc2XKkaFVN8kq', - 'polycrystal_2D_WS2':'1AWB3-UTPiTR9dgrEkNFD7EJYsKnbEy0y', - 'WS2cif':'13zBl6aFExtsz_sew-L0-_ALYJfcgHKjo', - 'polymers':'1lK-TAMXN1MpWG0Q3_4vss_uEZgW2_Xh7', +file_ids = { + 'sample_diffraction_pattern' : ( + 'a_diffraction_pattern.h5', + '1ymYMnuDC0KV6dqduxe2O1qafgSd0jjnU', + ), + 'Au_sim' : ( + 'Au_sim.h5', + '1PmbCYosA1eYydWmmZebvf6uon9k_5g_S', + ), + 'carbon_nanotube' : ( + 'carbon_nanotube.h5', + '1bHv3u61Cr-y_GkdWHrJGh1lw2VKmt3UM', + ), + 'Si_SiGe_exp' : ( + 'Si_SiGe_exp.h5', + '1fXNYSGpe6w6E9RBA-Ai_owZwoj3w8PNC', + ), + 'Si_SiGe_probe' : ( + 'Si_SiGe_probe.h5', + '141Tv0YF7c5a-MCrh3CkY_w4FgWtBih80', + ), + 'Si_SiGe_EELS_strain' : ( + 'Si_SiGe_EELS_strain.h5', + '1klkecq8IuEOYB-bXchO7RqOcgCl4bmDJ', + ), + 'AuAgPd_wire' : ( + 'AuAgPd_wire.h5', + '1OQYW0H6VELsmnLTcwicP88vo2V5E3Oyt', + ), + 'AuAgPd_wire_probe' : ( + 'AuAgPd_wire_probe.h5', + '17OduUKpxVBDumSK_VHtnc2XKkaFVN8kq', + ), + 'polycrystal_2D_WS2' : ( + 'polycrystal_2D_WS2.h5', + '1AWB3-UTPiTR9dgrEkNFD7EJYsKnbEy0y', + ), + 'WS2cif' : ( + 'WS2.cif', + '13zBl6aFExtsz_sew-L0-_ALYJfcgHKjo', + ), + 'polymers' : ( + 'polymers.h5', + '1lK-TAMXN1MpWG0Q3_4vss_uEZgW2_Xh7', + ), + 'vac_probe' : ( + 'vac_probe.h5', + '1QTcSKzZjHZd1fDimSI_q9_WsAU25NIXe', + ), + 'small_dm3_3Dstack' : ( + 'small_dm3_3Dstack.dm3', + '1B-xX3F65JcWzAg0v7f1aVwnawPIfb5_o' + ), + 'FCU-Net' : ( + 'filename.name', + '1-KX0saEYfhZ9IJAOwabH38PCVtfXidJi', + ), + 'small_datacube' : ( + 'small_datacube.dm4', + # TODO - change this file to something smaller - ideally e.g. shape (4,8,256,256) ~= 4.2MB' + '1QTcSKzZjHZd1fDimSI_q9_WsAU25NIXe' + ), + 'legacy_v0.9' : ( + 'legacy_v0.9_simAuNanoplatelet_bin.h5', + '1AIRwpcj87vK3ubLaKGj1UiYXZByD2lpu' + ), + 'legacy_v0.13' : ( + 'legacy_v0.13.h5', + '1VEqUy0Gthama7YAVkxwbjQwdciHpx8rA' + ), + 'legacy_v0.14' : ( + 'legacy_v0.14.h5', + '1eOTEJrpHnNv9_DPrWgZ4-NTN21UbH4aR', + ), + 'test_realslice_io' : ( + 'test_realslice_io.h5', + '1siH80-eRJwG5R6AnU4vkoqGWByrrEz1y' + ), + 'test_arina_master' : ( + 'STO_STEM_bench_20us_master.h5', + '1q_4IjFuWRkw5VM84NhxrNTdIq4563BOC' + ), + 'test_arina_01' : ( + 'STO_STEM_bench_20us_data_000001.h5', + '1_3Dbm22-hV58iffwK9x-3vqJUsEXZBFQ' + ), + 'test_arina_02' : ( + 'STO_STEM_bench_20us_data_000002.h5', + '1x29RzHLnCzP0qthLhA1kdlUQ09ENViR8' + ), + 'test_arina_03' : ( + 'STO_STEM_bench_20us_data_000003.h5', + '1qsbzdEVD8gt4DYKnpwjfoS_Mg4ggObAA' + ), + 'test_arina_04' : ( + 'STO_STEM_bench_20us_data_000004.h5', + '1Lcswld0Y9fNBk4-__C9iJbc854BuHq-h' + ), + 'test_arina_05' : ( + 'STO_STEM_bench_20us_data_000005.h5', + '13YTO2ABsTK5nObEr7RjOZYCV3sEk3gt9' + ), + 'test_arina_06' : ( + 'STO_STEM_bench_20us_data_000006.h5', + '1RywPXt6HRbCvjgjSuYFf60QHWlOPYXwy' + ), + 'test_arina_07' : ( + 'STO_STEM_bench_20us_data_000007.h5', + '1GRoBecCvAUeSIujzsPywv1vXKSIsNyoT' + ), + 'test_arina_08' : ( + 'STO_STEM_bench_20us_data_000008.h5', + '1sTFuuvgKbTjZz1lVUfkZbbTDTQmwqhuU' + ), + 'test_arina_09' : ( + 'STO_STEM_bench_20us_data_000009.h5', + '1JmBiMg16iMVfZ5wz8z_QqcNPVRym1Ezh' + ), + 'test_arina_10' : ( + 'STO_STEM_bench_20us_data_000010.h5', + '1_90xAfclNVwMWwQ-YKxNNwBbfR1nfHoB' + ), + 'test_strain' : ( + 'downsample_Si_SiGe_analysis_braggdisks_cal.h5', + '1bYgDdAlnWHyFmY-SwN3KVpMutWBI5MhP' + ) } # collections of files -sample_collection_ids = { - 'unit_test_data' : ( - ('dm_test_file.dm3', '1RxI1QY6vYMDqqMVPt5GBN6Q_iCwHFU4B'), - ), +collection_ids = { 'tutorials' : ( - ('simulated_Au_single+polyXtal.h5', sample_file_ids['Au_sim']), - ('carbon_nanotube.h5', sample_file_ids['carbon_nanotube']), - ('Si_SiGe_exp.h5',sample_file_ids['Si_SiGe_exp']), - ('Si_SiGe_probe.h5',sample_file_ids['Si_SiGe_probe']), - ('Si_SiGe_EELS_strain.h5',sample_file_ids['Si_SiGe_EELS_strain']), - ('AuAgPd_wire.h5',sample_file_ids['AuAgPd_wire']), - ('AuAgPd_wire_probe.h5',sample_file_ids['AuAgPd_wire_probe']), - ('polycrystal_2D_WS2.h5',sample_file_ids['polycrystal_2D_WS2']), - ('WS2.cif',sample_file_ids['WS2cif']), - ('polymers.h5',sample_file_ids['polymers']), + 'Au_sim', + 'carbon_nanotube', + 'Si_SiGe_exp', + 'Si_SiGe_probe', + 'Si_SiGe_EELS_strain', + 'AuAgPd_wire', + 'AuAgPd_wire_probe', + 'polycrystal_2D_WS2', + 'WS2cif', + 'polymers', + 'vac_probe', + ), + 'test_io' : ( + 'small_dm3_3Dstack', + 'vac_probe', + 'legacy_v0.9', + 'legacy_v0.13', + 'legacy_v0.14', + 'test_realslice_io', + ), + 'test_arina' : ( + 'test_arina_master', + 'test_arina_01', + 'test_arina_02', + 'test_arina_03', + 'test_arina_04', + 'test_arina_05', + 'test_arina_06', + 'test_arina_07', + 'test_arina_08', + 'test_arina_09', + 'test_arina_10', + ), + 'test_braggvectors' : ( + 'Au_sim', ), + 'strain' : ( + 'test_strain', + ) } +def get_sample_file_ids(): + return { + 'files' : file_ids.keys(), + 'collections' : collection_ids.keys() + } -def download_file_from_google_drive(id_, destination, overwrite=False): - """ - Downloads a file or collection of files from google drive to the - destination file path. - - Args: - id_ (str): File ID for the desired file. May be: - - the file id, i.e. for - https://drive.google.com/file/d/1bHv3u61Cr-y_GkdWHrJGh1lw2VKmt3UM/ - id='1bHv3u61Cr-y_GkdWHrJGh1lw2VKmt3UM' - - the complete URL, - - a special string denoting a sample dataset or collection of - datasets. For a list of sample datasets and their keys, run - get_sample_data_ids(). - destination (str or Path): path file will be downloaded to. For - collections of files, this should point to an existing directory; - a subdirectory will be created inside this directory whose name - will be given by `id_`, and the colletion of files will be placed - inside that subdirectory. If a subdirectory of this name already - exists, aborts or deletes and overwrite the entire subdirectory, - depending on the value of `overwrite`. - overwrite (bool): turn overwrite protection on/off - """ - # handle paths - # for collections of files - if id_ in sample_collection_ids.keys(): - # check if directory exists - assert os.path.exists(destination), "specified directory does not exist; check filepath" +### Downloader - # update the path with a new sub-directory name - destination = os.path.join(destination, id_) +def gdrive_download( + id_, + destination = None, + overwrite = False, + filename = None, + verbose = True, + ): + """ + Downloads a file or collection of files from google drive. - # check if it already exists - if os.path.exists(destination): + Parameters + ---------- + id_ : str + File ID for the desired file. May be either a key from the list + of files and collections of files accessible at get_sample_file_ids(), + or a complete url, or the portions of a google drive link specifying + it's google file ID, i.e. for the address + https://drive.google.com/file/d/1bHv3u61Cr-y_GkdWHrJGh1lw2VKmt3UM/, + the id string '1bHv3u61Cr-y_GkdWHrJGh1lw2VKmt3UM'. + destination : None or str + The location files are downloaded to. If a collection of files has been + specified, creates a new directory at the specified destination and + downloads the collection there. If None, downloads to the current + working directory. Otherwise must be a string or Path pointint to + a valid location on the filesystem. + overwrite : bool + Turns overwrite protection on/off. + filename : None or str + Used only if `id_` is a url or gdrive id. In these cases, specifies + the name of the output file. If left as None, saves to + 'gdrivedownload.file'. If `id_` is a key from the sample file id list, + this parameter is ignored. + verbose : bool + Toggles verbose output + """ + # parse destination + if destination is None: + destination = os.getcwd() + assert(os.path.exists(destination)), f"`destination` must exist on filesystem. Received {destination}" - # check if it contains any files - paths = os.listdir(destination) - if len(paths) > 0: + # download single files + if id_ not in collection_ids: - # check `overwrite` - if overwrite: - for p in paths: - os.remove(os.path.join(destination, p)) - else: - raise Exception('a populated directory exists at the specified location. to overwrite set `overwrite=True`.') + # assign the name and id + kwargs = { + 'fuzzy' : True + } + if id_ in file_ids: + f = file_ids[id_] + filename = f[0] + kwargs['id'] = f[1] + else: + filename = 'gdrivedownload.file' if filename is None else filename + kwargs['url'] = id_ - # if it doesn't exist, make a new directory + # download + kwargs['output'] = os.path.join(destination, filename) + if not(overwrite) and os.path.exists(kwargs['output']): + if verbose: + print(f"A file already exists at {kwargs['output']}, skipping...") else: + gdown.download( **kwargs ) + + # download a collections of files + else: + + # set destination + destination = os.path.join(destination, id_) + if not os.path.exists(destination): os.mkdir(destination) - # get the files - for file_name,file_id in sample_collection_ids[id_]: - if file_id[:4].lower() == 'http': - gdown.download( - url = file_id, - output = os.path.join(destination, file_name), - fuzzy = True - ) + # loop + for x in collection_ids[id_]: + file_name,file_id = file_ids[x] + output = os.path.join(destination, file_name) + # download + if not(overwrite) and os.path.exists(output): + if verbose: + print(f"A file already exists at {output}, skipping...") else: gdown.download( id = file_id, - output = os.path.join(destination, file_name), + output = output, fuzzy = True ) - return None - - - # for single files - - # Check and handle if a file at the destination filepath exists - if os.path.exists(destination): - - # check `overwrite` - if overwrite == True: - print(f"File already existed, downloading and overwriting") - os.remove(destination) - else: - raise Exception('File already exists; aborting. To overwrite, use `overwrite=True`.') - - # Check if the id_ is a key pointing to a known sample dataset - if id_ in sample_file_ids.keys(): - id_ = sample_file_ids[id_] - - # Check if `id_` is a file ID or a URL, and download - if id_[:4].lower()=='http': - gdown.download(url=id_,output=destination,fuzzy=True) - else: - gdown.download(id=id_,output=destination,fuzzy=True) - - return None - - - -def get_sample_data_ids(): - return {'files' : sample_file_ids.keys(), - 'collections' : sample_collection_ids.keys()} diff --git a/py4DSTEM/io/importfile.py b/py4DSTEM/io/importfile.py index 17b052601..20a3759a2 100644 --- a/py4DSTEM/io/importfile.py +++ b/py4DSTEM/io/importfile.py @@ -1,18 +1,18 @@ # Reader functions for non-native file types import pathlib -from os.path import exists, splitext -from typing import Union, Optional +from os.path import exists +from typing import Optional, Union -from py4DSTEM.io.parsefiletype import _parse_filetype from py4DSTEM.io.filereaders import ( - read_empad, + load_mib, + read_abTEM, + read_arina, read_dm, + read_empad, read_gatan_K2_bin, - load_mib ) - - +from py4DSTEM.io.parsefiletype import _parse_filetype def import_file( @@ -37,6 +37,7 @@ def import_file( from storage. binfactor (int): Diffraction space binning factor for bin-on-load. filetype (str): Used to override automatic filetype detection. + options include "dm", "empad", "gatan_K2_bin", "mib", "arina", "abTEM" **kwargs: any additional kwargs are passed to the downstream reader - refer to the individual filetype reader function call signatures and docstrings for more details. @@ -55,9 +56,7 @@ def import_file( "RAM", "MEMMAP", ], 'Error: argument mem must be either "RAM" or "MEMMAP"' - assert isinstance( - binfactor, int - ), "Error: argument binfactor must be an integer" + assert isinstance(binfactor, int), "Error: argument binfactor must be an integer" assert binfactor >= 1, "Error: binfactor must be >= 1" if binfactor > 1: assert ( @@ -66,13 +65,17 @@ def import_file( filetype = _parse_filetype(filepath) if filetype is None else filetype - if filetype == 'EMD': - raise Exception("EMD file detected - use py4DSTEM.read, not py4DSTEM.import_file!") + if filetype in ("emd", "legacy"): + raise Exception( + "EMD file or py4DSTEM detected - use py4DSTEM.read, not py4DSTEM.import_file!" + ) assert filetype in [ "dm", "empad", "gatan_K2_bin", - "mib" + "mib", + "arina", + "abTEM" # "kitware_counted", ], "Error: filetype not recognized" @@ -85,10 +88,12 @@ def import_file( # elif filetype == "kitware_counted": # data = read_kitware_counted(filepath, mem, binfactor, metadata=metadata, **kwargs) elif filetype == "mib": - data = load_mib(filepath, mem=mem, binfactor=binfactor,**kwargs) + data = load_mib(filepath, mem=mem, binfactor=binfactor, **kwargs) + elif filetype == "arina": + data = read_arina(filepath, mem=mem, binfactor=binfactor, **kwargs) + elif filetype == "abTEM": + data = read_abTEM(filepath, mem=mem, binfactor=binfactor, **kwargs) else: raise Exception("Bad filetype!") return data - - diff --git a/py4DSTEM/io/legacy/legacy12/read_v0_12.py b/py4DSTEM/io/legacy/legacy12/read_v0_12.py index a327a567b..470e64075 100644 --- a/py4DSTEM/io/legacy/legacy12/read_v0_12.py +++ b/py4DSTEM/io/legacy/legacy12/read_v0_12.py @@ -9,11 +9,11 @@ PointList, PointListArray ) -from py4DSTEM.classes import ( - DataCube, +from py4DSTEM.data import ( DiffractionSlice, RealSlice, ) +from py4DSTEM.datacube import DataCube from emdfile import tqdmnd def read_v0_12(fp, **kwargs): diff --git a/py4DSTEM/io/legacy/legacy12/read_v0_5.py b/py4DSTEM/io/legacy/legacy12/read_v0_5.py index 109c0a4c3..ae89ba86b 100644 --- a/py4DSTEM/io/legacy/legacy12/read_v0_5.py +++ b/py4DSTEM/io/legacy/legacy12/read_v0_5.py @@ -9,11 +9,11 @@ PointList, PointListArray ) -from py4DSTEM.classes import ( - DataCube, +from py4DSTEM.data import ( DiffractionSlice, RealSlice, ) +from py4DSTEM.datacube import DataCube from emdfile import tqdmnd def read_v0_5(fp, **kwargs): diff --git a/py4DSTEM/io/legacy/legacy12/read_v0_6.py b/py4DSTEM/io/legacy/legacy12/read_v0_6.py index ded73a8c4..d3b6d0d5f 100644 --- a/py4DSTEM/io/legacy/legacy12/read_v0_6.py +++ b/py4DSTEM/io/legacy/legacy12/read_v0_6.py @@ -9,11 +9,11 @@ PointList, PointListArray ) -from py4DSTEM.classes import ( - DataCube, +from py4DSTEM.data import ( DiffractionSlice, RealSlice, ) +from py4DSTEM.datacube import DataCube from emdfile import tqdmnd def read_v0_6(fp, **kwargs): diff --git a/py4DSTEM/io/legacy/legacy12/read_v0_7.py b/py4DSTEM/io/legacy/legacy12/read_v0_7.py index 7635c56f6..34fd917f4 100644 --- a/py4DSTEM/io/legacy/legacy12/read_v0_7.py +++ b/py4DSTEM/io/legacy/legacy12/read_v0_7.py @@ -9,11 +9,11 @@ PointList, PointListArray ) -from py4DSTEM.classes import ( - DataCube, +from py4DSTEM.data import ( DiffractionSlice, RealSlice, ) +from py4DSTEM.datacube import DataCube from emdfile import tqdmnd def read_v0_7(fp, **kwargs): diff --git a/py4DSTEM/io/legacy/legacy12/read_v0_9.py b/py4DSTEM/io/legacy/legacy12/read_v0_9.py index b3fe4583b..37e995871 100644 --- a/py4DSTEM/io/legacy/legacy12/read_v0_9.py +++ b/py4DSTEM/io/legacy/legacy12/read_v0_9.py @@ -9,11 +9,11 @@ PointList, PointListArray ) -from py4DSTEM.classes import ( - DataCube, +from py4DSTEM.data import ( DiffractionSlice, RealSlice, ) +from py4DSTEM.datacube import DataCube from emdfile import tqdmnd def read_v0_9(fp, **kwargs): diff --git a/py4DSTEM/io/legacy/legacy13/v13_emd_classes/array.py b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/array.py index 4b385a694..8b20779f8 100644 --- a/py4DSTEM/io/legacy/legacy13/v13_emd_classes/array.py +++ b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/array.py @@ -322,8 +322,10 @@ def set_dim( values for the n'th dim vector. Accepts: n (int): specifies which dim vector - dim (list or array): length must be either 2, or equal to the - length of the n'th axis of the data array + dim (list or array): length must be either 1 or 2, or equal to the + length of the n'th axis of the data array. If length is 1 specifies step + size of dim vector and starts at 0. If length is 2, specifies start + and step of dim vector. units (Optional, str): name: (Optional, str): """ diff --git a/py4DSTEM/io/legacy/legacy13/v13_to_14.py b/py4DSTEM/io/legacy/legacy13/v13_to_14.py index 1b0deafaa..18c08c777 100644 --- a/py4DSTEM/io/legacy/legacy13/v13_to_14.py +++ b/py4DSTEM/io/legacy/legacy13/v13_to_14.py @@ -1,5 +1,6 @@ # Convert v13 to v14 classes +import numpy as np from emdfile import tqdmnd @@ -36,17 +37,17 @@ PointListArray ) -from py4DSTEM.classes import ( +from py4DSTEM.data import ( Calibration, - DataCube, DiffractionSlice, - VirtualDiffraction, RealSlice, - VirtualImage, - Probe, QPoints, ) -from py4DSTEM.process.diskdetection.braggvectors import BraggVectors +from py4DSTEM.datacube import ( + DataCube, + VirtualImage, + VirtualDiffraction, +) @@ -92,7 +93,7 @@ def _populate_tree(node13,node14,root14): if isinstance(newnode14,Metadata): pass else: - node14.tree(newnode14) + node14.tree(newnode14,force=True) _populate_tree(newnode13,newnode14,root14) @@ -137,9 +138,13 @@ def _v13_to_14_cls(obj): ) elif isinstance(obj, DiffractionSlice13): + if obj.is_stack: + data = np.rollaxis(obj.data, axis=2) + else: + data = obj.data x = DiffractionSlice( name = obj.name, - data = obj.data, + data = data, units = obj.units, slicelabels = obj.slicelabels ) @@ -151,9 +156,13 @@ def _v13_to_14_cls(obj): ) elif isinstance(obj, RealSlice13): + if obj.is_stack: + data = np.rollaxis(obj.data, axis=2) + else: + data = obj.data x = RealSlice( name = obj.name, - data = obj.data, + data = data, units = obj.units, slicelabels = obj.slicelabels ) @@ -167,6 +176,7 @@ def _v13_to_14_cls(obj): pass elif isinstance(obj, Probe13): + from py4DSTEM.braggvectors import Probe x = Probe( name = obj.name, data = obj.data @@ -179,6 +189,7 @@ def _v13_to_14_cls(obj): ) elif isinstance(obj, BraggVectors13): + from py4DSTEM.braggvectors import BraggVectors x = BraggVectors( name = obj.name, Rshape = obj.Rshape, @@ -195,9 +206,13 @@ def _v13_to_14_cls(obj): elif isinstance(obj, Array13): # prepare arguments + if obj.is_stack: + data = np.rollaxis(obj.data, axis=2) + else: + data = obj.data args = { 'name' : obj.name, - 'data' : obj.data + 'data' : data } if hasattr(obj,'units'): args['units'] = obj.units if hasattr(obj,'dim_names'): args['dim_names'] = obj.dim_names diff --git a/py4DSTEM/io/legacy/read_legacy_13.py b/py4DSTEM/io/legacy/read_legacy_13.py index ebf86d5c3..56e931ea9 100644 --- a/py4DSTEM/io/legacy/read_legacy_13.py +++ b/py4DSTEM/io/legacy/read_legacy_13.py @@ -257,7 +257,7 @@ def _get_v13_class(grp): 'QPoints' : QPoints, 'BraggVectors' : BraggVectors } - print(grp) + if 'py4dstem_class' in grp.attrs: classname = grp.attrs['py4dstem_class'] elif 'emd_group_type' in grp.attrs: diff --git a/py4DSTEM/io/parsefiletype.py b/py4DSTEM/io/parsefiletype.py index 5903ce814..1838f89b6 100644 --- a/py4DSTEM/io/parsefiletype.py +++ b/py4DSTEM/io/parsefiletype.py @@ -1,9 +1,18 @@ # File parser utility from os.path import splitext +import py4DSTEM.io.legacy as legacy +import emdfile as emd +import h5py + +import emdfile as emd +import h5py +import py4DSTEM.io.legacy as legacy + def _parse_filetype(fp): - """ Accepts a path to a data file, and returns the file type as a string. + """ + Accepts a path to a data file, and returns the file type as a string. """ _, fext = splitext(fp) fext = fext.lower() @@ -13,7 +22,20 @@ def _parse_filetype(fp): ".py4dstem", ".emd", ]: - return "H5" + if emd._is_EMD_file(fp): + return "emd" + + elif legacy.is_py4DSTEM_file(fp): + return "legacy" + + elif _is_arina(fp): + return "arina" + + elif _is_abTEM(fp): + return "abTEM" + else: + raise Exception("not supported `h5` data type") + elif fext in [ ".dm", ".dm3", @@ -21,17 +43,67 @@ def _parse_filetype(fp): ]: return "dm" elif fext in [".raw"]: - return "empad" + return "empad" elif fext in [".mrc"]: - return "mrc_relativity" + return "mrc_relativity" elif fext in [".gtg", ".bin"]: - return "gatan_K2_bin" + return "gatan_K2_bin" elif fext in [".kitware_counted"]: - return "kitware_counted" + return "kitware_counted" elif fext in [".mib", ".MIB"]: return "mib" else: raise Exception(f"Unrecognized file extension {fext}.") +def _is_arina(filepath): + """ + Check if an h5 file is an Arina file. + """ + with h5py.File(filepath,'r') as f: + try: + assert("entry" in f.keys()) + except AssertionError: + return False + try: + assert("NX_class" in f["entry"].attrs.keys()) + except AssertionError: + return False + return True + +def _is_abTEM(filepath): + """ + Check if an h5 file is an abTEM file. + """ + with h5py.File(filepath,'r') as f: + try: + assert("array" in f.keys()) + except AssertionError: + return False + return True + +def _is_arina(filepath): + """ + Check if an h5 file is an Arina file. + """ + with h5py.File(filepath, "r") as f: + try: + assert "entry" in f.keys() + except AssertionError: + return False + try: + assert "NX_class" in f["entry"].attrs.keys() + except AssertionError: + return False + return True +def _is_abTEM(filepath): + """ + Check if an h5 file is an abTEM file. + """ + with h5py.File(filepath, "r") as f: + try: + assert "array" in f.keys() + except AssertionError: + return False + return True diff --git a/py4DSTEM/io/read.py b/py4DSTEM/io/read.py index 3fe94ca41..bab555eaf 100644 --- a/py4DSTEM/io/read.py +++ b/py4DSTEM/io/read.py @@ -1,24 +1,23 @@ # Reader for native files -from pathlib import Path +import warnings from os.path import exists -from typing import Optional,Union +from pathlib import Path +from typing import Optional, Union -import py4DSTEM import emdfile as emd -from py4DSTEM.io.parsefiletype import _parse_filetype import py4DSTEM.io.legacy as legacy - - +from py4DSTEM.data import Data +from py4DSTEM.io.parsefiletype import _parse_filetype def read( - filepath: Union[str,Path], + filepath: Union[str, Path], datapath: Optional[str] = None, - tree: Optional[Union[bool,str]] = True, + tree: Optional[Union[bool, str]] = True, verbose: Optional[bool] = False, **kwargs, - ): +): """ A file reader for native py4DSTEM / EMD files. To read non-native formats, use `py4DSTEM.import_file`. @@ -65,85 +64,133 @@ def read( # parse filetype er1 = f"filepath must be a string or Path, not {type(filepath)}" er2 = f"specified filepath '{filepath}' does not exist" - assert(isinstance(filepath, (str,Path) )), er1 - assert(exists(filepath)), er2 + assert isinstance(filepath, (str, Path)), er1 + assert exists(filepath), er2 filetype = _parse_filetype(filepath) - assert filetype == "H5", f"`py4DSTEM.read` loads native HDF5 formatted files, but a file of type {filetype} was detected. Try loading it using py4DSTEM.import_file" + assert filetype in ( + "emd", + "legacy", + ), f"`py4DSTEM.read` loads native HDF5 formatted files, but a file of type {filetype} was detected. Try loading it using py4DSTEM.import_file" + # support older `root` input + if datapath is None: + if "root" in kwargs: + datapath = kwargs["root"] # EMD 1.0 formatted files (py4DSTEM v0.14+) - if emd._is_EMD_file(filepath): + if filetype == "emd": + + # check version version = emd._get_EMD_version(filepath) - if verbose: print(f"EMD version {version[0]}.{version[1]}.{version[2]} detected. Reading...") - assert emd._version_is_geq(version,(1,0,0)), f"EMD version {version} detected. Expected version >= 1.0.0" - data = emd.read( - filepath, - emdpath = datapath, - tree = tree - ) - if verbose: print("Done.") + if verbose: + print( + f"EMD version {version[0]}.{version[1]}.{version[2]} detected. Reading..." + ) + assert emd._version_is_geq( + version, (1, 0, 0) + ), f"EMD version {version} detected. Expected version >= 1.0.0" + + # read + data = emd.read(filepath, emdpath=datapath, tree=tree) + if verbose: + print("Data was read from file. Adding calibration links...") + + # add calibration links + if isinstance(data, Data): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + cal = data.calibration + elif isinstance(data, emd.Root): + try: + cal = data.metadata["calibration"] + except KeyError: + cal = None + else: + cal = None + if cal is not None: + try: + root_treepath = cal["_root_treepath"] + target_paths = cal["_target_paths"] + del cal._params["_target_paths"] + for p in target_paths: + try: + p = p.replace(root_treepath, "") + d = data.root.tree(p) + cal.register_target(d) + if hasattr(d, "setcal"): + d.setcal() + except AssertionError: + pass + except KeyError: + pass + cal.calibrate() + + # return + if verbose: + print("Done.") return data - # legacy py4DSTEM files (v <= 0.13) else: - assert legacy.is_py4DSTEM_file(filepath), "path points to an H5 file which is neither an EMD 1.0+ file, nor a recognized legacy py4DSTEM file." - + assert ( + filetype == "legacy" + ), "path points to an H5 file which is neither an EMD 1.0+ file, nor a recognized legacy py4DSTEM file." # read v13 if legacy.is_py4DSTEM_version13(filepath): # load the data - if verbose: print(f"Legacy py4DSTEM version 13 file detected. Reading...") - kwargs['root'] = datapath - kwargs['tree'] = tree + if verbose: + print("Legacy py4DSTEM version 13 file detected. Reading...") + kwargs["root"] = datapath + kwargs["tree"] = tree data = legacy.read_legacy13( filepath=filepath, **kwargs, ) - if verbose: print("Done.") + if verbose: + print("Done.") return data - # read <= v12 else: # parse the root/data_id from the datapath arg if datapath is not None: - datapath = datapath.split('/') + datapath = datapath.split("/") try: - datapath.remove('') + datapath.remove("") except ValueError: pass rootgroup = datapath[0] - if len(datapath)>1: - datapath = '/'.join(rootgroup[1:]) + if len(datapath) > 1: + datapath = "/".join(rootgroup[1:]) else: datapath = None else: rootgroups = legacy.get_py4DSTEM_topgroups(filepath) - if len(rootgroups)>1: - print('multiple root groups in a legacy file found - returning list of root names; please pass one as `datapath`') + if len(rootgroups) > 1: + print( + "multiple root groups in a legacy file found - returning list of root names; please pass one as `datapath`" + ) return rootgroups - elif len(rootgroups)==0: - raise Exception('No rootgroups found') + elif len(rootgroups) == 0: + raise Exception("No rootgroups found") else: rootgroup = rootgroups[0] datapath = None - # load the data - if verbose: print(f"Legacy py4DSTEM version <= 12 file detected. Reading...") - kwargs['topgroup'] = rootgroup + if verbose: + print("Legacy py4DSTEM version <= 12 file detected. Reading...") + kwargs["topgroup"] = rootgroup if datapath is not None: - kwargs['data_id'] = datapath + kwargs["data_id"] = datapath data = legacy.read_legacy12( filepath=filepath, **kwargs, ) - if verbose: print("Done.") + if verbose: + print("Done.") return data - - - diff --git a/py4DSTEM/io/save.py b/py4DSTEM/io/save.py index 3d2ae95d8..ce0076724 100644 --- a/py4DSTEM/io/save.py +++ b/py4DSTEM/io/save.py @@ -1,4 +1,40 @@ -from emdfile import save +from emdfile import save as _save +import warnings + +def save( + filepath, + data, + mode='w', + emdpath=None, + tree=True + ): + """ + Saves data to an EMD 1.0 formatted HDF5 file at filepath. + + For the full docstring, see py4DSTEM.emdfile.save. + """ + # This function wraps emdfile's save and adds a small piece + # of metadata to the calibration to allow linking to calibrated + # data items on read + + cal = None + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + if hasattr(data,'calibration') and data.calibration is not None: + cal = data.calibration + rp = '/'.join(data._treepath.split('/')[:-1]) + cal['_root_treepath'] = rp + + _save( + filepath, + data = data, + mode = mode, + emdpath = emdpath, + tree = tree + ) + + if cal is not None: + del(cal._params['_root_treepath']) diff --git a/py4DSTEM/preprocess/preprocess.py b/py4DSTEM/preprocess/preprocess.py index de8ff62c1..30ccc9c0f 100644 --- a/py4DSTEM/preprocess/preprocess.py +++ b/py4DSTEM/preprocess/preprocess.py @@ -208,15 +208,30 @@ def crop_data_real(datacube, crop_Rx_min, crop_Rx_max, crop_Ry_min, crop_Ry_max) return datacube -def bin_data_diffraction(datacube, bin_factor): +def bin_data_diffraction( + datacube, + bin_factor, + dtype=None + ): """ Performs diffraction space binning of data by bin_factor. + + Parameters + ---------- + N : int + The binning factor + dtype : a datatype (optional) + Specify the datatype for the output. If not passed, the datatype + is left unchanged + """ # validate inputs assert(type(bin_factor) is int ), f"Error: binning factor {bin_factor} is not an int." if bin_factor == 1: return datacube + if dtype is None: + dtype = datacube.data.dtype # get shape R_Nx, R_Ny, Q_Nx, Q_Ny = ( @@ -245,12 +260,13 @@ def bin_data_diffraction(datacube, bin_factor): bin_factor, int(Q_Ny / bin_factor), bin_factor, - ).sum(axis=(3, 5)) - + ).sum(axis=(3, 5)).astype(dtype) # set dim vectors Qpixsize = datacube.calibration.get_Q_pixel_size() * bin_factor Qpixunits = datacube.calibration.get_Q_pixel_units() + + datacube.set_dim( 2, [0,Qpixsize], @@ -263,9 +279,11 @@ def bin_data_diffraction(datacube, bin_factor): units = Qpixunits, name = 'Qy' ) + # set calibration pixel size datacube.calibration.set_Q_pixel_size(Qpixsize) + # return return datacube @@ -625,10 +643,17 @@ def resample_data_diffraction( ) resampling_factor = resampling_factor[0] + old_size = datacube.data.shape + datacube.data = fourier_resample( datacube.data, scale=resampling_factor, output_size=output_size ) + if not resampling_factor: + resampling_factor = output_size[0] / old_size[2] + if datacube.calibration.get_Q_pixel_size() is not None: + datacube.calibration.set_Q_pixel_size(datacube.calibration.get_Q_pixel_size() / resampling_factor) + elif method == "bilinear": from scipy.ndimage import zoom @@ -659,6 +684,8 @@ def resample_data_diffraction( resampling_factor = np.concatenate(((1, 1), resampling_factor)) datacube.data = zoom(datacube.data, resampling_factor, order=1) + datacube.calibration.set_Q_pixel_size(datacube.calibration.get_Q_pixel_size() / resampling_factor[2]) + else: raise ValueError( f"'method' needs to be one of 'bilinear' or 'fourier', not {method}." @@ -724,6 +751,24 @@ def pad_data_diffraction(datacube, pad_factor=None, output_size=None): datacube.data = np.pad(datacube.data, pad_width=pad_width, mode="constant") + Qpixsize = datacube.calibration.get_Q_pixel_size() + Qpixunits = datacube.calibration.get_Q_pixel_units() + + datacube.set_dim( + 2, + [0,Qpixsize], + units = Qpixunits, + name = 'Qx' + ) + datacube.set_dim( + 3, + [0,Qpixsize], + units = Qpixunits, + name = 'Qy' + ) + + datacube.calibrate() + return datacube diff --git a/py4DSTEM/process/__init__.py b/py4DSTEM/process/__init__.py index 841314fbb..73088068b 100644 --- a/py4DSTEM/process/__init__.py +++ b/py4DSTEM/process/__init__.py @@ -1,12 +1,8 @@ from py4DSTEM.process.polar import PolarDatacube +from py4DSTEM.process.strain import StrainMap -from py4DSTEM.process import diskdetection -from py4DSTEM.process import virtualdiffraction -from py4DSTEM.process import virtualimage -from py4DSTEM.process import probe from py4DSTEM.process import latticevectors from py4DSTEM.process import phase -from py4DSTEM.process import probe from py4DSTEM.process import calibration from py4DSTEM.process import utils from py4DSTEM.process import classification diff --git a/py4DSTEM/process/calibration/origin.py b/py4DSTEM/process/calibration/origin.py index bc8c0be0c..19d2f0c55 100644 --- a/py4DSTEM/process/calibration/origin.py +++ b/py4DSTEM/process/calibration/origin.py @@ -6,55 +6,55 @@ from scipy.optimize import leastsq from emdfile import tqdmnd, PointListArray -from py4DSTEM.classes import DataCube +from py4DSTEM.datacube import DataCube from py4DSTEM.process.calibration.probe import get_probe_size from py4DSTEM.process.fit import plane,parabola,bezier_two,fit_2D from py4DSTEM.process.utils import get_CoM, add_to_2D_array_from_floats, get_maxima_2D - -# origin setting decorators - -def set_measured_origin(fun): - """ - This is intended as a decorator function to wrap other functions which measure - the position of the origin. If some function `get_the_origin` returns the - position of the origin as a tuple of two (R_Nx,R_Ny)-shaped arrays, then - decorating the function definition like - - >>> @measure_origin - >>> def get_the_origin(...): - - will make the function also save those arrays as the measured origin in the - calibration associated with the data used for the measurement. Any existing - measured origin value will be overwritten. - - For the wrapper to work, the decorated function's first argument must have - a .calibration property, and its first two return values must be qx0,qy0. - """ - @functools.wraps(fun) - def wrapper(*args,**kwargs): - ans = fun(*args,**kwargs) - data = args[0] - cali = data.calibration - cali.set_origin_meas((ans[0],ans[1])) - return ans - return wrapper - - -def set_fit_origin(fun): - """ - See docstring for `set_measured_origin` - """ - @functools.wraps(fun) - def wrapper(*args,**kwargs): - ans = fun(*args,**kwargs) - data = args[0] - cali = data.calibration - cali.set_origin((ans[0],ans[1])) - return ans - return wrapper - +# +# # origin setting decorators +# +# def set_measured_origin(fun): +# """ +# This is intended as a decorator function to wrap other functions which measure +# the position of the origin. If some function `get_the_origin` returns the +# position of the origin as a tuple of two (R_Nx,R_Ny)-shaped arrays, then +# decorating the function definition like +# +# >>> @measure_origin +# >>> def get_the_origin(...): +# +# will make the function also save those arrays as the measured origin in the +# calibration associated with the data used for the measurement. Any existing +# measured origin value will be overwritten. +# +# For the wrapper to work, the decorated function's first argument must have +# a .calibration property, and its first two return values must be qx0,qy0. +# """ +# @functools.wraps(fun) +# def wrapper(*args,**kwargs): +# ans = fun(*args,**kwargs) +# data = args[0] +# cali = data.calibration +# cali.set_origin_meas((ans[0],ans[1])) +# return ans +# return wrapper +# +# +# def set_fit_origin(fun): +# """ +# See docstring for `set_measured_origin` +# """ +# @functools.wraps(fun) +# def wrapper(*args,**kwargs): +# ans = fun(*args,**kwargs) +# data = args[0] +# cali = data.calibration +# cali.set_origin((ans[0],ans[1])) +# return ans +# return wrapper +# @@ -73,11 +73,12 @@ def fit_origin( ): """ Fits the position of the origin of diffraction space to a plane or parabola, - given some 2D arrays (qx0_meas,qy0_meas) of measured center positions, optionally - masked by the Boolean array `mask`. The 2D data arrays may be passed directly as - a 2-tuple to the arg `data`, or, if `data` is either a DataCube or Calibration - instance, they will be retreived automatically. If a DataCube or Calibration are - passed, fitted origin and residuals are stored there directly. + given some 2D arrays (qx0_meas,qy0_meas) of measured center positions, + optionally masked by the Boolean array `mask`. The 2D data arrays may be + passed directly as a 2-tuple to the arg `data`, or, if `data` is either a + DataCube or Calibration instance, they will be retreived automatically. If a + DataCube or Calibration are passed, fitted origin and residuals are stored + there directly. Args: data (2-tuple of 2d arrays): the measured origin position (qx0,qy0) @@ -133,14 +134,14 @@ def fit_origin( # Fit data if mask is None: - popt_x, pcov_x, qx0_fit = fit_2D( + popt_x, pcov_x, qx0_fit, _ = fit_2D( f, qx0_meas, robust=robust, robust_steps=robust_steps, robust_thresh=robust_thresh, ) - popt_y, pcov_y, qy0_fit = fit_2D( + popt_y, pcov_y, qy0_fit, _ = fit_2D( f, qy0_meas, robust=robust, @@ -149,7 +150,7 @@ def fit_origin( ) else: - popt_x, pcov_x, qx0_fit = fit_2D( + popt_x, pcov_x, qx0_fit, _ = fit_2D( f, qx0_meas, robust=robust, @@ -157,7 +158,7 @@ def fit_origin( robust_thresh=robust_thresh, data_mask=mask == True, ) - popt_y, pcov_y, qy0_fit = fit_2D( + popt_y, pcov_y, qy0_fit, _ = fit_2D( f, qy0_meas, robust=robust, @@ -358,4 +359,3 @@ def get_origin_beamstop(datacube: DataCube, mask: np.ndarray, **kwargs): return qx0, qy0 - diff --git a/py4DSTEM/process/calibration/probe.py b/py4DSTEM/process/calibration/probe.py index d77f88500..f42a353b6 100644 --- a/py4DSTEM/process/calibration/probe.py +++ b/py4DSTEM/process/calibration/probe.py @@ -34,6 +34,12 @@ def get_probe_size(DP, thresh_lower=0.01, thresh_upper=0.99, N=100): * **x0**: *(float)* the x position of the central disk center * **y0**: *(float)* the y position of the central disk center """ + from py4DSTEM.braggvectors import Probe + + # parse input + if isinstance(DP,Probe): + DP = DP.probe + thresh_vals = np.linspace(thresh_lower, thresh_upper, N) r_vals = np.zeros(N) diff --git a/py4DSTEM/process/classification/classutils.py b/py4DSTEM/process/classification/classutils.py index 1a7bccfda..bd4d0a053 100644 --- a/py4DSTEM/process/classification/classutils.py +++ b/py4DSTEM/process/classification/classutils.py @@ -3,7 +3,7 @@ import numpy as np from emdfile import tqdmnd, PointListArray -from py4DSTEM.classes import DataCube +from py4DSTEM.datacube import DataCube from py4DSTEM.process.utils import get_shifted_ar def get_class_DP(datacube, class_image, thresh=0.01, xshifts=None, yshifts=None, diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py index 228602692..1c4ac9073 100644 --- a/py4DSTEM/process/diffraction/crystal.py +++ b/py4DSTEM/process/diffraction/crystal.py @@ -898,12 +898,25 @@ def calculate_bragg_peak_histogram( k = np.arange(k_min, k_max + k_step, k_step) k_num = k.shape[0] - # experimental data histogram + # set rotate and ellipse based on their availability + rotate = bragg_peaks.calibration.get_QR_rotation_degrees() + ellipse = bragg_peaks.calibration.get_ellipse() + rotate = False if rotate is None else True + ellipse = False if ellipse is None else True + + # concatenate all peaks bigpl = np.concatenate( [ - bragg_peaks.cal[i, j].data - for i in range(bragg_peaks.shape[0]) - for j in range(bragg_peaks.shape[1]) + bragg_peaks.get_vectors( + rx, + ry, + center = True, + ellipse = ellipse, + pixel = True, + rotate = rotate, + ).data + for rx in range(bragg_peaks.shape[0]) + for ry in range(bragg_peaks.shape[1]) ] ) qr = np.sqrt(bigpl["qx"] ** 2 + bigpl["qy"] ** 2) diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py index b303ee916..bffa5b620 100644 --- a/py4DSTEM/process/diffraction/crystal_ACOM.py +++ b/py4DSTEM/process/diffraction/crystal_ACOM.py @@ -4,10 +4,12 @@ from typing import Union, Optional from emdfile import tqdmnd, PointList, PointListArray -from py4DSTEM.classes import RealSlice +from py4DSTEM.data import RealSlice from py4DSTEM.process.diffraction.utils import Orientation, OrientationMap, axisEqual3D from py4DSTEM.process.utils import electron_wavelength_angstrom +from warnings import warn + from numpy.linalg import lstsq try: import cupy as cp @@ -767,6 +769,18 @@ def match_orientations( num_x=bragg_peaks_array.shape[0], num_y=bragg_peaks_array.shape[1], num_matches=num_matches_return) + + #check cal state + if bragg_peaks_array.calstate['ellipse'] == False: + ellipse = False + warn('Warning: bragg peaks not elliptically calibrated') + else: + ellipse = True + if bragg_peaks_array.calstate['rotate'] == False: + rotate = False + warn('bragg peaks not rotationally calibrated') + else: + rotate = True for rx, ry in tqdmnd( *bragg_peaks_array.shape, @@ -774,9 +788,17 @@ def match_orientations( unit=" PointList", disable=not progress_bar, ): + vectors = bragg_peaks_array.get_vectors( + scan_x=rx, + scan_y=ry, + center=True, + ellipse=ellipse, + pixel=True, + rotate=rotate + ) orientation = self.match_single_pattern( - bragg_peaks_array.get_pointlist(rx, ry), + bragg_peaks=vectors, num_matches_return=num_matches_return, min_number_peaks=min_number_peaks, inversion_symmetry=inversion_symmetry, @@ -1616,9 +1638,10 @@ def calculate_strain( # Initialize empty strain maps strain_map = RealSlice( data=np.zeros(( + 5, bragg_peaks_array.shape[0], - bragg_peaks_array.shape[1], - 5)), + bragg_peaks_array.shape[1] + )), slicelabels=('e_xx','e_yy','e_xy','theta','mask'), name='strain_map') if mask_from_corr: @@ -1638,6 +1661,18 @@ def calculate_strain( corr_kernel_size = self.orientation_kernel_size radius_max_2 = corr_kernel_size**2 + #check cal state + if bragg_peaks_array.calstate['ellipse'] == False: + ellipse = False + warn('bragg peaks not elliptically calibrated') + else: + ellipse = True + if bragg_peaks_array.calstate['rotate'] == False: + rotate = False + warn('bragg peaks not rotationally calibrated') + else: + rotate = True + # Loop over all probe positions for rx, ry in tqdmnd( *bragg_peaks_array.shape, @@ -1646,7 +1681,14 @@ def calculate_strain( disable=not progress_bar, ): # Get bragg peaks from experiment and reference - p = bragg_peaks_array.get_pointlist(rx,ry) + p = bragg_peaks_array.get_vectors( + scan_x=rx, + scan_y=ry, + center=True, + ellipse=ellipse, + pixel=True, + rotate=rotate + ) if p.data.shape[0] >= min_num_peaks: p_ref = self.generate_diffraction_pattern( @@ -2069,5 +2111,4 @@ def symmetry_reduce_directions( } # "-3m": ["fiber", [0, 0, 1], [90.0, 60.0]], - # "-3m": ["fiber", [0, 0, 1], [180.0, 30.0]], - + # "-3m": ["fiber", [0, 0, 1], [180.0, 30.0]], \ No newline at end of file diff --git a/py4DSTEM/process/diffraction/crystal_calibrate.py b/py4DSTEM/process/diffraction/crystal_calibrate.py index 2d08cd03c..1b65480f5 100644 --- a/py4DSTEM/process/diffraction/crystal_calibrate.py +++ b/py4DSTEM/process/diffraction/crystal_calibrate.py @@ -24,7 +24,7 @@ def calibrate_pixel_size( k_step = 0.002, k_broadening = 0.002, fit_all_intensities = True, - set_calibration = True, + set_calibration_in_place = False, verbose = True, plot_result = False, figsize: Union[list, tuple, np.ndarray] = (12, 6), @@ -60,8 +60,13 @@ def calibrate_pixel_size( figsize (list, tuple, np.ndarray): Figure size of the plot. returnfig (bool): Return handles figure and axis - Returns: - fig, ax (handles): Optional figure and axis handles, if returnfig=True. + Returns + _______ + + + + fig, ax: handles, optional + Figure and axis handles, if returnfig=True. """ @@ -112,17 +117,21 @@ def fit_profile(k, *coefs): # Get the answer pix_size_prev = bragg_peaks.calibration.get_Q_pixel_size() - ans = pix_size_prev / scale_pixel_size + pixel_size_new = pix_size_prev / scale_pixel_size - # if requested, apply calibrations - if set_calibration: - bragg_peaks.calibration.set_Q_pixel_size( ans ) + # if requested, apply calibrations in place + if set_calibration_in_place: + bragg_peaks.calibration.set_Q_pixel_size( pixel_size_new ) bragg_peaks.calibration.set_Q_pixel_units('A^-1') - bragg_peaks.setcal() - # Output + # Output calibrated Bragg peaks + bragg_peaks_cali = bragg_peaks.copy() + bragg_peaks_cali.calibration.set_Q_pixel_size( pixel_size_new ) + bragg_peaks_cali.calibration.set_Q_pixel_units('A^-1') + + # Output pixel size if verbose: - print(f"Calibrated pixel size = {np.round(ans, decimals=8)} A^-1") + print(f"Calibrated pixel size = {np.round(pixel_size_new, decimals=8)} A^-1") # Plotting if plot_result: @@ -163,9 +172,9 @@ def fit_profile(k, *coefs): # return if returnfig and plot_result: - return ans, (fig,ax) + return bragg_peaks_cali, (fig,ax) else: - return ans + return bragg_peaks_cali @@ -463,4 +472,4 @@ def fitfun(self, k, *coefs_fit): "432": [[0, 0, 0, 3, 3, 3], [True, True, True, False, False, False]], #Cubic "-43m": [[0, 0, 0, 3, 3, 3], [True, True, True, False, False, False]], #Cubic "m-3m": [[0, 0, 0, 3, 3, 3], [True, True, True, False, False, False]], #Cubic - } + } \ No newline at end of file diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py index 9c1f5b667..a9420fee4 100644 --- a/py4DSTEM/process/diffraction/crystal_viz.py +++ b/py4DSTEM/process/diffraction/crystal_viz.py @@ -296,30 +296,48 @@ def plot_scattering_intensity( bragg_k_power=0.0, bragg_intensity_power=1.0, bragg_k_broadening=0.005, - figsize: Union[list, tuple, np.ndarray] = (12, 6), + figsize: Union[list, tuple, np.ndarray] = (10, 4), returnfig: bool = False, ): """ 1D plot of the structure factors - Args: - k_min (float): min k value for profile range. - k_max (float): max k value for profile range. - k_step (float): step size of k in profile range. - k_broadening (float): Broadening of simulated pattern. - k_power_scale (float): Scale SF intensities by k**k_power_scale. - int_power_scale (float): Scale SF intensities**int_power_scale. - int_scale (float): Scale output profile by this value. - remove_origin (bool): Remove origin from plot. - bragg_peaks (BraggVectors): Passed in bragg_peaks for comparison with simulated pattern. - bragg_k_power (float): bragg_peaks scaled by k**bragg_k_power. - bragg_intensity_power (float): bragg_peaks scaled by intensities**bragg_intensity_power. - bragg_k_broadening float): Broadening applied to bragg_peaks. - figsize (list, tuple, np.ndarray): Figure size for plot. - returnfig (bool): Return figure and axes handles if this is True. - - Returns: - fig, ax (optional) figure and axes handles + Parameters + -------- + + k_min: float + min k value for profile range. + k_max: float + max k value for profile range. + k_step: float + Step size of k in profile range. + k_broadening: float + Broadening of simulated pattern. + k_power_scale: float + Scale SF intensities by k**k_power_scale. + int_power_scale: float + Scale SF intensities**int_power_scale. + int_scale: float + Scale output profile by this value. + remove_origin: bool + Remove origin from plot. + bragg_peaks: BraggVectors + Passed in bragg_peaks for comparison with simulated pattern. + bragg_k_power: float + bragg_peaks scaled by k**bragg_k_power. + bragg_intensity_power: float + bragg_peaks scaled by intensities**bragg_intensity_power. + bragg_k_broadening: float + Broadening applied to bragg_peaks. + figsize: list, tuple, np.ndarray + Figure size for plot. + returnfig (bool): + Return figure and axes handles if this is True. + + Returns + -------- + fig, ax (optional) + figure and axes handles """ # k coordinates @@ -342,12 +360,25 @@ def plot_scattering_intensity( # If Bragg peaks are passed in, compute 1D integral if bragg_peaks is not None: + # set rotate and ellipse based on their availability + rotate = bragg_peaks.calibration.get_QR_rotation_degrees() + ellipse = bragg_peaks.calibration.get_ellipse() + rotate = False if rotate is None else True + ellipse = False if ellipse is None else True + # concatenate all peaks bigpl = np.concatenate( [ - bragg_peaks.cal[i, j].data - for i in range(bragg_peaks.shape[0]) - for j in range(bragg_peaks.shape[1]) + bragg_peaks.get_vectors( + rx, + ry, + center = True, + ellipse = ellipse, + pixel = True, + rotate = rotate, + ).data + for rx in range(bragg_peaks.shape[0]) + for ry in range(bragg_peaks.shape[1]) ] ) @@ -903,6 +934,9 @@ def plot_diffraction_pattern( ax.set_ylabel("$q_x$ [Ã…$^{-1}$]") if plot_range_kx_ky is not None: + plot_range_kx_ky = np.array(plot_range_kx_ky) + if plot_range_kx_ky.ndim == 0: + plot_range_kx_ky = np.array((plot_range_kx_ky,plot_range_kx_ky)) ax.set_xlim((-plot_range_kx_ky[0], plot_range_kx_ky[0])) ax.set_ylim((-plot_range_kx_ky[1], plot_range_kx_ky[1])) else: @@ -1846,4 +1880,4 @@ def plot_ring_pattern( plt.show() if returnfig: - return fig, ax + return fig, ax \ No newline at end of file diff --git a/py4DSTEM/process/diffraction/flowlines.py b/py4DSTEM/process/diffraction/flowlines.py index c2fa224fa..cf84f69f5 100644 --- a/py4DSTEM/process/diffraction/flowlines.py +++ b/py4DSTEM/process/diffraction/flowlines.py @@ -15,7 +15,7 @@ def make_orientation_histogram( - bragg_peaks: PointList = None, + bragg_peaks: PointListArray = None, radial_ranges: np.ndarray = None, orientation_map = None, orientation_ind: int = 0, @@ -519,6 +519,7 @@ def make_flowline_rainbow_image( power_scaling = 1.0, sum_radial_bins = False, plot_images = True, + figsize = None, ): """ Generate RGB output images from the flowline arrays. @@ -535,6 +536,7 @@ def make_flowline_rainbow_image( power_scaling (float): Power law scaling for flowline intensity output. sum_radial_bins (bool): Sum all radial bins (alternative is to output separate images). plot_images (bool): Plot the outputs for quick visualization. + figsize (2-tuple): Size of output figure. Returns: im_flowline (array): 3D or 4D array containing flowline images @@ -613,7 +615,14 @@ def make_flowline_rainbow_image( im_flowline = np.min(im_flowline,axis=0)[None,:,:,:] if plot_images is True: - fig,ax = plt.subplots(im_flowline.shape[0],1,figsize=(10,im_flowline.shape[0]*10)) + if figsize is None: + fig,ax = plt.subplots( + im_flowline.shape[0],1, + figsize=(10,im_flowline.shape[0]*10)) + else: + fig,ax = plt.subplots( + im_flowline.shape[0],1, + figsize=figsize) if im_flowline.shape[0] > 1: for a0 in range(im_flowline.shape[0]): @@ -729,6 +738,7 @@ def make_flowline_combined_image( power_scaling = 1.0, sum_radial_bins = True, plot_images = True, + figsize = None, ): """ Generate RGB output images from the flowline arrays. @@ -742,6 +752,7 @@ def make_flowline_combined_image( power_scaling (float): Power law scaling for flowline intensities. sum_radial_bins (bool): Sum outputs over radial bins. plot_images (bool): Plot the output images for quick visualization. + figsize (2-tuple): Size of output figure. Returns: im_flowline (array): flowline images @@ -787,7 +798,14 @@ def make_flowline_combined_image( if plot_images is True: - fig,ax = plt.subplots(im_flowline.shape[0],1,figsize=(10,im_flowline.shape[0]*10)) + if figsize is None: + fig,ax = plt.subplots( + im_flowline.shape[0],1, + figsize=(10,im_flowline.shape[0]*10)) + else: + fig,ax = plt.subplots( + im_flowline.shape[0],1, + figsize=figsize) if im_flowline.shape[0] > 1: for a0 in range(im_flowline.shape[0]): @@ -1143,4 +1161,4 @@ def set_intensity(orient,xy_t_int): mode=['clip','clip','wrap']) orient.ravel()[inds_1D] = orient.ravel()[inds_1D] + xy_t_int[:,3]*( dx)*( dy)*( dt) - return orient + return orient \ No newline at end of file diff --git a/py4DSTEM/process/diskdetection/__init__.py b/py4DSTEM/process/diskdetection/__init__.py deleted file mode 100644 index 76b0ea25c..000000000 --- a/py4DSTEM/process/diskdetection/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from py4DSTEM.process.diskdetection.braggvectors import BraggVectors -from py4DSTEM.process.diskdetection.braggvector_methods import BraggVectorMap -from py4DSTEM.process.diskdetection.diskdetection import * - -#from .diskdetection_aiml import * -#from .diskdetection_parallel_new import * - diff --git a/py4DSTEM/process/fit/fit.py b/py4DSTEM/process/fit/fit.py index e3bb4d86e..32809ddb1 100644 --- a/py4DSTEM/process/fit/fit.py +++ b/py4DSTEM/process/fit/fit.py @@ -29,44 +29,49 @@ def fit_1D_gaussian(xdata,ydata,xmin,xmax): A,mu,sigma = scale*popt[0],popt[1],popt[2] return A,mu,sigma -def fit_2D(function, data, data_mask=None, popt=None, - robust=False, robust_steps=3, robust_thresh=2): +def fit_2D( + function, + data, + data_mask=None, + popt=None, + robust=False, + robust_steps=3, + robust_thresh=2, + ): """ - Performs a 2D fit, where the fit function takes its first input in the form of a - length 2 vector (ndarray) of (x,y) positions, followed by the remaining parameters, - and the data to fit takes the form of an (n,m) shaped array. Robust fitting can be - enabled to iteratively reject outlier data points, which have a root-mean-square - error beyond the user-specified threshold. - - Args: - function: First input should be a length 2 array xy, where (xy[0],xy[1]) are the - (x,y) coordinates - data: Data to fit, in an (n,m) shaped ndarray - data_mask: Optional parameter. If specified, must be a boolean array of the same - shape as data, specifying which elements of data to use in the fit - return_ar: Optional parameter. If False, only the fit parameters and covariance - matrix are returned. If True, return an array of the same shape as data with - the fit values. Defaults to True - popt: Optional parameter for input. If specified, should be a tuple of initial - guesses for the fit parameters. - robust: Optional parameter. If set to True, fit will be repeated with outliers - removed. - robust_steps: Optional parameter. Number of robust iterations performed after - initial fit. - robust_thresh: Optional parameter. Threshold for including points, in units of - root-mean-square (standard deviations) error of the predicted values after - fitting. + Performs a 2D fit. - Returns: - (3-tuple) A 3-tuple containing: + TODO: make returning the mask optional + + Parameters + ---------- + function : callable + Some `function( xy, **p)` where `xy` is a length 2 vector (1D np array) + specifying the pixel position (x,y), and `p` is the function parameters + data : ndarray + Some 2D array of any shape (n,m) + data_mask : None or boolean array of shape (n,m), optional + If specified, fits only the pixels in `data` where this array is True + popt : dict + Initial guess at the parameters `p` of `function`. Note that positions + in pixels (i.e. the xy positions) are linearly scaled to the space [0,1] + robust : bool + Toggles robust fitting, which iteratively rejects outlier data points + which have a root-mean-square error beyond `robust_thresh` + robust_steps : int + The number of robust fitting iterations to perform + robust_thresh : int + The robust fitting cutoff - * **popt**: optimal fit parameters to function - * **pcov**: the covariance matrix - * **fit_ar**: optional. If return_ar==True, fit_ar is returned, and is an - array of the same shape as data, containing the fit values + Returns: + (popt,pcov,fit_at, mask) : 4-tuple + The optimal fit parameters, the fitting covariance matrix, the + the fit array with the returned `popt` params, and the mask """ + # get shape shape = data.shape shape1D = [1,np.prod(shape)] + # x and y coordinates normalized from 0 to 1 x,y = np.linspace(0, 1, shape[0]),np.linspace(0, 1, shape[1]) ry,rx = np.meshgrid(y,x) @@ -78,8 +83,10 @@ def fit_2D(function, data, data_mask=None, popt=None, if robust==False: robust_steps=0 - # least squares fitting - 1st iteration + # least squares fitting for k in range(robust_steps+1): + + # in 1st iteration, set up params and mask if k == 0: if popt is None: popt = np.zeros((1,len(signature(function).parameters)-1)) @@ -87,21 +94,26 @@ def fit_2D(function, data, data_mask=None, popt=None, mask = data_mask else: mask = np.ones(shape,dtype=bool) + + # otherwise, get fitting error and add high error pixels to mask else: - fit_mean_square_error = (function(xy,*popt).reshape(shape) - data)**2 - mask = fit_mean_square_error <= np.mean(fit_mean_square_error) * robust_thresh**2 - # include user-specified mask if provided - if data_mask is not None: - mask[data_mask==False] = False + fit_mean_square_error = ( + function(xy,*popt).reshape(shape) - data)**2 + _mask = fit_mean_square_error > np.mean( + fit_mean_square_error) * robust_thresh**2 + mask[_mask] == False # perform fitting - popt, pcov = curve_fit(function, - np.vstack((rx_1D[mask.reshape(shape1D)],ry_1D[mask.reshape(shape1D)])), + popt, pcov = curve_fit( + function, + np.vstack(( + rx_1D[mask.reshape(shape1D)], + ry_1D[mask.reshape(shape1D)])), data[mask], p0=popt) fit_ar = function(xy,*popt).reshape(shape) - return popt, pcov, fit_ar + return popt, pcov, fit_ar, mask # Functions for fitting @@ -127,5 +139,133 @@ def bezier_two(xy, c00, c01, c02, c10, c11, c12, c20, c21, c22): c12*2*(1-xy[0])*xy[0] * (xy[1]**2) + \ c22 *(xy[0]**2) * (xy[1]**2) +def polar_gaussian_2D( + tq, + I0, + mu_t, + mu_q, + sigma_t, + sigma_q, + C, + ): + # unpack position + t,q = tq + # set theta value to its closest periodic reflection to mu_t + #t = np.square(t-mu_t) + #t2 = np.min(np.vstack([t,1-t])) + t2 = np.square(t-mu_t) + return \ + I0 * np.exp( + - ( t2/(2*sigma_t**2) + \ + (q-mu_q)**2/(2*sigma_q**2) ) ) + C + + +def polar_twofold_gaussian_2D( + tq, + I0, + mu_t, + mu_q, + sigma_t, + sigma_q, + ): + + # unpack position + t,q = tq + + # theta periodicity + dt = np.mod(t - mu_t + np.pi/2, np.pi) - np.pi/2 + + # output intensity + return I0 * np.exp( + (dt**2 / (-2.0*sigma_t**2)) + \ + ((q - mu_q)**2 / (-2.0*sigma_q**2)) ) + +def polar_twofold_gaussian_2D_background( + tq, + I0, + mu_t, + mu_q, + sigma_t, + sigma_q, + C, + ): + + # unpack position + t,q = tq + # theta periodicity + dt = np.mod(t - mu_t + np.pi/2, np.pi) - np.pi/2 + + # output intensity + return C + I0 * np.exp( + (dt**2 / (-2.0*sigma_t**2)) + \ + ((q - mu_q)**2 / (-2.0*sigma_q**2)) ) + + +def fit_2D_polar_gaussian( + data, + mask = None, + p0 = None, + robust = False, + robust_steps = 3, + robust_thresh = 2, + constant_background = False, + ): + """ + + NOTE - this cannot work without using pixel coordinates - something is wrong in the workflow. + + + Fits a 2D gaussian to the pixels in `data` which are set to True in `mask`. + + The gaussian is anisotropic and oriented along (t,q), centered at + (mu_t,mu_q), has standard deviations (sigma_t,sigma_q), maximum of I0, + and an optional constant offset of C, and is periodic in t. + + f(x,y) = I0 * exp( - (x-mu_x)^2/(2sig_x^2) + (y-mu_y)^2/(2sig_y^2) ) + or + f(x,y) = I0 * exp( - (x-mu_x)^2/(2sig_x^2) + (y-mu_y)^2/(2sig_y^2) ) + C + + Parameters + ---------- + data : 2d array + the data to fit + p0 : 6-tuple + initial guess at fit parameters, (I0,mu_x,mu_y,sigma_x_sigma_y,C) + mask : 2d boolean array + ignore pixels where mask is False + robust : bool + toggle robust fitting + robust_steps : int + number of robust fit iterations + robust_thresh : number + the robust fitting threshold + constant_background : bool + whether or not to include constant background + + Returns + ------- + (popt,pcov,fit_ar) : 3-tuple + the optimal fit parameters, the covariance matrix, and the fit array + """ + if constant_background: + return fit_2D( + polar_twofold_gaussian_2D_background, + data = data, + data_mask = mask, + popt = p0, + robust = robust, + robust_steps = robust_steps, + robust_thresh = robust_thresh + ) + else: + return fit_2D( + polar_twofold_gaussian_2D, + data = data, + data_mask = mask, + popt = p0, + robust = robust, + robust_steps = robust_steps, + robust_thresh = robust_thresh + ) diff --git a/py4DSTEM/process/latticevectors/fit.py b/py4DSTEM/process/latticevectors/fit.py index dae5d485f..fef72aca3 100644 --- a/py4DSTEM/process/latticevectors/fit.py +++ b/py4DSTEM/process/latticevectors/fit.py @@ -4,7 +4,7 @@ from numpy.linalg import lstsq from emdfile import tqdmnd, PointList, PointListArray -from py4DSTEM.classes import RealSlice +from py4DSTEM.data import RealSlice def fit_lattice_vectors(braggpeaks, x0=0, y0=0, minNumPeaks=5): """ @@ -104,13 +104,22 @@ def fit_lattice_vectors_all_DPs(braggpeaks, x0=0, y0=0, minNumPeaks=5): # Make RealSlice to contain outputs slicelabels = ('x0','y0','g1x','g1y','g2x','g2y','error','mask') - g1g2_map = RealSlice(data=np.zeros((braggpeaks.shape[0],braggpeaks.shape[1],8)), - slicelabels=slicelabels, name='g1g2_map') + g1g2_map = RealSlice( + data=np.zeros( + (8, braggpeaks.shape[0],braggpeaks.shape[1]) + ), + slicelabels=slicelabels, name='g1g2_map' + ) # Fit lattice vectors for (Rx, Ry) in tqdmnd(braggpeaks.shape[0],braggpeaks.shape[1]): braggpeaks_curr = braggpeaks.get_pointlist(Rx,Ry) - qx0,qy0,g1x,g1y,g2x,g2y,error = fit_lattice_vectors(braggpeaks_curr, x0, y0, minNumPeaks) + qx0,qy0,g1x,g1y,g2x,g2y,error = fit_lattice_vectors( + braggpeaks_curr, + x0, + y0, + minNumPeaks + ) # Store data if g1x is not None: g1g2_map.get_slice('x0').data[Rx,Ry] = qx0 diff --git a/py4DSTEM/process/latticevectors/index.py b/py4DSTEM/process/latticevectors/index.py index cdf6b00fd..189e7f10f 100644 --- a/py4DSTEM/process/latticevectors/index.py +++ b/py4DSTEM/process/latticevectors/index.py @@ -80,6 +80,9 @@ def index_bragg_directions(x0, y0, gx, gy, g1, g2): temp_array = np.zeros([], dtype = coords) bragg_directions = PointList(data = temp_array) bragg_directions.add_data_by_field((gx,gy,h,k)) + mask = np.zeros(bragg_directions['qx'].shape[0]) + mask[0] = 1 + bragg_directions.remove(mask) return h,k, bragg_directions @@ -152,8 +155,14 @@ def generate_lattice(ux,uy,vx,vy,x0,y0,Q_Nx,Q_Ny,h_max=None,k_max=None): return ideal_lattice -def add_indices_to_braggpeaks(braggpeaks, lattice, maxPeakSpacing, qx_shift=0, - qy_shift=0, mask=None): +def add_indices_to_braggvectors( + braggpeaks, + lattice, + maxPeakSpacing, + qx_shift=0, + qy_shift=0, + mask=None + ): """ Using the peak positions (qx,qy) and indices (h,k) in the PointList lattice, identify the indices for each peak in the PointListArray braggpeaks. @@ -181,43 +190,41 @@ def add_indices_to_braggpeaks(braggpeaks, lattice, maxPeakSpacing, qx_shift=0, 'h', 'k', containing the indices of each indexable peak. """ - assert isinstance(braggpeaks,PointListArray) - assert np.all([name in braggpeaks.dtype.names for name in ('qx','qy','intensity')]) - assert isinstance(lattice, PointList) - assert np.all([name in lattice.dtype.names for name in ('qx','qy','h','k')]) + # assert isinstance(braggpeaks,BraggVectors) + # assert isinstance(lattice, PointList) + # assert np.all([name in lattice.dtype.names for name in ('qx','qy','h','k')]) if mask is None: - mask = np.ones(braggpeaks.shape,dtype=bool) + mask = np.ones(braggpeaks.Rshape,dtype=bool) - assert mask.shape == braggpeaks.shape, 'mask must have same shape as pointlistarray' + assert mask.shape == braggpeaks.Rshape, 'mask must have same shape as pointlistarray' assert mask.dtype == bool, 'mask must be boolean' - indexed_braggpeaks = braggpeaks.copy() - # add the coordinates if they don't exist - if not ('h' in braggpeaks.dtype.names): - indexed_braggpeaks = indexed_braggpeaks.add_fields([('h',int)]) - if not ('k' in braggpeaks.dtype.names): - indexed_braggpeaks = indexed_braggpeaks.add_fields([('k',int)]) + coords = [('qx',float),('qy',float),('intensity',float),('h',int),('k',int)] + + indexed_braggpeaks = PointListArray( + dtype = coords, + shape = braggpeaks.Rshape, + ) # loop over all the scan positions for Rx, Ry in tqdmnd(mask.shape[0],mask.shape[1]): - if mask[Rx,Ry]: - pl = indexed_braggpeaks.get_pointlist(Rx,Ry) - rm_peak_mask = np.zeros(pl.length,dtype=bool) - - for i in range(pl.length): + if mask[Rx,Ry]: + pl = braggpeaks.cal[Rx,Ry] + for i in range(pl.data.shape[0]): r2 = (pl.data['qx'][i]-lattice.data['qx'] + qx_shift)**2 + \ (pl.data['qy'][i]-lattice.data['qy'] + qy_shift)**2 ind = np.argmin(r2) if r2[ind] <= maxPeakSpacing**2: - pl.data['h'][i] = lattice.data['h'][ind] - pl.data['k'][i] = lattice.data['k'][ind] - else: - rm_peak_mask[i] = True - pl.remove(rm_peak_mask) + indexed_braggpeaks[Rx,Ry].add_data_by_field(( + pl.data['qx'][i], + pl.data['qy'][i], + pl.data['intensity'][i], + lattice.data['h'][ind], + lattice.data['k'][ind] + )) - indexed_braggpeaks.name = braggpeaks.name + "_indexed" return indexed_braggpeaks diff --git a/py4DSTEM/process/latticevectors/strain.py b/py4DSTEM/process/latticevectors/strain.py index 0521f703c..7a586bd69 100644 --- a/py4DSTEM/process/latticevectors/strain.py +++ b/py4DSTEM/process/latticevectors/strain.py @@ -3,7 +3,7 @@ import numpy as np from numpy.linalg import lstsq -from py4DSTEM.classes import RealSlice +from py4DSTEM.data import RealSlice def get_reference_g1g2(g1g2_map, mask): """ @@ -71,9 +71,11 @@ def get_strain_from_reference_g1g2(g1g2_map, g1, g2): # Get RealSlice for output storage R_Nx,R_Ny = g1g2_map.get_slice('g1x').shape - strain_map = RealSlice(data=np.zeros((R_Nx,R_Ny,5)), - slicelabels=('e_xx','e_yy','e_xy','theta','mask'), - name='strain_map') + strain_map = RealSlice( + data=np.zeros((5, R_Nx, R_Ny)), + slicelabels=('e_xx','e_yy','e_xy','theta','mask'), + name='strain_map' + ) # Get reference lattice matrix g1x,g1y = g1 @@ -130,7 +132,8 @@ def get_strain_from_reference_region(g1g2_map, mask): Note 1: the strain matrix has been symmetrized, so e_xy and e_yx are identical """ assert isinstance(g1g2_map, RealSlice) - assert np.all([name in g1g2_map.slicelabels for name in ('g1x','g1y','g2x','g2y','mask')]) + assert np.all( + [name in g1g2_map.slicelabels for name in ('g1x','g1y','g2x','g2y','mask')]) assert mask.dtype == bool g1,g2 = get_reference_g1g2(g1g2_map,mask) @@ -169,18 +172,20 @@ def get_rotated_strain_map(unrotated_strain_map, xaxis_x, xaxis_y, flip_theta): sint2 = sint**2 Rx,Ry = unrotated_strain_map.get_slice('e_xx').data.shape - rotated_strain_map = RealSlice(data=np.zeros((Rx,Ry,5)), - slicelabels=['e_xx','e_xy','e_yy','theta','mask'], - name=unrotated_strain_map.name+"_rotated".format(np.degrees(theta))) - - rotated_strain_map.data[:,:,0] = cost2*unrotated_strain_map.get_slice('e_xx').data - 2*cost*sint*unrotated_strain_map.get_slice('e_xy').data + sint2*unrotated_strain_map.get_slice('e_yy').data - rotated_strain_map.data[:,:,1] = cost*sint*(unrotated_strain_map.get_slice('e_xx').data-unrotated_strain_map.get_slice('e_yy').data) + (cost2-sint2)*unrotated_strain_map.get_slice('e_xy').data - rotated_strain_map.data[:,:,2] = sint2*unrotated_strain_map.get_slice('e_xx').data + 2*cost*sint*unrotated_strain_map.get_slice('e_xy').data + cost2*unrotated_strain_map.get_slice('e_yy').data + rotated_strain_map = RealSlice( + data=np.zeros((5, Rx,Ry)), + slicelabels=['e_xx','e_xy','e_yy','theta','mask'], + name=unrotated_strain_map.name+"_rotated".format(np.degrees(theta)) + ) + + rotated_strain_map.data[0,:,:] = cost2*unrotated_strain_map.get_slice('e_xx').data - 2*cost*sint*unrotated_strain_map.get_slice('e_xy').data + sint2*unrotated_strain_map.get_slice('e_yy').data + rotated_strain_map.data[1,:,:] = cost*sint*(unrotated_strain_map.get_slice('e_xx').data-unrotated_strain_map.get_slice('e_yy').data) + (cost2-sint2)*unrotated_strain_map.get_slice('e_xy').data + rotated_strain_map.data[2,:,:] = sint2*unrotated_strain_map.get_slice('e_xx').data + 2*cost*sint*unrotated_strain_map.get_slice('e_xy').data + cost2*unrotated_strain_map.get_slice('e_yy').data if flip_theta == True: - rotated_strain_map.data[:,:,3] = -unrotated_strain_map.get_slice('theta').data + rotated_strain_map.data[3,:,:] = -unrotated_strain_map.get_slice('theta').data else: - rotated_strain_map.data[:,:,3] = unrotated_strain_map.get_slice('theta').data - rotated_strain_map.data[:,:,4] = unrotated_strain_map.get_slice('mask').data + rotated_strain_map.data[3,:,:] = unrotated_strain_map.get_slice('theta').data + rotated_strain_map.data[4,:,:] = unrotated_strain_map.get_slice('mask').data return rotated_strain_map diff --git a/py4DSTEM/process/phase/.gitignore b/py4DSTEM/process/phase/.gitignore new file mode 100644 index 000000000..c97f963b3 --- /dev/null +++ b/py4DSTEM/process/phase/.gitignore @@ -0,0 +1 @@ +*.sh diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py index c2e141399..178079349 100644 --- a/py4DSTEM/process/phase/__init__.py +++ b/py4DSTEM/process/phase/__init__.py @@ -1,3 +1,5 @@ +# fmt: off + _emd_hook = True from py4DSTEM.process.phase.iterative_dpc import DPCReconstruction @@ -20,3 +22,9 @@ from py4DSTEM.process.phase.iterative_singleslice_ptychography import ( SingleslicePtychographicReconstruction, ) +from py4DSTEM.process.phase.parameter_optimize import ( + OptimizationParameter, + PtychographyOptimizer, +) + +# fmt: on diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py index c2acbee34..ae4c92d4b 100644 --- a/py4DSTEM/process/phase/iterative_base_class.py +++ b/py4DSTEM/process/phase/iterative_base_class.py @@ -3,7 +3,6 @@ """ import warnings -from typing import Sequence import matplotlib.pyplot as plt import numpy as np @@ -18,7 +17,8 @@ cp = None from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd -from py4DSTEM.classes import Calibration, DataCube +from py4DSTEM.data import Calibration +from py4DSTEM.datacube import DataCube from py4DSTEM.process.calibration import fit_origin from py4DSTEM.process.phase.iterative_ptychographic_constraints import ( PtychographicConstraints, @@ -142,6 +142,12 @@ def _preprocess_datacube_and_vacuum_probe( datacube: Datacube Resampled and Padded datacube """ + if com_shifts is not None: + if np.isscalar(com_shifts[0]): + com_shifts = ( + np.ones(self._datacube.Rshape) * com_shifts[0], + np.ones(self._datacube.Rshape) * com_shifts[1], + ) if diffraction_intensities_shape is not None: Qx, Qy = datacube.shape[-2:] @@ -155,8 +161,6 @@ def _preprocess_datacube_and_vacuum_probe( "Datacube calibration can only handle uniform Q-sampling." ) - Q_pixel_size = datacube.calibration.get_Q_pixel_size() - if com_shifts is not None: com_shifts = ( com_shifts[0] * resampling_factor_x, @@ -193,7 +197,6 @@ def _preprocess_datacube_and_vacuum_probe( output_size=diffraction_intensities_shape, force_nonnegative=True, ) - datacube.calibration.set_Q_pixel_size(Q_pixel_size / resampling_factor_x) if probe_roi_shape is not None: Qx, Qy = datacube.shape[-2:] @@ -221,6 +224,9 @@ def _extract_intensities_and_calibrations_from_datacube( self, datacube: DataCube, require_calibrations: bool = False, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, ): """ Method to extract intensities and calibrations from datacube. @@ -231,6 +237,12 @@ def _extract_intensities_and_calibrations_from_datacube( Input 4D diffraction pattern intensities require_calibrations: bool If False, warning is issued instead of raising an error + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 Assigns -------- @@ -266,8 +278,7 @@ def _extract_intensities_and_calibrations_from_datacube( """ # Copies intensities to device casting to float32 - xp = self._xp - intensities = xp.asarray(datacube.data, dtype=xp.float32) + intensities = datacube.data self._grid_scan_shape = intensities.shape[:2] # Extracts calibrations @@ -276,80 +287,121 @@ def _extract_intensities_and_calibrations_from_datacube( reciprocal_space_units = calibration.get_Q_pixel_units() # Real-space - if real_space_units == "pixels": - if require_calibrations: - raise ValueError("Real-space calibrations must be given in 'A'") - - warnings.warn( - ( - "Iterative reconstruction will not be quantitative unless you specify " - "real-space calibrations in 'A'" - ), - UserWarning, - ) + if force_scan_sampling is not None: + self._scan_sampling = (force_scan_sampling, force_scan_sampling) + self._scan_units = "A" + else: + if real_space_units == "pixels": + if require_calibrations: + raise ValueError("Real-space calibrations must be given in 'A'") - self._scan_sampling = (1.0, 1.0) - self._scan_units = ("pixels",) * 2 + warnings.warn( + ( + "Iterative reconstruction will not be quantitative unless you specify " + "real-space calibrations in 'A'" + ), + UserWarning, + ) - elif real_space_units == "A": - self._scan_sampling = (calibration.get_R_pixel_size(),) * 2 - self._scan_units = ("A",) * 2 - elif real_space_units == "nm": - self._scan_sampling = (calibration.get_R_pixel_size() * 10,) * 2 - self._scan_units = ("A",) * 2 - else: - raise ValueError( - f"Real-space calibrations must be given in 'A', not {real_space_units}" - ) + self._scan_sampling = (1.0, 1.0) + self._scan_units = ("pixels",) * 2 - # Reciprocal-space - if reciprocal_space_units == "pixels": - if require_calibrations: + elif real_space_units == "A": + self._scan_sampling = (calibration.get_R_pixel_size(),) * 2 + self._scan_units = ("A",) * 2 + elif real_space_units == "nm": + self._scan_sampling = (calibration.get_R_pixel_size() * 10,) * 2 + self._scan_units = ("A",) * 2 + else: raise ValueError( - "Reciprocal-space calibrations must be given in in 'A^-1' or 'mrad'" + f"Real-space calibrations must be given in 'A', not {real_space_units}" ) - warnings.warn( - ( - "Iterative reconstruction will not be quantitative unless you specify " - "appropriate reciprocal-space calibrations" - ), - UserWarning, - ) + # Reciprocal-space + if force_angular_sampling is not None or force_reciprocal_sampling is not None: + # there is no xor keyword in Python! + angular = force_angular_sampling is not None + reciprocal = force_reciprocal_sampling is not None + assert (angular and not reciprocal) or ( + not angular and reciprocal + ), "Only one of angular or reciprocal calibration can be forced!" + + # angular calibration specified + if angular: + self._angular_sampling = (force_angular_sampling,) * 2 + self._angular_units = ("mrad",) * 2 - self._angular_sampling = (1.0, 1.0) - self._angular_units = ("pixels",) * 2 - self._reciprocal_sampling = (1.0, 1.0) - self._reciprocal_units = ("pixels",) * 2 + if self._energy is not None: + self._reciprocal_sampling = ( + force_angular_sampling + / electron_wavelength_angstrom(self._energy) + / 1e3, + ) * 2 + self._reciprocal_units = ("A^-1",) * 2 + + # reciprocal calibration specified + if reciprocal: + self._reciprocal_sampling = (force_reciprocal_sampling,) * 2 + self._reciprocal_units = ("A^-1",) * 2 - elif reciprocal_space_units == "A^-1": - reciprocal_size = calibration.get_Q_pixel_size() - self._reciprocal_sampling = (reciprocal_size,) * 2 - self._reciprocal_units = ("A^-1",) * 2 + if self._energy is not None: + self._angular_sampling = ( + force_reciprocal_sampling + * electron_wavelength_angstrom(self._energy) + * 1e3, + ) * 2 + self._angular_units = ("mrad",) * 2 - if self._energy is not None: - self._angular_sampling = ( - reciprocal_size * electron_wavelength_angstrom(self._energy) * 1e3, - ) * 2 - self._angular_units = ("mrad",) * 2 + else: + if reciprocal_space_units == "pixels": + if require_calibrations: + raise ValueError( + "Reciprocal-space calibrations must be given in in 'A^-1' or 'mrad'" + ) + + warnings.warn( + ( + "Iterative reconstruction will not be quantitative unless you specify " + "appropriate reciprocal-space calibrations" + ), + UserWarning, + ) - elif reciprocal_space_units == "mrad": - angular_size = calibration.get_Q_pixel_size() - self._angular_sampling = (angular_size,) * 2 - self._angular_units = ("mrad",) * 2 + self._angular_sampling = (1.0, 1.0) + self._angular_units = ("pixels",) * 2 + self._reciprocal_sampling = (1.0, 1.0) + self._reciprocal_units = ("pixels",) * 2 - if self._energy is not None: - self._reciprocal_sampling = ( - angular_size / electron_wavelength_angstrom(self._energy) / 1e3, - ) * 2 + elif reciprocal_space_units == "A^-1": + reciprocal_size = calibration.get_Q_pixel_size() + self._reciprocal_sampling = (reciprocal_size,) * 2 self._reciprocal_units = ("A^-1",) * 2 - else: - raise ValueError( - ( - "Reciprocal-space calibrations must be given in 'A^-1' or 'mrad', " - f"not {reciprocal_space_units}" + + if self._energy is not None: + self._angular_sampling = ( + reciprocal_size + * electron_wavelength_angstrom(self._energy) + * 1e3, + ) * 2 + self._angular_units = ("mrad",) * 2 + + elif reciprocal_space_units == "mrad": + angular_size = calibration.get_Q_pixel_size() + self._angular_sampling = (angular_size,) * 2 + self._angular_units = ("mrad",) * 2 + + if self._energy is not None: + self._reciprocal_sampling = ( + angular_size / electron_wavelength_angstrom(self._energy) / 1e3, + ) * 2 + self._reciprocal_units = ("A^-1",) * 2 + else: + raise ValueError( + ( + "Reciprocal-space calibrations must be given in 'A^-1' or 'mrad', " + f"not {reciprocal_space_units}" + ) ) - ) return intensities @@ -359,6 +411,7 @@ def _calculate_intensities_center_of_mass( dp_mask: np.ndarray = None, fit_function: str = "plane", com_shifts: np.ndarray = None, + com_measured: np.ndarray = None, ): """ Common preprocessing function to compute and fit diffraction intensities CoM @@ -371,9 +424,10 @@ def _calculate_intensities_center_of_mass( If not None, apply mask to datacube amplitude fit_function: str, optional 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two' - com_shifts, np.ndarray, optional + com_shifts, tuple of ndarrays (CoMx measured, CoMy measured) If not None, com_shifts are fitted on the measured CoM values. - + com_measured: tuple of ndarrays (CoMx measured, CoMy measured) + If not None, com_measured are passed as com_measured_x, com_measured_y Returns ------- @@ -394,39 +448,48 @@ def _calculate_intensities_center_of_mass( xp = self._xp asnumpy = self._asnumpy - # Coordinates - kx = xp.arange(intensities.shape[-2], dtype=xp.float32) - ky = xp.arange(intensities.shape[-1], dtype=xp.float32) - kya, kxa = xp.meshgrid(ky, kx) + intensities = xp.asarray(intensities, dtype=xp.float32) + + # for ptycho + if com_measured: + com_measured_x, com_measured_y = com_measured - # calculate CoM - if dp_mask is not None: - if dp_mask.shape != intensities.shape[-2:]: - raise ValueError( - ( - f"Mask shape should be (Qx,Qy):{intensities.shape[-2:]}, " - f"not {dp_mask.shape}" - ) - ) - intensities_mask = intensities * xp.asarray(dp_mask, dtype=xp.float32) else: - intensities_mask = intensities + # Coordinates + kx = xp.arange(intensities.shape[-2], dtype=xp.float32) + ky = xp.arange(intensities.shape[-1], dtype=xp.float32) + kya, kxa = xp.meshgrid(ky, kx) - intensities_sum = xp.sum(intensities_mask, axis=(-2, -1)) - com_measured_x = ( - xp.sum(intensities_mask * kxa[None, None], axis=(-2, -1)) / intensities_sum - ) - com_measured_y = ( - xp.sum(intensities_mask * kya[None, None], axis=(-2, -1)) / intensities_sum - ) + # calculate CoM + if dp_mask is not None: + if dp_mask.shape != intensities.shape[-2:]: + raise ValueError( + ( + f"Mask shape should be (Qx,Qy):{intensities.shape[-2:]}, " + f"not {dp_mask.shape}" + ) + ) + intensities_mask = intensities * xp.asarray(dp_mask, dtype=xp.float32) + else: + intensities_mask = intensities + + intensities_sum = xp.sum(intensities_mask, axis=(-2, -1)) + com_measured_x = ( + xp.sum(intensities_mask * kxa[None, None], axis=(-2, -1)) + / intensities_sum + ) + com_measured_y = ( + xp.sum(intensities_mask * kya[None, None], axis=(-2, -1)) + / intensities_sum + ) - # Fit function to center of mass if com_shifts is None: com_shifts = fit_origin( (asnumpy(com_measured_x), asnumpy(com_measured_y)), fitfunction=fit_function, ) + # Fit function to center of mass com_fitted_x = xp.asarray(com_shifts[0], dtype=xp.float32) com_fitted_y = xp.asarray(com_shifts[1], dtype=xp.float32) @@ -527,7 +590,7 @@ def _solve_for_center_of_mass_relative_rotation( warnings.warn( ( "Best fit rotation forced to " - f"{str(np.round(force_com_rotation))} degrees." + f"{force_com_rotation:.0f} degrees." ), UserWarning, ) @@ -697,12 +760,7 @@ def _solve_for_center_of_mass_relative_rotation( _rotation_best_rad = rotation_angles_rad[ind_min] if self._verbose: - print( - ( - "Best fit rotation = " - f"{str(np.round(rotation_best_deg))} degrees." - ) - ) + print(("Best fit rotation = " f"{rotation_best_deg:.0f} degrees.")) if plot_rotation: figsize = kwargs.get("figsize", (8, 2)) @@ -857,12 +915,7 @@ def _solve_for_center_of_mass_relative_rotation( self._rotation_angles_deg = rotation_angles_deg # Print summary if self._verbose: - print( - ( - "Best fit rotation = " - f"{str(np.round(rotation_best_deg))} degrees." - ) - ) + print(("Best fit rotation = " f"{rotation_best_deg:.0f} degrees.")) if _rotation_best_transpose: print("Diffraction intensities should be transposed.") else: @@ -948,8 +1001,8 @@ def _solve_for_center_of_mass_relative_rotation( cmap = kwargs.pop("cmap", "RdBu_r") extent = [ 0, - self._scan_sampling[1] * self._intensities.shape[1], - self._scan_sampling[0] * self._intensities.shape[0], + self._scan_sampling[1] * _com_measured_x.shape[1], + self._scan_sampling[0] * _com_measured_x.shape[0], 0, ] @@ -986,8 +1039,8 @@ def _solve_for_center_of_mass_relative_rotation( extent = [ 0, - self._scan_sampling[1] * self._intensities.shape[1], - self._scan_sampling[0] * self._intensities.shape[0], + self._scan_sampling[1] * com_x.shape[1], + self._scan_sampling[0] * com_x.shape[0], 0, ] @@ -1171,6 +1224,15 @@ def to_h5(self, group): data=self._polar_parameters, ) + # object + self._object_emd = Array( + name="reconstruction_object", + data=asnumpy(self._xp.asarray(self._object)), + ) + + # probe + self._probe_emd = Array(name="reconstruction_probe", data=asnumpy(self._probe)) + if is_stack: iterations_labels = [f"iteration_{i:03}" for i in iterations] @@ -1178,32 +1240,20 @@ def to_h5(self, group): object_iterations = [ np.asarray(self.object_iterations[i]) for i in iterations ] - self._object_emd = Array( - name="reconstruction_object", + self._object_iterations_emd = Array( + name="reconstruction_object_iterations", data=np.stack(object_iterations, axis=0), slicelabels=iterations_labels, ) # probe probe_iterations = [self.probe_iterations[i] for i in iterations] - self._probe_emd = Array( - name="reconstruction_probe", + self._probe_iterations_emd = Array( + name="reconstruction_probe_iterations", data=np.stack(probe_iterations, axis=0), slicelabels=iterations_labels, ) - else: - # object - self._object_emd = Array( - name="reconstruction_object", - data=asnumpy(self._xp.asarray(self._object)), - ) - - # probe - self._probe_emd = Array( - name="reconstruction_probe", data=asnumpy(self._probe) - ) - # exit_waves if self._save_exit_waves: self._exit_waves_emd = Array( @@ -1244,13 +1294,8 @@ def _get_constructor_args(cls, group): else: dc = None - # Check if stack - if dict_data["_object_emd"].is_stack: - obj = dict_data["_object_emd"][-1].data - probe = dict_data["_probe_emd"][-1].data - else: - obj = dict_data["_object_emd"].data - probe = dict_data["_probe_emd"].data + obj = dict_data["_object_emd"].data + probe = dict_data["_probe_emd"].data # Populate args and return kwargs = { @@ -1308,8 +1353,8 @@ def _populate_instance(self, group): # Check if stack if hasattr(error, "__len__"): - self.object_iterations = list(dict_data["_object_emd"].data) - self.probe_iterations = list(dict_data["_probe_emd"].data) + self.object_iterations = list(dict_data["_object_iterations_emd"].data) + self.probe_iterations = list(dict_data["_probe_iterations_emd"].data) self.error_iterations = error self.error = error[-1] else: @@ -1318,7 +1363,7 @@ def _populate_instance(self, group): # Slim preprocessing to enable visualize self._positions_px_com = xp.mean(self._positions_px, axis=0) self.object = asnumpy(self._object) - self.probe = asnumpy(self._probe) + self.probe = self.probe_centered self._preprocessed = True def _set_polar_parameters(self, parameters: dict): @@ -1329,11 +1374,6 @@ def _set_polar_parameters(self, parameters: dict): ---------- parameters: dict Mapping from aberration symbols to their corresponding values. - - Mutates - ------- - self._polar_parameters: dict - Updated polar aberrations dictionary """ for symbol, value in parameters.items(): @@ -1359,11 +1399,6 @@ def _calculate_scan_positions_in_pixels(self, positions: np.ndarray): Input probe positions in Ã…. If None, a raster scan using experimental parameters is constructed. - Mutates - ------- - self._object_padding_px: np.ndarray - Object array padding in pixels - Returns ------- positions_in_px: (J,2) np.ndarray @@ -1412,74 +1447,23 @@ def _calculate_scan_positions_in_pixels(self, positions: np.ndarray): positions -= np.min(positions, axis=0) if self._object_padding_px is None: - self._object_padding_px = self._region_of_interest_shape / 2 - positions += self._object_padding_px - - return positions - - def _wrapped_indices_2D_window( - self, - center_position: np.ndarray, - window_shape: Sequence[int], - array_shape: Sequence[int], - ): - """ - Computes periodic indices for a window_shape probe centered at center_position, - in object of size array_shape. - - Parameters - ---------- - center_position: (2,) np.ndarray - The window center positions in pixels - window_shape: (2,) Sequence[int] - The pixel dimensions of the window - array_shape: (2,) Sequence[int] - The pixel dimensions of the array the window will be embedded in - - Returns - ------- - window_indices: length-2 tuple of - The 2D indices of the window - """ - - asnumpy = self._asnumpy - sx, sy = array_shape - nx, ny = window_shape - - cx, cy = np.round(asnumpy(center_position)).astype(int) - ox, oy = (cx - nx // 2, cy - ny // 2) - - return np.ix_(np.arange(ox, ox + nx) % sx, np.arange(oy, oy + ny) % sy) - - def _sum_overlapping_patches(self, patches: np.ndarray): - """ - Sum overlapping patches defined into object shaped array - - Parameters - ---------- - patches: (Rx*Ry,Sx,Sy) np.ndarray - Patches to sum - - Returns - ------- - out_array: (Px,Py) np.ndarray - Summed array - """ - xp = self._xp - positions = self._positions_px - patch_shape = self._region_of_interest_shape - array_shape = self._object_shape + float_padding = self._region_of_interest_shape / 2 + self._object_padding_px = (float_padding, float_padding) + elif np.isscalar(self._object_padding_px[0]): + self._object_padding_px = ( + (self._object_padding_px[0],) * 2, + (self._object_padding_px[1],) * 2, + ) - out_array = xp.zeros(array_shape, patches.dtype) - for ind, pos in enumerate(positions): - indices = self._wrapped_indices_2D_window(pos, patch_shape, array_shape) - out_array[indices] += patches[ind] + positions[:, 0] += self._object_padding_px[0][0] + positions[:, 1] += self._object_padding_px[1][0] - return out_array + return positions def _sum_overlapping_patches_bincounts_base(self, patches: np.ndarray): """ Base bincouts overlapping patches sum function, operating on real-valued arrays. + Note this assumes the probe is corner-centered. Parameters ---------- @@ -1496,8 +1480,8 @@ def _sum_overlapping_patches_bincounts_base(self, patches: np.ndarray): y0 = xp.round(self._positions_px[:, 1]).astype("int") roi_shape = self._region_of_interest_shape - x_ind = xp.round(xp.arange(roi_shape[0]) - roi_shape[0] / 2).astype("int") - y_ind = xp.round(xp.arange(roi_shape[1]) - roi_shape[1] / 2).astype("int") + x_ind = xp.fft.fftfreq(roi_shape[0], d=1 / roi_shape[0]).astype("int") + y_ind = xp.fft.fftfreq(roi_shape[1], d=1 / roi_shape[1]).astype("int") flat_weights = patches.ravel() indices = ( @@ -1539,6 +1523,7 @@ def _sum_overlapping_patches_bincounts(self, patches: np.ndarray): def _extract_vectorized_patch_indices(self): """ Sets the vectorized row/col indices used for the overlap projection + Note this assumes the probe is corner-centered. Returns ------- @@ -1552,8 +1537,8 @@ def _extract_vectorized_patch_indices(self): y0 = xp.round(self._positions_px[:, 1]).astype("int") roi_shape = self._region_of_interest_shape - x_ind = xp.round(xp.arange(roi_shape[0]) - roi_shape[0] / 2).astype("int") - y_ind = xp.round(xp.arange(roi_shape[1]) - roi_shape[1] / 2).astype("int") + x_ind = xp.fft.fftfreq(roi_shape[0], d=1 / roi_shape[0]).astype("int") + y_ind = xp.fft.fftfreq(roi_shape[1], d=1 / roi_shape[1]).astype("int") obj_shape = self._object_shape vectorized_patch_indices_row = ( @@ -1725,8 +1710,7 @@ def tune_angle_and_defocus( fig = plt.figure(figsize=figsize) - progress_bar = kwargs.get("progress_bar", False) - kwargs.pop("progress_bar", None) + progress_bar = kwargs.pop("progress_bar", False) # run loop and plot along the way self._verbose = False for flat_index, (angle, defocus) in enumerate( @@ -1994,7 +1978,7 @@ def _return_fourier_probe( ): """ Returns complex fourier probe shifted to center of array from - complex real space probe in center + corner-centered complex real space probe Parameters ---------- @@ -2013,16 +1997,61 @@ def _return_fourier_probe( else: probe = xp.asarray(probe, dtype=xp.complex64) - return xp.fft.fftshift( - xp.fft.fft2(xp.fft.ifftshift(probe, axes=(-2, -1))), axes=(-2, -1) - ) + return xp.fft.fftshift(xp.fft.fft2(probe), axes=(-2, -1)) + + def _return_fourier_probe_from_centered_probe( + self, + probe=None, + ): + """ + Returns complex fourier probe shifted to center of array from + centered complex real space probe + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses self._probe + + Returns + ------- + fourier_probe: np.ndarray + Fourier-transformed and center-shifted probe. + """ + xp = self._xp + return self._return_fourier_probe(xp.fft.ifftshift(probe, axes=(-2, -1))) + + def _return_centered_probe( + self, + probe=None, + ): + """ + Returns complex probe centered in middle of the array. + + Parameters + ---------- + probe: complex array, optional + if None is specified, uses self._probe + + Returns + ------- + centered_probe: np.ndarray + Center-shifted probe. + """ + xp = self._xp + + if probe is None: + probe = self._probe + else: + probe = xp.asarray(probe, dtype=xp.complex64) + + return xp.fft.fftshift(probe, axes=(-2, -1)) def _return_object_fft( self, obj=None, ): """ - Returns obj fft shifted to center of array + Returns absolute value of obj fft shifted to center of array Parameters ---------- @@ -2040,7 +2069,7 @@ def _return_object_fft( obj = self._object obj = self._crop_rotate_object_fov(asnumpy(obj)) - return np.abs(np.fft.fftshift(np.fft.fft2(obj))) + return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) def show_fourier_probe( self, @@ -2131,13 +2160,21 @@ def show_object_fft(self, obj=None, **kwargs): @property def probe_fourier(self): """Current probe estimate in Fourier space""" - if not hasattr(self, "_probe"): return None asnumpy = self._asnumpy return asnumpy(self._return_fourier_probe(self._probe)) + @property + def probe_centered(self): + """Current probe estimate shifted to the center""" + if not hasattr(self, "_probe"): + return None + + asnumpy = self._asnumpy + return asnumpy(self._return_centered_probe(self._probe)) + @property def object_fft(self): """Fourier transform of current object estimate""" @@ -2180,7 +2217,7 @@ def positions(self): return asnumpy(positions) @property - def _object_cropped(self): - """ cropped and rotated object """ + def object_cropped(self): + """cropped and rotated object""" return self._crop_rotate_object_fov(self._object) diff --git a/py4DSTEM/process/phase/iterative_dpc.py b/py4DSTEM/process/phase/iterative_dpc.py index b2d8ef891..4c80ed177 100644 --- a/py4DSTEM/process/phase/iterative_dpc.py +++ b/py4DSTEM/process/phase/iterative_dpc.py @@ -4,7 +4,7 @@ """ import warnings -from typing import Tuple +from typing import Sequence, Tuple, Union import matplotlib.pyplot as plt import numpy as np @@ -17,7 +17,8 @@ cp = None from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd -from py4DSTEM.classes import Calibration, DataCube +from py4DSTEM.data import Calibration +from py4DSTEM.datacube import DataCube from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction warnings.simplefilter(action="always", category=UserWarning) @@ -238,7 +239,8 @@ def preprocess( fit_function: str = "plane", force_com_rotation: float = None, force_com_transpose: bool = None, - force_com_shifts: float = None, + force_com_shifts: Union[Sequence[np.ndarray], Sequence[float]] = None, + force_com_measured: Sequence[np.ndarray] = None, plot_center_of_mass: str = "default", plot_rotation: bool = True, **kwargs, @@ -269,6 +271,8 @@ def preprocess( Force whether diffraction intensities need to be transposed. force_com_shifts: tuple of ndarrays (CoMx, CoMy) Force CoM fitted shifts + force_com_measured: tuple of ndarrays (CoMx measured, CoMy measured) + Force CoM measured shifts plot_center_of_mass: str, optional If 'default', the corrected CoM arrays will be displayed If 'all', the computed and fitted CoM arrays will be displayed @@ -286,12 +290,18 @@ def preprocess( self._dp_mask = dp_mask if self._datacube is None: - raise ValueError( - ( - "The preprocess() method requires a DataCube. " - "Please run dpc.attach_datacube(DataCube) first." + if force_com_measured is None: + raise ValueError( + ( + "The preprocess() method requires either a DataCube " + "or `force_com_measured`. " + "Please run dpc.attach_datacube(DataCube) to attach DataCube." + ) + ) + else: + self._datacube = DataCube( + data=np.empty(force_com_measured[0].shape + (1, 1)) ) - ) self._intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, @@ -310,6 +320,7 @@ def preprocess( dp_mask=self._dp_mask, fit_function=fit_function, com_shifts=force_com_shifts, + com_measured=force_com_measured, ) ( @@ -337,13 +348,13 @@ def preprocess( padded_object_shape = np.round( np.array(self._grid_scan_shape) * padding_factor ).astype("int") - self._padded_phase_object = xp.zeros(padded_object_shape, dtype=xp.float32) + self._padded_object_phase = xp.zeros(padded_object_shape, dtype=xp.float32) if self._object_phase is not None: - self._padded_phase_object[ + self._padded_object_phase[ : self._grid_scan_shape[0], : self._grid_scan_shape[1] ] = xp.asarray(self._object_phase, dtype=xp.float32) - self._padded_phase_object_initial = self._padded_phase_object.copy() + self._padded_object_phase_initial = self._padded_object_phase.copy() # Fourier coordinates and operators kx = xp.fft.fftfreq(padded_object_shape[0], d=self._scan_sampling[0]) @@ -357,6 +368,10 @@ def preprocess( self._preprocessed = True + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + return self def _forward( @@ -420,9 +435,6 @@ def _forward( xp.mean(self._com_x.ravel() ** 2 + self._com_y.ravel() ** 2) ) - if new_error > error: - step_size /= 2 - return obj_dx, obj_dy, new_error, step_size def _adjoint( @@ -481,7 +493,7 @@ def _update( Returns -------- - updated_padded_phase_object: np.ndarray + updated_padded_object_phase: np.ndarray Updated padded phase object estimate """ @@ -504,15 +516,16 @@ def _object_gaussian_constraint(self, current_object, gaussian_filter_sigma): constrained_object: np.ndarray Constrained object estimate """ - xp = self._xp gaussian_filter = self._gaussian_filter - gaussian_filter_sigma /= xp.sqrt(self.sampling[0] ** 2 + self.sampling[1] ** 2) + gaussian_filter_sigma /= self.sampling[0] current_object = gaussian_filter(current_object, gaussian_filter_sigma) return current_object - def _object_butterworth_constraint(self, current_object, q_lowpass, q_highpass): + def _object_butterworth_constraint( + self, current_object, q_lowpass, q_highpass, butterworth_order + ): """ Butterworth filter used for low/high-pass filtering. @@ -524,6 +537,8 @@ def _object_butterworth_constraint(self, current_object, q_lowpass, q_highpass): Cut-off frequency in A^-1 for low-pass butterworth filter q_highpass: float Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter Returns -------- @@ -539,9 +554,9 @@ def _object_butterworth_constraint(self, current_object, q_lowpass, q_highpass): env = xp.ones_like(qra) if q_highpass: - env *= 1 - 1 / (1 + (qra / q_highpass) ** 4) + env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) if q_lowpass: - env *= 1 / (1 + (qra / q_lowpass) ** 4) + env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) current_object_mean = xp.mean(current_object) current_object -= current_object_mean @@ -550,6 +565,37 @@ def _object_butterworth_constraint(self, current_object, q_lowpass, q_highpass): return xp.real(current_object) + def _object_anti_gridding_contraint(self, current_object): + """ + Zero outer pixels of object fft to remove gridding artifacts + + Parameters + -------- + current_object: np.ndarray + Current object estimate + + Returns + -------- + constrained_object: np.ndarray + Constrained object estimate + """ + xp = self._xp + + # find indices to zero + width_x = current_object.shape[0] + width_y = current_object.shape[1] + ind_min_x = int(xp.floor(width_x / 2) - 2) + ind_max_x = int(xp.ceil(width_x / 2) + 2) + ind_min_y = int(xp.floor(width_y / 2) - 2) + ind_max_y = int(xp.ceil(width_y / 2) + 2) + + # zero pixels + object_fft = xp.fft.fft2(current_object) + object_fft[ind_min_x:ind_max_x] = 0 + object_fft[:, ind_min_y:ind_max_y] = 0 + + return xp.real(xp.fft.ifft2(object_fft)) + def _constraints( self, current_object, @@ -558,6 +604,8 @@ def _constraints( butterworth_filter, q_lowpass, q_highpass, + butterworth_order, + anti_gridding, ): """ DPC constraints operator. @@ -576,6 +624,11 @@ def _constraints( Cut-off frequency in A^-1 for low-pass butterworth filter q_highpass: float Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + anti_gridding: bool + If true, zero outer pixels of object fft to remove + gridding artifacts Returns -------- @@ -592,6 +645,12 @@ def _constraints( current_object, q_lowpass, q_highpass, + butterworth_order, + ) + + if anti_gridding: + current_object = self._object_anti_gridding_contraint( + current_object, ) return current_object @@ -602,12 +661,15 @@ def reconstruct( max_iter: int = 64, step_size: float = None, stopping_criterion: float = 1e-6, + backtrack: bool = True, progress_bar: bool = True, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, butterworth_filter_iter: int = np.inf, q_lowpass: float = None, q_highpass: float = None, + butterworth_order: float = 2, + anti_gridding: float = True, store_iterations: bool = False, ): """ @@ -623,6 +685,10 @@ def reconstruct( Reconstruction update step size stopping_criterion: float, optional step_size below which reconstruction exits + backtrack: bool, optional + If True, steps that increase the error metric are rejected + and iteration continues with a reduced step size from the + previous iteration progress_bar: bool, optional If True, reconstruction progress bar will be printed gaussian_filter_sigma: float, optional @@ -635,6 +701,11 @@ def reconstruct( Cut-off frequency in A^-1 for low-pass butterworth filter q_highpass: float Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter + anti_gridding: bool + If true, zero outer pixels of object fft to remove + gridding artifacts store_iterations: bool, optional If True, all reconstruction iterations will be stored @@ -664,7 +735,7 @@ def reconstruct( if reset: self.error = np.inf self._step_size = step_size if step_size is not None else 0.5 - self._padded_phase_object = self._padded_phase_object_initial.copy() + self._padded_object_phase = self._padded_object_phase_initial.copy() self.error = getattr(self, "error", np.inf) @@ -673,7 +744,7 @@ def reconstruct( else: self._step_size = step_size - mask = xp.zeros(self._padded_phase_object.shape, dtype="bool") + mask = xp.zeros(self._padded_object_phase.shape, dtype="bool") mask[: self._grid_scan_shape[0], : self._grid_scan_shape[1]] = True mask_inv = xp.logical_not(mask) @@ -687,22 +758,33 @@ def reconstruct( if self._step_size < stopping_criterion: break + previous_iteration = self._padded_object_phase.copy() + # forward operator - com_dx, com_dy, self.error, self._step_size = self._forward( - self._padded_phase_object, mask, mask_inv, self.error, self._step_size + com_dx, com_dy, new_error, self._step_size = self._forward( + self._padded_object_phase, mask, mask_inv, self.error, self._step_size ) + # if the error went up after the previous step, go back to the step + # before the error rose and continue with the halved step size + if (new_error > self.error) and backtrack: + self._padded_object_phase = previous_iteration + self._step_size /= 2 + print(f"Iteration {a0}, step reduced to {self._step_size}") + continue + self.error = new_error + # adjoint operator phase_update = self._adjoint(com_dx, com_dy, self._kx_op, self._ky_op) # update - self._padded_phase_object = self._update( - self._padded_phase_object, phase_update, self._step_size + self._padded_object_phase = self._update( + self._padded_object_phase, phase_update, self._step_size ) # constraints - self._padded_phase_object = self._constraints( - self._padded_phase_object, + self._padded_object_phase = self._constraints( + self._padded_object_phase, gaussian_filter=a0 < gaussian_filter_iter and gaussian_filter_sigma is not None, gaussian_filter_sigma=gaussian_filter_sigma, @@ -710,12 +792,14 @@ def reconstruct( and (q_lowpass is not None or q_highpass is not None), q_lowpass=q_lowpass, q_highpass=q_highpass, + butterworth_order=butterworth_order, + anti_gridding=anti_gridding, ) if store_iterations: self.object_phase_iterations.append( asnumpy( - self._padded_phase_object[ + self._padded_object_phase[ : self._grid_scan_shape[0], : self._grid_scan_shape[1] ].copy() ) @@ -729,11 +813,15 @@ def reconstruct( ) # crop result - self._object_phase = self._padded_phase_object[ + self._object_phase = self._padded_object_phase[ : self._grid_scan_shape[0], : self._grid_scan_shape[1] ] self.object_phase = asnumpy(self._object_phase) + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + return self def _visualize_last_iteration( diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py index d09b5a339..56fec1004 100644 --- a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py +++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py @@ -24,7 +24,6 @@ ComplexProbe, fft_shift, generate_batches, - orthogonalize, polar_aliases, polar_symbols, ) @@ -53,7 +52,9 @@ class MixedstatePtychographicReconstruction(PtychographicReconstruction): num_probes: int, optional Number of mixed-state probes semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels rolloff: float, optional Semiangle rolloff for the initial probe guess vacuum_probe_intensity: np.ndarray, optional @@ -92,6 +93,7 @@ def __init__( datacube: DataCube = None, num_probes: int = None, semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, rolloff: float = 2.0, vacuum_probe_intensity: np.ndarray = None, polar_parameters: Mapping[str, float] = None, @@ -172,6 +174,7 @@ def __init__( self._scan_positions = initial_scan_positions self._energy = energy self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px @@ -197,6 +200,9 @@ def preprocess( force_com_rotation: float = None, force_com_transpose: float = None, force_com_shifts: float = None, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, **kwargs, ): @@ -246,6 +252,12 @@ def preprocess( Amplitudes come from diffraction patterns shifted with the CoM in the upper left corner for each probe unless shift is overwritten. + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded @@ -290,6 +302,9 @@ def preprocess( self._intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, ) ( @@ -345,15 +360,22 @@ def preprocess( self._scan_positions ) + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + # Object Initialization if self._object is None: - pad_x, pad_y = self._object_padding_px - p, q = np.max(self._positions_px, axis=0) + pad_x = self._object_padding_px[0][1] + pad_y = self._object_padding_px[1][1] + p, q = np.round(np.max(self._positions_px, axis=0)) p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - int + "int" ) q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - int + "int" ) if self._object_type == "potential": self._object = xp.zeros((p, q), dtype=xp.float32) @@ -400,12 +422,10 @@ def preprocess( self._vacuum_probe_intensity, device=self._device, ) - shift_x = self._region_of_interest_shape[0] // 2 - probe_x0 - shift_y = self._region_of_interest_shape[1] // 2 - probe_y0 self._vacuum_probe_intensity = get_shifted_ar( self._vacuum_probe_intensity, - shift_x, - shift_y, + -probe_x0, + -probe_y0, bilinear=True, device=self._device, ) @@ -444,16 +464,10 @@ def preprocess( # Randomly shift phase of other probes for i_probe in range(1, self._num_probes): shift_x = xp.exp( - -2j - * np.pi - * (xp.random.rand() - 0.5) - * ((xp.arange(sx) + 0.5) / sx - 0.5) + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx) ) shift_y = xp.exp( - -2j - * np.pi - * (xp.random.rand() - 0.5) - * ((xp.arange(sy) + 0.5) / sy - 0.5) + -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy) ) self._probe[i_probe] = ( self._probe[i_probe - 1] * shift_x[:, None] * shift_y[None] @@ -467,6 +481,7 @@ def preprocess( self._probe = xp.asarray(self._probe, dtype=xp.complex64) self._probe_initial = self._probe.copy() + self._probe_initial_aperture = None # Doesn't really make sense for mixed-state self._known_aberrations_array = ComplexProbe( energy=self._energy, @@ -476,8 +491,6 @@ def preprocess( device=self._device, )._evaluate_ctf() - self._known_aberrations_array = xp.fft.ifftshift(self._known_aberrations_array) - # overlaps shifted_probes = fft_shift(self._probe[0], self._positions_px_fractional, xp) probe_intensities = xp.abs(shifted_probes) ** 2 @@ -491,7 +504,7 @@ def preprocess( self._object_fov_mask_inverse = np.invert(self._object_fov_mask) if plot_probe_overlaps: - figsize = kwargs.pop("figsize", (9, 4)) + figsize = kwargs.pop("figsize", (4.5 * self._num_probes + 4, 4)) cmap = kwargs.pop("cmap", "Greys_r") vmin = kwargs.pop("vmin", None) vmax = kwargs.pop("vmax", None) @@ -500,7 +513,7 @@ def preprocess( # initial probe complex_probe_rgb = Complex2RGB( - asnumpy(self._probe[0]), + self.probe_centered, vmin=vmin, vmax=vmax, hue_start=hue_start, @@ -521,45 +534,50 @@ def preprocess( 0, ] - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) - - ax1.imshow( - complex_probe_rgb, - extent=probe_extent, - **kwargs, - ) + fig, axs = plt.subplots(1, self._num_probes + 1, figsize=figsize) - divider = make_axes_locatable(ax1) - cax1 = divider.append_axes("right", size="5%", pad="2.5%") - add_colorbar_arg( - cax1, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert - ) - ax1.set_ylabel("x [A]") - ax1.set_xlabel("y [A]") - ax1.set_title("Initial Probe") + for i in range(self._num_probes): + axs[i].imshow( + complex_probe_rgb[i], + extent=probe_extent, + **kwargs, + ) + axs[i].set_ylabel("x [A]") + axs[i].set_xlabel("y [A]") + axs[i].set_title(f"Initial Probe[{i}]") + + divider = make_axes_locatable(axs[i]) + cax = divider.append_axes("right", size="5%", pad="2.5%") + add_colorbar_arg( + cax, vmin=vmin, vmax=vmax, hue_start=hue_start, invert=invert + ) - ax2.imshow( + axs[-1].imshow( asnumpy(probe_overlap), extent=extent, cmap=cmap, **kwargs, ) - ax2.scatter( + axs[-1].scatter( self.positions[:, 1], self.positions[:, 0], s=2.5, color=(1, 0, 0, 1), ) - ax2.set_ylabel("x [A]") - ax2.set_xlabel("y [A]") - ax2.set_xlim((extent[0], extent[1])) - ax2.set_ylim((extent[2], extent[3])) - ax2.set_title("Object Field of View") + axs[-1].set_ylabel("x [A]") + axs[-1].set_xlabel("y [A]") + axs[-1].set_xlim((extent[0], extent[1])) + axs[-1].set_ylim((extent[2], extent[3])) + axs[-1].set_title("Object Field of View") fig.tight_layout() self._preprocessed = True + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + return self def _overlap_projection(self, current_object, current_probe): @@ -1021,8 +1039,8 @@ def _adjoint( def _probe_center_of_mass_constraint(self, current_probe): """ - Ptychographic threshold constraint. - Used for avoiding the scaling ambiguity between probe and object. + Ptychographic center of mass constraint. + Used for centering corner-centered probe intensity. Parameters -------- @@ -1035,15 +1053,12 @@ def _probe_center_of_mass_constraint(self, current_probe): Constrained probe estimate """ xp = self._xp - asnumpy = self._asnumpy + probe_intensity = xp.abs(current_probe[0]) ** 2 - probe_center = xp.array(self._region_of_interest_shape) / 2 - probe_intensity = asnumpy(xp.abs(current_probe[0]) ** 2) - - probe_x0, probe_y0 = get_CoM(probe_intensity) - shifted_probe = fft_shift( - current_probe, probe_center - xp.array([probe_x0, probe_y0]), xp + probe_x0, probe_y0 = get_CoM( + probe_intensity, device=self._device, corner_centered=True ) + shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) return shifted_probe @@ -1051,6 +1066,7 @@ def _probe_orthogonalization_constraint(self, current_probe): """ Ptychographic probe-orthogonalization constraint. Used to ensure mixed states are orthogonal to each other. + Adapted from https://github.com/AdvancedPhotonSource/tike/blob/main/src/tike/ptycho/probe.py#L690 Parameters -------- @@ -1063,10 +1079,25 @@ def _probe_orthogonalization_constraint(self, current_probe): Orthogonalized probe estimate """ xp = self._xp + n_probes = self._num_probes - return orthogonalize(current_probe.reshape((self._num_probes, -1)), xp).reshape( - current_probe.shape - ) + # compute upper half of P* @ P + pairwise_dot_product = xp.empty((n_probes, n_probes), dtype=current_probe.dtype) + + for i in range(n_probes): + for j in range(i, n_probes): + pairwise_dot_product[i, j] = xp.sum( + current_probe[i].conj() * current_probe[j] + ) + + # compute eigenvectors (effectively cheaper way of computing V* from SVD) + _, evecs = xp.linalg.eigh(pairwise_dot_product, UPLO="U") + current_probe = xp.tensordot(evecs.T, current_probe, axes=1) + + # sort by real-space intensity + intensities = xp.sum(xp.abs(current_probe) ** 2, axis=(-2, -1)) + intensities_order = xp.argsort(intensities, axis=None)[::-1] + return current_probe[intensities_order] def _constraints( self, @@ -1075,15 +1106,17 @@ def _constraints( current_positions, pure_phase_object, fix_com, - symmetrize_probe, - probe_gaussian_filter, - probe_gaussian_filter_sigma, - probe_gaussian_filter_fix_amplitude, - fix_probe_amplitude, - fix_probe_amplitude_relative_radius, - fix_probe_amplitude_relative_width, - fix_probe_fourier_amplitude, - fix_probe_fourier_amplitude_threshold, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + fix_probe_aperture, + initial_probe_aperture, fix_positions, global_affine_transformation, gaussian_filter, @@ -1112,25 +1145,28 @@ def _constraints( If True, object amplitude is set to unity fix_com: bool If True, probe CoM is fixed to the center - probe_gaussian_filter: bool - If True, applies reciprocal-space gaussian filtering on residual aberrations - probe_gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A^-1 - probe_gaussian_filter_fix_amplitude: bool - If True, only the probe phase is smoothed - symmetrize_probe: bool - If True, the probe is radially-averaged - fix_probe_amplitude: bool + fit_probe_aberrations: bool + If True, fits the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + constrain_probe_amplitude: bool If True, probe amplitude is constrained by top hat function - fix_probe_amplitude_relative_radius: float + constrain_probe_amplitude_relative_radius: float Relative location of top-hat inflection point, between 0 and 0.5 - fix_probe_amplitude_relative_width: float + constrain_probe_amplitude_relative_width: float Relative width of top-hat sigmoid, between 0 and 0.5 - fix_probe_fourier_amplitude: bool - If True, probe fourier amplitude is constrained by top hat function - fix_probe_fourier_amplitude_threshold: float - Threshold value for current probe fourier mask. Value should - be between 0 and 1, where higher values provide the most masking. + constrain_probe_fourier_amplitude: bool + If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_probe_aperture: bool, + If True, probe fourier amplitude is replaced by initial probe aperture. + initial_probe_aperture: np.ndarray + initial probe aperture to use in replacing probe fourier amplitude fix_positions: bool If True, positions are not updated gaussian_filter: bool @@ -1195,13 +1231,13 @@ def _constraints( current_probe = self._probe_center_of_mass_constraint(current_probe) # These constraints don't _really_ make sense for mixed-state - if probe_gaussian_filter: + if fix_probe_aperture: raise NotImplementedError() - if symmetrize_probe: + elif constrain_probe_fourier_amplitude: raise NotImplementedError() - if fix_probe_amplitude: + if fit_probe_aberrations: raise NotImplementedError() - elif fix_probe_fourier_amplitude: + if constrain_probe_amplitude: raise NotImplementedError() if orthogonalize_probe: @@ -1224,28 +1260,32 @@ def reconstruct( max_iter: int = 64, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, max_batch_size: int = None, seed_random: int = None, - step_size: float = 0.9, + step_size: float = 0.5, normalization_min: float = 1, positions_step_size: float = 0.9, pure_phase_object_iter: int = 0, fix_com: bool = True, orthogonalize_probe: bool = True, fix_probe_iter: int = 0, - symmetrize_probe_iter: int = 0, - fix_probe_amplitude_iter: int = 0, - fix_probe_amplitude_relative_radius: float = 0.5, - fix_probe_amplitude_relative_width: float = 0.05, - fix_probe_fourier_amplitude_iter: int = 0, - fix_probe_fourier_amplitude_threshold: float = 0.9, + fix_probe_aperture_iter: int = 0, + constrain_probe_amplitude_iter: int = 0, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude_iter: int = 0, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, global_affine_transformation: bool = True, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, - probe_gaussian_filter_sigma: float = None, - probe_gaussian_filter_residual_aberrations_iter: int = np.inf, - probe_gaussian_filter_fix_amplitude: bool = True, + fit_probe_aberrations_iter: int = 0, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, butterworth_filter_iter: int = np.inf, q_lowpass: float = None, q_highpass: float = None, @@ -1267,7 +1307,7 @@ def reconstruct( Maximum number of iterations to run reconstruction_method: str, optional Specifies which reconstruction algorithm to use, one of: - "generalized-projection", + "generalized-projections", "DM_AP" (or "difference-map_alternating-projections"), "RAAR" (or "relaxed-averaged-alternating-reflections"), "RRR" (or "relax-reflect-reflect"), @@ -1275,6 +1315,12 @@ def reconstruct( "GD" (or "gradient_descent") reconstruction_parameter: float, optional Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. max_batch_size: int, optional Max number of probes to update at once seed_random: int, optional @@ -1291,19 +1337,20 @@ def reconstruct( If True, fixes center of mass of probe fix_probe_iter: int, optional Number of iterations to run with a fixed probe before updating probe estimate - symmetrize_probe_iter: int, optional - Number of iterations to run before radially-averaging the probe - fix_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - fix_probe_amplitude_relative_radius: float + fix_probe_aperture_iter: int, optional + Number of iterations to run with a fixed probe fourier amplitude before updating probe estimate + constrain_probe_amplitude_iter: int, optional + Number of iterations to run while constraining the real-space probe with a top-hat support. + constrain_probe_amplitude_relative_radius: float Relative location of top-hat inflection point, between 0 and 0.5 - fix_probe_amplitude_relative_width: float + constrain_probe_amplitude_relative_width: float Relative width of top-hat sigmoid, between 0 and 0.5 - fix_probe_fourier_amplitude: bool - If True, probe fourier amplitude is constrained by top hat function - fix_probe_fourier_amplitude_threshold: float - Threshold value for current probe fourier mask. Value should - be between 0 and 1, where higher values provide the most masking. + constrain_probe_fourier_amplitude_iter: int, optional + Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate global_affine_transformation: bool, optional @@ -1312,12 +1359,12 @@ def reconstruct( Standard deviation of gaussian kernel in A gaussian_filter_iter: int, optional Number of iterations to run using object smoothness constraint - probe_gaussian_filter_sigma: float, optional - Standard deviation of probe gaussian kernel in A^-1 - probe_gaussian_filter_residual_aberrations_iter: int, optional - Number of iterations to run using probe smoothing of residual aberrations - probe_gaussian_filter_fix_amplitude: bool - If True, only the probe phase is smoothed + fit_probe_aberrations_iter: int, optional + Number of iterations to run while fitting the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass: float @@ -1352,17 +1399,23 @@ def reconstruct( # Reconstruction method - if reconstruction_method == "generalized-projection": - if np.array(reconstruction_parameter).shape != (3,): + if reconstruction_method == "generalized-projections": + if ( + reconstruction_parameter_a is None + or reconstruction_parameter_b is None + or reconstruction_parameter_c is None + ): raise ValueError( ( - "reconstruction_parameter must be a list of three numbers " - "when using `reconstriction_method`=generalized-projection." + "reconstruction_parameter_a/b/c must all be specified " + "when using reconstruction_method='generalized-projections'." ) ) use_projection_scheme = True - projection_a, projection_b, projection_c = reconstruction_parameter + projection_a = reconstruction_parameter_a + projection_b = reconstruction_parameter_b + projection_c = reconstruction_parameter_c step_size = None elif ( reconstruction_method == "DM_AP" @@ -1421,7 +1474,8 @@ def reconstruct( else: raise ValueError( ( - "reconstruction_method must be one of 'DM_AP' (or 'difference-map_alternating-projections'), " + "reconstruction_method must be one of 'generalized-projections', " + "'DM_AP' (or 'difference-map_alternating-projections'), " "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " "'RRR' (or 'relax-reflect-reflect'), " "'SUPERFLIP' (or 'charge-flipping'), " @@ -1627,19 +1681,21 @@ def reconstruct( self._probe, self._positions_px, fix_com=fix_com and a0 >= fix_probe_iter, - symmetrize_probe=a0 < symmetrize_probe_iter, - probe_gaussian_filter=a0 - < probe_gaussian_filter_residual_aberrations_iter - and probe_gaussian_filter_sigma is not None, - probe_gaussian_filter_sigma=probe_gaussian_filter_sigma, - probe_gaussian_filter_fix_amplitude=probe_gaussian_filter_fix_amplitude, - fix_probe_amplitude=a0 < fix_probe_amplitude_iter + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter and a0 >= fix_probe_iter, - fix_probe_amplitude_relative_radius=fix_probe_amplitude_relative_radius, - fix_probe_amplitude_relative_width=fix_probe_amplitude_relative_width, - fix_probe_fourier_amplitude=a0 < fix_probe_fourier_amplitude_iter + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter and a0 >= fix_probe_iter, - fix_probe_fourier_amplitude_threshold=fix_probe_fourier_amplitude_threshold, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=self._probe_initial_aperture, fix_positions=a0 < fix_positions_iter, global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter @@ -1663,13 +1719,17 @@ def reconstruct( self.error_iterations.append(error.item()) if store_iterations: self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append(asnumpy(self._probe.copy())) + self.probe_iterations.append(self.probe_centered) # store result self.object = asnumpy(self._object) - self.probe = asnumpy(self._probe) + self.probe = self.probe_centered self.error = error.item() + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + return self def _visualize_last_iteration_figax( @@ -1779,12 +1839,20 @@ def _visualize_last_iteration( 0, ] - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] if plot_convergence: if plot_probe or plot_fourier_probe: @@ -1849,20 +1917,22 @@ def _visualize_last_iteration( probe_array = Complex2RGB( self.probe_fourier[0], hue_start=hue_start, invert=invert ) - ax.set_title("Reconstructed Fourier probe") + ax.set_title("Reconstructed Fourier probe[0]") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( self.probe[0], hue_start=hue_start, invert=invert ) - ax.set_title("Reconstructed probe") + ax.set_title("Reconstructed probe[0]") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, **kwargs, ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") if cbar: divider = make_axes_locatable(ax) @@ -2006,12 +2076,20 @@ def _visualize_all_iterations( 0, ] - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] if plot_convergence: if plot_probe or plot_fourier_probe: @@ -2066,16 +2144,24 @@ def _visualize_all_iterations( for n, ax in enumerate(grid): if plot_fourier_probe: probe_array = Complex2RGB( - asnumpy(self._return_fourier_probe(probes[grid_range[n]][0])), + asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[grid_range[n]][0] + ) + ), hue_start=hue_start, invert=invert, ) - ax.set_title(f"Iter: {grid_range[n]} Fourier probe") + ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( probes[grid_range[n]][0], hue_start=hue_start, invert=invert ) - ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_title(f"Iter: {grid_range[n]} probe[0]") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") im = ax.imshow( probe_array, @@ -2083,9 +2169,6 @@ def _visualize_all_iterations( **kwargs, ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if cbar: add_colorbar_arg( grid.cbar_axes[n], hue_start=hue_start, invert=invert diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py index a7fe53589..92f8c0bf3 100644 --- a/py4DSTEM/process/phase/iterative_multislice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py @@ -56,7 +56,9 @@ class MultislicePtychographicReconstruction(PtychographicReconstruction): datacube: DataCube, optional Input 4D diffraction pattern intensities semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels rolloff: float, optional Semiangle rolloff for the initial probe guess vacuum_probe_intensity: np.ndarray, optional @@ -99,6 +101,7 @@ def __init__( slice_thicknesses: Union[float, Sequence[float]], datacube: DataCube = None, semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, rolloff: float = 2.0, vacuum_probe_intensity: np.ndarray = None, polar_parameters: Mapping[str, float] = None, @@ -175,6 +178,7 @@ def __init__( self._scan_positions = initial_scan_positions self._energy = energy self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px @@ -271,6 +275,9 @@ def preprocess( force_com_rotation: float = None, force_com_transpose: float = None, force_com_shifts: float = None, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, **kwargs, ): @@ -320,6 +327,12 @@ def preprocess( Amplitudes come from diffraction patterns shifted with the CoM in the upper left corner for each probe unless shift is overwritten. + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded @@ -364,6 +377,9 @@ def preprocess( self._intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, ) ( @@ -419,15 +435,22 @@ def preprocess( self._scan_positions ) + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + # Object Initialization if self._object is None: - pad_x, pad_y = self._object_padding_px - p, q = np.max(self._positions_px, axis=0) + pad_x = self._object_padding_px[0][1] + pad_y = self._object_padding_px[1][1] + p, q = np.round(np.max(self._positions_px, axis=0)) p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - int + "int" ) q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - int + "int" ) if self._object_type == "potential": self._object = xp.zeros((self._num_slices, p, q), dtype=xp.float32) @@ -470,14 +493,13 @@ def preprocess( self._vacuum_probe_intensity, dtype=xp.float32 ) probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, device=self._device + self._vacuum_probe_intensity, + device=self._device, ) - shift_x = self._region_of_interest_shape[0] // 2 - probe_x0 - shift_y = self._region_of_interest_shape[1] // 2 - probe_y0 self._vacuum_probe_intensity = get_shifted_ar( self._vacuum_probe_intensity, - shift_x, - shift_y, + -probe_x0, + -probe_y0, bilinear=True, device=self._device, ) @@ -521,6 +543,7 @@ def preprocess( self._probe = xp.asarray(self._probe, dtype=xp.complex64) self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) self._known_aberrations_array = ComplexProbe( energy=self._energy, @@ -530,8 +553,6 @@ def preprocess( device=self._device, )._evaluate_ctf() - self._known_aberrations_array = xp.fft.ifftshift(self._known_aberrations_array) - # Precomputed propagator arrays self._propagator_arrays = self._precompute_propagator_arrays( self._region_of_interest_shape, @@ -562,7 +583,7 @@ def preprocess( # initial probe complex_probe_rgb = Complex2RGB( - asnumpy(self._probe), + self.probe_centered, vmin=vmin, vmax=vmax, hue_start=hue_start, @@ -577,7 +598,7 @@ def preprocess( propagated_probe, self._propagator_arrays[s] ) complex_propagated_rgb = Complex2RGB( - asnumpy(propagated_probe), + asnumpy(self._return_centered_probe(propagated_probe)), vmin=vmin, vmax=vmax, hue_start=hue_start, @@ -652,6 +673,10 @@ def preprocess( self._preprocessed = True + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + return self def _overlap_projection(self, current_object, current_probe): @@ -949,7 +974,7 @@ def _gradient_descent_adjoint( ) # back-transmit - exit_waves *= xp.conj(obj) / xp.abs(obj) ** 2 + exit_waves *= xp.conj(obj) #/ xp.abs(obj) ** 2 if s > 0: # back-propagate @@ -1051,7 +1076,7 @@ def _projection_sets_adjoint( ) # back-transmit - exit_waves_copy *= xp.conj(obj) / xp.abs(obj) ** 2 + exit_waves_copy *= xp.conj(obj) # / xp.abs(obj) ** 2 if s > 0: # back-propagate @@ -1331,6 +1356,8 @@ def _object_butterworth_constraint( Cut-off frequency in A^-1 for low-pass butterworth filter q_highpass: float Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter Returns -------- @@ -1428,15 +1455,17 @@ def _constraints( current_probe, current_positions, fix_com, - symmetrize_probe, - probe_gaussian_filter, - probe_gaussian_filter_sigma, - probe_gaussian_filter_fix_amplitude, - fix_probe_amplitude, - fix_probe_amplitude_relative_radius, - fix_probe_amplitude_relative_width, - fix_probe_fourier_amplitude, - fix_probe_fourier_amplitude_threshold, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + fix_probe_aperture, + initial_probe_aperture, fix_positions, global_affine_transformation, gaussian_filter, @@ -1469,25 +1498,28 @@ def _constraints( Current positions estimate fix_com: bool If True, probe CoM is fixed to the center - symmetrize_probe: bool - If True, the probe is radially-averaged - probe_gaussian_filter: bool - If True, applies reciprocal-space gaussian filtering on residual aberrations - probe_gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A^-1 - probe_gaussian_filter_fix_amplitude: bool - If True, only the probe phase is smoothed - fix_probe_amplitude: bool + fit_probe_aberrations: bool + If True, fits the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + constrain_probe_amplitude: bool If True, probe amplitude is constrained by top hat function - fix_probe_amplitude_relative_radius: float + constrain_probe_amplitude_relative_radius: float Relative location of top-hat inflection point, between 0 and 0.5 - fix_probe_amplitude_relative_width: float + constrain_probe_amplitude_relative_width: float Relative width of top-hat sigmoid, between 0 and 0.5 - fix_probe_fourier_amplitude: bool - If True, probe fourier amplitude is constrained by top hat function - fix_probe_fourier_amplitude_threshold: float - Threshold value for current probe fourier mask. Value should - be between 0 and 1, where higher values provide the most masking. + constrain_probe_fourier_amplitude: bool + If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_probe_aperture: bool + If True, probe Fourier amplitude is replaced by initial_probe_aperture + initial_probe_aperture: np.ndarray + Initial probe aperture to use in replacing probe Fourier amplitude fix_positions: bool If True, positions are not updated gaussian_filter: bool @@ -1579,27 +1611,30 @@ def _constraints( if fix_com: current_probe = self._probe_center_of_mass_constraint(current_probe) - if probe_gaussian_filter: - current_probe = self._probe_residual_aberration_filtering_constraint( + if fix_probe_aperture: + current_probe = self._probe_aperture_constraint( current_probe, - probe_gaussian_filter_sigma, - probe_gaussian_filter_fix_amplitude, + initial_probe_aperture, + ) + elif constrain_probe_fourier_amplitude: + current_probe = self._probe_fourier_amplitude_constraint( + current_probe, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, ) - if symmetrize_probe: - current_probe = self._probe_radial_symmetrization_constraint(current_probe) - - if fix_probe_amplitude: - current_probe = self._probe_amplitude_constraint( + if fit_probe_aberrations: + current_probe = self._probe_aberration_fitting_constraint( current_probe, - fix_probe_amplitude_relative_radius, - fix_probe_amplitude_relative_width, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, ) - elif fix_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( + + if constrain_probe_amplitude: + current_probe = self._probe_amplitude_constraint( current_probe, - fix_probe_fourier_amplitude_threshold, - fix_probe_amplitude_relative_width, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, ) if not fix_positions: @@ -1619,27 +1654,31 @@ def reconstruct( max_iter: int = 64, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, max_batch_size: int = None, seed_random: int = None, - step_size: float = 0.9, + step_size: float = 0.5, normalization_min: float = 1, positions_step_size: float = 0.9, fix_com: bool = True, fix_probe_iter: int = 0, - symmetrize_probe_iter: int = 0, - fix_probe_amplitude_iter: int = 0, - fix_probe_amplitude_relative_radius: float = 0.5, - fix_probe_amplitude_relative_width: float = 0.05, - fix_probe_fourier_amplitude_iter: int = 0, - fix_probe_fourier_amplitude_threshold: float = 0.9, + fix_probe_aperture_iter: int = 0, + constrain_probe_amplitude_iter: int = 0, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude_iter: int = 0, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, constrain_position_distance: float = None, global_affine_transformation: bool = True, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, - probe_gaussian_filter_sigma: float = None, - probe_gaussian_filter_residual_aberrations_iter: int = np.inf, - probe_gaussian_filter_fix_amplitude: bool = True, + fit_probe_aberrations_iter: int = 0, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, butterworth_filter_iter: int = np.inf, q_lowpass: float = None, q_highpass: float = None, @@ -1668,7 +1707,7 @@ def reconstruct( Maximum number of iterations to run reconstruction_method: str, optional Specifies which reconstruction algorithm to use, one of: - "generalized-projection", + "generalized-projections", "DM_AP" (or "difference-map_alternating-projections"), "RAAR" (or "relaxed-averaged-alternating-reflections"), "RRR" (or "relax-reflect-reflect"), @@ -1676,8 +1715,12 @@ def reconstruct( "GD" (or "gradient_descent") reconstruction_parameter: float, optional Reconstruction parameter for various reconstruction methods above. - reconstruction_parameter: float, optional - Tuning parameter to interpolate b/w DM-AP and DM-RAAR + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. max_batch_size: int, optional Max number of probes to update at once seed_random: int, optional @@ -1692,19 +1735,20 @@ def reconstruct( If True, fixes center of mass of probe fix_probe_iter: int, optional Number of iterations to run with a fixed probe before updating probe estimate - symmetrize_probe_iter: int, optional - Number of iterations to run before radially-averaging the probe - fix_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - fix_probe_amplitude_relative_radius: float + fix_probe_aperture_iter: int, optional + Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate + constrain_probe_amplitude_iter: int, optional + Number of iterations to run while constraining the real-space probe with a top-hat support. + constrain_probe_amplitude_relative_radius: float Relative location of top-hat inflection point, between 0 and 0.5 - fix_probe_amplitude_relative_width: float + constrain_probe_amplitude_relative_width: float Relative width of top-hat sigmoid, between 0 and 0.5 - fix_probe_fourier_amplitude: bool - If True, probe fourier amplitude is constrained by top hat function - fix_probe_fourier_amplitude_threshold: float - Threshold value for current probe fourier mask. Value should - be between 0 and 1, where higher values provide the most masking. + constrain_probe_fourier_amplitude_iter: int, optional + Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate global_affine_transformation: bool, optional @@ -1713,12 +1757,12 @@ def reconstruct( Standard deviation of gaussian kernel in A gaussian_filter_iter: int, optional Number of iterations to run using object smoothness constraint - probe_gaussian_filter_sigma: float, optional - Standard deviation of probe gaussian kernel in A^-1 - probe_gaussian_filter_residual_aberrations_iter: int, optional - Number of iterations to run using probe smoothing of residual aberrations - probe_gaussian_filter_fix_amplitude: bool - If True, only the probe phase is smoothed + fit_probe_aberrations_iter: int, optional + Number of iterations to run while fitting the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass: float @@ -1767,17 +1811,23 @@ def reconstruct( # Reconstruction method - if reconstruction_method == "generalized-projection": - if np.array(reconstruction_parameter).shape != (3,): + if reconstruction_method == "generalized-projections": + if ( + reconstruction_parameter_a is None + or reconstruction_parameter_b is None + or reconstruction_parameter_c is None + ): raise ValueError( ( - "reconstruction_parameter must be a list of three numbers " - "when using `reconstriction_method`=generalized-projection." + "reconstruction_parameter_a/b/c must all be specified " + "when using reconstruction_method='generalized-projections'." ) ) use_projection_scheme = True - projection_a, projection_b, projection_c = reconstruction_parameter + projection_a = reconstruction_parameter_a + projection_b = reconstruction_parameter_b + projection_c = reconstruction_parameter_c step_size = None elif ( reconstruction_method == "DM_AP" @@ -1836,7 +1886,8 @@ def reconstruct( else: raise ValueError( ( - "reconstruction_method must be one of 'DM_AP' (or 'difference-map_alternating-projections'), " + "reconstruction_method must be one of 'generalized-projections', " + "'DM_AP' (or 'difference-map_alternating-projections'), " "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " "'RRR' (or 'relax-reflect-reflect'), " "'SUPERFLIP' (or 'charge-flipping'), " @@ -2043,19 +2094,21 @@ def reconstruct( self._probe, self._positions_px, fix_com=fix_com and a0 >= fix_probe_iter, - symmetrize_probe=a0 < symmetrize_probe_iter, - probe_gaussian_filter=a0 - < probe_gaussian_filter_residual_aberrations_iter - and probe_gaussian_filter_sigma is not None, - probe_gaussian_filter_sigma=probe_gaussian_filter_sigma, - probe_gaussian_filter_fix_amplitude=probe_gaussian_filter_fix_amplitude, - fix_probe_amplitude=a0 < fix_probe_amplitude_iter + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter and a0 >= fix_probe_iter, - fix_probe_amplitude_relative_radius=fix_probe_amplitude_relative_radius, - fix_probe_amplitude_relative_width=fix_probe_amplitude_relative_width, - fix_probe_fourier_amplitude=a0 < fix_probe_fourier_amplitude_iter + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter and a0 >= fix_probe_iter, - fix_probe_fourier_amplitude_threshold=fix_probe_fourier_amplitude_threshold, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=self._probe_initial_aperture, fix_positions=a0 < fix_positions_iter, global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter @@ -2088,13 +2141,17 @@ def reconstruct( self.error_iterations.append(error.item()) if store_iterations: self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append(asnumpy(self._probe.copy())) + self.probe_iterations.append(self.probe_centered) # store result self.object = asnumpy(self._object) - self.probe = asnumpy(self._probe) + self.probe = self.probe_centered self.error = error.item() + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + return self def _visualize_last_iteration_figax( @@ -2213,12 +2270,20 @@ def _visualize_last_iteration( 0, ] - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] if plot_convergence: if plot_probe or plot_fourier_probe: @@ -2285,19 +2350,21 @@ def _visualize_last_iteration( self.probe_fourier, hue_start=hue_start, invert=invert ) ax.set_title("Reconstructed Fourier probe") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( self.probe, hue_start=hue_start, invert=invert ) ax.set_title("Reconstructed probe") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, **kwargs, ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") if cbar: divider = make_axes_locatable(ax) @@ -2443,12 +2510,20 @@ def _visualize_all_iterations( 0, ] - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] if plot_convergence: if plot_probe or plot_fourier_probe: @@ -2503,16 +2578,24 @@ def _visualize_all_iterations( for n, ax in enumerate(grid): if plot_fourier_probe: probe_array = Complex2RGB( - asnumpy(self._return_fourier_probe(probes[grid_range[n]])), + asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[grid_range[n]] + ) + ), hue_start=hue_start, invert=invert, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( probes[grid_range[n]], hue_start=hue_start, invert=invert ) ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") im = ax.imshow( probe_array, @@ -2520,9 +2603,6 @@ def _visualize_all_iterations( **kwargs, ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if cbar: add_colorbar_arg( grid.cbar_axes[n], hue_start=hue_start, invert=invert @@ -2619,6 +2699,7 @@ def show_transmitted_probe( """ xp = self._xp + asnumpy = self._asnumpy transmitted_probe_intensities = xp.sum( xp.abs(self._transmitted_probes) ** 2, axis=(-2, -1) @@ -2631,7 +2712,7 @@ def show_transmitted_probe( ] mean_transmitted = self._transmitted_probes.mean(0) probes = [ - self._asnumpy(probe) + asnumpy(self._return_centered_probe(probe)) for probe in [ mean_transmitted, min_intensity_transmitted, @@ -2646,7 +2727,7 @@ def show_transmitted_probe( if plot_fourier_probe: bottom_row = [ - self._asnumpy(self._return_fourier_probe(probe)) + asnumpy(self._return_fourier_probe(probe)) for probe in [ mean_transmitted, min_intensity_transmitted, @@ -2986,4 +3067,4 @@ def _return_object_fft( obj = np.angle(obj) obj = self._crop_rotate_object_fov(np.sum(obj, axis=0)) - return np.abs(np.fft.fftshift(np.fft.fft2(obj))) + return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))) \ No newline at end of file diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py index 9eda852a1..8691a121d 100644 --- a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py @@ -63,7 +63,9 @@ class OverlapMagneticTomographicReconstruction(PtychographicReconstruction): - \beta tilt around x-axis - -\alpha tilt around z-axis semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels rolloff: float, optional Semiangle rolloff for the initial probe guess vacuum_probe_intensity: np.ndarray, optional @@ -106,6 +108,7 @@ def __init__( tilt_angles_deg: Sequence[Tuple[float, float]], datacube: Sequence[DataCube] = None, semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, rolloff: float = 2.0, vacuum_probe_intensity: np.ndarray = None, polar_parameters: Mapping[str, float] = None, @@ -171,6 +174,7 @@ def __init__( self._scan_positions = initial_scan_positions self._energy = energy self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px @@ -327,9 +331,9 @@ def _euler_angle_rotate_volume( """ Rotate 3D volume using alpha, beta, gamma Euler angles according to convention: - - \-alpha tilt around first axis (z) - - \beta tilt around second axis (x) - - \alpha tilt around first axis (z) + - \\-alpha tilt around first axis (z) + - \\beta tilt around second axis (x) + - \\alpha tilt around first axis (z) Note: since we store array as zxy, the x- and y-axis rotations flip sign below. @@ -422,6 +426,9 @@ def preprocess( diffraction_patterns_transpose: bool = None, force_com_shifts: Sequence[float] = None, progress_bar: bool = True, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, **kwargs, ): @@ -458,6 +465,12 @@ def preprocess( Amplitudes come from diffraction patterns shifted with the CoM in the upper left corner for each probe unless shift is overwritten. One tuple per tilt. + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded @@ -551,6 +564,9 @@ def preprocess( intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube[tilt_index], require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, ) ( @@ -600,10 +616,17 @@ def preprocess( self._scan_positions[tilt_index] ) + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + # Object Initialization if self._object is None: - pad_x, pad_y = self._object_padding_px - p, q = np.max(self._positions_px_all, axis=0) + pad_x = self._object_padding_px[0][1] + pad_y = self._object_padding_px[1][1] + p, q = np.round(np.max(self._positions_px_all, axis=0)) p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( "int" ) @@ -654,12 +677,10 @@ def preprocess( probe_x0, probe_y0 = get_CoM( self._vacuum_probe_intensity, device=self._device ) - shift_x = self._region_of_interest_shape[0] // 2 - probe_x0 - shift_y = self._region_of_interest_shape[1] // 2 - probe_y0 self._vacuum_probe_intensity = get_shifted_ar( self._vacuum_probe_intensity, - shift_x, - shift_y, + -probe_x0, + -probe_y0, bilinear=True, device=self._device, ) @@ -708,6 +729,7 @@ def preprocess( self._probe = xp.asarray(self._probe, dtype=xp.complex64) self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) self._known_aberrations_array = ComplexProbe( energy=self._energy, @@ -717,8 +739,6 @@ def preprocess( device=self._device, )._evaluate_ctf() - self._known_aberrations_array = xp.fft.ifftshift(self._known_aberrations_array) - # Precomputed propagator arrays self._slice_thicknesses = np.tile( self._object_shape[1] * self.sampling[1] / self._num_slices, @@ -795,7 +815,7 @@ def preprocess( # initial probe complex_probe_rgb = Complex2RGB( - asnumpy(self._probe), + self.probe_centered, vmin=vmin, vmax=vmax, hue_start=hue_start, @@ -810,7 +830,7 @@ def preprocess( propagated_probe, self._propagator_arrays[s] ) complex_propagated_rgb = Complex2RGB( - asnumpy(propagated_probe), + asnumpy(self._return_centered_probe(propagated_probe)), vmin=vmin, vmax=vmax, hue_start=hue_start, @@ -885,6 +905,10 @@ def preprocess( self._preprocessed = True + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + return self def _overlap_projection( @@ -1584,14 +1608,15 @@ def _object_gaussian_constraint(self, current_object, gaussian_filter_sigma): Constrained object estimate """ gaussian_filter = self._gaussian_filter - xp = self._xp - gaussian_filter_sigma /= xp.sqrt(self.sampling[0] ** 2 + self.sampling[1] ** 2) + gaussian_filter_sigma /= self.sampling[0] current_object = gaussian_filter(current_object, gaussian_filter_sigma) return current_object - def _object_butterworth_constraint(self, current_object, q_lowpass, q_highpass): + def _object_butterworth_constraint( + self, current_object, q_lowpass, q_highpass, butterworth_order + ): """ Butterworth filter @@ -1603,6 +1628,8 @@ def _object_butterworth_constraint(self, current_object, q_lowpass, q_highpass): Cut-off frequency in A^-1 for low-pass butterworth filter q_highpass: float Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter Returns -------- @@ -1618,9 +1645,9 @@ def _object_butterworth_constraint(self, current_object, q_lowpass, q_highpass): env = xp.ones_like(qra) if q_highpass: - env *= 1 - 1 / (1 + (qra / q_highpass) ** 4) + env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) if q_lowpass: - env *= 1 / (1 + (qra / q_lowpass) ** 4) + env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) current_object_mean = xp.mean(current_object) current_object -= current_object_mean @@ -1658,15 +1685,17 @@ def _constraints( current_probe, current_positions, fix_com, - symmetrize_probe, - probe_gaussian_filter, - probe_gaussian_filter_sigma, - probe_gaussian_filter_fix_amplitude, - fix_probe_amplitude, - fix_probe_amplitude_relative_radius, - fix_probe_amplitude_relative_width, - fix_probe_fourier_amplitude, - fix_probe_fourier_amplitude_threshold, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + fix_probe_aperture, + initial_probe_aperture, fix_positions, global_affine_transformation, gaussian_filter, @@ -1677,6 +1706,7 @@ def _constraints( q_lowpass_m, q_highpass_e, q_highpass_m, + butterworth_order, object_positivity, shrinkage_rad, object_mask, @@ -1695,25 +1725,28 @@ def _constraints( Current positions estimate fix_com: bool If True, probe CoM is fixed to the center - symmetrize_probe: bool - If True, the probe is radially-averaged - probe_gaussian_filter: bool - If True, applies reciprocal-space gaussian filtering on residual aberrations - probe_gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A^-1 - probe_gaussian_filter_fix_amplitude: bool - If True, only the probe phase is smoothed - fix_probe_amplitude: bool + fit_probe_aberrations: bool + If True, fits the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + constrain_probe_amplitude: bool If True, probe amplitude is constrained by top hat function - fix_probe_amplitude_relative_radius: float + constrain_probe_amplitude_relative_radius: float Relative location of top-hat inflection point, between 0 and 0.5 - fix_probe_amplitude_relative_width: float + constrain_probe_amplitude_relative_width: float Relative width of top-hat sigmoid, between 0 and 0.5 - fix_probe_fourier_amplitude: bool - If True, probe fourier amplitude is constrained by top hat function - fix_probe_fourier_amplitude_threshold: float - Threshold value for current probe fourier mask. Value should - be between 0 and 1, where higher values provide the most masking. + constrain_probe_fourier_amplitude: bool + If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_probe_aperture: bool, + If True, probe Fourier amplitude is replaced by initial probe aperture. + initial_probe_aperture: np.ndarray, + Initial probe aperture to use in replacing probe Fourier amplitude. fix_positions: bool If True, positions are not updated gaussian_filter: bool @@ -1732,6 +1765,8 @@ def _constraints( Cut-off frequency in A^-1 for high-pass filtering electrostatic object q_highpass_m: float Cut-off frequency in A^-1 for high-pass filtering magnetic object + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter object_positivity: bool If True, forces object to be positive shrinkage_rad: float @@ -1766,21 +1801,25 @@ def _constraints( current_object[0], q_lowpass_e, q_highpass_e, + butterworth_order, ) current_object[1] = self._object_butterworth_constraint( current_object[1], q_lowpass_m, q_highpass_m, + butterworth_order, ) current_object[2] = self._object_butterworth_constraint( current_object[2], q_lowpass_m, q_highpass_m, + butterworth_order, ) current_object[3] = self._object_butterworth_constraint( current_object[3], q_lowpass_m, q_highpass_m, + butterworth_order, ) if shrinkage_rad > 0.0 or object_mask is not None: @@ -1796,27 +1835,30 @@ def _constraints( if fix_com: current_probe = self._probe_center_of_mass_constraint(current_probe) - if probe_gaussian_filter: - current_probe = self._probe_residual_aberration_filtering_constraint( + if fix_probe_aperture: + current_probe = self._probe_aperture_constraint( current_probe, - probe_gaussian_filter_sigma, - probe_gaussian_filter_fix_amplitude, + initial_probe_aperture, + ) + elif constrain_probe_fourier_amplitude: + current_probe = self._probe_fourier_amplitude_constraint( + current_probe, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, ) - if symmetrize_probe: - current_probe = self._probe_radial_symmetrization_constraint(current_probe) - - if fix_probe_amplitude: - current_probe = self._probe_amplitude_constraint( + if fit_probe_aberrations: + current_probe = self._probe_aberration_fitting_constraint( current_probe, - fix_probe_amplitude_relative_radius, - fix_probe_amplitude_relative_width, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, ) - elif fix_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( + + if constrain_probe_amplitude: + current_probe = self._probe_amplitude_constraint( current_probe, - fix_probe_fourier_amplitude_threshold, - fix_probe_amplitude_relative_width, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, ) if not fix_positions: @@ -1836,33 +1878,38 @@ def reconstruct( max_iter: int = 64, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, max_batch_size: int = None, seed_random: int = None, - step_size: float = 0.9, + step_size: float = 0.5, normalization_min: float = 1, positions_step_size: float = 0.9, fix_com: bool = True, fix_probe_iter: int = 0, - symmetrize_probe_iter: int = 0, - fix_probe_amplitude_iter: int = 0, - fix_probe_amplitude_relative_radius: float = 0.5, - fix_probe_amplitude_relative_width: float = 0.05, - fix_probe_fourier_amplitude_iter: int = 0, - fix_probe_fourier_amplitude_threshold: float = 0.9, + fix_probe_aperture_iter: int = 0, + constrain_probe_amplitude_iter: int = 0, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude_iter: int = 0, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, constrain_position_distance: float = None, global_affine_transformation: bool = True, gaussian_filter_sigma_e: float = None, gaussian_filter_sigma_m: float = None, gaussian_filter_iter: int = np.inf, - probe_gaussian_filter_sigma: float = None, - probe_gaussian_filter_residual_aberrations_iter: int = np.inf, - probe_gaussian_filter_fix_amplitude: bool = True, + fit_probe_aberrations_iter: int = 0, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, butterworth_filter_iter: int = np.inf, q_lowpass_e: float = None, q_lowpass_m: float = None, q_highpass_e: float = None, q_highpass_m: float = None, + butterworth_order: float = 2, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -1880,7 +1927,7 @@ def reconstruct( Maximum number of iterations to run reconstruction_method: str, optional Specifies which reconstruction algorithm to use, one of: - "generalized-projection", + "generalized-projections", "DM_AP" (or "difference-map_alternating-projections"), "RAAR" (or "relaxed-averaged-alternating-reflections"), "RRR" (or "relax-reflect-reflect"), @@ -1888,8 +1935,12 @@ def reconstruct( "GD" (or "gradient_descent") reconstruction_parameter: float, optional Reconstruction parameter for various reconstruction methods above. - reconstruction_parameter: float, optional - Tuning parameter to interpolate b/w DM-AP and DM-RAAR + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. max_batch_size: int, optional Max number of probes to update at once seed_random: int, optional @@ -1904,19 +1955,20 @@ def reconstruct( If True, fixes center of mass of probe fix_probe_iter: int, optional Number of iterations to run with a fixed probe before updating probe estimate - symmetrize_probe_iter: int, optional - Number of iterations to run before radially-averaging the probe - fix_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - fix_probe_amplitude_relative_radius: float + fix_probe_aperture_iter: int, optional + Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate + constrain_probe_amplitude_iter: int, optional + Number of iterations to run while constraining the real-space probe with a top-hat support. + constrain_probe_amplitude_relative_radius: float Relative location of top-hat inflection point, between 0 and 0.5 - fix_probe_amplitude_relative_width: float + constrain_probe_amplitude_relative_width: float Relative width of top-hat sigmoid, between 0 and 0.5 - fix_probe_fourier_amplitude: bool - If True, probe fourier amplitude is constrained by top hat function - fix_probe_fourier_amplitude_threshold: float - Threshold value for current probe fourier mask. Value should - be between 0 and 1, where higher values provide the most masking. + constrain_probe_fourier_amplitude_iter: int, optional + Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate constrain_position_distance: float, optional @@ -1930,18 +1982,20 @@ def reconstruct( Standard deviation of gaussian kernel for magnetic object in A gaussian_filter_iter: int, optional Number of iterations to run using object smoothness constraint - probe_gaussian_filter_sigma: float, optional - Standard deviation of probe gaussian kernel in A^-1 - probe_gaussian_filter_residual_aberrations_iter: int, optional - Number of iterations to run using probe smoothing of residual aberrations - probe_gaussian_filter_fix_amplitude: bool - If True, only the probe phase is smoothed + fit_probe_aberrations_iter: int, optional + Number of iterations to run while fitting the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass: float Cut-off frequency in A^-1 for low-pass butterworth filter q_highpass: float Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -1963,17 +2017,23 @@ def reconstruct( # Reconstruction method - if reconstruction_method == "generalized-projection": - if np.array(reconstruction_parameter).shape != (3,): + if reconstruction_method == "generalized-projections": + if ( + reconstruction_parameter_a is None + or reconstruction_parameter_b is None + or reconstruction_parameter_c is None + ): raise ValueError( ( - "reconstruction_parameter must be a list of three numbers " - "when using `reconstriction_method`=generalized-projection." + "reconstruction_parameter_a/b/c must all be specified " + "when using reconstruction_method='generalized-projections'." ) ) use_projection_scheme = True - projection_a, projection_b, projection_c = reconstruction_parameter + projection_a = reconstruction_parameter_a + projection_b = reconstruction_parameter_b + projection_c = reconstruction_parameter_c step_size = None elif ( reconstruction_method == "DM_AP" @@ -2032,7 +2092,8 @@ def reconstruct( else: raise ValueError( ( - "reconstruction_method must be one of 'DM_AP' (or 'difference-map_alternating-projections'), " + "reconstruction_method must be one of 'generalized-projections', " + "'DM_AP' (or 'difference-map_alternating-projections'), " "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " "'RRR' (or 'relax-reflect-reflect'), " "'SUPERFLIP' (or 'charge-flipping'), " @@ -2382,20 +2443,21 @@ def reconstruct( self._probe, self._positions_px_all[start_tilt:end_tilt], fix_com=fix_com and a0 >= fix_probe_iter, - symmetrize_probe=a0 < symmetrize_probe_iter, - probe_gaussian_filter=a0 - < probe_gaussian_filter_residual_aberrations_iter - and probe_gaussian_filter_sigma is not None, - probe_gaussian_filter_sigma=probe_gaussian_filter_sigma, - probe_gaussian_filter_fix_amplitude=probe_gaussian_filter_fix_amplitude, - fix_probe_amplitude=a0 < fix_probe_amplitude_iter + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter and a0 >= fix_probe_iter, - fix_probe_amplitude_relative_radius=fix_probe_amplitude_relative_radius, - fix_probe_amplitude_relative_width=fix_probe_amplitude_relative_width, - fix_probe_fourier_amplitude=a0 - < fix_probe_fourier_amplitude_iter + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter and a0 >= fix_probe_iter, - fix_probe_fourier_amplitude_threshold=fix_probe_fourier_amplitude_threshold, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=self._probe_initial_aperture, fix_positions=a0 < fix_positions_iter, global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter @@ -2408,6 +2470,7 @@ def reconstruct( q_lowpass_m=q_lowpass_m, q_highpass_e=q_highpass_e, q_highpass_m=q_highpass_m, + butterworth_order=butterworth_order, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse @@ -2424,24 +2487,30 @@ def reconstruct( if collective_tilt_updates: self._object += collective_object / self._num_tilts - (self._object, self._probe, _,) = self._constraints( + ( + self._object, + self._probe, + _, + ) = self._constraints( self._object, self._probe, None, fix_com=fix_com and a0 >= fix_probe_iter, - symmetrize_probe=a0 < symmetrize_probe_iter, - probe_gaussian_filter=a0 - < probe_gaussian_filter_residual_aberrations_iter - and probe_gaussian_filter_sigma is not None, - probe_gaussian_filter_sigma=probe_gaussian_filter_sigma, - probe_gaussian_filter_fix_amplitude=probe_gaussian_filter_fix_amplitude, - fix_probe_amplitude=a0 < fix_probe_amplitude_iter + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter and a0 >= fix_probe_iter, - fix_probe_amplitude_relative_radius=fix_probe_amplitude_relative_radius, - fix_probe_amplitude_relative_width=fix_probe_amplitude_relative_width, - fix_probe_fourier_amplitude=a0 < fix_probe_fourier_amplitude_iter + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter and a0 >= fix_probe_iter, - fix_probe_fourier_amplitude_threshold=fix_probe_fourier_amplitude_threshold, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter + and a0 >= fix_probe_iter, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=self._probe_initial_aperture, fix_positions=True, global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter @@ -2454,6 +2523,7 @@ def reconstruct( q_lowpass_m=q_lowpass_m, q_highpass_e=q_highpass_e, q_highpass_m=q_highpass_m, + butterworth_order=butterworth_order, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse @@ -2465,13 +2535,17 @@ def reconstruct( self.error_iterations.append(error.item()) if store_iterations: self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append(asnumpy(self._probe.copy())) + self.probe_iterations.append(self.probe_centered) # store result self.object = asnumpy(self._object) - self.probe = asnumpy(self._probe) + self.probe = self.probe_centered self.error = error.item() + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + return self def _crop_rotate_object_manually( @@ -3010,7 +3084,7 @@ def _return_object_fft( rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims ) - return np.abs(np.fft.fftshift(np.fft.fft2(rotated_object))) + return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(rotated_object)))) def show_object_fft( self, diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py index 3337d02cc..d6bee12fd 100644 --- a/py4DSTEM/process/phase/iterative_overlap_tomography.py +++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py @@ -58,7 +58,9 @@ class OverlapTomographicReconstruction(PtychographicReconstruction): tilt_angles_deg: Sequence[float] List of tilt angles in degrees, semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels rolloff: float, optional Semiangle rolloff for the initial probe guess vacuum_probe_intensity: np.ndarray, optional @@ -101,6 +103,7 @@ def __init__( tilt_angles_deg: Sequence[float], datacube: Sequence[DataCube] = None, semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, rolloff: float = 2.0, vacuum_probe_intensity: np.ndarray = None, polar_parameters: Mapping[str, float] = None, @@ -172,6 +175,7 @@ def __init__( self._scan_positions = initial_scan_positions self._energy = energy self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px @@ -254,40 +258,6 @@ def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray): return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array) - def _expand_or_project_sliced_object(self, array: np.ndarray, output_z): - """ - OLD Version - - Expands supersliced object or projects voxel-sliced object. - - Parameters - ---------- - array: np.ndarray - 3D array to expand/project - output_z: int - Output_dimension to expand/project array to. - If output_z > array.shape[0] array is expanded, else it's projected - - Returns - ------- - expanded_or_projected_array: np.ndarray - expanded or projected array - """ - zoom = self._zoom - input_z = array.shape[0] - - return ( - zoom( - array, - (output_z / input_z, 1, 1), - order=0, - mode="nearest", - grid_mode=True, - ) - * input_z - / output_z - ) - def _project_sliced_object(self, array: np.ndarray, output_z): """ Expands supersliced object or projects voxel-sliced object. @@ -365,6 +335,9 @@ def preprocess( diffraction_patterns_rotate_degrees: float = None, diffraction_patterns_transpose: bool = None, force_com_shifts: Sequence[float] = None, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, progress_bar: bool = True, object_fov_mask: np.ndarray = None, **kwargs, @@ -402,6 +375,12 @@ def preprocess( Amplitudes come from diffraction patterns shifted with the CoM in the upper left corner for each probe unless shift is overwritten. One tuple per tilt. + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded @@ -494,6 +473,9 @@ def preprocess( intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube[tilt_index], require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, ) ( @@ -543,10 +525,17 @@ def preprocess( self._scan_positions[tilt_index] ) + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + # Object Initialization if self._object is None: - pad_x, pad_y = self._object_padding_px - p, q = np.max(self._positions_px_all, axis=0) + pad_x = self._object_padding_px[0][1] + pad_y = self._object_padding_px[1][1] + p, q = np.round(np.max(self._positions_px_all, axis=0)) p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( "int" ) @@ -597,12 +586,10 @@ def preprocess( probe_x0, probe_y0 = get_CoM( self._vacuum_probe_intensity, device=self._device ) - shift_x = self._region_of_interest_shape[0] // 2 - probe_x0 - shift_y = self._region_of_interest_shape[1] // 2 - probe_y0 self._vacuum_probe_intensity = get_shifted_ar( self._vacuum_probe_intensity, - shift_x, - shift_y, + -probe_x0, + -probe_y0, bilinear=True, device=self._device, ) @@ -651,6 +638,7 @@ def preprocess( self._probe = xp.asarray(self._probe, dtype=xp.complex64) self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) self._known_aberrations_array = ComplexProbe( energy=self._energy, @@ -660,8 +648,6 @@ def preprocess( device=self._device, )._evaluate_ctf() - self._known_aberrations_array = xp.fft.ifftshift(self._known_aberrations_array) - # Precomputed propagator arrays self._slice_thicknesses = np.tile( self._object_shape[1] * self.sampling[1] / self._num_slices, @@ -741,7 +727,7 @@ def preprocess( # initial probe complex_probe_rgb = Complex2RGB( - asnumpy(self._probe), + self.probe_centered, vmin=vmin, vmax=vmax, hue_start=hue_start, @@ -756,7 +742,7 @@ def preprocess( propagated_probe, self._propagator_arrays[s] ) complex_propagated_rgb = Complex2RGB( - asnumpy(propagated_probe), + asnumpy(self._return_centered_probe(propagated_probe)), vmin=vmin, vmax=vmax, hue_start=hue_start, @@ -831,6 +817,10 @@ def preprocess( self._preprocessed = True + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + return self def _overlap_projection(self, current_object, current_probe): @@ -1491,14 +1481,15 @@ def _object_gaussian_constraint(self, current_object, gaussian_filter_sigma): Constrained object estimate """ gaussian_filter = self._gaussian_filter - xp = self._xp - gaussian_filter_sigma /= xp.sqrt(self.sampling[0] ** 2 + self.sampling[1] ** 2) + gaussian_filter_sigma /= self.sampling[0] current_object = gaussian_filter(current_object, gaussian_filter_sigma) return current_object - def _object_butterworth_constraint(self, current_object, q_lowpass, q_highpass): + def _object_butterworth_constraint( + self, current_object, q_lowpass, q_highpass, butterworth_order + ): """ Butterworth filter @@ -1510,7 +1501,8 @@ def _object_butterworth_constraint(self, current_object, q_lowpass, q_highpass): Cut-off frequency in A^-1 for low-pass butterworth filter q_highpass: float Cut-off frequency in A^-1 for high-pass butterworth filter - + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter Returns -------- constrained_object: np.ndarray @@ -1525,9 +1517,9 @@ def _object_butterworth_constraint(self, current_object, q_lowpass, q_highpass): env = xp.ones_like(qra) if q_highpass: - env *= 1 - 1 / (1 + (qra / q_highpass) ** 4) + env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order)) if q_lowpass: - env *= 1 / (1 + (qra / q_lowpass) ** 4) + env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order)) current_object_mean = xp.mean(current_object) current_object -= current_object_mean @@ -1541,15 +1533,17 @@ def _constraints( current_probe, current_positions, fix_com, - symmetrize_probe, - probe_gaussian_filter, - probe_gaussian_filter_sigma, - probe_gaussian_filter_fix_amplitude, - fix_probe_amplitude, - fix_probe_amplitude_relative_radius, - fix_probe_amplitude_relative_width, - fix_probe_fourier_amplitude, - fix_probe_fourier_amplitude_threshold, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + fix_probe_aperture, + initial_probe_aperture, fix_positions, global_affine_transformation, gaussian_filter, @@ -1557,6 +1551,7 @@ def _constraints( butterworth_filter, q_lowpass, q_highpass, + butterworth_order, object_positivity, shrinkage_rad, object_mask, @@ -1574,25 +1569,28 @@ def _constraints( Current positions estimate fix_com: bool If True, probe CoM is fixed to the center - symmetrize_probe_iter: int, optional - Number of iterations to run before radially-averaging the probe - probe_gaussian_filter: bool - If True, applies reciprocal-space gaussian filtering on residual aberrations - probe_gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A^-1 - probe_gaussian_filter_fix_amplitude: bool - If True, only the probe phase is smoothed - fix_probe_amplitude: bool + fit_probe_aberrations: bool + If True, fits the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + constrain_probe_amplitude: bool If True, probe amplitude is constrained by top hat function - fix_probe_amplitude_relative_radius: float + constrain_probe_amplitude_relative_radius: float Relative location of top-hat inflection point, between 0 and 0.5 - fix_probe_amplitude_relative_width: float + constrain_probe_amplitude_relative_width: float Relative width of top-hat sigmoid, between 0 and 0.5 - fix_probe_fourier_amplitude: bool - If True, probe fourier amplitude is constrained by top hat function - fix_probe_fourier_amplitude_threshold: float - Threshold value for current probe fourier mask. Value should - be between 0 and 1, where higher values provide the most masking. + constrain_probe_fourier_amplitude: bool + If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_probe_aperture: bool, + If True, probe Fourier amplitude is replaced by initial probe aperture. + initial_probe_aperture: np.ndarray, + Initial probe aperture to use in replacing probe Fourier amplitude. fix_positions: bool If True, positions are not updated gaussian_filter: bool @@ -1605,6 +1603,8 @@ def _constraints( Cut-off frequency in A^-1 for low-pass butterworth filter q_highpass: float Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter object_positivity: bool If True, forces object to be positive shrinkage_rad: float @@ -1632,6 +1632,7 @@ def _constraints( current_object, q_lowpass, q_highpass, + butterworth_order, ) if shrinkage_rad > 0.0 or object_mask is not None: @@ -1647,27 +1648,30 @@ def _constraints( if fix_com: current_probe = self._probe_center_of_mass_constraint(current_probe) - if probe_gaussian_filter: - current_probe = self._probe_residual_aberration_filtering_constraint( + if fix_probe_aperture: + current_probe = self._probe_aperture_constraint( current_probe, - probe_gaussian_filter_sigma, - probe_gaussian_filter_fix_amplitude, + initial_probe_aperture, + ) + elif constrain_probe_fourier_amplitude: + current_probe = self._probe_fourier_amplitude_constraint( + current_probe, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, ) - if symmetrize_probe: - current_probe = self._probe_radial_symmetrization_constraint(current_probe) - - if fix_probe_amplitude: - current_probe = self._probe_amplitude_constraint( + if fit_probe_aberrations: + current_probe = self._probe_aberration_fitting_constraint( current_probe, - fix_probe_amplitude_relative_radius, - fix_probe_amplitude_relative_width, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, ) - elif fix_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( + + if constrain_probe_amplitude: + current_probe = self._probe_amplitude_constraint( current_probe, - fix_probe_fourier_amplitude_threshold, - fix_probe_amplitude_relative_width, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, ) if not fix_positions: @@ -1687,30 +1691,35 @@ def reconstruct( max_iter: int = 64, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, max_batch_size: int = None, seed_random: int = None, - step_size: float = 0.9, + step_size: float = 0.5, normalization_min: float = 1, positions_step_size: float = 0.9, fix_com: bool = True, fix_probe_iter: int = 0, - symmetrize_probe_iter: int = 0, - fix_probe_amplitude_iter: int = 0, - fix_probe_amplitude_relative_radius: float = 0.5, - fix_probe_amplitude_relative_width: float = 0.05, - fix_probe_fourier_amplitude_iter: int = 0, - fix_probe_fourier_amplitude_threshold: float = 0.9, + fix_probe_aperture_iter: int = 0, + constrain_probe_amplitude_iter: int = 0, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude_iter: int = 0, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, constrain_position_distance: float = None, global_affine_transformation: bool = True, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, - probe_gaussian_filter_sigma: float = None, - probe_gaussian_filter_residual_aberrations_iter: int = np.inf, - probe_gaussian_filter_fix_amplitude: bool = True, + fit_probe_aberrations_iter: int = 0, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, butterworth_filter_iter: int = np.inf, q_lowpass: float = None, q_highpass: float = None, + butterworth_order: float = 2, object_positivity: bool = True, shrinkage_rad: float = 0.0, fix_potential_baseline: bool = True, @@ -1728,7 +1737,7 @@ def reconstruct( Maximum number of iterations to run reconstruction_method: str, optional Specifies which reconstruction algorithm to use, one of: - "generalized-projection", + "generalized-projections", "DM_AP" (or "difference-map_alternating-projections"), "RAAR" (or "relaxed-averaged-alternating-reflections"), "RRR" (or "relax-reflect-reflect"), @@ -1736,8 +1745,12 @@ def reconstruct( "GD" (or "gradient_descent") reconstruction_parameter: float, optional Reconstruction parameter for various reconstruction methods above. - reconstruction_parameter: float, optional - Tuning parameter to interpolate b/w DM-AP and DM-RAAR + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. max_batch_size: int, optional Max number of probes to update at once seed_random: int, optional @@ -1752,19 +1765,20 @@ def reconstruct( If True, fixes center of mass of probe fix_probe_iter: int, optional Number of iterations to run with a fixed probe before updating probe estimate - symmetrize_probe_iter: int, optional - Number of iterations to run with a fixed probe before updating probe estimate - fix_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - fix_probe_amplitude_relative_radius: float + fix_probe_aperture_iter: int, optional + Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate + constrain_probe_amplitude_iter: int, optional + Number of iterations to run while constraining the real-space probe with a top-hat support. + constrain_probe_amplitude_relative_radius: float Relative location of top-hat inflection point, between 0 and 0.5 - fix_probe_amplitude_relative_width: float + constrain_probe_amplitude_relative_width: float Relative width of top-hat sigmoid, between 0 and 0.5 - fix_probe_fourier_amplitude: bool - If True, probe fourier amplitude is constrained by top hat function - fix_probe_fourier_amplitude_threshold: float - Threshold value for current probe fourier mask. Value should - be between 0 and 1, where higher values provide the most masking. + constrain_probe_fourier_amplitude_iter: int, optional + Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate constrain_position_distance: float, optional @@ -1776,18 +1790,20 @@ def reconstruct( Standard deviation of gaussian kernel in A gaussian_filter_iter: int, optional Number of iterations to run using object smoothness constraint - probe_gaussian_filter_sigma: float, optional - Standard deviation of probe gaussian kernel in A^-1 - probe_gaussian_filter_residual_aberrations_iter: int, optional - Number of iterations to run using probe smoothing of residual aberrations - probe_gaussian_filter_fix_amplitude: bool - If True, only the probe phase is smoothed + fit_probe_aberrations_iter: int, optional + Number of iterations to run while fitting the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass: float Cut-off frequency in A^-1 for low-pass butterworth filter q_highpass: float Cut-off frequency in A^-1 for high-pass butterworth filter + butterworth_order: float + Butterworth filter order. Smaller gives a smoother filter object_positivity: bool, optional If True, forces object to be positive shrinkage_rad: float @@ -1809,17 +1825,23 @@ def reconstruct( # Reconstruction method - if reconstruction_method == "generalized-projection": - if np.array(reconstruction_parameter).shape != (3,): + if reconstruction_method == "generalized-projections": + if ( + reconstruction_parameter_a is None + or reconstruction_parameter_b is None + or reconstruction_parameter_c is None + ): raise ValueError( ( - "reconstruction_parameter must be a list of three numbers " - "when using `reconstriction_method`=generalized-projection." + "reconstruction_parameter_a/b/c must all be specified " + "when using reconstruction_method='generalized-projections'." ) ) use_projection_scheme = True - projection_a, projection_b, projection_c = reconstruction_parameter + projection_a = reconstruction_parameter_a + projection_b = reconstruction_parameter_b + projection_c = reconstruction_parameter_c step_size = None elif ( reconstruction_method == "DM_AP" @@ -1878,7 +1900,8 @@ def reconstruct( else: raise ValueError( ( - "reconstruction_method must be one of 'DM_AP' (or 'difference-map_alternating-projections'), " + "reconstruction_method must be one of 'generalized-projections', " + "'DM_AP' (or 'difference-map_alternating-projections'), " "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " "'RRR' (or 'relax-reflect-reflect'), " "'SUPERFLIP' (or 'charge-flipping'), " @@ -2149,20 +2172,21 @@ def reconstruct( self._probe, self._positions_px_all[start_tilt:end_tilt], fix_com=fix_com and a0 >= fix_probe_iter, - symmetrize_probe=a0 < symmetrize_probe_iter, - probe_gaussian_filter=a0 - < probe_gaussian_filter_residual_aberrations_iter - and probe_gaussian_filter_sigma is not None, - probe_gaussian_filter_sigma=probe_gaussian_filter_sigma, - probe_gaussian_filter_fix_amplitude=probe_gaussian_filter_fix_amplitude, - fix_probe_amplitude=a0 < fix_probe_amplitude_iter + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter and a0 >= fix_probe_iter, - fix_probe_amplitude_relative_radius=fix_probe_amplitude_relative_radius, - fix_probe_amplitude_relative_width=fix_probe_amplitude_relative_width, - fix_probe_fourier_amplitude=a0 - < fix_probe_fourier_amplitude_iter + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter and a0 >= fix_probe_iter, - fix_probe_fourier_amplitude_threshold=fix_probe_fourier_amplitude_threshold, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=self._probe_initial_aperture, fix_positions=a0 < fix_positions_iter, global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter @@ -2172,6 +2196,7 @@ def reconstruct( and (q_lowpass is not None or q_highpass is not None), q_lowpass=q_lowpass, q_highpass=q_highpass, + butterworth_order=butterworth_order, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse @@ -2186,24 +2211,30 @@ def reconstruct( if collective_tilt_updates: self._object += collective_object / self._num_tilts - (self._object, self._probe, _,) = self._constraints( + ( + self._object, + self._probe, + _, + ) = self._constraints( self._object, self._probe, None, fix_com=fix_com and a0 >= fix_probe_iter, - symmetrize_probe=a0 < symmetrize_probe_iter, - probe_gaussian_filter=a0 - < probe_gaussian_filter_residual_aberrations_iter - and probe_gaussian_filter_sigma is not None, - probe_gaussian_filter_sigma=probe_gaussian_filter_sigma, - probe_gaussian_filter_fix_amplitude=probe_gaussian_filter_fix_amplitude, - fix_probe_amplitude=a0 < fix_probe_amplitude_iter + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter and a0 >= fix_probe_iter, - fix_probe_amplitude_relative_radius=fix_probe_amplitude_relative_radius, - fix_probe_amplitude_relative_width=fix_probe_amplitude_relative_width, - fix_probe_fourier_amplitude=a0 < fix_probe_fourier_amplitude_iter + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter and a0 >= fix_probe_iter, - fix_probe_fourier_amplitude_threshold=fix_probe_fourier_amplitude_threshold, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=self._probe_initial_aperture, fix_positions=True, global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter @@ -2213,6 +2244,7 @@ def reconstruct( and (q_lowpass is not None or q_highpass is not None), q_lowpass=q_lowpass, q_highpass=q_highpass, + butterworth_order=butterworth_order, object_positivity=object_positivity, shrinkage_rad=shrinkage_rad, object_mask=self._object_fov_mask_inverse @@ -2224,13 +2256,17 @@ def reconstruct( self.error_iterations.append(error.item()) if store_iterations: self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append(asnumpy(self._probe.copy())) + self.probe_iterations.append(self.probe_centered) # store result self.object = asnumpy(self._object) - self.probe = asnumpy(self._probe) + self.probe = self.probe_centered self.error = error.item() + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + return self def _crop_rotate_object_manually( @@ -2424,12 +2460,20 @@ def _visualize_last_iteration( 0, ] - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] if plot_convergence: if plot_probe or plot_fourier_probe: @@ -2493,19 +2537,21 @@ def _visualize_last_iteration( self.probe_fourier, hue_start=hue_start, invert=invert ) ax.set_title("Reconstructed Fourier probe") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( self.probe, hue_start=hue_start, invert=invert ) ax.set_title("Reconstructed probe") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, **kwargs, ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") if cbar: divider = make_axes_locatable(ax) @@ -2670,12 +2716,20 @@ def _visualize_all_iterations( 0, ] - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] if plot_convergence: if plot_probe or plot_fourier_probe: @@ -2729,16 +2783,24 @@ def _visualize_all_iterations( for n, ax in enumerate(grid): if plot_fourier_probe: probe_array = Complex2RGB( - asnumpy(self._return_fourier_probe(probes[grid_range[n]])), + asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[grid_range[n]] + ) + ), hue_start=hue_start, invert=invert, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( probes[grid_range[n]], hue_start=hue_start, invert=invert ) ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") im = ax.imshow( probe_array, @@ -2746,9 +2808,6 @@ def _visualize_all_iterations( **kwargs, ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if cbar: add_colorbar_arg( grid.cbar_axes[n], hue_start=hue_start, invert=invert @@ -2893,7 +2952,7 @@ def _return_object_fft( rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims ) - return np.abs(np.fft.fftshift(np.fft.fft2(rotated_object))) + return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(rotated_object)))) def show_object_fft( self, diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py index 94f32ad12..80cdd8cd8 100644 --- a/py4DSTEM/process/phase/iterative_parallax.py +++ b/py4DSTEM/process/phase/iterative_parallax.py @@ -79,6 +79,7 @@ def __init__( # Metadata self._energy = energy self._verbose = verbose + self._device = device self._object_padding_px = object_padding_px self._preprocessed = False @@ -170,7 +171,7 @@ def preprocess( self._datacube, require_calibrations=True, ) - + self._intensities = xp.asarray(self._intensities, dtype=xp.float32) # make sure mean diffraction pattern is shaped correctly if (self._dp_mean.shape[0] != self._intensities.shape[2]) or ( self._dp_mean.shape[1] != self._intensities.shape[3] @@ -375,6 +376,11 @@ def preprocess( ax.set_title("Average Bright Field Image") self._preprocessed = True + + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + return self def tune_angle_and_defocus( @@ -725,7 +731,7 @@ def reconstruct( G_ref, G, upsample_factor=upsample_factor, - device="cpu" if xp is np else "gpu", + device=self._device, ) dx = ( @@ -825,6 +831,10 @@ def reconstruct( self.recon_BF = asnumpy(self._recon_BF) + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + return self def aberration_fit( @@ -874,6 +884,10 @@ def aberration_fit( ) / 2.0 # factor /2 for A1 astigmatism? /4? self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0 + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + # Print results if self._verbose: print( @@ -1059,6 +1073,10 @@ def aberration_correct( self._recon_phase_corrected = xp.real(xp.fft.ifft2(im_fft_corr)) self.recon_phase_corrected = asnumpy(self._recon_phase_corrected) + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + # plotting if plot_corrected_phase: figsize = kwargs.pop("figsize", (6, 6)) @@ -1217,6 +1235,11 @@ def depth_section( ax.set_yticks([]) ax.set_title(f"Depth section: {dz}A") + if self._device == "gpu": + xp = self._xp + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + return stack_depth def _crop_padded_object( diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py index d0a637204..9af22ba92 100644 --- a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py +++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py @@ -3,6 +3,8 @@ array_slice, estimate_global_transformation_ransac, fft_shift, + fit_aberration_surface, + regularize_probe_amplitude, ) from py4DSTEM.process.utils import get_CoM @@ -120,7 +122,7 @@ def _object_gaussian_constraint( """ xp = self._xp gaussian_filter = self._gaussian_filter - gaussian_filter_sigma /= xp.sqrt(self.sampling[0] ** 2 + self.sampling[1] ** 2) + gaussian_filter_sigma /= self.sampling[0] if pure_phase_object: phase = xp.angle(current_object) @@ -315,7 +317,7 @@ def _object_denoise_tv_chambolle( def _probe_center_of_mass_constraint(self, current_probe): """ Ptychographic center of mass constraint. - Used for centering probe intensity. + Used for centering corner-centered probe intensity. Parameters -------- @@ -329,78 +331,15 @@ def _probe_center_of_mass_constraint(self, current_probe): """ xp = self._xp - probe_center = xp.array(self._region_of_interest_shape) / 2 probe_intensity = xp.abs(current_probe) ** 2 probe_x0, probe_y0 = get_CoM( - probe_intensity, device="cpu" if xp is np else "gpu" - ) - shifted_probe = fft_shift( - current_probe, probe_center - xp.array([probe_x0, probe_y0]), xp + probe_intensity, device=self._device, corner_centered=True ) + shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp) return shifted_probe - def _probe_radial_symmetrization_constraint_base( - self, - current_probe, - num_bins=None, - center=None, - ): - xp = self._xp - - sx, sy = current_probe.shape - - if center is None: - center = (sx // 2, sy // 2) - - if num_bins is None: - num_bins = np.maximum(sx, sy) * 2 + 1 - - cx, cy = center - X, Y = xp.ogrid[0:sx, 0:sy] - r = xp.hypot(X - cx, Y - cy) - - rbin = (num_bins * r / r.max()).astype("int") - num = xp.bincount(rbin.ravel(), current_probe.ravel()) - denom = xp.bincount(rbin.ravel()) - denom[denom == 0] = 1 - - radial_mean = num / denom - - for r_bin, r_mean in enumerate(radial_mean): - if r_bin != 0.0: - current_probe[np.where(rbin == r_bin)] = r_mean - - return current_probe - - def _probe_radial_symmetrization_constraint( - self, - current_probe, - num_bins=None, - center=None, - ): - xp = self._xp - - current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) - fourier_probe = self._return_fourier_probe(current_probe) - - fourier_probe_real = fourier_probe.real.copy() - fourier_probe_imag = fourier_probe.imag.copy() - - fourier_probe_real = self._probe_radial_symmetrization_constraint_base( - fourier_probe_real, num_bins, center - ) - fourier_probe_imag = self._probe_radial_symmetrization_constraint_base( - fourier_probe_imag, num_bins, center - ) - - fourier_probe = fourier_probe_real + 1.0j * fourier_probe_imag - current_probe = xp.fft.ifftshift(xp.fft.ifft2(xp.fft.fftshift(fourier_probe))) - current_probe *= xp.sqrt(current_probe_sum / np.sum(np.abs(current_probe) ** 2)) - - return current_probe - def _probe_amplitude_constraint( self, current_probe, relative_radius, relative_width ): @@ -425,24 +364,26 @@ def _probe_amplitude_constraint( erf = self._erf probe_intensity = xp.abs(current_probe) ** 2 - current_probe_sum = xp.sum(probe_intensity) + #current_probe_sum = xp.sum(probe_intensity) - x = xp.linspace(-1 / 2, 1 / 2, current_probe.shape[0]) - y = xp.linspace(-1 / 2, 1 / 2, current_probe.shape[1]) - xa, ya = xp.meshgrid(x, y, indexing="ij") - ra = xp.sqrt(xa**2 + ya**2) - relative_radius + X = xp.fft.fftfreq(current_probe.shape[0])[:, None] + Y = xp.fft.fftfreq(current_probe.shape[1])[None] + r = xp.hypot(X, Y) - relative_radius sigma = np.sqrt(np.pi) / relative_width - tophat_mask = 0.5 * (1 - erf(sigma * ra / (1 - ra**2))) + tophat_mask = 0.5 * (1 - erf(sigma * r / (1 - r**2))) updated_probe = current_probe * tophat_mask - updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + #updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + #normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - return updated_probe * normalization + return updated_probe #* normalization def _probe_fourier_amplitude_constraint( - self, current_probe, threshold, relative_width + self, + current_probe, + width_max_pixels, + enforce_constant_intensity, ): """ Ptychographic top-hat filtering of Fourier probe. @@ -463,38 +404,62 @@ def _probe_fourier_amplitude_constraint( Constrained probe estimate """ xp = self._xp - erf = self._erf + asnumpy = self._asnumpy - current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + #current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) current_probe_fft = xp.fft.fft2(current_probe) - current_probe_fft_amp = xp.abs(current_probe_fft) - threshold_px = xp.argmax( - current_probe_fft_amp < xp.max(current_probe_fft_amp) * threshold + updated_probe_fft, _, _, _ = regularize_probe_amplitude( + asnumpy(current_probe_fft), + width_max_pixels=width_max_pixels, + nearest_angular_neighbor_averaging=5, + enforce_constant_intensity=enforce_constant_intensity, + corner_centered=True, ) - if threshold_px == 0: - return current_probe + updated_probe_fft = xp.asarray(updated_probe_fft) + updated_probe = xp.fft.ifft2(updated_probe_fft) + #updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + #normalization = xp.sqrt(current_probe_sum / updated_probe_sum) - qx = xp.fft.fftfreq(current_probe.shape[0], 1) - qy = xp.fft.fftfreq(current_probe.shape[1], 1) - qya, qxa = xp.meshgrid(qy, qx) - qra = xp.sqrt(qxa**2 + qya**2) - threshold_px / current_probe.shape[0] + return updated_probe #* normalization - sigma = np.sqrt(np.pi) / relative_width - tophat_mask = 0.5 * (1 - erf(sigma * qra / (1 - qra**2))) + def _probe_aperture_constraint( + self, + current_probe, + initial_probe_aperture, + ): + """ + Ptychographic constraint to fix Fourier amplitude to initial aperture. + + Parameters + ---------- + current_probe: np.ndarray + Current positions estimate - updated_probe = xp.fft.ifft2(current_probe_fft * tophat_mask) - updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) - normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + Returns + -------- + constrained_probe: np.ndarray + Constrained probe estimate + """ + xp = self._xp + + #current_probe_sum = xp.sum(xp.abs(current_probe) ** 2) + current_probe_fft_phase = xp.angle(xp.fft.fft2(current_probe)) - return updated_probe * normalization + updated_probe = xp.fft.ifft2( + xp.exp(1j * current_probe_fft_phase) * initial_probe_aperture + ) + #updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2) + #normalization = xp.sqrt(current_probe_sum / updated_probe_sum) + + return updated_probe #* normalization - def _probe_residual_aberration_filtering_constraint( + def _probe_aberration_fitting_constraint( self, current_probe, - gaussian_filter_sigma, - fix_amplitude, + max_angular_order, + max_radial_order, ): """ Ptychographic probe smoothing constraint. @@ -516,25 +481,21 @@ def _probe_residual_aberration_filtering_constraint( """ xp = self._xp - gaussian_filter = self._gaussian_filter - known_aberrations_array = self._known_aberrations_array - gaussian_filter_sigma /= xp.sqrt( - self._reciprocal_sampling[0] ** 2 + self._reciprocal_sampling[1] ** 2 - ) - fourier_probe = self._return_fourier_probe(current_probe) - if fix_amplitude: - fourier_probe_abs = xp.abs(fourier_probe) + fourier_probe = xp.fft.fft2(current_probe) + fourier_probe_abs = xp.abs(fourier_probe) + sampling = self.sampling - fourier_probe *= xp.conjugate(known_aberrations_array) - fourier_probe = gaussian_filter(fourier_probe, gaussian_filter_sigma) - fourier_probe *= known_aberrations_array - - if fix_amplitude: - fourier_probe_angle = xp.angle(fourier_probe) - fourier_probe = fourier_probe_abs * xp.exp(1.0j * fourier_probe_angle) + fitted_angle, _ = fit_aberration_surface( + fourier_probe, + sampling, + max_angular_order, + max_radial_order, + xp=xp, + ) - current_probe = xp.fft.ifftshift(xp.fft.ifft2(xp.fft.fftshift(fourier_probe))) + fourier_probe = fourier_probe_abs * xp.exp(1.0j * fitted_angle) + current_probe = xp.fft.ifft2(fourier_probe) return current_probe diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py index c655b82a1..8881d021c 100644 --- a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py +++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py @@ -53,7 +53,9 @@ class SimultaneousPtychographicReconstruction(PtychographicReconstruction): simultaneous_measurements_mode: str, optional One of '-+', '-0+', '0+', where -/0/+ refer to the sign of the magnetic potential semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels rolloff: float, optional Semiangle rolloff for the initial probe guess vacuum_probe_intensity: np.ndarray, optional @@ -95,6 +97,7 @@ def __init__( datacube: Sequence[DataCube] = None, simultaneous_measurements_mode: str = "-+", semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, rolloff: float = 2.0, vacuum_probe_intensity: np.ndarray = None, polar_parameters: Mapping[str, float] = None, @@ -160,6 +163,7 @@ def __init__( self._scan_positions = initial_scan_positions self._energy = energy self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px @@ -184,6 +188,9 @@ def preprocess( force_com_rotation: float = None, force_com_transpose: float = None, force_com_shifts: float = None, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, **kwargs, ): @@ -230,6 +237,12 @@ def preprocess( Amplitudes come from diffraction patterns shifted with the CoM in the upper left corner for each probe unless shift is overwritten. + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded @@ -344,6 +357,9 @@ def preprocess( intensities_0 = self._extract_intensities_and_calibrations_from_datacube( measurement_0, require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, ) ( @@ -424,6 +440,9 @@ def preprocess( intensities_1 = self._extract_intensities_and_calibrations_from_datacube( measurement_1, require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, ) ( @@ -505,6 +524,9 @@ def preprocess( intensities_2 = self._extract_intensities_and_calibrations_from_datacube( measurement_2, require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, ) ( @@ -591,15 +613,22 @@ def preprocess( self._scan_positions ) + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + # Object Initialization if self._object is None: - pad_x, pad_y = self._object_padding_px - p, q = np.max(self._positions_px, axis=0) + pad_x = self._object_padding_px[0][1] + pad_y = self._object_padding_px[1][1] + p, q = np.round(np.max(self._positions_px, axis=0)) p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - int + "int" ) q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - int + "int" ) if self._object_type == "potential": object_e = xp.zeros((p, q), dtype=xp.float32) @@ -647,12 +676,10 @@ def preprocess( probe_x0, probe_y0 = get_CoM( self._vacuum_probe_intensity, device=self._device ) - shift_x = self._region_of_interest_shape[0] // 2 - probe_x0 - shift_y = self._region_of_interest_shape[1] // 2 - probe_y0 self._vacuum_probe_intensity = get_shifted_ar( self._vacuum_probe_intensity, - shift_x, - shift_y, + -probe_x0, + -probe_y0, bilinear=True, device=self._device, ) @@ -695,6 +722,7 @@ def preprocess( self._probe = xp.asarray(self._probe, dtype=xp.complex64) self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) self._known_aberrations_array = ComplexProbe( energy=self._energy, @@ -704,8 +732,6 @@ def preprocess( device=self._device, )._evaluate_ctf() - self._known_aberrations_array = xp.fft.ifftshift(self._known_aberrations_array) - # overlaps shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) probe_intensities = xp.abs(shifted_probes) ** 2 @@ -728,7 +754,7 @@ def preprocess( # initial probe complex_probe_rgb = Complex2RGB( - asnumpy(self._probe), + self.probe_centered, vmin=vmin, vmax=vmax, hue_start=hue_start, @@ -788,6 +814,10 @@ def preprocess( self._preprocessed = True + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + return self def _warmup_overlap_projection(self, current_object, current_probe): @@ -2173,34 +2203,6 @@ def _adjoint( return current_object, current_probe - def _probe_center_of_mass_constraint(self, current_probe): - """ - Ptychographic threshold constraint. - Used for avoiding the scaling ambiguity between probe and object. - - Parameters - -------- - current_probe: np.ndarray - Current probe estimate - - Returns - -------- - constrained_probe: np.ndarray - Constrained probe estimate - """ - xp = self._xp - asnumpy = self._asnumpy - - probe_center = xp.array(self._region_of_interest_shape) / 2 - probe_intensity = asnumpy(xp.abs(current_probe) ** 2) - - probe_x0, probe_y0 = get_CoM(probe_intensity) - shifted_probe = fft_shift( - current_probe, probe_center - xp.array([probe_x0, probe_y0]), xp - ) - - return shifted_probe - def _constraints( self, current_object, @@ -2208,15 +2210,17 @@ def _constraints( current_positions, pure_phase_object, fix_com, - symmetrize_probe, - probe_gaussian_filter, - probe_gaussian_filter_sigma, - probe_gaussian_filter_fix_amplitude, - fix_probe_amplitude, - fix_probe_amplitude_relative_radius, - fix_probe_amplitude_relative_width, - fix_probe_fourier_amplitude, - fix_probe_fourier_amplitude_threshold, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + fix_probe_aperture, + initial_probe_aperture, fix_positions, global_affine_transformation, gaussian_filter, @@ -2248,19 +2252,28 @@ def _constraints( If True, object amplitude is set to unity fix_com: bool If True, probe CoM is fixed to the center - symmetrize_probe: bool - If True, the probe is radially-averaged - fix_probe_amplitude: bool + fit_probe_aberrations: bool + If True, fits the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + constrain_probe_amplitude: bool If True, probe amplitude is constrained by top hat function - fix_probe_amplitude_relative_radius: float + constrain_probe_amplitude_relative_radius: float Relative location of top-hat inflection point, between 0 and 0.5 - fix_probe_amplitude_relative_width: float + constrain_probe_amplitude_relative_width: float Relative width of top-hat sigmoid, between 0 and 0.5 - fix_probe_fourier_amplitude: bool - If True, probe fourier amplitude is constrained by top hat function - fix_probe_fourier_amplitude_threshold: float - Threshold value for current probe fourier mask. Value should - be between 0 and 1, where higher values provide the most masking. + constrain_probe_fourier_amplitude: bool + If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_probe_aperture: bool, + If True, probe Fourier amplitude is replaced by initial probe aperture. + initial_probe_aperture: np.ndarray, + Initial probe aperture to use in replacing probe Fourier amplitude. fix_positions: bool If True, positions are not updated gaussian_filter: bool @@ -2356,27 +2369,30 @@ def _constraints( if fix_com: current_probe = self._probe_center_of_mass_constraint(current_probe) - if probe_gaussian_filter: - current_probe = self._probe_residual_aberration_filtering_constraint( + if fix_probe_aperture: + current_probe = self._probe_aperture_constraint( current_probe, - probe_gaussian_filter_sigma, - probe_gaussian_filter_fix_amplitude, + initial_probe_aperture, + ) + elif constrain_probe_fourier_amplitude: + current_probe = self._probe_fourier_amplitude_constraint( + current_probe, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, ) - if symmetrize_probe: - current_probe = self._probe_radial_symmetrization_constraint(current_probe) - - if fix_probe_amplitude: - current_probe = self._probe_amplitude_constraint( + if fit_probe_aberrations: + current_probe = self._probe_aberration_fitting_constraint( current_probe, - fix_probe_amplitude_relative_radius, - fix_probe_amplitude_relative_width, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, ) - elif fix_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( + + if constrain_probe_amplitude: + current_probe = self._probe_amplitude_constraint( current_probe, - fix_probe_fourier_amplitude_threshold, - fix_probe_amplitude_relative_width, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, ) if not fix_positions: @@ -2396,30 +2412,34 @@ def reconstruct( max_iter: int = 64, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, max_batch_size: int = None, seed_random: int = None, - step_size: float = 0.9, + step_size: float = 0.5, normalization_min: float = 1, positions_step_size: float = 0.9, pure_phase_object_iter: int = 0, fix_com: bool = True, fix_probe_iter: int = 0, warmup_iter: int = 0, - symmetrize_probe_iter: int = 0, - fix_probe_amplitude_iter: int = 0, - fix_probe_amplitude_relative_radius: float = 0.5, - fix_probe_amplitude_relative_width: float = 0.05, - fix_probe_fourier_amplitude_iter: int = 0, - fix_probe_fourier_amplitude_threshold: float = 0.9, + fix_probe_aperture_iter: int = 0, + constrain_probe_amplitude_iter: int = 0, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude_iter: int = 0, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, constrain_position_distance: float = None, global_affine_transformation: bool = True, gaussian_filter_sigma_e: float = None, gaussian_filter_sigma_m: float = None, gaussian_filter_iter: int = np.inf, - probe_gaussian_filter_sigma: float = None, - probe_gaussian_filter_residual_aberrations_iter: int = np.inf, - probe_gaussian_filter_fix_amplitude: bool = True, + fit_probe_aberrations_iter: int = 0, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, butterworth_filter_iter: int = np.inf, q_lowpass_e: float = None, q_lowpass_m: float = None, @@ -2443,7 +2463,7 @@ def reconstruct( Maximum number of iterations to run reconstruction_method: str, optional Specifies which reconstruction algorithm to use, one of: - "generalized-projection", + "generalized-projections", "DM_AP" (or "difference-map_alternating-projections"), "RAAR" (or "relaxed-averaged-alternating-reflections"), "RRR" (or "relax-reflect-reflect"), @@ -2451,6 +2471,12 @@ def reconstruct( "GD" (or "gradient_descent") reconstruction_parameter: float, optional Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. max_batch_size: int, optional Max number of probes to update at once seed_random: int, optional @@ -2467,19 +2493,20 @@ def reconstruct( If True, fixes center of mass of probe fix_probe_iter: int, optional Number of iterations to run with a fixed probe before updating probe estimate - symmetrize_probe_iter: int, optional - Number of iterations to run before radially-averaging the probe - fix_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - fix_probe_amplitude_relative_radius: float + fix_probe_aperture_iter: int, optional + Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate + constrain_probe_amplitude_iter: int, optional + Number of iterations to run while constraining the real-space probe with a top-hat support. + constrain_probe_amplitude_relative_radius: float Relative location of top-hat inflection point, between 0 and 0.5 - fix_probe_amplitude_relative_width: float + constrain_probe_amplitude_relative_width: float Relative width of top-hat sigmoid, between 0 and 0.5 - fix_probe_fourier_amplitude: bool - If True, probe fourier amplitude is constrained by top hat function - fix_probe_fourier_amplitude_threshold: float - Threshold value for current probe fourier mask. Value should - be between 0 and 1, where higher values provide the most masking. + constrain_probe_fourier_amplitude_iter: int, optional + Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate constrain_position_distance: float @@ -2493,12 +2520,12 @@ def reconstruct( Standard deviation of gaussian kernel for magnetic object in A gaussian_filter_iter: int, optional Number of iterations to run using object smoothness constraint - probe_gaussian_filter_sigma: float, optional - Standard deviation of probe gaussian kernel in A^-1 - probe_gaussian_filter_residual_aberrations_iter: int, optional - Number of iterations to run using probe smoothing of residual aberrations - probe_gaussian_filter_fix_amplitude: bool - If True, only the probe phase is smoothed + fit_probe_aberrations_iter: int, optional + Number of iterations to run while fitting the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass_e: float @@ -2537,17 +2564,23 @@ def reconstruct( # Reconstruction method - if reconstruction_method == "generalized-projection": - if np.array(reconstruction_parameter).shape != (3,): + if reconstruction_method == "generalized-projections": + if ( + reconstruction_parameter_a is None + or reconstruction_parameter_b is None + or reconstruction_parameter_c is None + ): raise ValueError( ( - "reconstruction_parameter must be a list of three numbers " - "when using `reconstriction_method`=generalized-projection." + "reconstruction_parameter_a/b/c must all be specified " + "when using reconstruction_method='generalized-projections'." ) ) use_projection_scheme = True - projection_a, projection_b, projection_c = reconstruction_parameter + projection_a = reconstruction_parameter_a + projection_b = reconstruction_parameter_b + projection_c = reconstruction_parameter_c step_size = None elif ( reconstruction_method == "DM_AP" @@ -2606,7 +2639,8 @@ def reconstruct( else: raise ValueError( ( - "reconstruction_method must be one of 'DM_AP' (or 'difference-map_alternating-projections'), " + "reconstruction_method must be one of 'generalized-projections', " + "'DM_AP' (or 'difference-map_alternating-projections'), " "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " "'RRR' (or 'relax-reflect-reflect'), " "'SUPERFLIP' (or 'charge-flipping'), " @@ -2836,19 +2870,21 @@ def reconstruct( self._probe, self._positions_px, fix_com=fix_com and a0 >= fix_probe_iter, - symmetrize_probe=a0 < symmetrize_probe_iter, - probe_gaussian_filter=a0 - < probe_gaussian_filter_residual_aberrations_iter - and probe_gaussian_filter_sigma is not None, - probe_gaussian_filter_sigma=probe_gaussian_filter_sigma, - probe_gaussian_filter_fix_amplitude=probe_gaussian_filter_fix_amplitude, - fix_probe_amplitude=a0 < fix_probe_amplitude_iter + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter and a0 >= fix_probe_iter, - fix_probe_amplitude_relative_radius=fix_probe_amplitude_relative_radius, - fix_probe_amplitude_relative_width=fix_probe_amplitude_relative_width, - fix_probe_fourier_amplitude=a0 < fix_probe_fourier_amplitude_iter + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter and a0 >= fix_probe_iter, - fix_probe_fourier_amplitude_threshold=fix_probe_fourier_amplitude_threshold, + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter + and a0 >= fix_probe_iter, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=self._probe_initial_aperture, fix_positions=a0 < fix_positions_iter, global_affine_transformation=global_affine_transformation, warmup_iteration=a0 < warmup_iter, @@ -2885,16 +2921,20 @@ def reconstruct( asnumpy(self._object[1].copy()), ) ) - self.probe_iterations.append(asnumpy(self._probe.copy())) + self.probe_iterations.append(self.probe_centered) # store result if a0 < warmup_iter: self.object = (asnumpy(self._object[0]), None) else: self.object = (asnumpy(self._object[0]), asnumpy(self._object[1])) - self.probe = asnumpy(self._probe) + self.probe = self.probe_centered self.error = error.item() + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + return self def _visualize_last_iteration_figax( @@ -3019,12 +3059,20 @@ def _visualize_last_iteration( 0, ] - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] if plot_convergence: if plot_probe or plot_fourier_probe: @@ -3111,19 +3159,21 @@ def _visualize_last_iteration( self.probe_fourier, hue_start=hue_start, invert=invert ) ax.set_title("Reconstructed Fourier probe") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( self.probe, hue_start=hue_start, invert=invert ) ax.set_title("Reconstructed probe") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, **kwargs, ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") if cbar: divider = make_axes_locatable(ax) diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py index 8a83fa5ba..0480bae8a 100644 --- a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py +++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py @@ -18,7 +18,7 @@ cp = None from emdfile import Custom, tqdmnd -from py4DSTEM.classes import DataCube +from py4DSTEM.datacube import DataCube from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction from py4DSTEM.process.phase.utils import ( ComplexProbe, @@ -50,7 +50,9 @@ class SingleslicePtychographicReconstruction(PtychographicReconstruction): datacube: DataCube Input 4D diffraction pattern intensities semiangle_cutoff: float, optional - Semiangle cutoff for the initial probe guess + Semiangle cutoff for the initial probe guess in mrad + semiangle_cutoff_pixels: float, optional + Semiangle cutoff for the initial probe guess in pixels rolloff: float, optional Semiangle rolloff for the initial probe guess vacuum_probe_intensity: np.ndarray, optional @@ -91,6 +93,7 @@ def __init__( energy: float, datacube: DataCube = None, semiangle_cutoff: float = None, + semiangle_cutoff_pixels: float = None, rolloff: float = 2.0, vacuum_probe_intensity: np.ndarray = None, polar_parameters: Mapping[str, float] = None, @@ -156,6 +159,7 @@ def __init__( self._scan_positions = initial_scan_positions self._energy = energy self._semiangle_cutoff = semiangle_cutoff + self._semiangle_cutoff_pixels = semiangle_cutoff_pixels self._rolloff = rolloff self._object_type = object_type self._object_padding_px = object_padding_px @@ -180,6 +184,9 @@ def preprocess( force_com_rotation: float = None, force_com_transpose: float = None, force_com_shifts: float = None, + force_scan_sampling: float = None, + force_angular_sampling: float = None, + force_reciprocal_sampling: float = None, object_fov_mask: np.ndarray = None, **kwargs, ): @@ -229,6 +236,12 @@ def preprocess( Amplitudes come from diffraction patterns shifted with the CoM in the upper left corner for each probe unless shift is overwritten. + force_scan_sampling: float, optional + Override DataCube real space scan pixel size calibrations, in Angstrom + force_angular_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in mrad + force_reciprocal_sampling: float, optional + Override DataCube reciprocal pixel size calibration, in A^-1 object_fov_mask: np.ndarray (boolean) Boolean mask of FOV. Used to calculate additional shrinkage of object If None, probe_overlap intensity is thresholded @@ -273,6 +286,9 @@ def preprocess( self._intensities = self._extract_intensities_and_calibrations_from_datacube( self._datacube, require_calibrations=True, + force_scan_sampling=force_scan_sampling, + force_angular_sampling=force_angular_sampling, + force_reciprocal_sampling=force_reciprocal_sampling, ) ( @@ -328,15 +344,22 @@ def preprocess( self._scan_positions ) + # handle semiangle specified in pixels + if self._semiangle_cutoff_pixels: + self._semiangle_cutoff = ( + self._semiangle_cutoff_pixels * self._angular_sampling[0] + ) + # Object Initialization if self._object is None: - pad_x, pad_y = self._object_padding_px - p, q = np.max(self._positions_px, axis=0) + pad_x = self._object_padding_px[0][1] + pad_y = self._object_padding_px[1][1] + p, q = np.round(np.max(self._positions_px, axis=0)) p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype( - int + "int" ) q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype( - int + "int" ) if self._object_type == "potential": self._object = xp.zeros((p, q), dtype=xp.float32) @@ -379,18 +402,16 @@ def preprocess( self._vacuum_probe_intensity, dtype=xp.float32 ) probe_x0, probe_y0 = get_CoM( - self._vacuum_probe_intensity, device=self._device + self._vacuum_probe_intensity, + device=self._device, ) - shift_x = self._region_of_interest_shape[0] // 2 - probe_x0 - shift_y = self._region_of_interest_shape[1] // 2 - probe_y0 self._vacuum_probe_intensity = get_shifted_ar( self._vacuum_probe_intensity, - shift_x, - shift_y, + -probe_x0, + -probe_y0, bilinear=True, device=self._device, ) - self._probe = ( ComplexProbe( gpts=self._region_of_interest_shape, @@ -429,6 +450,7 @@ def preprocess( self._probe = xp.asarray(self._probe, dtype=xp.complex64) self._probe_initial = self._probe.copy() + self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe)) self._known_aberrations_array = ComplexProbe( energy=self._energy, @@ -438,8 +460,6 @@ def preprocess( device=self._device, )._evaluate_ctf() - self._known_aberrations_array = xp.fft.ifftshift(self._known_aberrations_array) - # overlaps shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp) probe_intensities = xp.abs(shifted_probes) ** 2 @@ -462,7 +482,7 @@ def preprocess( # initial probe complex_probe_rgb = Complex2RGB( - asnumpy(self._probe), + self.probe_centered, vmin=vmin, vmax=vmax, hue_start=hue_start, @@ -522,6 +542,10 @@ def preprocess( self._preprocessed = True + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + return self def _overlap_projection(self, current_object, current_probe): @@ -980,15 +1004,17 @@ def _constraints( current_positions, pure_phase_object, fix_com, - symmetrize_probe, - probe_gaussian_filter, - probe_gaussian_filter_sigma, - probe_gaussian_filter_fix_amplitude, - fix_probe_amplitude, - fix_probe_amplitude_relative_radius, - fix_probe_amplitude_relative_width, - fix_probe_fourier_amplitude, - fix_probe_fourier_amplitude_threshold, + fit_probe_aberrations, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, + constrain_probe_amplitude, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, + fix_probe_aperture, + initial_probe_aperture, fix_positions, global_affine_transformation, gaussian_filter, @@ -1016,25 +1042,28 @@ def _constraints( If True, object amplitude is set to unity fix_com: bool If True, probe CoM is fixed to the center - symmetrize_probe: bool - If True, the probe is radially-averaged - probe_gaussian_filter: bool - If True, applies reciprocal-space gaussian filtering on residual aberrations - probe_gaussian_filter_sigma: float - Standard deviation of gaussian kernel in A^-1 - probe_gaussian_filter_fix_amplitude: bool - If True, only the probe phase is smoothed - fix_probe_amplitude: bool + fit_probe_aberrations: bool + If True, fits the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions + constrain_probe_amplitude: bool If True, probe amplitude is constrained by top hat function - fix_probe_amplitude_relative_radius: float + constrain_probe_amplitude_relative_radius: float Relative location of top-hat inflection point, between 0 and 0.5 - fix_probe_amplitude_relative_width: float + constrain_probe_amplitude_relative_width: float Relative width of top-hat sigmoid, between 0 and 0.5 - fix_probe_fourier_amplitude: bool - If True, probe fourier amplitude is constrained by top hat function - fix_probe_fourier_amplitude_threshold: float - Threshold value for current probe fourier mask. Value should - be between 0 and 1, where higher values provide the most masking. + constrain_probe_fourier_amplitude: bool + If True, probe aperture is constrained by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. + fix_probe_aperture: bool, + If True, probe Fourier amplitude is replaced by initial probe aperture. + initial_probe_aperture: np.ndarray, + Initial probe aperture to use in replacing probe Fourier amplitude. fix_positions: bool If True, positions are not updated gaussian_filter: bool @@ -1096,27 +1125,30 @@ def _constraints( if fix_com: current_probe = self._probe_center_of_mass_constraint(current_probe) - if probe_gaussian_filter: - current_probe = self._probe_residual_aberration_filtering_constraint( + if fix_probe_aperture: + current_probe = self._probe_aperture_constraint( current_probe, - probe_gaussian_filter_sigma, - probe_gaussian_filter_fix_amplitude, + initial_probe_aperture, + ) + elif constrain_probe_fourier_amplitude: + current_probe = self._probe_fourier_amplitude_constraint( + current_probe, + constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity, ) - if symmetrize_probe: - current_probe = self._probe_radial_symmetrization_constraint(current_probe) - - if fix_probe_amplitude: - current_probe = self._probe_amplitude_constraint( + if fit_probe_aberrations: + current_probe = self._probe_aberration_fitting_constraint( current_probe, - fix_probe_amplitude_relative_radius, - fix_probe_amplitude_relative_width, + fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order, ) - elif fix_probe_fourier_amplitude: - current_probe = self._probe_fourier_amplitude_constraint( + + if constrain_probe_amplitude: + current_probe = self._probe_amplitude_constraint( current_probe, - fix_probe_fourier_amplitude_threshold, - fix_probe_amplitude_relative_width, + constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width, ) if not fix_positions: @@ -1136,28 +1168,32 @@ def reconstruct( max_iter: int = 64, reconstruction_method: str = "gradient-descent", reconstruction_parameter: float = 1.0, + reconstruction_parameter_a: float = None, + reconstruction_parameter_b: float = None, + reconstruction_parameter_c: float = None, max_batch_size: int = None, seed_random: int = None, - step_size: float = 0.9, + step_size: float = 0.5, normalization_min: float = 1, positions_step_size: float = 0.9, pure_phase_object_iter: int = 0, fix_com: bool = True, fix_probe_iter: int = 0, - symmetrize_probe_iter: int = 0, - fix_probe_amplitude_iter: int = 0, - fix_probe_amplitude_relative_radius: float = 0.5, - fix_probe_amplitude_relative_width: float = 0.05, - fix_probe_fourier_amplitude_iter: int = 0, - fix_probe_fourier_amplitude_threshold: float = 0.9, + fix_probe_aperture_iter: int = 0, + constrain_probe_amplitude_iter: int = 0, + constrain_probe_amplitude_relative_radius: float = 0.5, + constrain_probe_amplitude_relative_width: float = 0.05, + constrain_probe_fourier_amplitude_iter: int = 0, + constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0, + constrain_probe_fourier_amplitude_constant_intensity: bool = False, fix_positions_iter: int = np.inf, constrain_position_distance: float = None, global_affine_transformation: bool = True, gaussian_filter_sigma: float = None, gaussian_filter_iter: int = np.inf, - probe_gaussian_filter_sigma: float = None, - probe_gaussian_filter_residual_aberrations_iter: int = np.inf, - probe_gaussian_filter_fix_amplitude: bool = True, + fit_probe_aberrations_iter: int = 0, + fit_probe_aberrations_max_angular_order: int = 4, + fit_probe_aberrations_max_radial_order: int = 4, butterworth_filter_iter: int = np.inf, q_lowpass: float = None, q_highpass: float = None, @@ -1179,7 +1215,7 @@ def reconstruct( Maximum number of iterations to run reconstruction_method: str, optional Specifies which reconstruction algorithm to use, one of: - "generalized-projection", + "generalized-projections", "DM_AP" (or "difference-map_alternating-projections"), "RAAR" (or "relaxed-averaged-alternating-reflections"), "RRR" (or "relax-reflect-reflect"), @@ -1187,6 +1223,12 @@ def reconstruct( "GD" (or "gradient_descent") reconstruction_parameter: float, optional Reconstruction parameter for various reconstruction methods above. + reconstruction_parameter_a: float, optional + Reconstruction parameter a for reconstruction_method='generalized-projections'. + reconstruction_parameter_b: float, optional + Reconstruction parameter b for reconstruction_method='generalized-projections'. + reconstruction_parameter_c: float, optional + Reconstruction parameter c for reconstruction_method='generalized-projections'. max_batch_size: int, optional Max number of probes to update at once seed_random: int, optional @@ -1203,19 +1245,20 @@ def reconstruct( If True, fixes center of mass of probe fix_probe_iter: int, optional Number of iterations to run with a fixed probe before updating probe estimate - symmetrize_probe_iter: int, optional - Number of iterations to run before radially-averaging the probe - fix_probe_amplitude: bool - If True, probe amplitude is constrained by top hat function - fix_probe_amplitude_relative_radius: float + fix_probe_aperture_iter: int, optional + Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate + constrain_probe_amplitude_iter: int, optional + Number of iterations to run while constraining the real-space probe with a top-hat support. + constrain_probe_amplitude_relative_radius: float Relative location of top-hat inflection point, between 0 and 0.5 - fix_probe_amplitude_relative_width: float + constrain_probe_amplitude_relative_width: float Relative width of top-hat sigmoid, between 0 and 0.5 - fix_probe_fourier_amplitude: bool - If True, probe fourier amplitude is constrained by top hat function - fix_probe_fourier_amplitude_threshold: float - Threshold value for current probe fourier mask. Value should - be between 0 and 1, where higher values provide the most masking. + constrain_probe_fourier_amplitude_iter: int, optional + Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency. + constrain_probe_fourier_amplitude_max_width_pixels: float + Maximum pixel width of fitted sigmoid functions. + constrain_probe_fourier_amplitude_constant_intensity: bool + If True, the probe aperture is additionally constrained to a constant intensity. fix_positions_iter: int, optional Number of iterations to run with fixed positions before updating positions estimate constrain_position_distance: float, optional @@ -1227,12 +1270,12 @@ def reconstruct( Standard deviation of gaussian kernel in A gaussian_filter_iter: int, optional Number of iterations to run using object smoothness constraint - probe_gaussian_filter_sigma: float, optional - Standard deviation of probe gaussian kernel in A^-1 - probe_gaussian_filter_fix_amplitude: bool - If True, only the probe phase is smoothed - probe_gaussian_filter_residual_aberrations_iter: int, optional - Number of iterations to run using probe smoothing of residual aberrations + fit_probe_aberrations_iter: int, optional + Number of iterations to run while fitting the probe aberrations to a low-order expansion + fit_probe_aberrations_max_angular_order: bool + Max angular order of probe aberrations basis functions + fit_probe_aberrations_max_radial_order: bool + Max radial order of probe aberrations basis functions butterworth_filter_iter: int, optional Number of iterations to run using high-pass butteworth filter q_lowpass: float @@ -1267,17 +1310,23 @@ def reconstruct( # Reconstruction method - if reconstruction_method == "generalized-projection": - if np.array(reconstruction_parameter).shape != (3,): + if reconstruction_method == "generalized-projections": + if ( + reconstruction_parameter_a is None + or reconstruction_parameter_b is None + or reconstruction_parameter_c is None + ): raise ValueError( ( - "reconstruction_parameter must be a list of three numbers " - "when using `reconstriction_method`=generalized-projection." + "reconstruction_parameter_a/b/c must all be specified " + "when using reconstruction_method='generalized-projections'." ) ) use_projection_scheme = True - projection_a, projection_b, projection_c = reconstruction_parameter + projection_a = reconstruction_parameter_a + projection_b = reconstruction_parameter_b + projection_c = reconstruction_parameter_c step_size = None elif ( reconstruction_method == "DM_AP" @@ -1336,7 +1385,8 @@ def reconstruct( else: raise ValueError( ( - "reconstruction_method must be one of 'DM_AP' (or 'difference-map_alternating-projections'), " + "reconstruction_method must be one of 'generalized-projections', " + "'DM_AP' (or 'difference-map_alternating-projections'), " "'RAAR' (or 'relaxed-averaged-alternating-reflections'), " "'RRR' (or 'relax-reflect-reflect'), " "'SUPERFLIP' (or 'charge-flipping'), " @@ -1543,19 +1593,21 @@ def reconstruct( self._probe, self._positions_px, fix_com=fix_com and a0 >= fix_probe_iter, - symmetrize_probe=a0 < symmetrize_probe_iter, - probe_gaussian_filter=a0 - < probe_gaussian_filter_residual_aberrations_iter - and probe_gaussian_filter_sigma is not None, - probe_gaussian_filter_sigma=probe_gaussian_filter_sigma, - probe_gaussian_filter_fix_amplitude=probe_gaussian_filter_fix_amplitude, - fix_probe_amplitude=a0 < fix_probe_amplitude_iter + constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter + and a0 >= fix_probe_iter, + constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius, + constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width, + constrain_probe_fourier_amplitude=a0 + < constrain_probe_fourier_amplitude_iter and a0 >= fix_probe_iter, - fix_probe_amplitude_relative_radius=fix_probe_amplitude_relative_radius, - fix_probe_amplitude_relative_width=fix_probe_amplitude_relative_width, - fix_probe_fourier_amplitude=a0 < fix_probe_fourier_amplitude_iter + constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels, + constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity, + fit_probe_aberrations=a0 < fit_probe_aberrations_iter and a0 >= fix_probe_iter, - fix_probe_fourier_amplitude_threshold=fix_probe_fourier_amplitude_threshold, + fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order, + fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order, + fix_probe_aperture=a0 < fix_probe_aperture_iter, + initial_probe_aperture=self._probe_initial_aperture, fix_positions=a0 < fix_positions_iter, global_affine_transformation=global_affine_transformation, gaussian_filter=a0 < gaussian_filter_iter @@ -1578,13 +1630,17 @@ def reconstruct( self.error_iterations.append(error.item()) if store_iterations: self.object_iterations.append(asnumpy(self._object.copy())) - self.probe_iterations.append(asnumpy(self._probe.copy())) + self.probe_iterations.append(self.probe_centered) # store result self.object = asnumpy(self._object) - self.probe = asnumpy(self._probe) + self.probe = self.probe_centered self.error = error.item() + if self._device == "gpu": + xp._default_memory_pool.free_all_blocks() + xp.clear_memo() + return self def _visualize_last_iteration_figax( @@ -1696,12 +1752,20 @@ def _visualize_last_iteration( 0, ] - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] if plot_convergence: if plot_probe or plot_fourier_probe: @@ -1767,19 +1831,21 @@ def _visualize_last_iteration( self.probe_fourier, hue_start=hue_start, invert=invert ) ax.set_title("Reconstructed Fourier probe") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") else: probe_array = Complex2RGB( self.probe, hue_start=hue_start, invert=invert ) ax.set_title("Reconstructed probe") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") im = ax.imshow( probe_array, extent=probe_extent, **kwargs, ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") if cbar: divider = make_axes_locatable(ax) @@ -1923,12 +1989,20 @@ def _visualize_all_iterations( 0, ] - probe_extent = [ - 0, - self.sampling[1] * self._region_of_interest_shape[1], - self.sampling[0] * self._region_of_interest_shape[0], - 0, - ] + if plot_fourier_probe: + probe_extent = [ + -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[1] * self._region_of_interest_shape[1] / 2, + self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2, + ] + elif plot_probe: + probe_extent = [ + 0, + self.sampling[1] * self._region_of_interest_shape[1], + self.sampling[0] * self._region_of_interest_shape[0], + 0, + ] if plot_convergence: if plot_probe or plot_fourier_probe: @@ -1984,16 +2058,25 @@ def _visualize_all_iterations( for n, ax in enumerate(grid): if plot_fourier_probe: probe_array = Complex2RGB( - asnumpy(self._return_fourier_probe(probes[grid_range[n]])), + asnumpy( + self._return_fourier_probe_from_centered_probe( + probes[grid_range[n]] + ) + ), hue_start=hue_start, invert=invert, ) ax.set_title(f"Iter: {grid_range[n]} Fourier probe") + ax.set_ylabel("kx [mrad]") + ax.set_xlabel("ky [mrad]") + else: probe_array = Complex2RGB( probes[grid_range[n]], hue_start=hue_start, invert=invert ) ax.set_title(f"Iter: {grid_range[n]} probe") + ax.set_ylabel("x [A]") + ax.set_xlabel("y [A]") im = ax.imshow( probe_array, @@ -2001,9 +2084,6 @@ def _visualize_all_iterations( **kwargs, ) - ax.set_ylabel("x [A]") - ax.set_xlabel("y [A]") - if cbar: add_colorbar_arg( grid.cbar_axes[n], hue_start=hue_start, invert=invert diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py new file mode 100644 index 000000000..91a71cb30 --- /dev/null +++ b/py4DSTEM/process/phase/parameter_optimize.py @@ -0,0 +1,598 @@ +from functools import partial +from typing import Callable, Union + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.gridspec import GridSpec +from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction +from py4DSTEM.process.phase.utils import AffineTransform +from skopt import gp_minimize +from skopt.plots import plot_convergence as skopt_plot_convergence +from skopt.plots import plot_evaluations as skopt_plot_evaluations +from skopt.plots import plot_gaussian_process as skopt_plot_gaussian_process +from skopt.plots import plot_objective as skopt_plot_objective +from skopt.space import Categorical, Integer, Real +from skopt.utils import use_named_args +from tqdm import tqdm + + +class PtychographyOptimizer: + """ + Optimize ptychographic hyperparameters with Bayesian Optimization of a + Gaussian process. Any of the scalar-valued real or integer, boolean, or categorical + arguments to the ptychographic init-preprocess-reconstruct pipeline can be optimized over. + """ + + def __init__( + self, + reconstruction_type: type[PhaseReconstruction], + init_args: dict, + preprocess_args: dict = {}, + reconstruction_args: dict = {}, + affine_args: dict = {}, + ): + """ + Parameter optimization for ptychographic reconstruction based on Bayesian Optimization + with Gaussian Process. + + Usage + ----- + Dictionaries of the arguments to __init__, AffineTransform (for distorting the initial + scan positions), preprocess, and reconstruct are required. For parameters not optimized + over, the value in the dictionary is used. To optimize a parameter, instead pass an + OptimizationParameter object inside the dictionary to specify the initial guess, bounds, + and type of parameter, for example: + >>> 'param':OptimizationParameter(initial guess, lower bound, upper bound) + Calling optimize will then run the optimization simultaneously over all + optimization parameters. To obtain the optimized parameters, call get_optimized_arguments + to return a set of dictionaries where the OptimizationParameter objects have been replaced + with the optimal values. These can then be modified for running a full reconstruction. + + Parameters + ---------- + reconstruction_type: class + Type of ptychographic reconstruction to perform + init_args: dict + Keyword arguments passed to the __init__ method of the reconstruction class + preprocess_args: dict + Keyword arguments passed to the preprocess method the reconstruction object + reconstruction_args: dict + Keyword arguments passed to the reconstruct method the the reconstruction object + affine_args: dict + Keyword arguments passed to AffineTransform. The transform is applied to the initial + scan positions. + """ + + # loop over each argument dictionary and split into static and optimization variables + ( + self._init_static_args, + self._init_optimize_args, + ) = self._split_static_and_optimization_vars(init_args) + ( + self._affine_static_args, + self._affine_optimize_args, + ) = self._split_static_and_optimization_vars(affine_args) + ( + self._preprocess_static_args, + self._preprocess_optimize_args, + ) = self._split_static_and_optimization_vars(preprocess_args) + ( + self._reconstruction_static_args, + self._reconstruction_optimize_args, + ) = self._split_static_and_optimization_vars(reconstruction_args) + + # Save list of skopt parameter objects and inital guess + self._parameter_list = [] + self._x0 = [] + for k, v in ( + self._init_optimize_args + | self._affine_optimize_args + | self._preprocess_optimize_args + | self._reconstruction_optimize_args + ).items(): + self._parameter_list.append(v._get(k)) + self._x0.append(v._initial_value) + + self._init_args = init_args + self._affine_args = affine_args + self._preprocess_args = preprocess_args + self._reconstruction_args = reconstruction_args + + self._reconstruction_type = reconstruction_type + + self._set_optimizer_defaults() + + def optimize( + self, + n_calls: int = 50, + n_initial_points: int = 20, + error_metric: Union[Callable, str] = "log", + **skopt_kwargs: dict, + ): + """ + Run optimizer + + Parameters + ---------- + n_calls: int + Number of times to run ptychographic reconstruction + n_initial_points: int + Number of uniformly spaced trial points to test before + beginning Bayesian optimization (must be less than n_calls) + error_metric: Callable or str + Function used to compute the reconstruction error. + When passed as a string, may be one of: + 'log': log(NMSE) of final object + 'linear': NMSE of final object + 'log-converged': log(NMSE) of final object if + NMSE is decreasing, 0 if NMSE increasing + 'linear-converged': NMSE of final object if + NMSE is decreasing, 1 if NMSE increasing + 'TV': sum( abs( grad( object ) ) ) / sum( abs( object ) ) + 'std': negative standard deviation of cropped object + 'std-phase': negative standard deviation of + phase of the cropped object + 'entropy-phase': entropy of the phase of the + cropped object + When passed as a Callable, a function that takes the + PhaseReconstruction object as its only argument + and returns the error metric as a single float + skopt_kwargs: dict + Additional arguments to be passed to skopt.gp_minimize + + """ + + error_metric = self._get_error_metric(error_metric) + + self._optimization_function = self._get_optimization_function( + self._reconstruction_type, + self._parameter_list, + self._init_static_args, + self._affine_static_args, + self._preprocess_static_args, + self._reconstruction_static_args, + self._init_optimize_args, + self._affine_optimize_args, + self._preprocess_optimize_args, + self._reconstruction_optimize_args, + error_metric, + ) + + # Make a progress bar + pbar = tqdm(total=n_calls, desc="Optimizing parameters") + + # We need to wrap the callback because if it returns a value + # the optimizer breaks its loop + def callback(*args, **kwargs): + pbar.update(1) + + self._skopt_result = gp_minimize( + self._optimization_function, + self._parameter_list, + n_calls=n_calls, + n_initial_points=n_initial_points, + x0=self._x0, + callback=callback, + **skopt_kwargs, + ) + + print("Optimized parameters:") + for p, x in zip(self._parameter_list, self._skopt_result.x): + print(f"{p.name}: {x}") + + # Finish the tqdm progressbar so subsequent things behave nicely + pbar.close() + + return self + + def visualize( + self, + plot_gp_model=True, + plot_convergence=False, + plot_objective=True, + plot_evaluations=False, + **kwargs, + ): + """ + Visualize optimization results + + Parameters + ---------- + plot_gp_model: bool + Display fitted Gaussian process model (only available for 1-dimensional problem) + plot_convergence: bool + Display convergence history + plot_objective: bool + Display GP objective function and partial dependence plots + plot_evaluations: bool + Display histograms of sampled points + kwargs: + Passed directly to the skopt plot_gassian_process/plot_objective + """ + ndims = len(self._parameter_list) + if ndims == 1: + if plot_convergence: + figsize = kwargs.pop("figsize", (9, 9)) + spec = GridSpec(nrows=2, ncols=1, height_ratios=[2, 1], hspace=0.15) + else: + figsize = kwargs.pop("figsize", (9, 6)) + spec = GridSpec(nrows=1, ncols=1) + + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(spec[0]) + skopt_plot_gaussian_process(self._skopt_result, ax=ax, **kwargs) + + if plot_convergence: + ax = fig.add_subplot(spec[1]) + skopt_plot_convergence(self._skopt_result, ax=ax) + + else: + if plot_convergence: + figsize = kwargs.pop("figsize", (4 * ndims, 4 * (ndims + 0.5))) + spec = GridSpec( + nrows=ndims + 1, + ncols=ndims, + height_ratios=[2] * ndims + [1], + hspace=0.15, + ) + else: + figsize = kwargs.pop("figsize", (4 * ndims, 4 * ndims)) + spec = GridSpec(nrows=ndims, ncols=ndims, hspace=0.15) + + if plot_evaluations: + axs = skopt_plot_evaluations(self._skopt_result) + elif plot_objective: + cmap = kwargs.pop("cmap", "magma") + axs = skopt_plot_objective(self._skopt_result, cmap=cmap, **kwargs) + elif plot_convergence: + skopt_plot_convergence(self._skopt_result) + return self + + fig = axs[0, 0].figure + fig.set_size_inches(figsize) + for i in range(ndims): + for j in range(ndims): + ax = axs[i, j] + ax.remove() + ax.figure = fig + fig.add_axes(ax) + ax.set_subplotspec(spec[i, j]) + + if plot_convergence: + ax = fig.add_subplot(spec[ndims, :]) + skopt_plot_convergence(self._skopt_result, ax=ax) + + spec.tight_layout(fig) + + return self + + def get_optimized_arguments(self): + """ + Get argument dictionaries containing optimized hyperparameters + + Returns + ------- + init_opt, prep_opt, reco_opt: dicts + Dictionaries of arguments to __init__, preprocess, and reconstruct + where the OptimizationParameter items have been replaced with the optimal + values obtained from the optimizer + """ + optimized_dict = { + p.name: v for p, v in zip(self._parameter_list, self._skopt_result.x) + } + + filtered_dict = { + k: v for k, v in optimized_dict.items() if k in self._init_args + } + init_opt = self._init_args | filtered_dict + + filtered_dict = { + k: v for k, v in optimized_dict.items() if k in self._affine_args + } + affine_opt = self._affine_args | filtered_dict + + affine_transform = partial(AffineTransform, **self._affine_static_args)( + **affine_opt + ) + scan_positions = self._get_scan_positions( + affine_transform, init_opt["datacube"] + ) + init_opt["initial_scan_positions"] = scan_positions + + filtered_dict = { + k: v for k, v in optimized_dict.items() if k in self._preprocess_args + } + prep_opt = self._preprocess_args | filtered_dict + + filtered_dict = { + k: v for k, v in optimized_dict.items() if k in self._reconstruction_args + } + reco_opt = self._reconstruction_args | filtered_dict + + return init_opt, prep_opt, reco_opt + + def _split_static_and_optimization_vars(self, argdict): + static_args = {} + optimization_args = {} + for k, v in argdict.items(): + if isinstance(v, OptimizationParameter): + optimization_args[k] = v + else: + static_args[k] = v + return static_args, optimization_args + + def _get_scan_positions(self, affine_transform, dataset): + R_pixel_size = dataset.calibration.get_R_pixel_size() + x, y = ( + np.arange(dataset.R_Nx) * R_pixel_size, + np.arange(dataset.R_Ny) * R_pixel_size, + ) + x, y = np.meshgrid(x, y, indexing="ij") + scan_positions = np.stack((x.ravel(), y.ravel()), axis=1) + scan_positions = scan_positions @ affine_transform.asarray() + return scan_positions + + def _get_error_metric(self, error_metric: Union[Callable, str]) -> Callable: + """ + Get error metric as a function, converting builtin method names + to functions + """ + + if callable(error_metric): + return error_metric + + assert error_metric in ( + "log", + "linear", + "log-converged", + "linear-converged", + "TV", + "std", + "std-phase", + "entropy-phase", + ), f"Error metric {error_metric} not recognized." + + if error_metric == "log": + + def f(ptycho): + return np.log(ptycho.error) + + elif error_metric == "linear": + + def f(ptycho): + return ptycho.error + + elif error_metric == "log-converged": + + def f(ptycho): + converged = ptycho.error_iterations[-1] <= np.min( + ptycho.error_iterations + ) + return np.log(ptycho.error) if converged else 0.0 + + elif error_metric == "log-linear": + + def f(ptycho): + converged = ptycho.error_iterations[-1] <= np.min( + ptycho.error_iterations + ) + return ptycho.error if converged else 1.0 + + elif error_metric == "TV": + + def f(ptycho): + gx, gy = np.gradient(ptycho.object_cropped, axis=(-2, -1)) + obj_mag = np.sum(np.abs(ptycho.object_cropped)) + tv = np.sum(np.abs(gx)) + np.sum(np.abs(gy)) + return tv / obj_mag + + elif error_metric == "std": + + def f(ptycho): + return -np.std(ptycho.object_cropped) + + elif error_metric == "std-phase": + + def f(ptycho): + return -np.std(np.angle(ptycho.object_cropped)) + + elif error_metric == "entropy-phase": + + def f(ptycho): + obj = np.angle(ptycho.object_cropped) + gx, gy = np.gradient(obj) + ghist, _, _ = np.histogram2d( + gx.ravel(), gy.ravel(), bins=obj.shape, density=True + ) + nz = ghist > 0 + S = np.sum(ghist[nz] * np.log2(ghist[nz])) + return S + + else: + raise ValueError(f"Error metric {error_metric} not recognized.") + + return f + + def _get_optimization_function( + self, + cls: type[PhaseReconstruction], + parameter_list: list, + init_static_args: dict, + affine_static_args: dict, + preprocess_static_args: dict, + reconstruct_static_args: dict, + init_optimization_params: dict, + affine_optimization_params: dict, + preprocess_optimization_params: dict, + reconstruct_optimization_params: dict, + error_metric: Callable, + ): + """ + Wrap the ptychography pipeline into a single function that encapsulates all of the + non-optimization arguments and accepts a concatenated set of keyword arguments. The + wrapper function returns the final error value from the ptychography run. + + parameter_list is a list of skopt Dimension objects + + Both static and optimization args are passed in dictionaries. The values of the + static dictionary are the fixed parameters, and only the keys of the optimization + dictionary are used. + """ + + # Get lists of optimization parameters for each step + init_params = list(init_optimization_params.keys()) + afft_params = list(affine_optimization_params.keys()) + prep_params = list(preprocess_optimization_params.keys()) + reco_params = list(reconstruct_optimization_params.keys()) + + # Construct partial methods to encapsulate the static parameters. + # If only ``reconstruct`` has optimization variables, perform + # preprocessing now, store the ptycho object, and use dummy + # functions instead of the partials + if (len(init_params), len(afft_params), len(prep_params)) == (0, 0, 0): + affine_preprocessed = AffineTransform(**affine_static_args) + init_args = init_static_args.copy() + init_args["initial_scan_positions"] = self._get_scan_positions( + affine_preprocessed, init_static_args["datacube"] + ) + + ptycho_preprocessed = cls(**init_args).preprocess(**preprocess_static_args) + + def obj(**kwargs): + return ptycho_preprocessed + + def prep(ptycho, **kwargs): + return ptycho + + else: + obj = partial(cls, **init_static_args) + prep = partial(cls.preprocess, **preprocess_static_args) + + affine = partial(AffineTransform, **affine_static_args) + recon = partial(cls.reconstruct, **reconstruct_static_args) + + # Target function for Gaussian process optimization that takes a single + # dict of named parameters and returns the ptycho error metric + @use_named_args(parameter_list) + def f(**kwargs): + init_args = {k: kwargs[k] for k in init_params} + afft_args = {k: kwargs[k] for k in afft_params} + prep_args = {k: kwargs[k] for k in prep_params} + reco_args = {k: kwargs[k] for k in reco_params} + + # Create affine transform object + tr = affine(**afft_args) + # Apply affine transform to pixel grid, using the + # calibrations lifted from the dataset + dataset = init_static_args["datacube"] + init_args["initial_scan_positions"] = self._get_scan_positions(tr, dataset) + + ptycho = obj(**init_args) + prep(ptycho, **prep_args) + recon(ptycho, **reco_args) + + return error_metric(ptycho) + + return f + + def _set_optimizer_defaults( + self, + verbose=False, + plot_center_of_mass=False, + plot_rotation=False, + plot_probe_overlaps=False, + progress_bar=False, + store_iterations=False, + reset=True, + ): + """ + Set all of the verbose and plotting to False, allowing for user-overwrite. + """ + self._init_static_args["verbose"] = verbose + + self._preprocess_static_args["plot_center_of_mass"] = plot_center_of_mass + self._preprocess_static_args["plot_rotation"] = plot_rotation + self._preprocess_static_args["plot_probe_overlaps"] = plot_probe_overlaps + + self._reconstruction_static_args["progress_bar"] = progress_bar + self._reconstruction_static_args["store_iterations"] = store_iterations + self._reconstruction_static_args["reset"] = reset + + +class OptimizationParameter: + """ + Wrapper for scikit-optimize Space objects used for convenient calling in the PtyhochraphyOptimizer + """ + + def __init__( + self, + initial_value: Union[float, int, bool], + lower_bound: Union[float, int, bool] = None, + upper_bound: Union[float, int, bool] = None, + scaling: str = "uniform", + space: str = "real", + categories: list = [], + ): + """ + Wrapper for scikit-optimize Space objects used as inputs to PtychographyOptimizer + + Parameters + ---------- + initial_value: + Initial value, used for first evaluation in optimizer + lower_bound, upper_bound: + Bounds on real or integer variables (not needed for bool or categorical) + scaling: str + Prior knowledge on sensitivity of the variable. Can be 'uniform' or 'log-uniform' + space: str + Type of variable. Can be 'real', 'integer', 'bool', or 'categorical' + categories: list + List of options for Categorical parameter + """ + # Check input + space = space.lower() + if space not in ("real", "integer", "bool", "categorical"): + raise ValueError(f"Unknown Parameter type: {space}") + + scaling = scaling.lower() + if scaling not in ("uniform", "log-uniform"): + raise ValueError(f"Unknown scaling: {scaling}") + + # Get the right scikit-optimize class + space_map = { + "real": Real, + "integer": Integer, + "bool": Categorical, + "categorical": Categorical, + } + param = space_map[space] + + # If a boolean property, the categories are True/False + if space == "bool": + categories = [True, False] + + if categories == [] and space in ("categorical", "bool"): + raise ValueError("Empty list of categories!") + + # store necessary information + self._initial_value = initial_value + self._categories = categories + self._lower_bound = lower_bound + self._upper_bound = upper_bound + self._scaling = scaling + self._param_type = param + + def _get(self, name): + self._name = name + if self._param_type is Categorical: + self._skopt_param = self._param_type( + name=self._name, categories=self._categories + ) + else: + self._skopt_param = self._param_type( + name=self._name, + low=self._lower_bound, + high=self._upper_bound, + prior=self._scaling, + ) + return self._skopt_param diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py index bb11db7d4..c2e1d3b77 100644 --- a/py4DSTEM/process/phase/utils.py +++ b/py4DSTEM/process/phase/utils.py @@ -1,7 +1,9 @@ +import functools from typing import Mapping, Tuple, Union import matplotlib.pyplot as plt import numpy as np +from scipy.optimize import curve_fit try: import cupy as cp @@ -10,50 +12,32 @@ cp = None from scipy.fft import dstn, idstn +from py4DSTEM.process.utils import get_CoM from py4DSTEM.process.utils.cross_correlate import align_and_shift_images from py4DSTEM.process.utils.utils import electron_wavelength_angstrom -from scipy.ndimage import gaussian_filter +from scipy.ndimage import gaussian_filter, uniform_filter1d +from skimage.restoration import unwrap_phase + +# fmt: off #: Symbols for the polar representation of all optical aberrations up to the fifth order. polar_symbols = ( - "C10", - "C12", - "phi12", - "C21", - "phi21", - "C23", - "phi23", - "C30", - "C32", - "phi32", - "C34", - "phi34", - "C41", - "phi41", - "C43", - "phi43", - "C45", - "phi45", - "C50", - "C52", - "phi52", - "C54", - "phi54", - "C56", - "phi56", + "C10", "C12", "phi12", + "C21", "phi21", "C23", "phi23", + "C30", "C32", "phi32", "C34", "phi34", + "C41", "phi41", "C43", "phi43", "C45", "phi45", + "C50", "C52", "phi52", "C54", "phi54", "C56", "phi56", ) #: Aliases for the most commonly used optical aberrations. polar_aliases = { - "defocus": "C10", - "astigmatism": "C12", - "astigmatism_angle": "phi12", - "coma": "C21", - "coma_angle": "phi21", - "Cs": "C30", - "C5": "C50", + "defocus": "C10", "astigmatism": "C12", "astigmatism_angle": "phi12", + "coma": "C21", "coma_angle": "phi21", + "Cs": "C30", + "C5": "C50", } +# fmt: on ### Probe functions @@ -80,6 +64,8 @@ class ComplexProbe: Device to perform calculations on. Must be either 'cpu' or 'gpu' rolloff: float, optional Tapers the cutoff edge over the given angular range [mrad]. + vacuum_probe_intensity: np.ndarray, optional + Squared of corner-centered aperture amplitude to use, instead of semiangle_cutoff + rolloff focal_spread: float, optional The 1/e width of the focal spread due to chromatic aberration and lens current instability [Ã…]. angular_spread: float, optional @@ -179,7 +165,7 @@ def evaluate_aperture( self._vacuum_probe_intensity, dtype=xp.float32 ) vacuum_probe_amplitude = xp.sqrt(xp.maximum(vacuum_probe_intensity, 0)) - return xp.fft.ifftshift(vacuum_probe_amplitude) + return vacuum_probe_amplitude if self._semiangle_cutoff == xp.inf: return xp.ones_like(alpha) @@ -431,9 +417,9 @@ def polar_coordinates(self, x, y): return alpha, phi def build(self): - """Builds complex probe in the center of the region of interest.""" + """Builds corner-centered complex probe in the center of the region of interest.""" xp = self._xp - array = xp.fft.fftshift(xp.fft.ifft2(self._evaluate_ctf())) + array = xp.fft.ifft2(self._evaluate_ctf()) array = array / xp.sqrt((xp.abs(array) ** 2).sum()) self._array = array return self @@ -447,7 +433,7 @@ def visualize(self, **kwargs): kwargs.pop("cmap", None) plt.imshow( - asnumpy(xp.abs(self._array) ** 2), + asnumpy(xp.abs(xp.fft.ifftshift(self._array)) ** 2), cmap=cmap, **kwargs, ) @@ -475,20 +461,6 @@ def spatial_frequencies(gpts: Tuple[int, int], sampling: Tuple[float, float]): ) -def projection(u: np.ndarray, v: np.ndarray, xp): - """Projection of vector u onto vector v.""" - return u * xp.vdot(u, v) / xp.vdot(u, u) - - -def orthogonalize(V: np.ndarray, xp): - """Non-normalized QR decomposition using repeated projections.""" - U = V.copy() - for i in range(1, V.shape[0]): - for j in range(i): - U[i, :] -= projection(U[j, :], V[i, :], xp) - return U - - ### FFT-shift functions @@ -644,6 +616,8 @@ class AffineTransform: x-translation t1: float y-translation + dilation: float + Isotropic expansion (multiplies scale0 and scale1) """ def __init__( @@ -654,9 +628,10 @@ def __init__( angle: float = 0.0, t0: float = 0.0, t1: float = 0.0, + dilation: float = 1.0, ): - self.scale0 = scale0 - self.scale1 = scale1 + self.scale0 = scale0 * dilation + self.scale1 = scale1 * dilation self.shear1 = shear1 self.angle = angle self.t0 = t0 @@ -1284,3 +1259,354 @@ def project_vector_field_divergence(vector_field, spacings=(1, 1, 1), xp=np): p = preconditioned_poisson_solver(div_v, spacings[0], xp=xp) grad_p = compute_gradient(p, spacings, xp=xp) return vector_field - grad_p + + +# Nesterov acceleration functions +# https://blogs.princeton.edu/imabandit/2013/04/01/acceleratedgradientdescent/ + + +@functools.cache +def nesterov_lambda(one_indexed_iter_num): + if one_indexed_iter_num == 0: + return 0 + return (1 + np.sqrt(1 + 4 * nesterov_lambda(one_indexed_iter_num - 1) ** 2)) / 2 + + +def nesterov_gamma(zero_indexed_iter_num): + one_indexed_iter_num = zero_indexed_iter_num + 1 + return (1 - nesterov_lambda(one_indexed_iter_num)) / nesterov_lambda( + one_indexed_iter_num + 1 + ) + + +def cartesian_to_polar_transform_2Ddata( + im_cart, + xy_center, + num_theta_bins=90, + radius_max=None, + corner_centered=False, + xp=np, +): + """ + Quick cartesian to polar conversion. + """ + + # coordinates + if radius_max is None: + if corner_centered: + radius_max = np.min(np.array(im_cart.shape) // 2) + else: + radius_max = np.sqrt(np.sum(np.array(im_cart.shape) ** 2)) // 2 + + r = xp.arange(radius_max) + t = xp.linspace( + 0, + 2.0 * np.pi, + num_theta_bins, + endpoint=False, + ) + ra, ta = xp.meshgrid(r, t) + + # resampling coordinates + x = ra * xp.cos(ta) + xy_center[0] + y = ra * xp.sin(ta) + xy_center[1] + + xf = xp.floor(x).astype("int") + yf = xp.floor(y).astype("int") + dx = x - xf + dy = y - yf + + mode = "wrap" if corner_centered else "clip" + + # resample image + im_polar = ( + im_cart.ravel()[ + xp.ravel_multi_index( + (xf, yf), + im_cart.shape, + mode=mode, + ) + ] + * (1 - dx) + * (1 - dy) + + im_cart.ravel()[ + xp.ravel_multi_index( + (xf + 1, yf), + im_cart.shape, + mode=mode, + ) + ] + * (dx) + * (1 - dy) + + im_cart.ravel()[ + xp.ravel_multi_index( + (xf, yf + 1), + im_cart.shape, + mode=mode, + ) + ] + * (1 - dx) + * (dy) + + im_cart.ravel()[ + xp.ravel_multi_index( + (xf + 1, yf + 1), + im_cart.shape, + mode=mode, + ) + ] + * (dx) + * (dy) + ) + + return im_polar + + +def polar_to_cartesian_transform_2Ddata( + im_polar, + xy_size, + xy_center, + corner_centered=False, + xp=np, +): + """ + Quick polar to cartesian conversion. + """ + + # coordinates + sx, sy = xy_size + cx, cy = xy_center + + if corner_centered: + x = xp.fft.fftfreq(sx, d=1 / sx) + y = xp.fft.fftfreq(sy, d=1 / sy) + else: + x = xp.arange(sx) + y = xp.arange(sy) + + xa, ya = xp.meshgrid(x, y, indexing="ij") + ra = xp.hypot(xa - cx, ya - cy) + ta = xp.arctan2(ya - cy, xa - cx) + + t = xp.linspace(0, 2 * np.pi, im_polar.shape[0], endpoint=False) + t_step = t[1] - t[0] + + # resampling coordinates + t_ind = ta / t_step + r_ind = ra.copy() + tf = xp.floor(t_ind).astype("int") + rf = xp.floor(r_ind).astype("int") + + # resample image + im_cart = im_polar.ravel()[ + xp.ravel_multi_index( + (tf, rf), + im_polar.shape, + mode=("wrap", "clip"), + ) + ] + + return im_cart + + +def regularize_probe_amplitude( + probe_init, + width_max_pixels=2.0, + nearest_angular_neighbor_averaging=5, + enforce_constant_intensity=True, + corner_centered=False, +): + """ + Fits sigmoid for each angular direction. + + Parameters + -------- + probe_init: np.array + 2D complex image of the probe in Fourier space. + width_max_pixels: float + Maximum edge width of the probe in pixels. + nearest_angular_neighbor_averaging: int + Number of nearest angular neighbor pixels to average to make aperture less jagged. + enforce_constant_intensity: bool + Set to true to make intensity inside the aperture constant. + corner_centered: bool + If True, the probe is assumed to be corner-centered + + Returns + -------- + probe_corr: np.ndarray + 2D complex image of the corrected probe in Fourier space. + coefs_all: np.ndarray + coefficients for the sigmoid fits + """ + + # Get probe intensity + probe_amp = np.abs(probe_init) + probe_angle = np.angle(probe_init) + probe_int = probe_amp**2 + + # Center of mass for probe intensity + xy_center = get_CoM(probe_int, device="cpu", corner_centered=corner_centered) + + # Convert intensity to polar coordinates + polar_int = cartesian_to_polar_transform_2Ddata( + probe_int, + xy_center=xy_center, + corner_centered=corner_centered, + xp=np, + ) + + # Fit corrected probe intensity + radius = np.arange(polar_int.shape[1]) + + # estimate initial parameters + sub = polar_int > (np.max(polar_int) * 0.5) + sig_0 = np.mean(polar_int[sub]) + rad_0 = np.max(np.argwhere(np.sum(sub, axis=0))) + width = width_max_pixels * 0.5 + + # init + def step_model(radius, sig_0, rad_0, width): + return sig_0 * np.clip((rad_0 - radius) / width, 0.0, 1.0) + + coefs_all = np.zeros((polar_int.shape[0], 3)) + coefs_all[:, 0] = sig_0 + coefs_all[:, 1] = rad_0 + coefs_all[:, 2] = width + + # bounds + lb = (0.0, 0.0, 1e-4) + ub = (np.inf, np.inf, width_max_pixels) + + # refine parameters, generate polar image + polar_fit = np.zeros_like(polar_int) + for a0 in range(polar_int.shape[0]): + coefs_all[a0, :] = curve_fit( + step_model, + radius, + polar_int[a0, :], + p0=coefs_all[a0, :], + xtol=1e-12, + bounds=(lb, ub), + )[0] + polar_fit[a0, :] = step_model(radius, *coefs_all[a0, :]) + + # Compute best-fit constant intensity inside probe, update bounds + sig_0 = np.median(coefs_all[:, 0]) + coefs_all[:, 0] = sig_0 + lb = (sig_0 - 1e-8, 0.0, 1e-4) + ub = (sig_0 + 1e-8, np.inf, width_max_pixels) + + # refine parameters, generate polar image + polar_int_corr = np.zeros_like(polar_int) + for a0 in range(polar_int.shape[0]): + coefs_all[a0, :] = curve_fit( + step_model, + radius, + polar_int[a0, :], + p0=coefs_all[a0, :], + xtol=1e-12, + bounds=(lb, ub), + )[0] + # polar_int_corr[a0, :] = step_model(radius, *coefs_all[a0, :]) + + # make aperture less jagged, using moving mean + coefs_all = np.apply_along_axis( + uniform_filter1d, + 0, + coefs_all, + size=nearest_angular_neighbor_averaging, + mode="wrap", + ) + for a0 in range(polar_int.shape[0]): + polar_int_corr[a0, :] = step_model(radius, *coefs_all[a0, :]) + + # Convert back to cartesian coordinates + int_corr = polar_to_cartesian_transform_2Ddata( + polar_int_corr, + xy_size=probe_init.shape, + xy_center=xy_center, + corner_centered=corner_centered, + ) + + amp_corr = np.sqrt(np.maximum(int_corr, 0)) + + # Assemble output probe + if not enforce_constant_intensity: + max_coeff = np.sqrt(coefs_all[:, 0]).max() + amp_corr = amp_corr / max_coeff * probe_amp + + probe_corr = amp_corr * np.exp(1j * probe_angle) + + return probe_corr, polar_int, polar_int_corr, coefs_all + + +def aberrations_basis_function( + probe_size, + probe_sampling, + max_angular_order, + max_radial_order, + xp=np, +): + """ """ + sx, sy = probe_size + dx, dy = probe_sampling + qx = xp.fft.fftfreq(sx, dx) + qy = xp.fft.fftfreq(sy, dy) + + qxa, qya = xp.meshgrid(qx, qy, indexing="ij") + q2 = qxa**2 + qya**2 + theta = xp.arctan2(qya, qxa) + + basis = [] + index = [] + + for n in range(max_angular_order + 1): + for m in range((max_radial_order - n) // 2 + 1): + basis.append((q2 ** (m + n / 2) * np.cos(n * theta))) + index.append((m, n, 0)) + if n > 0: + basis.append((q2 ** (m + n / 2) * np.sin(n * theta))) + index.append((m, n, 1)) + + basis = xp.array(basis) + + return basis, index + + +def fit_aberration_surface( + complex_probe, + probe_sampling, + max_angular_order, + max_radial_order, + xp=np, +): + """ """ + probe_amp = xp.abs(complex_probe) + probe_angle = xp.angle(complex_probe) + + if xp is np: + probe_angle = probe_angle.astype(np.float64) + unwrapped_angle = unwrap_phase(probe_angle, wrap_around=True).astype(xp.float32) + else: + probe_angle = xp.asnumpy(probe_angle).astype(np.float64) + unwrapped_angle = unwrap_phase(probe_angle, wrap_around=True) + unwrapped_angle = xp.asarray(unwrapped_angle).astype(xp.float32) + + basis, _ = aberrations_basis_function( + complex_probe.shape, + probe_sampling, + max_angular_order, + max_radial_order, + xp=xp, + ) + + raveled_basis = basis.reshape((basis.shape[0], -1)) + raveled_weights = probe_amp.ravel() + + Aw = raveled_basis.T * raveled_weights[:, None] + bw = unwrapped_angle.ravel() * raveled_weights + coeff = xp.linalg.lstsq(Aw, bw, rcond=None)[0] + + fitted_angle = xp.tensordot(coeff, basis, axes=1) + + return fitted_angle, coeff diff --git a/py4DSTEM/process/polar/__init__.py b/py4DSTEM/process/polar/__init__.py index 741e0c610..79e13a054 100644 --- a/py4DSTEM/process/polar/__init__.py +++ b/py4DSTEM/process/polar/__init__.py @@ -1,3 +1,3 @@ from py4DSTEM.process.polar.polar_datacube import PolarDatacube from py4DSTEM.process.polar.polar_fits import fit_amorphous_ring, plot_amorphous_ring - +from py4DSTEM.process.polar.polar_peaks import find_peaks_single_pattern, find_peaks, refine_peaks, plot_radial_peaks, plot_radial_background, make_orientation_histogram \ No newline at end of file diff --git a/py4DSTEM/process/polar/polar_datacube.py b/py4DSTEM/process/polar/polar_datacube.py index c8afa7609..cb206c961 100644 --- a/py4DSTEM/process/polar/polar_datacube.py +++ b/py4DSTEM/process/polar/polar_datacube.py @@ -1,5 +1,5 @@ import numpy as np -from py4DSTEM.classes import DataCube +from py4DSTEM.datacube import DataCube from scipy.ndimage import binary_opening,binary_closing, gaussian_filter1d @@ -19,9 +19,9 @@ def __init__( n_annular = 180, qscale = None, mask = None, - mask_thresh = 0.25, + mask_thresh = 0.1, ellipse = True, - two_fold_rotation = False, + two_fold_symmetry = False, ): """ Parameters @@ -29,14 +29,15 @@ def __init__( datacube : DataCube The datacube in cartesian coordinates qmin : number - Minumum radius of the polar transformation + Minumum radius of the polar transformation, in pixels qmax : number or None - Maximum radius of the polar transformation + Maximum radius of the polar transformation, in pixels qstep : number - Width of radial bins + Width of radial bins, in pixels n_annular : integer Number of bins in the annular direction. Bins will each - have a width of 360/num_annular_bins, in degrees + have a width of 360/n_annular, or 180/n_annular if + two_fold_rotation is set to True, in degrees qscale : number or None Radial scaling power to apply to polar transform mask : boolean array @@ -49,9 +50,10 @@ def __init__( performs an elliptic transform iff elliptic calibrations are available. two_fold_rotation : bool - Setting to True computes the transform mod(theta,pi), i.e. assumes all patterns - posess two-fold rotation (Friedel symmetry). The output angular range in this case - becomes [0, pi) as opposed to the default of [0,2*pi). + Setting to True computes the transform mod(theta,pi), i.e. assumes + all patterns possess two-fold rotation (Friedel symmetry). The + output angular range in this case becomes [0, pi) as opposed to the + default of [0,2*pi). """ # attach datacube @@ -68,17 +70,12 @@ def __init__( # setup sampling - # annular range, depending on if polar transform spans pi or 2*pi - if two_fold_rotation: - self._annular_range = np.pi - else: - self._annular_range = 2.0 * np.pi - # polar self._qscale = qscale if qmax is None: qmax = np.min(self._datacube.Qshape) / np.sqrt(2) - self.set_annular_bins(n_annular) + self._n_annular = n_annular + self.two_fold_symmetry = two_fold_symmetry #implicitly calls set_annular_bins self.set_radial_bins(qmin,qmax,qstep) # cartesian @@ -102,6 +99,16 @@ def __init__( plot_FEM_global, calculate_FEM_local, ) + from py4DSTEM.process.polar.polar_peaks import ( + find_peaks_single_pattern, + find_peaks, + refine_peaks_local, + refine_peaks, + plot_radial_peaks, + plot_radial_background, + model_radial_background, + make_orientation_histogram, + ) # sampling methods + properties @@ -120,8 +127,9 @@ def set_radial_bins( self._qmax, self._qstep ) - self.qscale = self._qscale + self._radial_step = self._datacube.calibration.get_Q_pixel_size() * self._qstep self.set_polar_shape() + self.qscale = self._qscale @property def qmin(self): @@ -176,6 +184,18 @@ def annular_bins(self): @property def annular_step(self): return self._annular_step + @property + def two_fold_symmetry(self): + return self._two_fold_symmetry + @two_fold_symmetry.setter + def two_fold_symmetry(self,x): + assert(isinstance(x,bool)), f"two_fold_symmetry must be boolean, not type {type(x)}" + self._two_fold_symmetry = x + if x: + self._annular_range = np.pi + else: + self._annular_range = 2 * np.pi + self.set_annular_bins(self._n_annular) @property def n_annular(self): @@ -195,8 +215,22 @@ def set_polar_shape(self): # set KDE params self._annular_bin_step = 1 / (self._annular_step * (self.radial_bins + self.qstep * 0.5)) self._sigma_KDE = self._annular_bin_step * 0.5 + # set array indices + self._annular_indices = np.arange(self.polar_shape[0]).astype(int) + self._radial_indices = np.arange(self.polar_shape[1]).astype(int) + # coordinate grid properties + @property + def tt(self): + return self._annular_bins + @property + def tt_deg(self): + return self._annular_bins * 180/np.pi + @property + def qq(self): + return self.radial_bins * self.calibration.get_Q_pixel_size() + # scaling property @@ -207,7 +241,7 @@ def qscale(self): def qscale(self,x): self._qscale = x if x is not None: - self._qscale_ar = np.arange(self.polar_shape[1])**x + self._qscale_ar = (self.qq / self.qq[-1])**x # expose raw data @@ -419,7 +453,7 @@ def _transform( ) # scale the normalization array by the bin density - norm_array = ans_norm*self._polarcube._annular_bin_step[np.newaxis] + norm_array = ans_norm * self._polarcube._annular_bin_step[np.newaxis] mask_bool = norm_array < mask_thresh # apply normalization @@ -442,6 +476,7 @@ def _transform( ) return ans elif returnval == 'nan': + ans[mask_bool] = np.nan return ans elif returnval == 'all': return ans, ans_norm, norm_array, mask_bool @@ -472,7 +507,7 @@ def _transform_array( # get polar coords rr = np.sqrt(x**2 + y**2) tt = np.mod( - np.arctan2(y, x) - np.pi/2, + np.arctan2(y, x), self._polarcube._annular_range) # elliptical @@ -480,21 +515,12 @@ def _transform_array( # unpack ellipse a,b,theta = ellipse - # transformation matrix (elliptic cartesian -> circular cartesian) - A = (a/b)*np.cos(theta) - B = -np.sin(theta) - C = (a/b)*np.sin(theta) - D = np.cos(theta) - det = 1 / (A*D - B*C) - - # get circular cartesian coords - xc = x*D - y*B - yc = -x*C + y*A - - # get polar coords - rr = det * np.hypot(xc,yc) + # Get polar coords + xc = x*np.cos(theta) + y*np.sin(theta) + yc = (y*np.cos(theta) - x*np.sin(theta))*(a/b) + rr = (b/a) * np.hypot(xc,yc) tt = np.mod( - np.arctan2(yc,xc) - np.pi/2, + np.arctan2(yc,xc) + theta, self._polarcube._annular_range) # transform to bin sampling @@ -562,5 +588,4 @@ def __repr__(self): space = ' '*len(self.__class__.__name__)+' ' string = f"{self.__class__.__name__}( " string += "Retrieves the diffraction pattern at scan position (x,y) in polar coordinates when sliced with [x,y]." - return string - + return string \ No newline at end of file diff --git a/py4DSTEM/process/polar/polar_fits.py b/py4DSTEM/process/polar/polar_fits.py index ad418939f..e231dda07 100644 --- a/py4DSTEM/process/polar/polar_fits.py +++ b/py4DSTEM/process/polar/polar_fits.py @@ -7,13 +7,15 @@ def fit_amorphous_ring( im, - center, - radial_range, + center = None, + radial_range = None, coefs = None, + mask_dp = None, show_fit_mask = False, + maxfev = None, verbose = False, plot_result = True, - plot_log_scale = True, + plot_log_scale = False, plot_int_scale = (-3,3), figsize = (8,8), return_all_coefs = True, @@ -27,11 +29,19 @@ def fit_amorphous_ring( im: np.array 2D image array to perform fitting on center: np.array - (x,y) center coordinates for fitting mask + (x,y) center coordinates for fitting mask. If not specified + by the user, we will assume the center coordinate is (im.shape-1)/2. radial_range: np.array - (radius_inner, radius_outer) radial range to perform fitting over + (radius_inner, radius_outer) radial range to perform fitting over. + If not specified by the user, we will assume (im.shape[0]/4,im.shape[0]/2). + coefs: np.array (optional) + Array containing initial fitting coefficients for the amorphous fit. + mask_dp: np.array + Dark field mask for fitting, in addition to the radial range specified above. show_fit_mask: bool Set to true to preview the fitting mask and initial guess for the ellipse params + maxfev: int + Max number of fitting evaluations for curve_fit. verbose: bool Print fit results plot_result: bool @@ -53,6 +63,14 @@ def fit_amorphous_ring( 11 parameter elliptic fit coefficients """ + # Default values + if center is None: + center = np.array(( + (im.shape[0]-1)/2, + (im.shape[1]-1)/2)) + if radial_range is None: + radial_range = (im.shape[0]/4, im.shape[0]/2) + # coordinates xa,ya = np.meshgrid( np.arange(im.shape[0]), @@ -66,6 +84,9 @@ def fit_amorphous_ring( ra2 >= radial_range[0]**2, ra2 <= radial_range[1]**2, ) + if mask_dp is not None: + # Logical AND the radial mask with the user-provided mask + mask = np.logical_and(mask, mask_dp) vals = im[mask] basis = np.vstack((xa[mask],ya[mask])) @@ -118,9 +139,11 @@ def fit_amorphous_ring( int_range = ( int_med + plot_int_scale[0]*int_std, int_med + plot_int_scale[1]*int_std) - im_plot = np.tile(np.clip( - (np.log(im[:,:,None]) - int_range[0]) / (int_range[1] - int_range[0]), - 0,1),(1,1,3)) + im_plot = np.tile( + np.clip( + (np.log(im[:,:,None]) - int_range[0]) / (int_range[1] - int_range[0]), + 0,1), + (1,1,3)) else: int_med = np.median(vals) @@ -139,14 +162,26 @@ def fit_amorphous_ring( else: # Perform elliptic fitting int_mean = np.mean(vals) - coefs = curve_fit( - amorphous_model, - basis, - vals / int_mean, - p0=coefs, - xtol = 1e-12, - bounds = (lb,ub), - )[0] + + if maxfev is None: + coefs = curve_fit( + amorphous_model, + basis, + vals / int_mean, + p0=coefs, + xtol = 1e-8, + bounds = (lb,ub), + )[0] + else: + coefs = curve_fit( + amorphous_model, + basis, + vals / int_mean, + p0=coefs, + xtol = 1e-8, + bounds = (lb,ub), + maxfev = maxfev, + )[0] coefs[4] = np.mod(coefs[4],2*np.pi) coefs[5:8] *= int_mean # bounds=bounds @@ -346,5 +381,4 @@ def amorphous_model(basis, *coefs): sub = np.logical_not(sub) int_model[sub] += int12*np.exp(dr2[sub]/(-2*sigma2**2)) - return int_model - + return int_model \ No newline at end of file diff --git a/py4DSTEM/process/polar/polar_peaks.py b/py4DSTEM/process/polar/polar_peaks.py new file mode 100644 index 000000000..6a6e0860a --- /dev/null +++ b/py4DSTEM/process/polar/polar_peaks.py @@ -0,0 +1,1347 @@ + +import numpy as np +import matplotlib.pyplot as plt + +from scipy.ndimage import gaussian_filter, gaussian_filter1d +from scipy.signal import peak_prominences +from skimage.feature import peak_local_max +from scipy.optimize import curve_fit, leastsq +import warnings + +# from emdfile import tqdmnd, PointList, PointListArray +from py4DSTEM import tqdmnd, PointList, PointListArray +from py4DSTEM.process.fit import polar_twofold_gaussian_2D, polar_twofold_gaussian_2D_background + +def find_peaks_single_pattern( + self, + x, + y, + mask = None, + bragg_peaks = None, + bragg_mask_radius = None, + sigma_annular_deg = 10.0, + sigma_radial_px = 3.0, + sigma_annular_deg_max = None, + radial_background_subtract = True, + radial_background_thresh = 0.25, + num_peaks_max = 100, + threshold_abs = 1.0, + threshold_prom_annular = None, + threshold_prom_radial = None, + remove_masked_peaks = False, + scale_sigma_annular = 0.5, + scale_sigma_radial = 0.25, + return_background = False, + plot_result = True, + plot_power_scale = 1.0, + plot_scale_size = 10.0, + figsize = (12,6), + returnfig = False, + ): + """ + Peak detection function for polar transformations. + + Parameters + -------- + x: int + x index of diffraction pattern + y: int + y index of diffraction pattern + mask: np.array + Boolean mask in Cartesian space, to filter detected peaks. + bragg_peaks: py4DSTEM.BraggVectors + Set of Bragg peaks used to generated a mask in Cartesian space, to filter detected peaks + sigma_annular_deg: float + smoothing along the annular direction in degrees, periodic + sigma_radial_px: float + smoothing along the radial direction in pixels, not periodic + sigma_annular_deg_max: float + Specify this value for the max annular sigma. Peaks larger than this will be split + into multiple peaks, depending on the ratio. + radial_background_subtract: bool + If true, subtract radial background estimate + radial_background_thresh: float + Relative order of sorted values to use as background estimate. + Setting to 0.5 is equivalent to median, 0.0 is min value. + num_peaks_max = 100 + Max number of peaks to return. + threshold_abs: float + Absolute image intensity threshold for peaks. + threshold_prom_annular: float + Threshold for prominance, along annular direction. + threshold_prom_radial: float + Threshold for prominance, along radial direction. + remove_masked_peaks: bool + Delete peaks that are in the region masked by "mask" + scale_sigma_annular: float + Scaling of the estimated annular standard deviation. + scale_sigma_radial: float + Scaling of the estimated radial standard deviation. + return_background: bool + Return the background signal. + plot_result: + Plot the detector peaks + plot_power_scale: float + Image intensity power law scaling. + plot_scale_size: float + Marker scaling in the plot. + figsize: 2-tuple + Size of the result plotting figure. + returnfig: bool + Return the figure and axes handles. + + Returns + -------- + + peaks_polar : pointlist + The detected peaks + fig, ax : (optional) + Figure and axes handles + + """ + + # if needed, generate mask from Bragg peaks + if bragg_peaks is not None: + mask_bragg = self._datacube.get_braggmask( + bragg_peaks, + x, + y, + radius = bragg_mask_radius, + ) + if mask is None: + mask = mask_bragg + else: + mask = np.logical_or(mask, mask_bragg) + + + # Convert sigma values into units of bins + sigma_annular = np.deg2rad(sigma_annular_deg) / self.annular_step + sigma_radial = sigma_radial_px / self.qstep + + # Get transformed image and normalization array + im_polar, im_polar_norm, norm_array, mask_bool = self.transform( + self._datacube.data[x,y], + mask = mask, + returnval = 'all_zeros', + ) + # Change sign convention of mask + mask_bool = np.logical_not(mask_bool) + + # Background subtraction + if radial_background_subtract: + sig_bg = np.zeros(im_polar.shape[1]) + for a0 in range(im_polar.shape[1]): + if np.any(mask_bool[:,a0]): + vals = np.sort(im_polar[mask_bool[:,a0],a0]) + ind = np.round(radial_background_thresh * (vals.shape[0]-1)).astype('int') + sig_bg[a0] = vals[ind] + sig_bg_mask = np.sum(mask_bool, axis=0) >= (im_polar.shape[0]//2) + im_polar = np.maximum(im_polar - sig_bg[None,:], 0) + + # apply smoothing and normalization + im_polar_sm = gaussian_filter( + im_polar * norm_array, + sigma = (sigma_annular, sigma_radial), + mode = ('wrap', 'nearest'), + ) + im_mask = gaussian_filter( + norm_array, + sigma = (sigma_annular, sigma_radial), + mode = ('wrap', 'nearest'), + ) + sub = im_mask > 0.001 * np.max(im_mask) + im_polar_sm[sub] /= im_mask[sub] + + # Find local maxima + peaks = peak_local_max( + im_polar_sm, + num_peaks = num_peaks_max, + threshold_abs = threshold_abs, + ) + + # check if peaks should be removed from the polar transformation mask + if remove_masked_peaks: + peaks = np.delete( + peaks, + mask_bool[peaks[:,0],peaks[:,1]] == False, + axis = 0, + ) + + # peak intensity + peaks_int = im_polar_sm[peaks[:,0],peaks[:,1]] + + # Estimate prominance of peaks, and their size in units of pixels + peaks_prom = np.zeros((peaks.shape[0],4)) + annular_ind_center = np.atleast_1d(np.array(im_polar_sm.shape[0]//2).astype('int')) + for a0 in range(peaks.shape[0]): + + # annular + trace_annular = np.roll( + np.squeeze(im_polar_sm[:,peaks[a0,1]]), + annular_ind_center - peaks[a0,0]) + p_annular = peak_prominences( + trace_annular, + annular_ind_center, + ) + sigma_annular = scale_sigma_annular * np.minimum( + annular_ind_center - p_annular[1], + p_annular[2] - annular_ind_center) + + # radial + trace_radial = im_polar_sm[peaks[a0,0],:] + p_radial = peak_prominences( + trace_radial, + np.atleast_1d(peaks[a0,1]), + ) + sigma_radial = scale_sigma_radial * np.minimum( + peaks[a0,1] - p_radial[1], + p_radial[2] - peaks[a0,1]) + + # output + peaks_prom[a0,0] = p_annular[0] + peaks_prom[a0,1] = sigma_annular[0] + peaks_prom[a0,2] = p_radial[0] + peaks_prom[a0,3] = sigma_radial[0] + + # if needed, remove peaks using prominance criteria + if threshold_prom_annular is not None: + remove = peaks_prom[:,0] < threshold_prom_annular + peaks = np.delete( + peaks, + remove, + axis = 0, + ) + peaks_int = np.delete( + peaks_int, + remove, + ) + peaks_prom = np.delete( + peaks_prom, + remove, + axis = 0, + ) + if threshold_prom_radial is not None: + remove = peaks_prom[:,2] < threshold_prom_radial + peaks = np.delete( + peaks, + remove, + axis = 0, + ) + peaks_int = np.delete( + peaks_int, + remove, + ) + peaks_prom = np.delete( + peaks_prom, + remove, + axis = 0, + ) + + # combine peaks into one array + peaks_all = np.column_stack((peaks, peaks_int, peaks_prom)) + + # Split peaks into multiple peaks if they have sigma values larger than user-specified threshold + if sigma_annular_deg_max is not None: + peaks_new = np.zeros((0,peaks_all.shape[1])) + for a0 in range(peaks_all.shape[0]): + if peaks_all[a0,4] >= (1.5*sigma_annular_deg_max): + num = np.round(peaks_all[a0,4] / sigma_annular_deg_max) + sigma_annular_new = peaks_all[a0,4] / num + + v = np.arange(num) + v -= np.mean(v) + t_new = np.mod(peaks_all[a0,0] + 2*v*sigma_annular_new, + self._n_annular) + + for a1 in range(num.astype('int')): + peaks_new = np.vstack(( + peaks_new, + np.array(( + t_new[a1], + peaks_all[a0,1], + peaks_all[a0,2], + peaks_all[a0,3], + sigma_annular_new, + peaks_all[a0,5], + peaks_all[a0,6], + )), + )) + else: + peaks_new = np.vstack(( + peaks_new, + peaks_all[a0,:] + )) + peaks_all = peaks_new + + + # Output data as a pointlist + peaks_polar = PointList( + peaks_all.ravel().view([ + ('qt', float), + ('qr', float), + ('intensity', float), + ('prom_annular', float), + ('sigma_annular', float), + ('prom_radial', float), + ('sigma_radial', float), + ]), + name = 'peaks_polar') + + + if plot_result: + # init + im_plot = im_polar.copy() + im_plot = np.maximum(im_plot, 0) ** plot_power_scale + + t = np.linspace(0,2*np.pi,180+1) + ct = np.cos(t) + st = np.sin(t) + + + fig,ax = plt.subplots(figsize=figsize) + + ax.imshow( + im_plot, + cmap = 'gray', + ) + + # peaks + ax.scatter( + peaks_polar['qr'], + peaks_polar['qt'], + s = peaks_polar['intensity'] * plot_scale_size, + marker='o', + color = (1,0,0), + ) + for a0 in range(peaks_polar.data.shape[0]): + ax.plot( + peaks_polar['qr'][a0] + st * peaks_polar['sigma_radial'][a0], + peaks_polar['qt'][a0] + ct * peaks_polar['sigma_annular'][a0], + linewidth = 1, + color = 'r', + ) + if peaks_polar['qt'][a0] - peaks_polar['sigma_annular'][a0] < 0: + ax.plot( + peaks_polar['qr'][a0] + st * peaks_polar['sigma_radial'][a0], + peaks_polar['qt'][a0] + ct * peaks_polar['sigma_annular'][a0] + im_plot.shape[0], + linewidth = 1, + color = 'r', + ) + if peaks_polar['qt'][a0] + peaks_polar['sigma_annular'][a0] > im_plot.shape[0]: + ax.plot( + peaks_polar['qr'][a0] + st * peaks_polar['sigma_radial'][a0], + peaks_polar['qt'][a0] + ct * peaks_polar['sigma_annular'][a0] - im_plot.shape[0], + linewidth = 1, + color = 'r', + ) + + # plot appearance + ax.set_xlim((0,im_plot.shape[1]-1)) + ax.set_ylim((im_plot.shape[0]-1,0)) + + if returnfig and plot_result: + if return_background: + return peaks_polar, sig_bg, sig_bg_mask, fig, ax + else: + return peaks_polar, fig, ax + else: + if return_background: + return peaks_polar, sig_bg, sig_bg_mask + else: + return peaks_polar + + +def find_peaks( + self, + mask = None, + bragg_peaks = None, + bragg_mask_radius = None, + sigma_annular_deg = 10.0, + sigma_radial_px = 3.0, + sigma_annular_deg_max = None, + radial_background_subtract = True, + radial_background_thresh = 0.25, + num_peaks_max = 100, + threshold_abs = 1.0, + threshold_prom_annular = None, + threshold_prom_radial = None, + remove_masked_peaks = False, + scale_sigma_annular = 0.5, + scale_sigma_radial = 0.25, + progress_bar = True, + ): + """ + Peak detection function for polar transformations. Loop through all probe positions, + find peaks. Store the peak positions and background signals. + + Parameters + -------- + sigma_annular_deg: float + smoothing along the annular direction in degrees, periodic + sigma_radial_px: float + smoothing along the radial direction in pixels, not periodic + + Returns + -------- + + """ + + # init + self.bragg_peaks = bragg_peaks + self.bragg_mask_radius = bragg_mask_radius + self.peaks = PointListArray( + dtype = [ + ('qt', ' min_num_pixels_fit: + try: + # perform fitting + p0, pcov = curve_fit( + polar_twofold_gaussian_2D, + tq[:,mask_peak.ravel()], + im_polar[mask_peak], + p0 = p0, + # bounds = bounds, + ) + + # Output parameters + self.peaks[rx,ry]['intensity'][a0] = p0[0] + self.peaks[rx,ry]['qt'][a0] = p0[1] / t_step + self.peaks[rx,ry]['qr'][a0] = p0[2] / q_step + self.peaks[rx,ry]['sigma_annular'][a0] = p0[3] / t_step + self.peaks[rx,ry]['sigma_radial'][a0] = p0[4] / q_step + + except: + pass + + else: + # initial parameters + p0 = [ + p['intensity'][a0], + p['qt'][a0] * t_step, + p['qr'][a0] * q_step, + p['sigma_annular'][a0] * t_step, + p['sigma_radial'][a0] * q_step, + 0, + ] + + # Mask around peak for fitting + dt = np.mod(tt - p0[1] + np.pi/2, np.pi) - np.pi/2 + mask_peak = np.logical_and(mask_bool, + dt**2/(fit_range_sigma_annular*p0[3])**2 \ + + (qq-p0[2])**2/(fit_range_sigma_radial*p0[4])**2 <= 1) + + if np.sum(mask_peak) > min_num_pixels_fit: + try: + # perform fitting + p0, pcov = curve_fit( + polar_twofold_gaussian_2D_background, + tq[:,mask_peak.ravel()], + im_polar[mask_peak], + p0 = p0, + # bounds = bounds, + ) + + # Output parameters + self.peaks[rx,ry]['intensity'][a0] = p0[0] + self.peaks[rx,ry]['qt'][a0] = p0[1] / t_step + self.peaks[rx,ry]['qr'][a0] = p0[2] / q_step + self.peaks[rx,ry]['sigma_annular'][a0] = p0[3] / t_step + self.peaks[rx,ry]['sigma_radial'][a0] = p0[4] / q_step + + except: + pass + + +def plot_radial_peaks( + self, + q_pixel_units = False, + qmin = None, + qmax = None, + qstep = None, + label_y_axis = False, + figsize = (8,4), + returnfig = False, + ): + """ + Calculate and plot the total peak signal as a function of the radial coordinate. + + """ + + # Get all peak data + vects = np.concatenate( + [self.peaks[i,j].data for i in range(self._datacube.Rshape[0]) for j in range(self._datacube.Rshape[1])]) + if q_pixel_units: + qr = vects['qr'] + else: + qr = (vects['qr'] + self.qmin) * self._radial_step + intensity = vects['intensity'] + + # bins + if qmin is None: + qmin = self.qq[0] + if qmax is None: + qmax = self.qq[-1] + if qstep is None: + qstep = self.qq[1] - self.qq[0] + q_bins = np.arange(qmin,qmax,qstep) + q_num = q_bins.shape[0] + if q_pixel_units: + q_bins /= self._radial_step + + # histogram + q_ind = (qr - q_bins[0]) / (q_bins[1] - q_bins[0]) + qf = np.floor(q_ind).astype("int") + dq = q_ind - qf + + sub = np.logical_and(qf >= 0, qf < q_num) + int_peaks = np.bincount( + np.floor(q_ind[sub]).astype("int"), + weights=(1 - dq[sub]) * intensity[sub], + minlength=q_num, + ) + sub = np.logical_and(q_ind >= -1, q_ind < q_num - 1) + int_peaks += np.bincount( + np.floor(q_ind[sub] + 1).astype("int"), + weights=dq[sub] * intensity[sub], + minlength=q_num, + ) + + + # plotting + fig,ax = plt.subplots(figsize = figsize) + ax.plot( + q_bins, + int_peaks, + color = 'r', + linewidth = 2, + ) + ax.set_xlim((q_bins[0],q_bins[-1])) + if q_pixel_units: + ax.set_xlabel( + 'Scattering Angle (pixels)', + fontsize = 14, + ) + else: + ax.set_xlabel( + 'Scattering Angle (' + self.calibration.get_Q_pixel_units() +')', + fontsize = 14, + ) + ax.set_ylabel( + 'Total Peak Signal', + fontsize = 14, + ) + if not label_y_axis: + ax.tick_params( + left = False, + labelleft = False) + + if returnfig: + return fig,ax + + +def model_radial_background( + self, + ring_position = None, + ring_sigma = None, + ring_int = None, + refine_model = True, + plot_result = True, + figsize = (8,4), + ): + """ + User provided radial background model, of the form: + + int = int_const + + int_0 * exp( - q**2 / (2*s0**2) ) + + int_1 * exp( - (q - q_1)**2 / (2*s1**2) ) + + ... + + int_n * exp( - (q - q_n)**2 / (2*sn**2) ) + + where n is the number of amorphous halos / rings included in the fit. + + """ + + # Get mean radial background and mask + self.background_radial_mean = np.sum( + self.background_radial * self.background_radial_mask, + axis=(0,1)) + background_radial_mean_norm = np.sum( + self.background_radial_mask, + axis=(0,1)) + self.background_mask = \ + background_radial_mean_norm > (np.max(background_radial_mean_norm)*0.05) + self.background_radial_mean[self.background_mask] \ + /= background_radial_mean_norm[self.background_mask] + self.background_radial_mean[np.logical_not(self.background_mask)] = 0 + + # init + if ring_position is not None: + ring_position = np.atleast_1d(np.array(ring_position)) + num_rings = ring_position.shape[0] + else: + num_rings = 0 + self.background_coefs = np.zeros(3 + 3*num_rings) + + if ring_sigma is None: + ring_sigma = np.atleast_1d(np.ones(num_rings)) \ + * self.polar_shape[1] * 0.05 * self._radial_step + else: + ring_sigma = np.atleast_1d(np.array(ring_sigma)) + + # Background model initial parameters + int_const = np.min(self.background_radial_mean) + int_0 = np.max(self.background_radial_mean) - int_const + sigma_0 = self.polar_shape[1] * 0.25 * self._radial_step + self.background_coefs[0] = int_const + self.background_coefs[1] = int_0 + self.background_coefs[2] = sigma_0 + + # Additional Gaussians + if ring_int is None: + # Estimate peak intensities + sig_0 = int_const + int_0*np.exp(self.qq**2/(-2*sigma_0**2)) + sig_peaks = np.maximum(self.background_radial_mean - sig_0, 0.0) + + ring_int = np.atleast_1d(np.zeros(num_rings)) + for a0 in range(num_rings): + ind = np.argmin(np.abs(self.qq - ring_position[a0])) + ring_int[a0] = sig_peaks[ind] + + else: + ring_int = np.atleast_1d(np.array(ring_int)) + for a0 in range(num_rings): + self.background_coefs[3*a0+3] = ring_int[a0] + self.background_coefs[3*a0+4] = ring_sigma[a0] + self.background_coefs[3*a0+5] = ring_position[a0] + lb = np.zeros_like(self.background_coefs) + ub = np.ones_like(self.background_coefs) * np.inf + + # Create background model + def background_model(q, *coefs): + coefs = np.squeeze(np.array(coefs)) + num_rings = np.round((coefs.shape[0] - 3)/3).astype('int') + + sig = np.ones(q.shape[0])*coefs[0] + sig += coefs[1]*np.exp(q**2/(-2*coefs[2]**2)) + + for a0 in range(num_rings): + sig += coefs[3*a0+3]*np.exp( + (q-coefs[3*a0+5])**2 / (-2*coefs[3*a0+4]**2)) + + return sig + + self.background_model = background_model + + # Refine background model coefficients + if refine_model: + self.background_coefs = curve_fit( + self.background_model, + self.qq[self.background_mask], + self.background_radial_mean[self.background_mask], + p0 = self.background_coefs, + xtol = 1e-12, + bounds = (lb,ub), + )[0] + + # plotting + if plot_result: + self.plot_radial_background( + q_pixel_units = False, + plot_background_model = True, + figsize = figsize, + ) + + + +def refine_peaks( + self, + mask = None, + # reset_fits_to_init_positions = False, + scale_sigma_estimate = 0.5, + min_num_pixels_fit = 10, + maxfev = None, + progress_bar = True, + ): + """ + Use global fitting model for all images. Requires an background model + specified with self.model_radial_background(). + + TODO: add fitting reset + add min number pixels condition + track any failed fitting points, output as a boolean array + + Parameters + -------- + mask: np.array + Mask image to apply to all images + radial_background_subtract: bool + Subtract radial background before fitting + reset_fits_to_init_positions: bool + Use the initial peak parameters for fitting + scale_sigma_estimate: float + Factor to reduce sigma of peaks by, to prevent fit from running away. + min_num_pixels_fit: int + Minimum number of pixels to perform fitting + maxfev: int + Maximum number of iterations in fit. Set to a low number for a fast fit. + progress_bar: bool + Enable progress bar + + Returns + -------- + + """ + + # coordinate scaling + t_step = self._annular_step + q_step = self._radial_step + + # Background model params + num_rings = np.round((self.background_coefs.shape[0]-3)/3).astype('int') + + # basis + qq,tt = np.meshgrid( + self.qq, + self.tt, + ) + basis = np.zeros((qq.size,3)) + basis[:,0] = tt.ravel() + basis[:,1] = qq.ravel() + basis[:,2] = num_rings + + # init + self.peaks_refine = PointListArray( + dtype = [ + ('qt', 'float'), + ('qr', 'float'), + ('intensity', 'float'), + ('sigma_annular', 'float'), + ('sigma_radial', 'float')], + shape = self._datacube.Rshape, + name = 'peaks_polardata_refined', + ) + self.background_refine = np.zeros(( + self._datacube.Rshape[0], + self._datacube.Rshape[1], + np.round(3*num_rings+3).astype('int'), + )) + + + # Main loop over probe positions + for rx, ry in tqdmnd( + self._datacube.shape[0], + self._datacube.shape[1], + desc="Refining peaks ", + unit=" probe positions", + disable=not progress_bar): + + # Get transformed image and normalization array + im_polar, im_polar_norm, norm_array, mask_bool = self.transform( + self._datacube.data[rx,ry], + mask = mask, + returnval = 'all_zeros', + ) + # Change sign convention of mask + mask_bool = np.logical_not(mask_bool) + + # Get initial peaks, in dimensioned units + p = self.peaks[rx,ry] + qt = p.data['qt'] * t_step + qr = (p.data['qr'] + self.qmin) * q_step + int_peaks = p.data['intensity'] + s_annular = p.data['sigma_annular'] * t_step + s_radial = p.data['sigma_radial'] * q_step + num_peaks = p['qt'].shape[0] + + # unified coefficients + # Note we sharpen sigma estimate for refinement + coefs_all = np.hstack(( + self.background_coefs, + qt, + qr, + int_peaks, + s_annular * scale_sigma_estimate, + s_radial * scale_sigma_estimate, + )) + + # bounds + lb = np.zeros_like(coefs_all) + ub = np.ones_like(coefs_all) * np.inf + + + # Construct fitting model + def fit_image(basis, *coefs): + coefs = np.squeeze(np.array(coefs)) + + num_rings = np.round(basis[0,2]).astype('int') + num_peaks = np.round((coefs.shape[0] - (3*num_rings+3))/5).astype('int') + + coefs_bg = coefs[:(3*num_rings+3)] + coefs_peaks = coefs[(3*num_rings+3):] + + + # Background + sig = self.background_model( + basis[:,1], + coefs_bg) + + # add peaks + for a0 in range(num_peaks): + dt = np.mod(basis[:,0] - coefs_peaks[num_peaks*0+a0] + np.pi/2, np.pi) - np.pi/2 + dq = basis[:,1] - coefs_peaks[num_peaks*1+a0] + + sig += coefs_peaks[num_peaks*2+a0] \ + * np.exp( + dt**2 / (-2*coefs_peaks[num_peaks*3+a0]**2) + \ + dq**2 / (-2*coefs_peaks[num_peaks*4+a0]**2)) + + return sig + + # refine fitting model + try: + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + if maxfev is None: + coefs_all = curve_fit( + fit_image, + basis[mask_bool.ravel(),:], + im_polar[mask_bool], + p0 = coefs_all, + xtol = 1e-12, + bounds = (lb,ub), + )[0] + else: + coefs_all = curve_fit( + fit_image, + basis[mask_bool.ravel(),:], + im_polar[mask_bool], + p0 = coefs_all, + xtol = 1e-12, + maxfev = maxfev, + bounds = (lb,ub), + )[0] + + # Output refined peak parameters + coefs_peaks = np.reshape( + coefs_all[(3*num_rings+3):], + (5,num_peaks)).T + self.peaks_refine[rx,ry] = PointList( + coefs_peaks.ravel().view([ + ('qt', float), + ('qr', float), + ('intensity', float), + ('sigma_annular', float), + ('sigma_radial', float), + ]), + name = 'peaks_polar') + except: + # if fitting has failed, we will still output the last iteration + # TODO - add a flag for unconverged fits + coefs_peaks = np.reshape( + coefs_all[(3*num_rings+3):], + (5,num_peaks)).T + self.peaks_refine[rx,ry] = PointList( + coefs_peaks.ravel().view([ + ('qt', float), + ('qr', float), + ('intensity', float), + ('sigma_annular', float), + ('sigma_radial', float), + ]), + name = 'peaks_polar') + + # mean background signal, + # # but none of the peaks. + # pass + + # Output refined parameters for background + coefs_bg = coefs_all[:(3*num_rings+3)] + self.background_refine[rx,ry] = coefs_bg + + + # # Testing + # im_fit = np.reshape( + # fit_image(basis,coefs_all), + # self.polar_shape) + + + # fig,ax = plt.subplots(figsize=(8,6)) + # ax.imshow( + # np.vstack(( + # im_polar, + # im_fit, + # )), + # cmap = 'turbo', + # ) + + +def plot_radial_background( + self, + q_pixel_units = False, + label_y_axis = False, + plot_background_model = False, + figsize = (8,4), + returnfig = False, + ): + """ + Calculate and plot the mean background signal, background standard deviation. + + """ + + # mean + self.background_radial_mean = np.sum( + self.background_radial * self.background_radial_mask, + axis=(0,1)) + background_radial_mean_norm = np.sum( + self.background_radial_mask, + axis=(0,1)) + self.background_mask = \ + background_radial_mean_norm > (np.max(background_radial_mean_norm)*0.05) + self.background_radial_mean[self.background_mask] \ + /= background_radial_mean_norm[self.background_mask] + self.background_radial_mean[np.logical_not(self.background_mask)] = 0 + + # variance and standard deviation + self.background_radial_var = np.sum( + (self.background_radial - self.background_radial_mean[None,None,:])**2 \ + * self.background_radial_mask, + axis=(0,1)) + self.background_radial_var[self.background_mask] \ + /= self.background_radial_var[self.background_mask] + self.background_radial_var[np.logical_not(self.background_mask)] = 0 + self.background_radial_std = np.sqrt(self.background_radial_var) + + + if q_pixel_units: + q_axis = np.arange(self.qq.shape[0]) + else: + q_axis = self.qq[self.background_mask] + + fig,ax = plt.subplots(figsize = figsize) + ax.fill_between( + q_axis, + self.background_radial_mean[self.background_mask] \ + - self.background_radial_std[self.background_mask], + self.background_radial_mean[self.background_mask] \ + + self.background_radial_std[self.background_mask], + color = 'r', + alpha=0.2, + ) + ax.plot( + q_axis, + self.background_radial_mean[self.background_mask], + color = 'r', + linewidth = 2, + ) + + # overlay fitting model + if plot_background_model: + sig = self.background_model( + self.qq, + self.background_coefs, + ) + ax.plot( + q_axis, + sig, + color = 'k', + linewidth = 2, + linestyle = '--' + ) + + # plot appearance + ax.set_xlim(( + q_axis[0], + q_axis[-1])) + if q_pixel_units: + ax.set_xlabel( + 'Scattering Angle (pixels)', + fontsize = 14, + ) + else: + ax.set_xlabel( + 'Scattering Angle (' + self.calibration.get_Q_pixel_units() +')', + fontsize = 14, + ) + ax.set_ylabel( + 'Background Signal', + fontsize = 14, + ) + if not label_y_axis: + ax.tick_params( + left = False, + labelleft = False) + + if returnfig: + return fig,ax + + +def make_orientation_histogram( + self, + radial_ranges: np.ndarray = None, + orientation_flip_sign: bool = False, + orientation_offset_degrees: float = 0.0, + orientation_separate_bins: bool = False, + upsample_factor: float = 4.0, + use_refined_peaks = True, + use_peak_sigma = False, + peak_sigma_samples = 6, + theta_step_deg: float = None, + sigma_x: float = 1.0, + sigma_y: float = 1.0, + sigma_theta: float = 3.0, + normalize_intensity_image: bool = False, + normalize_intensity_stack: bool = True, + progress_bar: bool = True, + ): + """ + Make an orientation histogram, in order to use flowline visualization of orientation maps. + Use peaks attached to polardatacube. + + NOTE - currently assumes two fold rotation symmetry + TODO - add support for non two fold symmetry polardatacube + + Args: + radial_ranges (np array): Size (N x 2) array for N radial bins, or (2,) for a single bin. + orientation_flip_sign (bool): Flip the direction of theta + orientation_offset_degrees (float): Offset for orientation angles + orientation_separate_bins (bool): whether to place multiple angles into multiple radial bins. + upsample_factor (float): Upsample factor + use_refined_peaks (float): Use refined peak positions + use_peak_sigma (float): Spread signal along annular direction using measured std. + theta_step_deg (float): Step size along annular direction in degrees + sigma_x (float): Smoothing in x direction before upsample + sigma_y (float): Smoothing in x direction before upsample + sigma_theta (float): Smoothing in annular direction (units of bins, periodic) + normalize_intensity_image (bool): Normalize to max peak intensity = 1, per image + normalize_intensity_stack (bool): Normalize to max peak intensity = 1, all images + progress_bar (bool): Enable progress bar + + Returns: + orient_hist (array): 4D array containing Bragg peak intensity histogram + [radial_bin x_probe y_probe theta] + """ + + # coordinates + if theta_step_deg is None: + # Get angles from polardatacube + theta = self.tt + else: + theta = np.arange(0,180,theta_step_deg) * np.pi / 180.0 + dtheta = theta[1] - theta[0] + dtheta_deg = dtheta * 180 / np.pi + num_theta_bins = np.size(theta) + + # Input bins + radial_ranges = np.array(radial_ranges) + if radial_ranges.ndim == 1: + radial_ranges = radial_ranges[None,:] + radial_ranges_2 = radial_ranges**2 + num_radii = radial_ranges.shape[0] + size_input = self._datacube.shape[0:2] + + # Output size + size_output = np.round(np.array(size_input).astype('float') * upsample_factor).astype('int') + + # output init + orient_hist = np.zeros([ + num_radii, + size_output[0], + size_output[1], + num_theta_bins]) + + if use_peak_sigma: + v_sigma = np.linspace(-2,2,2*peak_sigma_samples+1) + w_sigma = np.exp(-v_sigma**2/2) + + if use_refined_peaks is False: + warnings.warn("Orientation histogram is using non-refined peak positions") + + # Loop over all probe positions + for a0 in range(num_radii): + t = "Generating histogram " + str(a0) + + for rx, ry in tqdmnd( + *size_input, + desc=t, + unit=" probe positions", + disable=not progress_bar + ): + x = (rx + 0.5)*upsample_factor - 0.5 + y = (ry + 0.5)*upsample_factor - 0.5 + x = np.clip(x,0,size_output[0]-2) + y = np.clip(y,0,size_output[1]-2) + + xF = np.floor(x).astype('int') + yF = np.floor(y).astype('int') + dx = x - xF + dy = y - yF + + add_data = False + if use_refined_peaks: + q = self.peaks_refine[rx,ry]['qr'] + else: + q = (self.peaks[rx,ry]['qr'] + self.qmin) * self._radial_step + r2 = q**2 + sub = np.logical_and(r2 >= radial_ranges_2[a0,0], r2 < radial_ranges_2[a0,1]) + + if np.any(sub): + add_data = True + intensity = self.peaks[rx,ry]['intensity'][sub] + + # Angles of all peaks + if use_refined_peaks: + theta = self.peaks_refine[rx,ry]['qt'][sub] + else: + theta = self.peaks[rx,ry]['qt'][sub] * self._annular_step + if orientation_flip_sign: + theta *= -1 + theta += orientation_offset_degrees + + t = theta / dtheta + + # If needed, expand signal using peak sigma to write into multiple bins + if use_peak_sigma: + if use_refined_peaks: + theta_std = self.peaks_refine[rx,ry]['sigma_annular'][sub] / dtheta + else: + theta_std = self.peaks[rx,ry]['sigma_annular'][sub] / dtheta + t = (t[:,None] + theta_std[:,None]*v_sigma[None,:]).ravel() + intensity = (intensity[:,None] * w_sigma[None,:]).ravel() + + if add_data: + tF = np.floor(t).astype('int') + dt = t - tF + + orient_hist[a0,xF ,yF ,:] = orient_hist[a0,xF ,yF ,:] + \ + np.bincount(np.mod(tF ,num_theta_bins), + weights=(1-dx)*(1-dy)*(1-dt)*intensity,minlength=num_theta_bins) + orient_hist[a0,xF ,yF ,:] = orient_hist[a0,xF ,yF ,:] + \ + np.bincount(np.mod(tF+1,num_theta_bins), + weights=(1-dx)*(1-dy)*( dt)*intensity,minlength=num_theta_bins) + + orient_hist[a0,xF+1,yF ,:] = orient_hist[a0,xF+1,yF ,:] + \ + np.bincount(np.mod(tF ,num_theta_bins), + weights=( dx)*(1-dy)*(1-dt)*intensity,minlength=num_theta_bins) + orient_hist[a0,xF+1,yF ,:] = orient_hist[a0,xF+1,yF ,:] + \ + np.bincount(np.mod(tF+1,num_theta_bins), + weights=( dx)*(1-dy)*( dt)*intensity,minlength=num_theta_bins) + + orient_hist[a0,xF ,yF+1,:] = orient_hist[a0,xF ,yF+1,:] + \ + np.bincount(np.mod(tF ,num_theta_bins), + weights=(1-dx)*( dy)*(1-dt)*intensity,minlength=num_theta_bins) + orient_hist[a0,xF ,yF+1,:] = orient_hist[a0,xF ,yF+1,:] + \ + np.bincount(np.mod(tF+1,num_theta_bins), + weights=(1-dx)*( dy)*( dt)*intensity,minlength=num_theta_bins) + + orient_hist[a0,xF+1,yF+1,:] = orient_hist[a0,xF+1,yF+1,:] + \ + np.bincount(np.mod(tF ,num_theta_bins), + weights=( dx)*( dy)*(1-dt)*intensity,minlength=num_theta_bins) + orient_hist[a0,xF+1,yF+1,:] = orient_hist[a0,xF+1,yF+1,:] + \ + np.bincount(np.mod(tF+1,num_theta_bins), + weights=( dx)*( dy)*( dt)*intensity,minlength=num_theta_bins) + + # smoothing / interpolation + if (sigma_x is not None) or (sigma_y is not None) or (sigma_theta is not None): + if num_radii > 1: + print('Interpolating orientation matrices ...', end='') + else: + print('Interpolating orientation matrix ...', end='') + if sigma_x is not None and sigma_x > 0: + orient_hist = gaussian_filter1d( + orient_hist,sigma_x*upsample_factor, + mode='nearest', + axis=1, + truncate=3.0) + if sigma_y is not None and sigma_y > 0: + orient_hist = gaussian_filter1d( + orient_hist,sigma_y*upsample_factor, + mode='nearest', + axis=2, + truncate=3.0) + if sigma_theta is not None and sigma_theta > 0: + orient_hist = gaussian_filter1d( + orient_hist,sigma_theta/dtheta_deg, + mode='wrap', + axis=3, + truncate=2.0) + print(' done.') + + # normalization + if normalize_intensity_stack is True: + orient_hist = orient_hist / np.max(orient_hist) + elif normalize_intensity_image is True: + for a0 in range(num_radii): + orient_hist[a0,:,:,:] = orient_hist[a0,:,:,:] / np.max(orient_hist[a0,:,:,:]) + + return orient_hist \ No newline at end of file diff --git a/py4DSTEM/process/probe/__init__.py b/py4DSTEM/process/probe/__init__.py deleted file mode 100644 index a77a14ae6..000000000 --- a/py4DSTEM/process/probe/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from py4DSTEM.process.probe.probe import * -from py4DSTEM.process.probe.kernel import * - diff --git a/py4DSTEM/process/probe/kernel.py b/py4DSTEM/process/probe/kernel.py deleted file mode 100644 index f415ae96d..000000000 --- a/py4DSTEM/process/probe/kernel.py +++ /dev/null @@ -1,279 +0,0 @@ -# Functions for preparing the probe for cross-correlative template matching. - - -import numpy as np - -from py4DSTEM.process.utils import get_shifted_ar -from py4DSTEM.process.calibration import get_probe_size - - -def get_kernel( - probe, - mode = 'flat', - **kwargs - ): - """ - Creates a kernel from the probe for cross-correlative template matching. - - Precise behavior and valid keyword arguments depend on the `mode` - selected. In each case, the center of the probe is shifted to the - origin and the kernel normalized such that it sums to 1. In 'flat' - mode, this is the only processing performed. In the remaining modes, - some additional processing is performed which adds a ring of - negative intensity around the central probe, which results in - edge-filetering-like behavior during cross correlation. Valid modes - are: - - - 'flat': creates a flat probe kernel. For bullseye or other - structured probes, this mode is recommended. - - 'gaussian': subtracts a gaussian with a width of standard - deviation 'sigma' - - 'sigmoid': subtracts an annulus with inner and outer radii - of (ri,ro) and a sine-squared sigmoid radial profile from - the probe template. - - 'sigmoid_log': subtracts an annulus with inner and outer radii - of (ri,ro) and a logistic sigmoid radial profile from - the probe template. - - Each mode accepts 'center' (2-tuple) as a kwarg to manually specify - the center of the probe, which is otherwise autodetected. Modes which - accept additional kwargs and those arguments are: - - - 'gaussian': - sigma (number) - - 'sigmoid': - radii (2-tuple) - - 'sigmoid_log': - radii (2-tuple) - - Accepts: - probe (2D array): - mode (str): must be in 'flat','gaussian','sigmoid','sigmoid_log' - **kwargs: depend on `mode`, see above - - Returns: - (2D array) - """ - - modes = [ - 'flat', - 'gaussian', - 'sigmoid', - 'sigmoid_log' - ] - - # parse args - assert mode in modes, f"mode must be in {modes}. Received {mode}" - - # get function - fn_dict = _make_function_dict() - fn = fn_dict[mode] - - # compute and return - kernel = fn(probe, **kwargs) - return kernel - - - -def _make_function_dict(): - d = { - 'flat' : get_probe_kernel, - 'gaussian' : get_probe_kernel_edge_gaussian, - 'sigmoid' : _get_probe_kernel_edge_sigmoid_sine_squared, - 'sigmoid_log' : _get_probe_kernel_edge_sigmoid_sine_squared - } - return d - - - - - -def get_probe_kernel( - probe, - origin=None, - bilinear=True - ): - """ - Creates a convolution kernel from an average probe, by normalizing, then - shifting the center of the probe to the corners of the array. - - Args: - probe (ndarray): the diffraction pattern corresponding to the probe over - vacuum - origin (2-tuple or None): if None (default), finds the origin using - get_probe_radius. Otherwise, should be a 2-tuple (x0,y0) specifying - the origin position - bilinear (bool): By default probe is shifted via a Fourier transform. - Setting this to True overrides it and uses bilinear shifting. - Not recommended! - - Returns: - (ndarray): the convolution kernel corresponding to the probe, in real - space - """ - Q_Nx, Q_Ny = probe.shape - - # Get CoM - if origin is None: - _,xCoM,yCoM = get_probe_size(probe) - else: - xCoM,yCoM = origin - - # Normalize - probe = probe/np.sum(probe) - - # Shift center to corners of array - probe_kernel = get_shifted_ar(probe, -xCoM, -yCoM, bilinear=bilinear) - - return probe_kernel - - -def get_probe_kernel_edge_gaussian( - probe, - sigma, - origin=None, - bilinear=True, - ): - """ - Creates a convolution kernel from an average probe, subtracting a gaussian - from the normalized probe such that the kernel integrates to zero, then - shifting the center of the probe to the array corners. - - Args: - probe (ndarray): the diffraction pattern corresponding to the probe - over vacuum - sigma (float): the width of the gaussian to subtract, relative to - the standard deviation of the probe - origin (2-tuple or None): if None (default), finds the origin using - get_probe_radius. Otherwise, should be a 2-tuple (x0,y0) specifying - the origin position - bilinear (bool): By default probe is shifted via a Fourier transform. - Setting this to True overrides it and uses bilinear shifting. - Not recommended! - - Returns: - (ndarray) the convolution kernel corresponding to the probe - """ - Q_Nx, Q_Ny = probe.shape - - # Get CoM - if origin is None: - _,xCoM,yCoM = get_probe_size(probe) - else: - xCoM,yCoM = origin - - # Shift probe to origin - probe_kernel = get_shifted_ar(probe, -xCoM, -yCoM, bilinear=bilinear) - - # Generate normalization kernel - # Coordinates - qy,qx = np.meshgrid( - np.mod(np.arange(Q_Ny) + Q_Ny//2, Q_Ny) - Q_Ny//2, - np.mod(np.arange(Q_Nx) + Q_Nx//2, Q_Nx) - Q_Nx//2) - qr2 = (qx**2 + qy**2) - # Calculate Gaussian normalization kernel - qstd2 = np.sum(qr2*probe_kernel) / np.sum(probe_kernel) - kernel_norm = np.exp(-qr2 / (2*qstd2*sigma**2)) - - # Output normalized kernel - probe_kernel = probe_kernel/np.sum(probe_kernel) - kernel_norm/np.sum(kernel_norm) - - return probe_kernel - - -def get_probe_kernel_edge_sigmoid( - probe, - radii, - origin=None, - type='sine_squared', - bilinear=True, - ): - """ - Creates a convolution kernel from an average probe, subtracting an annular - trench about the probe such that the kernel integrates to zero, then - shifting the center of the probe to the array corners. - - Args: - probe (ndarray): the diffraction pattern corresponding to the probe over - vacuum - radii (2-tuple): the sigmoid inner and outer radii, from the probe center - origin (2-tuple or None): if None (default), finds the origin using - get_probe_radius. Otherwise, should be a 2-tuple (x0,y0) specifying - the origin position - type (string): must be 'logistic' or 'sine_squared' - bilinear (bool): By default probe is shifted via a Fourier transform. - Setting this to True overrides it and uses bilinear shifting. - Not recommended! - - Returns: - (ndarray): the convolution kernel corresponding to the probe - """ - valid_types = ('logistic','sine_squared') - assert(type in valid_types), "type must be in {}".format(valid_types) - Q_Nx, Q_Ny = probe.shape - ri,ro = radii - - # Get CoM - if origin is None: - _,xCoM,yCoM = get_probe_size(probe) - else: - xCoM,yCoM = origin - - # Shift probe to origin - probe_kernel = get_shifted_ar(probe, -xCoM, -yCoM, bilinear=bilinear) - - # Generate normalization kernel - # Coordinates - qy,qx = np.meshgrid( - np.mod(np.arange(Q_Ny) + Q_Ny//2, Q_Ny) - Q_Ny//2, - np.mod(np.arange(Q_Nx) + Q_Nx//2, Q_Nx) - Q_Nx//2) - qr = np.sqrt(qx**2 + qy**2) - # Calculate sigmoid - if type == 'logistic': - r0 = 0.5*(ro+ri) - sigma = 0.25*(ro-ri) - sigmoid = 1/(1+np.exp((qr-r0)/sigma)) - elif type == 'sine_squared': - sigmoid = (qr - ri) / (ro - ri) - sigmoid = np.minimum(np.maximum(sigmoid, 0.0), 1.0) - sigmoid = np.cos((np.pi/2)*sigmoid)**2 - else: - raise Exception("type must be in {}".format(valid_types)) - - # Output normalized kernel - probe_kernel = probe_kernel/np.sum(probe_kernel) - sigmoid/np.sum(sigmoid) - - return probe_kernel - - - -def _get_probe_kernel_edge_sigmoid_sine_squared( - probe, - radii, - origin=None, - **kwargs, - ): - return get_probe_kernel_edge_sigmoid( - probe, - radii, - origin = origin, - type='sine_squared', - **kwargs, - ) - -def _get_probe_kernel_edge_sigmoid_logistic( - probe, - radii, - origin=None, - **kwargs, - ): - return get_probe_kernel_edge_sigmoid( - probe, - radii, - origin = origin, - type='logistic', - **kwargs - ) - - - diff --git a/py4DSTEM/process/probe/probe.py b/py4DSTEM/process/probe/probe.py deleted file mode 100644 index c9bb98634..000000000 --- a/py4DSTEM/process/probe/probe.py +++ /dev/null @@ -1,392 +0,0 @@ -# Functions for getting images of the vacuum probe - -import numpy as np -from scipy.ndimage import ( - binary_opening, binary_dilation, distance_transform_edt) - -from emdfile import tqdmnd -from py4DSTEM.classes import DataCube, Probe -from py4DSTEM.process.utils import get_shifted_ar, get_shift - - - - - -def get_vacuum_probe( - data, - **kwargs - ): - """ - Takes some data and computes a vacuum probe, using a method - selected based on the type and shape of `data`, and on other - arguments passed. In each case, points outside the center - disk are set to zero. - - Args: - data (variable): behavior and additional arguments depend on - the type of `data`. If `data` is a - - DataCube: computes a probe using all or some subset of - the diffraction patterns in the datacube, aligning and - averaging those patterns. The whole datacube is used - if the `ROI` argument is not passed. If `ROI` is - passed, uses a subset of diffraction patterns based on - the ROI argument's value, which may be a either an - R-space shaped boolean mask, or a 4-tuple representing - (Rxmin,Rxmax,Rymin,Rymax) of a rectangular region. - - 3D array: averages the stack with no alignment - - 2D array: uses this array as the probe. - - None: makes a synthetic probe. Additional required - arguments are - - radius (number): the probe radius - - width (number): the width of the region where - the probe intensity drops off from its maximum - to 0. - - Qshape (2 tuple): the shape of diffraction space - - Returns: - (Probe) a Probe instance - """ - - # select mode of operation - - if isinstance(data,DataCube): - mode = '4D' - elif isinstance(data,np.ndarray): - mode = str(data.ndim)+'D' - elif data is None: - mode = 'synth' - else: - er = f"invalid type {type(data)} for `data`." - er += f"must be in (DataCube,np.ndarray,None)" - raise Exception(er) - - if mode == '4D': - if 'ROI' in kwargs.keys(): - roi = kwargs['ROI'] - if isinstance(roi, np.ndarray): - mode = '4D_roi_mask' - else: - mode = '4D_roi_lims' - else: - mode = '4D_full' - - - - # choose and run a function - functions = { - '4D_full' : get_probe_from_vacuum_4Dscan, - '4D_roi_mask' : get_probe_from_4Dscan_ROI_mask, - '4D_roi_lims' : get_probe_from_4Dscan_ROI_lims, - '3D' : get_probe_from_vacuum_3Dstack, - '2D' : get_probe_from_vacuum_2Dimage, - 'synth' : get_probe_synthetic - } - fn = functions[mode] - if mode == 'synth': - probe = fn(**kwargs) - else: - probe = fn(data,**kwargs) - return probe - - - - - -def get_probe_from_vacuum_4Dscan( - datacube, - mask_threshold=0.2, - mask_expansion=12, - mask_opening=3, - verbose=False, - align=True - ): - """ - Averages all diffraction patterns in a datacube, assumed to be taken - over vacuum, to create and average vacuum probe. Optionally (default) - aligns the patterns. - - Values outisde the average probe are zeroed, using a binary mask determined - by the optional parameters mask_threshold, mask_expansion, and mask_opening. - An initial binary mask is created using a threshold of less than - mask_threshold times the maximal probe value. A morphological opening of - mask_opening pixels is performed to eliminate stray pixels (e.g. from - x-rays), followed by a dilation of mask_expansion pixels to ensure the - entire probe is captured. - - Args: - datacube (DataCube): a vacuum scan - mask_threshold (float): threshold determining mask which zeros values - outside of probe - mask_expansion (int): number of pixels by which the zeroing mask is - expanded to capture the full probe - mask_opening (int): size of binary opening used to eliminate stray - bright pixels - verbose (bool): if True, prints progress updates - align (bool): if True, aligns the probes before averaging - - Returns: - (ndarray of shape (datacube.Q_Nx,datacube.Q_Ny)): the average probe - """ - - probe = datacube.data[0,0,:,:] - for n in tqdmnd(range(1,datacube.R_N)): - Rx,Ry = np.unravel_index(n,datacube.data.shape[:2]) - curr_DP = datacube.data[Rx,Ry,:,:] - if verbose: - print(f"Shifting and averaging diffraction pattern {n} of {datacube.R_N}.") - if align: - xshift,yshift = get_shift(probe, curr_DP) - curr_DP = get_shifted_ar(curr_DP, xshift, yshift) - probe = probe*(n-1)/n + curr_DP/n - - mask = probe > np.max(probe)*mask_threshold - mask = binary_opening(mask, iterations=mask_opening) - mask = binary_dilation(mask, iterations=1) - mask = np.cos((np.pi/2)*np.minimum(distance_transform_edt(np.logical_not(mask)) / mask_expansion, 1))**2 - - return probe*mask - - - - -def get_probe_from_4Dscan_ROI_lims( - datacube, - ROI, - mask_threshold=0.2, - mask_expansion=12, - mask_opening=3, - verbose=False, - align=True - ): - """ - Averages all diffraction patterns within a specified ROI of a datacube to - create an average vacuum probe. Optionally (default) aligns the patterns. - - See documentation for get_average_probe_from_vacuum_scan for more detailed - discussion of the algorithm. - - Args: - datacube (DataCube): a vacuum scan - ROI (len 4 list or tuple): the limits (rx_min, rx_max, ry_min, ry_max) - of the selected region. - mask_threshold (float): threshold determining mask which zeros values - outside of probe - mask_expansion (int): number of pixels by which the zeroing mask is - expanded to capture the full probe - mask_opening (int): size of binary opening used to eliminate stray - bright pixels - verbose (bool): if True, prints progress updates - align (bool): if True, aligns the probes before averaging - DP_mask (array): array of same shape as diffraction pattern to mask - probes - - Returns: - (ndarray of shape (datacube.Q_Nx,datacube.Q_Ny)): the average probe - """ - assert len(ROI) == 4 - - datacube = DataCube( - data = datacube.data[ROI[0]:ROI[1],ROI[2]:ROI[3]] - ) - return get_probe_from_vacuum_4Dscan( - datacube, - mask_threshold = mask_threshold, - mask_expansion = mask_expansion, - mask_opening = mask_opening, - verbose = verbose, - align = align) - - - - -def get_probe_from_4Dscan_ROI_mask( - datacube, - ROI, - mask_threshold=0.2, - mask_expansion=12, - mask_opening=3, - verbose=False, - align=True, - DP_mask=1 - ): - """ - Averages all diffraction patterns within a specified ROI of a datacube to - create an average vacuum probe. Optionally (default) aligns the patterns. - - See documentation for get_average_probe_from_vacuum_scan for more detailed - discussion of the algorithm. - - Args: - datacube (DataCube): a vacuum scan - ROI (ndarray of dtype=bool and shape (datacube.R_Nx,datacube.R_Ny)): An - array of boolean variables shaped like the real space scan. Only scan - positions where ROI==True are used to create the average probe. - mask_threshold (float): threshold determining mask which zeros values - outside of probe - mask_expansion (int): number of pixels by which the zeroing mask is - expanded to capture the full probe - mask_opening (int): size of binary opening used to eliminate stray - bright pixels - verbose (bool): if True, prints progress updates - align (bool): if True, aligns the probes before averaging - DP_mask (array): array of same shape as diffraction pattern to mask - probes - - Returns: - (ndarray of shape (datacube.Q_Nx,datacube.Q_Ny)): the average probe - """ - assert ROI.shape==(datacube.R_Nx,datacube.R_Ny) - length = ROI.sum() - xy = np.vstack(np.nonzero(ROI)) - probe = datacube.data[xy[0,0],xy[1,0],:,:] - for n in tqdmnd(range(1,length)): - curr_DP = datacube.data[xy[0,n],xy[1,n],:,:] * DP_mask - if align: - xshift,yshift = get_shift(probe, curr_DP) - curr_DP = get_shifted_ar(curr_DP, xshift, yshift) - probe = probe*(n-1)/n + curr_DP/n - - mask = probe > np.max(probe)*mask_threshold - mask = binary_opening(mask, iterations=mask_opening) - mask = binary_dilation(mask, iterations=1) - mask = np.cos((np.pi/2)*np.minimum(distance_transform_edt(np.logical_not(mask)) / mask_expansion, 1))**2 - - return probe*mask - - - - - -def get_probe_from_vacuum_3Dstack( - data, - mask_threshold=0.2, - mask_expansion=12, - mask_opening=3 - ): - """ - Averages all diffraction patterns in a 3D stack of diffraction patterns, - assumed to be taken over vacuum, to create and average vacuum probe. No - alignment is performed - i.e. it is assumed that the beam was stationary - during acquisition of the stack. - - Values outisde the average probe are zeroed, using a binary mask determined - by the optional parameters mask_threshold, mask_expansion, and mask_opening. - An initial binary mask is created using a threshold of less than - mask_threshold times the maximal probe value. A morphological opening of - mask_opening pixels is performed to eliminate stray pixels (e.g. from - x-rays), followed by a dilation of mask_expansion pixels to ensure the - entire probe is captured. - - Args: - data (array): a 3D stack of vacuum diffraction patterns, shape - (Q_Nx,Q_Ny,N) - mask_threshold (float): threshold determining mask which zeros values - outside of probe - mask_expansion (int): number of pixels by which the zeroing mask is - expanded to capture the full probe - mask_opening (int): size of binary opening used to eliminate stray - bright pixels - - Returns: - (array of shape (Q_Nx,Q_Ny)): the average probe - """ - probe = np.average(data,axis=2) - - mask = probe > np.max(probe)*mask_threshold - mask = binary_opening(mask, iterations=mask_opening) - mask = binary_dilation(mask, iterations=1) - mask = np.cos((np.pi/2)*np.minimum(distance_transform_edt(np.logical_not(mask)) / mask_expansion, 1))**2 - - return probe*mask - - - - -def get_probe_from_vacuum_2Dimage( - data, - mask_threshold=0.2, - mask_expansion=12, - mask_opening=3 - ): - """ - A single image of the probe over vacuum is processed by zeroing values - outside the central disk, using a binary mask determined by the optional - parameters mask_threshold, mask_expansion, and mask_opening. An initial - binary mask is created using a threshold of less than mask_threshold time - the maximal probe value. A morphological opening of mask_opening pixels is - performed to eliminate stray pixels (e.g. from x-rays), followed by a - dilation of mask_expansion pixels to ensure the entire probe is captured. - - Args: - data (array): a 2D array of the vacuum diffraction pattern, shape - (Q_Nx,Q_Ny) - mask_threshold (float): threshold determining mask which zeros values - outside of probe - mask_expansion (int): number of pixels by which the zeroing mask is - expanded to capture the full probe - mask_opening (int): size of binary opening used to eliminate stray - bright pixels - - Returns: - (array of shape (Q_Nx,Q_Ny)) the average probe - """ - mask = data > np.max(data)*mask_threshold - mask = binary_opening(mask, iterations=mask_opening) - mask = binary_dilation(mask, iterations=1) - mask = np.cos((np.pi/2)*np.minimum(distance_transform_edt(np.logical_not(mask)) / mask_expansion, 1))**2 - - return data*mask - - - - -def get_probe_synthetic( - radius, - width, - Qshape - ): - """ - Makes a synthetic probe, with the functional form of a disk blurred by a - sigmoid (a logistic function). - - Args: - radius (float): the probe radius - width (float): the blurring of the probe edge. width represents the - full width of the blur, with x=-w/2 to x=+w/2 about the edge - spanning values of ~0.12 to 0.88 - Qshape (2 tuple): the diffraction plane dimensions - - Returns: - (ndarray of shape (Q_Nx,Q_Ny)): the probe - """ - # Make coords - Q_Nx,Q_Ny = Qshape - qy,qx = np.meshgrid(np.arange(Q_Ny),np.arange(Q_Nx)) - qy,qx = qy - Q_Ny/2., qx-Q_Nx/2. - qr = np.sqrt(qx**2+qy**2) - - # Shift zero to disk edge - qr = qr - radius - - # Calculate logistic function - probe = 1/(1+np.exp(4*qr/width)) - - return probe - - - - - - - - -# Probe templates can be generated from vacuum scans, from a selected ROI of a -# vacuum region of a scan, or synthetic probes. Ultimately the purpose is to -# generate a kernel for convolution with individual diffraction patterns to -# identify Bragg disks. Kernel generation will generally proceed in two steps, -# which will each correspond to a function call: first, obtaining or creating -# the diffraction pattern of a probe over vacuum, and second, turning the probe -# DP into a convolution kernel by shifting and normalizing. - - - diff --git a/py4DSTEM/process/rdf/amorph.py b/py4DSTEM/process/rdf/amorph.py index 9c80a2807..a537896b9 100644 --- a/py4DSTEM/process/rdf/amorph.py +++ b/py4DSTEM/process/rdf/amorph.py @@ -111,7 +111,7 @@ def plot_strains(strains, cmap="RdBu_r", vmin=None, vmax=None, mask=None): cmap, vmin, vmax: imshow parameters mask: real space mask of values not to show (black) """ - cmap = matplotlib.cm.get_cmap(cmap) + cmap = plt.get_cmap(cmap) if vmin is None: vmin = np.min(strains) if vmax is None: diff --git a/py4DSTEM/process/strain.py b/py4DSTEM/process/strain.py new file mode 100644 index 000000000..db252f75b --- /dev/null +++ b/py4DSTEM/process/strain.py @@ -0,0 +1,601 @@ +# Defines the Strain class + +from typing import Optional + +import matplotlib.pyplot as plt +import numpy as np +from py4DSTEM import PointList +from py4DSTEM.braggvectors import BraggVectors +from py4DSTEM.data import Data, RealSlice +from py4DSTEM.preprocess.utils import get_maxima_2D +from py4DSTEM.visualize import add_bragg_index_labels, add_pointlabels, add_vector, show + + +class StrainMap(RealSlice, Data): + """ + Stores strain map. + + TODO add docs + + """ + + def __init__(self, braggvectors: BraggVectors, name: Optional[str] = "strainmap"): + """ + TODO + """ + assert isinstance( + braggvectors, BraggVectors + ), f"braggvectors must be BraggVectors, not type {type(braggvectors)}" + + # initialize as a RealSlice + RealSlice.__init__( + self, + name=name, + data=np.empty( + ( + 6, + braggvectors.Rshape[0], + braggvectors.Rshape[1], + ) + ), + slicelabels=["exx", "eyy", "exy", "theta", "mask", "error"], + ) + + # set up braggvectors + # this assigns the bvs, ensures the origin is calibrated, + # and adds the strainmap to the bvs' tree + self.braggvectors = braggvectors + + # initialize as Data + Data.__init__(self) + + # set calstate + # this property is used only to check to make sure that + # the braggvectors being used throughout a workflow are + # the same. The state of calibration of the vectors is noted + # here, and then checked each time the vectors are used - + # if they differ, an error message and instructions for + # re-calibration are issued + self.calstate = self.braggvectors.calstate + assert self.calstate["center"], "braggvectors must be centered" + # get the BVM + # a new BVM using the current calstate is computed + self.bvm = self.braggvectors.histogram(mode="cal") + + # braggvector properties + + @property + def braggvectors(self): + return self._braggvectors + + @braggvectors.setter + def braggvectors(self, x): + assert isinstance( + x, BraggVectors + ), f".braggvectors must be BraggVectors, not type {type(x)}" + assert ( + x.calibration.origin is not None + ), f"braggvectors must have a calibrated origin" + self._braggvectors = x + self._braggvectors.tree(self, force=True) + + def reset_calstate(self): + """ + Resets the calibration state. This recomputes the BVM, and removes any computations + this StrainMap instance has stored, which will need to be recomputed. + """ + for attr in ( + "g0", + "g1", + "g2", + ): + if hasattr(self, attr): + delattr(self, attr) + self.calstate = self.braggvectors.calstate + pass + + # Class methods + + def choose_lattice_vectors( + self, + index_g0, + index_g1, + index_g2, + subpixel="multicorr", + upsample_factor=16, + sigma=0, + minAbsoluteIntensity=0, + minRelativeIntensity=0, + relativeToPeak=0, + minSpacing=0, + edgeBoundary=1, + maxNumPeaks=10, + figsize=(12, 6), + c_indices="lightblue", + c0="g", + c1="r", + c2="r", + c_vectors="r", + c_vectorlabels="w", + size_indices=20, + width_vectors=1, + size_vectorlabels=20, + vis_params={}, + returncalc=False, + returnfig=False, + ): + """ + Choose which lattice vectors to use for strain mapping. + + Overlays the bvm with the points detected via local 2D + maxima detection, plus an index for each point. User selects + 3 points using the overlaid indices, which are identified as + the origin and the termini of the lattice vectors g1 and g2. + + Parameters + ---------- + index_g0 : int + selected index for the origin + index_g1 : int + selected index for g1 + index_g2 :int + selected index for g2 + subpixel : str in ('pixel','poly','multicorr') + See the docstring for py4DSTEM.preprocess.get_maxima_2D + upsample_factor : int + See the py4DSTEM.preprocess.get_maxima_2D docstring + sigma : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + minAbsoluteIntensity : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + minRelativeIntensity : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + relativeToPeak : int + See the py4DSTEM.preprocess.get_maxima_2D docstring + minSpacing : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + edgeBoundary : number + See the py4DSTEM.preprocess.get_maxima_2D docstring + maxNumPeaks : int + See the py4DSTEM.preprocess.get_maxima_2D docstring + figsize : 2-tuple + the size of the figure + c_indices : color + color of the maxima + c0 : color + color of the origin + c1 : color + color of g1 point + c2 : color + color of g2 point + c_vectors : color + color of the g1/g2 vectors + c_vectorlabels : color + color of the vector labels + size_indices : number + size of the indices + width_vectors : number + width of the vectors + size_vectorlabels : number + size of the vector labels + vis_params : dict + additional visualization parameters passed to `show` + returncalc : bool + toggles returning the answer + returnfig : bool + toggles returning the figure + + Returns + ------- + (optional) : None or (g0,g1,g2) or (fig,(ax1,ax2)) or both of the latter + """ + # validate inputs + for i in (index_g0, index_g1, index_g2): + assert isinstance(i, (int, np.integer)), "indices must be integers!" + # check the calstate + assert ( + self.calstate == self.braggvectors.calstate + ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." + + # find the maxima + g = get_maxima_2D( + self.bvm.data, + subpixel=subpixel, + upsample_factor=upsample_factor, + sigma=sigma, + minAbsoluteIntensity=minAbsoluteIntensity, + minRelativeIntensity=minRelativeIntensity, + relativeToPeak=relativeToPeak, + minSpacing=minSpacing, + edgeBoundary=edgeBoundary, + maxNumPeaks=maxNumPeaks, + ) + + # get the lattice vectors + gx, gy = g["x"], g["y"] + g0 = gx[index_g0], gy[index_g0] + g1x = gx[index_g1] - g0[0] + g1y = gy[index_g1] - g0[1] + g2x = gx[index_g2] - g0[0] + g2y = gy[index_g2] - g0[1] + g1, g2 = (g1x, g1y), (g2x, g2y) + + # make the figure + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) + show(self.bvm.data, figax=(fig, ax1), **vis_params) + show(self.bvm.data, figax=(fig, ax2), **vis_params) + + # Add indices to left panel + d = {"x": gx, "y": gy, "size": size_indices, "color": c_indices} + d0 = { + "x": gx[index_g0], + "y": gy[index_g0], + "size": size_indices, + "color": c0, + "fontweight": "bold", + "labels": [str(index_g0)], + } + d1 = { + "x": gx[index_g1], + "y": gy[index_g1], + "size": size_indices, + "color": c1, + "fontweight": "bold", + "labels": [str(index_g1)], + } + d2 = { + "x": gx[index_g2], + "y": gy[index_g2], + "size": size_indices, + "color": c2, + "fontweight": "bold", + "labels": [str(index_g2)], + } + add_pointlabels(ax1, d) + add_pointlabels(ax1, d0) + add_pointlabels(ax1, d1) + add_pointlabels(ax1, d2) + + # Add vectors to right panel + dg1 = { + "x0": gx[index_g0], + "y0": gy[index_g0], + "vx": g1[0], + "vy": g1[1], + "width": width_vectors, + "color": c_vectors, + "label": r"$g_1$", + "labelsize": size_vectorlabels, + "labelcolor": c_vectorlabels, + } + dg2 = { + "x0": gx[index_g0], + "y0": gy[index_g0], + "vx": g2[0], + "vy": g2[1], + "width": width_vectors, + "color": c_vectors, + "label": r"$g_2$", + "labelsize": size_vectorlabels, + "labelcolor": c_vectorlabels, + } + add_vector(ax2, dg1) + add_vector(ax2, dg2) + + # store vectors + self.g = g + self.g0 = g0 + self.g1 = g1 + self.g2 = g2 + + # return + if returncalc and returnfig: + return (g0, g1, g2), (fig, (ax1, ax2)) + elif returncalc: + return (g0, g1, g2) + elif returnfig: + return (fig, (ax1, ax2)) + else: + return + + def fit_lattice_vectors( + self, + x0=None, + y0=None, + max_peak_spacing=2, + mask=None, + plot=True, + vis_params={}, + returncalc=False, + ): + """ + From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of + lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the + reciprocal lattice directions. + + Args: + x0 : floagt + x-coord of origin + y0 : float + y-coord of origin + max_peak_spacing: float + Maximum distance from the ideal lattice points + to include a peak for indexing + mask: bool + Boolean mask, same shape as the pointlistarray, indicating which + locations should be indexed. This can be used to index different regions of + the scan with different lattices + plot:bool + plot results if tru + vis_params : dict + additional visualization parameters passed to `show` + returncalc : bool + if True, returns bragg_directions, bragg_vectors_indexed, g1g2_map + """ + # check the calstate + assert ( + self.calstate == self.braggvectors.calstate + ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." + + if x0 is None: + x0 = self.braggvectors.Qshape[0] / 2 + if y0 is None: + y0 = self.braggvectors.Qshape[0] / 2 + + # index braggvectors + from py4DSTEM.process.latticevectors import index_bragg_directions + + _, _, braggdirections = index_bragg_directions( + x0, y0, self.g["x"], self.g["y"], self.g1, self.g2 + ) + + self.braggdirections = braggdirections + + if plot: + self.show_bragg_indexing( + self.bvm, + bragg_directions=braggdirections, + points=True, + **vis_params, + ) + + # add indicies to braggvectors + from py4DSTEM.process.latticevectors import add_indices_to_braggvectors + + bragg_vectors_indexed = add_indices_to_braggvectors( + self.braggvectors, + self.braggdirections, + maxPeakSpacing=max_peak_spacing, + qx_shift=self.braggvectors.Qshape[0] / 2, + qy_shift=self.braggvectors.Qshape[1] / 2, + mask=mask, + ) + + self.bragg_vectors_indexed = bragg_vectors_indexed + + # fit bragg vectors + from py4DSTEM.process.latticevectors import fit_lattice_vectors_all_DPs + + g1g2_map = fit_lattice_vectors_all_DPs(self.bragg_vectors_indexed) + self.g1g2_map = g1g2_map + + if returncalc: + braggdirections, bragg_vectors_indexed, g1g2_map + + def get_strain( + self, mask=None, g_reference=None, flip_theta=False, returncalc=False, **kwargs + ): + """ + mask: nd.array (bool) + Use lattice vectors from g1g2_map scan positions + wherever mask==True. If mask is None gets median strain + map from entire field of view. If mask is not None, gets + reference g1 and g2 from region and then calculates strain. + g_reference: nd.array of form [x,y] + G_reference (tupe): reference coordinate system for + xaxis_x and xaxis_y + flip_theta: bool + If True, flips rotation coordinate system + returncal: bool + It True, returns rotated map + """ + # check the calstate + assert ( + self.calstate == self.braggvectors.calstate + ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`." + + if mask is None: + mask = np.ones(self.g1g2_map.shape, dtype="bool") + + from py4DSTEM.process.latticevectors import get_strain_from_reference_region + + strainmap_g1g2 = get_strain_from_reference_region( + self.g1g2_map, + mask=mask, + ) + else: + from py4DSTEM.process.latticevectors import get_reference_g1g2 + + g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, mask) + + from py4DSTEM.process.latticevectors import get_strain_from_reference_g1g2 + + strainmap_g1g2 = get_strain_from_reference_g1g2( + self.g1g2_map, g1_ref, g2_ref + ) + + self.strainmap_g1g2 = strainmap_g1g2 + + if g_reference is None: + g_reference = np.subtract(self.g1, self.g2) + + from py4DSTEM.process.latticevectors import get_rotated_strain_map + + strainmap_rotated = get_rotated_strain_map( + self.strainmap_g1g2, + xaxis_x=g_reference[0], + xaxis_y=g_reference[1], + flip_theta=flip_theta, + ) + + self.strainmap_rotated = strainmap_rotated + + from py4DSTEM.visualize import show_strain + + figsize = kwargs.pop("figsize", (14, 4)) + vrange_exx = kwargs.pop("vrange_exx", [-2.0, 2.0]) + vrange_theta = kwargs.pop("vrange_theta", [-2.0, 2.0]) + ticknumber = kwargs.pop("ticknumber", 3) + bkgrd = kwargs.pop("bkgrd", False) + axes_plots = kwargs.pop("axes_plots", ()) + + fig, ax = show_strain( + self.strainmap_rotated, + vrange_exx=vrange_exx, + vrange_theta=vrange_theta, + ticknumber=ticknumber, + axes_plots=axes_plots, + bkgrd=bkgrd, + figsize=figsize, + **kwargs, + returnfig=True, + ) + + if not np.all(mask == True): + ax[0][0].imshow(mask, alpha=0.2, cmap="binary") + ax[0][1].imshow(mask, alpha=0.2, cmap="binary") + ax[1][0].imshow(mask, alpha=0.2, cmap="binary") + ax[1][1].imshow(mask, alpha=0.2, cmap="binary") + + if returncalc: + return self.strainmap_rotated + + def show_lattice_vectors( + ar, + x0, + y0, + g1, + g2, + color="r", + width=1, + labelsize=20, + labelcolor="w", + returnfig=False, + **kwargs, + ): + """Adds the vectors g1,g2 to an image, with tail positions at (x0,y0). g1 and g2 are 2-tuples (gx,gy).""" + fig, ax = show(ar, returnfig=True, **kwargs) + + # Add vectors + dg1 = { + "x0": x0, + "y0": y0, + "vx": g1[0], + "vy": g1[1], + "width": width, + "color": color, + "label": r"$g_1$", + "labelsize": labelsize, + "labelcolor": labelcolor, + } + dg2 = { + "x0": x0, + "y0": y0, + "vx": g2[0], + "vy": g2[1], + "width": width, + "color": color, + "label": r"$g_2$", + "labelsize": labelsize, + "labelcolor": labelcolor, + } + add_vector(ax, dg1) + add_vector(ax, dg2) + + if returnfig: + return fig, ax + else: + plt.show() + return + + def show_bragg_indexing( + self, + ar, + bragg_directions, + voffset=5, + hoffset=0, + color="w", + size=20, + points=True, + pointcolor="r", + pointsize=50, + returnfig=False, + **kwargs, + ): + """ + Shows an array with an overlay describing the Bragg directions + + Accepts: + ar (arrray) the image + bragg_directions (PointList) the bragg scattering directions; must have coordinates + 'qx','qy','h', and 'k'. Optionally may also have 'l'. + """ + assert isinstance(bragg_directions, PointList) + for k in ("qx", "qy", "h", "k"): + assert k in bragg_directions.data.dtype.fields + + fig, ax = show(ar, returnfig=True, **kwargs) + d = { + "bragg_directions": bragg_directions, + "voffset": voffset, + "hoffset": hoffset, + "color": color, + "size": size, + "points": points, + "pointsize": pointsize, + "pointcolor": pointcolor, + } + add_bragg_index_labels(ax, d) + + if returnfig: + return fig, ax + else: + plt.show() + return + + def copy(self, name=None): + name = name if name is not None else self.name + "_copy" + strainmap_copy = StrainMap(self.braggvectors) + for attr in ( + "g", + "g0", + "g1", + "g2", + "calstate", + "bragg_directions", + "bragg_vectors_indexed", + "g1g2_map", + "strainmap_g1g2", + "strainmap_rotated", + ): + if hasattr(self, attr): + setattr(strainmap_copy, attr, getattr(self, attr)) + + for k in self.metadata.keys(): + strainmap_copy.metadata = self.metadata[k].copy() + return strainmap_copy + + # IO methods + + # read + @classmethod + def _get_constructor_args(cls, group): + """ + Returns a dictionary of args/values to pass to the class constructor + """ + ar_constr_args = RealSlice._get_constructor_args(group) + args = { + "data": ar_constr_args["data"], + "name": ar_constr_args["name"], + } + return args diff --git a/py4DSTEM/process/utils/elliptical_coords.py b/py4DSTEM/process/utils/elliptical_coords.py index 8bc80db79..97291bc20 100644 --- a/py4DSTEM/process/utils/elliptical_coords.py +++ b/py4DSTEM/process/utils/elliptical_coords.py @@ -151,7 +151,7 @@ def cartesian_to_polarelliptical_transform( * **pp**: *(2D array)* meshgrid of the phi coordinates """ if mask is None: - mask = np.ones_like(cartesianData, dtype=bool) + mask = np.ones_like(cartesianData.data, dtype=bool) assert ( cartesianData.shape == mask.shape ), "Mask and cartesian data array shapes must match." diff --git a/py4DSTEM/process/utils/utils.py b/py4DSTEM/process/utils/utils.py index 29badd389..86257b4dc 100644 --- a/py4DSTEM/process/utils/utils.py +++ b/py4DSTEM/process/utils/utils.py @@ -59,7 +59,7 @@ def radial_reduction( def plot(img, title='Image', savePath=None, cmap='inferno', show=True, vmax=None, figsize=(10, 10), scale=None): fig, ax = plt.subplots(figsize=figsize) - im = ax.imshow(img, interpolation='nearest', cmap=plt.cm.get_cmap(cmap), vmax=vmax) + im = ax.imshow(img, interpolation='nearest', cmap=plt.get_cmap(cmap), vmax=vmax) divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im, cax=cax) @@ -174,21 +174,24 @@ def make_Fourier_coords2D(Nx, Ny, pixelSize=1): qy, qx = np.meshgrid(qy, qx) return qx, qy - -def get_CoM(ar, device = "cpu"): +def get_CoM(ar, device="cpu", corner_centered=False): """ Finds and returns the center of mass of array ar. + If corner_centered is True, uses fftfreq for indices. """ if device == "cpu": xp = np - elif device == "gpu": xp = cp - ar = xp.asarray(ar) - + ar = xp.asarray(ar) nx, ny = ar.shape - ry, rx = xp.meshgrid(xp.arange(ny), xp.arange(nx)) + + if corner_centered: + ry, rx = xp.meshgrid(xp.fft.fftfreq(ny, 1 / ny), xp.fft.fftfreq(nx, 1 / nx)) + else: + ry, rx = xp.meshgrid(xp.arange(ny), xp.arange(nx)) + tot_intens = xp.sum(ar) xCoM = xp.sum(rx * ar) / tot_intens yCoM = xp.sum(ry * ar) / tot_intens @@ -633,7 +636,7 @@ def fourier_resample( #def plot(img, title='Image', savePath=None, cmap='inferno', show=True, vmax=None, # figsize=(10, 10), scale=None): # fig, ax = plt.subplots(figsize=figsize) -# im = ax.imshow(img, interpolation='nearest', cmap=plt.cm.get_cmap(cmap), vmax=vmax) +# im = ax.imshow(img, interpolation='nearest', cmap=plt.get_cmap(cmap), vmax=vmax) # divider = make_axes_locatable(ax) # cax = divider.append_axes("right", size="5%", pad=0.05) # plt.colorbar(im, cax=cax) diff --git a/py4DSTEM/process/virtualdiffraction.py b/py4DSTEM/process/virtualdiffraction.py deleted file mode 100644 index ba31b778a..000000000 --- a/py4DSTEM/process/virtualdiffraction.py +++ /dev/null @@ -1,223 +0,0 @@ -# Functions for generating diffraction images - -import numpy as np -from emdfile import tqdmnd - -def get_virtual_diffraction( - datacube, - method, - mode = None, - geometry = None, - calibrated = False, - shift_center = False, - verbose = True, - return_mask = False, -): - ''' - Function to calculate virtual diffraction - - Args: - datacube (Datacube) : datacube class object which stores 4D-dataset - needed for calculation - method (str) : defines method used for diffraction pattern, options are - 'mean', 'median', and 'max' - mode (str) : defines mode for selecting area in real space to use for - virtual diffraction. The default is None, which means no - geometry will be applied and the whole datacube will be used - for the calculation. Options: - - 'point' uses singular point as detector - - 'circle' or 'circular' uses round detector, like bright field - - 'annular' or 'annulus' uses annular detector, like dark field - - 'rectangle', 'square', 'rectangular', uses rectangular detector - - 'mask' flexible detector, any 2D array - geometry (variable) : valid entries are determined by the `mode`, values - in pixels argument, as follows. The default is None, which means no - geometry will be applied and the whole datacube will be used for the - calculation. If mode is None the geometry will not be applied. - - 'point': 2-tuple, (rx,ry), - qx and qy are each single float or int to define center - - 'circle' or 'circular': nested 2-tuple, ((rx,ry),radius), - qx, qy and radius, are each single float or int - - 'annular' or 'annulus': nested 2-tuple, ((rx,ry),(radius_i,radius_o)), - qx, qy, radius_i, and radius_o are each single float or integer - - 'rectangle', 'square', 'rectangular': 4-tuple, (xmin,xmax,ymin,ymax) - - `mask`: flexible detector, any boolean or floating point 2D array with - the same shape as datacube.Rshape - calibrated (bool): if True, geometry is specified in units of 'A' instead - of pixels. The datacube's calibrations must have its `"R_pixel_units"` - parameter set to "A". If mode is None the geometry and calibration will - not be applied. - shift_center (bool): if True, the difraction pattern is shifted to account - for beam shift or the changing of the origin through the scan. The - datacube's calibration['origin'] parameter must be set Only 'max' and - 'mean' supported for this option. - verbose (bool): if True, show progress bar - return_mask (bool): if False (default) returns a virtual image as usual. - If True, does *not* generate or return a virtual image, instead - returning the mask that would be used in virtual diffraction computation. - - Returns: - (2D array): the diffraction image - ''' - - assert method in ('max', 'median', 'mean'),\ - 'check doc strings for supported types' - - #create mask - if mode is not None: - use_all_points = False - - from py4DSTEM.process.virtualimage import make_detector - assert mode in ('point', 'circle', 'circular', 'annulus', 'annular', 'rectangle', 'square', 'rectangular', 'mask'),\ - 'check doc strings for supported modes' - g = geometry - - if calibrated == True: - assert datacube.calibration['R_pixel_units'] == 'A', \ - 'check datacube.calibration. datacube must be calibrated in A to use `calibrated=True`' - - unit_conversion = datacube.calibration['R_pixel_size'] - if mode == 'point': - g = (g[0]/unit_conversion, g[1]/unit_conversion) - if mode in('circle', 'circular'): - g = ((g[0][0]/unit_conversion, g[0][1]/unit_conversion), - (g[1]/unit_conversion)) - if mode in('annulus', 'annular'): - g = ((g[0][0]/unit_conversion, g[0][1]/unit_conversion), - (g[1][0]/unit_conversion, g[1][1]/unit_conversion)) - if mode in('rectangle', 'square', 'rectangular') : - g = (g[0]/unit_conversion, g[1]/unit_conversion, - g[2]/unit_conversion, g[3]/unit_conversion) - - # Get mask - mask = make_detector(datacube.Rshape, mode, g) - - # Determine if mask is boolean and if so, vectorize - if mask.dtype == bool: - mask_is_boolean = True - mask_indices = np.nonzero(mask) - else: - mask_is_boolean = False - - #if no mask - else: - mask = np.ones(datacube.Rshape, dtype=bool) - mask_indices = np.nonzero(mask) - use_all_points = True - mask_is_boolean = False - - # if return_mask is True, skip computation - if return_mask == True: - return mask - - # Calculate diffracton pattern... - - # ...with no center shifting - if shift_center == False: - - # ...for the whole pattern - if use_all_points: - if method == 'mean': - virtual_diffraction = np.mean(datacube.data, axis=(0,1)) - elif method == 'max': - virtual_diffraction = np.max(datacube.data, axis=(0,1)) - else: - virtual_diffraction = np.median(datacube.data, axis=(0,1)) - - # ...for boolean masks - elif mask_is_boolean: - if method == 'mean': - virtual_diffraction = np.mean(datacube.data[mask_indices[0],mask_indices[1],:,:], axis=0) - elif method == 'max': - virtual_diffraction = np.max(datacube.data[mask_indices[0],mask_indices[1],:,:], axis=0) - else: - virtual_diffraction = np.median(datacube.data[mask_indices[0],mask_indices[1],:,:], axis=0) - - # ...for floating point masks - else: - if mask.dtype == 'complex': - virtual_diffraction = np.zeros(datacube.Qshape, dtype = 'complex') - else: - virtual_diffraction = np.zeros(datacube.Qshape) - for qx,qy in tqdmnd( - datacube.Q_Nx, - datacube.Q_Ny, - disable = not verbose, - ): - if method == 'mean': - virtual_diffraction[qx,qy] = np.sum( np.squeeze(datacube.data[:,:,qx,qy])*mask ) - elif method == 'max': - virtual_diffraction[qx,qy] = np.max( np.squeeze(datacube.data[:,:,qx,qy])*mask ) - elif method == 'median': - virtual_diffraction[qx,qy] = np.median( np.squeeze(datacube.data[:,:,qx,qy])*mask ) - - # norm by weighting term for means - if method == 'mean' and not use_all_points: - virtual_diffraction /= np.sum(mask) - - # ...with center shifting - else: - assert method in ('max', 'mean'),\ - "only 'mean' and 'max' are supported for center-shifted virtual diffraction" - - # Get calibration metadata - assert datacube.calibration.get_origin(), "origin needs to be calibrated" - x0, y0 = datacube.calibration.get_origin() - x0_mean, y0_mean = datacube.calibration.get_origin_mean() - - # get shifts - qx_shift = (x0_mean-x0).round().astype(int) - qy_shift = (y0_mean-y0).round().astype(int) - - - # compute... - - # ...for boolean masks / whole datacubes - if mask_is_boolean or use_all_points: - virtual_diffraction = np.zeros(datacube.Qshape) - for rx,ry in zip(mask_indices[0],mask_indices[1]): - # get shifted DP - DP = np.roll( - datacube.data[rx,ry, :,:,], - (qx_shift[rx,ry], qy_shift[rx,ry]), - axis=(0,1), - ) - # compute - if method == 'mean': - virtual_diffraction += DP - elif method == 'max': - virtual_diffraction = np.maximum(virtual_diffraction, DP) - if method == 'mean': - virtual_diffraction /= len(mask_indices[0]) - - # ...for floating point masks - else: - if mask.dtype == 'complex': - virtual_diffraction = np.zeros(datacube.Qshape, dtype = 'complex') - else: - virtual_diffraction = np.zeros(datacube.Qshape) - for rx,ry in tqdmnd( - datacube.R_Nx, - datacube.R_Ny, - disable = not verbose, - ): - # get shifted DP - DP = np.roll( - datacube.data[rx,ry, :,:,], - (qx_shift[rx,ry], qy_shift[rx,ry]), - axis=(0,1), - ) - # compute - w = mask[rx,ry] - if w > 0: - if method == 'mean': - virtual_diffraction += DP*w - elif method == 'max': - virtual_diffraction = np.maximum(virtual_diffraction, DP*w) - if method == 'mean': - virtual_diffraction /= np.sum(mask) - - # return - return virtual_diffraction - - diff --git a/py4DSTEM/process/virtualimage.py b/py4DSTEM/process/virtualimage.py deleted file mode 100644 index 1c5e5e1da..000000000 --- a/py4DSTEM/process/virtualimage.py +++ /dev/null @@ -1,537 +0,0 @@ -# Functions for generating virtual images -import numpy as np -import dask.array as da -from emdfile import tqdmnd -from py4DSTEM.classes import Calibration - -def get_virtual_image( - datacube, - mode, - geometry, - centered = False, - calibrated = False, - shift_center = False, - verbose = True, - dask = False, - return_mask = False, - test_config = False -): - ''' - Function to calculate virtual image - - Args: - datacube (Datacube) : datacube class object which stores 4D-dataset - needed for calculation - mode (str) : defines geometry mode for calculating virtual image. - Options: - - 'point' uses singular point as detector - - 'circle' or 'circular' uses round detector, like bright field - - 'annular' or 'annulus' uses annular detector, like dark field - - 'rectangle', 'square', 'rectangular', uses rectangular detector - - 'mask' flexible detector, any 2D array - geometry (variable) : valid entries are determined by the `mode`, values - in pixels argument, as follows: - - 'point': 2-tuple, (qx,qy), - qx and qy are each single float or int to define center - - 'circle' or 'circular': nested 2-tuple, ((qx,qy),radius), - qx, qy and radius, are each single float or int - - 'annular' or 'annulus': nested 2-tuple, ((qx,qy), - (radius_i,radius_o)) - - 'rectangle', 'square', 'rectangular': a 4-tuple, - (xmin,xmax,ymin,ymax) - - `mask`: flexible detector, any boolean or floating point - 2D array with the same shape as datacube.Qshape - centered (bool): if False (default), the origin is in the upper left - corner. If True, the mean measured origin in the datacube - calibrations is set as center. The measured origin is set with - datacube.calibration.set_origin(). In this case, for example, a - centered bright field image could be defined by - geometry = ((0,0), R). For `mode="mask"`, has no effect. - calibrated (bool): if True, geometry is specified in units of 'A^-1' - instead of pixels. The datacube's calibrations must have its - `"Q_pixel_units"` parameter set to "A^-1". For `mode="mask"`, has - no effect. - shift_center (bool): if True, the mask is shifted at each real space - position to account for any shifting of the origin of the - diffraction images. The datacube's calibration['origin'] parameter - must be set (centered = True). The shift applied to each pattern is - the difference between the local origin position and the mean origin - position over all patterns, rounded to the nearest integer for speed. - verbose (bool): if True, show progress bar - dask (bool): if True, use dask arrays - return_mask (bool or tuple): if False (default) returns a virtual image - as usual. If True, does *not* generate or return a virtual image, - instead returning the mask that would be used in virtual image - computation for any call to this function where - `shift_center = False`. Otherwise, must be a 2-tuple of integers - corresponding to a scan position (rx,ry); in this case, returns the - mask that would be used for virtual image computation at this scan - position with `shift_center` set to `True`. - test_config: if True, returns the Boolean value of (`centered`, - `calibrated`,`shift_center`). Does not compute the virtual image. - - Returns: - (2D array) virtual image - ''' - - assert mode in ('point', 'circle', 'circular', 'annulus', 'annular', 'rectangle', 'square', 'rectangular', 'mask'),\ - 'check doc strings for supported modes' - if shift_center == True: - assert centered, "centered must be True if shift_center is True" - if test_config: - for x,y in zip(['centered','calibrated','shift_center'], - [centered,calibrated,shift_center]): - print(f"{x} = {y}") - - # Get geometry - g = get_calibrated_geometry( - datacube, - mode, - geometry, - centered, - calibrated - ) - - # Get mask - mask = make_detector(datacube.Qshape, mode, g) - # if return_mask is True, skip computation - if return_mask == True and shift_center == False: - return mask - - - # Calculate images - - # no center shifting - if shift_center == False: - # dask - if dask == True: - - # set up a generalized universal function for dask distribution - def _apply_mask_dask(datacube,mask): - virtual_image = np.sum(np.multiply(datacube.data,mask), dtype=np.float64) - apply_mask_dask = da.as_gufunc( - _apply_mask_dask,signature='(i,j),(i,j)->()', - output_dtypes=np.float64, - axes=[(2,3),(0,1),()], - vectorize=True - ) - - # compute - virtual_image = apply_mask_dask(datacube.data, mask) - - # non-dask - else: - - # compute - if mask.dtype == 'complex': - virtual_image = np.zeros(datacube.Rshape, dtype = 'complex') - else: - virtual_image = np.zeros(datacube.Rshape) - for rx,ry in tqdmnd( - datacube.R_Nx, - datacube.R_Ny, - disable = not verbose, - ): - virtual_image[rx,ry] = np.sum(datacube.data[rx,ry]*mask) - - # with center shifting - else: - - # get shifts - assert datacube.calibration.get_origin_shift(), "origin need to be calibrated" - qx_shift,qy_shift = datacube.calibration.get_origin_shift() - qx_shift = qx_shift.round().astype(int) - qy_shift = qy_shift.round().astype(int) - - # if return_mask is True, skip computation - if return_mask is not False: - try: - rx,ry = return_mask - except TypeError: - raise Exception("when `shift_center` is True, return_mask must be a 2-tuple of ints or False") - # get shifted mask - _mask = np.roll( - mask, - (qx_shift[rx,ry], qy_shift[rx,ry]), - axis=(0,1) - ) - return _mask - - # compute - if mask.dtype == 'complex': - virtual_image = np.zeros(datacube.Rshape, dtype = 'complex') - else: - virtual_image = np.zeros(datacube.Rshape) - - for rx,ry in tqdmnd( - datacube.R_Nx, - datacube.R_Ny, - disable = not verbose, - ): - # get shifted mask - _mask = np.roll( - mask, - (qx_shift[rx,ry], qy_shift[rx,ry]), - axis=(0,1) - ) - virtual_image[rx,ry] = np.sum(datacube.data[rx,ry]*_mask) - - return virtual_image - - -def get_calibrated_geometry( - calibration, - mode, - geometry, - centered, - calibrated - ): - """ - Determine the detector geometry in pixels, given some mode and geometry - in calibrated units, where the calibration state is specified by { - centered, calibrated} - - Args: - calibration (Calibration, DataCube, any object with a .calibration attr, - or None) Used to retrieve the center positions. If `None`, confirms that - centered and calibrated are False then passes - mode: see py4DSTEM.process.virtualimage.get_virtual_image - geometry: see py4DSTEM.process.virtualimage.get_virtual_image - centered: see py4DSTEM.process.virtualimage.get_virtual_image - calibrated: see py4DSTEM.process.virtualimage.get_virtual_image - - Returns: - (tuple) the geometry in detector pixels - """ - # Parse inputs - g = geometry - if calibration is None: - assert calibrated is False and centered is False - return g - elif isinstance(calibration, Calibration): - cal = calibration - else: - try: - cal = calibration.calibration - assert isinstance(cal, Calibration), "`calibration.calibration` must be a Calibration instance" - except AttributeError: - raise Exception("`calibration` must either be a Calibration instance or have a .calibration attribute") - - # Get calibration metadata - if centered: - assert cal.get_qx0_mean(), "origin needs to be calibrated" - x0_mean, y0_mean = cal.get_origin_mean() - - if calibrated: - assert cal['Q_pixel_units'] == 'A^-1', \ - 'check calibration - must be calibrated in A^-1 to use `calibrated=True`' - unit_conversion = cal.get_Q_pixel_size() - - - # Convert units into detector pixels - - # Shift center - if centered == True: - if mode == 'point': - g = (g[0] + x0_mean, g[1] + y0_mean) - if mode in('circle', 'circular', 'annulus', 'annular'): - g = ((g[0][0] + x0_mean, g[0][1] + y0_mean), g[1]) - if mode in('rectangle', 'square', 'rectangular') : - g = (g[0] + x0_mean, g[1] + x0_mean, g[2] + y0_mean, g[3] + y0_mean) - - # Scale by the detector pixel size - if calibrated == True: - if mode == 'point': - g = (g[0]/unit_conversion, g[1]/unit_conversion) - if mode in('circle', 'circular'): - g = ((g[0][0]/unit_conversion, g[0][1]/unit_conversion), - (g[1]/unit_conversion)) - if mode in('annulus', 'annular'): - g = ((g[0][0]/unit_conversion, g[0][1]/unit_conversion), - (g[1][0]/unit_conversion, g[1][1]/unit_conversion)) - if mode in('rectangle', 'square', 'rectangular') : - g = (g[0]/unit_conversion, g[1]/unit_conversion, - g[2]/unit_conversion, g[3]/unit_conversion) - - return g - - - -def make_detector( - shape, - mode, - geometry, -): - ''' - Function to return 2D mask - - Args: - shape (tuple) : defines shape of mask, for example (Q_Nx, Q_Ny) where Q_Nx and Q_Ny are mask sizes - mode (str) : defines geometry mode for calculating virtual image - options: - - 'point' uses singular point as detector - - 'circle' or 'circular' uses round detector, like bright field - - 'annular' or 'annulus' uses annular detector, like dark field - - 'rectangle', 'square', 'rectangular', uses rectangular detector - - 'mask' flexible detector, any boolean or floating point 2D array with - the same shape as datacube.Qshape or datacube.Rshape for virtual image - or diffraction image respectively - geometry (variable) : valid entries are determined by the `mode`, values in pixels - argument, as follows: - - 'point': 2-tuple, (qx,qy), - qx and qy are each single float or int to define center - - 'circle' or 'circular': nested 2-tuple, ((qx,qy),radius), - qx, qy and radius, are each single float or int - - 'annular' or 'annulus': nested 2-tuple, ((qx,qy),(radius_i,radius_o)), - qx, qy, radius_i, and radius_o are each single float or integer - - 'rectangle', 'square', 'rectangular': 4-tuple, (xmin,xmax,ymin,ymax) - - `mask`: flexible detector, any boolean or floating point 2D array with the - same shape as datacube.Qshape or datacube.Rshapefor virtual image - or diffraction image respectively - - Returns: - virtual detector in the form of a 2D mask (array) - ''' - g = geometry - - #point mask - if mode == 'point': - assert(isinstance(g,tuple) and len(g)==2), 'specify qx and qy as tuple (qx, qy)' - mask = np.zeros(shape, dtype=bool) - - qx = int(g[0]) - qy = int(g[1]) - - mask[qx,qy] = 1 - - #circular mask - if mode in('circle', 'circular'): - assert(isinstance(g,tuple) and len(g)==2 and len(g[0])==2 and isinstance(g[1],(float,int))), \ - 'specify qx, qy, radius_i as ((qx, qy), radius)' - - qxa, qya = np.indices(shape) - mask = (qxa - g[0][0]) ** 2 + (qya - g[0][1]) ** 2 < g[1] ** 2 - - #annular mask - if mode in('annulus', 'annular'): - assert(isinstance(g,tuple) and len(g)==2 and len(g[0])==2 and len(g[1])==2), \ - 'specify qx, qy, radius_i, radius_0 as ((qx, qy), (radius_i, radius_o))' - - assert g[1][1] > g[1][0], "Inner radius must be smaller than outer radius" - - qxa, qya = np.indices(shape) - mask1 = (qxa - g[0][0]) ** 2 + (qya - g[0][1]) ** 2 > g[1][0] ** 2 - mask2 = (qxa - g[0][0]) ** 2 + (qya - g[0][1]) ** 2 < g[1][1] ** 2 - mask = np.logical_and(mask1, mask2) - - #rectangle mask - if mode in('rectangle', 'square', 'rectangular') : - assert(isinstance(g,tuple) and len(g)==4), \ - 'specify x_min, x_max, y_min, y_max as (x_min, x_max, y_min, y_max)' - mask = np.zeros(shape, dtype=bool) - - xmin = int(np.round(g[0])) - xmax = int(np.round(g[1])) - ymin = int(np.round(g[2])) - ymax = int(np.round(g[3])) - - mask[xmin:xmax, ymin:ymax] = 1 - - #flexible mask - if mode == 'mask': - assert type(g) == np.ndarray, '`geometry` type should be `np.ndarray`' - assert (g.shape == shape), 'mask and diffraction pattern shapes do not match' - mask = g - return mask - -def make_bragg_mask( - Qshape, - g1, - g2, - radius, - origin, - max_q, - return_sum = True, - **kwargs, - ): - ''' - Creates and returns a mask consisting of circular disks - about the points of a 2D lattice. - - Args: - Qshape (2 tuple): the shape of diffraction space - g1,g2 (len 2 array or tuple): the lattice vectors - radius (number): the disk radius - origin (len 2 array or tuple): the origin - max_q (nuumber): the maxima distance to tile to - return_sum (bool): if False, return a 3D array, where each - slice contains a single disk; if False, return a single - 2D masks of all disks - - Returns: - (2 or 3D array) the mask - ''' - nas = np.asarray - g1,g2,origin = nas(g1),nas(g2),nas(origin) - - # Get N,M, the maximum indices to tile out to - L1 = np.sqrt(np.sum(g1**2)) - H = int(max_q/L1) + 1 - L2 = np.hypot(-g2[0]*g1[1],g2[1]*g1[0])/np.sqrt(np.sum(g1**2)) - K = int(max_q/L2) + 1 - - # Compute number of points - N = 0 - for h in range(-H,H+1): - for k in range(-K,K+1): - v = h*g1 + k*g2 - if np.sqrt(v.dot(v)) < max_q: - N += 1 - - #create mask - mask = np.zeros((Qshape[0], Qshape[1], N), dtype=bool) - N = 0 - for h in range(-H,H+1): - for k in range(-K,K+1): - v = h*g1 + k*g2 - if np.sqrt(v.dot(v)) < max_q: - center = origin + v - mask[:,:,N] = make_detector( - Qshape, - mode = 'circle', - geometry = (center, radius), - ) - N += 1 - - - if return_sum: - mask = np.sum(mask, axis = 2) - return mask - - - - - - - - - - - - - -def get_virtual_image_pointlistarray( - peaks, - mode = None, - geometry = None, - ): - """ - Make a virtual image from a pointlist array. - TODO - implement more virtual detectors. - - Args: - peaks (PointListArray): List of all peaks and intensities. - mode (str) : defines geometry mode for calculating virtual image. - Options: - - 'circular' or 'circle' uses round detector, like bright field - - 'annular' or 'annulus' uses annular detector, like dark field - geometry (variable) : valid entries are determined by the `mode`, values in pixels - argument, as follows: - - 'circle' or 'circular': nested 2-tuple, ((qx,qy),radius), - qx, qy and radius, are each single float or int - - 'annular' or 'annulus': nested 2-tuple, ((qx,qy),(radius_i,radius_o)), - qx, qy, radius_i, and radius_o are each single float or integer - - Note that (qx,qy) can be skipped, which assumes peaks centered at (0,0) - - Returns: - im_virtual (2D numpy array): the calculated virtual image - """ - - # Set geometry - if mode is None: - if geometry is None: - center = None - radial_range = np.array((0,np.inf)) - else: - if len(geometry[0]) == 0: - center = None - else: - center = np.array(geometry[0]) - if isinstance(geometry[1], int) or isinstance(geometry[1], float): - radial_range = np.array((0,geometry[1])) - elif len(geometry[1]) == 0: - radial_range = None - else: - radial_range = np.array(geometry[1]) - elif mode == 'circular' or mode == 'circle': - radial_range = np.array((0,geometry[1])) - if len(geometry[0]) == 0: - center = None - else: - center = np.array(geometry[0]) - elif mode == 'annular' or mode == 'annulus': - radial_range = np.array(geometry[1]) - if len(geometry[0]) == 0: - center = None - else: - center = np.array(geometry[0]) - - - - # init - im_virtual = np.zeros(peaks.shape) - - # Generate image - for rx,ry in tqdmnd(peaks.shape[0],peaks.shape[1]): - p = peaks.get_pointlist(rx,ry) - if p.data.shape[0] > 0: - if radial_range is None: - im_virtual[rx,ry] = np.sum(p.data['intensity']) - else: - if center is None: - qr = np.hypot(p.data['qx'],p.data['qy']) - else: - qr = np.hypot(p.data['qx'] - center[0],p.data['qy'] - center[1]) - sub = np.logical_and( - qr >= radial_range[0], - qr < radial_range[1]) - if np.sum(sub) > 0: - im_virtual[rx,ry] = np.sum(p.data['intensity'][sub]) - - return im_virtual - - -def get_virtual_image_braggvectors( - bragg_peaks, - mode = None, - geometry = None, - ): - ''' - Function to calculate virtual images from braggvectors / pointlist arrays. - TODO - implement these detectors for braggvectors - - Args: - bragg_peaks (BraggVectors) : BraggVectors class object which stores bragg peaks - mode (str) : defines geometry mode for calculating virtual image. - Options: - - 'circular' or 'circle' uses round detector, like bright field - - 'annular' or 'annulus' uses annular detector, like dark field - geometry (variable) : valid entries are determined by the `mode`, values in pixels - argument, as follows: - - 'circle' or 'circular': nested 2-tuple, ((qx,qy),radius), - qx, qy and radius, are each single float or int - - 'annular' or 'annulus': nested 2-tuple, ((qx,qy),(radius_i,radius_o)), - qx, qy, radius_i, and radius_o are each single float or integer - - Note that (qx,qy) can be skipped, which assumes peaks centered at (0,0) - - Returns: - im_virtual (2D numpy array): the calculated virtual image - ''' - - virtual_image = get_virtual_image_pointlistarray( - bragg_peaks.vectors, - mode = mode, - geometry = geometry, - ) - - return virtual_image diff --git a/py4DSTEM/utils/configuration_checker.py b/py4DSTEM/utils/configuration_checker.py index 859443723..0ccf633fc 100644 --- a/py4DSTEM/utils/configuration_checker.py +++ b/py4DSTEM/utils/configuration_checker.py @@ -15,7 +15,6 @@ 'gdown', 'h5py', 'ipyparallel', - 'ipywidgets', 'jax', 'matplotlib', 'mp_api', @@ -42,7 +41,6 @@ 'matplotlib', 'skimage', 'sklearn', - 'ipywidgets', 'tqdm', 'dill', 'gdown', diff --git a/py4DSTEM/version.py b/py4DSTEM/version.py index 4009e43c9..9df5075b8 100644 --- a/py4DSTEM/version.py +++ b/py4DSTEM/version.py @@ -1,2 +1,2 @@ -__version__='0.14.2' +__version__='0.14.3' diff --git a/py4DSTEM/visualize/__init__.py b/py4DSTEM/visualize/__init__.py index 40fdedca6..d5c183ab5 100644 --- a/py4DSTEM/visualize/__init__.py +++ b/py4DSTEM/visualize/__init__.py @@ -3,5 +3,4 @@ from py4DSTEM.visualize.vis_RQ import * from py4DSTEM.visualize.vis_grid import * from py4DSTEM.visualize.vis_special import * -from py4DSTEM.visualize.virtualimage import * diff --git a/py4DSTEM/visualize/overlay.py b/py4DSTEM/visualize/overlay.py index e0c87a427..7e7147a15 100644 --- a/py4DSTEM/visualize/overlay.py +++ b/py4DSTEM/visualize/overlay.py @@ -437,7 +437,7 @@ def add_bragg_index_labels(ax,d): Adds labels for indexed bragg directions to a plot, using the parameters in dict d. The dictionary d has required and optional parameters as follows: - braggdirections (req'd) (PointList) the Bragg directions. This PointList must have + bragg_directions (req'd) (PointList) the Bragg directions. This PointList must have the fields 'qx','qy','h', and 'k', and may optionally have 'l' voffset (number) vertical offset for the labels hoffset (number) horizontal offset for the labels @@ -450,12 +450,12 @@ def add_bragg_index_labels(ax,d): # handle inputs assert isinstance(ax,Axes) # bragg directions - assert('braggdirections' in d.keys()) - braggdirections = d['braggdirections'] - assert isinstance(braggdirections,PointList) + assert('bragg_directions' in d.keys()) + bragg_directions = d['bragg_directions'] + assert isinstance(bragg_directions,PointList) for k in ('qx','qy','h','k'): - assert k in braggdirections.data.dtype.fields - include_l = True if 'l' in braggdirections.data.dtype.fields else False + assert k in bragg_directions.data.dtype.fields + include_l = True if 'l' in bragg_directions.data.dtype.fields else False # offsets hoffset = d['hoffset'] if 'hoffset' in d.keys() else 0 voffset = d['voffset'] if 'voffset' in d.keys() else 5 @@ -474,20 +474,20 @@ def add_bragg_index_labels(ax,d): # add the points if points: - ax.scatter(braggdirections.data['qy'],braggdirections.data['qx'], + ax.scatter(bragg_directions.data['qy'],bragg_directions.data['qx'], color=pointcolor,s=pointsize) # add index labels - for i in range(braggdirections.length): - x,y = braggdirections.data['qx'][i],braggdirections.data['qy'][i] + for i in range(bragg_directions.length): + x,y = bragg_directions.data['qx'][i],bragg_directions.data['qy'][i] x -= voffset y += hoffset - h,k = braggdirections.data['h'][i],braggdirections.data['k'][i] + h,k = bragg_directions.data['h'][i],bragg_directions.data['k'][i] h = str(h) if h>=0 else r'$\overline{{{}}}$'.format(np.abs(h)) k = str(k) if k>=0 else r'$\overline{{{}}}$'.format(np.abs(k)) s = h+','+k if include_l: - l = braggdirections.data['l'][i] + l = bragg_directions.data['l'][i] l = str(l) if l>=0 else r'$\overline{{{}}}$'.format(np.abs(l)) s += l ax.text(y,x,s,color=color,size=size,ha='center',va='bottom') diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py index 45a5b7395..3b9d99e43 100644 --- a/py4DSTEM/visualize/show.py +++ b/py4DSTEM/visualize/show.py @@ -9,7 +9,7 @@ from math import log from copy import copy -from py4DSTEM.classes import ( +from py4DSTEM.data import ( Calibration, DiffractionSlice, RealSlice @@ -553,7 +553,7 @@ def show( # Create colormap with mask_color for bad values - cm = copy(plt.cm.get_cmap(cmap)) + cm = copy(plt.get_cmap(cmap)) if mask_color=='empty': cm.set_bad(alpha=0) else: diff --git a/py4DSTEM/visualize/virtualimage.py b/py4DSTEM/visualize/virtualimage.py deleted file mode 100644 index 87931bfba..000000000 --- a/py4DSTEM/visualize/virtualimage.py +++ /dev/null @@ -1,129 +0,0 @@ -import numpy as np - -from py4DSTEM.classes import Calibration, DataCube, DiffractionSlice -from py4DSTEM.visualize.show import show - - -def position_detector( - data, - mode, - geometry, - centered, - calibrated, - shift_center, - invert = False, - color = 'r', - alpha = 0.7, - **kwargs -): - """ - Display a diffraction space image with an overlaid mask representing - a virtual detector. - - Args: - data (DataCube, DiffractionSlice, array, tuple): - behavoir depends on the argument type: - DataCube - check to see if this datacube has a mean, max, - or median diffraction pattern, and if found, uses it - (order of preference as written here). If not found, - raises an exception. - DiffractionSlice - use the first slice - array - use this array. This mode only works when - centered, calibrated, and shift_center are False. - Otherwise, use the tuple entry (array, Calibration) - tuple - must be either: - - (DataCube, rx, ry) for rx,ry integers. - Use the diffraction pattern at this scan position. - `shift_center` is auto set to True in this mode. - - (array, Calibration) - mode: see py4DSTEM.process.get_virtual_image - geometry: see py4DSTEM.process.get_virtual_image - centered: see py4DSTEM.process.get_virtual_image - calibrated: see py4DSTEM.process.get_virtual_image - shift_center: see py4DSTEM.process.get_virtual_image; if True, `data` - should be a 3-tuple (DataCube, rx, ry) - invert: if True, invert the mask - **kwargs: all additional arguments are passed on to `show` - """ - # Parse data - if isinstance(data, DataCube): - cal = data.calibration - keys = ['dp_mean','dp_max','dp_median'] - for k in keys: - try: - image = data.tree(k) - break - except: - pass - else: - raise Exception("No mean, max, or median diffraction image found; try calling datacube.get_mean_dp() first") - elif isinstance(data, DiffractionSlice): - cal = data.calibration - try: - image = data[:,:,0] - except IndexError: - image = data[:,:] - elif isinstance(data, np.ndarray): - er = "centered and calibrated must be False to pass an uncalibrated array; set these to False or try using `data = (array, Calibration)`" - assert all([x is False for x in [centered,calibrated]]), er - image = data - cal = None - elif isinstance(data, tuple): - if len(data)==2: - image,cal = data - assert isinstance(image, np.ndarray) - assert isinstance(cal, Calibration) - elif len(data)==3: - data,rx,ry = data - image = data[rx,ry,:,:] - cal = data.calibration - else: - raise Exception(f"Invalid entry {data} for argument `data`") - - - # Get geometry - from py4DSTEM.process.virtualimage import get_calibrated_geometry - g = get_calibrated_geometry( - calibration = cal, - mode = mode, - geometry = geometry, - centered = centered, - calibrated = calibrated - ) - - # Get mask - from py4DSTEM.process.virtualimage import make_detector - mask = make_detector(image.shape, mode, g) - if not(invert): - mask = np.logical_not(mask) - - # Shift center - if shift_center: - try: - rx,ry - except NameError: - raise Exception("if `shift_center` is True then `data` must be the 3-tuple (DataCube,rx,ry)") - # get shifts - assert cal.get_origin_shift(), "origin shifts need to be calibrated" - qx_shift,qy_shift = cal.get_origin_shift() - qx_shift = int(np.round(qx_shift[rx,ry])) - qy_shift = int(np.round(qy_shift[rx,ry])) - mask = np.roll( - mask, - (qx_shift, qy_shift), - axis=(0,1) - ) - - # Display - - show( - image, - mask = mask, - mask_color = color, - mask_alpha = alpha, - **kwargs - ) - - return - - diff --git a/py4DSTEM/visualize/vis_RQ.py b/py4DSTEM/visualize/vis_RQ.py index 4e6f1d20a..72c50a396 100644 --- a/py4DSTEM/visualize/vis_RQ.py +++ b/py4DSTEM/visualize/vis_RQ.py @@ -3,7 +3,6 @@ from matplotlib.axes import Axes from py4DSTEM.visualize.show import show,show_points -from py4DSTEM.process.calibration.rotation import get_Qvector_from_Rvector,get_Rvector_from_Qvector @@ -79,6 +78,7 @@ def ax_addvector_RtoQ(ax,vx,vy,vlength,x0,y0,QR_rotation,width=1,color='r'): the counterclockwise rotation of real space with respect to diffraction space. In degrees. """ + from py4DSTEM.process.calibration.rotation import get_Qvector_from_Rvector _,_,vx,vy = get_Qvector_from_Rvector(vx,vy,QR_rotation) vx,vy = vx*vlength,vy*vlength ax.arrow(y0,x0,vy,vx,color=color,width=width,length_includes_head=True) @@ -102,6 +102,7 @@ def ax_addvector_QtoR(ax,vx,vy,vlength,x0,y0,QR_rotation,width=1,color='r'): the counterclockwise rotation of real space with respect to diffraction space. In degrees. """ + from py4DSTEM.process.calibration.rotation import get_Rvector_from_Qvector vx,vy,_,_ = get_Rvector_from_Qvector(vx,vy,QR_rotation) vx,vy = vx*vlength,vy*vlength ax.arrow(y0,x0,vy,vx,color=color,width=width,length_includes_head=True) @@ -278,6 +279,7 @@ def ax_addaxes_QtoR(ax,vx,vy,vlength,x0,y0,QR_rotation,width=1,color='r', the counterclockwise rotation of real space with respect to diffraction space. In degrees. """ + from py4DSTEM.process.calibration.rotation import get_Rvector_from_Qvector vx,vy,_,_ = get_Rvector_from_Qvector(vx,vy,QR_rotation) ax_addaxes(ax,vx,vy,vlength,x0,y0,width=width,color=color,labelaxes=labelaxes, labelsize=labelsize,labelcolor=labelcolor,righthandedcoords=True) @@ -304,6 +306,7 @@ def ax_addaxes_RtoQ(ax,vx,vy,vlength,x0,y0,QR_rotation,width=1,color='r', the counterclockwise rotation of real space with respect to diffraction space. In degrees. """ + from py4DSTEM.process.calibration.rotation import get_Qvector_from_Rvector _,_,vx,vy = get_Qvector_from_Rvector(vx,vy,QR_rotation) ax_addaxes(ax,vx,vy,vlength,x0,y0,width=width,color=color,labelaxes=labelaxes, labelsize=labelsize,labelcolor=labelcolor,righthandedcoords=True) diff --git a/py4DSTEM/visualize/vis_grid.py b/py4DSTEM/visualize/vis_grid.py index 8ca9c477d..48b5b158a 100644 --- a/py4DSTEM/visualize/vis_grid.py +++ b/py4DSTEM/visualize/vis_grid.py @@ -97,7 +97,7 @@ def show_image_grid( H,W, axsize=(6,6), returnfig=False, - figax = None, + figax = None, title = None, title_index = False, suptitle = None, diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index af8c074d0..43cf7fff8 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -6,10 +6,6 @@ from scipy.spatial import Voronoi from emdfile import PointList -from py4DSTEM.classes import DataCube,Calibration -from py4DSTEM.process.utils import get_voronoi_vertices,convert_ellipse_params -from py4DSTEM.process.calibration import double_sided_gaussian -from py4DSTEM.process.latticevectors import get_selected_lattice_vectors from py4DSTEM.visualize import show from py4DSTEM.visualize.overlay import ( add_pointlabels, @@ -24,6 +20,8 @@ + + def show_elliptical_fit(ar,fitradii,p_ellipse,fill=True, color_ann='y',color_ell='r',alpha_ann=0.2,alpha_ell=0.7, linewidth_ann=2,linewidth_ell=2,returnfig=False,**kwargs): @@ -93,6 +91,8 @@ def show_amorphous_ring_fit(dp,fitradii,p_dsg,N=12,cmap=('gray','gray'), ellipse (bool): if True, overlay an ellipse returnfig (bool): if True, returns the figure """ + from py4DSTEM.process.calibration import double_sided_gaussian + from py4DSTEM.process.utils import convert_ellipse_params assert(len(p_dsg)==11) assert(isinstance(N,(int,np.integer))) if isinstance(cmap,tuple): @@ -258,6 +258,7 @@ def show_voronoi(ar,x,y,color_points='r',color_lines='w',max_dist=None, """ words """ + from py4DSTEM.process.utils import get_voronoi_vertices Nx,Ny = ar.shape points = np.vstack((x,y)).T voronoi = Voronoi(points) @@ -589,100 +590,6 @@ def select_point(ar,x,y,i,color='lightblue',color_selected='r',size=20,returnfig return -def select_lattice_vectors(ar,gx,gy,i0,i1,i2, - c_indices='lightblue',c0='g',c1='r',c2='r',c_vectors='r',c_vectorlabels='w', - size_indices=20,width_vectors=1,size_vectorlabels=20, - figsize=(12,6),returnfig=False,**kwargs): - """ - This function accepts a set of reciprocal lattice points (gx,gy) and three indices - (i0,i1,i2). Using those indices as, respectively, the origin, the endpoint of g1, and - the endpoint of g2, this function computes the basis lattice vectors g1,g2, visualizes - them, and returns them. To compute these vectors without visualizing, use - latticevectors.get_selected_lattice_vectors(). - - Returns: - if returnfig==False: g1,g2 - if returnfig==True g1,g2,fig,ax - """ - # Make the figure - fig,(ax1,ax2) = plt.subplots(1,2,figsize=figsize) - show(ar,figax=(fig,ax1),**kwargs) - show(ar,figax=(fig,ax2),**kwargs) - - # Add indices to left panel - d = {'x':gx,'y':gy,'size':size_indices,'color':c_indices} - d0 = {'x':gx[i0],'y':gy[i0],'size':size_indices,'color':c0,'fontweight':'bold','labels':[str(i0)]} - d1 = {'x':gx[i1],'y':gy[i1],'size':size_indices,'color':c1,'fontweight':'bold','labels':[str(i1)]} - d2 = {'x':gx[i2],'y':gy[i2],'size':size_indices,'color':c2,'fontweight':'bold','labels':[str(i2)]} - add_pointlabels(ax1,d) - add_pointlabels(ax1,d0) - add_pointlabels(ax1,d1) - add_pointlabels(ax1,d2) - - # Compute vectors - g1,g2 = get_selected_lattice_vectors(gx,gy,i0,i1,i2) - - # Add vectors to right panel - dg1 = {'x0':gx[i0],'y0':gy[i0],'vx':g1[0],'vy':g1[1],'width':width_vectors, - 'color':c_vectors,'label':r'$g_1$','labelsize':size_vectorlabels,'labelcolor':c_vectorlabels} - dg2 = {'x0':gx[i0],'y0':gy[i0],'vx':g2[0],'vy':g2[1],'width':width_vectors, - 'color':c_vectors,'label':r'$g_2$','labelsize':size_vectorlabels,'labelcolor':c_vectorlabels} - add_vector(ax2,dg1) - add_vector(ax2,dg2) - - if returnfig: - return g1,g2,fig,(ax1,ax2) - else: - plt.show() - return g1,g2 - - -def show_lattice_vectors(ar,x0,y0,g1,g2,color='r',width=1,labelsize=20,labelcolor='w',returnfig=False,**kwargs): - """ Adds the vectors g1,g2 to an image, with tail positions at (x0,y0). g1 and g2 are 2-tuples (gx,gy). - """ - fig,ax = show(ar,returnfig=True,**kwargs) - - # Add vectors - dg1 = {'x0':x0,'y0':y0,'vx':g1[0],'vy':g1[1],'width':width, - 'color':color,'label':r'$g_1$','labelsize':labelsize,'labelcolor':labelcolor} - dg2 = {'x0':x0,'y0':y0,'vx':g2[0],'vy':g2[1],'width':width, - 'color':color,'label':r'$g_2$','labelsize':labelsize,'labelcolor':labelcolor} - add_vector(ax,dg1) - add_vector(ax,dg2) - - if returnfig: - return fig,ax - else: - plt.show() - return - - -def show_bragg_indexing(ar,braggdirections,voffset=5,hoffset=0,color='w',size=20, - points=True,pointcolor='r',pointsize=50,returnfig=False,**kwargs): - """ - Shows an array with an overlay describing the Bragg directions - - Accepts: - ar (arrray) the image - bragg_directions (PointList) the bragg scattering directions; must have coordinates - 'qx','qy','h', and 'k'. Optionally may also have 'l'. - """ - assert isinstance(braggdirections,PointList) - for k in ('qx','qy','h','k'): - assert k in braggdirections.data.dtype.fields - - fig,ax = show(ar,returnfig=True,**kwargs) - d = {'braggdirections':braggdirections,'voffset':voffset,'hoffset':hoffset,'color':color, - 'size':size,'points':points,'pointsize':pointsize,'pointcolor':pointcolor} - add_bragg_index_labels(ax,d) - - if returnfig: - return fig,ax - else: - plt.show() - return - - def show_max_peak_spacing(ar,spacing,braggdirections,color='g',lw=2,returnfig=False,**kwargs): """ Show a circle of radius `spacing` about each Bragg direction """ @@ -702,6 +609,8 @@ def show_origin_meas(data): Args: data (DataCube or Calibration or 2-tuple of arrays (qx0,qy0)) """ + from py4DSTEM.data import Calibration + from py4DSTEM.datacube import DataCube if isinstance(data,tuple): assert len(data)==2 qx,qy = data @@ -722,6 +631,8 @@ def show_origin_fit(data): data (DataCube or Calibration or (3,2)-tuple of arrays ((qx0_meas,qy0_meas),(qx0_fit,qy0_fit),(qx0_residuals,qy0_residuals)) """ + from py4DSTEM.data import Calibration + from py4DSTEM.datacube import DataCube if isinstance(data,tuple): assert len(data)==3 qx0_meas,qy_meas = data[0] @@ -764,6 +675,7 @@ def show_selected_dps(datacube,positions,im,bragg_pos=None, **kwargs (dict): arguments passed to visualize.show for the *diffraction patterns*. Default is `scaling='log'` """ + from py4DSTEM.datacube import DataCube assert isinstance(datacube,DataCube) N = len(positions) assert(all([len(x)==2 for x in positions])), "Improperly formated argument `positions`" diff --git a/setup.py b/setup.py index 0afcb8bba..b0c7fa081 100644 --- a/setup.py +++ b/setup.py @@ -21,26 +21,27 @@ author_email='ben.savitzky@gmail.com', license='GNU GPLv3', keywords="STEM 4DSTEM", - python_requires='>=3.9,<3.11', + python_requires='>=3.9,<3.12', install_requires=[ 'numpy >= 1.19', 'scipy >= 1.5.2', 'h5py >= 3.2.0', + 'hdf5plugin >= 4.1.3', 'ncempy >= 1.8.1', 'matplotlib >= 3.2.2', 'scikit-image >= 0.17.2', 'scikit-learn >= 0.23.2', - 'ipywidgets >= 7.6.3', + 'scikit-optimize >= 0.9.0', 'tqdm >= 4.46.1', 'dill >= 0.3.3', 'gdown >= 4.4.0', 'dask >= 2.3.0', 'distributed >= 2.3.0', - 'emdfile == 0.0.8', + 'emdfile >= 0.0.10', ], extras_require={ 'ipyparallel': ['ipyparallel >= 6.2.4', 'dill >= 0.3.3'], - 'cuda': ['cupy'], + 'cuda': ['cupy >= 10.0.0'], 'acom': ['pymatgen >= 2022', 'mp-api == 0.24.1'], 'aiml': ['tensorflow == 2.4.1','tensorflow-addons <= 0.14.0','crystal4D'], 'aiml-cuda': ['tensorflow == 2.4.1','tensorflow-addons <= 0.14.0','crystal4D','cupy'], diff --git a/test/gettestdata.py b/test/gettestdata.py new file mode 100644 index 000000000..a84e5b9b3 --- /dev/null +++ b/test/gettestdata.py @@ -0,0 +1,85 @@ +# A command line tool for downloading data to run the py4DSTEM test suite + + +import argparse +from os.path import join, exists +from os import makedirs + +from py4DSTEM import _TESTPATH as testpath +from py4DSTEM.io import gdrive_download as download + + + +# Make the argument parser +parser = argparse.ArgumentParser( + description = "A command line tool for downloading data to run the py4DSTEM test suite" +) + +# Set up data download options +data_options = [ + 'tutorials', + 'io', + 'basic', + 'strain', +] + +# Add arguments +parser.add_argument( + "data", + help = "which data to download.", + choices = data_options, +) +parser.add_argument( + "-o", "--overwrite", + help = "if turned on, overwrite files that are already present. Otherwise, skips these files.", + action = "store_true" +) +parser.add_argument( + "-v", "--verbose", + help = "turn on verbose output", + action = "store_true" +) + + + +# Get the command line arguments +args = parser.parse_args() + + +# Set up paths +if not exists(testpath): + makedirs(testpath) + + +# Set data collection key +if args.data == 'tutorials': + data = ['tutorials'] +elif args.data == 'io': + data = ['test_io','test_arina'] +elif args.data == 'basic': + data = ['small_datacube'] +elif args.data == 'strain': + data = ['strain'] +else: + raise Exception(f"invalid data choice, {parser.data}") + +# Download data +for d in data: + download( + d, + destination = testpath, + overwrite = args.overwrite, + verbose = args.verbose + ) + +# Always download the basic datacube +if args.data != 'basic': + download( + 'small_datacube', + destination = testpath, + overwrite = args.overwrite, + verbose = args.verbose + ) + + + diff --git a/test/test_classes/test_braggvectors.py b/test/test_braggvectors.py similarity index 64% rename from test/test_classes/test_braggvectors.py rename to test/test_braggvectors.py index 901317c05..36ec138a0 100644 --- a/test/test_classes/test_braggvectors.py +++ b/test/test_braggvectors.py @@ -3,7 +3,7 @@ from os.path import join # set filepath -path = join(py4DSTEM._TESTPATH,"simulatedAuNanoplatelet_binned_v0_9.h5") +path = join(py4DSTEM._TESTPATH,"test_io/legacy_v0.9_simAuNanoplatelet_bin.h5") @@ -24,7 +24,7 @@ def setup_class(cls): mask = np.zeros(datacube.Rshape,dtype=bool) mask[28:33,14:19] = 1 probe = datacube.get_vacuum_probe( ROI=mask ) - alpha_pr,qx0_pr,qy0_pr = py4DSTEM.process.probe.get_probe_size( probe.probe ) + alpha_pr,qx0_pr,qy0_pr = py4DSTEM.process.calibration.get_probe_size( probe.probe ) probe.get_kernel( mode='sigmoid', origin=(qx0_pr,qy0_pr), @@ -45,6 +45,14 @@ def setup_class(cls): # 'CUDA': True, } + # find disks + cls.braggpeaks = datacube.find_Bragg_disks( + template = probe.kernel, + **cls.detect_params, + ) + + # set an arbitrary center for testing + cls.braggpeaks.calibration.set_origin((datacube.Qshape[0]/2,datacube.Qshape[1]/2)) @@ -53,7 +61,7 @@ def setup_class(cls): def test_BraggVectors_import(self): - from py4DSTEM.process.diskdetection import BraggVectors + from py4DSTEM.braggvectors import BraggVectors pass @@ -68,20 +76,19 @@ def test_disk_detection_selected_positions(self): **self.detect_params, ) - def test_disk_detection(self): - - braggpeaks = self.datacube.find_Bragg_disks( - template = self.probe.kernel, - **self.detect_params, - ) - + def test_BraggVectors(self): - print(braggpeaks) + print(self.braggpeaks) print() - print(braggpeaks.raw[0,0]) + print(self.braggpeaks.raw[0,0]) print() - print(braggpeaks.cal[0,0]) + print(self.braggpeaks.cal[0,0]) print() + print(self.braggpeaks.get_vectors( + scan_x=5,scan_y=5, + center=True,ellipse=False,pixel=False,rotate=False + )) + diff --git a/test/test_calibration.py b/test/test_calibration.py new file mode 100644 index 000000000..80cb03493 --- /dev/null +++ b/test/test_calibration.py @@ -0,0 +1,74 @@ +import py4DSTEM +from py4DSTEM import Calibration +import numpy as np +from os import mkdir, remove, rmdir +from os.path import join, exists + +# set filepaths +path_datacube = join(py4DSTEM._TESTPATH, "small_datacube.dm4") +path_3Darray = join(py4DSTEM._TESTPATH, "test_io/small_dm3_3Dstack.dm3") + +path_out_dir = join(py4DSTEM._TESTPATH, "test_outputs") +path_out = join(path_out_dir, "test_calibration.h5") + + +class TestCalibration: + + # setup + + def setup_class(cls): + if not exists(path_out_dir): + mkdir(path_out_dir) + + def teardown_class(cls): + if exists(path_out_dir): + rmdir(path_out_dir) + + def teardown_method(self): + if exists(path_out): + remove(path_out) + + + # test + + def test_imported_datacube_calibration(self): + + datacube = py4DSTEM.import_file(path_datacube) + + assert(hasattr(datacube,'calibration')) + assert(isinstance(datacube.calibration,Calibration)) + assert(hasattr(datacube,'root')) + assert(isinstance(datacube.root,py4DSTEM.Root)) + + + def test_instantiated_datacube_calibration(self): + + datacube = py4DSTEM.DataCube( + data = np.ones((4,8,128,128)) + ) + + assert(hasattr(datacube,'calibration')) + assert(isinstance(datacube.calibration,Calibration)) + assert(hasattr(datacube,'root')) + assert(isinstance(datacube.root,py4DSTEM.Root)) + + datacube.calibration.set_Q_pixel_size(10) + + py4DSTEM.save( + path_out, + datacube + ) + + new_datacube = py4DSTEM.read(path_out) + + assert(hasattr(new_datacube,'calibration')) + assert(isinstance(new_datacube.calibration,Calibration)) + assert(hasattr(new_datacube,'root')) + assert(isinstance(new_datacube.root,py4DSTEM.Root)) + + assert(new_datacube.calibration.get_Q_pixel_size() == 10) + + + + + diff --git a/test/test_classes/test_crystal.py b/test/test_crystal.py similarity index 100% rename from test/test_classes/test_crystal.py rename to test/test_crystal.py diff --git a/test/test_datacube.py b/test/test_datacube.py new file mode 100644 index 000000000..7e0f29188 --- /dev/null +++ b/test/test_datacube.py @@ -0,0 +1,37 @@ +import py4DSTEM +import numpy as np +from os.path import join + +# set filepath +path = py4DSTEM._TESTPATH + "/small_datacube.dm4" + + + +class TestDataCube: + + # setup/teardown + def setup_class(cls): + + # Read datacube + datacube = py4DSTEM.import_file(path) + cls.datacube = datacube + + # tests + + def test_binning_default_dtype(self): + + dtype = self.datacube.data.dtype + assert(dtype == np.uint16) + + self.datacube.bin_Q(2) + + assert(self.datacube.data.dtype == dtype) + + new_dtype = np.uint32 + self.datacube.bin_Q(2, dtype=new_dtype) + + assert(self.datacube.data.dtype == new_dtype) + assert(self.datacube.data.dtype != dtype) + + pass + diff --git a/test/test_fileimport/test_dm.py b/test/test_fileimport/test_dm.py deleted file mode 100644 index 18bab1fbd..000000000 --- a/test/test_fileimport/test_dm.py +++ /dev/null @@ -1,18 +0,0 @@ -import py4DSTEM -import emdfile -from os.path import join - - -# Set filepaths -filepath_dm = join(py4DSTEM._TESTPATH, "small_dm3.dm3") - - -def test_dmfile_3Darray(): - data = py4DSTEM.import_file( filepath_dm ) - assert isinstance(data, emdfile.Array) - - -# TODO -# def test_dmfile_4Darray(): -# def test_dmfile_multiple_datablocks(): - diff --git a/test/test_misc.py b/test/test_misc.py new file mode 100644 index 000000000..275eb767b --- /dev/null +++ b/test/test_misc.py @@ -0,0 +1,32 @@ +import py4DSTEM +import numpy as np + + +def test_attach(): + """ tests to make sure Data.attach handles metadata merging correctly + """ + + x = py4DSTEM.DiffractionSlice(np.ones((5,5)), name='x') + y = py4DSTEM.DiffractionSlice(np.ones((5,5)), name='y') + + + x.calibration.set_Q_pixel_size(50) + y.calibration.set_Q_pixel_size(2) + + x.attach(y) + + assert('y' in x.treekeys) + assert(x.calibration.get_Q_pixel_size() == 50) + + +def test_datacube_copy(): + """ tests datacube.copy() + """ + x = py4DSTEM.DataCube(data=np.zeros((3,3,4,4))) + y = x.copy() + assert(isinstance(y,py4DSTEM.DataCube)) + + + + + diff --git a/test/test_native_io/test_calibration_io.py b/test/test_native_io/test_calibration_io.py new file mode 100644 index 000000000..a753c7c81 --- /dev/null +++ b/test/test_native_io/test_calibration_io.py @@ -0,0 +1,44 @@ +import py4DSTEM +import numpy as np +from os.path import join + +# set filepath +#path = join(py4DSTEM._TESTPATH, "filename") + + + +# class TestCalibrationIO: +# +# +# +# def test_datacube_cal_io(self): +# # TODO +# # make a datacube +# # modify its calibration +# # save +# # load the datacube +# # check its calibration +# assert 0 +# pass +# +# +# def test_datacube_child_node(self): +# # TODO +# # make a datacube +# # make a child node +# # confirm calibrations are the same +# # modify the calibration +# # save +# # load the datacube +# # check its calibration +# # load just the child node +# # check its calibration +# assert 0 +# pass + + + + + + + diff --git a/test/test_native_readwrite/test_realslice_read.py b/test/test_native_io/test_realslice_read.py similarity index 74% rename from test/test_native_readwrite/test_realslice_read.py rename to test/test_native_io/test_realslice_read.py index a1a8f09e7..8fb577ad8 100644 --- a/test/test_native_readwrite/test_realslice_read.py +++ b/test/test_native_io/test_realslice_read.py @@ -6,7 +6,7 @@ # Set filepaths -filepath = join(py4DSTEM._TESTPATH, "YanAiming_bilayer_v01.h5") +filepath = join(py4DSTEM._TESTPATH, "test_io/test_realslice_io.h5") def test_read_realslice(): realslice = py4DSTEM.read(filepath, datapath='4DSTEM/Fit Data') diff --git a/test/test_native_io/test_single_object_io.py b/test/test_native_io/test_single_object_io.py new file mode 100644 index 000000000..d1a826c5f --- /dev/null +++ b/test/test_native_io/test_single_object_io.py @@ -0,0 +1,256 @@ +import numpy as np +from os.path import join,exists +from os import remove +from numpy import array_equal + +import py4DSTEM +from py4DSTEM import save,read +import emdfile as emd + +from py4DSTEM import ( + Calibration, + DiffractionSlice, + RealSlice, + QPoints, + DataCube, + VirtualImage, + VirtualDiffraction, + BraggVectors, + Probe +) + +# Set paths +dirpath = py4DSTEM._TESTPATH +path_dm3 = join(dirpath,"test_io/small_dm3_3Dstack.dm3") +path_h5 = join(dirpath,"test.h5") + + +class TestDataCubeIO(): + + def test_datacube_instantiation(self): + """ + Instantiate a datacube and apply basic calibrations + """ + datacube = DataCube( + data = np.arange(np.prod((4,5,6,7))).reshape((4,5,6,7)) + ) + # calibration + datacube.calibration.set_Q_pixel_size(0.062) + datacube.calibration.set_Q_pixel_units("A^-1") + datacube.calibration.set_R_pixel_size(2.8) + datacube.calibration.set_R_pixel_units("nm") + + return datacube + + def test_datacube_io(self): + """ + Instantiate, save, then read a datacube, and + compare its contents before/after + """ + datacube = self.test_datacube_instantiation() + + assert(isinstance(datacube,DataCube)) + # test dim vectors + assert(datacube.dim_names[0] == 'Rx') + assert(datacube.dim_names[1] == 'Ry') + assert(datacube.dim_names[2] == 'Qx') + assert(datacube.dim_names[3] == 'Qy') + assert(datacube.dim_units[0] == 'nm') + assert(datacube.dim_units[1] == 'nm') + assert(datacube.dim_units[2] == 'A^-1') + assert(datacube.dim_units[3] == 'A^-1') + assert(datacube.dims[0][1] == 2.8) + assert(datacube.dims[2][1] == 0.062) + # check the calibrations + assert(datacube.calibration.get_Q_pixel_size() == 0.062) + assert(datacube.calibration.get_Q_pixel_units() == "A^-1") + # save and read + save(path_h5,datacube,mode='o') + new_datacube = read(path_h5) + # check it's the same + assert(isinstance(new_datacube,DataCube)) + assert(array_equal(datacube.data,new_datacube.data)) + assert(new_datacube.calibration.get_Q_pixel_size() == 0.062) + assert(new_datacube.calibration.get_Q_pixel_units() == "A^-1") + assert(new_datacube.dims[0][1] == 2.8) + assert(new_datacube.dims[2][1] == 0.062) + + + +class TestBraggVectorsIO(): + + def test_braggvectors_instantiation(self): + """ + Instantiate a braggvectors instance + """ + braggvectors = BraggVectors( + Rshape = (5,6), + Qshape = (7,8) + ) + for x in range(braggvectors.Rshape[0]): + for y in range(braggvectors.Rshape[1]): + L = int(4 * (np.sin(x*y)+1)) + braggvectors._v_uncal[x,y].add( + np.ones(L,dtype=braggvectors._v_uncal.dtype) + ) + return braggvectors + + + def test_braggvectors_io(self): + """ Save then read a BraggVectors instance, and compare contents before/after + """ + braggvectors = self.test_braggvectors_instantiation() + + assert(isinstance(braggvectors,BraggVectors)) + # save then read + save(path_h5,braggvectors,mode='o') + new_braggvectors = read(path_h5) + # check it's the same + assert(isinstance(new_braggvectors,BraggVectors)) + assert(new_braggvectors is not braggvectors) + for x in range(new_braggvectors.shape[0]): + for y in range(new_braggvectors.shape[1]): + assert(array_equal( + new_braggvectors._v_uncal[x,y].data, + braggvectors._v_uncal[x,y].data)) + + +class TestSlices: + + + # test instantiation + + def test_diffractionslice_instantiation(self): + diffractionslice = DiffractionSlice( + data = np.arange(np.prod((4,8,2))).reshape((4,8,2)), + slicelabels = ['a','b'] + ) + return diffractionslice + + def test_realslice_instantiation(self): + realslice = RealSlice( + data = np.arange(np.prod((8,4,2))).reshape((8,4,2)), + slicelabels = ['x','y'] + ) + return realslice + + def test_virtualdiffraction_instantiation(self): + virtualdiffraction = VirtualDiffraction( + data = np.arange(np.prod((8,4,2))).reshape((8,4,2)), + ) + return virtualdiffraction + + def test_virtualimage_instantiation(self): + virtualimage = VirtualImage( + data = np.arange(np.prod((8,4,2))).reshape((8,4,2)), + ) + return virtualimage + + def test_probe_instantiation(self): + probe = Probe( + data = np.arange(8*12).reshape((8,12)) + ) + # add a kernel + probe.kernel = np.ones_like(probe.probe) + # return + return probe + + + # test io + + def test_diffractionslice_io(self): + """ test diffractionslice io + """ + diffractionslice = self.test_diffractionslice_instantiation() + assert(isinstance(diffractionslice,DiffractionSlice)) + # save and read + save(path_h5,diffractionslice,mode='o') + new_diffractionslice = read(path_h5) + # check it's the same + assert(isinstance(new_diffractionslice,DiffractionSlice)) + assert(array_equal(diffractionslice.data,new_diffractionslice.data)) + assert(diffractionslice.slicelabels == new_diffractionslice.slicelabels) + + + def test_realslice_io(self): + """ test realslice io + """ + realslice = self.test_realslice_instantiation() + assert(isinstance(realslice,RealSlice)) + # save and read + save(path_h5,realslice,mode='o') + rs = read(path_h5) + # check it's the same + assert(isinstance(rs,RealSlice)) + assert(array_equal(realslice.data,rs.data)) + assert(rs.slicelabels == realslice.slicelabels) + + def test_virtualdiffraction_io(self): + """ test virtualdiffraction io + """ + virtualdiffraction = self.test_virtualdiffraction_instantiation() + assert(isinstance(virtualdiffraction,VirtualDiffraction)) + # save and read + save(path_h5,virtualdiffraction,mode='o') + vd = read(path_h5) + # check it's the same + assert(isinstance(vd,VirtualDiffraction)) + assert(array_equal(vd.data,virtualdiffraction.data)) + pass + + def test_virtualimage_io(self): + """ test virtualimage io + """ + virtualimage = self.test_virtualimage_instantiation() + assert(isinstance(virtualimage,VirtualImage)) + # save and read + save(path_h5,virtualimage,mode='o') + virtIm = read(path_h5) + # check it's the same + assert(isinstance(virtIm,VirtualImage)) + assert(array_equal(virtualimage.data,virtIm.data)) + pass + + + def test_probe1_io(self): + """ test probe io + """ + probe0 = self.test_probe_instantiation() + assert(isinstance(probe0,Probe)) + # save and read + save(path_h5,probe0,mode='o') + probe = read(path_h5) + # check it's the same + assert(isinstance(probe,Probe)) + assert(array_equal(probe0.data,probe.data)) + pass + + + + + +class TestPoints: + + def test_qpoints_instantiation(self): + qpoints = QPoints( + data = np.ones(10, + dtype = [('qx',float),('qy',float),('intensity',float)] + ) + ) + return qpoints + + + def test_qpoints_io(self): + """ test qpoints io + """ + qpoints0 = self.test_qpoints_instantiation() + assert(isinstance(qpoints0,QPoints)) + # save and read + save(path_h5,qpoints0,mode='o') + qpoints = read(path_h5) + # check it's the same + assert(isinstance(qpoints,QPoints)) + assert(array_equal(qpoints0.data,qpoints.data)) + pass + + diff --git a/test/test_native_readwrite/test_v0_13.py b/test/test_native_io/test_v0_13.py similarity index 90% rename from test/test_native_readwrite/test_v0_13.py rename to test/test_native_io/test_v0_13.py index b3fa0328d..92ab41ca2 100644 --- a/test/test_native_readwrite/test_v0_13.py +++ b/test/test_native_io/test_v0_13.py @@ -3,7 +3,7 @@ # Set filepaths -filepath = join(_TESTPATH, "v13_sample.h5") +filepath = join(_TESTPATH, "test_io/legacy_v0.13.h5") diff --git a/test/test_native_readwrite/test_v0_14.py b/test/test_native_io/test_v0_14.py similarity index 95% rename from test/test_native_readwrite/test_v0_14.py rename to test/test_native_io/test_v0_14.py index c6ba27ea8..33230a816 100644 --- a/test/test_native_readwrite/test_v0_14.py +++ b/test/test_native_io/test_v0_14.py @@ -3,7 +3,7 @@ -path = join(py4DSTEM._TESTPATH,'test_v14_sample.h5') +path = join(py4DSTEM._TESTPATH,'test_io/legacy_v0.14.h5') @@ -14,7 +14,7 @@ def _make_v14_test_file(): assert(py4DSTEM.__version__.split('.')[1]=='14'), 'no!' # Set filepaths - filepath_data = join(py4DSTEM._TESTPATH,"simulatedAuNanoplatelet_binned_v0_9.h5") + filepath_data = join(py4DSTEM._TESTPATH,"test_io/legacy_v0.9_simAuNanoplatelet_bin.h5") # Read sim Au datacube datacube = py4DSTEM.io.read( diff --git a/test/test_native_readwrite/test_v0_9.py b/test/test_native_io/test_v0_9.py similarity index 77% rename from test/test_native_readwrite/test_v0_9.py rename to test/test_native_io/test_v0_9.py index b3424efa4..8024f646c 100644 --- a/test/test_native_readwrite/test_v0_9.py +++ b/test/test_native_io/test_v0_9.py @@ -1,7 +1,7 @@ from py4DSTEM import read, DataCube, _TESTPATH from os.path import join -path = join(_TESTPATH, 'simulatedAuNanoplatelet_binned_v0_9.h5') +path = join(_TESTPATH,"test_io/legacy_v0.9_simAuNanoplatelet_bin.h5") def test_read_v0_9_noID(): diff --git a/test/test_native_readwrite/test_single_object_io.py b/test/test_native_readwrite/test_single_object_io.py deleted file mode 100644 index 22ee2d0df..000000000 --- a/test/test_native_readwrite/test_single_object_io.py +++ /dev/null @@ -1,316 +0,0 @@ -import numpy as np -from os.path import join,exists -from os import remove -from numpy import array_equal - -import py4DSTEM -from py4DSTEM import save,read -import emdfile as emd - -from py4DSTEM.classes import ( - DataCube, - BraggVectors, - DiffractionSlice, - VirtualDiffraction, - Probe, - RealSlice, - VirtualImage, - QPoints, - Calibration -) - -# Set paths -dirpath = py4DSTEM._TESTPATH -path_dm3 = join(dirpath,"small_dm3.dm3") -path_h5 = join(dirpath,"test.h5") - - - - -class TestSingleDataNodeIO: - - ## Setup and teardown - - @classmethod - def setup_class(cls): - cls._clear_files(cls) - cls._make_data(cls) - - @classmethod - def teardown_class(cls): - pass - - def setup_method(self, method): - pass - - def teardown_method(self, method): - self._clear_files() - - def _make_data(self): - """ Make - - a datacube - - a braggvectors instance with only v_uncal - - a braggvectors instance with both PLAs - - a diffractionslice - - a realslice - - a probe, with no kernel - - a probe, with a kernel - - a qpoints instance - - a virtualdiffraction instance - - a virtualimage instance - """ - # datacube - self.datacube = DataCube( - data = np.arange(np.prod((4,5,6,7))).reshape((4,5,6,7)) - ) - # calibration - self.datacube.calibration.set_Q_pixel_size(0.062) - self.datacube.calibration.set_Q_pixel_units("A^-1") - self.datacube.calibration.set_R_pixel_size(2.8) - self.datacube.calibration.set_R_pixel_units("nm") - # braggvectors - self.braggvectors = BraggVectors( - Rshape = (5,6), - Qshape = (7,8) - ) - for x in range(self.braggvectors.Rshape[0]): - for y in range(self.braggvectors.Rshape[1]): - L = int(4 * (np.sin(x*y)+1)) - self.braggvectors.vectors_uncal[x,y].add( - np.ones(L,dtype=self.braggvectors.vectors_uncal.dtype) - ) - # braggvectors 2 - self.braggvectors2 = BraggVectors( - Rshape = (5,6), - Qshape = (7,8) - ) - for x in range(self.braggvectors2.Rshape[0]): - for y in range(self.braggvectors2.Rshape[1]): - L = int(4 * (np.sin(x*y)+1)) - self.braggvectors2.vectors_uncal[x,y].add( - np.ones(L,dtype=self.braggvectors2.vectors_uncal.dtype) - ) - self.braggvectors2._v_cal = self.braggvectors2._v_uncal.copy(name='_v_uncal') - # diffractionslice - self.diffractionslice = DiffractionSlice( - data = np.arange(np.prod((4,8,2))).reshape((4,8,2)), - slicelabels = ['a','b'] - ) - # realslice - self.realslice = RealSlice( - data = np.arange(np.prod((8,4,2))).reshape((8,4,2)), - slicelabels = ['x','y'] - ) - # virtualdiffraction instance - self.virtualdiffraction = VirtualDiffraction( - data = np.arange(np.prod((8,4,2))).reshape((8,4,2)), - ) - # virtualimage instance - self.virtualimage = VirtualImage( - data = np.arange(np.prod((8,4,2))).reshape((8,4,2)), - ) - # probe, with no kernel - self.probe1 = Probe( - data = np.arange(8*12).reshape((8,12)) - ) - # probe, with a kernel - self.probe2 = Probe( - data = np.arange(8*12*2).reshape((8,12,2)) - ) - # qpoints instance - self.qpoints = QPoints( - data = np.ones(10, - dtype = [('qx',float),('qy',float),('intensity',float)] - ) - ) - - - - - - - - - - def _clear_files(self): - """ - Delete h5 files which this test suite wrote - """ - paths = [ - path_h5 - ] - for p in paths: - if exists(p): - remove(p) - - - - - - ## Tests - - def test_datacube(self): - """ Save then read a datacube, and compare its contents before/after - """ - assert(isinstance(self.datacube,DataCube)) - # test dim vectors - assert(self.datacube.dim_names[0] == 'Rx') - assert(self.datacube.dim_names[1] == 'Ry') - assert(self.datacube.dim_names[2] == 'Qx') - assert(self.datacube.dim_names[3] == 'Qy') - assert(self.datacube.dim_units[0] == 'nm') - assert(self.datacube.dim_units[1] == 'nm') - assert(self.datacube.dim_units[2] == 'A^-1') - assert(self.datacube.dim_units[3] == 'A^-1') - assert(self.datacube.dims[0][1] == 2.8) - assert(self.datacube.dims[2][1] == 0.062) - # check the calibrations - assert(self.datacube.calibration.get_Q_pixel_size() == 0.062) - assert(self.datacube.calibration.get_Q_pixel_units() == "A^-1") - # save and read - save(path_h5,self.datacube) - new_datacube = read(path_h5) - # check it's the same - assert(isinstance(new_datacube,DataCube)) - assert(array_equal(self.datacube.data,new_datacube.data)) - assert(new_datacube.calibration.get_Q_pixel_size() == 0.062) - assert(new_datacube.calibration.get_Q_pixel_units() == "A^-1") - assert(new_datacube.dims[0][1] == 2.8) - assert(new_datacube.dims[2][1] == 0.062) - - - - - def test_braggvectors(self): - """ Save then read a BraggVectors instance, and compare contents before/after - """ - assert(isinstance(self.braggvectors,BraggVectors)) - # save then read - save(path_h5,self.braggvectors) - new_braggvectors = read(path_h5) - # check it's the same - assert(isinstance(new_braggvectors,BraggVectors)) - assert(new_braggvectors is not self.braggvectors) - for x in range(new_braggvectors.shape[0]): - for y in range(new_braggvectors.shape[1]): - assert(array_equal( - new_braggvectors.v_uncal[x,y].data, - self.braggvectors.v_uncal[x,y].data)) - # check that _v_cal isn't there - assert(not(hasattr(new_braggvectors,'_v_cal'))) - - - def test_braggvectors2(self): - """ Save then read a BraggVectors instance, and compare contents before/after - """ - assert(isinstance(self.braggvectors2,BraggVectors)) - # save then read - save(path_h5,self.braggvectors2) - new_braggvectors = read(path_h5) - # check it's the same - assert(isinstance(new_braggvectors,BraggVectors)) - assert(new_braggvectors is not self.braggvectors2) - for x in range(new_braggvectors.shape[0]): - for y in range(new_braggvectors.shape[1]): - assert(array_equal( - new_braggvectors.v_uncal[x,y].data, - self.braggvectors2.v_uncal[x,y].data)) - # check cal vectors are there - assert(hasattr(new_braggvectors,'_v_cal')) - # check it's the same - for x in range(new_braggvectors.shape[0]): - for y in range(new_braggvectors.shape[1]): - assert(array_equal( - new_braggvectors.v[x,y].data, - self.braggvectors2.v[x,y].data)) - - - def test_diffractionslice(self): - """ test diffractionslice io - """ - assert(isinstance(self.diffractionslice,DiffractionSlice)) - # save and read - save(path_h5,self.diffractionslice) - diffractionslice = read(path_h5) - # check it's the same - assert(isinstance(diffractionslice,DiffractionSlice)) - assert(array_equal(self.diffractionslice.data,diffractionslice.data)) - assert(diffractionslice.slicelabels == self.diffractionslice.slicelabels) - - - def test_realslice(self): - """ test realslice io - """ - assert(isinstance(self.realslice,RealSlice)) - # save and read - save(path_h5,self.realslice) - rs = read(path_h5) - # check it's the same - assert(isinstance(rs,RealSlice)) - assert(array_equal(self.realslice.data,rs.data)) - assert(rs.slicelabels == self.realslice.slicelabels) - - def test_virtualdiffraction(self): - """ test virtualdiffraction io - """ - assert(isinstance(self.virtualdiffraction,VirtualDiffraction)) - # save and read - save(path_h5,self.virtualdiffraction) - virtualdiffraction = read(path_h5) - # check it's the same - assert(isinstance(virtualdiffraction,VirtualDiffraction)) - assert(array_equal(self.virtualdiffraction.data,virtualdiffraction.data)) - pass - - def test_virtualimage(self): - """ test virtualimage io - """ - assert(isinstance(self.virtualimage,VirtualImage)) - # save and read - save(path_h5,self.virtualimage) - virtIm = read(path_h5) - # check it's the same - assert(isinstance(virtIm,VirtualImage)) - assert(array_equal(self.virtualimage.data,virtIm.data)) - pass - - def test_probe1(self): - """ test probe io - """ - assert(isinstance(self.probe1,Probe)) - # check for an empty kernel array - assert(array_equal(self.probe1.kernel,np.zeros_like(self.probe1.probe))) - # save and read - save(path_h5,self.probe1) - probe = read(path_h5) - # check it's the same - assert(isinstance(probe,Probe)) - assert(array_equal(self.probe1.data,probe.data)) - pass - - def test_probe2(self): - """ test probe io - """ - assert(isinstance(self.probe2,Probe)) - # save and read - save(path_h5,self.probe2) - probe = read(path_h5) - # check it's the same - assert(isinstance(probe,Probe)) - assert(array_equal(self.probe2.data,probe.data)) - pass - - - def test_qpoints(self): - """ test qpoints io - """ - assert(isinstance(self.qpoints,QPoints)) - # save and read - save(path_h5,self.qpoints) - qpoints = read(path_h5) - # check it's the same - assert(isinstance(qpoints,QPoints)) - assert(array_equal(self.qpoints.data,qpoints.data)) - pass - - diff --git a/test/test_nonnative_io/test_arina.py b/test/test_nonnative_io/test_arina.py new file mode 100644 index 000000000..c27cb8ef5 --- /dev/null +++ b/test/test_nonnative_io/test_arina.py @@ -0,0 +1,19 @@ +import py4DSTEM +import emdfile +from os.path import join + + +# Set filepaths +filepath = join(py4DSTEM._TESTPATH, "test_arina/STO_STEM_bench_20us_master.h5") + + +def test_read_arina(): + + # read + data = py4DSTEM.import_file( filepath ) + + # check imported data + assert isinstance(data, emdfile.Array) + assert isinstance(data, py4DSTEM.DataCube) + + diff --git a/test/test_nonnative_io/test_dm.py b/test/test_nonnative_io/test_dm.py new file mode 100644 index 000000000..a3d5aa7b0 --- /dev/null +++ b/test/test_nonnative_io/test_dm.py @@ -0,0 +1,24 @@ +import py4DSTEM +import emdfile +from os.path import join + + +# Set filepaths +filepath_dm4_datacube = join(py4DSTEM._TESTPATH, "small_datacube.dm4") +filepath_dm3_3Dstack = join(py4DSTEM._TESTPATH, "test_io/small_dm3_3Dstack.dm3") + + +def test_dmfile_datacube(): + data = py4DSTEM.import_file( filepath_dm4_datacube ) + assert isinstance(data, emdfile.Array) + assert isinstance(data, py4DSTEM.DataCube) + +def test_dmfile_3Darray(): + data = py4DSTEM.import_file( filepath_dm3_3Dstack ) + assert isinstance(data, emdfile.Array) + + +# TODO +# def test_dmfile_multiple_datablocks(): +# def test_dmfile_2Darray + diff --git a/test/test_probe.py b/test/test_probe.py new file mode 100644 index 000000000..ad10c3100 --- /dev/null +++ b/test/test_probe.py @@ -0,0 +1,65 @@ +import py4DSTEM +from py4DSTEM import Probe +import numpy as np +from os.path import join + +# set filepath +path = py4DSTEM._TESTPATH + "/small_datacube.dm4" + + + +class TestProbe: + + # setup/teardown + def setup_class(cls): + + # Read datacube + datacube = py4DSTEM.import_file(path) + cls.datacube = datacube + + # tests + + def test_probe_gen_from_dp(self): + p = Probe.from_vacuum_data( + self.datacube[0,0] + ) + assert(isinstance(p,Probe)) + pass + + def test_probe_gen_from_stack(self): + # get a 3D stack + x,y = np.zeros(10).astype(int),np.arange(10).astype(int) + data = self.datacube.data[x,y,:,:] + # get the probe + p = Probe.from_vacuum_data( + data + ) + assert(isinstance(p,Probe)) + pass + + def test_probe_gen_from_datacube_ROI_1(self): + ROI = np.zeros(self.datacube.Rshape,dtype=bool) + ROI[3:7,5:10] = True + p = self.datacube.get_vacuum_probe( ROI ) + assert(isinstance(p,Probe)) + + self.datacube.tree() + self.datacube.tree(True) + _p = self.datacube.tree('probe') + print(_p) + + assert(p is self.datacube.tree('probe')) + pass + + def test_probe_gen_from_datacube_ROI_2(self): + ROI = (3,7,5,10) + p = self.datacube.get_vacuum_probe( ROI ) + assert(isinstance(p,Probe)) + assert(p is self.datacube.tree('probe')) + pass + + + + + + diff --git a/test/test_strain.py b/test/test_strain.py new file mode 100644 index 000000000..bc9b8b58c --- /dev/null +++ b/test/test_strain.py @@ -0,0 +1,33 @@ +import py4DSTEM +from py4DSTEM import StrainMap +from os.path import join +from numpy import zeros + + +# set filepath +path = join(py4DSTEM._TESTPATH,"strain/downsample_Si_SiGe_analysis_braggdisks_cal.h5") + + +class TestStrainMap: + + # setup/teardown + def setup_class(cls): + + # Read braggpeaks + # origin is calibrated + cls.braggpeaks = py4DSTEM.io.read( path ) + + + # tests + + def test_strainmap_instantiation(self): + + strainmap = StrainMap( + braggvectors = self.braggpeaks, + ) + + assert(isinstance(strainmap, StrainMap)) + assert(strainmap.calibration is not None) + assert(strainmap.calibration is strainmap.braggvectors.calibration) + + diff --git a/test/test_workflow/test_basics.py b/test/test_workflow/test_basics.py index 083b15fcb..cb30f5d6c 100644 --- a/test/test_workflow/test_basics.py +++ b/test/test_workflow/test_basics.py @@ -2,7 +2,7 @@ from os.path import join # set filepath -path = join(py4DSTEM._TESTPATH,"simulatedAuNanoplatelet_binned_v0_9.h5") +path = join(py4DSTEM._TESTPATH,"test_io/legacy_v0.9_simAuNanoplatelet_bin.h5") diff --git a/test/test_workflow/test_disk_detection_basic.py b/test/test_workflow/test_disk_detection_basic.py index 18296d0a7..60a5e2aa1 100644 --- a/test/test_workflow/test_disk_detection_basic.py +++ b/test/test_workflow/test_disk_detection_basic.py @@ -3,7 +3,7 @@ from numpy import zeros # set filepath -path = join(py4DSTEM._TESTPATH,"simulatedAuNanoplatelet_binned_v0_9.h5") +path = join(py4DSTEM._TESTPATH,"test_io/legacy_v0.9_simAuNanoplatelet_bin.h5") @@ -23,7 +23,7 @@ def setup_class(cls): mask = zeros(datacube.Rshape,dtype=bool) mask[28:33,14:19] = 1 probe = datacube.get_vacuum_probe( ROI=mask ) - alpha_pr,qx0_pr,qy0_pr = py4DSTEM.process.probe.get_probe_size( probe.probe ) + alpha_pr,qx0_pr,qy0_pr = py4DSTEM.process.calibration.get_probe_size( probe.probe ) probe.get_kernel( mode='sigmoid', origin=(qx0_pr,qy0_pr), diff --git a/test/test_workflow/test_disk_detection_with_calibration.py b/test/test_workflow/test_disk_detection_with_calibration.py index 83e52050f..b93e2d00b 100644 --- a/test/test_workflow/test_disk_detection_with_calibration.py +++ b/test/test_workflow/test_disk_detection_with_calibration.py @@ -3,7 +3,7 @@ from numpy import zeros # set filepath -path = join(py4DSTEM._TESTPATH,"simulatedAuNanoplatelet_binned_v0_9.h5") +path = join(py4DSTEM._TESTPATH,"test_io/legacy_v0.9_simAuNanoplatelet_bin.h5") @@ -23,7 +23,7 @@ def setup_class(cls): mask = zeros(datacube.Rshape,dtype=bool) mask[28:33,14:19] = 1 probe = datacube.get_vacuum_probe( ROI=mask ) - alpha_pr,qx0_pr,qy0_pr = py4DSTEM.process.probe.get_probe_size( probe.probe ) + alpha_pr,qx0_pr,qy0_pr = py4DSTEM.process.calibration.get_probe_size( probe.probe ) probe.get_kernel( mode='sigmoid', origin=(qx0_pr,qy0_pr), diff --git a/test/unit_test_data/crystal/LCO.cif b/test/unit_test_data/crystal/LCO.cif deleted file mode 100644 index 8fe2b777e..000000000 --- a/test/unit_test_data/crystal/LCO.cif +++ /dev/null @@ -1,96 +0,0 @@ -#------------------------------------------------------------------------------ -#$Date: 2016-02-13 21:28:24 +0200 (Sat, 13 Feb 2016) $ -#$Revision: 176429 $ -#$URL: svn://www.crystallography.net/cod/cif/1/53/38/1533825.cif $ -#------------------------------------------------------------------------------ -# -# This file is available in the Crystallography Open Database (COD), -# http://www.crystallography.net/ -# -# All data on this site have been placed in the public domain by the -# contributors. -# -data_1533825 -loop_ -_publ_author_name -'Hong Jin K' -'Oh Seung M' -_publ_section_title -; - Crystal structure and electrochemical performance of Li Ni1-x Cox O2 (0 - <= x < 1.0) according to Co substitution -; -_journal_name_full -'Journal of the Korean Electrochemical Society' -_journal_page_first 1 -_journal_page_last 5 -_journal_volume 6 -_journal_year 2003 -_chemical_formula_sum 'Co Li O2' -_chemical_name_systematic 'Li (Co O2)' -_space_group_IT_number 166 -_symmetry_space_group_name_Hall '-R 3 2"' -_symmetry_space_group_name_H-M 'R -3 m :H' -_cell_angle_alpha 90 -_cell_angle_beta 90 -_cell_angle_gamma 120 -_cell_formula_units_Z 3 -_cell_length_a 2.92 -_cell_length_b 2.92 -_cell_length_c 14.25 -_cell_volume 96.337 -_citation_journal_id_ASTM JKESFC -_cod_data_source_file HongJinK_JKESFC_2003_777.cif -_cod_data_source_block Co1Li1O2 -_cod_original_cell_volume 96.33704 -_cod_original_formula_sum 'Co1 Li1 O2' -_cod_database_code 1533825 -loop_ -_symmetry_equiv_pos_as_xyz -x,y,z --y,x-y,z --x+y,-x,z -y,x,-z --x,-x+y,-z -x-y,-y,-z --x,-y,-z -y,-x+y,-z -x-y,x,-z --y,-x,z -x,x-y,z --x+y,y,z -x+2/3,y+1/3,z+1/3 --y+2/3,x-y+1/3,z+1/3 --x+y+2/3,-x+1/3,z+1/3 -y+2/3,x+1/3,-z+1/3 --x+2/3,-x+y+1/3,-z+1/3 -x-y+2/3,-y+1/3,-z+1/3 --x+2/3,-y+1/3,-z+1/3 -y+2/3,-x+y+1/3,-z+1/3 -x-y+2/3,x+1/3,-z+1/3 --y+2/3,-x+1/3,z+1/3 -x+2/3,x-y+1/3,z+1/3 --x+y+2/3,y+1/3,z+1/3 -x+1/3,y+2/3,z+2/3 --y+1/3,x-y+2/3,z+2/3 --x+y+1/3,-x+2/3,z+2/3 -y+1/3,x+2/3,-z+2/3 --x+1/3,-x+y+2/3,-z+2/3 -x-y+1/3,-y+2/3,-z+2/3 --x+1/3,-y+2/3,-z+2/3 -y+1/3,-x+y+2/3,-z+2/3 -x-y+1/3,x+2/3,-z+2/3 --y+1/3,-x+2/3,z+2/3 -x+1/3,x-y+2/3,z+2/3 --x+y+1/3,y+2/3,z+2/3 -loop_ -_atom_site_label -_atom_site_type_symbol -_atom_site_fract_x -_atom_site_fract_y -_atom_site_fract_z -_atom_site_occupancy -_atom_site_U_iso_or_equiv -O1 O-2 0 0 0.2268 1 0.0 -Co1 Co+3 0 0 0 1 0.0 -Li1 Li+1 0 0 0.5 1 0.0 diff --git a/test/unit_test_data/crystal/Li2MnO3.cif b/test/unit_test_data/crystal/Li2MnO3.cif deleted file mode 100644 index 7a74b4fae..000000000 --- a/test/unit_test_data/crystal/Li2MnO3.cif +++ /dev/null @@ -1,79 +0,0 @@ -#------------------------------------------------------------------------------ -#$Date: 2015-01-27 21:58:39 +0200 (Tue, 27 Jan 2015) $ -#$Revision: 130149 $ -#$URL: svn://www.crystallography.net/cod/cif/1/00/83/1008373.cif $ -#------------------------------------------------------------------------------ -# -# This file is available in the Crystallography Open Database (COD), -# http://www.crystallography.net/ -# -# All data on this site have been placed in the public domain by the -# contributors. -# -data_1008373 -loop_ -_publ_author_name -'Strobel, P' -'Lambert-Andron, B' -_publ_section_title -; -Crystallographic and Magnetic Structure of Li~2~ Mn O~3~ -; -_journal_coden_ASTM JSSCBI -_journal_name_full 'Journal of Solid State Chemistry' -_journal_page_first 90 -_journal_page_last 98 -_journal_paper_doi 10.1016/0022-4596(88)90305-2 -_journal_volume 75 -_journal_year 1988 -_chemical_formula_structural 'Li2 Mn O3' -_chemical_formula_sum 'Li2 Mn O3' -_chemical_name_systematic 'Dilithium manganese(IV) trioxide' -_space_group_IT_number 12 -_symmetry_cell_setting monoclinic -_symmetry_Int_Tables_number 12 -_symmetry_space_group_name_Hall '-C 2y' -_symmetry_space_group_name_H-M 'C 1 2/m 1' -_cell_angle_alpha 90 -_cell_angle_beta 109.46(3) -_cell_angle_gamma 90 -_cell_formula_units_Z 4 -_cell_length_a 5.1 -_cell_length_b 8.5 -_cell_length_c 5.1 -_cell_volume 199.8 -_refine_ls_R_factor_all 0.02 -_cod_database_code 1008373 -loop_ -_symmetry_equiv_pos_as_xyz -x,y,z -x,-y,z --x,-y,-z --x,y,-z -1/2+x,1/2+y,z -1/2+x,1/2-y,z -1/2-x,1/2-y,-z -1/2-x,1/2+y,-z -loop_ -_atom_site_label -_atom_site_type_symbol -_atom_site_symmetry_multiplicity -_atom_site_Wyckoff_symbol -_atom_site_fract_x -_atom_site_fract_y -_atom_site_fract_z -_atom_site_occupancy -_atom_site_attached_hydrogens -_atom_site_calc_flag -Mn1 Mn4+ 4 g 0. 0.16708(2) 0. 1. 0 d -Li1 Li1+ 2 b 0. 0.5 0. 1. 0 d -Li2 Li1+ 2 c 0. 0. 0.5 1. 0 d -Li3 Li1+ 4 h 0. 0.6606(3) 0.5 1. 0 d -O1 O2- 4 i 0.2189(2) 0. 0.2273(2) 1. 0 d -O2 O2- 8 j 0.2540(1) 0.32119(7) 0.2233(1) 1. 0 d -loop_ -_atom_type_symbol -_atom_type_oxidation_number -Mn4+ 4.000 -Li1+ 1.000 -O2- -2.000 diff --git a/test/unit_test_data/crystal/LiMn2O4.cif b/test/unit_test_data/crystal/LiMn2O4.cif deleted file mode 100644 index 81f8db377..000000000 --- a/test/unit_test_data/crystal/LiMn2O4.cif +++ /dev/null @@ -1,277 +0,0 @@ -#------------------------------------------------------------------------------ -#$Date: 2016-02-20 20:10:49 +0200 (Sat, 20 Feb 2016) $ -#$Revision: 176788 $ -#$URL: svn://www.crystallography.net/cod/cif/1/51/39/1513962.cif $ -#------------------------------------------------------------------------------ -# -# This file is available in the Crystallography Open Database (COD), -# http://www.crystallography.net/ -# -# All data on this site have been placed in the public domain by the -# contributors. -# -data_1513962 -loop_ -_publ_author_name -'Mosbah, A.' -'Verbaere, A.' -'Tournoux, M.' -_publ_section_title -; - Phases Lix Mn O2 lambda rattachees au type spinelle -; -_journal_coden_ASTM MRBUAC -_journal_name_full 'Materials Research Bulletin' -_journal_page_first 1375 -_journal_page_last 1381 -_journal_year 1983 -_chemical_formula_structural 'Li Mn2 O4' -_chemical_formula_sum 'Li Mn2 O4' -_chemical_name_systematic 'Lithium Dimanganese Tetraoxide' -_space_group_IT_number 227 -_symmetry_Int_Tables_number 227 -_symmetry_space_group_name_Hall '-F 4vw 2vw 3' -_symmetry_space_group_name_H-M 'F d -3 m :2' -_audit_creation_date 1999-11-30 -_audit_update_record 2005-10-01 -_cell_angle_alpha 90. -_cell_angle_beta 90. -_cell_angle_gamma 90. -_cell_formula_units_Z 8 -_cell_length_a 8.245(1) -_cell_length_b 8.245(1) -_cell_length_c 8.245(1) -_cell_volume 560.50(12) -_refine_ls_R_factor_all 0.036 -_cod_data_source_file 'data_LiMn2O4melanie040485.cif' -_cod_data_source_block -/var/www/cod/tmp/uploads/1395678109.95-D290F29E9218B844.cif -_cod_original_cell_volume 560.5 -_cod_original_sg_symbol_H-M 'F d -3 m Z' -_cod_original_formula_sum 'Li1 Mn2 O4' -_cod_database_code 1513962 -loop_ -_symmetry_equiv_pos_site_id -_symmetry_equiv_pos_as_xyz -1 '-z, y+3/4, x+3/4' -2 'z+3/4, -y, x+3/4' -3 'z+3/4, y+3/4, -x' -4 '-z, -y, -x' -5 'y+3/4, x+3/4, -z' -6 '-y, x+3/4, z+3/4' -7 'y+3/4, -x, z+3/4' -8 '-y, -x, -z' -9 'x+3/4, -z, y+3/4' -10 'x+3/4, z+3/4, -y' -11 '-x, z+3/4, y+3/4' -12 '-x, -z, -y' -13 '-z, x+3/4, y+3/4' -14 'z+3/4, x+3/4, -y' -15 'z+3/4, -x, y+3/4' -16 '-z, -x, -y' -17 'y+3/4, -z, x+3/4' -18 '-y, z+3/4, x+3/4' -19 'y+3/4, z+3/4, -x' -20 '-y, -z, -x' -21 'x+3/4, y+3/4, -z' -22 'x+3/4, -y, z+3/4' -23 '-x, y+3/4, z+3/4' -24 '-x, -y, -z' -25 'z, -y+1/4, -x+1/4' -26 '-z+1/4, y, -x+1/4' -27 '-z+1/4, -y+1/4, x' -28 'z, y, x' -29 '-y+1/4, -x+1/4, z' -30 'y, -x+1/4, -z+1/4' -31 '-y+1/4, x, -z+1/4' -32 'y, x, z' -33 '-x+1/4, z, -y+1/4' -34 '-x+1/4, -z+1/4, y' -35 'x, -z+1/4, -y+1/4' -36 'x, z, y' -37 'z, -x+1/4, -y+1/4' -38 '-z+1/4, -x+1/4, y' -39 '-z+1/4, x, -y+1/4' -40 'z, x, y' -41 '-y+1/4, z, -x+1/4' -42 'y, -z+1/4, -x+1/4' -43 '-y+1/4, -z+1/4, x' -44 'y, z, x' -45 '-x+1/4, -y+1/4, z' -46 '-x+1/4, y, -z+1/4' -47 'x, -y+1/4, -z+1/4' -48 'x, y, z' -49 '-z, y+1/4, x+1/4' -50 '-z+1/2, y+3/4, x+1/4' -51 '-z+1/2, y+1/4, x+3/4' -52 'z+3/4, -y+1/2, x+1/4' -53 'z+1/4, -y, x+1/4' -54 'z+1/4, -y+1/2, x+3/4' -55 'z+3/4, y+1/4, -x+1/2' -56 'z+1/4, y+3/4, -x+1/2' -57 'z+1/4, y+1/4, -x' -58 '-z, -y+1/2, -x+1/2' -59 '-z+1/2, -y, -x+1/2' -60 '-z+1/2, -y+1/2, -x' -61 'y+3/4, x+1/4, -z+1/2' -62 'y+1/4, x+3/4, -z+1/2' -63 'y+1/4, x+1/4, -z' -64 '-y, x+1/4, z+1/4' -65 '-y+1/2, x+3/4, z+1/4' -66 '-y+1/2, x+1/4, z+3/4' -67 'y+3/4, -x+1/2, z+1/4' -68 'y+1/4, -x, z+1/4' -69 'y+1/4, -x+1/2, z+3/4' -70 '-y, -x+1/2, -z+1/2' -71 '-y+1/2, -x, -z+1/2' -72 '-y+1/2, -x+1/2, -z' -73 'x+3/4, -z+1/2, y+1/4' -74 'x+1/4, -z, y+1/4' -75 'x+1/4, -z+1/2, y+3/4' -76 'x+3/4, z+1/4, -y+1/2' -77 'x+1/4, z+3/4, -y+1/2' -78 'x+1/4, z+1/4, -y' -79 '-x, z+1/4, y+1/4' -80 '-x+1/2, z+3/4, y+1/4' -81 '-x+1/2, z+1/4, y+3/4' -82 '-x, -z+1/2, -y+1/2' -83 '-x+1/2, -z, -y+1/2' -84 '-x+1/2, -z+1/2, -y' -85 '-z, x+1/4, y+1/4' -86 '-z+1/2, x+3/4, y+1/4' -87 '-z+1/2, x+1/4, y+3/4' -88 'z+3/4, x+1/4, -y+1/2' -89 'z+1/4, x+3/4, -y+1/2' -90 'z+1/4, x+1/4, -y' -91 'z+3/4, -x+1/2, y+1/4' -92 'z+1/4, -x, y+1/4' -93 'z+1/4, -x+1/2, y+3/4' -94 '-z, -x+1/2, -y+1/2' -95 '-z+1/2, -x, -y+1/2' -96 '-z+1/2, -x+1/2, -y' -97 'y+3/4, -z+1/2, x+1/4' -98 'y+1/4, -z, x+1/4' -99 'y+1/4, -z+1/2, x+3/4' -100 '-y, z+1/4, x+1/4' -101 '-y+1/2, z+3/4, x+1/4' -102 '-y+1/2, z+1/4, x+3/4' -103 'y+3/4, z+1/4, -x+1/2' -104 'y+1/4, z+3/4, -x+1/2' -105 'y+1/4, z+1/4, -x' -106 '-y, -z+1/2, -x+1/2' -107 '-y+1/2, -z, -x+1/2' -108 '-y+1/2, -z+1/2, -x' -109 'x+3/4, y+1/4, -z+1/2' -110 'x+1/4, y+3/4, -z+1/2' -111 'x+1/4, y+1/4, -z' -112 'x+3/4, -y+1/2, z+1/4' -113 'x+1/4, -y, z+1/4' -114 'x+1/4, -y+1/2, z+3/4' -115 '-x, y+1/4, z+1/4' -116 '-x+1/2, y+3/4, z+1/4' -117 '-x+1/2, y+1/4, z+3/4' -118 '-x, -y+1/2, -z+1/2' -119 '-x+1/2, -y, -z+1/2' -120 '-x+1/2, -y+1/2, -z' -121 'z, -y+3/4, -x+3/4' -122 'z+1/2, -y+1/4, -x+3/4' -123 'z+1/2, -y+3/4, -x+1/4' -124 '-z+1/4, y+1/2, -x+3/4' -125 '-z+3/4, y, -x+3/4' -126 '-z+3/4, y+1/2, -x+1/4' -127 '-z+1/4, -y+3/4, x+1/2' -128 '-z+3/4, -y+1/4, x+1/2' -129 '-z+3/4, -y+3/4, x' -130 'z, y+1/2, x+1/2' -131 'z+1/2, y, x+1/2' -132 'z+1/2, y+1/2, x' -133 '-y+1/4, -x+3/4, z+1/2' -134 '-y+3/4, -x+1/4, z+1/2' -135 '-y+3/4, -x+3/4, z' -136 'y, -x+3/4, -z+3/4' -137 'y+1/2, -x+1/4, -z+3/4' -138 'y+1/2, -x+3/4, -z+1/4' -139 '-y+1/4, x+1/2, -z+3/4' -140 '-y+3/4, x, -z+3/4' -141 '-y+3/4, x+1/2, -z+1/4' -142 'y, x+1/2, z+1/2' -143 'y+1/2, x, z+1/2' -144 'y+1/2, x+1/2, z' -145 '-x+1/4, z+1/2, -y+3/4' -146 '-x+3/4, z, -y+3/4' -147 '-x+3/4, z+1/2, -y+1/4' -148 '-x+1/4, -z+3/4, y+1/2' -149 '-x+3/4, -z+1/4, y+1/2' -150 '-x+3/4, -z+3/4, y' -151 'x, -z+3/4, -y+3/4' -152 'x+1/2, -z+1/4, -y+3/4' -153 'x+1/2, -z+3/4, -y+1/4' -154 'x, z+1/2, y+1/2' -155 'x+1/2, z, y+1/2' -156 'x+1/2, z+1/2, y' -157 'z, -x+3/4, -y+3/4' -158 'z+1/2, -x+1/4, -y+3/4' -159 'z+1/2, -x+3/4, -y+1/4' -160 '-z+1/4, -x+3/4, y+1/2' -161 '-z+3/4, -x+1/4, y+1/2' -162 '-z+3/4, -x+3/4, y' -163 '-z+1/4, x+1/2, -y+3/4' -164 '-z+3/4, x, -y+3/4' -165 '-z+3/4, x+1/2, -y+1/4' -166 'z, x+1/2, y+1/2' -167 'z+1/2, x, y+1/2' -168 'z+1/2, x+1/2, y' -169 '-y+1/4, z+1/2, -x+3/4' -170 '-y+3/4, z, -x+3/4' -171 '-y+3/4, z+1/2, -x+1/4' -172 'y, -z+3/4, -x+3/4' -173 'y+1/2, -z+1/4, -x+3/4' -174 'y+1/2, -z+3/4, -x+1/4' -175 '-y+1/4, -z+3/4, x+1/2' -176 '-y+3/4, -z+1/4, x+1/2' -177 '-y+3/4, -z+3/4, x' -178 'y, z+1/2, x+1/2' -179 'y+1/2, z, x+1/2' -180 'y+1/2, z+1/2, x' -181 '-x+1/4, -y+3/4, z+1/2' -182 '-x+3/4, -y+1/4, z+1/2' -183 '-x+3/4, -y+3/4, z' -184 '-x+1/4, y+1/2, -z+3/4' -185 '-x+3/4, y, -z+3/4' -186 '-x+3/4, y+1/2, -z+1/4' -187 'x, -y+3/4, -z+3/4' -188 'x+1/2, -y+1/4, -z+3/4' -189 'x+1/2, -y+3/4, -z+1/4' -190 'x, y+1/2, z+1/2' -191 'x+1/2, y, z+1/2' -192 'x+1/2, y+1/2, z' -loop_ -_atom_site_label -_atom_site_type_symbol -_atom_site_symmetry_multiplicity -_atom_site_Wyckoff_symbol -_atom_site_fract_x -_atom_site_fract_y -_atom_site_fract_z -_atom_site_occupancy -_atom_site_attached_hydrogens -_atom_site_B_iso_or_equiv -Li1 Li1+ 8 a 0.125 0.125 0.125 1. 0 0.5 -Mn1 Mn3.5+ 16 d 0.5 0.5 0.5 1. 0 0.3(2) -O1 O2- 32 e 0.261(2) 0.261(2) 0.261(2) 1. 0 0.4(2) -loop_ -_atom_type_symbol -_atom_type_oxidation_number -Li1+ 1 -Mn3.5+ 3.5 -O2- -2 -loop_ -_citation_id -_citation_journal_full -_citation_year -_citation_journal_volume -_citation_page_first -_citation_page_last -_citation_journal_id_ASTM -primary 'Materials Research Bulletin' 1983 18 1375 1381 MRBUAC -2 'Golden Book of Phase Transitions, Wroclaw' 2002 1 1 123 GBOPT5 diff --git a/test/unit_test_data/crystal/braggdisks_cali.h5 b/test/unit_test_data/crystal/braggdisks_cali.h5 deleted file mode 100644 index 066474fad..000000000 Binary files a/test/unit_test_data/crystal/braggdisks_cali.h5 and /dev/null differ