Skip to content

Commit

Permalink
Added _post_process method to SpotFinder. Added TOFSpotFinder. Added …
Browse files Browse the repository at this point in the history
…tof_centroid_px_to_mm_panel.
  • Loading branch information
toastisme committed Feb 6, 2024
1 parent bf73681 commit c90d6fc
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 7 deletions.
59 changes: 58 additions & 1 deletion src/dials/algorithms/centroid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
from __future__ import annotations

from dxtbx.model.tof_helpers import frame_to_tof_interpolator
from scitbx.array_family import flex


def centroid_px_to_mm(detector, scan, position, variance, sd_error):
"""Convenience function to calculate centroid in mm/rad from px"""

# Get the pixel to millimeter function
assert len(detector) == 1
if scan.has_property("time_of_flight"):
return tof_centroid_px_to_mm_panel(
detector[0], scan, position, variance, sd_error
)
return centroid_px_to_mm_panel(detector[0], scan, position, variance, sd_error)


def centroid_px_to_mm_panel(panel, scan, position, variance, sd_error):
"""Convenience function to calculate centroid in mm/rad from px"""
# Get the pixel to millimeter function

if scan.has_property("time_of_flight"):
return tof_centroid_px_to_mm_panel(panel, scan, position, variance, sd_error)

pixel_size = panel.get_pixel_size()
if scan is None:
oscillation = (0, 0)
Expand All @@ -38,7 +49,6 @@ def centroid_px_to_mm_panel(panel, scan, position, variance, sd_error):
sd_error_mm = [sde * s for sde, s in zip(sd_error, scale2)]

else:
from scitbx.array_family import flex

# Convert Pixel coordinate into mm/rad
x, y, z = position.parts()
Expand All @@ -61,3 +71,50 @@ def centroid_px_to_mm_panel(panel, scan, position, variance, sd_error):

# Return the stuff in mm/rad
return position_mm, variance_mm, sd_error_mm


def tof_centroid_px_to_mm_panel(panel, scan, position, variance, sd_error):
"""Convenience function to calculate centroid in mm/tof from px"""
assert scan.has_property("time_of_flight")

pixel_size = panel.get_pixel_size()
tof = scan.get_property("time_of_flight") # (usec)
frames = [i + 1 for i in range(len(tof))]
frame_to_tof = frame_to_tof_interpolator(frames, tof)
scale = pixel_size + (tof[0],)
scale2 = tuple(s * s for s in scale)

if isinstance(position, tuple):
# Convert Pixel coordinate into mm/rad
x, y, z = position
xy_mm = panel.pixel_to_millimeter((x, y))
z_tof = flex.double(frame_to_tof(z))
scale = pixel_size + (z_tof,)
scale2 = tuple(s * s for s in scale)

# Set the position, variance and squared width in mm/tof
# N.B assuming locally flat pixel to millimeter transform
# for variance calculation.
position_mm = xy_mm + (z_tof,)
variance_mm = [var * s for var, s in zip(variance, scale2)]
sd_error_mm = [sde * s for sde, s in zip(sd_error, scale2)]

else:

# Convert Pixel coordinate into mm/tof
x, y, z = position.parts()
xy_mm = panel.pixel_to_millimeter(flex.vec2_double(x, y))
z_tof = flex.double(frame_to_tof(z))
scale = tuple(s * s for s in pixel_size)

# Set the position, variance and squared width in mm/tof
# N.B assuming locally flat pixel to millimeter transform
# for variance calculation.
x_mm, y_mm = xy_mm.parts()
position_mm = flex.vec3_double(x_mm, y_mm, z_tof)
v0, v1, v2 = variance.parts()
variance_mm = flex.vec3_double(v0 * scale[0], v1 * scale[1], v2 * z_tof * z_tof)
s0, s1, s2 = sd_error.parts()
sd_error_mm = flex.vec3_double(s0 * scale[0], s1 * scale[1], s2 * z_tof * z_tof)

return position_mm, variance_mm, sd_error_mm
45 changes: 41 additions & 4 deletions src/dials/algorithms/spot_finding/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import dials.extensions
import dials.util.masking
from dials.algorithms.background.simple import Linear2dModeller
from dials.algorithms.spot_finding.finder import SpotFinder
from dials.algorithms.spot_finding.finder import SpotFinder, TOFSpotFinder
from dials.array_family import flex

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -438,9 +438,6 @@ def from_parameters(params=None, experiments=None, is_stills=False):
mask = SpotFinderFactory.load_image(params.spotfinder.lookup.mask)
params.spotfinder.lookup.mask = mask

# Configure the filter options
filter_spots = SpotFinderFactory.configure_filter(params)

# Create the threshold strategy
threshold_function = SpotFinderFactory.configure_threshold(params)

Expand All @@ -453,6 +450,46 @@ def from_parameters(params=None, experiments=None, is_stills=False):
params.spotfinder.mp.method = None

# Setup the spot finder
contains_tof_experiments = False
for experiment in experiments:
if experiment.scan.has_property("time_of_flight"):
contains_tof_experiments = True
elif contains_tof_experiments:
raise RuntimeError("All experiment scans must contain time_of_flight")

if contains_tof_experiments:

# ToF spots from spallation sources typically have elongated tails
if params.spotfinder.filter.max_separation < 6:
# Based on ISISSXD data
# https://zenodo.org/records/4415768
logger.info("Increasing max allowed peak-centroid distance to 6px")
params.spotfinder.filter.max_separation = 6
filter_spots = SpotFinderFactory.configure_filter(params)

return TOFSpotFinder(
experiments=experiments,
threshold_function=threshold_function,
mask=params.spotfinder.lookup.mask,
filter_spots=filter_spots,
scan_range=params.spotfinder.scan_range,
write_hot_mask=params.spotfinder.write_hot_mask,
hot_mask_prefix=params.spotfinder.hot_mask_prefix,
mp_method=params.spotfinder.mp.method,
mp_nproc=params.spotfinder.mp.nproc,
mp_njobs=params.spotfinder.mp.njobs,
mp_chunksize=params.spotfinder.mp.chunksize,
max_strong_pixel_fraction=params.spotfinder.filter.max_strong_pixel_fraction,
compute_mean_background=params.spotfinder.compute_mean_background,
region_of_interest=params.spotfinder.region_of_interest,
mask_generator=mask_generator,
min_spot_size=params.spotfinder.filter.min_spot_size,
max_spot_size=params.spotfinder.filter.max_spot_size,
min_chunksize=params.spotfinder.mp.min_chunksize,
)

filter_spots = SpotFinderFactory.configure_filter(params)

return SpotFinder(
threshold_function=threshold_function,
mask=params.spotfinder.lookup.mask,
Expand Down
97 changes: 95 additions & 2 deletions src/dials/algorithms/spot_finding/finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dxtbx.format.image import ImageBool
from dxtbx.imageset import ImageSequence, ImageSet
from dxtbx.model import ExperimentList
from dxtbx.model.tof_helpers import wavelength_from_tof

from dials.array_family import flex
from dials.model.data import PixelList, PixelListLabeller
Expand Down Expand Up @@ -746,10 +747,10 @@ def find_spots(self, experiments: ExperimentList) -> flex.reflection_table:
flex.size_t_range(len(reflections)), reflections.flags.strong
)

# Check for overloads
reflections.is_overloaded(experiments)

# Return the reflections
reflections = self._post_process(reflections)

return reflections

def _find_spots_in_imageset(self, imageset):
Expand Down Expand Up @@ -854,3 +855,95 @@ def _create_hot_mask(self, imageset, hot_pixels):

# Return the hot mask
return hot_mask

def _post_process(self, reflections):
return reflections


class TOFSpotFinder(SpotFinder):
"""
Class to do spot finding tailored to time of flight experiments
"""

def __init__(
self,
experiments,
threshold_function=None,
mask=None,
region_of_interest=None,
max_strong_pixel_fraction=0.1,
compute_mean_background=False,
mp_method=None,
mp_nproc=1,
mp_njobs=1,
mp_chunksize=1,
mask_generator=None,
filter_spots=None,
scan_range=None,
write_hot_mask=True,
hot_mask_prefix="hot_mask",
min_spot_size=1,
max_spot_size=20,
min_chunksize=50,
):

super().__init__(
threshold_function=threshold_function,
mask=mask,
region_of_interest=region_of_interest,
max_strong_pixel_fraction=max_strong_pixel_fraction,
compute_mean_background=compute_mean_background,
mp_method=mp_method,
mp_nproc=mp_nproc,
mp_njobs=mp_njobs,
mp_chunksize=mp_chunksize,
mask_generator=mask_generator,
filter_spots=filter_spots,
scan_range=scan_range,
write_hot_mask=write_hot_mask,
hot_mask_prefix=hot_mask_prefix,
min_spot_size=min_spot_size,
max_spot_size=max_spot_size,
no_shoeboxes_2d=False,
min_chunksize=min_chunksize,
is_stills=False,
)

self.experiments = experiments

def _post_process(self, reflections):

n_rows = reflections.nrows()
panel_numbers = flex.size_t(reflections["panel"])
reflections["L1"] = flex.double(n_rows)
reflections["wavelength"] = flex.double(n_rows)
reflections["s0"] = flex.vec3_double(n_rows)
reflections.centroid_px_to_mm(self.experiments)

for i, expt in enumerate(self.experiments):
if "imageset_id" in reflections:
sel_expt = reflections["imageset_id"] == i
else:
sel_expt = reflections["id"] == i

L0 = expt.beam.get_sample_to_source_distance() * 10**-3 # (m)
unit_s0 = expt.beam.get_unit_s0()

for i_panel in range(len(expt.detector)):

sel = sel_expt & (panel_numbers == i_panel)
x, y, tof = reflections["xyzobs.mm.value"].select(sel).parts()
px, py, frame = reflections["xyzobs.px.value"].select(sel).parts()
s1 = expt.detector[i_panel].get_lab_coord(flex.vec2_double(x, y))
L1 = s1.norms()
wavelengths = wavelength_from_tof(L0 + L1 * 10**-3, tof * 10**-6)
s0s = flex.vec3_double(
unit_s0[0] / wavelengths,
unit_s0[1] / wavelengths,
unit_s0[2] / wavelengths,
)

reflections["wavelength"].set_selected(sel, wavelengths)
reflections["s0"].set_selected(sel, s0s)
reflections["L1"].set_selected(sel, L1)
return reflections

0 comments on commit c90d6fc

Please sign in to comment.