Skip to content

Commit

Permalink
Add ExperimentList.all_same_type().
Browse files Browse the repository at this point in the history
  • Loading branch information
toastisme committed Sep 17, 2024
1 parent 9e26747 commit 8ce275c
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 0 deletions.
1 change: 1 addition & 0 deletions newsfragments/XXX.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `ExperimentList.all_same_type()` to quickly check if experiments are all the same type.
10 changes: 10 additions & 0 deletions src/dxtbx/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,16 @@ def all_laue(self):
"""Check if all the experiments are Laue experiments"""
return all(exp.get_type() == ExperimentType.LAUE for exp in self)

def all_same_type(self):
"""Check if all experiments are the same type"""
if len(self) <= 1:
return True
expt_type = self[0].get_type()
for i in range(1, len(self)):
if self[i].get_type() != expt_type:
return False
return True

def to_dict(self):
"""Serialize the experiment list to dictionary."""

Expand Down
14 changes: 14 additions & 0 deletions tests/model/test_experiment_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,15 +1205,19 @@ def test_from_templates(dials_data):

def test_experiment_list_all():
experiments = ExperimentList()
assert experiments.all_same_type()
for i in range(3):
experiments.append(Experiment())

assert experiments.all_stills()
assert experiments.all_same_type()
experiments[0].goniometer = Goniometer()
assert experiments.all_stills()
assert experiments.all_same_type()
experiments[1].goniometer = Goniometer()
experiments[2].goniometer = Goniometer()
assert experiments.all_stills()
assert experiments.all_same_type()

experiments[0].beam = BeamFactory.make_polychromatic_beam(
direction=(0, 0, -1),
Expand All @@ -1222,6 +1226,7 @@ def test_experiment_list_all():
wavelength_range=(1, 10),
)
assert not experiments.all_stills()
assert not experiments.all_same_type()
experiments[1].beam = BeamFactory.make_polychromatic_beam(
direction=(0, 0, -1),
sample_to_source_distance=(100),
Expand All @@ -1235,39 +1240,48 @@ def test_experiment_list_all():
wavelength_range=(1, 10),
)
assert experiments.all_laue()
assert experiments.all_same_type()

experiments[0].beam = Beam()
assert not experiments.all_laue()
assert not experiments.all_same_type()
experiments[1].beam = Beam()
experiments[2].beam = Beam()
assert experiments.all_stills()
assert experiments.all_same_type()

experiments[0].scan = Scan((1, 1000), (0, 0.05))
assert not experiments.all_stills()
assert not experiments.all_same_type()
experiments[1].scan = Scan((1, 1000), (0, 0.05))
experiments[2].scan = Scan((1, 1000), (0, 0.05))
assert experiments.all_rotations()
assert experiments.all_same_type()

experiments[0].scan = ScanFactory.make_scan_from_properties(
(1, 10), properties={"time_of_flight": list(range(10))}
)
assert not experiments.all_rotations()
assert not experiments.all_same_type()
experiments[1].scan = ScanFactory.make_scan_from_properties(
(1, 10), properties={"time_of_flight": list(range(10))}
)
experiments[2].scan = ScanFactory.make_scan_from_properties(
(1, 10), properties={"time_of_flight": list(range(10))}
)
assert experiments.all_tof()
assert experiments.all_same_type()

experiments[0].scan = ScanFactory.make_scan_from_properties(
(1, 10), properties={"other_property": list(range(10))}
)
assert not experiments.all_tof()
assert not experiments.all_same_type()
experiments[1].scan = ScanFactory.make_scan_from_properties(
(1, 10), properties={"other_property": list(range(10))}
)
experiments[2].scan = ScanFactory.make_scan_from_properties(
(1, 10), properties={"other_property": list(range(10))}
)
assert experiments.all_stills()
assert experiments.all_same_type()

0 comments on commit 8ce275c

Please sign in to comment.