-
-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #142 from KarhouTam/dev
Periodic update
- Loading branch information
Showing
45 changed files
with
1,545 additions
and
469 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ out | |
multirun | ||
|
||
# datasets | ||
|
||
data/cifar10 | ||
data/cifar100 | ||
data/mnist | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,68 +1,81 @@ | ||
from collections import Counter | ||
from typing import Any, Dict, Set | ||
|
||
import numpy as np | ||
|
||
|
||
def dirichlet( | ||
targets: np.ndarray, | ||
target_indices: np.ndarray, | ||
label_set: set, | ||
label_set: Set[int], | ||
client_num: int, | ||
alpha: float, | ||
least_samples: int, | ||
partition: dict, | ||
stats: dict, | ||
min_samples_per_client: int, | ||
partition: Dict[str, Any], | ||
stats: Dict[int, Dict[str, Any]], | ||
): | ||
"""Partition dataset according to Dirichlet with concentration parameter | ||
`alpha`. | ||
"""Partition the dataset according to the Dirichlet distribution using a | ||
specified concentration parameter, `alpha`. | ||
Args: | ||
targets (np.ndarray): Data label array. | ||
target_indices (np.ndarray): Indices of targets. If you haven't set `--iid`, then it will be np.arange(len(targets)) | ||
Otherwise, it will be the absolute indices of the full targets. | ||
label_set (set): Label set. | ||
targets (np.ndarray): Array of data labels. | ||
target_indices (np.ndarray): Indices of targets. If not set to `--iid`, it will be np.arange(len(targets)). | ||
Otherwise, it holds the absolute indices of the full targets. | ||
label_set (Set[int]): Set of unique labels. | ||
client_num (int): Number of clients. | ||
alpha (float): Concentration parameter. Smaller alpha indicates strong data heterogeneity. | ||
least_samples (int): Lease number of data samples each client should have. | ||
partition (Dict): Output data indices dict. | ||
stats (Dict): Output dict that recording clients data distribution. | ||
alpha (float): Concentration parameter; smaller values indicate stronger data heterogeneity. | ||
min_samples_per_client (int): Minimum number of data samples each client should have. | ||
partition (Dict[str, Any]): Dictionary to hold output data indices for each client. | ||
stats (Dict[int, Dict[str, Any]]): Dictionary to record clients' data distribution. | ||
""" | ||
|
||
min_size = 0 | ||
indices_4_labels = {i: np.where(targets == i)[0] for i in label_set} | ||
# Map each label to its corresponding indices in the target array | ||
indices_per_label = {label: np.where(targets == label)[0] for label in label_set} | ||
|
||
while min_size < least_samples: | ||
# Initialize data indices for each client | ||
while min_size < min_samples_per_client: | ||
# Initialize empty lists to hold data indices for each client | ||
partition["data_indices"] = [[] for _ in range(client_num)] | ||
|
||
# Iterate over each label in the label set | ||
# Iterate through each label in the label_set | ||
for label in label_set: | ||
# Shuffle the indices associated with the current label | ||
np.random.shuffle(indices_4_labels[label]) | ||
# Generate a Dirichlet distribution for splitting data among clients | ||
# Shuffle the indices corresponding to the current label | ||
np.random.shuffle(indices_per_label[label]) | ||
|
||
# Generate a Dirichlet distribution for partitioning data among clients | ||
distribution = np.random.dirichlet(np.repeat(alpha, client_num)) | ||
|
||
# Calculate split indices based on the generated distribution | ||
cumulative_indices = np.cumsum(distribution) * len(indices_4_labels[label]) | ||
split_indices_position = cumulative_indices.astype(int)[:-1] | ||
|
||
# Split the indices for the current label | ||
split_indices = np.split(indices_4_labels[label], split_indices_position) | ||
|
||
# Assign split indices to each client | ||
|
||
# Calculate the cumulative distribution to get split indices | ||
cumulative_distribution = np.cumsum(distribution) * len( | ||
indices_per_label[label] | ||
) | ||
split_indices_position = cumulative_distribution.astype(int)[:-1] | ||
|
||
# Split the indices based on the calculated positions | ||
split_indices = np.split(indices_per_label[label], split_indices_position) | ||
|
||
# Assign the split indices to each client | ||
for client_id in range(client_num): | ||
partition["data_indices"][client_id].extend(split_indices[client_id]) | ||
|
||
# Update the minimum size of the data across all clients | ||
min_size = min(len(idx) for idx in partition["data_indices"]) | ||
# Update the minimum number of samples across all clients | ||
min_size = min(len(indices) for indices in partition["data_indices"]) | ||
|
||
# Gather statistics and prepare the output for each client | ||
for client_id in range(client_num): | ||
stats[client_id]["x"] = len(targets[partition["data_indices"][client_id]]) | ||
stats[client_id]["y"] = dict( | ||
Counter(targets[partition["data_indices"][client_id]].tolist()) | ||
) | ||
|
||
for i in range(client_num): | ||
stats[i]["x"] = len(targets[partition["data_indices"][i]]) | ||
stats[i]["y"] = dict(Counter(targets[partition["data_indices"][i]].tolist())) | ||
partition["data_indices"][i] = target_indices[partition["data_indices"][i]] | ||
# Update the data indices to use the original target indices | ||
partition["data_indices"][client_id] = target_indices[ | ||
partition["data_indices"][client_id] | ||
] | ||
|
||
sample_num = np.array(list(map(lambda stat_i: stat_i["x"], stats.values()))) | ||
# Calculate the number of samples for each client and update statistics | ||
sample_counts = np.array([stat["x"] for stat in stats.values()]) | ||
stats["samples_per_client"] = { | ||
"std": sample_num.mean().item(), | ||
"stddev": sample_num.std().item(), | ||
"mean": sample_counts.mean().item(), | ||
"stddev": sample_counts.std().item(), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
|
||
from collections import Counter | ||
import json | ||
import numpy as np | ||
import datasets | ||
|
||
from data.utils.process import class_from_string | ||
|
||
|
||
def flower_partition( | ||
targets: np.ndarray, | ||
target_indices: np.ndarray, | ||
label_set: set, | ||
client_num: int, | ||
flower_partitioner_class: str, | ||
flower_partitioner_kwargs: str, | ||
partition: dict, | ||
stats: dict, | ||
): | ||
target_indices = [i for i in range(len(target_indices)) if targets[i] in label_set] | ||
targets = targets[target_indices] | ||
data = { | ||
"data_indices": target_indices, | ||
"label": targets | ||
} | ||
|
||
# Create a Hugging Face Dataset | ||
dataset = datasets.Dataset.from_dict(data) | ||
|
||
flower_partitioner_kwargs = json.loads(flower_partitioner_kwargs) | ||
partitioner_class = class_from_string(flower_partitioner_class) | ||
partitioner = partitioner_class(num_partitions=client_num, **flower_partitioner_kwargs) | ||
|
||
# Assign the dataset to the partitioner | ||
partitioner.dataset = dataset | ||
num_samples = [] | ||
|
||
# Print each partition and the samples it contains | ||
for i in range(client_num): | ||
partition_i = partitioner.load_partition(i) | ||
indices = partition_i["data_indices"] | ||
partition["data_indices"][i] = indices | ||
stats[i] = {"x": None, "y": None} | ||
stats[i]["x"] = len(indices) | ||
stats[i]["y"] = dict(Counter(targets[indices].tolist())) | ||
num_samples.append(len(partition_i)) | ||
|
||
num_samples = np.array(num_samples) | ||
stats["samples_per_client"] = { | ||
"std": num_samples.mean().item(), | ||
"stddev": num_samples.std().item(), | ||
} |
Oops, something went wrong.