From 917d4068edb1739630f14fd5455363b153865228 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Tue, 1 Oct 2024 22:56:21 +0200 Subject: [PATCH] fix(datasets) Fix pathological partitioner on string labels (#4253) --- .../partitioner/pathological_partitioner.py | 2 +- .../pathological_partitioner_test.py | 26 ++++++++++++++++--- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/pathological_partitioner.py b/datasets/flwr_datasets/partitioner/pathological_partitioner.py index 350383f344e7..d114ccbda02f 100644 --- a/datasets/flwr_datasets/partitioner/pathological_partitioner.py +++ b/datasets/flwr_datasets/partitioner/pathological_partitioner.py @@ -225,7 +225,7 @@ def _determine_partition_id_to_unique_labels(self) -> None: if self._class_assignment_mode == "first-deterministic": # if self._first_class_deterministic_assignment: for partition_id in range(self._num_partitions): - label = partition_id % num_unique_classes + label = self._unique_labels[partition_id % num_unique_classes] self._partition_id_to_unique_labels[partition_id].append(label) while ( diff --git a/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py b/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py index 18707a56bd98..5a3b13bb1436 100644 --- a/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py +++ b/datasets/flwr_datasets/partitioner/pathological_partitioner_test.py @@ -18,7 +18,7 @@ import unittest import numpy as np -from parameterized import parameterized +from parameterized import parameterized, parameterized_class import datasets from datasets import Dataset @@ -26,7 +26,10 @@ def _dummy_dataset_setup( - num_samples: int, partition_by: str, num_unique_classes: int + num_samples: int, + partition_by: str, + num_unique_classes: int, + string_partition_by: bool = False, ) -> Dataset: """Create a dummy dataset for testing.""" data = { @@ -35,6 +38,8 @@ def _dummy_dataset_setup( )[:num_samples], "features": np.random.randn(num_samples), } + if string_partition_by: + data[partition_by] = data[partition_by].astype(str) return Dataset.from_dict(data) @@ -51,6 +56,7 @@ def _dummy_heterogeneous_dataset_setup( return Dataset.from_dict(data) +@parameterized_class(("string_partition_by",), [(False,), (True,)]) class TestClassConstrainedPartitioner(unittest.TestCase): """Unit tests for PathologicalPartitioner.""" @@ -94,7 +100,8 @@ def test_first_class_deterministic_assignment(self) -> None: Test if all the classes are used (which has to be the case, given num_partitions >= than the number of unique classes). """ - dataset = _dummy_dataset_setup(100, "labels", 10) + partition_by = "labels" + dataset = _dummy_dataset_setup(100, partition_by, 10) partitioner = PathologicalPartitioner( num_partitions=10, partition_by="labels", @@ -103,7 +110,12 @@ def test_first_class_deterministic_assignment(self) -> None: ) partitioner.dataset = dataset partitioner.load_partition(0) - expected_classes = set(range(10)) + expected_classes = set( + range(10) + # pylint: disable=unsubscriptable-object + if isinstance(dataset[partition_by][0], int) + else [str(i) for i in range(10)] + ) actual_classes = set() for pid in range(10): partition = partitioner.load_partition(pid) @@ -141,6 +153,9 @@ def test_deterministic_class_assignment( for i in range(num_classes_per_partition) ] ) + # pylint: disable=unsubscriptable-object + if isinstance(dataset["labels"][0], str): + expected_labels = [str(label) for label in expected_labels] actual_labels = sorted(np.unique(partition["labels"])) self.assertTrue( np.array_equal(expected_labels, actual_labels), @@ -166,6 +181,9 @@ def test_too_many_partitions_for_a_class( "labels": np.array([num_unique_classes - 1] * (num_samples // 2)), "features": np.random.randn(num_samples // 2), } + # pylint: disable=unsubscriptable-object + if isinstance(dataset_1["labels"][0], str): + data["labels"] = data["labels"].astype(str) dataset_2 = Dataset.from_dict(data) dataset = datasets.concatenate_datasets([dataset_1, dataset_2])