diff --git a/datasets/docs/source/how-to-install-flwr-datasets.rst b/datasets/docs/source/how-to-install-flwr-datasets.rst index 3f79daceb753..5f89261f0b29 100644 --- a/datasets/docs/source/how-to-install-flwr-datasets.rst +++ b/datasets/docs/source/how-to-install-flwr-datasets.rst @@ -4,7 +4,7 @@ Installation Python Version -------------- -Flower Datasets requires `Python 3.8 `_ or above. +Flower Datasets requires `Python 3.9 `_ or above. Install stable release (pip) @@ -20,14 +20,41 @@ For vision datasets (e.g. MNIST, CIFAR10) ``flwr-datasets`` should be installed .. code-block:: bash - python -m pip install flwr_datasets[vision] + python -m pip install "flwr-datasets[vision]" For audio datasets (e.g. Speech Command) ``flwr-datasets`` should be installed with the ``audio`` extra .. code-block:: bash - python -m pip install flwr_datasets[audio] + python -m pip install "flwr-datasets[audio]" +Install directly from GitHub (pip) +---------------------------------- + +Installing Flower Datasets directly from GitHub ensures you have access to the most up-to-date version. +If you encounter any issues or bugs, you may be directed to a specific branch containing a fix before +it becomes part of an official release. + +.. code-block:: bash + + python -m pip install "flwr-datasets@git+https://github.com/adap/flower.git"\ + "@TYPE-HERE-BRANCH-NAME#subdirectory=datasets" + +Similarly to the situation before, you can specify the ``vision`` or ``audio`` extra after the name of the library. + +.. code-block:: bash + + python -m pip install "flwr-datasets[vision]@git+https://github.com/adap/flower.git"\ + "@TYPE-HERE-BRANCH-NAME#subdirectory=datasets" + +e.g. for the main branch: + +.. code-block:: bash + + python -m pip install "flwr-datasets@git+https://github.com/adap/flower.git"\ + "@main#subdirectory=datasets" + +Since `flwr-datasets` is a part of the Flower repository, the `subdirectory` parameter (at the end of the URL) is used to specify the package location in the GitHub repo. Verify installation ------------------- @@ -38,7 +65,7 @@ The following command can be used to verify if Flower Datasets was successfully python -c "import flwr_datasets;print(flwr_datasets.__version__)" -If everything worked, it should print the version of Flower Datasets to the command line: +If everything works, it should print the version of Flower Datasets to the command line: .. code-block:: none diff --git a/datasets/flwr_datasets/partitioner/__init__.py b/datasets/flwr_datasets/partitioner/__init__.py index a14efa1cc905..8770d5b8b76e 100644 --- a/datasets/flwr_datasets/partitioner/__init__.py +++ b/datasets/flwr_datasets/partitioner/__init__.py @@ -29,6 +29,7 @@ from .shard_partitioner import ShardPartitioner from .size_partitioner import SizePartitioner from .square_partitioner import SquarePartitioner +from .vertical_size_partitioner import VerticalSizePartitioner __all__ = [ "DirichletPartitioner", @@ -45,4 +46,5 @@ "ShardPartitioner", "SizePartitioner", "SquarePartitioner", + "VerticalSizePartitioner", ] diff --git a/datasets/flwr_datasets/partitioner/vertical_partitioner_utils.py b/datasets/flwr_datasets/partitioner/vertical_partitioner_utils.py new file mode 100644 index 000000000000..e9e7e3855ef4 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/vertical_partitioner_utils.py @@ -0,0 +1,103 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""VerticalPartitioner utils.py.""" +# flake8: noqa: E501 +# pylint: disable=C0301 +from typing import Any, Literal, Union + + +def _list_split(lst: list[Any], num_sublists: int) -> list[list[Any]]: + """Split a list into n nearly equal-sized sublists. + + Parameters + ---------- + lst : list[Any] + The list to split. + num_sublists : int + Number of sublists to create. + + Returns + ------- + subslist: list[list[Any]] + A list containing num_sublists sublists. + """ + if num_sublists <= 0: + raise ValueError("Number of splits must be greater than 0") + chunk_size, remainder = divmod(len(lst), num_sublists) + sublists = [] + start_index = 0 + for i in range(num_sublists): + end_index = start_index + chunk_size + if i < remainder: + end_index += 1 + sublists.append(lst[start_index:end_index]) + start_index = end_index + return sublists + + +def _add_active_party_columns( + active_party_columns: list[str], + active_party_columns_mode: Union[ + Literal[ + "add_to_first", + "add_to_last", + "create_as_first", + "create_as_last", + "add_to_all", + ], + int, + ], + partition_columns: list[list[str]], +) -> list[list[str]]: + """Add active party columns to the partition columns based on the mode. + + Parameters + ---------- + active_party_columns : list[str] + List of active party columns. + active_party_columns_mode : Union[Literal["add_to_first", "add_to_last", "create_as_first", "create_as_last", "add_to_all"], int] + Mode to add active party columns to partition columns. + + Returns + ------- + partition_columns: list[list[str]] + List of partition columns after the modyfication. + """ + if isinstance(active_party_columns_mode, int): + partition_id = active_party_columns_mode + if partition_id < 0 or partition_id >= len(partition_columns): + raise ValueError( + f"Invalid partition index {partition_id} for active_party_columns_mode." + f"Must be in the range [0, {len(partition_columns) - 1}]" + f"but given {partition_id}" + ) + for column in active_party_columns: + partition_columns[partition_id].append(column) + else: + if active_party_columns_mode == "add_to_first": + for column in active_party_columns: + partition_columns[0].append(column) + elif active_party_columns_mode == "add_to_last": + for column in active_party_columns: + partition_columns[-1].append(column) + elif active_party_columns_mode == "create_as_first": + partition_columns.insert(0, active_party_columns) + elif active_party_columns_mode == "create_as_last": + partition_columns.append(active_party_columns) + elif active_party_columns_mode == "add_to_all": + for column in active_party_columns: + for partition in partition_columns: + partition.append(column) + return partition_columns diff --git a/datasets/flwr_datasets/partitioner/vertical_partitioner_utils_test.py b/datasets/flwr_datasets/partitioner/vertical_partitioner_utils_test.py new file mode 100644 index 000000000000..f85d027fe444 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/vertical_partitioner_utils_test.py @@ -0,0 +1,144 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for vertical partitioner utilities.""" +import unittest +from typing import Any, Literal + +from flwr_datasets.partitioner.vertical_partitioner_utils import ( + _add_active_party_columns, + _list_split, +) + + +class TestVerticalPartitionerUtils(unittest.TestCase): + """Tests for _list_split and _add_active_party_columns utilities.""" + + def test_list_split_basic_splitting(self) -> None: + """Check equal splitting with divisible lengths.""" + lst = [1, 2, 3, 4, 5, 6] + result = _list_split(lst, 3) + expected = [[1, 2], [3, 4], [5, 6]] + self.assertEqual(result, expected) + + def test_list_split_uneven_splitting(self) -> None: + """Check uneven splitting with non-divisible lengths.""" + lst = [10, 20, 30, 40, 50] + result = _list_split(lst, 2) + expected = [[10, 20, 30], [40, 50]] + self.assertEqual(result, expected) + + def test_list_split_single_sublist(self) -> None: + """Check that single sublist returns the full list.""" + lst = [1, 2, 3] + result = _list_split(lst, 1) + expected = [[1, 2, 3]] + self.assertEqual(result, expected) + + def test_list_split_more_sublists_than_elements(self) -> None: + """Check extra sublists are empty when count exceeds length.""" + lst = [42] + result = _list_split(lst, 3) + expected = [[42], [], []] + self.assertEqual(result, expected) + + def test_list_split_empty_list(self) -> None: + """Check splitting empty list produces empty sublists.""" + lst: list[Any] = [] + result = _list_split(lst, 3) + expected: list[list[Any]] = [[], [], []] + self.assertEqual(result, expected) + + def test_list_split_invalid_num_sublists(self) -> None: + """Check ValueError when sublist count is zero or negative.""" + lst = [1, 2, 3] + with self.assertRaises(ValueError): + _list_split(lst, 0) + + def test_add_to_first(self) -> None: + """Check adding active cols to the first partition.""" + partition_columns = [["col1", "col2"], ["col3"], ["col4"]] + active_party_columns = ["active1", "active2"] + mode: Literal["add_to_first"] = "add_to_first" + result = _add_active_party_columns( + active_party_columns, mode, partition_columns + ) + self.assertEqual( + result, [["col1", "col2", "active1", "active2"], ["col3"], ["col4"]] + ) + + def test_add_to_last(self) -> None: + """Check adding active cols to the last partition.""" + partition_columns = [["col1", "col2"], ["col3"], ["col4"]] + active_party_columns = ["active"] + mode: Literal["add_to_last"] = "add_to_last" + result = _add_active_party_columns( + active_party_columns, mode, partition_columns + ) + self.assertEqual(result, [["col1", "col2"], ["col3"], ["col4", "active"]]) + + def test_create_as_first(self) -> None: + """Check creating a new first partition for active cols.""" + partition_columns = [["col1"], ["col2"]] + active_party_columns = ["active1", "active2"] + mode: Literal["create_as_first"] = "create_as_first" + result = _add_active_party_columns( + active_party_columns, mode, partition_columns + ) + self.assertEqual(result, [["active1", "active2"], ["col1"], ["col2"]]) + + def test_create_as_last(self) -> None: + """Check creating a new last partition for active cols.""" + partition_columns = [["col1"], ["col2"]] + active_party_columns = ["active1", "active2"] + mode: Literal["create_as_last"] = "create_as_last" + result = _add_active_party_columns( + active_party_columns, mode, partition_columns + ) + self.assertEqual(result, [["col1"], ["col2"], ["active1", "active2"]]) + + def test_add_to_all(self) -> None: + """Check adding active cols to all partitions.""" + partition_columns = [["col1"], ["col2", "col3"], ["col4"]] + active_party_columns = ["active"] + mode: Literal["add_to_all"] = "add_to_all" + result = _add_active_party_columns( + active_party_columns, mode, partition_columns + ) + self.assertEqual( + result, [["col1", "active"], ["col2", "col3", "active"], ["col4", "active"]] + ) + + def test_add_to_specific_partition_valid_index(self) -> None: + """Check adding active cols to a specific valid partition.""" + partition_columns = [["col1"], ["col2"], ["col3"]] + active_party_columns = ["active1", "active2"] + mode: int = 1 + result = _add_active_party_columns( + active_party_columns, mode, partition_columns + ) + self.assertEqual(result, [["col1"], ["col2", "active1", "active2"], ["col3"]]) + + def test_add_to_specific_partition_invalid_index(self) -> None: + """Check ValueError when partition index is invalid.""" + partition_columns = [["col1"], ["col2"]] + active_party_columns = ["active"] + mode: int = 5 + with self.assertRaises(ValueError) as context: + _add_active_party_columns(active_party_columns, mode, partition_columns) + self.assertIn("Invalid partition index", str(context.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/datasets/flwr_datasets/partitioner/vertical_size_partitioner.py b/datasets/flwr_datasets/partitioner/vertical_size_partitioner.py new file mode 100644 index 000000000000..462a76a2e3f5 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/vertical_size_partitioner.py @@ -0,0 +1,312 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""VerticalSizePartitioner class.""" +# flake8: noqa: E501 +# pylint: disable=C0301, R0902, R0913 +from math import floor +from typing import Literal, Optional, Union, cast + +import numpy as np + +import datasets +from flwr_datasets.partitioner.partitioner import Partitioner +from flwr_datasets.partitioner.vertical_partitioner_utils import ( + _add_active_party_columns, +) + + +class VerticalSizePartitioner(Partitioner): + """Creates vertical partitions by spliting features (columns) based on sizes. + + The sizes refer to the number of columns after the `drop_columns` are + dropped. `shared_columns` and `active_party_column` are excluded and + added only after the size-based division. + + Enables selection of "active party" column(s) and palcement into + a specific partition or creation of a new partition just for it. + Also enables droping columns and sharing specified columns across + all partitions. + + Parameters + ---------- + partition_sizes : Union[list[int], list[float]] + A list where each value represents the size of a partition. + list[int] -> each value represent an absolute number of columns. Size zero is + allowed and will result in an empty partition if no shared columns are present. + A list of floats -> each value represent a fraction total number of columns. + Note that these values apply to collums without `active_party_columns`, `shared_columns`. + They are additionally included in to the partition(s). `drop_columns` are also not counted + toward the partition sizes. + In case fo list[int]: sum(partition_sizes) == len(columns) - len(drop_columns) - + len(shared_columns) - len(active_party_columns) + active_party_column : Optional[Union[str, list[str]]] + Column(s) (typically representing labels) associated with the + "active party" (which can be the server). + active_party_columns_mode : Union[Literal[["add_to_first", "add_to_last", "create_as_first", "create_as_last", "add_to_all"], int] + Determines how to assign the active party columns: + + - `"add_to_first"`: Append active party columns to the first partition. + - `"add_to_last"`: Append active party columns to the last partition. + - `"create_as_first"`: Create a new partition at the start containing only these columns. + - `"create_as_last"`: Create a new partition at the end containing only these columns. + - `"add_to_all"`: Append active party columns to all partitions. + - int: Append active party columns to the specified partition index. + drop_columns : Optional[list[str]] + Columns to remove entirely from the dataset before partitioning. + shared_columns : Optional[list[str]] + Columns to duplicate into every partition after initial partitioning. + shuffle : bool + Whether to shuffle the order of columns before partitioning. + seed : Optional[int] + Random seed for shuffling columns. Has no effect if `shuffle=False`. + + Examples + -------- + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import VerticalSizePartitioner + >>> + >>> partitioner = VerticalSizePartitioner( + ... partition_sizes=[8, 4, 2], + ... active_party_column="income", + ... active_party_columns_mode="create_as_last" + ... ) + >>> fds = FederatedDataset( + ... dataset="scikit-learn/adult-census-income", + ... partitioners={"train": partitioner} + ... ) + >>> partitions = [fds.load_partition(i) for i in range(fds.partitioners["train"].num_partitions)] + >>> print([partition.column_names for partition in partitions]) + """ + + def __init__( + self, + partition_sizes: Union[list[int], list[float]], + active_party_column: Optional[Union[str, list[str]]] = None, + active_party_columns_mode: Union[ + Literal[ + "add_to_first", + "add_to_last", + "create_as_first", + "create_as_last", + "add_to_all", + ], + int, + ] = "add_to_last", + drop_columns: Optional[list[str]] = None, + shared_columns: Optional[list[str]] = None, + shuffle: bool = True, + seed: Optional[int] = 42, + ) -> None: + super().__init__() + + self._partition_sizes = partition_sizes + self._active_party_columns = self._init_active_party_column(active_party_column) + self._active_party_columns_mode = active_party_columns_mode + self._drop_columns = drop_columns or [] + self._shared_columns = shared_columns or [] + self._shuffle = shuffle + self._seed = seed + self._rng = np.random.default_rng(seed=self._seed) + + self._partition_columns: Optional[list[list[str]]] = None + self._partitions_determined = False + + self._validate_parameters_in_init() + + def _determine_partitions_if_needed(self) -> None: + if self._partitions_determined: + return + + if self.dataset is None: + raise ValueError("No dataset is set for this partitioner.") + + all_columns = list(self.dataset.column_names) + self._validate_parameters_while_partitioning( + all_columns, self._shared_columns, self._active_party_columns + ) + columns = [column for column in all_columns if column not in self._drop_columns] + columns = [column for column in columns if column not in self._shared_columns] + columns = [ + column for column in columns if column not in self._active_party_columns + ] + + if self._shuffle: + self._rng.shuffle(columns) + if all(isinstance(fraction, float) for fraction in self._partition_sizes): + partition_columns = _fraction_split( + columns, cast(list[float], self._partition_sizes) + ) + else: + partition_columns = _count_split( + columns, cast(list[int], self._partition_sizes) + ) + + partition_columns = _add_active_party_columns( + self._active_party_columns, + self._active_party_columns_mode, + partition_columns, + ) + + # Add shared columns to all partitions + for partition in partition_columns: + for column in self._shared_columns: + partition.append(column) + + self._partition_columns = partition_columns + self._partitions_determined = True + + def load_partition(self, partition_id: int) -> datasets.Dataset: + """Load a partition based on the partition index. + + Parameters + ---------- + partition_id : int + The index that corresponds to the requested partition. + + Returns + ------- + dataset_partition : Dataset + Single partition of a dataset. + """ + self._determine_partitions_if_needed() + assert self._partition_columns is not None + if partition_id < 0 or partition_id >= len(self._partition_columns): + raise IndexError( + f"partition_id: {partition_id} out of range <0, {self.num_partitions - 1}>." + ) + columns = self._partition_columns[partition_id] + return self.dataset.select_columns(columns) + + @property + def num_partitions(self) -> int: + """Number of partitions.""" + self._determine_partitions_if_needed() + assert self._partition_columns is not None + return len(self._partition_columns) + + def _validate_parameters_in_init(self) -> None: + if not isinstance(self._partition_sizes, list): + raise ValueError("partition_sizes must be a list.") + if all(isinstance(fraction, float) for fraction in self._partition_sizes): + fraction_sum = sum(self._partition_sizes) + if fraction_sum != 1.0: + raise ValueError("Float ratios in `partition_sizes` must sum to 1.0.") + if any( + fraction < 0.0 or fraction > 1.0 for fraction in self._partition_sizes + ): + raise ValueError( + "All floats in `partition_sizes` must be >= 0.0 and <= 1.0." + ) + elif all( + isinstance(coulumn_count, int) for coulumn_count in self._partition_sizes + ): + if any(coulumn_count < 0 for coulumn_count in self._partition_sizes): + raise ValueError("All integers in `partition_sizes` must be >= 0.") + else: + raise ValueError("`partition_sizes` list must be all floats or all ints.") + + # Validate columns lists + for parameter_name, parameter_list in [ + ("drop_columns", self._drop_columns), + ("shared_columns", self._shared_columns), + ("active_party_columns", self._active_party_columns), + ]: + if not all(isinstance(column, str) for column in parameter_list): + raise ValueError(f"All entries in {parameter_name} must be strings.") + + valid_modes = { + "add_to_first", + "add_to_last", + "create_as_first", + "create_as_last", + "add_to_all", + } + if not ( + isinstance(self._active_party_columns_mode, int) + or self._active_party_columns_mode in valid_modes + ): + raise ValueError( + "active_party_columns_mode must be an int or one of " + "'add_to_first', 'add_to_last', 'create_as_first', 'create_as_last', " + "'add_to_all'." + ) + + def _validate_parameters_while_partitioning( + self, + all_columns: list[str], + shared_columns: list[str], + active_party_columns: list[str], + ) -> None: + # Shared columns existance check + for column in shared_columns: + if column not in all_columns: + raise ValueError(f"Shared column '{column}' not found in the dataset.") + # Active party columns existence check + for column in active_party_columns: + if column not in all_columns: + raise ValueError( + f"Active party column '{column}' not found in the dataset." + ) + num_columns = len(all_columns) + num_cols_unused_in_core_div = 0 + if self._active_party_columns is not None: + num_cols_unused_in_core_div += len(self._active_party_columns) + if self._shared_columns is not None: + num_cols_unused_in_core_div += len(self._shared_columns) + if self._drop_columns is not None: + num_cols_unused_in_core_div += len(self._drop_columns) + num_core_div_columns = num_columns - num_cols_unused_in_core_div + if all(isinstance(size, int) for size in self._partition_sizes): + if sum(self._partition_sizes) != num_core_div_columns: + raise ValueError( + "Sum of partition sizes cannot differ from the total number of columns " + "used in the division. Note that shared_columns, drop_columns and" + "active_party_columns are not included in the division." + ) + + def _init_active_party_column( + self, active_party_column: Optional[Union[str, list[str]]] + ) -> list[str]: + if active_party_column is None: + return [] + if isinstance(active_party_column, str): + return [active_party_column] + if isinstance(active_party_column, list): + return active_party_column + raise ValueError("active_party_column must be a string or a list of strings.") + + +def _count_split(columns: list[str], counts: list[int]) -> list[list[str]]: + partition_columns = [] + start = 0 + for count in counts: + end = start + count + partition_columns.append(columns[start:end]) + start = end + return partition_columns + + +def _fraction_split(columns: list[str], fractions: list[float]) -> list[list[str]]: + num_columns = len(columns) + partitions = [] + cumulative = 0 + for index, fraction in enumerate(fractions): + count = int(floor(fraction * num_columns)) + if index == len(fractions) - 1: + # Last partition takes the remainder + count = num_columns - cumulative + partitions.append(columns[cumulative : cumulative + count]) + cumulative += count + return partitions diff --git a/datasets/flwr_datasets/partitioner/vertical_size_partitioner_test.py b/datasets/flwr_datasets/partitioner/vertical_size_partitioner_test.py new file mode 100644 index 000000000000..d2c483c2be88 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/vertical_size_partitioner_test.py @@ -0,0 +1,206 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""VerticalSizePartitioner class tests.""" +# mypy: disable-error-code=arg-type +# pylint: disable=R0902, R0913 +import unittest + +import numpy as np + +from datasets import Dataset +from flwr_datasets.partitioner.vertical_size_partitioner import VerticalSizePartitioner + + +def _create_dummy_dataset(column_names: list[str], num_rows: int = 100) -> Dataset: + """Create a dataset with random integer data.""" + rng = np.random.default_rng(seed=42) + data = {col: rng.integers(0, 100, size=num_rows).tolist() for col in column_names} + return Dataset.from_dict(data) + + +class TestVerticalSizePartitioner(unittest.TestCase): + """Tests for VerticalSizePartitioner.""" + + def test_init_invalid_partition_sizes_type(self) -> None: + """Check ValueError if partition_sizes is not a list.""" + with self.assertRaises(ValueError): + VerticalSizePartitioner(partition_sizes="not_a_list") + + def test_init_mixed_partition_sizes_types(self) -> None: + """Check ValueError if partition_sizes mix int and float.""" + with self.assertRaises(ValueError): + VerticalSizePartitioner(partition_sizes=[0.5, 1]) + + def test_init_float_partitions_sum_not_one(self) -> None: + """Check ValueError if float partitions do not sum to 1.""" + with self.assertRaises(ValueError): + VerticalSizePartitioner(partition_sizes=[0.3, 0.3]) + + def test_init_float_partitions_out_of_range(self) -> None: + """Check ValueError if any float partition <0 or >1.""" + with self.assertRaises(ValueError): + VerticalSizePartitioner(partition_sizes=[-0.5, 1.5]) + + def test_init_int_partitions_negative(self) -> None: + """Check ValueError if any int partition size is negative.""" + with self.assertRaises(ValueError): + VerticalSizePartitioner(partition_sizes=[5, -1]) + + def test_init_invalid_mode(self) -> None: + """Check ValueError if active_party_columns_mode is invalid.""" + with self.assertRaises(ValueError): + VerticalSizePartitioner( + partition_sizes=[2, 2], active_party_columns_mode="invalid" + ) + + def test_init_active_party_column_invalid_type(self) -> None: + """Check ValueError if active_party_column is not str/list.""" + with self.assertRaises(ValueError): + VerticalSizePartitioner(partition_sizes=[2, 2], active_party_column=123) + + def test_partitioning_with_int_sizes(self) -> None: + """Check correct partitioning with integer sizes.""" + columns = ["f1", "f2", "f3", "f4", "f5"] + dataset = _create_dummy_dataset(columns) + partitioner = VerticalSizePartitioner(partition_sizes=[2, 3], shuffle=False) + partitioner.dataset = dataset + p0 = partitioner.load_partition(0) + p1 = partitioner.load_partition(1) + self.assertEqual(len(p0.column_names), 2) + self.assertEqual(len(p1.column_names), 3) + + def test_partitioning_with_fraction_sizes(self) -> None: + """Check correct partitioning with fraction sizes.""" + columns = ["f1", "f2", "f3", "f4"] + dataset = _create_dummy_dataset(columns) + partitioner = VerticalSizePartitioner(partition_sizes=[0.5, 0.5], shuffle=False) + partitioner.dataset = dataset + p0 = partitioner.load_partition(0) + p1 = partitioner.load_partition(1) + self.assertEqual(len(p0.column_names), 2) + self.assertEqual(len(p1.column_names), 2) + + def test_partitioning_with_drop_columns(self) -> None: + """Check dropping specified columns before partitioning.""" + columns = ["f1", "drop_me", "f2", "f3"] + dataset = _create_dummy_dataset(columns) + partitioner = VerticalSizePartitioner( + partition_sizes=[2, 1], drop_columns=["drop_me"], shuffle=False + ) + partitioner.dataset = dataset + p0 = partitioner.load_partition(0) + p1 = partitioner.load_partition(1) + all_cols = p0.column_names + p1.column_names + self.assertNotIn("drop_me", all_cols) + + def test_partitioning_with_shared_columns(self) -> None: + """Check shared columns added to every partition.""" + columns = ["f1", "f2", "shared"] + dataset = _create_dummy_dataset(columns) + partitioner = VerticalSizePartitioner( + partition_sizes=[1, 1], shared_columns=["shared"], shuffle=False + ) + partitioner.dataset = dataset + p0 = partitioner.load_partition(0) + p1 = partitioner.load_partition(1) + self.assertIn("shared", p0.column_names) + self.assertIn("shared", p1.column_names) + + def test_partitioning_with_active_party_add_to_last(self) -> None: + """Check active party columns added to the last partition.""" + columns = ["f1", "f2", "label"] + dataset = _create_dummy_dataset(columns) + partitioner = VerticalSizePartitioner( + partition_sizes=[2], + active_party_column="label", + active_party_columns_mode="add_to_last", + shuffle=False, + ) + partitioner.dataset = dataset + p0 = partitioner.load_partition(0) + self.assertIn("label", p0.column_names) + + def test_partitioning_with_active_party_create_as_first(self) -> None: + """Check creating a new first partition for active party cols.""" + columns = ["f1", "f2", "label"] + dataset = _create_dummy_dataset(columns) + partitioner = VerticalSizePartitioner( + partition_sizes=[2], + active_party_column="label", + active_party_columns_mode="create_as_first", + shuffle=False, + ) + partitioner.dataset = dataset + self.assertEqual(partitioner.num_partitions, 2) + p0 = partitioner.load_partition(0) + p1 = partitioner.load_partition(1) + self.assertEqual(p0.column_names, ["label"]) + self.assertIn("f1", p1.column_names) + self.assertIn("f2", p1.column_names) + + def test_partitioning_with_nonexistent_shared_column(self) -> None: + """Check ValueError if shared column does not exist.""" + columns = ["f1", "f2"] + dataset = _create_dummy_dataset(columns) + partitioner = VerticalSizePartitioner( + partition_sizes=[1], shared_columns=["nonexistent"], shuffle=False + ) + partitioner.dataset = dataset + with self.assertRaises(ValueError): + partitioner.load_partition(0) + + def test_partitioning_with_nonexistent_active_party_column(self) -> None: + """Check ValueError if active party column does not exist.""" + columns = ["f1", "f2"] + dataset = _create_dummy_dataset(columns) + partitioner = VerticalSizePartitioner( + partition_sizes=[1], active_party_column="missing_label", shuffle=False + ) + partitioner.dataset = dataset + with self.assertRaises(ValueError): + partitioner.load_partition(0) + + def test_sum_of_int_partition_sizes_exceeds_num_columns(self) -> None: + """Check ValueError if sum of int sizes > total columns.""" + columns = ["f1", "f2"] + dataset = _create_dummy_dataset(columns) + partitioner = VerticalSizePartitioner(partition_sizes=[3], shuffle=False) + partitioner.dataset = dataset + with self.assertRaises(ValueError): + partitioner.load_partition(0) + + def test_sum_of_int_partition_sizes_indirectly_exceeds_num_columns(self) -> None: + """Check ValueError if sum of int sizes > total columns.""" + columns = ["f1", "f2", "f3"] + dataset = _create_dummy_dataset(columns) + partitioner = VerticalSizePartitioner( + partition_sizes=[1, 1], drop_columns=["f3", "f2"], shuffle=False + ) + partitioner.dataset = dataset + with self.assertRaises(ValueError): + partitioner.load_partition(0) + + def test_sum_of_int_partition_sizes_is_smaller_than_num_columns(self) -> None: + """Check ValueError if sum of int sizes < total columns.""" + columns = ["f1", "f2", "f3"] + dataset = _create_dummy_dataset(columns) + partitioner = VerticalSizePartitioner(partition_sizes=[2], shuffle=False) + partitioner.dataset = dataset + with self.assertRaises(ValueError): + partitioner.load_partition(0) + + +if __name__ == "__main__": + unittest.main() diff --git a/datasets/pyproject.toml b/datasets/pyproject.toml index 2d699c5e901b..497f89a2f7ca 100644 --- a/datasets/pyproject.toml +++ b/datasets/pyproject.toml @@ -34,6 +34,8 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: Implementation :: CPython", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Artificial Intelligence", diff --git a/src/py/flwr/cli/new/templates/app/.gitignore.tpl b/src/py/flwr/cli/new/templates/app/.gitignore.tpl index 68bc17f9ff21..f791a9b679d8 100644 --- a/src/py/flwr/cli/new/templates/app/.gitignore.tpl +++ b/src/py/flwr/cli/new/templates/app/.gitignore.tpl @@ -3,6 +3,9 @@ __pycache__/ *.py[cod] *$py.class +# Flower directory +.flwr + # C extensions *.so diff --git a/src/py/flwr/cli/utils.py b/src/py/flwr/cli/utils.py index 8cb89255ed40..e01a0439c9da 100644 --- a/src/py/flwr/cli/utils.py +++ b/src/py/flwr/cli/utils.py @@ -22,7 +22,7 @@ from contextlib import contextmanager from logging import DEBUG from pathlib import Path -from typing import Any, Callable, Optional, cast +from typing import Any, Callable, Optional, Union, cast import grpc import typer @@ -148,23 +148,69 @@ def sanitize_project_name(name: str) -> str: return sanitized_name -def get_sha256_hash(file_path: Path) -> str: +def get_sha256_hash(file_path_or_int: Union[Path, int]) -> str: """Calculate the SHA-256 hash of a file.""" sha256 = hashlib.sha256() - with open(file_path, "rb") as f: - while True: - data = f.read(65536) # Read in 64kB blocks - if not data: - break - sha256.update(data) + if isinstance(file_path_or_int, Path): + with open(file_path_or_int, "rb") as f: + while True: + data = f.read(65536) # Read in 64kB blocks + if not data: + break + sha256.update(data) + elif isinstance(file_path_or_int, int): + sha256.update(str(file_path_or_int).encode()) return sha256.hexdigest() def get_user_auth_config_path(root_dir: Path, federation: str) -> Path: - """Return the path to the user auth config file.""" + """Return the path to the user auth config file. + + Additionally, a `.gitignore` file will be created in the Flower directory to + include the `.credentials` folder to be excluded from git. If the `.gitignore` + file already exists, a warning will be displayed if the `.credentials` entry is + not found. + """ # Locate the credentials directory - credentials_dir = root_dir.absolute() / FLWR_DIR / CREDENTIALS_DIR + abs_flwr_dir = root_dir.absolute() / FLWR_DIR + credentials_dir = abs_flwr_dir / CREDENTIALS_DIR credentials_dir.mkdir(parents=True, exist_ok=True) + + # Determine the absolute path of the Flower directory for .gitignore + gitignore_path = abs_flwr_dir / ".gitignore" + credential_entry = CREDENTIALS_DIR + + try: + if gitignore_path.exists(): + with open(gitignore_path, encoding="utf-8") as gitignore_file: + lines = gitignore_file.read().splitlines() + + # Warn if .credentials is not already in .gitignore + if credential_entry not in lines: + typer.secho( + f"`.gitignore` exists, but `{credential_entry}` entry not found. " + "Consider adding it to your `.gitignore` to exclude Flower " + "credentials from git.", + fg=typer.colors.YELLOW, + bold=True, + ) + else: + typer.secho( + f"Creating a new `.gitignore` with `{credential_entry}` entry...", + fg=typer.colors.BLUE, + ) + # Create a new .gitignore with .credentials + with open(gitignore_path, "w", encoding="utf-8") as gitignore_file: + gitignore_file.write(f"{credential_entry}\n") + except Exception as err: + typer.secho( + "❌ An error occurred while handling `.gitignore.` " + f"Please check the permissions of `{gitignore_path}` and try again.", + fg=typer.colors.RED, + bold=True, + ) + raise typer.Exit(code=1) from err + return credentials_dir / f"{federation}.json" diff --git a/src/py/flwr/cli/utils_test.py b/src/py/flwr/cli/utils_test.py new file mode 100644 index 000000000000..e722dee70c3c --- /dev/null +++ b/src/py/flwr/cli/utils_test.py @@ -0,0 +1,109 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for Flower command line interface utils.""" + + +import hashlib +import os +import tempfile +import unittest +from pathlib import Path + +from flwr.cli.utils import get_sha256_hash + + +class TestGetSHA256Hash(unittest.TestCase): + """Unit tests for `get_sha256_hash` function.""" + + def test_hash_with_integer(self) -> None: + """Test the SHA-256 hash calculation when input is an integer.""" + # Prepare + test_int = 13413 + expected_hash = hashlib.sha256(str(test_int).encode()).hexdigest() + + # Execute + result = get_sha256_hash(test_int) + + # Assert + self.assertEqual(result, expected_hash) + + def test_hash_with_file(self) -> None: + """Test the SHA-256 hash calculation when input is a file path.""" + # Prepare - Create a temporary file with known content + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + temp_file.write(b"Test content for SHA-256 hashing.") + temp_file_path = Path(temp_file.name) + + try: + # Execute + sha256 = hashlib.sha256() + with open(temp_file_path, "rb") as f: + while True: + data = f.read(65536) + if not data: + break + sha256.update(data) + expected_hash = sha256.hexdigest() + + result = get_sha256_hash(temp_file_path) + + # Assert + self.assertEqual(result, expected_hash) + finally: + # Clean up the temporary file + os.remove(temp_file_path) + + def test_empty_file(self) -> None: + """Test the SHA-256 hash calculation for an empty file.""" + # Prepare + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + temp_file_path = Path(temp_file.name) + + try: + # Execute + expected_hash = hashlib.sha256(b"").hexdigest() + result = get_sha256_hash(temp_file_path) + + # Assert + self.assertEqual(result, expected_hash) + finally: + os.remove(temp_file_path) + + def test_large_file(self) -> None: + """Test the SHA-256 hash calculation for a large file.""" + # Prepare - Generate large data (e.g., 10 MB) + large_data = b"a" * (10 * 1024 * 1024) # 10 MB of 'a's + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + temp_file.write(large_data) + temp_file_path = Path(temp_file.name) + + try: + expected_hash = hashlib.sha256(large_data).hexdigest() + # Execute + result = get_sha256_hash(temp_file_path) + + # Assert + self.assertEqual(result, expected_hash) + finally: + os.remove(temp_file_path) + + def test_nonexistent_file(self) -> None: + """Test the SHA-256 hash calculation when the file does not exist.""" + # Prepare + nonexistent_path = Path("/path/to/nonexistent/file.txt") + + # Execute & assert + with self.assertRaises(FileNotFoundError): + get_sha256_hash(nonexistent_path) diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index 91514f845651..b5c7ae95e224 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -374,6 +374,7 @@ def run_superlink() -> None: server_public_key, ) = maybe_keys state = state_factory.state() + state.clear_supernode_auth_keys_and_credentials() state.store_node_public_keys(node_public_keys) state.store_server_private_public_key( private_key_to_bytes(server_private_key), diff --git a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index f26bb11a4bdb..d22072b41621 100644 --- a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -430,6 +430,13 @@ def get_server_public_key(self) -> Optional[bytes]: """Retrieve `server_public_key` in urlsafe bytes.""" return self.server_public_key + def clear_supernode_auth_keys_and_credentials(self) -> None: + """Clear stored `node_public_keys` and credentials in the link state if any.""" + with self.lock: + self.server_private_key = None + self.server_public_key = None + self.node_public_keys.clear() + def store_node_public_keys(self, public_keys: set[bytes]) -> None: """Store a set of `node_public_keys` in the link state.""" with self.lock: diff --git a/src/py/flwr/server/superlink/linkstate/linkstate.py b/src/py/flwr/server/superlink/linkstate/linkstate.py index ae9d1710f069..4f3c16a5460a 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate.py @@ -284,6 +284,10 @@ def get_server_private_key(self) -> Optional[bytes]: def get_server_public_key(self) -> Optional[bytes]: """Retrieve `server_public_key` in urlsafe bytes.""" + @abc.abstractmethod + def clear_supernode_auth_keys_and_credentials(self) -> None: + """Clear stored `node_public_keys` and credentials in the link state if any.""" + @abc.abstractmethod def store_node_public_keys(self, public_keys: set[bytes]) -> None: """Store a set of `node_public_keys` in the link state.""" diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index 93f5d94daef7..3edaf72ec20c 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -820,6 +820,29 @@ def test_store_server_private_public_key_twice(self) -> None: new_private_key_bytes, new_public_key_bytes ) + def test_clear_supernode_auth_keys_and_credentials(self) -> None: + """Test clear_supernode_auth_keys_and_credentials from linkstate.""" + # Prepare + state: LinkState = self.state_factory() + key_pairs = [generate_key_pairs() for _ in range(3)] + public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} + + # Execute (store) + state.store_node_public_keys(public_keys) + private_key, public_key = generate_key_pairs() + private_key_bytes = private_key_to_bytes(private_key) + public_key_bytes = public_key_to_bytes(public_key) + state.store_server_private_public_key(private_key_bytes, public_key_bytes) + + # Execute (clear) + state.clear_supernode_auth_keys_and_credentials() + node_public_keys = state.get_node_public_keys() + + # Assert + assert node_public_keys == set() + assert state.get_server_private_key() is None + assert state.get_server_public_key() is None + def test_node_public_keys(self) -> None: """Test store_node_public_keys and get_node_public_keys from state.""" # Prepare diff --git a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py index 99334233319d..e8311dfaac5e 100644 --- a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py @@ -818,6 +818,12 @@ def get_server_public_key(self) -> Optional[bytes]: public_key = None return public_key + def clear_supernode_auth_keys_and_credentials(self) -> None: + """Clear stored `node_public_keys` and credentials in the link state if any.""" + queries = ["DELETE FROM public_key;", "DELETE FROM credential;"] + for query in queries: + self.query(query) + def store_node_public_keys(self, public_keys: set[bytes]) -> None: """Store a set of `node_public_keys` in the link state.""" query = "INSERT INTO public_key (public_key) VALUES (?)"