diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index 72ea54773564..8659aa03313b 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -128,6 +128,7 @@ def __init__( self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners( partitioners ) + self._check_partitioners_correctness() self._shuffle = shuffle self._seed = seed # _dataset is prepared lazily on the first call to `load_partition` @@ -336,3 +337,20 @@ def _check_if_no_split_keyword_possible(self) -> None: "Please set the `split` argument. You can only omit the split keyword " "if there is exactly one partitioner specified." ) + + def _check_partitioners_correctness(self) -> None: + """Check if the partitioners are correctly specified. + + Check if each partitioner is a different Python object. Using the same + partitioner for different splits is not allowed. + """ + partitioners_keys = list(self._partitioners.keys()) + for i, first_split in enumerate(partitioners_keys): + for j in range(i + 1, len(partitioners_keys)): + second_split = partitioners_keys[j] + if self._partitioners[first_split] is self._partitioners[second_split]: + raise ValueError( + f"The same partitioner object is used for multiple splits: " + f"('{first_split}', '{second_split}'). " + "Each partitioner should be a separate object." + ) diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index bbdfa42292c2..6c12ee0e2e1a 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -32,6 +32,7 @@ _load_mocked_dataset_dict_by_partial_download, ) from flwr_datasets.partitioner import IidPartitioner, NaturalIdPartitioner, Partitioner +from flwr_datasets.preprocessor.divider import Divider mocked_datasets = ["cifar100", "svhn", "sentiment140", "speech_commands"] @@ -568,6 +569,57 @@ def test_use_load_dataset_kwargs(self) -> None: with self.assertRaises(ValueError): _ = fds.load_partition(0) + def test_incorrect_two_partitioners(self) -> None: + """Test if the method raises ValueError with incorrect partitioners.""" + partitioner = IidPartitioner(num_partitions=10) + partitioners: dict[str, Union[Partitioner, int]] = { + "train": partitioner, + "test": partitioner, + } + first_split = "train" + second_split = "test" + with self.assertRaises(ValueError) as context: + FederatedDataset( + dataset="mnist", + partitioners=partitioners, + ) + self.assertIn( + f"The same partitioner object is used for multiple splits: " + f"('{first_split}', '{second_split}'). " + "Each partitioner should be a separate object.", + str(context.exception), + ) + + def test_incorrect_three_partitioners(self) -> None: + """Test if the method raises ValueError with incorrect partitioners.""" + partitioner = IidPartitioner(num_partitions=10) + partitioners: dict[str, Union[int, Partitioner]] = { + "train1": partitioner, + "train2": 10, + "test": partitioner, + } + divider = Divider( + divide_config={ + "train1": 0.5, + "train2": 0.5, + }, + divide_split="train", + ) + + with self.assertRaises( + ValueError, + ) as context: + + FederatedDataset( + dataset="mnist", partitioners=partitioners, preprocessor=divider + ) + + self.assertIn( + "The same partitioner object is used for multiple splits: " + "('train1', 'test'). Each partitioner should be a separate object.", + str(context.exception), + ) + def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool: """Check if two Datasets have the same values."""