From 72263417ebf43c0be410dbe15866b02e0a99916c Mon Sep 17 00:00:00 2001 From: davidmcdonagh Date: Wed, 31 Jul 2024 22:32:05 +0100 Subject: [PATCH] Added tof integration methods. --- .../profile_model/gaussian_rs/__init__.py | 4 + .../gaussian_rs/bbox_calculator.h | 133 ++++++++++++++++++ .../boost_python/gaussian_rs_ext.cc | 30 ++++ .../gaussian_rs/coordinate_system.h | 123 ++++++++++++++++ .../profile_model/gaussian_rs/model.py | 18 ++- .../scaling/tof/boost_python/tof_scaling.cc | 15 ++ .../algorithms/spot_prediction/__init__.py | 103 ++++++++++++++ .../spot_prediction/reflection_predictor.py | 10 ++ src/dials/array_family/flex_ext.py | 2 +- 9 files changed, 433 insertions(+), 5 deletions(-) diff --git a/src/dials/algorithms/profile_model/gaussian_rs/__init__.py b/src/dials/algorithms/profile_model/gaussian_rs/__init__.py index 1002d80da4..deef97b91d 100644 --- a/src/dials/algorithms/profile_model/gaussian_rs/__init__.py +++ b/src/dials/algorithms/profile_model/gaussian_rs/__init__.py @@ -6,6 +6,7 @@ BBoxCalculator2D, BBoxCalculator3D, BBoxCalculatorIface, + BBoxCalculatorTOF, BBoxMultiCalculator, CoordinateSystem, CoordinateSystem2d, @@ -27,6 +28,7 @@ "BBoxCalculator", "BBoxCalculator2D", "BBoxCalculator3D", + "BBoxCalculatorTOF", "BBoxCalculatorIface", "BBoxMultiCalculator", "CoordinateSystem", @@ -54,6 +56,8 @@ def BBoxCalculator(crystal, beam, detector, goniometer, scan, delta_b, delta_m): """Return the relevant bbox calculator.""" if goniometer is None or scan is None or scan.is_still(): algorithm = BBoxCalculator2D(beam, detector, delta_b, delta_m) + elif scan.has_property("time_of_flight"): + algorithm = BBoxCalculatorTOF(beam, detector, scan, delta_b, delta_m) else: algorithm = BBoxCalculator3D(beam, detector, goniometer, scan, delta_b, delta_m) return algorithm diff --git a/src/dials/algorithms/profile_model/gaussian_rs/bbox_calculator.h b/src/dials/algorithms/profile_model/gaussian_rs/bbox_calculator.h index f6d8fe1486..3087950fb9 100644 --- a/src/dials/algorithms/profile_model/gaussian_rs/bbox_calculator.h +++ b/src/dials/algorithms/profile_model/gaussian_rs/bbox_calculator.h @@ -33,6 +33,7 @@ namespace dials { using dxtbx::model::BeamBase; using dxtbx::model::Detector; using dxtbx::model::Goniometer; + using dxtbx::model::PolychromaticBeam; using dxtbx::model::Scan; using scitbx::vec2; using scitbx::vec3; @@ -394,6 +395,138 @@ namespace dials { private: std::vector > compute_; }; + class BBoxCalculatorTOF { + public: + /** + * Initialise the bounding box calculation. + * @param beam The beam parameters + * @param detector The detector parameters + * @param delta_divergence The xds delta_divergence parameter + * @param delta_mosaicity The xds delta_mosaicity parameter + */ + BBoxCalculatorTOF(const PolychromaticBeam &beam, + const Detector &detector, + const Scan &scan, + double delta_divergence, + double delta_mosaicity) + : detector_(detector), + scan_(scan), + beam_(beam), + delta_divergence_(delta_divergence), + delta_mosaicity_(delta_mosaicity) { + DIALS_ASSERT(delta_divergence > 0.0); + DIALS_ASSERT(delta_mosaicity >= 0.0); + } + + /** + * Calculate the bbox on the detector image volume for the reflection. + * + * The roi is calculated using the parameters delta_divergence and + * delta_mosaicity. The reflection mask comprises all pixels where: + * |e1| <= delta_d, |e2| <= delta_d, |e3| <= delta_m + * + * We transform the coordinates of the box + * (-delta_d, -delta_d, 0) + * (+delta_d, -delta_d, 0) + * (-delta_d, +delta_d, 0) + * (+delta_d, +delta_d, 0) + * + * to the detector image volume and return the minimum and maximum values + * for the x, y, z image volume coordinates. + * + * @param s1 The diffracted beam vector + * @param frame The predicted frame number + * @returns A 6 element array: (minx, maxx, miny, maxy, minz, maxz) + */ + virtual int6 single(vec3 s0, + vec3 s1, + double frame, + double L1, + std::size_t panel) const { + // Ensure our values are ok + DIALS_ASSERT(s1.length_sq() > 0); + + // Create the coordinate system for the reflection + CoordinateSystemTOF xcs(s0, s1, L1); + + // Get the divergence and mosaicity for this point + double delta_d = delta_divergence_; + double delta_m = delta_mosaicity_; + + // Calculate the beam vectors at the following xds coordinates: + // (-delta_d, -delta_d, 0) + // (+delta_d, -delta_d, 0) + // (-delta_d, +delta_d, 0) + // (+delta_d, +delta_d, 0) + double point = delta_d; + double3 sdash1 = xcs.to_beam_vector(double2(-point, -point)); + double3 sdash2 = xcs.to_beam_vector(double2(+point, -point)); + double3 sdash3 = xcs.to_beam_vector(double2(-point, +point)); + double3 sdash4 = xcs.to_beam_vector(double2(+point, +point)); + + // Get the detector coordinates (px) at the ray intersections + double2 xy1 = detector_[panel].get_ray_intersection_px(sdash1); + double2 xy2 = detector_[panel].get_ray_intersection_px(sdash2); + double2 xy3 = detector_[panel].get_ray_intersection_px(sdash3); + double2 xy4 = detector_[panel].get_ray_intersection_px(sdash4); + + // Return the roi in the following form: + // (minx, maxx, miny, maxy, minz, maxz) + // Min's are rounded down to the nearest integer, Max's are rounded up + double4 x(xy1[0], xy2[0], xy3[0], xy4[0]); + double4 y(xy1[1], xy2[1], xy3[1], xy4[1]); + + int x0 = (int)floor(min(x)); + int x1 = (int)ceil(max(x)); + int y0 = (int)floor(min(y)); + int y1 = (int)ceil(max(y)); + + double z0 = frame - delta_m * .5; + if (z0 < 0) { + z0 = 0; + } + double max_z = scan_.get_array_range()[1]; + double z1 = frame + delta_m; + if (z1 > max_z) { + z1 = max_z; + } + + int6 bbox(x0, x1, y0, y1, z0, z1); + DIALS_ASSERT(bbox[1] > bbox[0]); + DIALS_ASSERT(bbox[3] > bbox[2]); + DIALS_ASSERT(bbox[5] > bbox[4]); + return bbox; + } + + /** + * Calculate the rois for an array of reflections given by the array of + * diffracted beam vectors and rotation angles. + * @param s1 The array of diffracted beam vectors + * @param phi The array of rotation angles. + */ + virtual af::shared array(const af::const_ref > &s0, + const af::const_ref > &s1, + const af::const_ref &frame, + const af::const_ref &L1, + const af::const_ref &panel) const { + DIALS_ASSERT(s1.size() == frame.size()); + DIALS_ASSERT(s1.size() == panel.size()); + DIALS_ASSERT(s0.size() == frame.size()); + DIALS_ASSERT(s0.size() == panel.size()); + af::shared result(s1.size(), af::init_functor_null()); + for (std::size_t i = 0; i < s1.size(); ++i) { + result[i] = single(s0[i], s1[i], frame[i], L1[i], panel[i]); + } + return result; + } + + private: + Detector detector_; + Scan scan_; + PolychromaticBeam beam_; + double delta_divergence_; + double delta_mosaicity_; + }; }}}} // namespace dials::algorithms::profile_model::gaussian_rs diff --git a/src/dials/algorithms/profile_model/gaussian_rs/boost_python/gaussian_rs_ext.cc b/src/dials/algorithms/profile_model/gaussian_rs/boost_python/gaussian_rs_ext.cc index 2568531f27..ab4f32f867 100644 --- a/src/dials/algorithms/profile_model/gaussian_rs/boost_python/gaussian_rs_ext.cc +++ b/src/dials/algorithms/profile_model/gaussian_rs/boost_python/gaussian_rs_ext.cc @@ -204,6 +204,36 @@ namespace dials { arg("scan"), arg("delta_divergence"), arg("delta_mosaicity")))); + class_("BBoxCalculatorTOF", no_init) + .def( + init( + (arg("beam"), + arg("detector"), + arg("scan"), + arg("delta_divergence"), + arg("delta_mosaicity")))) + .def("__call__", + &BBoxCalculatorTOF::single, + (arg("s0"), arg("s1"), arg("frame"), arg("L1"), arg("panel"))) + .def("__call__", + &BBoxCalculatorTOF::array, + (arg("s0"), arg("s1"), arg("frame"), arg("L1"), arg("panel"))); + + class_("BBoxMultiCalculator") + .def("append", &BBoxMultiCalculator::push_back) + .def("__len__", &BBoxMultiCalculator::size) + .def("__call__", &BBoxMultiCalculator::operator()); + + class_("MaskCalculatorIface", no_init) + .def("__call__", + &MaskCalculatorIface::single, + (arg("shoebox"), arg("s1"), arg("frame"), arg("panel"))) + .def("__call__", + &MaskCalculatorIface::array, + (arg("shoebox"), arg("s1"), arg("frame"), arg("panel"))) + .def("__call__", + &MaskCalculatorIface::volume, + (arg("volume"), arg("bbox"), arg("s1"), arg("frame"), arg("panel"))); class_ >("BBoxCalculator2D", no_init) .def(init( diff --git a/src/dials/algorithms/profile_model/gaussian_rs/coordinate_system.h b/src/dials/algorithms/profile_model/gaussian_rs/coordinate_system.h index 99f3c1194c..0d26fb0f61 100644 --- a/src/dials/algorithms/profile_model/gaussian_rs/coordinate_system.h +++ b/src/dials/algorithms/profile_model/gaussian_rs/coordinate_system.h @@ -374,6 +374,129 @@ namespace dials { double zeta_; }; + class CoordinateSystemTOF { + public: + /** + * Initialise coordinate system. s0 should be the same length as s1. + * These quantities are not checked because this class will be created for + * each reflection and we want to maximize performance. + * @param s0 The incident beam vector + * @param s1 The diffracted beam vector + */ + CoordinateSystemTOF(vec3 s0, vec3 s1, double L1) + : s0_(s0), + s1_(s1), + p_star_(s1 - s0), + L1_(L1), + e1_(s1.cross(s0).normalize()), + e2_(s1.cross(e1_).normalize()), + e3_((s1 + s0).normalize()) {} + + vec3 s0() const { + return s0_; + } + vec3 s1() const { + return s1_; + } + double L1() const { + return L1_; + } + vec3 p_star() const { + return p_star_; + } + vec3 e1_axis() const { + return e1_; + } + vec3 e2_axis() const { + return e2_; + } + vec3 e3_axis() const { + return e3_; + } + + /** + * Transform the beam vector to the reciprocal space coordinate system. + * @param s_dash The beam vector + * @param s0_dash The incident beam vector + * @returns The e1, e2, e3 coordinates + */ + vec2 from_beam_vector(const vec3 &s_dash) const { + double s1_length = s1_.length(); + double s0_length = s0_.length(); + DIALS_ASSERT(s1_length > 0); + DIALS_ASSERT(s0_length > 0); + // vec3 p_star0 = s_dash-s0_dash; + vec3 e1 = e1_ / s1_length; + vec3 e2 = e2_ / s1_length; + // vec3 e3 = (s1_+s0_)/(s1_length + s0_length); + /* + return vec3( + e1 * (s_dash - s1_), + e2 * (s_dash - s1_), + e3 * (p_star0 - p_star_)/p_star_.length()); + */ + return vec2(e1 * (s_dash - s1_), e2 * (s_dash - s1_)); + } + + /** + * Transform the reciprocal space coordinate to get the beam vector. + * @param c12 The e1 and e2 coordinates. + * @returns The beam vector + */ + vec3 to_beam_vector(const vec2 &c12) const { + double radius = s1_.length(); + DIALS_ASSERT(radius > 0); + vec3 scaled_e1 = e1_ * radius; + vec3 scaled_e2 = e2_ * radius; + vec3 normalized_s1 = s1_ / radius; + + vec3 p = c12[0] * scaled_e1 + c12[1] * scaled_e2; + double b = radius * radius - p.length_sq(); + DIALS_ASSERT(b >= 0); + double d = -(normalized_s1 * p) + std::sqrt(b); + return p + d * normalized_s1; + } + + /** + * @param c3 The XDS e3 coordinate + * @param s_dash The beam vector from the e1 and e2 coordinates + * @returns The wavelength of the e3 coordinate (s) + * + * Solved be rearranging c3 = e3(p_star0 - p_star) / s1_length, + * noting that p_star0 = s_dash - s0_dash, + * and s0_dash = unit_s0/wavelength_dash + */ + double to_wavelength(double c3, vec3 s_dash) const { + double p_star_length = p_star_.length(); + DIALS_ASSERT(p_star_length > 0); + vec3 unit_s0 = s0_.normalize(); + return (e3_ * unit_s0) / (e3_ * s_dash - e3_ * p_star_ - c3 * p_star_length); + } + + /** + * Transform the rotation angle to the reciprocal space coordinate system + * @param phi_dash The rotation angle + * @returns The e3 coordinate. + */ + double from_wavelength(double wavelength) const { + double p_star_length = p_star_.length(); + DIALS_ASSERT(p_star_length > 0); + vec3 scaled_e3 = e3_ / p_star_length; + vec3 s0 = s0_.normalize() / wavelength; + vec3 p_star0 = s1_ - s0; + return scaled_e3 * (p_star0 - p_star_); + } + + private: + vec3 s0_; + vec3 s1_; + vec3 p_star_; + vec3 e1_; + vec3 e2_; + vec3 e3_; + double L1_; + }; + }}}} // namespace dials::algorithms::profile_model::gaussian_rs #endif // DIALS_ALGORITHMS_PROFILE_MODEL_GAUSSIAN_RS_COORDINATE_SYSTEM_H diff --git a/src/dials/algorithms/profile_model/gaussian_rs/model.py b/src/dials/algorithms/profile_model/gaussian_rs/model.py index 710e974929..8dc1cbeb71 100644 --- a/src/dials/algorithms/profile_model/gaussian_rs/model.py +++ b/src/dials/algorithms/profile_model/gaussian_rs/model.py @@ -514,10 +514,20 @@ def compute_bbox( crystal, beam, detector, goniometer, scan, delta_b, delta_m ) - # Calculate the bounding boxes of all the reflections - bbox = calculate( - reflections["s1"], reflections["xyzcal.px"].parts()[2], reflections["panel"] - ) + if scan.has_property("time_of_flight"): + bbox = calculate( + reflections["s0_cal"], + reflections["s1"], + reflections["xyzcal.px"].parts()[2], + reflections["L1"], + reflections["panel"], + ) + else: + bbox = calculate( + reflections["s1"], + reflections["xyzcal.px"].parts()[2], + reflections["panel"], + ) # Return the bounding boxes return bbox diff --git a/src/dials/algorithms/scaling/tof/boost_python/tof_scaling.cc b/src/dials/algorithms/scaling/tof/boost_python/tof_scaling.cc index c192ccbd91..e7ecb55f3f 100644 --- a/src/dials/algorithms/scaling/tof/boost_python/tof_scaling.cc +++ b/src/dials/algorithms/scaling/tof/boost_python/tof_scaling.cc @@ -1,6 +1,7 @@ #include #include #include +#include namespace dials_scaling { namespace boost_python { @@ -43,6 +44,20 @@ namespace dials_scaling { namespace boost_python { def("tof_extract_shoeboxes_to_reflection_table", extract_shoeboxes1); def("tof_extract_shoeboxes_to_reflection_table", extract_shoeboxes2); def("tof_extract_shoeboxes_to_reflection_table", extract_shoeboxes3); + def("tof_calculate_shoebox_mask", + &dials::algorithms::tof_calculate_shoebox_mask, + (arg("reflection_table"), arg("experiment"))); + def("tof_calculate_shoebox_foreground", + &dials::algorithms::tof_calculate_shoebox_foreground, + (arg("reflection_table"), arg("experiment"), arg("foreground_radius"))); + def("get_asu_reflections", + &dials::algorithms::get_asu_reflections, + (arg("indices"), + arg("predicted_indices"), + arg("wavelengths"), + arg("predicted_wavelengths"), + arg("asu_reflection"), + arg("space_group"))); } }} // namespace dials_scaling::boost_python diff --git a/src/dials/algorithms/spot_prediction/__init__.py b/src/dials/algorithms/spot_prediction/__init__.py index 8b4cb9cd92..0d46fccfac 100644 --- a/src/dials/algorithms/spot_prediction/__init__.py +++ b/src/dials/algorithms/spot_prediction/__init__.py @@ -1,6 +1,10 @@ from __future__ import annotations +from dxtbx import flumpy +from dxtbx.model import tof_helpers + import dials_algorithms_spot_prediction_ext +from dials.array_family import flex from dials_algorithms_spot_prediction_ext import ( IndexGenerator, LaueRayPredictor, @@ -171,3 +175,102 @@ def LaueReflectionPredictor(experiment, dmin: float): experiment.crystal.get_space_group().type(), dmin, ) + + +class TOFReflectionPredictor: + def __init__(self, experiment, dmin): + self.experiment = experiment + self.dmin = dmin + self.predictor = dials_algorithms_spot_prediction_ext.LaueReflectionPredictor( + experiment.beam, + experiment.detector, + experiment.goniometer, + experiment.crystal.get_A(), + experiment.crystal.get_unit_cell(), + experiment.crystal.get_space_group().type(), + dmin, + ) + + def post_prediction(self, reflections): + + if "tof_cal" not in reflections: + reflections["tof_cal"] = flex.double(reflections.nrows()) + if "L1" not in reflections: + reflections["L1"] = flex.double(reflections.nrows()) + + tof_cal = flex.double(reflections.nrows()) + L1 = flex.double(reflections.nrows()) + L0 = self.experiment.beam.get_sample_to_source_distance() * 10**-3 # (m) + + panel_numbers = flex.size_t(reflections["panel"]) + expt = self.experiment + + for i_panel in range(len(expt.detector)): + sel = panel_numbers == i_panel + expt_reflections = reflections.select(sel) + x, y, _ = expt_reflections["xyzcal.mm"].parts() + s1 = expt.detector[i_panel].get_lab_coord(flex.vec2_double(x, y)) + expt_L1 = s1.norms() + expt_tof_cal = flex.double(expt_reflections.nrows()) + + for idx in range(len(expt_reflections)): + wavelength = expt_reflections[idx]["wavelength_cal"] + tof = tof_helpers.tof_from_wavelength( + wavelength, L0 + expt_L1[idx] * 10**-3 + ) + expt_tof_cal[idx] = tof + tof_cal.set_selected(sel, expt_tof_cal) + L1.set_selected(sel, expt_L1) + + reflections["tof_cal"] = tof_cal + reflections["L1"] = L1 + + # Filter out predicted reflections outside of experiment range + wavelength_range = expt.beam.get_wavelength_range() + sel = reflections["wavelength_cal"] >= wavelength_range[0] + reflections = reflections.select(sel) + sel = reflections["wavelength_cal"] <= wavelength_range[1] + reflections = reflections.select(sel) + + return reflections + + def for_ub(self, ub): + + reflection_table = self.predictor.for_ub(ub) + reflection_table = self.post_prediction(reflection_table) + + interpolation_tof = self.experiment.scan.get_property("time_of_flight") + interpolation_frames = list(range(len(interpolation_tof))) + tof_to_frame = tof_helpers.tof_to_frame_interpolator( + interpolation_tof, interpolation_frames + ) + L0 = self.experiment.beam.get_sample_to_source_distance() * 10**-3 # (m) + + reflection_tof = ( + tof_helpers.tof_from_wavelength( + reflection_table["wavelength_cal"], + L0 + reflection_table["L1"] * 10**-3, + ) + * 10**6 + ) + + reflection_table = reflection_table.select( + (reflection_tof > min(interpolation_tof)) + & (reflection_tof < max(interpolation_tof)) + ) + + reflection_tof = reflection_tof.select( + (reflection_tof > min(interpolation_tof)) + & (reflection_tof < max(interpolation_tof)) + ) + reflection_frames = flumpy.from_numpy(tof_to_frame(reflection_tof)) + x, y, _ = reflection_table["xyzcal.px"].parts() + reflection_table["xyzcal.px"] = flex.vec3_double(x, y, reflection_frames) + + return reflection_table + + def for_reflection_table(self, reflections, UB): + return self.predictor.for_reflection_table(reflections, UB) + + def all_reflections_for_asu(self, phi): + return self.predictor.all_reflections_for_asu(float(phi)) diff --git a/src/dials/algorithms/spot_prediction/reflection_predictor.py b/src/dials/algorithms/spot_prediction/reflection_predictor.py index 95058000ec..7c4a5bc9a0 100644 --- a/src/dials/algorithms/spot_prediction/reflection_predictor.py +++ b/src/dials/algorithms/spot_prediction/reflection_predictor.py @@ -4,6 +4,7 @@ logger = logging.getLogger(__name__) +from dxtbx.model import ExperimentType from libtbx.phil import parse from dials.util import Sorry @@ -62,6 +63,7 @@ def __init__( ScanStaticReflectionPredictor, ScanVaryingReflectionPredictor, StillsReflectionPredictor, + TOFReflectionPredictor, ) from dials.array_family import flex @@ -79,6 +81,14 @@ def __call__(self): result.del_selected(mask) return result + if experiment.get_type() == ExperimentType.TOF: + predictor = TOFReflectionPredictor(experiment=experiment, dmin=dmin) + predict = Predictor( + "ToF prediction", + lambda: predictor.for_ub(experiment.crystal.get_A()), + ) + self._predict = predict + return # Check prediction to maximum resolution is possible wl = experiment.beam.get_wavelength() if dmin is not None and dmin < 0.5 * wl: diff --git a/src/dials/array_family/flex_ext.py b/src/dials/array_family/flex_ext.py index 72a02daecc..623feaed6f 100644 --- a/src/dials/array_family/flex_ext.py +++ b/src/dials/array_family/flex_ext.py @@ -629,7 +629,7 @@ def __init__(self): distance = cctbx.array_family.flex.sqrt( cctbx.array_family.flex.pow2(x1 - x2) + cctbx.array_family.flex.pow2(y1 - y2) - + cctbx.array_family.flex.pow2(z1 - z2) + + cctbx.array_family.flex.pow2(z1 - z1) ) mask = distance < 2 logger.info(" %d reflections matched", len(o2))