From c90d6fc97c356431a9229b89e1f09683338db1df Mon Sep 17 00:00:00 2001 From: davidmcdonagh Date: Tue, 6 Feb 2024 11:50:35 +0000 Subject: [PATCH] Added _post_process method to SpotFinder. Added TOFSpotFinder. Added tof_centroid_px_to_mm_panel. --- src/dials/algorithms/centroid/__init__.py | 59 +++++++++++- src/dials/algorithms/spot_finding/factory.py | 45 ++++++++- src/dials/algorithms/spot_finding/finder.py | 97 +++++++++++++++++++- 3 files changed, 194 insertions(+), 7 deletions(-) diff --git a/src/dials/algorithms/centroid/__init__.py b/src/dials/algorithms/centroid/__init__.py index 1ddb10cacc..4ce44936c9 100644 --- a/src/dials/algorithms/centroid/__init__.py +++ b/src/dials/algorithms/centroid/__init__.py @@ -1,17 +1,28 @@ from __future__ import annotations +from dxtbx.model.tof_helpers import frame_to_tof_interpolator +from scitbx.array_family import flex + def centroid_px_to_mm(detector, scan, position, variance, sd_error): """Convenience function to calculate centroid in mm/rad from px""" # Get the pixel to millimeter function assert len(detector) == 1 + if scan.has_property("time_of_flight"): + return tof_centroid_px_to_mm_panel( + detector[0], scan, position, variance, sd_error + ) return centroid_px_to_mm_panel(detector[0], scan, position, variance, sd_error) def centroid_px_to_mm_panel(panel, scan, position, variance, sd_error): """Convenience function to calculate centroid in mm/rad from px""" # Get the pixel to millimeter function + + if scan.has_property("time_of_flight"): + return tof_centroid_px_to_mm_panel(panel, scan, position, variance, sd_error) + pixel_size = panel.get_pixel_size() if scan is None: oscillation = (0, 0) @@ -38,7 +49,6 @@ def centroid_px_to_mm_panel(panel, scan, position, variance, sd_error): sd_error_mm = [sde * s for sde, s in zip(sd_error, scale2)] else: - from scitbx.array_family import flex # Convert Pixel coordinate into mm/rad x, y, z = position.parts() @@ -61,3 +71,50 @@ def centroid_px_to_mm_panel(panel, scan, position, variance, sd_error): # Return the stuff in mm/rad return position_mm, variance_mm, sd_error_mm + + +def tof_centroid_px_to_mm_panel(panel, scan, position, variance, sd_error): + """Convenience function to calculate centroid in mm/tof from px""" + assert scan.has_property("time_of_flight") + + pixel_size = panel.get_pixel_size() + tof = scan.get_property("time_of_flight") # (usec) + frames = [i + 1 for i in range(len(tof))] + frame_to_tof = frame_to_tof_interpolator(frames, tof) + scale = pixel_size + (tof[0],) + scale2 = tuple(s * s for s in scale) + + if isinstance(position, tuple): + # Convert Pixel coordinate into mm/rad + x, y, z = position + xy_mm = panel.pixel_to_millimeter((x, y)) + z_tof = flex.double(frame_to_tof(z)) + scale = pixel_size + (z_tof,) + scale2 = tuple(s * s for s in scale) + + # Set the position, variance and squared width in mm/tof + # N.B assuming locally flat pixel to millimeter transform + # for variance calculation. + position_mm = xy_mm + (z_tof,) + variance_mm = [var * s for var, s in zip(variance, scale2)] + sd_error_mm = [sde * s for sde, s in zip(sd_error, scale2)] + + else: + + # Convert Pixel coordinate into mm/tof + x, y, z = position.parts() + xy_mm = panel.pixel_to_millimeter(flex.vec2_double(x, y)) + z_tof = flex.double(frame_to_tof(z)) + scale = tuple(s * s for s in pixel_size) + + # Set the position, variance and squared width in mm/tof + # N.B assuming locally flat pixel to millimeter transform + # for variance calculation. + x_mm, y_mm = xy_mm.parts() + position_mm = flex.vec3_double(x_mm, y_mm, z_tof) + v0, v1, v2 = variance.parts() + variance_mm = flex.vec3_double(v0 * scale[0], v1 * scale[1], v2 * z_tof * z_tof) + s0, s1, s2 = sd_error.parts() + sd_error_mm = flex.vec3_double(s0 * scale[0], s1 * scale[1], s2 * z_tof * z_tof) + + return position_mm, variance_mm, sd_error_mm diff --git a/src/dials/algorithms/spot_finding/factory.py b/src/dials/algorithms/spot_finding/factory.py index f3a3234cce..7cf8400afc 100644 --- a/src/dials/algorithms/spot_finding/factory.py +++ b/src/dials/algorithms/spot_finding/factory.py @@ -13,7 +13,7 @@ import dials.extensions import dials.util.masking from dials.algorithms.background.simple import Linear2dModeller -from dials.algorithms.spot_finding.finder import SpotFinder +from dials.algorithms.spot_finding.finder import SpotFinder, TOFSpotFinder from dials.array_family import flex logger = logging.getLogger(__name__) @@ -438,9 +438,6 @@ def from_parameters(params=None, experiments=None, is_stills=False): mask = SpotFinderFactory.load_image(params.spotfinder.lookup.mask) params.spotfinder.lookup.mask = mask - # Configure the filter options - filter_spots = SpotFinderFactory.configure_filter(params) - # Create the threshold strategy threshold_function = SpotFinderFactory.configure_threshold(params) @@ -453,6 +450,46 @@ def from_parameters(params=None, experiments=None, is_stills=False): params.spotfinder.mp.method = None # Setup the spot finder + contains_tof_experiments = False + for experiment in experiments: + if experiment.scan.has_property("time_of_flight"): + contains_tof_experiments = True + elif contains_tof_experiments: + raise RuntimeError("All experiment scans must contain time_of_flight") + + if contains_tof_experiments: + + # ToF spots from spallation sources typically have elongated tails + if params.spotfinder.filter.max_separation < 6: + # Based on ISISSXD data + # https://zenodo.org/records/4415768 + logger.info("Increasing max allowed peak-centroid distance to 6px") + params.spotfinder.filter.max_separation = 6 + filter_spots = SpotFinderFactory.configure_filter(params) + + return TOFSpotFinder( + experiments=experiments, + threshold_function=threshold_function, + mask=params.spotfinder.lookup.mask, + filter_spots=filter_spots, + scan_range=params.spotfinder.scan_range, + write_hot_mask=params.spotfinder.write_hot_mask, + hot_mask_prefix=params.spotfinder.hot_mask_prefix, + mp_method=params.spotfinder.mp.method, + mp_nproc=params.spotfinder.mp.nproc, + mp_njobs=params.spotfinder.mp.njobs, + mp_chunksize=params.spotfinder.mp.chunksize, + max_strong_pixel_fraction=params.spotfinder.filter.max_strong_pixel_fraction, + compute_mean_background=params.spotfinder.compute_mean_background, + region_of_interest=params.spotfinder.region_of_interest, + mask_generator=mask_generator, + min_spot_size=params.spotfinder.filter.min_spot_size, + max_spot_size=params.spotfinder.filter.max_spot_size, + min_chunksize=params.spotfinder.mp.min_chunksize, + ) + + filter_spots = SpotFinderFactory.configure_filter(params) + return SpotFinder( threshold_function=threshold_function, mask=params.spotfinder.lookup.mask, diff --git a/src/dials/algorithms/spot_finding/finder.py b/src/dials/algorithms/spot_finding/finder.py index e02afd40fc..c2a284b860 100644 --- a/src/dials/algorithms/spot_finding/finder.py +++ b/src/dials/algorithms/spot_finding/finder.py @@ -13,6 +13,7 @@ from dxtbx.format.image import ImageBool from dxtbx.imageset import ImageSequence, ImageSet from dxtbx.model import ExperimentList +from dxtbx.model.tof_helpers import wavelength_from_tof from dials.array_family import flex from dials.model.data import PixelList, PixelListLabeller @@ -746,10 +747,10 @@ def find_spots(self, experiments: ExperimentList) -> flex.reflection_table: flex.size_t_range(len(reflections)), reflections.flags.strong ) - # Check for overloads reflections.is_overloaded(experiments) - # Return the reflections + reflections = self._post_process(reflections) + return reflections def _find_spots_in_imageset(self, imageset): @@ -854,3 +855,95 @@ def _create_hot_mask(self, imageset, hot_pixels): # Return the hot mask return hot_mask + + def _post_process(self, reflections): + return reflections + + +class TOFSpotFinder(SpotFinder): + """ + Class to do spot finding tailored to time of flight experiments + """ + + def __init__( + self, + experiments, + threshold_function=None, + mask=None, + region_of_interest=None, + max_strong_pixel_fraction=0.1, + compute_mean_background=False, + mp_method=None, + mp_nproc=1, + mp_njobs=1, + mp_chunksize=1, + mask_generator=None, + filter_spots=None, + scan_range=None, + write_hot_mask=True, + hot_mask_prefix="hot_mask", + min_spot_size=1, + max_spot_size=20, + min_chunksize=50, + ): + + super().__init__( + threshold_function=threshold_function, + mask=mask, + region_of_interest=region_of_interest, + max_strong_pixel_fraction=max_strong_pixel_fraction, + compute_mean_background=compute_mean_background, + mp_method=mp_method, + mp_nproc=mp_nproc, + mp_njobs=mp_njobs, + mp_chunksize=mp_chunksize, + mask_generator=mask_generator, + filter_spots=filter_spots, + scan_range=scan_range, + write_hot_mask=write_hot_mask, + hot_mask_prefix=hot_mask_prefix, + min_spot_size=min_spot_size, + max_spot_size=max_spot_size, + no_shoeboxes_2d=False, + min_chunksize=min_chunksize, + is_stills=False, + ) + + self.experiments = experiments + + def _post_process(self, reflections): + + n_rows = reflections.nrows() + panel_numbers = flex.size_t(reflections["panel"]) + reflections["L1"] = flex.double(n_rows) + reflections["wavelength"] = flex.double(n_rows) + reflections["s0"] = flex.vec3_double(n_rows) + reflections.centroid_px_to_mm(self.experiments) + + for i, expt in enumerate(self.experiments): + if "imageset_id" in reflections: + sel_expt = reflections["imageset_id"] == i + else: + sel_expt = reflections["id"] == i + + L0 = expt.beam.get_sample_to_source_distance() * 10**-3 # (m) + unit_s0 = expt.beam.get_unit_s0() + + for i_panel in range(len(expt.detector)): + + sel = sel_expt & (panel_numbers == i_panel) + x, y, tof = reflections["xyzobs.mm.value"].select(sel).parts() + px, py, frame = reflections["xyzobs.px.value"].select(sel).parts() + s1 = expt.detector[i_panel].get_lab_coord(flex.vec2_double(x, y)) + L1 = s1.norms() + wavelengths = wavelength_from_tof(L0 + L1 * 10**-3, tof * 10**-6) + s0s = flex.vec3_double( + unit_s0[0] / wavelengths, + unit_s0[1] / wavelengths, + unit_s0[2] / wavelengths, + ) + + reflections["wavelength"].set_selected(sel, wavelengths) + reflections["s0"].set_selected(sel, s0s) + reflections["L1"].set_selected(sel, L1) + return reflections