Skip to content

Commit

Permalink
Added LauePredictionParameterisation
Browse files Browse the repository at this point in the history
  • Loading branch information
toastisme committed Mar 20, 2024
1 parent 8914057 commit 89fa4bf
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/dials/algorithms/refinement/parameterisation/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
phil_str as sv_phil_str,
)
from dials.algorithms.refinement.refinement_helpers import string_sel
from dials.algorithms.refinement.reflection_mamager import LaueReflectionManager
from dials.algorithms.refinement.restraints.restraints_parameterisation import (
uc_phil_str as uc_restraints_phil_str,
)
Expand All @@ -31,6 +32,7 @@
)
from .goniometer_parameters import GoniometerParameterisation
from .prediction_parameters import (
LauePredictionParameterisation,
XYPhiPredictionParameterisation,
XYPhiPredictionParameterisationSparse,
)
Expand Down Expand Up @@ -823,6 +825,16 @@ def build_prediction_parameterisation(
det_params = _parameterise_detectors(options, experiments, analysis)
gon_params = _parameterise_goniometers(options, experiments, analysis)

beam_params = []

if isinstance(reflection_manager, LaueReflectionManager):
PredParam = LauePredictionParameterisation
return PredParam(
experiments, det_params, beam_params, xl_ori_params, xl_uc_params
)

beam_params = _parameterise_beams(options, experiments, analysis)

# Build the prediction equation parameterisation
if do_stills: # doing stills
if options.sparse:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -967,3 +967,104 @@ class XYPhiPredictionParameterisationSparse(
@staticmethod
def _extend_gradient_vectors(results, m, n, keys=("dX_dp", "dY_dp", "dZ_dp")):
return SparseGradientVectorMixin._extend_gradient_vectors(results, m, n, keys)


class LauePredictionParameterisation(PredictionParameterisation):

"""A basic extension to PredictionParameterisation for ToF data,
where only panel positions are considered."""

_grad_names = ("dX_dp", "dY_dp", "dwavelength_dp")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
return

def _local_setup(self, reflections):
self._wavelength = reflections["wavelength_cal"]
self._r = self._UB * self._h
self._s0 = reflections["s0_cal"]
return

def _xl_derivatives(self, isel, derivatives, b_matrix, parameterisation=None):
"""helper function to extend the derivatives lists by derivatives of
generic parameterisations."""

# Get required data
h = self._h.select(isel)
if b_matrix:
B = self._B.select(isel)
else:
U = self._U.select(isel)
D = self._D.select(isel)
s1 = self._s1.select(isel)
s0 = self._s0.select(isel)
wavelength = self._wavelength.select(isel)

if derivatives is None:
# get derivatives of the B/U matrix wrt the parameters
derivatives = [
None if der is None else flex.mat3_double(len(isel), der.elems)
for der in parameterisation.get_ds_dp(use_none_as_null=True)
]

dpv_dp = []
dwavelength_dp = []

# loop through the parameters
for der in derivatives:
if der is None:
dpv_dp.append(None)
dwavelength_dp.append(None)
continue

# calculate the derivative of r for this parameter
if b_matrix:
dr = der * B * h
else:
dr = U * der * h

dwavelength = (-wavelength) * (dr.dot(s1)) / (s0.dot(s0))
dwavelength_dp.append(dwavelength)
# calculate the derivative of pv for this parameter
dpv_dp.append(D * (dr + (s0 / wavelength) * dwavelength))

return dpv_dp, dwavelength_dp

def _xl_orientation_derivatives(
self, isel, parameterisation=None, dU_dxlo_p=None, reflections=None
):
"""helper function to extend the derivatives lists by derivatives of the
crystal orientation parameterisations"""
return self._xl_derivatives(
isel, dU_dxlo_p, b_matrix=True, parameterisation=parameterisation
)

def _xl_unit_cell_derivatives(
self, isel, parameterisation=None, dB_dxluc_p=None, reflections=None
):
"""helper function to extend the derivatives lists by
derivatives of the crystal unit cell parameterisations"""
return self._xl_derivatives(
isel, dB_dxluc_p, b_matrix=False, parameterisation=parameterisation
)

@staticmethod
def _calc_dX_dp_and_dY_dp_from_dpv_dp(w_inv, u_w_inv, v_w_inv, dpv_dp):
"""helper function to calculate positional derivatives from
dpv_dp using the quotient rule"""

dX_dp = []
dY_dp = []

for der in dpv_dp:
if der is None:
dX_dp.append(None)
dY_dp.append(None)
else:
du_dp, dv_dp, dw_dp = der.parts()

dX_dp.append(w_inv * (du_dp - dw_dp * u_w_inv))
dY_dp.append(w_inv * (dv_dp - dw_dp * v_w_inv))

return dX_dp, dY_dp
1 change: 1 addition & 0 deletions src/dials/algorithms/refinement/refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def _build_reflection_manager_and_predictor(cls, params, reflections, experiment
obs["x_resid"] = x_calc - x_obs
obs["y_resid"] = y_calc - y_obs
obs["phi_resid"] = phi_calc - phi_obs
refman.update_residuals()

# determine whether to do basic centroid analysis to automatically
# determine outlier rejection block
Expand Down
7 changes: 7 additions & 0 deletions src/dials/algorithms/refinement/reflection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,13 @@ def filter_obs(self, sel):
self._reflections = self._reflections.select(sel)
return self._reflections

def update_residuals(self):
x_obs, y_obs, phi_obs = self._reflections["xyzobs.mm.value"].parts()
x_calc, y_calc, phi_calc = self._reflections["xyzcal.mm"].parts()
self._reflections["x_resid"] = x_calc - x_obs
self._reflections["y_resid"] = y_calc - y_obs
self._reflections["phi_resid"] = phi_calc - phi_obs


class StillsReflectionManager(ReflectionManager):
"""Overloads for a Reflection Manager that does not exclude
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace dials { namespace algorithms { namespace boost_python {
void export_rotation_angles();
void export_ray_predictor();
void export_scan_varying_ray_predictor();
void export_laue_ray_predictor();
void export_stills_ray_predictor();
void export_ray_intersection();
void export_reflection_predictor();
Expand All @@ -40,6 +41,7 @@ namespace dials { namespace algorithms { namespace boost_python {
export_ray_predictor();
export_scan_varying_ray_predictor();
export_stills_ray_predictor();
export_laue_ray_predictor();
export_ray_intersection();
export_reflection_predictor();
export_pixel_labeller();
Expand Down
10 changes: 10 additions & 0 deletions src/dials/algorithms/spot_prediction/reflection_predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ namespace dials { namespace algorithms {
}
};

struct laue_prediction_data : prediction_data {
af::shared<double> wavelength_cal;
af::shared<vec3<double> > s0_cal;

laue_prediction_data(af::reflection_table &table) : prediction_data(table) {
wavelength_cal = table.get<double>("wavelength_cal");
s0_cal = table.get<vec3<double> >("s0_cal");
}
};

/**
* A reflection predictor for scan static prediction.
*/
Expand Down

0 comments on commit 89fa4bf

Please sign in to comment.