Skip to content

Commit

Permalink
Add tof spotfinding (dials#2602)
Browse files Browse the repository at this point in the history
* Enabled spot finding for time of flight data
  • Loading branch information
toastisme committed Mar 1, 2024
1 parent 6372cee commit 0f4a435
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 7 deletions.
1 change: 1 addition & 0 deletions newsfragments/2602.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `TOFSpotFinder` to tailor default params to time of flight experiments and add additional reflection table data.
57 changes: 56 additions & 1 deletion src/dials/algorithms/centroid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
from __future__ import annotations

from dxtbx.model.tof_helpers import frame_to_tof_interpolator
from scitbx.array_family import flex


def centroid_px_to_mm(detector, scan, position, variance, sd_error):
"""Convenience function to calculate centroid in mm/rad from px"""

# Get the pixel to millimeter function
assert len(detector) == 1
if scan is not None and scan.has_property("time_of_flight"):
return tof_centroid_px_to_mm_panel(
detector[0], scan, position, variance, sd_error
)
return centroid_px_to_mm_panel(detector[0], scan, position, variance, sd_error)


def centroid_px_to_mm_panel(panel, scan, position, variance, sd_error):
"""Convenience function to calculate centroid in mm/rad from px"""
# Get the pixel to millimeter function

if scan is not None and scan.has_property("time_of_flight"):
return tof_centroid_px_to_mm_panel(panel, scan, position, variance, sd_error)

pixel_size = panel.get_pixel_size()
if scan is None:
oscillation = (0, 0)
Expand All @@ -38,7 +49,6 @@ def centroid_px_to_mm_panel(panel, scan, position, variance, sd_error):
sd_error_mm = [sde * s for sde, s in zip(sd_error, scale2)]

else:
from scitbx.array_family import flex

# Convert Pixel coordinate into mm/rad
x, y, z = position.parts()
Expand All @@ -61,3 +71,48 @@ def centroid_px_to_mm_panel(panel, scan, position, variance, sd_error):

# Return the stuff in mm/rad
return position_mm, variance_mm, sd_error_mm


def tof_centroid_px_to_mm_panel(panel, scan, position, variance, sd_error):
"""Convenience function to calculate centroid in mm/tof from px"""
assert scan is not None and scan.has_property("time_of_flight")

pixel_size = panel.get_pixel_size()
tof = scan.get_property("time_of_flight") # (usec)
frames = [i + 1 for i in range(len(tof))]
frame_to_tof = frame_to_tof_interpolator(frames, tof)

if isinstance(position, tuple):
# Convert Pixel coordinate into mm/rad
x, y, z = position
xy_mm = panel.pixel_to_millimeter((x, y))
z_tof = flex.double(frame_to_tof(z))
scale = pixel_size + (z_tof,)
scale2 = tuple(s * s for s in scale)

# Set the position, variance and squared width in mm/tof
# N.B assuming locally flat pixel to millimeter transform
# for variance calculation.
position_mm = xy_mm + (z_tof,)
variance_mm = [var * s for var, s in zip(variance, scale2)]
sd_error_mm = [sde * s for sde, s in zip(sd_error, scale2)]

else:

# Convert Pixel coordinate into mm/tof
x, y, z = position.parts()
xy_mm = panel.pixel_to_millimeter(flex.vec2_double(x, y))
z_tof = flex.double(frame_to_tof(z))
scale = tuple(s * s for s in pixel_size)

# Set the position, variance and squared width in mm/tof
# N.B assuming locally flat pixel to millimeter transform
# for variance calculation.
x_mm, y_mm = xy_mm.parts()
position_mm = flex.vec3_double(x_mm, y_mm, z_tof)
v0, v1, v2 = variance.parts()
variance_mm = flex.vec3_double(v0 * scale[0], v1 * scale[1], v2 * z_tof * z_tof)
s0, s1, s2 = sd_error.parts()
sd_error_mm = flex.vec3_double(s0 * scale[0], s1 * scale[1], s2 * z_tof * z_tof)

return position_mm, variance_mm, sd_error_mm
47 changes: 43 additions & 4 deletions src/dials/algorithms/spot_finding/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import dials.extensions
import dials.util.masking
from dials.algorithms.background.simple import Linear2dModeller
from dials.algorithms.spot_finding.finder import SpotFinder
from dials.algorithms.spot_finding.finder import SpotFinder, TOFSpotFinder
from dials.array_family import flex

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -438,9 +438,6 @@ def from_parameters(params=None, experiments=None, is_stills=False):
mask = SpotFinderFactory.load_image(params.spotfinder.lookup.mask)
params.spotfinder.lookup.mask = mask

# Configure the filter options
filter_spots = SpotFinderFactory.configure_filter(params)

# Create the threshold strategy
threshold_function = SpotFinderFactory.configure_threshold(params)

Expand All @@ -453,6 +450,48 @@ def from_parameters(params=None, experiments=None, is_stills=False):
params.spotfinder.mp.method = None

# Setup the spot finder
contains_tof_experiments = False
for experiment in experiments:
if experiment.scan is None:
continue
if experiment.scan.has_property("time_of_flight"):
contains_tof_experiments = True
elif contains_tof_experiments:
raise RuntimeError("All experiment scans must contain time_of_flight")

if contains_tof_experiments:

# ToF spots from spallation sources typically have elongated tails
if params.spotfinder.filter.max_separation < 6:
# Based on ISISSXD data
# https://zenodo.org/records/4415768
logger.info("Increasing max allowed peak-centroid distance to 6px")
params.spotfinder.filter.max_separation = 6
filter_spots = SpotFinderFactory.configure_filter(params)

return TOFSpotFinder(
experiments=experiments,
threshold_function=threshold_function,
mask=params.spotfinder.lookup.mask,
filter_spots=filter_spots,
scan_range=params.spotfinder.scan_range,
write_hot_mask=params.spotfinder.write_hot_mask,
hot_mask_prefix=params.spotfinder.hot_mask_prefix,
mp_method=params.spotfinder.mp.method,
mp_nproc=params.spotfinder.mp.nproc,
mp_njobs=params.spotfinder.mp.njobs,
mp_chunksize=params.spotfinder.mp.chunksize,
max_strong_pixel_fraction=params.spotfinder.filter.max_strong_pixel_fraction,
compute_mean_background=params.spotfinder.compute_mean_background,
region_of_interest=params.spotfinder.region_of_interest,
mask_generator=mask_generator,
min_spot_size=params.spotfinder.filter.min_spot_size,
max_spot_size=params.spotfinder.filter.max_spot_size,
min_chunksize=params.spotfinder.mp.min_chunksize,
)

filter_spots = SpotFinderFactory.configure_filter(params)

return SpotFinder(
threshold_function=threshold_function,
mask=params.spotfinder.lookup.mask,
Expand Down
113 changes: 111 additions & 2 deletions src/dials/algorithms/spot_finding/finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dxtbx.format.image import ImageBool
from dxtbx.imageset import ImageSequence, ImageSet
from dxtbx.model import ExperimentList
from dxtbx.model.tof_helpers import wavelength_from_tof

from dials.array_family import flex
from dials.model.data import PixelList, PixelListLabeller
Expand Down Expand Up @@ -747,10 +748,10 @@ def find_spots(self, experiments: ExperimentList) -> flex.reflection_table:
flex.size_t_range(len(reflections)), reflections.flags.strong
)

# Check for overloads
reflections.is_overloaded(experiments)

# Return the reflections
reflections = self._post_process(reflections)

return reflections

def _find_spots_in_imageset(self, imageset):
Expand Down Expand Up @@ -855,3 +856,111 @@ def _create_hot_mask(self, imageset, hot_pixels):

# Return the hot mask
return hot_mask

def _post_process(self, reflections):
return reflections


class TOFSpotFinder(SpotFinder):
"""
Class to do spot finding tailored to time of flight experiments
"""

def __init__(
self,
experiments,
threshold_function=None,
mask=None,
region_of_interest=None,
max_strong_pixel_fraction=0.1,
compute_mean_background=False,
mp_method=None,
mp_nproc=1,
mp_njobs=1,
mp_chunksize=1,
mask_generator=None,
filter_spots=None,
scan_range=None,
write_hot_mask=True,
hot_mask_prefix="hot_mask",
min_spot_size=1,
max_spot_size=20,
min_chunksize=50,
):

super().__init__(
threshold_function=threshold_function,
mask=mask,
region_of_interest=region_of_interest,
max_strong_pixel_fraction=max_strong_pixel_fraction,
compute_mean_background=compute_mean_background,
mp_method=mp_method,
mp_nproc=mp_nproc,
mp_njobs=mp_njobs,
mp_chunksize=mp_chunksize,
mask_generator=mask_generator,
filter_spots=filter_spots,
scan_range=scan_range,
write_hot_mask=write_hot_mask,
hot_mask_prefix=hot_mask_prefix,
min_spot_size=min_spot_size,
max_spot_size=max_spot_size,
no_shoeboxes_2d=False,
min_chunksize=min_chunksize,
is_stills=False,
)

self.experiments = experiments

def _correct_centroid_tof(self, reflections):

"""
Sets the centroid of the spot to the peak position along the
time of flight, as this tends to more accurately represent the true
centroid for spallation sources.
"""

x, y, tof = reflections["xyzobs.px.value"].parts()
peak_x, peak_y, peak_tof = reflections["shoebox"].peak_coordinates().parts()
reflections["xyzobs.px.value"] = flex.vec3_double(x, y, peak_tof)

return reflections

def _post_process(self, reflections):

reflections = self._correct_centroid_tof(reflections)

n_rows = reflections.nrows()
panel_numbers = flex.size_t(reflections["panel"])
reflections["L1"] = flex.double(n_rows)
reflections["wavelength"] = flex.double(n_rows)
reflections["s0"] = flex.vec3_double(n_rows)
reflections.centroid_px_to_mm(self.experiments)

for i, expt in enumerate(self.experiments):
if "imageset_id" in reflections:
sel_expt = reflections["imageset_id"] == i
else:
sel_expt = reflections["id"] == i

L0 = expt.beam.get_sample_to_source_distance() * 10**-3 # (m)
unit_s0 = expt.beam.get_unit_s0()

for i_panel in range(len(expt.detector)):

sel = sel_expt & (panel_numbers == i_panel)
x, y, tof = reflections["xyzobs.mm.value"].select(sel).parts()
px, py, frame = reflections["xyzobs.px.value"].select(sel).parts()
s1 = expt.detector[i_panel].get_lab_coord(flex.vec2_double(x, y))
L1 = s1.norms()
wavelengths = wavelength_from_tof(L0 + L1 * 10**-3, tof * 10**-6)
s0s = flex.vec3_double(
unit_s0[0] / wavelengths,
unit_s0[1] / wavelengths,
unit_s0[2] / wavelengths,
)

reflections["wavelength"].set_selected(sel, wavelengths)
reflections["s0"].set_selected(sel, s0s)
reflections["L1"].set_selected(sel, L1)
return reflections

0 comments on commit 0f4a435

Please sign in to comment.