From 8e35c8dee8243e2c4eef16ffca717cfb7a7909d2 Mon Sep 17 00:00:00 2001 From: davidmcdonagh Date: Thu, 14 Dec 2023 15:10:02 +0000 Subject: [PATCH] Enabled export mtz for tof data. --- src/dials/command_line/export.py | 48 +++-- src/dials/util/export_mtz.py | 337 ++++++++++++++++++++++++++++++- 2 files changed, 364 insertions(+), 21 deletions(-) diff --git a/src/dials/command_line/export.py b/src/dials/command_line/export.py index 007e58fc1b..0e3e9e304d 100644 --- a/src/dials/command_line/export.py +++ b/src/dials/command_line/export.py @@ -333,7 +333,7 @@ def export_mtz(params, experiments, reflections): _check_input(experiments, reflections) - from dials.util.export_mtz import export_mtz + from dials.util.export_mtz import export_mtz, export_mtz_tof # Handle case where user has passed data before integration if ( @@ -359,21 +359,37 @@ def export_mtz(params, experiments, reflections): "Data appears to be unscaled, setting mtz.hklout = 'integrated.mtz'" ) - m = export_mtz( - reflection_table, - experiments, - intensity_choice=params.intensity, - filename=filename, - best_unit_cell=params.mtz.best_unit_cell, - partiality_threshold=params.mtz.partiality_threshold, - combine_partials=params.mtz.combine_partials, - min_isigi=params.mtz.min_isigi, - filter_ice_rings=params.mtz.filter_ice_rings, - d_min=params.mtz.d_min, - force_static_model=params.mtz.force_static_model, - crystal_name=params.mtz.crystal_name, - project_name=params.mtz.project_name, - ) + if experiments.all_tof_experiments(): + m = export_mtz_tof( + reflection_table, + experiments, + intensity_choice=params.intensity, + filename=filename, + best_unit_cell=params.mtz.best_unit_cell, + partiality_threshold=params.mtz.partiality_threshold, + combine_partials=params.mtz.combine_partials, + min_isigi=params.mtz.min_isigi, + filter_ice_rings=params.mtz.filter_ice_rings, + d_min=params.mtz.d_min, + crystal_name=params.mtz.crystal_name, + project_name=params.mtz.project_name, + ) + else: + m = export_mtz( + reflection_table, + experiments, + intensity_choice=params.intensity, + filename=filename, + best_unit_cell=params.mtz.best_unit_cell, + partiality_threshold=params.mtz.partiality_threshold, + combine_partials=params.mtz.combine_partials, + min_isigi=params.mtz.min_isigi, + filter_ice_rings=params.mtz.filter_ice_rings, + d_min=params.mtz.d_min, + force_static_model=params.mtz.force_static_model, + crystal_name=params.mtz.crystal_name, + project_name=params.mtz.project_name, + ) summary = StringIO() m.show_summary(out=summary) diff --git a/src/dials/util/export_mtz.py b/src/dials/util/export_mtz.py index 9536a5bf32..5be2ce1b69 100644 --- a/src/dials/util/export_mtz.py +++ b/src/dials/util/export_mtz.py @@ -7,7 +7,7 @@ import numpy as np -from cctbx import uctbx +from dxtbx.model import Scan from iotbx import mtz from libtbx import env from rstbx.cftbx.coordinate_frame_helpers import align_reference_frame @@ -169,14 +169,13 @@ def add_batch_list( UBlab = S * F * matrix.sqr(experiment.crystal.get_A()) axis = matrix.col(experiment.goniometer.get_rotation_axis()) - axis_datum = matrix.col(experiment.goniometer.get_rotation_axis_datum()) else: UBlab = matrix.sqr(experiment.crystal.get_A()) i0 = image_range[0] for i in range(n_batches): - if experiment.sequence: + if isinstance(experiment.sequence, Scan): phi_start[i], phi_range[i] = experiment.sequence.get_image_oscillation( i + i0 ) @@ -190,8 +189,7 @@ def add_batch_list( # Get the index of the image in the sequence e.g. first => 0, second => 1 image_index = i + i0 - experiment.sequence.get_image_range()[0] - _unit_cell = experiment.crystal.get_unit_cell_at_scan_point(image_index) - _U = matrix.sqr(experiment.crystal.get_U_at_scan_point(image_index)) + unit_cell = experiment.crystal.get_unit_cell_at_scan_point(image_index) else: unit_cell = experiment.crystal.get_unit_cell() _UBlab = UBlab @@ -248,6 +246,159 @@ def add_batch_list( source, ) + def write_columns_tof(self, reflection_table): + """Write the column definitions AND data to the current dataset.""" + + # now create the actual data structures - first keep a track of the columns + + # H K L M/ISYM BATCH I SIGI IPR SIGIPR FRACTIONCALC XDET YDET ROT WIDTH + # LP MPART FLAG BGPKRATIOS + + # gather the required information for the reflection file + + nref = len(reflection_table["miller_index"]) + assert nref + xdet, ydet, _ = [ + flex.double(x) for x in reflection_table["xyzobs.px.value"].parts() + ] + + # now add column information... + + # FIXME add DIALS_FLAG which can include e.g. was partial etc. + + type_table = { + "H": "H", + "K": "H", + "L": "H", + "I": "J", + "SIGI": "Q", + "IPR": "J", + "SIGIPR": "Q", + "BG": "R", + "SIGBG": "R", + "XDET": "R", + "YDET": "R", + "BATCH": "B", + "BGPKRATIOS": "R", + "WIDTH": "R", + "MPART": "I", + "M_ISYM": "Y", + "FLAG": "I", + "LP": "R", + "FRACTIONCALC": "R", + "ROT": "R", + "QE": "R", + "LAMBDA": "R", + } + + # derive index columns from original indices with + # + # from m.replace_original_index_miller_indices + # + # so all that is needed now is to make space for the reflections - fill with + # zeros... + + self.mtz_file.adjust_column_array_sizes(nref) + self.mtz_file.set_n_reflections(nref) + dataset = self.current_dataset + + # assign H, K, L, M_ISYM space + for column in "H", "K", "L", "M_ISYM": + dataset.add_column(column, type_table[column]).set_values( + flex.double(nref, 0.0).as_float() + ) + + self.mtz_file.replace_original_index_miller_indices( + reflection_table["miller_index"] + ) + + dataset.add_column("BATCH", type_table["BATCH"]).set_values( + reflection_table["batch"].as_double().as_float() + ) + + # if intensity values used in scaling exist, then just export these as I, SIGI + if "intensity.scale.value" in reflection_table: + I_scaling = reflection_table["intensity.scale.value"] + V_scaling = reflection_table["intensity.scale.variance"] + # Trap negative variances + assert V_scaling.all_gt(0) + dataset.add_column("I", type_table["I"]).set_values(I_scaling.as_float()) + dataset.add_column("SIGI", type_table["SIGI"]).set_values( + flex.sqrt(V_scaling).as_float() + ) + dataset.add_column("SCALEUSED", "R").set_values( + reflection_table["inverse_scale_factor"].as_float() + ) + dataset.add_column("SIGSCALEUSED", "R").set_values( + flex.sqrt(reflection_table["inverse_scale_factor_variance"]).as_float() + ) + else: + if "intensity.prf.value" in reflection_table: + if "intensity.sum.value" in reflection_table: + col_names = ("IPR", "SIGIPR") + else: + col_names = ("I", "SIGI") + I_profile = reflection_table["intensity.prf.value"] + V_profile = reflection_table["intensity.prf.variance"] + # Trap negative variances + assert V_profile.all_gt(0) + dataset.add_column(col_names[0], type_table["I"]).set_values( + I_profile.as_float() + ) + dataset.add_column(col_names[1], type_table["SIGI"]).set_values( + flex.sqrt(V_profile).as_float() + ) + if "intensity.sum.value" in reflection_table: + I_sum = reflection_table["intensity.sum.value"] + V_sum = reflection_table["intensity.sum.variance"] + # Trap negative variances + assert V_sum.all_gt(0) + dataset.add_column("I", type_table["I"]).set_values(I_sum.as_float()) + dataset.add_column("SIGI", type_table["SIGI"]).set_values( + flex.sqrt(V_sum).as_float() + ) + if ( + "background.sum.value" in reflection_table + and "background.sum.variance" in reflection_table + ): + bg = reflection_table["background.sum.value"] + varbg = reflection_table["background.sum.variance"] + assert (varbg >= 0).count(False) == 0 + sigbg = flex.sqrt(varbg) + dataset.add_column("BG", type_table["BG"]).set_values(bg.as_float()) + dataset.add_column("SIGBG", type_table["SIGBG"]).set_values( + sigbg.as_float() + ) + + dataset.add_column("FRACTIONCALC", type_table["FRACTIONCALC"]).set_values( + reflection_table["fractioncalc"].as_float() + ) + dataset.add_column("LAMBDA", type_table["LAMBDA"]).set_values( + reflection_table["wavelength_cal"].as_float() + ) + + dataset.add_column("XDET", type_table["XDET"]).set_values(xdet.as_float()) + dataset.add_column("YDET", type_table["YDET"]).set_values(ydet.as_float()) + dataset.add_column("ROT", type_table["ROT"]).set_values( + reflection_table["ROT"].as_float() + ) + if "lp" in reflection_table: + dataset.add_column("LP", type_table["LP"]).set_values( + reflection_table["lp"].as_float() + ) + if "qe" in reflection_table: + dataset.add_column("QE", type_table["QE"]).set_values( + reflection_table["qe"].as_float() + ) + elif "dqe" in reflection_table: + dataset.add_column("QE", type_table["QE"]).set_values( + reflection_table["dqe"].as_float() + ) + else: + dataset.add_column("QE", type_table["QE"]).set_values( + flex.double(nref, 1.0).as_float() + ) + def write_columns(self, reflection_table): """Write the column definitions AND data to the current dataset.""" @@ -398,6 +549,182 @@ def write_columns(self, reflection_table): ) +def export_mtz_tof( + reflection_table, + experiment_list, + intensity_choice, + filename, + best_unit_cell=None, + partiality_threshold=0.4, + combine_partials=True, + min_isigi=-5, + filter_ice_rings=False, + d_min=None, + crystal_name=None, + project_name=None, +): + """Export data from reflection_table corresponding to experiment_list to an + MTZ file hklout.""" + + # First get the experiment identifier information out of the data + expids_in_table = reflection_table.experiment_identifiers() + if not list(expids_in_table.keys()): + reflection_tables = parse_multiple_datasets([reflection_table]) + experiment_list, refl_list = assign_unique_identifiers( + experiment_list, reflection_tables + ) + reflection_table = flex.reflection_table() + for reflections in refl_list: + reflection_table.extend(reflections) + expids_in_table = reflection_table.experiment_identifiers() + reflection_table.assert_experiment_identifiers_are_consistent(experiment_list) + expids_in_list = list(experiment_list.identifiers()) + + # Convert geometry to the Cambridge frame + experiment_list = convert_to_cambridge(experiment_list) + + # Convert experiment_list to a real python list or else identity assumptions + # fail like: + # assert experiment_list[0] is experiment_list[0] + # And assumptions about added attributes break + experiment_list = list(experiment_list) + + # Validate multi-experiment assumptions + if len(experiment_list) > 1: + # All experiments should match crystals, or else we need multiple crystals/datasets + if not all( + x.crystal == experiment_list[0].crystal for x in experiment_list[1:] + ): + logger.warning( + "Experiment crystals differ. Using first experiment crystal for file-level data." + ) + + # At least, all experiments must have the same space group + if len({x.crystal.get_space_group().make_tidy() for x in experiment_list}) != 1: + raise ValueError("Experiments do not have a unique space group") + + # also only work correctly with one panel (for the moment) + if any(len(experiment.detector) != 1 for experiment in experiment_list): + logger.warning("Ignoring multiple panels in output MTZ") + + if best_unit_cell is None: + best_unit_cell = determine_best_unit_cell(experiment_list) + reflection_table["d"] = best_unit_cell.d(reflection_table["miller_index"]) + + # Clean up the data with the passed in options + reflection_table = filter_reflection_table( + reflection_table, + intensity_choice=intensity_choice, + partiality_threshold=partiality_threshold, + combine_partials=combine_partials, + min_isigi=min_isigi, + filter_ice_rings=filter_ice_rings, + d_min=d_min, + ) + + # get batch offsets and image ranges - even for scanless experiments + batch_offsets = [ + expt.sequence.get_batch_offset() + for expt in experiment_list + if expt.sequence is not None + ] + unique_offsets = set(batch_offsets) + if len(set(unique_offsets)) <= 1: + logger.debug("Calculating new batches") + batch_offsets = calculate_batch_offsets(experiment_list) + batch_starts = [ + e.sequence.get_image_range()[0] if e.sequence else 0 + for e in experiment_list + ] + effective_offsets = [o + s for o, s in zip(batch_offsets, batch_starts)] + unique_offsets = set(effective_offsets) + else: + logger.debug("Keeping existing batches") + image_ranges = get_image_ranges(experiment_list) + if len(unique_offsets) != len(batch_offsets): + + raise ValueError( + "Duplicate batch offsets detected: %s" + % ", ".join( + str(item) for item, count in Counter(batch_offsets).items() if count > 1 + ) + ) + + # Create the mtz file + mtz_writer = UnmergedMTZWriter(experiment_list[0].crystal.get_space_group()) + + # FIXME TODO for more than one experiment into an MTZ file: + # + # - add an epoch (or recover an epoch) from the scan and add this as an extra + # column to the MTZ file for scaling, so we know that the two lattices were + # integrated at the same time + # ✓ decide a sensible BATCH increment to apply to the BATCH value between + # experiments and add this + + for id_ in expids_in_table.keys(): + # Grab our subset of the data + loc = expids_in_list.index( + expids_in_table[id_] + ) # get strid and use to find loc in list + experiment = experiment_list[loc] + wavelength = -1 + dataset_id = 1 + reflections = reflection_table.select(reflection_table["id"] == id_) + batch_offset = batch_offsets[loc] + image_range = image_ranges[loc] + reflections = assign_batches_to_reflections([reflections], [batch_offset])[0] + experiment.data = dict(reflections) + + mtz_writer.add_batch_list( + image_range, + experiment, + wavelength, + dataset_id, + batch_offset=batch_offset, + force_static_model=True, + ) + + # Create the batch offset array. This gives us an experiment (id)-dependent + # batch offset to calculate the correct batch from image number. + experiment.data["batch_offset"] = flex.int( + len(experiment.data["id"]), batch_offset + ) + + # Calculate whether we have a ROT value for this experiment, and set the column + _, _, z = experiment.data["xyzcal.px"].parts() + experiment.data["ROT"] = z + + mtz_writer.add_crystal( + crystal_name=crystal_name, + project_name=project_name, + unit_cell=best_unit_cell, + ) + + mtz_writer.add_empty_dataset(wavelength) + + # Combine all of the experiment data columns before writing + combined_data = {k: v.deep_copy() for k, v in experiment_list[0].data.items()} + for experiment in experiment_list[1:]: + for k, v in experiment.data.items(): + combined_data[k].extend(v) + # ALL columns must be the same length + assert len({len(v) for v in combined_data.values()}) == 1, "Column length mismatch" + assert len(combined_data["id"]) == len( + reflection_table["id"] + ), "Lost rows in split/combine" + + # Write all the data and columns to the mtz file + mtz_writer.write_columns_tof(combined_data) + + logger.info( + "Saving %s integrated reflections to %s", len(combined_data["id"]), filename + ) + mtz_file = mtz_writer.mtz_file + mtz_file.write(filename) + + return mtz_file + + def export_mtz( reflection_table, experiment_list,