Skip to content

Commit

Permalink
fix(datasets) Fix pathological partitioner on string labels (#4253)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak authored Oct 1, 2024
1 parent 1a4da37 commit 917d406
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _determine_partition_id_to_unique_labels(self) -> None:
if self._class_assignment_mode == "first-deterministic":
# if self._first_class_deterministic_assignment:
for partition_id in range(self._num_partitions):
label = partition_id % num_unique_classes
label = self._unique_labels[partition_id % num_unique_classes]
self._partition_id_to_unique_labels[partition_id].append(label)

while (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@
import unittest

import numpy as np
from parameterized import parameterized
from parameterized import parameterized, parameterized_class

import datasets
from datasets import Dataset
from flwr_datasets.partitioner.pathological_partitioner import PathologicalPartitioner


def _dummy_dataset_setup(
num_samples: int, partition_by: str, num_unique_classes: int
num_samples: int,
partition_by: str,
num_unique_classes: int,
string_partition_by: bool = False,
) -> Dataset:
"""Create a dummy dataset for testing."""
data = {
Expand All @@ -35,6 +38,8 @@ def _dummy_dataset_setup(
)[:num_samples],
"features": np.random.randn(num_samples),
}
if string_partition_by:
data[partition_by] = data[partition_by].astype(str)
return Dataset.from_dict(data)


Expand All @@ -51,6 +56,7 @@ def _dummy_heterogeneous_dataset_setup(
return Dataset.from_dict(data)


@parameterized_class(("string_partition_by",), [(False,), (True,)])
class TestClassConstrainedPartitioner(unittest.TestCase):
"""Unit tests for PathologicalPartitioner."""

Expand Down Expand Up @@ -94,7 +100,8 @@ def test_first_class_deterministic_assignment(self) -> None:
Test if all the classes are used (which has to be the case, given num_partitions
>= than the number of unique classes).
"""
dataset = _dummy_dataset_setup(100, "labels", 10)
partition_by = "labels"
dataset = _dummy_dataset_setup(100, partition_by, 10)
partitioner = PathologicalPartitioner(
num_partitions=10,
partition_by="labels",
Expand All @@ -103,7 +110,12 @@ def test_first_class_deterministic_assignment(self) -> None:
)
partitioner.dataset = dataset
partitioner.load_partition(0)
expected_classes = set(range(10))
expected_classes = set(
range(10)
# pylint: disable=unsubscriptable-object
if isinstance(dataset[partition_by][0], int)
else [str(i) for i in range(10)]
)
actual_classes = set()
for pid in range(10):
partition = partitioner.load_partition(pid)
Expand Down Expand Up @@ -141,6 +153,9 @@ def test_deterministic_class_assignment(
for i in range(num_classes_per_partition)
]
)
# pylint: disable=unsubscriptable-object
if isinstance(dataset["labels"][0], str):
expected_labels = [str(label) for label in expected_labels]
actual_labels = sorted(np.unique(partition["labels"]))
self.assertTrue(
np.array_equal(expected_labels, actual_labels),
Expand All @@ -166,6 +181,9 @@ def test_too_many_partitions_for_a_class(
"labels": np.array([num_unique_classes - 1] * (num_samples // 2)),
"features": np.random.randn(num_samples // 2),
}
# pylint: disable=unsubscriptable-object
if isinstance(dataset_1["labels"][0], str):
data["labels"] = data["labels"].astype(str)
dataset_2 = Dataset.from_dict(data)
dataset = datasets.concatenate_datasets([dataset_1, dataset_2])

Expand Down

0 comments on commit 917d406

Please sign in to comment.