diff --git a/datasets/docs/source/how-to-install-flwr-datasets.rst b/datasets/docs/source/how-to-install-flwr-datasets.rst
index 3f79daceb753..5f89261f0b29 100644
--- a/datasets/docs/source/how-to-install-flwr-datasets.rst
+++ b/datasets/docs/source/how-to-install-flwr-datasets.rst
@@ -4,7 +4,7 @@ Installation
Python Version
--------------
-Flower Datasets requires `Python 3.8 `_ or above.
+Flower Datasets requires `Python 3.9 `_ or above.
Install stable release (pip)
@@ -20,14 +20,41 @@ For vision datasets (e.g. MNIST, CIFAR10) ``flwr-datasets`` should be installed
.. code-block:: bash
- python -m pip install flwr_datasets[vision]
+ python -m pip install "flwr-datasets[vision]"
For audio datasets (e.g. Speech Command) ``flwr-datasets`` should be installed with the ``audio`` extra
.. code-block:: bash
- python -m pip install flwr_datasets[audio]
+ python -m pip install "flwr-datasets[audio]"
+Install directly from GitHub (pip)
+----------------------------------
+
+Installing Flower Datasets directly from GitHub ensures you have access to the most up-to-date version.
+If you encounter any issues or bugs, you may be directed to a specific branch containing a fix before
+it becomes part of an official release.
+
+.. code-block:: bash
+
+ python -m pip install "flwr-datasets@git+https://github.com/adap/flower.git"\
+ "@TYPE-HERE-BRANCH-NAME#subdirectory=datasets"
+
+Similarly to the situation before, you can specify the ``vision`` or ``audio`` extra after the name of the library.
+
+.. code-block:: bash
+
+ python -m pip install "flwr-datasets[vision]@git+https://github.com/adap/flower.git"\
+ "@TYPE-HERE-BRANCH-NAME#subdirectory=datasets"
+
+e.g. for the main branch:
+
+.. code-block:: bash
+
+ python -m pip install "flwr-datasets@git+https://github.com/adap/flower.git"\
+ "@main#subdirectory=datasets"
+
+Since `flwr-datasets` is a part of the Flower repository, the `subdirectory` parameter (at the end of the URL) is used to specify the package location in the GitHub repo.
Verify installation
-------------------
@@ -38,7 +65,7 @@ The following command can be used to verify if Flower Datasets was successfully
python -c "import flwr_datasets;print(flwr_datasets.__version__)"
-If everything worked, it should print the version of Flower Datasets to the command line:
+If everything works, it should print the version of Flower Datasets to the command line:
.. code-block:: none
diff --git a/datasets/flwr_datasets/partitioner/__init__.py b/datasets/flwr_datasets/partitioner/__init__.py
index a14efa1cc905..8770d5b8b76e 100644
--- a/datasets/flwr_datasets/partitioner/__init__.py
+++ b/datasets/flwr_datasets/partitioner/__init__.py
@@ -29,6 +29,7 @@
from .shard_partitioner import ShardPartitioner
from .size_partitioner import SizePartitioner
from .square_partitioner import SquarePartitioner
+from .vertical_size_partitioner import VerticalSizePartitioner
__all__ = [
"DirichletPartitioner",
@@ -45,4 +46,5 @@
"ShardPartitioner",
"SizePartitioner",
"SquarePartitioner",
+ "VerticalSizePartitioner",
]
diff --git a/datasets/flwr_datasets/partitioner/vertical_partitioner_utils.py b/datasets/flwr_datasets/partitioner/vertical_partitioner_utils.py
new file mode 100644
index 000000000000..e9e7e3855ef4
--- /dev/null
+++ b/datasets/flwr_datasets/partitioner/vertical_partitioner_utils.py
@@ -0,0 +1,103 @@
+# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""VerticalPartitioner utils.py."""
+# flake8: noqa: E501
+# pylint: disable=C0301
+from typing import Any, Literal, Union
+
+
+def _list_split(lst: list[Any], num_sublists: int) -> list[list[Any]]:
+ """Split a list into n nearly equal-sized sublists.
+
+ Parameters
+ ----------
+ lst : list[Any]
+ The list to split.
+ num_sublists : int
+ Number of sublists to create.
+
+ Returns
+ -------
+ subslist: list[list[Any]]
+ A list containing num_sublists sublists.
+ """
+ if num_sublists <= 0:
+ raise ValueError("Number of splits must be greater than 0")
+ chunk_size, remainder = divmod(len(lst), num_sublists)
+ sublists = []
+ start_index = 0
+ for i in range(num_sublists):
+ end_index = start_index + chunk_size
+ if i < remainder:
+ end_index += 1
+ sublists.append(lst[start_index:end_index])
+ start_index = end_index
+ return sublists
+
+
+def _add_active_party_columns(
+ active_party_columns: list[str],
+ active_party_columns_mode: Union[
+ Literal[
+ "add_to_first",
+ "add_to_last",
+ "create_as_first",
+ "create_as_last",
+ "add_to_all",
+ ],
+ int,
+ ],
+ partition_columns: list[list[str]],
+) -> list[list[str]]:
+ """Add active party columns to the partition columns based on the mode.
+
+ Parameters
+ ----------
+ active_party_columns : list[str]
+ List of active party columns.
+ active_party_columns_mode : Union[Literal["add_to_first", "add_to_last", "create_as_first", "create_as_last", "add_to_all"], int]
+ Mode to add active party columns to partition columns.
+
+ Returns
+ -------
+ partition_columns: list[list[str]]
+ List of partition columns after the modyfication.
+ """
+ if isinstance(active_party_columns_mode, int):
+ partition_id = active_party_columns_mode
+ if partition_id < 0 or partition_id >= len(partition_columns):
+ raise ValueError(
+ f"Invalid partition index {partition_id} for active_party_columns_mode."
+ f"Must be in the range [0, {len(partition_columns) - 1}]"
+ f"but given {partition_id}"
+ )
+ for column in active_party_columns:
+ partition_columns[partition_id].append(column)
+ else:
+ if active_party_columns_mode == "add_to_first":
+ for column in active_party_columns:
+ partition_columns[0].append(column)
+ elif active_party_columns_mode == "add_to_last":
+ for column in active_party_columns:
+ partition_columns[-1].append(column)
+ elif active_party_columns_mode == "create_as_first":
+ partition_columns.insert(0, active_party_columns)
+ elif active_party_columns_mode == "create_as_last":
+ partition_columns.append(active_party_columns)
+ elif active_party_columns_mode == "add_to_all":
+ for column in active_party_columns:
+ for partition in partition_columns:
+ partition.append(column)
+ return partition_columns
diff --git a/datasets/flwr_datasets/partitioner/vertical_partitioner_utils_test.py b/datasets/flwr_datasets/partitioner/vertical_partitioner_utils_test.py
new file mode 100644
index 000000000000..f85d027fe444
--- /dev/null
+++ b/datasets/flwr_datasets/partitioner/vertical_partitioner_utils_test.py
@@ -0,0 +1,144 @@
+# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for vertical partitioner utilities."""
+import unittest
+from typing import Any, Literal
+
+from flwr_datasets.partitioner.vertical_partitioner_utils import (
+ _add_active_party_columns,
+ _list_split,
+)
+
+
+class TestVerticalPartitionerUtils(unittest.TestCase):
+ """Tests for _list_split and _add_active_party_columns utilities."""
+
+ def test_list_split_basic_splitting(self) -> None:
+ """Check equal splitting with divisible lengths."""
+ lst = [1, 2, 3, 4, 5, 6]
+ result = _list_split(lst, 3)
+ expected = [[1, 2], [3, 4], [5, 6]]
+ self.assertEqual(result, expected)
+
+ def test_list_split_uneven_splitting(self) -> None:
+ """Check uneven splitting with non-divisible lengths."""
+ lst = [10, 20, 30, 40, 50]
+ result = _list_split(lst, 2)
+ expected = [[10, 20, 30], [40, 50]]
+ self.assertEqual(result, expected)
+
+ def test_list_split_single_sublist(self) -> None:
+ """Check that single sublist returns the full list."""
+ lst = [1, 2, 3]
+ result = _list_split(lst, 1)
+ expected = [[1, 2, 3]]
+ self.assertEqual(result, expected)
+
+ def test_list_split_more_sublists_than_elements(self) -> None:
+ """Check extra sublists are empty when count exceeds length."""
+ lst = [42]
+ result = _list_split(lst, 3)
+ expected = [[42], [], []]
+ self.assertEqual(result, expected)
+
+ def test_list_split_empty_list(self) -> None:
+ """Check splitting empty list produces empty sublists."""
+ lst: list[Any] = []
+ result = _list_split(lst, 3)
+ expected: list[list[Any]] = [[], [], []]
+ self.assertEqual(result, expected)
+
+ def test_list_split_invalid_num_sublists(self) -> None:
+ """Check ValueError when sublist count is zero or negative."""
+ lst = [1, 2, 3]
+ with self.assertRaises(ValueError):
+ _list_split(lst, 0)
+
+ def test_add_to_first(self) -> None:
+ """Check adding active cols to the first partition."""
+ partition_columns = [["col1", "col2"], ["col3"], ["col4"]]
+ active_party_columns = ["active1", "active2"]
+ mode: Literal["add_to_first"] = "add_to_first"
+ result = _add_active_party_columns(
+ active_party_columns, mode, partition_columns
+ )
+ self.assertEqual(
+ result, [["col1", "col2", "active1", "active2"], ["col3"], ["col4"]]
+ )
+
+ def test_add_to_last(self) -> None:
+ """Check adding active cols to the last partition."""
+ partition_columns = [["col1", "col2"], ["col3"], ["col4"]]
+ active_party_columns = ["active"]
+ mode: Literal["add_to_last"] = "add_to_last"
+ result = _add_active_party_columns(
+ active_party_columns, mode, partition_columns
+ )
+ self.assertEqual(result, [["col1", "col2"], ["col3"], ["col4", "active"]])
+
+ def test_create_as_first(self) -> None:
+ """Check creating a new first partition for active cols."""
+ partition_columns = [["col1"], ["col2"]]
+ active_party_columns = ["active1", "active2"]
+ mode: Literal["create_as_first"] = "create_as_first"
+ result = _add_active_party_columns(
+ active_party_columns, mode, partition_columns
+ )
+ self.assertEqual(result, [["active1", "active2"], ["col1"], ["col2"]])
+
+ def test_create_as_last(self) -> None:
+ """Check creating a new last partition for active cols."""
+ partition_columns = [["col1"], ["col2"]]
+ active_party_columns = ["active1", "active2"]
+ mode: Literal["create_as_last"] = "create_as_last"
+ result = _add_active_party_columns(
+ active_party_columns, mode, partition_columns
+ )
+ self.assertEqual(result, [["col1"], ["col2"], ["active1", "active2"]])
+
+ def test_add_to_all(self) -> None:
+ """Check adding active cols to all partitions."""
+ partition_columns = [["col1"], ["col2", "col3"], ["col4"]]
+ active_party_columns = ["active"]
+ mode: Literal["add_to_all"] = "add_to_all"
+ result = _add_active_party_columns(
+ active_party_columns, mode, partition_columns
+ )
+ self.assertEqual(
+ result, [["col1", "active"], ["col2", "col3", "active"], ["col4", "active"]]
+ )
+
+ def test_add_to_specific_partition_valid_index(self) -> None:
+ """Check adding active cols to a specific valid partition."""
+ partition_columns = [["col1"], ["col2"], ["col3"]]
+ active_party_columns = ["active1", "active2"]
+ mode: int = 1
+ result = _add_active_party_columns(
+ active_party_columns, mode, partition_columns
+ )
+ self.assertEqual(result, [["col1"], ["col2", "active1", "active2"], ["col3"]])
+
+ def test_add_to_specific_partition_invalid_index(self) -> None:
+ """Check ValueError when partition index is invalid."""
+ partition_columns = [["col1"], ["col2"]]
+ active_party_columns = ["active"]
+ mode: int = 5
+ with self.assertRaises(ValueError) as context:
+ _add_active_party_columns(active_party_columns, mode, partition_columns)
+ self.assertIn("Invalid partition index", str(context.exception))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/datasets/flwr_datasets/partitioner/vertical_size_partitioner.py b/datasets/flwr_datasets/partitioner/vertical_size_partitioner.py
new file mode 100644
index 000000000000..462a76a2e3f5
--- /dev/null
+++ b/datasets/flwr_datasets/partitioner/vertical_size_partitioner.py
@@ -0,0 +1,312 @@
+# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""VerticalSizePartitioner class."""
+# flake8: noqa: E501
+# pylint: disable=C0301, R0902, R0913
+from math import floor
+from typing import Literal, Optional, Union, cast
+
+import numpy as np
+
+import datasets
+from flwr_datasets.partitioner.partitioner import Partitioner
+from flwr_datasets.partitioner.vertical_partitioner_utils import (
+ _add_active_party_columns,
+)
+
+
+class VerticalSizePartitioner(Partitioner):
+ """Creates vertical partitions by spliting features (columns) based on sizes.
+
+ The sizes refer to the number of columns after the `drop_columns` are
+ dropped. `shared_columns` and `active_party_column` are excluded and
+ added only after the size-based division.
+
+ Enables selection of "active party" column(s) and palcement into
+ a specific partition or creation of a new partition just for it.
+ Also enables droping columns and sharing specified columns across
+ all partitions.
+
+ Parameters
+ ----------
+ partition_sizes : Union[list[int], list[float]]
+ A list where each value represents the size of a partition.
+ list[int] -> each value represent an absolute number of columns. Size zero is
+ allowed and will result in an empty partition if no shared columns are present.
+ A list of floats -> each value represent a fraction total number of columns.
+ Note that these values apply to collums without `active_party_columns`, `shared_columns`.
+ They are additionally included in to the partition(s). `drop_columns` are also not counted
+ toward the partition sizes.
+ In case fo list[int]: sum(partition_sizes) == len(columns) - len(drop_columns) -
+ len(shared_columns) - len(active_party_columns)
+ active_party_column : Optional[Union[str, list[str]]]
+ Column(s) (typically representing labels) associated with the
+ "active party" (which can be the server).
+ active_party_columns_mode : Union[Literal[["add_to_first", "add_to_last", "create_as_first", "create_as_last", "add_to_all"], int]
+ Determines how to assign the active party columns:
+
+ - `"add_to_first"`: Append active party columns to the first partition.
+ - `"add_to_last"`: Append active party columns to the last partition.
+ - `"create_as_first"`: Create a new partition at the start containing only these columns.
+ - `"create_as_last"`: Create a new partition at the end containing only these columns.
+ - `"add_to_all"`: Append active party columns to all partitions.
+ - int: Append active party columns to the specified partition index.
+ drop_columns : Optional[list[str]]
+ Columns to remove entirely from the dataset before partitioning.
+ shared_columns : Optional[list[str]]
+ Columns to duplicate into every partition after initial partitioning.
+ shuffle : bool
+ Whether to shuffle the order of columns before partitioning.
+ seed : Optional[int]
+ Random seed for shuffling columns. Has no effect if `shuffle=False`.
+
+ Examples
+ --------
+ >>> from flwr_datasets import FederatedDataset
+ >>> from flwr_datasets.partitioner import VerticalSizePartitioner
+ >>>
+ >>> partitioner = VerticalSizePartitioner(
+ ... partition_sizes=[8, 4, 2],
+ ... active_party_column="income",
+ ... active_party_columns_mode="create_as_last"
+ ... )
+ >>> fds = FederatedDataset(
+ ... dataset="scikit-learn/adult-census-income",
+ ... partitioners={"train": partitioner}
+ ... )
+ >>> partitions = [fds.load_partition(i) for i in range(fds.partitioners["train"].num_partitions)]
+ >>> print([partition.column_names for partition in partitions])
+ """
+
+ def __init__(
+ self,
+ partition_sizes: Union[list[int], list[float]],
+ active_party_column: Optional[Union[str, list[str]]] = None,
+ active_party_columns_mode: Union[
+ Literal[
+ "add_to_first",
+ "add_to_last",
+ "create_as_first",
+ "create_as_last",
+ "add_to_all",
+ ],
+ int,
+ ] = "add_to_last",
+ drop_columns: Optional[list[str]] = None,
+ shared_columns: Optional[list[str]] = None,
+ shuffle: bool = True,
+ seed: Optional[int] = 42,
+ ) -> None:
+ super().__init__()
+
+ self._partition_sizes = partition_sizes
+ self._active_party_columns = self._init_active_party_column(active_party_column)
+ self._active_party_columns_mode = active_party_columns_mode
+ self._drop_columns = drop_columns or []
+ self._shared_columns = shared_columns or []
+ self._shuffle = shuffle
+ self._seed = seed
+ self._rng = np.random.default_rng(seed=self._seed)
+
+ self._partition_columns: Optional[list[list[str]]] = None
+ self._partitions_determined = False
+
+ self._validate_parameters_in_init()
+
+ def _determine_partitions_if_needed(self) -> None:
+ if self._partitions_determined:
+ return
+
+ if self.dataset is None:
+ raise ValueError("No dataset is set for this partitioner.")
+
+ all_columns = list(self.dataset.column_names)
+ self._validate_parameters_while_partitioning(
+ all_columns, self._shared_columns, self._active_party_columns
+ )
+ columns = [column for column in all_columns if column not in self._drop_columns]
+ columns = [column for column in columns if column not in self._shared_columns]
+ columns = [
+ column for column in columns if column not in self._active_party_columns
+ ]
+
+ if self._shuffle:
+ self._rng.shuffle(columns)
+ if all(isinstance(fraction, float) for fraction in self._partition_sizes):
+ partition_columns = _fraction_split(
+ columns, cast(list[float], self._partition_sizes)
+ )
+ else:
+ partition_columns = _count_split(
+ columns, cast(list[int], self._partition_sizes)
+ )
+
+ partition_columns = _add_active_party_columns(
+ self._active_party_columns,
+ self._active_party_columns_mode,
+ partition_columns,
+ )
+
+ # Add shared columns to all partitions
+ for partition in partition_columns:
+ for column in self._shared_columns:
+ partition.append(column)
+
+ self._partition_columns = partition_columns
+ self._partitions_determined = True
+
+ def load_partition(self, partition_id: int) -> datasets.Dataset:
+ """Load a partition based on the partition index.
+
+ Parameters
+ ----------
+ partition_id : int
+ The index that corresponds to the requested partition.
+
+ Returns
+ -------
+ dataset_partition : Dataset
+ Single partition of a dataset.
+ """
+ self._determine_partitions_if_needed()
+ assert self._partition_columns is not None
+ if partition_id < 0 or partition_id >= len(self._partition_columns):
+ raise IndexError(
+ f"partition_id: {partition_id} out of range <0, {self.num_partitions - 1}>."
+ )
+ columns = self._partition_columns[partition_id]
+ return self.dataset.select_columns(columns)
+
+ @property
+ def num_partitions(self) -> int:
+ """Number of partitions."""
+ self._determine_partitions_if_needed()
+ assert self._partition_columns is not None
+ return len(self._partition_columns)
+
+ def _validate_parameters_in_init(self) -> None:
+ if not isinstance(self._partition_sizes, list):
+ raise ValueError("partition_sizes must be a list.")
+ if all(isinstance(fraction, float) for fraction in self._partition_sizes):
+ fraction_sum = sum(self._partition_sizes)
+ if fraction_sum != 1.0:
+ raise ValueError("Float ratios in `partition_sizes` must sum to 1.0.")
+ if any(
+ fraction < 0.0 or fraction > 1.0 for fraction in self._partition_sizes
+ ):
+ raise ValueError(
+ "All floats in `partition_sizes` must be >= 0.0 and <= 1.0."
+ )
+ elif all(
+ isinstance(coulumn_count, int) for coulumn_count in self._partition_sizes
+ ):
+ if any(coulumn_count < 0 for coulumn_count in self._partition_sizes):
+ raise ValueError("All integers in `partition_sizes` must be >= 0.")
+ else:
+ raise ValueError("`partition_sizes` list must be all floats or all ints.")
+
+ # Validate columns lists
+ for parameter_name, parameter_list in [
+ ("drop_columns", self._drop_columns),
+ ("shared_columns", self._shared_columns),
+ ("active_party_columns", self._active_party_columns),
+ ]:
+ if not all(isinstance(column, str) for column in parameter_list):
+ raise ValueError(f"All entries in {parameter_name} must be strings.")
+
+ valid_modes = {
+ "add_to_first",
+ "add_to_last",
+ "create_as_first",
+ "create_as_last",
+ "add_to_all",
+ }
+ if not (
+ isinstance(self._active_party_columns_mode, int)
+ or self._active_party_columns_mode in valid_modes
+ ):
+ raise ValueError(
+ "active_party_columns_mode must be an int or one of "
+ "'add_to_first', 'add_to_last', 'create_as_first', 'create_as_last', "
+ "'add_to_all'."
+ )
+
+ def _validate_parameters_while_partitioning(
+ self,
+ all_columns: list[str],
+ shared_columns: list[str],
+ active_party_columns: list[str],
+ ) -> None:
+ # Shared columns existance check
+ for column in shared_columns:
+ if column not in all_columns:
+ raise ValueError(f"Shared column '{column}' not found in the dataset.")
+ # Active party columns existence check
+ for column in active_party_columns:
+ if column not in all_columns:
+ raise ValueError(
+ f"Active party column '{column}' not found in the dataset."
+ )
+ num_columns = len(all_columns)
+ num_cols_unused_in_core_div = 0
+ if self._active_party_columns is not None:
+ num_cols_unused_in_core_div += len(self._active_party_columns)
+ if self._shared_columns is not None:
+ num_cols_unused_in_core_div += len(self._shared_columns)
+ if self._drop_columns is not None:
+ num_cols_unused_in_core_div += len(self._drop_columns)
+ num_core_div_columns = num_columns - num_cols_unused_in_core_div
+ if all(isinstance(size, int) for size in self._partition_sizes):
+ if sum(self._partition_sizes) != num_core_div_columns:
+ raise ValueError(
+ "Sum of partition sizes cannot differ from the total number of columns "
+ "used in the division. Note that shared_columns, drop_columns and"
+ "active_party_columns are not included in the division."
+ )
+
+ def _init_active_party_column(
+ self, active_party_column: Optional[Union[str, list[str]]]
+ ) -> list[str]:
+ if active_party_column is None:
+ return []
+ if isinstance(active_party_column, str):
+ return [active_party_column]
+ if isinstance(active_party_column, list):
+ return active_party_column
+ raise ValueError("active_party_column must be a string or a list of strings.")
+
+
+def _count_split(columns: list[str], counts: list[int]) -> list[list[str]]:
+ partition_columns = []
+ start = 0
+ for count in counts:
+ end = start + count
+ partition_columns.append(columns[start:end])
+ start = end
+ return partition_columns
+
+
+def _fraction_split(columns: list[str], fractions: list[float]) -> list[list[str]]:
+ num_columns = len(columns)
+ partitions = []
+ cumulative = 0
+ for index, fraction in enumerate(fractions):
+ count = int(floor(fraction * num_columns))
+ if index == len(fractions) - 1:
+ # Last partition takes the remainder
+ count = num_columns - cumulative
+ partitions.append(columns[cumulative : cumulative + count])
+ cumulative += count
+ return partitions
diff --git a/datasets/flwr_datasets/partitioner/vertical_size_partitioner_test.py b/datasets/flwr_datasets/partitioner/vertical_size_partitioner_test.py
new file mode 100644
index 000000000000..d2c483c2be88
--- /dev/null
+++ b/datasets/flwr_datasets/partitioner/vertical_size_partitioner_test.py
@@ -0,0 +1,206 @@
+# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""VerticalSizePartitioner class tests."""
+# mypy: disable-error-code=arg-type
+# pylint: disable=R0902, R0913
+import unittest
+
+import numpy as np
+
+from datasets import Dataset
+from flwr_datasets.partitioner.vertical_size_partitioner import VerticalSizePartitioner
+
+
+def _create_dummy_dataset(column_names: list[str], num_rows: int = 100) -> Dataset:
+ """Create a dataset with random integer data."""
+ rng = np.random.default_rng(seed=42)
+ data = {col: rng.integers(0, 100, size=num_rows).tolist() for col in column_names}
+ return Dataset.from_dict(data)
+
+
+class TestVerticalSizePartitioner(unittest.TestCase):
+ """Tests for VerticalSizePartitioner."""
+
+ def test_init_invalid_partition_sizes_type(self) -> None:
+ """Check ValueError if partition_sizes is not a list."""
+ with self.assertRaises(ValueError):
+ VerticalSizePartitioner(partition_sizes="not_a_list")
+
+ def test_init_mixed_partition_sizes_types(self) -> None:
+ """Check ValueError if partition_sizes mix int and float."""
+ with self.assertRaises(ValueError):
+ VerticalSizePartitioner(partition_sizes=[0.5, 1])
+
+ def test_init_float_partitions_sum_not_one(self) -> None:
+ """Check ValueError if float partitions do not sum to 1."""
+ with self.assertRaises(ValueError):
+ VerticalSizePartitioner(partition_sizes=[0.3, 0.3])
+
+ def test_init_float_partitions_out_of_range(self) -> None:
+ """Check ValueError if any float partition <0 or >1."""
+ with self.assertRaises(ValueError):
+ VerticalSizePartitioner(partition_sizes=[-0.5, 1.5])
+
+ def test_init_int_partitions_negative(self) -> None:
+ """Check ValueError if any int partition size is negative."""
+ with self.assertRaises(ValueError):
+ VerticalSizePartitioner(partition_sizes=[5, -1])
+
+ def test_init_invalid_mode(self) -> None:
+ """Check ValueError if active_party_columns_mode is invalid."""
+ with self.assertRaises(ValueError):
+ VerticalSizePartitioner(
+ partition_sizes=[2, 2], active_party_columns_mode="invalid"
+ )
+
+ def test_init_active_party_column_invalid_type(self) -> None:
+ """Check ValueError if active_party_column is not str/list."""
+ with self.assertRaises(ValueError):
+ VerticalSizePartitioner(partition_sizes=[2, 2], active_party_column=123)
+
+ def test_partitioning_with_int_sizes(self) -> None:
+ """Check correct partitioning with integer sizes."""
+ columns = ["f1", "f2", "f3", "f4", "f5"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(partition_sizes=[2, 3], shuffle=False)
+ partitioner.dataset = dataset
+ p0 = partitioner.load_partition(0)
+ p1 = partitioner.load_partition(1)
+ self.assertEqual(len(p0.column_names), 2)
+ self.assertEqual(len(p1.column_names), 3)
+
+ def test_partitioning_with_fraction_sizes(self) -> None:
+ """Check correct partitioning with fraction sizes."""
+ columns = ["f1", "f2", "f3", "f4"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(partition_sizes=[0.5, 0.5], shuffle=False)
+ partitioner.dataset = dataset
+ p0 = partitioner.load_partition(0)
+ p1 = partitioner.load_partition(1)
+ self.assertEqual(len(p0.column_names), 2)
+ self.assertEqual(len(p1.column_names), 2)
+
+ def test_partitioning_with_drop_columns(self) -> None:
+ """Check dropping specified columns before partitioning."""
+ columns = ["f1", "drop_me", "f2", "f3"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(
+ partition_sizes=[2, 1], drop_columns=["drop_me"], shuffle=False
+ )
+ partitioner.dataset = dataset
+ p0 = partitioner.load_partition(0)
+ p1 = partitioner.load_partition(1)
+ all_cols = p0.column_names + p1.column_names
+ self.assertNotIn("drop_me", all_cols)
+
+ def test_partitioning_with_shared_columns(self) -> None:
+ """Check shared columns added to every partition."""
+ columns = ["f1", "f2", "shared"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(
+ partition_sizes=[1, 1], shared_columns=["shared"], shuffle=False
+ )
+ partitioner.dataset = dataset
+ p0 = partitioner.load_partition(0)
+ p1 = partitioner.load_partition(1)
+ self.assertIn("shared", p0.column_names)
+ self.assertIn("shared", p1.column_names)
+
+ def test_partitioning_with_active_party_add_to_last(self) -> None:
+ """Check active party columns added to the last partition."""
+ columns = ["f1", "f2", "label"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(
+ partition_sizes=[2],
+ active_party_column="label",
+ active_party_columns_mode="add_to_last",
+ shuffle=False,
+ )
+ partitioner.dataset = dataset
+ p0 = partitioner.load_partition(0)
+ self.assertIn("label", p0.column_names)
+
+ def test_partitioning_with_active_party_create_as_first(self) -> None:
+ """Check creating a new first partition for active party cols."""
+ columns = ["f1", "f2", "label"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(
+ partition_sizes=[2],
+ active_party_column="label",
+ active_party_columns_mode="create_as_first",
+ shuffle=False,
+ )
+ partitioner.dataset = dataset
+ self.assertEqual(partitioner.num_partitions, 2)
+ p0 = partitioner.load_partition(0)
+ p1 = partitioner.load_partition(1)
+ self.assertEqual(p0.column_names, ["label"])
+ self.assertIn("f1", p1.column_names)
+ self.assertIn("f2", p1.column_names)
+
+ def test_partitioning_with_nonexistent_shared_column(self) -> None:
+ """Check ValueError if shared column does not exist."""
+ columns = ["f1", "f2"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(
+ partition_sizes=[1], shared_columns=["nonexistent"], shuffle=False
+ )
+ partitioner.dataset = dataset
+ with self.assertRaises(ValueError):
+ partitioner.load_partition(0)
+
+ def test_partitioning_with_nonexistent_active_party_column(self) -> None:
+ """Check ValueError if active party column does not exist."""
+ columns = ["f1", "f2"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(
+ partition_sizes=[1], active_party_column="missing_label", shuffle=False
+ )
+ partitioner.dataset = dataset
+ with self.assertRaises(ValueError):
+ partitioner.load_partition(0)
+
+ def test_sum_of_int_partition_sizes_exceeds_num_columns(self) -> None:
+ """Check ValueError if sum of int sizes > total columns."""
+ columns = ["f1", "f2"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(partition_sizes=[3], shuffle=False)
+ partitioner.dataset = dataset
+ with self.assertRaises(ValueError):
+ partitioner.load_partition(0)
+
+ def test_sum_of_int_partition_sizes_indirectly_exceeds_num_columns(self) -> None:
+ """Check ValueError if sum of int sizes > total columns."""
+ columns = ["f1", "f2", "f3"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(
+ partition_sizes=[1, 1], drop_columns=["f3", "f2"], shuffle=False
+ )
+ partitioner.dataset = dataset
+ with self.assertRaises(ValueError):
+ partitioner.load_partition(0)
+
+ def test_sum_of_int_partition_sizes_is_smaller_than_num_columns(self) -> None:
+ """Check ValueError if sum of int sizes < total columns."""
+ columns = ["f1", "f2", "f3"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(partition_sizes=[2], shuffle=False)
+ partitioner.dataset = dataset
+ with self.assertRaises(ValueError):
+ partitioner.load_partition(0)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/datasets/pyproject.toml b/datasets/pyproject.toml
index 2d699c5e901b..497f89a2f7ca 100644
--- a/datasets/pyproject.toml
+++ b/datasets/pyproject.toml
@@ -34,6 +34,8 @@ classifiers = [
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
"Programming Language :: Python :: Implementation :: CPython",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
diff --git a/src/py/flwr/cli/new/templates/app/.gitignore.tpl b/src/py/flwr/cli/new/templates/app/.gitignore.tpl
index 68bc17f9ff21..f791a9b679d8 100644
--- a/src/py/flwr/cli/new/templates/app/.gitignore.tpl
+++ b/src/py/flwr/cli/new/templates/app/.gitignore.tpl
@@ -3,6 +3,9 @@ __pycache__/
*.py[cod]
*$py.class
+# Flower directory
+.flwr
+
# C extensions
*.so
diff --git a/src/py/flwr/cli/utils.py b/src/py/flwr/cli/utils.py
index 8cb89255ed40..e01a0439c9da 100644
--- a/src/py/flwr/cli/utils.py
+++ b/src/py/flwr/cli/utils.py
@@ -22,7 +22,7 @@
from contextlib import contextmanager
from logging import DEBUG
from pathlib import Path
-from typing import Any, Callable, Optional, cast
+from typing import Any, Callable, Optional, Union, cast
import grpc
import typer
@@ -148,23 +148,69 @@ def sanitize_project_name(name: str) -> str:
return sanitized_name
-def get_sha256_hash(file_path: Path) -> str:
+def get_sha256_hash(file_path_or_int: Union[Path, int]) -> str:
"""Calculate the SHA-256 hash of a file."""
sha256 = hashlib.sha256()
- with open(file_path, "rb") as f:
- while True:
- data = f.read(65536) # Read in 64kB blocks
- if not data:
- break
- sha256.update(data)
+ if isinstance(file_path_or_int, Path):
+ with open(file_path_or_int, "rb") as f:
+ while True:
+ data = f.read(65536) # Read in 64kB blocks
+ if not data:
+ break
+ sha256.update(data)
+ elif isinstance(file_path_or_int, int):
+ sha256.update(str(file_path_or_int).encode())
return sha256.hexdigest()
def get_user_auth_config_path(root_dir: Path, federation: str) -> Path:
- """Return the path to the user auth config file."""
+ """Return the path to the user auth config file.
+
+ Additionally, a `.gitignore` file will be created in the Flower directory to
+ include the `.credentials` folder to be excluded from git. If the `.gitignore`
+ file already exists, a warning will be displayed if the `.credentials` entry is
+ not found.
+ """
# Locate the credentials directory
- credentials_dir = root_dir.absolute() / FLWR_DIR / CREDENTIALS_DIR
+ abs_flwr_dir = root_dir.absolute() / FLWR_DIR
+ credentials_dir = abs_flwr_dir / CREDENTIALS_DIR
credentials_dir.mkdir(parents=True, exist_ok=True)
+
+ # Determine the absolute path of the Flower directory for .gitignore
+ gitignore_path = abs_flwr_dir / ".gitignore"
+ credential_entry = CREDENTIALS_DIR
+
+ try:
+ if gitignore_path.exists():
+ with open(gitignore_path, encoding="utf-8") as gitignore_file:
+ lines = gitignore_file.read().splitlines()
+
+ # Warn if .credentials is not already in .gitignore
+ if credential_entry not in lines:
+ typer.secho(
+ f"`.gitignore` exists, but `{credential_entry}` entry not found. "
+ "Consider adding it to your `.gitignore` to exclude Flower "
+ "credentials from git.",
+ fg=typer.colors.YELLOW,
+ bold=True,
+ )
+ else:
+ typer.secho(
+ f"Creating a new `.gitignore` with `{credential_entry}` entry...",
+ fg=typer.colors.BLUE,
+ )
+ # Create a new .gitignore with .credentials
+ with open(gitignore_path, "w", encoding="utf-8") as gitignore_file:
+ gitignore_file.write(f"{credential_entry}\n")
+ except Exception as err:
+ typer.secho(
+ "❌ An error occurred while handling `.gitignore.` "
+ f"Please check the permissions of `{gitignore_path}` and try again.",
+ fg=typer.colors.RED,
+ bold=True,
+ )
+ raise typer.Exit(code=1) from err
+
return credentials_dir / f"{federation}.json"
diff --git a/src/py/flwr/cli/utils_test.py b/src/py/flwr/cli/utils_test.py
new file mode 100644
index 000000000000..e722dee70c3c
--- /dev/null
+++ b/src/py/flwr/cli/utils_test.py
@@ -0,0 +1,109 @@
+# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test for Flower command line interface utils."""
+
+
+import hashlib
+import os
+import tempfile
+import unittest
+from pathlib import Path
+
+from flwr.cli.utils import get_sha256_hash
+
+
+class TestGetSHA256Hash(unittest.TestCase):
+ """Unit tests for `get_sha256_hash` function."""
+
+ def test_hash_with_integer(self) -> None:
+ """Test the SHA-256 hash calculation when input is an integer."""
+ # Prepare
+ test_int = 13413
+ expected_hash = hashlib.sha256(str(test_int).encode()).hexdigest()
+
+ # Execute
+ result = get_sha256_hash(test_int)
+
+ # Assert
+ self.assertEqual(result, expected_hash)
+
+ def test_hash_with_file(self) -> None:
+ """Test the SHA-256 hash calculation when input is a file path."""
+ # Prepare - Create a temporary file with known content
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
+ temp_file.write(b"Test content for SHA-256 hashing.")
+ temp_file_path = Path(temp_file.name)
+
+ try:
+ # Execute
+ sha256 = hashlib.sha256()
+ with open(temp_file_path, "rb") as f:
+ while True:
+ data = f.read(65536)
+ if not data:
+ break
+ sha256.update(data)
+ expected_hash = sha256.hexdigest()
+
+ result = get_sha256_hash(temp_file_path)
+
+ # Assert
+ self.assertEqual(result, expected_hash)
+ finally:
+ # Clean up the temporary file
+ os.remove(temp_file_path)
+
+ def test_empty_file(self) -> None:
+ """Test the SHA-256 hash calculation for an empty file."""
+ # Prepare
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
+ temp_file_path = Path(temp_file.name)
+
+ try:
+ # Execute
+ expected_hash = hashlib.sha256(b"").hexdigest()
+ result = get_sha256_hash(temp_file_path)
+
+ # Assert
+ self.assertEqual(result, expected_hash)
+ finally:
+ os.remove(temp_file_path)
+
+ def test_large_file(self) -> None:
+ """Test the SHA-256 hash calculation for a large file."""
+ # Prepare - Generate large data (e.g., 10 MB)
+ large_data = b"a" * (10 * 1024 * 1024) # 10 MB of 'a's
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
+ temp_file.write(large_data)
+ temp_file_path = Path(temp_file.name)
+
+ try:
+ expected_hash = hashlib.sha256(large_data).hexdigest()
+ # Execute
+ result = get_sha256_hash(temp_file_path)
+
+ # Assert
+ self.assertEqual(result, expected_hash)
+ finally:
+ os.remove(temp_file_path)
+
+ def test_nonexistent_file(self) -> None:
+ """Test the SHA-256 hash calculation when the file does not exist."""
+ # Prepare
+ nonexistent_path = Path("/path/to/nonexistent/file.txt")
+
+ # Execute & assert
+ with self.assertRaises(FileNotFoundError):
+ get_sha256_hash(nonexistent_path)
diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py
index 91514f845651..b5c7ae95e224 100644
--- a/src/py/flwr/server/app.py
+++ b/src/py/flwr/server/app.py
@@ -374,6 +374,7 @@ def run_superlink() -> None:
server_public_key,
) = maybe_keys
state = state_factory.state()
+ state.clear_supernode_auth_keys_and_credentials()
state.store_node_public_keys(node_public_keys)
state.store_server_private_public_key(
private_key_to_bytes(server_private_key),
diff --git a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py
index f26bb11a4bdb..d22072b41621 100644
--- a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py
+++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py
@@ -430,6 +430,13 @@ def get_server_public_key(self) -> Optional[bytes]:
"""Retrieve `server_public_key` in urlsafe bytes."""
return self.server_public_key
+ def clear_supernode_auth_keys_and_credentials(self) -> None:
+ """Clear stored `node_public_keys` and credentials in the link state if any."""
+ with self.lock:
+ self.server_private_key = None
+ self.server_public_key = None
+ self.node_public_keys.clear()
+
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
"""Store a set of `node_public_keys` in the link state."""
with self.lock:
diff --git a/src/py/flwr/server/superlink/linkstate/linkstate.py b/src/py/flwr/server/superlink/linkstate/linkstate.py
index ae9d1710f069..4f3c16a5460a 100644
--- a/src/py/flwr/server/superlink/linkstate/linkstate.py
+++ b/src/py/flwr/server/superlink/linkstate/linkstate.py
@@ -284,6 +284,10 @@ def get_server_private_key(self) -> Optional[bytes]:
def get_server_public_key(self) -> Optional[bytes]:
"""Retrieve `server_public_key` in urlsafe bytes."""
+ @abc.abstractmethod
+ def clear_supernode_auth_keys_and_credentials(self) -> None:
+ """Clear stored `node_public_keys` and credentials in the link state if any."""
+
@abc.abstractmethod
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
"""Store a set of `node_public_keys` in the link state."""
diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py
index 93f5d94daef7..3edaf72ec20c 100644
--- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py
+++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py
@@ -820,6 +820,29 @@ def test_store_server_private_public_key_twice(self) -> None:
new_private_key_bytes, new_public_key_bytes
)
+ def test_clear_supernode_auth_keys_and_credentials(self) -> None:
+ """Test clear_supernode_auth_keys_and_credentials from linkstate."""
+ # Prepare
+ state: LinkState = self.state_factory()
+ key_pairs = [generate_key_pairs() for _ in range(3)]
+ public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs}
+
+ # Execute (store)
+ state.store_node_public_keys(public_keys)
+ private_key, public_key = generate_key_pairs()
+ private_key_bytes = private_key_to_bytes(private_key)
+ public_key_bytes = public_key_to_bytes(public_key)
+ state.store_server_private_public_key(private_key_bytes, public_key_bytes)
+
+ # Execute (clear)
+ state.clear_supernode_auth_keys_and_credentials()
+ node_public_keys = state.get_node_public_keys()
+
+ # Assert
+ assert node_public_keys == set()
+ assert state.get_server_private_key() is None
+ assert state.get_server_public_key() is None
+
def test_node_public_keys(self) -> None:
"""Test store_node_public_keys and get_node_public_keys from state."""
# Prepare
diff --git a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py
index 99334233319d..e8311dfaac5e 100644
--- a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py
+++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py
@@ -818,6 +818,12 @@ def get_server_public_key(self) -> Optional[bytes]:
public_key = None
return public_key
+ def clear_supernode_auth_keys_and_credentials(self) -> None:
+ """Clear stored `node_public_keys` and credentials in the link state if any."""
+ queries = ["DELETE FROM public_key;", "DELETE FROM credential;"]
+ for query in queries:
+ self.query(query)
+
def store_node_public_keys(self, public_keys: set[bytes]) -> None:
"""Store a set of `node_public_keys` in the link state."""
query = "INSERT INTO public_key (public_key) VALUES (?)"