diff --git a/src/dials/algorithms/refinement/prediction/managed_predictors.py b/src/dials/algorithms/refinement/prediction/managed_predictors.py index 97833e73f4..eaed3826eb 100644 --- a/src/dials/algorithms/refinement/prediction/managed_predictors.py +++ b/src/dials/algorithms/refinement/prediction/managed_predictors.py @@ -11,9 +11,13 @@ from math import pi +from dxtbx.model import tof_helpers from scitbx.array_family import flex -from dials.algorithms.spot_prediction import ScanStaticRayPredictor +from dials.algorithms.spot_prediction import ( + LaueReflectionPredictor, + ScanStaticRayPredictor, +) from dials.algorithms.spot_prediction import ScanStaticReflectionPredictor as sc from dials.algorithms.spot_prediction import ScanVaryingReflectionPredictor as sv from dials.algorithms.spot_prediction import StillsReflectionPredictor as st @@ -83,6 +87,7 @@ def __call__(self, reflections): refs = reflections.select(sel) self._predict_one_experiment(e, refs) + refs = self._post_predict_one_experiment(e, refs) # write predictions back to overall reflections reflections.set_selected(sel, refs) @@ -95,6 +100,9 @@ def _predict_one_experiment(self, experiment, reflections): raise NotImplementedError() + def _post_predict_one_experiment(self, experiment, reflections): + return reflections + def _post_prediction(self, reflections): """Perform tasks on the whole reflection list after prediction before returning.""" @@ -165,6 +173,50 @@ def _predict_one_experiment(self, experiment, reflections): predictor.for_reflection_table(reflections, UB) +class LaueExperimentsPredictor(ExperimentsPredictor): + def _predict_one_experiment(self, experiment, reflections): + + min_s0_idx = min( + range(len(reflections["wavelength"])), + key=reflections["wavelength"].__getitem__, + ) + + if "s0" not in reflections: + unit_s0 = experiment.beam.get_unit_s0() + wl = reflections["wavelength"][min_s0_idx] + min_s0 = (unit_s0[0] / wl, unit_s0[1] / wl, unit_s0[2] / wl) + else: + min_s0 = reflections["s0"][min_s0_idx] + + dmin = experiment.detector.get_max_resolution(min_s0) + predictor = LaueReflectionPredictor(experiment, dmin) + UB = experiment.crystal.get_A() + predictor.for_reflection_table(reflections, UB) + + +class TOFExperimentsPredictor(LaueExperimentsPredictor): + def _post_predict_one_experiment(self, experiment, reflections): + + # Add ToF to xyzcal.mm + wavelength_cal = reflections["wavelength_cal"] + tof_cal = tof_helpers.tof_from_wavelength(wavelength_cal) # (s) + x, y, z = reflections["xyzcal.mm"].parts() + reflections["xyzcal.mm"] = flex.vec3_double(x, y, tof_cal) + tof_cal = tof_cal * 1e6 # (usec) + + # Add frame to xyzcal.px + expt_tof = experiment.scan.get_property("time_of_flight") # (usec) + frames = [i + 1 for i in range(len(expt_tof))] + tof_to_frame = tof_helpers.tof_to_frame_interpolator(expt_tof, frames) + reflection_frames = flex.double(tof_to_frame(tof_cal)) + px, py, pz = reflections["xyzcal.px"].parts() + reflections["xyzcal.px"] = flex.vec3_double(px, py, reflection_frames) + if "xyzobs.mm.value" in reflections: + reflections = self._match_full_turns(reflections) + + return reflections + + class ExperimentsPredictorFactory: @staticmethod def from_experiments(experiments, force_stills=False, spherical_relp=False): @@ -180,7 +232,21 @@ def from_experiments(experiments, force_stills=False, spherical_relp=False): if force_stills: predictor = StillsExperimentsPredictor(experiments) predictor.spherical_relp_model = spherical_relp + else: - predictor = ScansExperimentsPredictor(experiments) + + all_tof_experiments = False + for expt in experiments: + if expt.scan is not None and expt.scan.has_property("time_of_flight"): + all_tof_experiments = True + elif all_tof_experiments: + raise ValueError( + "Cannot find max cell for ToF and non-ToF experiments at the same time" + ) + + if all_tof_experiments: + predictor = TOFExperimentsPredictor(experiments) + else: + predictor = ScansExperimentsPredictor(experiments) return predictor diff --git a/src/dials/algorithms/spot_prediction/__init__.py b/src/dials/algorithms/spot_prediction/__init__.py index b3e0b35b4c..2f3a82a242 100644 --- a/src/dials/algorithms/spot_prediction/__init__.py +++ b/src/dials/algorithms/spot_prediction/__init__.py @@ -32,6 +32,7 @@ "StillsDeltaPsiReflectionPredictor", "StillsRayPredictor", "StillsReflectionPredictor", + "LaueReflectionPredictor", ] @@ -155,3 +156,16 @@ def StillsReflectionPredictor(experiment, dmin=None, spherical_relp=False, **kwa experiment.crystal.get_space_group().type(), dmin, ) + + +def LaueReflectionPredictor(experiment, dmin: float): + + return dials_algorithms_spot_prediction_ext.LaueReflectionPredictor( + experiment.beam, + experiment.detector, + experiment.goniometer, + experiment.crystal.get_A(), + experiment.crystal.get_unit_cell(), + experiment.crystal.get_space_group().type(), + dmin, + ) diff --git a/src/dials/algorithms/spot_prediction/boost_python/reflection_predictor.cc b/src/dials/algorithms/spot_prediction/boost_python/reflection_predictor.cc index 3d7c481f4c..d5699bb3db 100644 --- a/src/dials/algorithms/spot_prediction/boost_python/reflection_predictor.cc +++ b/src/dials/algorithms/spot_prediction/boost_python/reflection_predictor.cc @@ -156,12 +156,45 @@ namespace dials { namespace algorithms { namespace boost_python { .def("for_reflection_table", &Predictor::for_reflection_table_with_individual_ub); } + void export_laue_reflection_predictor() { + typedef LaueReflectionPredictor Predictor; + + af::reflection_table (Predictor::*predict_all)() const = &Predictor::operator(); + + af::reflection_table (Predictor::*predict_observed)( + const af::const_ref >&) = &Predictor::operator(); + + af::reflection_table (Predictor::*predict_observed_with_panel)( + const af::const_ref >&, std::size_t) = + &Predictor::operator(); + + af::reflection_table (Predictor::*predict_observed_with_panel_list)( + const af::const_ref >&, + const af::const_ref&) = &Predictor::operator(); + + class_("LaueReflectionPredictor", no_init) + .def(init&, + const Detector&, + mat3, + const cctbx::uctbx::unit_cell&, + const cctbx::sgtbx::space_group_type&, + const double&>()) + .def("__call__", predict_all) + .def("for_ub", &Predictor::for_ub) + .def("__call__", predict_observed) + .def("__call__", predict_observed_with_panel) + .def("__call__", predict_observed_with_panel_list) + .def("for_reflection_table", &Predictor::for_reflection_table) + .def("for_reflection_table", &Predictor::for_reflection_table_with_individual_ub); + } + void export_reflection_predictor() { export_scan_static_reflection_predictor(); export_scan_varying_reflection_predictor(); export_stills_delta_psi_reflection_predictor(); export_nave_stills_reflection_predictor(); export_spherical_relp_stills_reflection_predictor(); + export_laue_reflection_predictor(); } }}} // namespace dials::algorithms::boost_python diff --git a/src/dials/algorithms/spot_prediction/ray_predictor.h b/src/dials/algorithms/spot_prediction/ray_predictor.h index 30d2e307f3..9a225f3251 100644 --- a/src/dials/algorithms/spot_prediction/ray_predictor.h +++ b/src/dials/algorithms/spot_prediction/ray_predictor.h @@ -88,7 +88,7 @@ namespace dials { namespace algorithms { vec2 phi; try { phi = calculate_rotation_angles_(pstar0); - } catch (error const&) { + } catch (error const &) { return rays; } @@ -124,7 +124,7 @@ namespace dials { namespace algorithms { vec2 phi; try { phi = calculate_rotation_angles_(pstar0); - } catch (error const&) { + } catch (error const &) { return rays; } @@ -250,6 +250,63 @@ namespace dials { namespace algorithms { vec3 s0_m2_plane; }; + /** + * Class to predict s1 rays for Laue data + */ + class LaueRayPredictor { + public: + typedef cctbx::miller::index<> miller_index; + + LaueRayPredictor(const vec3 unit_s0, + mat3 fixed_rotation, + mat3 setting_rotation) + : unit_s0_(unit_s0), + fixed_rotation_(fixed_rotation), + setting_rotation_(setting_rotation) + + { + DIALS_ASSERT(unit_s0_.length() > 0.0); + } + + /** + * For a given miller index and UB matrix, calculates the predicted s1 ray. + * The LaueRayPredictor wavelength and s0 variables are updated during the + * calculation, so that they can be monitored for convergence. + * @param h The miller index + * @param ub The UB matrix + * @returns Ray + */ + Ray operator()(const miller_index &h, const mat3 &ub) { + // Calculate the reciprocal lattice vector + vec3 q = setting_rotation_ * fixed_rotation_ * ub * h; + + // Calculate the wavelength required to meet the diffraction condition + // (starting from q.q + 2q.s0 = 0) + wavelength_ = -2 * ((unit_s0_ * q) / (q * q)); + s0_ = unit_s0_ / wavelength_; + DIALS_ASSERT(s0_.length() > 0); + + // Calculate the Ray (default zero angle and 'entering' as false) + vec3 s1 = s0_ + q; + return Ray(s1, 0.0, false); + } + + double get_wavelength() const { + return wavelength_; + } + + vec3 get_s0() const { + return s0_; + } + + private: + const vec3 unit_s0_; + double wavelength_; + vec3 s0_; + mat3 fixed_rotation_; + mat3 setting_rotation_; + }; + }} // namespace dials::algorithms #endif // DIALS_ALGORITHMS_SPOT_PREDICTION_RAY_PREDICTOR_H diff --git a/src/dials/algorithms/spot_prediction/reflection_predictor.h b/src/dials/algorithms/spot_prediction/reflection_predictor.h index 88026ea313..aae3cdfa90 100644 --- a/src/dials/algorithms/spot_prediction/reflection_predictor.h +++ b/src/dials/algorithms/spot_prediction/reflection_predictor.h @@ -1309,6 +1309,322 @@ namespace dials { namespace algorithms { SphericalRelpStillsRayPredictor spherical_relp_predict_ray_; }; + /** + * A class to do Laue reflection prediction. + * Uses LaueRayPredictor to make predictions, and adds additional + * wavelenegth_cal and s0_cal columns to the predicted reflection table. + */ + class LaueReflectionPredictor { + public: + typedef cctbx::miller::index<> miller_index; + + /** + * Initialise the predictor + */ + LaueReflectionPredictor(const PolychromaticBeam &beam, + const Detector &detector, + const Goniometer &goniometer, + mat3 ub, + const cctbx::uctbx::unit_cell &unit_cell, + const cctbx::sgtbx::space_group_type &space_group_type, + const double &dmin) + : beam_(beam), + detector_(detector), + goniometer_(goniometer), + ub_(ub), + unit_cell_(unit_cell), + space_group_type_(space_group_type), + dmin_(dmin), + predict_ray_(beam.get_unit_s0(), + goniometer.get_fixed_rotation(), + goniometer.get_setting_rotation()) {} + + /** + * Predict all reflection. + * @returns reflection table. + */ + af::reflection_table operator()() const { + throw DIALS_ERROR("Not implemented"); + return af::reflection_table(); + } + + af::reflection_table all_reflections_for_asu(Goniometer goniometer, double phi) { + mat3 fixed_rotation = goniometer.get_fixed_rotation(); + mat3 setting_rotation = goniometer.get_setting_rotation(); + vec3 rotation_axis = goniometer.get_rotation_axis(); + mat3 rotation = + scitbx::math::r3_rotation::axis_and_angle_as_matrix(rotation_axis, phi); + vec3 unit_s0 = beam_.get_unit_s0(); + vec2 wavelength_range = beam_.get_wavelength_range(); + + cctbx::miller::index_generator indices = + cctbx::miller::index_generator(unit_cell_, space_group_type_, false, dmin_); + + af::shared indices_arr = indices.to_array(); + + af::reflection_table table; + af::shared wavelength_column; + table["wavelength_cal"] = wavelength_column; + af::shared > s0_column; + table["s0_cal"] = s0_column; + laue_prediction_data predictions(table); + + for (std::size_t i = 0; i < indices_arr.size(); ++i) { + miller_index h = indices_arr[i]; + + vec3 q = setting_rotation * rotation * fixed_rotation * ub_ * h; + + // Calculate the wavelength required to meet the diffraction condition + double wavelength = -2 * ((unit_s0 * q) / (q * q)); + if (wavelength < wavelength_range[0] || wavelength > wavelength_range[1]) { + continue; + } + vec3 s0 = unit_s0 / wavelength; + DIALS_ASSERT(s0.length() > 0); + + // Calculate the Ray (default zero angle and 'entering' as false) + vec3 s1 = s0 + q; + + int panel = detector_.get_panel_intersection(s1); + if (panel == -1) { + continue; + } + + Detector::coord_type coord; + coord.first = panel; + coord.second = detector_[panel].get_ray_intersection(s1); + vec2 mm = coord.second; + vec2 px = detector_[panel].millimeter_to_pixel(mm); + + // Add the reflections to the table + predictions.hkl.push_back(h); + predictions.enter.push_back(false); + predictions.s1.push_back(s1); + predictions.xyz_mm.push_back(vec3(mm[0], mm[1], 0.0)); + predictions.xyz_px.push_back(vec3(px[0], px[1], 0.0)); + predictions.panel.push_back(panel); + predictions.flags.push_back(af::Predicted); + predictions.wavelength_cal.push_back(wavelength); + predictions.s0_cal.push_back(s0); + } + + // Return the reflection table + return table; + } + + /** + * Predict reflections for UB. Also filters based on ewald sphere proximity. + * @param ub The UB matrix + * @returns A reflection table. + */ + af::reflection_table for_ub(const mat3 &ub) { + // Create the reflection table and the local container + af::reflection_table table; + laue_prediction_data predictions(table); + + // Create the index generate and loop through the indices. For each index, + // predict the rays and append to the reflection table + IndexGenerator indices(unit_cell_, space_group_type_, dmin_); + for (;;) { + miller_index h = indices.next(); + if (h.is_zero()) { + break; + } + + Ray ray; + ray = predict_ray_(h, ub); + append_for_index(predictions, ub, h); + } + + // Return the reflection table + return table; + } + + /** + * Predict the reflections with given Miller indices. + * @param h The miller index + * @returns The reflection table + */ + af::reflection_table operator()(const af::const_ref &h) { + af::reflection_table table; + laue_prediction_data predictions(table); + for (std::size_t i = 0; i < h.size(); ++i) { + append_for_index(predictions, ub_, h[i]); + } + return table; + } + + /** + * Predict for given Miller indices on a single panel. + * @param h The array of Miller indices + * @param panel The panel index + * @returns The reflection table + */ + af::reflection_table operator()(const af::const_ref &h, + std::size_t panel) { + af::shared panels(h.size(), panel); + return (*this)(h, panels.const_ref()); + } + + /** + * Predict for given Miller indices for specific panels. + * @param h The array of Miller indices + * @param panel The array of panel indices + * @returns The reflection table + */ + af::reflection_table operator()(const af::const_ref &h, + const af::const_ref &panel) { + DIALS_ASSERT(h.size() == panel.size()); + af::reflection_table table; + laue_prediction_data predictions(table); + for (std::size_t i = 0; i < h.size(); ++i) { + append_for_index(predictions, ub_, h[i], (int)panel[i]); + } + return table; + } + + /** + * Predict reflections for specific Miller indices, panels and individual + * UB matrices + * @param h The array of miller indices + * @param panel The array of panels + * @param ub The array of setting matrices + * @returns A reflection table. + */ + af::reflection_table for_hkl_with_individual_ub( + const af::const_ref &h, + const af::const_ref &panel, + const af::const_ref > &ub) { + DIALS_ASSERT(ub.size() == h.size()); + DIALS_ASSERT(ub.size() == panel.size()); + af::reflection_table table; + af::shared wavelength_column; + table["wavelength_cal"] = wavelength_column; + af::shared > s0_column; + table["s0_cal"] = s0_column; + laue_prediction_data predictions(table); + for (std::size_t i = 0; i < h.size(); ++i) { + append_for_index(predictions, ub[i], h[i], panel[i]); + } + DIALS_ASSERT(table.nrows() == h.size()); + return table; + } + + /** + * Predict reflections and add to the entries in the table for a single UB + * matrix + * @param table The reflection table + * @param ub The ub matrix + */ + void for_reflection_table(af::reflection_table table, const mat3 &ub) { + af::shared > uba(table.nrows(), ub); + for_reflection_table_with_individual_ub(table, uba.const_ref()); + } + + /** + * Predict reflections and add to the entries in the table for an array of + * UB matrices + * @param table The reflection table + */ + void for_reflection_table_with_individual_ub( + af::reflection_table table, + const af::const_ref > &ub) { + DIALS_ASSERT(ub.size() == table.nrows()); + af::reflection_table new_table = + for_hkl_with_individual_ub(table["miller_index"], table["panel"], ub); + DIALS_ASSERT(new_table.nrows() == table.nrows()); + table["miller_index"] = new_table["miller_index"]; + table["panel"] = new_table["panel"]; + table["s1"] = new_table["s1"]; + table["xyzcal.px"] = new_table["xyzcal.px"]; + table["xyzcal.mm"] = new_table["xyzcal.mm"]; + table["wavelength_cal"] = new_table["wavelength_cal"]; + table["s0_cal"] = new_table["s0_cal"]; + + af::shared flags = table["flags"]; + af::shared new_flags = new_table["flags"]; + for (std::size_t i = 0; i < flags.size(); ++i) { + flags[i] &= ~af::Predicted; + flags[i] |= new_flags[i]; + } + DIALS_ASSERT(table.is_consistent()); + } + + protected: + /** + * Predict for the given Miller index, UB matrix and panel number + * @param p The reflection data + * @param ub The UB matrix + * @param h The miller index + * @param panel The panel index + */ + virtual void append_for_index(laue_prediction_data &p, + const mat3 ub, + const miller_index &h, + int panel = -1) { + Ray ray; + ray = predict_ray_(h, ub); + double wavelength = predict_ray_.get_wavelength(); + vec3 s0 = predict_ray_.get_s0(); + append_for_ray(p, h, ray, panel, wavelength, s0); + } + + void append_for_ray(laue_prediction_data &p, + const miller_index &h, + const Ray &ray, + int panel, + double wavelength, + vec3 s0) const { + try { + // Get the impact on the detector + Detector::coord_type impact = get_ray_intersection(ray.s1, panel); + std::size_t panel = impact.first; + vec2 mm = impact.second; + vec2 px = detector_[panel].millimeter_to_pixel(mm); + + // Add the reflections to the table + p.hkl.push_back(h); + p.enter.push_back(ray.entering); + p.s1.push_back(ray.s1); + p.xyz_mm.push_back(vec3(mm[0], mm[1], 0.0)); + p.xyz_px.push_back(vec3(px[0], px[1], 0.0)); + p.panel.push_back(panel); + p.flags.push_back(af::Predicted); + p.wavelength_cal.push_back(wavelength); + p.s0_cal.push_back(s0); + + } catch (dxtbx::error const &) { + // do nothing + } + } + + private: + /** + * Helper function to do ray intersection with/without panel set. + */ + Detector::coord_type get_ray_intersection(vec3 s1, int panel) const { + Detector::coord_type coord; + if (panel < 0) { + coord = detector_.get_ray_intersection(s1); + } else { + coord.first = panel; + coord.second = detector_[panel].get_ray_intersection(s1); + } + return coord; + } + + protected: + PolyBeam beam_; + Detector detector_; + Goniometer goniometer_; + Scan scan_; + mat3 ub_; + cctbx::uctbx::unit_cell unit_cell_; + cctbx::sgtbx::space_group_type space_group_type_; + const double dmin_; + LaueRayPredictor predict_ray_; + }; + }} // namespace dials::algorithms #endif // DIALS_ALGORITHMS_SPOT_PREDICTION_REFLECTION_PREDICTOR_H