Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scan from dict fix #693

Merged
merged 8 commits into from
Feb 20, 2024
1 change: 1 addition & 0 deletions newsfragments/688.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed `Scan.from_dict` bug where some properties were not correctly parsed.
11 changes: 9 additions & 2 deletions src/dxtbx/model/boost_python/scan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ namespace dxtbx { namespace model { namespace boost_python {
value[0].attr("__class__").attr("__name__"));

// Handled explicitly as it is in deg when serialised but rad in code
if (key == "oscillation") {
if (key == "oscillation" || key == "oscillation_width") {
DXTBX_ASSERT(obj_type == "float");
scitbx::af::shared<double> osc =
boost::python::extract<scitbx::af::shared<double> >(value);
Expand Down Expand Up @@ -215,7 +215,14 @@ namespace dxtbx { namespace model { namespace boost_python {
dxtbx::af::flex_table_suite::column_to_object_visitor visitor;

for (const_iterator it = properties.begin(); it != properties.end(); ++it) {
if (it->first == "oscillation") { // Handled explicitly due to unit conversion
if (it->first
== "oscillation_width") { // Handled explicitly due to unit conversion
vec2<double> osc_deg = obj.get_oscillation_in_deg();
boost::python::list lst = boost::python::list();
lst.append(osc_deg[1]);
properties_dict[it->first] = lst;
} else if (it->first
== "oscillation") { // Handled explicitly due to unit conversion
properties_dict[it->first] =
boost::python::tuple(obj.get_oscillation_arr_in_deg());
} else {
Expand Down
132 changes: 104 additions & 28 deletions src/dxtbx/model/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,59 +113,135 @@
The scan model
"""

def convert_oscillation_to_vec2(properties_dict):
def add_properties_table(scan_dict, num_images):

"""
If oscillation is in properties_dict,
shared<double> is converted to vec2<double> and
oscillation_width is removed (if present) to ensure
it is replaced correctly if updating t dict from d dict
Handles legacy case before Scan had a properties table.
Moves oscillation, epochs, and exposure times to a properties
table and adds this to scan_dict.
"""

if "oscillation" not in properties_dict:
assert "oscillation_width" not in properties_dict
return properties_dict
if "oscillation_width" in properties_dict:
assert "oscillation" in properties_dict
properties_dict["oscillation"] = (
properties_dict["oscillation"][0],
properties_dict["oscillation_width"][0],
)
del properties_dict["oscillation_width"]
return properties_dict
properties_dict["oscillation"] = (
properties_dict["oscillation"][0],
properties_dict["oscillation"][1] - properties_dict["oscillation"][0],
properties = {}
if scan_dict:
if "oscillation" in scan_dict:
if num_images == 1:
properties["oscillation_width"] = [scan_dict["oscillation"][1]]
properties["oscillation"] = [scan_dict["oscillation"][0]]

Check warning on line 129 in src/dxtbx/model/scan.py

View check run for this annotation

Codecov / codecov/patch

src/dxtbx/model/scan.py#L128-L129

Added lines #L128 - L129 were not covered by tests

else:
osc = scan_dict["oscillation"]

Check warning on line 132 in src/dxtbx/model/scan.py

View check run for this annotation

Codecov / codecov/patch

src/dxtbx/model/scan.py#L132

Added line #L132 was not covered by tests
properties["oscillation"] = [
osc[0] + (osc[1] - osc[0]) * i for i in range(num_images)
]
del scan_dict["oscillation"]

Check warning on line 136 in src/dxtbx/model/scan.py

View check run for this annotation

Codecov / codecov/patch

src/dxtbx/model/scan.py#L136

Added line #L136 was not covered by tests
if "exposure_time" in scan_dict:
properties["exposure_time"] = scan_dict["exposure_time"]
del scan_dict["exposure_time"]
if "epochs" in scan_dict:
properties["epochs"] = scan_dict["epochs"]
del scan_dict["epochs"]
scan_dict["properties"] = make_properties_table_consistent(
properties, num_images
)
return scan_dict

def make_properties_table_consistent(properties, num_images):

return properties_dict
"""
Handles legacy case before Scan had a properties table.
Ensures oscillation, epochs, and exposure times have the same length.
"""

if not properties:
return properties

property_length = len(next(iter(properties.values())))
all_same_length = all(
len(lst) == property_length for lst in properties.values()
)
if all_same_length and property_length == num_images:
return properties

if "oscillation" in properties:
assert len(properties["oscillation"]) > 0

if num_images == 1 and "oscillation_width" not in properties:
assert len(properties["oscillation"]) > 1
properties["oscillation_width"] = [properties["oscillation"][1]]
properties["oscillation"] = [properties["oscillation"][0]]

Check warning on line 171 in src/dxtbx/model/scan.py

View check run for this annotation

Codecov / codecov/patch

src/dxtbx/model/scan.py#L169-L171

Added lines #L169 - L171 were not covered by tests
elif num_images > 1:
osc_0 = properties["oscillation"][0]
if "oscillation_width" in properties:
osc_1 = properties["oscillation_width"][0]
del properties["oscillation_width"]

Check warning on line 176 in src/dxtbx/model/scan.py

View check run for this annotation

Codecov / codecov/patch

src/dxtbx/model/scan.py#L175-L176

Added lines #L175 - L176 were not covered by tests
else:
assert len(properties["oscillation"]) > 1
osc_1 = (
properties["oscillation"][1] - properties["oscillation"][0]
)
properties["oscillation"] = [
osc_0 + osc_1 * i for i in range(num_images)
]

if "exposure_time" in properties:
assert len(properties["exposure_time"]) > 0

# Assume same exposure time for each image
properties["exposure_time"] = [
properties["exposure_time"][0] for i in range(num_images)
]

if "epochs" in properties:
assert len(properties["epochs"]) > 0
# If 1 epoch, assume increasing by epochs[0]
# Else assume increasing as epochs[1] - epochs[0]
if len(properties["epochs"]) == 1:
properties["epochs"] = [
properties["epochs"][0] + properties["epochs"][0] * i
for i in range(num_images)
]
else:
diff = properties["epochs"][1] - properties["epochs"][0]
properties["epochs"] = [
properties["epochs"][0] + i * diff for i in range(num_images)
]
return properties

if d is None and t is None:
return None
joint = t.copy() if t else {}

# Accounting for legacy cases where t or d does not
# contain properties dict
num_images = None
if "image_range" in d:
num_images = 1 + d["image_range"][1] - d["image_range"][0]
elif "image_range" in joint:
num_images = 1 + joint["image_range"][1] - joint["image_range"][0]

if "properties" in joint and "properties" in d:
properties = t["properties"].copy()
properties.update(d["properties"])
joint.update(d)
joint["properties"] = properties
elif "properties" in d:
joint = add_properties_table(joint, num_images)
d_copy = d.copy()
d_copy["properties"] = convert_oscillation_to_vec2(
d_copy["properties"].copy()
joint["properties"].update(d_copy["properties"])
joint["properties"] = make_properties_table_consistent(
joint["properties"], num_images
)
joint.update(**d_copy["properties"])
del d_copy["properties"]
joint.update(d_copy)
elif "properties" in joint:
joint["properties"] = convert_oscillation_to_vec2(
joint["properties"].copy()
d = add_properties_table(d, num_images)
d_copy = d.copy()
joint["properties"].update(d_copy["properties"])
joint["properties"] = make_properties_table_consistent(
joint["properties"], num_images
)
joint.update(**joint["properties"])
del joint["properties"]
joint.update(d)
del d_copy["properties"]
joint.update(d_copy)
else:
joint.update(d)

Expand Down
21 changes: 21 additions & 0 deletions tests/model/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,24 @@ def test_scan_is_still():
(1, 10), properties={"other_property": list(range(10))}
)
assert scan.is_still()

def test_scan_properties_from_dict():
image_range = (1, 10)
properties = {"test": list(range(10))}
scan = ScanFactory.make_scan_from_properties(image_range, properties)
assert scan == ScanFactory.from_dict(scan.to_dict())

image_range = (1, 1)
properties = {"oscillation": [1.0], "oscillation_width": [0.5]}
scan = ScanFactory.make_scan_from_properties(image_range, properties)
assert scan == ScanFactory.from_dict(scan.to_dict())

scan = ScanFactory.from_dict(
{
"oscillation": [1.0, 0.5],
"image_range": [1, 1],
"exposure_time": [0.5],
"epochs": [1],
}
)
assert scan == ScanFactory.from_dict(scan.to_dict())
Loading