diff --git a/newsfragments/XXX.misc b/newsfragments/XXX.misc new file mode 100644 index 000000000..8bc4cb074 --- /dev/null +++ b/newsfragments/XXX.misc @@ -0,0 +1 @@ +Add `ExperimentList.all_same_type()` to quickly check if experiments are all the same type. diff --git a/src/dxtbx/model/__init__.py b/src/dxtbx/model/__init__.py index c23bf5747..2608a12ce 100644 --- a/src/dxtbx/model/__init__.py +++ b/src/dxtbx/model/__init__.py @@ -619,6 +619,16 @@ def all_laue(self): """Check if all the experiments are Laue experiments""" return all(exp.get_type() == ExperimentType.LAUE for exp in self) + def all_same_type(self): + """Check if all experiments are the same type""" + if len(self) <= 1: + return True + expt_type = self[0].get_type() + for i in range(1, len(self)): + if self[i].get_type() != expt_type: + return False + return True + def to_dict(self): """Serialize the experiment list to dictionary.""" diff --git a/tests/model/test_experiment_list.py b/tests/model/test_experiment_list.py index 47c324e29..2a19e2ec5 100644 --- a/tests/model/test_experiment_list.py +++ b/tests/model/test_experiment_list.py @@ -1205,15 +1205,19 @@ def test_from_templates(dials_data): def test_experiment_list_all(): experiments = ExperimentList() + assert experiments.all_same_type() for i in range(3): experiments.append(Experiment()) assert experiments.all_stills() + assert experiments.all_same_type() experiments[0].goniometer = Goniometer() assert experiments.all_stills() + assert experiments.all_same_type() experiments[1].goniometer = Goniometer() experiments[2].goniometer = Goniometer() assert experiments.all_stills() + assert experiments.all_same_type() experiments[0].beam = BeamFactory.make_polychromatic_beam( direction=(0, 0, -1), @@ -1222,6 +1226,7 @@ def test_experiment_list_all(): wavelength_range=(1, 10), ) assert not experiments.all_stills() + assert not experiments.all_same_type() experiments[1].beam = BeamFactory.make_polychromatic_beam( direction=(0, 0, -1), sample_to_source_distance=(100), @@ -1235,23 +1240,29 @@ def test_experiment_list_all(): wavelength_range=(1, 10), ) assert experiments.all_laue() + assert experiments.all_same_type() experiments[0].beam = Beam() assert not experiments.all_laue() + assert not experiments.all_same_type() experiments[1].beam = Beam() experiments[2].beam = Beam() assert experiments.all_stills() + assert experiments.all_same_type() experiments[0].scan = Scan((1, 1000), (0, 0.05)) assert not experiments.all_stills() + assert not experiments.all_same_type() experiments[1].scan = Scan((1, 1000), (0, 0.05)) experiments[2].scan = Scan((1, 1000), (0, 0.05)) assert experiments.all_rotations() + assert experiments.all_same_type() experiments[0].scan = ScanFactory.make_scan_from_properties( (1, 10), properties={"time_of_flight": list(range(10))} ) assert not experiments.all_rotations() + assert not experiments.all_same_type() experiments[1].scan = ScanFactory.make_scan_from_properties( (1, 10), properties={"time_of_flight": list(range(10))} ) @@ -1259,11 +1270,13 @@ def test_experiment_list_all(): (1, 10), properties={"time_of_flight": list(range(10))} ) assert experiments.all_tof() + assert experiments.all_same_type() experiments[0].scan = ScanFactory.make_scan_from_properties( (1, 10), properties={"other_property": list(range(10))} ) assert not experiments.all_tof() + assert not experiments.all_same_type() experiments[1].scan = ScanFactory.make_scan_from_properties( (1, 10), properties={"other_property": list(range(10))} ) @@ -1271,3 +1284,4 @@ def test_experiment_list_all(): (1, 10), properties={"other_property": list(range(10))} ) assert experiments.all_stills() + assert experiments.all_same_type()