Skip to content

Commit

Permalink
Refactored scaling corrections.
Browse files Browse the repository at this point in the history
  • Loading branch information
toastisme committed Jul 23, 2024
1 parent 827ea15 commit 9a1de20
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 51 deletions.
14 changes: 12 additions & 2 deletions src/dials/algorithms/indexing/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,12 +995,22 @@ def index_reflections(self, experiments, reflections):
def refine(self, experiments, reflections):
from dials.algorithms.indexing.refinement import refine

properties_to_save = [
"xyzcal.mm",
"entering",
"wavelength_cal",
"s0_cal",
"tof_cal",
]

refiner, refined, outliers = refine(self.all_params, reflections, experiments)
if outliers is not None:
reflections["id"].set_selected(outliers, -1)
predicted = refiner.predict_for_indexed()
reflections["xyzcal.mm"] = predicted["xyzcal.mm"]
reflections["entering"] = predicted["entering"]
for i in properties_to_save:
if i in predicted:
reflections[i] = predicted[i]

reflections.unset_flags(
flex.bool(len(reflections), True), reflections.flags.centroid_outlier
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ namespace dials { namespace algorithms { namespace boost_python {
}
};

void (*test_func_a)(double) = &test_func;
void (*test_func_b)(double, int) = &test_func;

BOOST_PYTHON_MODULE(dials_algorithms_integration_integrator_ext) {
class_<GroupList::Group>("Group", no_init)
.def("index", &GroupList::Group::index)
Expand Down Expand Up @@ -375,9 +378,16 @@ namespace dials { namespace algorithms { namespace boost_python {
arg("apply_lorentz_correction"),
arg("apply_spherical_absorption_correction")));

def("tof_extract_shoeboxes_to_reflection_table_no_corrections",
&tof_extract_shoeboxes_to_reflection_table_no_corrections,
(arg("reflection_table"), arg("experiment"), arg("data")));

def("tof_calculate_shoebox_foreground",
&tof_calculate_shoebox_foreground,
(arg("reflection_table"), arg("experiment"), arg("foreground_radius")));

def("test_func", test_func_a);
def("test_func", test_func_b);
}

}}} // namespace dials::algorithms::boost_python
9 changes: 8 additions & 1 deletion src/dials/algorithms/scaling/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,11 @@ Python_add_library(
boost_python/scaling_helper.cc
boost_python/scaling_ext.cc
)
target_link_libraries( dials_scaling_ext PUBLIC CCTBX::cctbx Boost::python )
Python_add_library(
dials_tof_scaling_ext
MODULE
tof/boost_python/tof_scaling.cc
)

target_link_libraries( dials_scaling_ext PUBLIC CCTBX::cctbx Boost::python )
target_link_libraries( dials_tof_scaling_ext PUBLIC CCTBX::cctbx Boost::python )
52 changes: 51 additions & 1 deletion src/dials/algorithms/scaling/tof_scaling_corrections.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,48 @@ namespace dials { namespace algorithms {
incident_absorption_x_section(incident_absorption_x_section),
incident_number_density(incident_number_density) {}
};
void tof_extract_shoeboxes_to_reflection_table_no_corrections(
af::reflection_table &reflection_table,
Experiment &experiment,
ImageSequence &data) {
Detector detector = *experiment.get_detector();
Scan scan = *experiment.get_scan();

std::shared_ptr<dxtbx::model::BeamBase> beam_ptr = experiment.get_beam();
std::shared_ptr<PolychromaticBeam> beam =
std::dynamic_pointer_cast<PolychromaticBeam>(beam_ptr);
DIALS_ASSERT(beam != nullptr);

vec3<double> unit_s0 = beam->get_unit_s0();
double sample_to_source_distance = beam->get_sample_to_source_distance();

scitbx::af::shared<double> img_tof = scan.get_property<double>("time_of_flight");

int n_panels = detector.size();
int num_images = data.size();
vec2<std::size_t> image_size = detector[0].get_image_size();
DIALS_ASSERT(num_images == img_tof.size());

ShoeboxProcessor shoebox_processor(
reflection_table, n_panels, 0, num_images, false);

for (std::size_t img_num = 0; img_num < num_images; ++img_num) {
dxtbx::format::Image<double> img = data.get_corrected_data(img_num);
dxtbx::format::Image<bool> mask = data.get_mask(img_num);

af::shared<scitbx::af::versa<double, scitbx::af::c_grid<2> > > output_data(
n_panels);
af::shared<scitbx::af::versa<bool, scitbx::af::c_grid<2> > > output_mask(
n_panels);

for (std::size_t i = 0; i < output_data.size(); ++i) {
output_data[i] = img.tile(i).data();
output_mask[i] = mask.tile(i).data();
}
shoebox_processor.next_data_only(
model::Image<double>(output_data.const_ref(), output_mask.const_ref()));
}
}

void tof_extract_shoeboxes_to_reflection_table(
af::reflection_table &reflection_table,
Expand Down Expand Up @@ -578,6 +620,14 @@ namespace dials { namespace algorithms {
}
}

void test_func(double test_double) {
std::cout << "test func " << test_double << "\n";
}

void test_func(double test_double, int test_int) {
std::cout << "test_func " << test_double << " test_int " << test_int << "\n";
}

}} // namespace dials::algorithms

#endif /* DIALS_ALGORITHMS_SCALING_TOF_SCALING_CORRECTIONS_H */
#endif /* DIALS_ALGORITHMS_SCALING_TOF_SCALING_CORRECTIONS_H */
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <dials/algorithms/spot_prediction/stills_ray_predictor.h>
#include <dials/algorithms/spot_prediction/ray_intersection.h>
#include <cctbx/miller/index_generator.h>
#include <iostream>

namespace dials { namespace algorithms {

Expand Down
143 changes: 96 additions & 47 deletions src/dials/command_line/tof_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from dials.util.options import ArgumentParser, reflections_and_experiments_from_files
from dials.util.phil import parse
from dials.util.version import dials_version
from dials_algorithms_integration_integrator_ext import (
from dials_algorithms_integration_integrator_ext import tof_calculate_shoebox_foreground
from dials_tof_scaling_ext import (
TOFCorrectionsData,
tof_calculate_shoebox_foreground,
tof_extract_shoeboxes_to_reflection_table,
)

Expand Down Expand Up @@ -73,9 +73,6 @@
lorentz = True
.type = bool
.help = "Apply the Lorentz correction to target spectrum."
spherical_absorption = True
.type = bool
.help = "Apply a spherical absorption correction."
}
incident_spectrum{
sample_number_density = 0.0722
Expand Down Expand Up @@ -241,7 +238,6 @@ def join_reflections(list_of_reflections):


def run():

"""
Input setup
"""
Expand Down Expand Up @@ -278,6 +274,37 @@ def run():
output_reflections_as_hkl(integrated_reflections, params.output.hkl)


def applying_spherical_absorption_correction(params):
all_params_present = True
some_params_present = False
for i in dir(params.target_spectrum):
if i.startswith("__"):
continue
if getattr(params.target_spectrum, i) is not None:
some_params_present = True
else:
all_params_present = False
if some_params_present and not all_params_present:
raise ValueError(
"Trying to apply spherical absorption correction but some corrections are None."
)
return all_params_present


def applying_incident_and_empty_runs(params):
if params.input.incident_run is not None:
assert (
params.input.empty_run is not None
), "Incident run given without empty run."
return True
elif params.input.empty_run is not None:
assert (
params.input.incident_run is not None
), "Empty run given without incident run."
return True
return False


def run_integrate(params, experiments, reflections):
nproc = params.mp.nproc
if nproc is libtbx.Auto:
Expand Down Expand Up @@ -385,54 +412,76 @@ def run_integrate(params, experiments, reflections):
)

experiment_cls = experiments[0].imageset.get_format_class()
incident_fmt_class = experiment_cls.get_instance(params.input.incident_run)
empty_fmt_class = experiment_cls.get_instance(params.input.empty_run)

incident_data = experiment_cls(params.input.incident_run).get_imageset(
params.input.incident_run
)
empty_data = experiment_cls(params.input.empty_run).get_imageset(
params.input.empty_run
)
incident_proton_charge = incident_fmt_class.get_proton_charge()
empty_proton_charge = empty_fmt_class.get_proton_charge()

predicted_reflections.map_centroids_to_reciprocal_space(
experiments, calculated=True
)

for expt in experiments:

expt_proton_charge = experiment_cls.get_instance(
expt.imageset.paths()[0], **expt.imageset.data().get_params()
).get_proton_charge()
expt_data = expt.imageset
corrections_data = TOFCorrectionsData(
expt_proton_charge,
incident_proton_charge,
empty_proton_charge,
params.target_spectrum.sample_radius,
params.target_spectrum.scattering_x_section,
params.target_spectrum.absorption_x_section,
params.target_spectrum.sample_number_density,
params.incident_spectrum.sample_radius,
params.incident_spectrum.scattering_x_section,
params.incident_spectrum.absorption_x_section,
params.incident_spectrum.sample_number_density,
)
if applying_incident_and_empty_runs(params):
incident_fmt_class = experiment_cls.get_instance(params.input.incident_run)
empty_fmt_class = experiment_cls.get_instance(params.input.empty_run)

tof_extract_shoeboxes_to_reflection_table(
predicted_reflections,
expt,
expt_data,
incident_data,
empty_data,
corrections_data,
params.corrections.lorentz,
params.corrections.spherical_absorption,
incident_data = experiment_cls(params.input.incident_run).get_imageset(
params.input.incident_run
)
empty_data = experiment_cls(params.input.empty_run).get_imageset(
params.input.empty_run
)
incident_proton_charge = incident_fmt_class.get_proton_charge()
empty_proton_charge = empty_fmt_class.get_proton_charge()

for expt in experiments:
expt_data = expt.imageset
expt_proton_charge = experiment_cls.get_instance(
expt.imageset.paths()[0], **expt.imageset.data().get_params()
).get_proton_charge()

if applying_spherical_absorption_correction(params):
corrections_data = TOFCorrectionsData(
expt_proton_charge,
incident_proton_charge,
empty_proton_charge,
params.target_spectrum.sample_radius,
params.target_spectrum.scattering_x_section,
params.target_spectrum.absorption_x_section,
params.target_spectrum.sample_number_density,
params.incident_spectrum.sample_radius,
params.incident_spectrum.scattering_x_section,
params.incident_spectrum.absorption_x_section,
params.incident_spectrum.sample_number_density,
)

tof_extract_shoeboxes_to_reflection_table(
predicted_reflections,
expt,
expt_data,
incident_data,
empty_data,
corrections_data,
params.corrections.lorentz,
)
else:
tof_extract_shoeboxes_to_reflection_table(
predicted_reflections,
expt,
expt_data,
incident_data,
empty_data,
expt_proton_charge,
incident_proton_charge,
empty_proton_charge,
params.corrections.lorentz,
)
else:
for expt in experiments:
expt_data = expt.imageset
tof_extract_shoeboxes_to_reflection_table(
predicted_reflections,
expt,
expt_data,
params.corrections.lorentz,
)

tof_calculate_shoebox_foreground(predicted_reflections, expt, 0.5)
tof_calculate_shoebox_foreground(predicted_reflections, expt, 0.5)
predicted_reflections.is_overloaded(experiments)
predicted_reflections.contains_invalid_pixels()
predicted_reflections["partiality"] = flex.double(len(predicted_reflections), 1.0)
Expand Down

0 comments on commit 9a1de20

Please sign in to comment.