Skip to content

Commit

Permalink
feat(framework): Add external flower partitioner support (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
wittenator authored Dec 20, 2024
1 parent 7070f61 commit bb2fcde
Show file tree
Hide file tree
Showing 8 changed files with 506 additions and 177 deletions.
586 changes: 409 additions & 177 deletions .env/poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions .env/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ ray = { extras = ["default"], version = "2.36.1" }
tensorboard = "^2.17.1"
cvxpy = "^1.5.1"
hydra-core = "^1.3.2"
flwr-datasets = "^0.4.0"
statsmodels = "^0.14.4"
pytorch-minimize = "^0.0.2"

Expand Down
1 change: 1 addition & 0 deletions .env/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ ray[default]~=2.38.0
tensorboard~=2.17.1
cvxpy~=1.5.1
hydra-core~=1.3.2
flwr-datasets~=0.4.0
statsmodels~=0.14.4
pytorch-minimize~=0.0.2
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ out
multirun

# datasets

data/cifar10
data/cifar100
data/mnist
Expand Down
8 changes: 8 additions & 0 deletions data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ python generate_data.py -d cifar10 -sm 1 -cn 20
```
<img src="../.github/images/distributions/semantic.png" alt="Image" width="350"/>

## Flower Partitioner

This benchmark also supports external partitioners provided by [flwr_datasets](https://flower.ai/docs/datasets/), enabling the comparison with built-in partitioning schemes and additional schemese that exist in flwr_datasets. To use flwr partitioners, you need to specify the class path of the partitioner you want to use and all its parameters in a seperate dictionary. This is how you would use the DirichletPartitioner from flwr:
Attention: To use flwr's partitioners, internally a mock dataset is created that has a column called "label". If the partitioning scheme depends on label information, please insert "label" as the label column.
```shell
python generate_data.py -d cifar10 -cn 10 -fpc "flwr_datasets.partitioner.DirichletPartitioner" -fpk '{"alpha": 100.0, "partition_by": "label"}'
```

# Usage 🚀

## Synthetic Dataset in FedProx
Expand Down
18 changes: 18 additions & 0 deletions data/utils/process.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import json
import os
from argparse import Namespace
Expand Down Expand Up @@ -551,3 +552,20 @@ def plot_distribution(client_num: int, label_counts: np.ndarray, save_path: str)
ax.spines["top"].set_visible(False)
ax.legend(bbox_to_anchor=(1.2, 1))
plt.savefig(save_path, bbox_inches="tight")

def class_from_string(class_string: str) -> type:
"""
Dynamically loads a class from a string representation.
Args:
class_string (str): The string representation of the class, including the module path.
Returns:
type: The loaded class.
Example:
class_from_string('path.to.module.ClassName') returns the class 'ClassName' from the module 'path.to.module'.
"""
module = importlib.import_module('.'.join(class_string.split('.')[:-1]))
class_ = getattr(module, class_string.split('.')[-1])
return class_
52 changes: 52 additions & 0 deletions data/utils/schemes/flower.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@

from collections import Counter
import json
import numpy as np
import datasets

from data.utils.process import class_from_string


def flower_partition(
targets: np.ndarray,
target_indices: np.ndarray,
label_set: set,
client_num: int,
flower_partitioner_class: str,
flower_partitioner_kwargs: str,
partition: dict,
stats: dict,
):
target_indices = [i for i in range(len(target_indices)) if targets[i] in label_set]
targets = targets[target_indices]
data = {
"data_indices": target_indices,
"label": targets
}

# Create a Hugging Face Dataset
dataset = datasets.Dataset.from_dict(data)

flower_partitioner_kwargs = json.loads(flower_partitioner_kwargs)
partitioner_class = class_from_string(flower_partitioner_class)
partitioner = partitioner_class(num_partitions=client_num, **flower_partitioner_kwargs)

# Assign the dataset to the partitioner
partitioner.dataset = dataset
num_samples = []

# Print each partition and the samples it contains
for i in range(client_num):
partition_i = partitioner.load_partition(i)
indices = partition_i["data_indices"]
partition["data_indices"][i] = indices
stats[i] = {"x": None, "y": None}
stats[i]["x"] = len(indices)
stats[i]["y"] = dict(Counter(targets[indices].tolist()))
num_samples.append(len(partition_i))

num_samples = np.array(num_samples)
stats["samples_per_client"] = {
"std": num_samples.mean().item(),
"stddev": num_samples.std().item(),
}
16 changes: 16 additions & 0 deletions generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import numpy as np

from data.utils.schemes.flower import flower_partition
from src.utils.tools import fix_random_seed
from data.utils.process import (
exclude_domain,
Expand Down Expand Up @@ -153,6 +154,17 @@ def main(args):
partition=partition,
stats=stats,
)
elif args.flower_partitioner_class != "":
flower_partition(
targets=targets[target_indices],
target_indices=target_indices,
label_set=valid_label_set,
client_num=client_num,
flower_partitioner_class=args.flower_partitioner_class,
flower_partitioner_kwargs=args.flower_partitioner_kwargs,
partition=partition,
stats=stats,
)
elif args.dataset in ["domain"] and args.ood_domains is None:
with open(dataset_root / "original_partition.pkl", "rb") as f:
partition = {}
Expand Down Expand Up @@ -356,6 +368,10 @@ def _idx_2_domain_label(index):
parser.add_argument("-a", "--alpha", type=float, default=0)
parser.add_argument("-ms", "--min_samples_per_client", type=int, default=10)

# Flower partitioner
parser.add_argument("-fpc", "--flower_partitioner_class", type=str, default="")
parser.add_argument("-fpk", "--flower_partitioner_kwargs", type=str, default="{}")

# For synthetic data only
parser.add_argument("--gamma", type=float, default=0.5)
parser.add_argument("--beta", type=float, default=0.5)
Expand Down

0 comments on commit bb2fcde

Please sign in to comment.