From e684c41f38405a23c217c250de76066ee509f3b3 Mon Sep 17 00:00:00 2001 From: David McDonagh <60879630+toastisme@users.noreply.github.com> Date: Tue, 20 Feb 2024 11:25:58 +0000 Subject: [PATCH] Scan from dict fix (#693) * Fixed bug in `Scan.from_dict` where some properties were not correctly parsed. --- newsfragments/688.bugfix | 1 + src/dxtbx/model/boost_python/scan.cc | 11 ++- src/dxtbx/model/scan.py | 132 +++++++++++++++++++++------ tests/model/test_scan.py | 21 +++++ 4 files changed, 135 insertions(+), 30 deletions(-) create mode 100644 newsfragments/688.bugfix diff --git a/newsfragments/688.bugfix b/newsfragments/688.bugfix new file mode 100644 index 000000000..96cdd168b --- /dev/null +++ b/newsfragments/688.bugfix @@ -0,0 +1 @@ +Fixed `Scan.from_dict` bug where some properties were not correctly parsed. diff --git a/src/dxtbx/model/boost_python/scan.cc b/src/dxtbx/model/boost_python/scan.cc index 8a442d9f4..3c014e48f 100644 --- a/src/dxtbx/model/boost_python/scan.cc +++ b/src/dxtbx/model/boost_python/scan.cc @@ -76,7 +76,7 @@ namespace dxtbx { namespace model { namespace boost_python { value[0].attr("__class__").attr("__name__")); // Handled explicitly as it is in deg when serialised but rad in code - if (key == "oscillation") { + if (key == "oscillation" || key == "oscillation_width") { DXTBX_ASSERT(obj_type == "float"); scitbx::af::shared osc = boost::python::extract >(value); @@ -215,7 +215,14 @@ namespace dxtbx { namespace model { namespace boost_python { dxtbx::af::flex_table_suite::column_to_object_visitor visitor; for (const_iterator it = properties.begin(); it != properties.end(); ++it) { - if (it->first == "oscillation") { // Handled explicitly due to unit conversion + if (it->first + == "oscillation_width") { // Handled explicitly due to unit conversion + vec2 osc_deg = obj.get_oscillation_in_deg(); + boost::python::list lst = boost::python::list(); + lst.append(osc_deg[1]); + properties_dict[it->first] = lst; + } else if (it->first + == "oscillation") { // Handled explicitly due to unit conversion properties_dict[it->first] = boost::python::tuple(obj.get_oscillation_arr_in_deg()); } else { diff --git a/src/dxtbx/model/scan.py b/src/dxtbx/model/scan.py index 40cf591c5..e63f2b025 100644 --- a/src/dxtbx/model/scan.py +++ b/src/dxtbx/model/scan.py @@ -113,32 +113,99 @@ def from_dict(d, t=None): The scan model """ - def convert_oscillation_to_vec2(properties_dict): + def add_properties_table(scan_dict, num_images): """ - If oscillation is in properties_dict, - shared is converted to vec2 and - oscillation_width is removed (if present) to ensure - it is replaced correctly if updating t dict from d dict + Handles legacy case before Scan had a properties table. + Moves oscillation, epochs, and exposure times to a properties + table and adds this to scan_dict. """ - if "oscillation" not in properties_dict: - assert "oscillation_width" not in properties_dict - return properties_dict - if "oscillation_width" in properties_dict: - assert "oscillation" in properties_dict - properties_dict["oscillation"] = ( - properties_dict["oscillation"][0], - properties_dict["oscillation_width"][0], - ) - del properties_dict["oscillation_width"] - return properties_dict - properties_dict["oscillation"] = ( - properties_dict["oscillation"][0], - properties_dict["oscillation"][1] - properties_dict["oscillation"][0], + properties = {} + if scan_dict: + if "oscillation" in scan_dict: + if num_images == 1: + properties["oscillation_width"] = [scan_dict["oscillation"][1]] + properties["oscillation"] = [scan_dict["oscillation"][0]] + + else: + osc = scan_dict["oscillation"] + properties["oscillation"] = [ + osc[0] + (osc[1] - osc[0]) * i for i in range(num_images) + ] + del scan_dict["oscillation"] + if "exposure_time" in scan_dict: + properties["exposure_time"] = scan_dict["exposure_time"] + del scan_dict["exposure_time"] + if "epochs" in scan_dict: + properties["epochs"] = scan_dict["epochs"] + del scan_dict["epochs"] + scan_dict["properties"] = make_properties_table_consistent( + properties, num_images ) + return scan_dict + + def make_properties_table_consistent(properties, num_images): - return properties_dict + """ + Handles legacy case before Scan had a properties table. + Ensures oscillation, epochs, and exposure times have the same length. + """ + + if not properties: + return properties + + property_length = len(next(iter(properties.values()))) + all_same_length = all( + len(lst) == property_length for lst in properties.values() + ) + if all_same_length and property_length == num_images: + return properties + + if "oscillation" in properties: + assert len(properties["oscillation"]) > 0 + + if num_images == 1 and "oscillation_width" not in properties: + assert len(properties["oscillation"]) > 1 + properties["oscillation_width"] = [properties["oscillation"][1]] + properties["oscillation"] = [properties["oscillation"][0]] + elif num_images > 1: + osc_0 = properties["oscillation"][0] + if "oscillation_width" in properties: + osc_1 = properties["oscillation_width"][0] + del properties["oscillation_width"] + else: + assert len(properties["oscillation"]) > 1 + osc_1 = ( + properties["oscillation"][1] - properties["oscillation"][0] + ) + properties["oscillation"] = [ + osc_0 + osc_1 * i for i in range(num_images) + ] + + if "exposure_time" in properties: + assert len(properties["exposure_time"]) > 0 + + # Assume same exposure time for each image + properties["exposure_time"] = [ + properties["exposure_time"][0] for i in range(num_images) + ] + + if "epochs" in properties: + assert len(properties["epochs"]) > 0 + # If 1 epoch, assume increasing by epochs[0] + # Else assume increasing as epochs[1] - epochs[0] + if len(properties["epochs"]) == 1: + properties["epochs"] = [ + properties["epochs"][0] + properties["epochs"][0] * i + for i in range(num_images) + ] + else: + diff = properties["epochs"][1] - properties["epochs"][0] + properties["epochs"] = [ + properties["epochs"][0] + i * diff for i in range(num_images) + ] + return properties if d is None and t is None: return None @@ -146,26 +213,35 @@ def convert_oscillation_to_vec2(properties_dict): # Accounting for legacy cases where t or d does not # contain properties dict + num_images = None + if "image_range" in d: + num_images = 1 + d["image_range"][1] - d["image_range"][0] + elif "image_range" in joint: + num_images = 1 + joint["image_range"][1] - joint["image_range"][0] + if "properties" in joint and "properties" in d: properties = t["properties"].copy() properties.update(d["properties"]) joint.update(d) joint["properties"] = properties elif "properties" in d: + joint = add_properties_table(joint, num_images) d_copy = d.copy() - d_copy["properties"] = convert_oscillation_to_vec2( - d_copy["properties"].copy() + joint["properties"].update(d_copy["properties"]) + joint["properties"] = make_properties_table_consistent( + joint["properties"], num_images ) - joint.update(**d_copy["properties"]) del d_copy["properties"] joint.update(d_copy) elif "properties" in joint: - joint["properties"] = convert_oscillation_to_vec2( - joint["properties"].copy() + d = add_properties_table(d, num_images) + d_copy = d.copy() + joint["properties"].update(d_copy["properties"]) + joint["properties"] = make_properties_table_consistent( + joint["properties"], num_images ) - joint.update(**joint["properties"]) - del joint["properties"] - joint.update(d) + del d_copy["properties"] + joint.update(d_copy) else: joint.update(d) diff --git a/tests/model/test_scan.py b/tests/model/test_scan.py index 76dc3154a..3a17c4e59 100644 --- a/tests/model/test_scan.py +++ b/tests/model/test_scan.py @@ -471,3 +471,24 @@ def test_scan_is_still(): (1, 10), properties={"other_property": list(range(10))} ) assert scan.is_still() + +def test_scan_properties_from_dict(): + image_range = (1, 10) + properties = {"test": list(range(10))} + scan = ScanFactory.make_scan_from_properties(image_range, properties) + assert scan == ScanFactory.from_dict(scan.to_dict()) + + image_range = (1, 1) + properties = {"oscillation": [1.0], "oscillation_width": [0.5]} + scan = ScanFactory.make_scan_from_properties(image_range, properties) + assert scan == ScanFactory.from_dict(scan.to_dict()) + + scan = ScanFactory.from_dict( + { + "oscillation": [1.0, 0.5], + "image_range": [1, 1], + "exposure_time": [0.5], + "epochs": [1], + } + ) + assert scan == ScanFactory.from_dict(scan.to_dict())