Skip to content

Commit

Permalink
Add LaueExperimentPredictor, TOFExperimentPredictor, LaueRayPredictor
Browse files Browse the repository at this point in the history
  • Loading branch information
toastisme committed Mar 20, 2024
1 parent b85fec4 commit 8914057
Show file tree
Hide file tree
Showing 5 changed files with 490 additions and 4 deletions.
70 changes: 68 additions & 2 deletions src/dials/algorithms/refinement/prediction/managed_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@

from math import pi

from dxtbx.model import tof_helpers
from scitbx.array_family import flex

from dials.algorithms.spot_prediction import ScanStaticRayPredictor
from dials.algorithms.spot_prediction import (
LaueReflectionPredictor,
ScanStaticRayPredictor,
)
from dials.algorithms.spot_prediction import ScanStaticReflectionPredictor as sc
from dials.algorithms.spot_prediction import ScanVaryingReflectionPredictor as sv
from dials.algorithms.spot_prediction import StillsReflectionPredictor as st
Expand Down Expand Up @@ -83,6 +87,7 @@ def __call__(self, reflections):
refs = reflections.select(sel)

self._predict_one_experiment(e, refs)
refs = self._post_predict_one_experiment(e, refs)

# write predictions back to overall reflections
reflections.set_selected(sel, refs)
Expand All @@ -95,6 +100,9 @@ def _predict_one_experiment(self, experiment, reflections):

raise NotImplementedError()

def _post_predict_one_experiment(self, experiment, reflections):
return reflections

def _post_prediction(self, reflections):
"""Perform tasks on the whole reflection list after prediction before
returning."""
Expand Down Expand Up @@ -165,6 +173,50 @@ def _predict_one_experiment(self, experiment, reflections):
predictor.for_reflection_table(reflections, UB)


class LaueExperimentsPredictor(ExperimentsPredictor):
def _predict_one_experiment(self, experiment, reflections):

min_s0_idx = min(
range(len(reflections["wavelength"])),
key=reflections["wavelength"].__getitem__,
)

if "s0" not in reflections:
unit_s0 = experiment.beam.get_unit_s0()
wl = reflections["wavelength"][min_s0_idx]
min_s0 = (unit_s0[0] / wl, unit_s0[1] / wl, unit_s0[2] / wl)
else:
min_s0 = reflections["s0"][min_s0_idx]

dmin = experiment.detector.get_max_resolution(min_s0)
predictor = LaueReflectionPredictor(experiment, dmin)
UB = experiment.crystal.get_A()
predictor.for_reflection_table(reflections, UB)


class TOFExperimentsPredictor(LaueExperimentsPredictor):
def _post_predict_one_experiment(self, experiment, reflections):

# Add ToF to xyzcal.mm
wavelength_cal = reflections["wavelength_cal"]
tof_cal = tof_helpers.tof_from_wavelength(wavelength_cal) # (s)
x, y, z = reflections["xyzcal.mm"].parts()
reflections["xyzcal.mm"] = flex.vec3_double(x, y, tof_cal)
tof_cal = tof_cal * 1e6 # (usec)

# Add frame to xyzcal.px
expt_tof = experiment.scan.get_property("time_of_flight") # (usec)
frames = [i + 1 for i in range(len(expt_tof))]
tof_to_frame = tof_helpers.tof_to_frame_interpolator(expt_tof, frames)
reflection_frames = flex.double(tof_to_frame(tof_cal))
px, py, pz = reflections["xyzcal.px"].parts()
reflections["xyzcal.px"] = flex.vec3_double(px, py, reflection_frames)
if "xyzobs.mm.value" in reflections:
reflections = self._match_full_turns(reflections)

return reflections


class ExperimentsPredictorFactory:
@staticmethod
def from_experiments(experiments, force_stills=False, spherical_relp=False):
Expand All @@ -180,7 +232,21 @@ def from_experiments(experiments, force_stills=False, spherical_relp=False):
if force_stills:
predictor = StillsExperimentsPredictor(experiments)
predictor.spherical_relp_model = spherical_relp

else:
predictor = ScansExperimentsPredictor(experiments)

all_tof_experiments = False
for expt in experiments:
if expt.scan is not None and expt.scan.has_property("time_of_flight"):
all_tof_experiments = True
elif all_tof_experiments:
raise ValueError(
"Cannot find max cell for ToF and non-ToF experiments at the same time"
)

if all_tof_experiments:
predictor = TOFExperimentsPredictor(experiments)
else:
predictor = ScansExperimentsPredictor(experiments)

return predictor
14 changes: 14 additions & 0 deletions src/dials/algorithms/spot_prediction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"StillsDeltaPsiReflectionPredictor",
"StillsRayPredictor",
"StillsReflectionPredictor",
"LaueReflectionPredictor",
]


Expand Down Expand Up @@ -155,3 +156,16 @@ def StillsReflectionPredictor(experiment, dmin=None, spherical_relp=False, **kwa
experiment.crystal.get_space_group().type(),
dmin,
)


def LaueReflectionPredictor(experiment, dmin: float):

return dials_algorithms_spot_prediction_ext.LaueReflectionPredictor(
experiment.beam,
experiment.detector,
experiment.goniometer,
experiment.crystal.get_A(),
experiment.crystal.get_unit_cell(),
experiment.crystal.get_space_group().type(),
dmin,
)
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,45 @@ namespace dials { namespace algorithms { namespace boost_python {
.def("for_reflection_table", &Predictor::for_reflection_table_with_individual_ub);
}

void export_laue_reflection_predictor() {
typedef LaueReflectionPredictor Predictor;

af::reflection_table (Predictor::*predict_all)() const = &Predictor::operator();

af::reflection_table (Predictor::*predict_observed)(
const af::const_ref<cctbx::miller::index<> >&) = &Predictor::operator();

af::reflection_table (Predictor::*predict_observed_with_panel)(
const af::const_ref<cctbx::miller::index<> >&, std::size_t) =
&Predictor::operator();

af::reflection_table (Predictor::*predict_observed_with_panel_list)(
const af::const_ref<cctbx::miller::index<> >&,
const af::const_ref<std::size_t>&) = &Predictor::operator();

class_<Predictor>("LaueReflectionPredictor", no_init)
.def(init<const vec3<double>&,
const Detector&,
mat3<double>,
const cctbx::uctbx::unit_cell&,
const cctbx::sgtbx::space_group_type&,
const double&>())
.def("__call__", predict_all)
.def("for_ub", &Predictor::for_ub)
.def("__call__", predict_observed)
.def("__call__", predict_observed_with_panel)
.def("__call__", predict_observed_with_panel_list)
.def("for_reflection_table", &Predictor::for_reflection_table)
.def("for_reflection_table", &Predictor::for_reflection_table_with_individual_ub);
}

void export_reflection_predictor() {
export_scan_static_reflection_predictor();
export_scan_varying_reflection_predictor();
export_stills_delta_psi_reflection_predictor();
export_nave_stills_reflection_predictor();
export_spherical_relp_stills_reflection_predictor();
export_laue_reflection_predictor();
}

}}} // namespace dials::algorithms::boost_python
61 changes: 59 additions & 2 deletions src/dials/algorithms/spot_prediction/ray_predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ namespace dials { namespace algorithms {
vec2<double> phi;
try {
phi = calculate_rotation_angles_(pstar0);
} catch (error const&) {
} catch (error const &) {
return rays;
}

Expand Down Expand Up @@ -124,7 +124,7 @@ namespace dials { namespace algorithms {
vec2<double> phi;
try {
phi = calculate_rotation_angles_(pstar0);
} catch (error const&) {
} catch (error const &) {
return rays;
}

Expand Down Expand Up @@ -250,6 +250,63 @@ namespace dials { namespace algorithms {
vec3<double> s0_m2_plane;
};

/**
* Class to predict s1 rays for Laue data
*/
class LaueRayPredictor {
public:
typedef cctbx::miller::index<> miller_index;

LaueRayPredictor(const vec3<double> unit_s0,
mat3<double> fixed_rotation,
mat3<double> setting_rotation)
: unit_s0_(unit_s0),
fixed_rotation_(fixed_rotation),
setting_rotation_(setting_rotation)

{
DIALS_ASSERT(unit_s0_.length() > 0.0);
}

/**
* For a given miller index and UB matrix, calculates the predicted s1 ray.
* The LaueRayPredictor wavelength and s0 variables are updated during the
* calculation, so that they can be monitored for convergence.
* @param h The miller index
* @param ub The UB matrix
* @returns Ray
*/
Ray operator()(const miller_index &h, const mat3<double> &ub) {
// Calculate the reciprocal lattice vector
vec3<double> q = setting_rotation_ * fixed_rotation_ * ub * h;

// Calculate the wavelength required to meet the diffraction condition
// (starting from q.q + 2q.s0 = 0)
wavelength_ = -2 * ((unit_s0_ * q) / (q * q));
s0_ = unit_s0_ / wavelength_;
DIALS_ASSERT(s0_.length() > 0);

// Calculate the Ray (default zero angle and 'entering' as false)
vec3<double> s1 = s0_ + q;
return Ray(s1, 0.0, false);
}

double get_wavelength() const {
return wavelength_;
}

vec3<double> get_s0() const {
return s0_;
}

private:
const vec3<double> unit_s0_;
double wavelength_;
vec3<double> s0_;
mat3<double> fixed_rotation_;
mat3<double> setting_rotation_;
};

}} // namespace dials::algorithms

#endif // DIALS_ALGORITHMS_SPOT_PREDICTION_RAY_PREDICTOR_H
Loading

0 comments on commit 8914057

Please sign in to comment.