Skip to content

Commit

Permalink
🐞 fix(data): Fix bug in processing FEMNIST, CelebA
Browse files Browse the repository at this point in the history
  • Loading branch information
KarhouTam committed Feb 29, 2024
1 parent 7ebd658 commit 234f25d
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 16 deletions.
6 changes: 3 additions & 3 deletions data/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
root / "targets.npy"
):
raise RuntimeError(
"run data/utils/run.py -d femnist for generating the data.npy and targets.npy first."
"Please run generate_data.py -d synthetic for generating the data.npy and targets.npy first."
)

data = np.load(root / "data.npy")
Expand All @@ -83,7 +83,7 @@ def __init__(self, root, *args, **kwargs) -> None:
root / "targets.npy"
):
raise RuntimeError(
"run data/utils/run.py -d femnist for generating the data.npy and targets.npy first."
"Please run generate_data.py -d synthetic for generating the data.npy and targets.npy first."
)

data = np.load(root / "data.npy")
Expand Down Expand Up @@ -111,7 +111,7 @@ def __init__(
root / "targets.npy"
):
raise RuntimeError(
"run data/utils/run.py -d femnist for generating the data.npy and targets.npy first."
"Please run generate_data.py -d synthetic for generating the data.npy and targets.npy first."
)

data = np.load(root / "data.npy")
Expand Down
29 changes: 24 additions & 5 deletions data/utils/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from matplotlib import pyplot as plt
from PIL import Image

from data.utils.datasets import FEMNIST, CelebA, Synthetic

DATA_ROOT = Path(__file__).parent.parent.absolute()


Expand Down Expand Up @@ -81,10 +83,10 @@ def process_femnist(args, partition: Dict, stats: Dict):
data_indices = {}
clients_4_train, clients_4_test = None, None
with open(DATA_ROOT / "femnist" / "preprocess_args.json", "r") as f:
args = json.load(f)
preprocess_args = json.load(f)

# load data of train clients
if args["t"] == "sample":
if preprocess_args["t"] == "sample":
train_filename_list = sorted(os.listdir(train_dir))
test_filename_list = sorted(os.listdir(test_dir))
for train_js_file, test_js_file in zip(train_filename_list, test_filename_list):
Expand All @@ -103,7 +105,7 @@ def process_femnist(args, partition: Dict, stats: Dict):
targets = train_targets + test_targets
all_data.append(np.array(data))
all_targets.append(np.array(targets))
partition["data_indices"][client_cnt] = {
data_indices[client_cnt] = {
"train": list(range(data_cnt, data_cnt + len(train_data))),
"test": list(
range(data_cnt + len(train_data), data_cnt + len(data))
Expand Down Expand Up @@ -196,6 +198,14 @@ def process_femnist(args, partition: Dict, stats: Dict):
}
partition["data_indices"] = [indices for indices in data_indices.values()]
args.client_num = client_cnt
return FEMNIST(
root=DATA_ROOT / "femnist",
args=None,
general_data_transform=None,
general_target_transform=None,
train_data_transform=None,
train_target_transform=None,
)


def process_celeba(args, partition: Dict, stats: Dict):
Expand All @@ -217,9 +227,9 @@ def process_celeba(args, partition: Dict, stats: Dict):
clients_4_test, clients_4_train = None, None

with open(DATA_ROOT / "celeba" / "preprocess_args.json") as f:
args = json.load(f)
preprocess_args = json.load(f)

if args["t"] == "sample":
if preprocess_args["t"] == "sample":
for client_cnt, ori_id in enumerate(train["users"]):
stats[client_cnt] = {"x": None, "y": None}
train_data = np.stack(
Expand Down Expand Up @@ -360,6 +370,14 @@ def process_celeba(args, partition: Dict, stats: Dict):
}
partition["data_indices"] = [indices for indices in data_indices.values()]
args.client_num = client_cnt
return CelebA(
root=DATA_ROOT / "celeba",
args=None,
general_data_transform=None,
general_target_transform=None,
train_data_transform=None,
train_target_transform=None,
)


def generate_synthetic_data(args, partition: Dict, stats: Dict):
Expand Down Expand Up @@ -440,6 +458,7 @@ def softmax(x):
"std": num_samples.mean().item(),
"stddev": num_samples.std().item(),
}
return Synthetic(root=DATA_ROOT / "synthetic")


def exclude_domain(
Expand Down
17 changes: 9 additions & 8 deletions generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
allocate_shards,
semantic_partition,
)
from data.utils.datasets import DATASETS
from data.utils.datasets import DATASETS, BaseDataset

CURRENT_DIR = Path(__file__).parent.absolute()

Expand All @@ -36,22 +36,23 @@ def main(args):
if not os.path.isdir(dataset_root):
os.mkdir(dataset_root)

dataset = DATASETS[args.dataset](dataset_root, args)
targets = np.array(dataset.targets, dtype=np.int32)
label_set = set(range(len(dataset.classes)))
client_num = args.client_num
partition = {"separation": None, "data_indices": [[] for _ in range(client_num)]}
stats = {}

dataset: BaseDataset = None

if args.dataset == "femnist":
process_femnist(args, partition, stats)
dataset = process_femnist(args, partition, stats)
elif args.dataset == "celeba":
process_celeba(args, partition, stats)
dataset = process_celeba(args, partition, stats)
elif args.dataset == "synthetic":
generate_synthetic_data(args, partition, stats)
dataset = generate_synthetic_data(args, partition, stats)
else: # MEDMNIST, COVID, MNIST, CIFAR10, ...
# NOTE: If `args.ood_domains`` is not empty, then FL-bench will map all labels (class space) to the domain space
# and partition data according to the new `targets` array.
dataset = DATASETS[args.dataset](dataset_root, args)
targets = np.array(dataset.targets, dtype=np.int32)
label_set = set(range(len(dataset.classes)))
if args.dataset in ["domain"] and args.ood_domains:
metadata = json.load(open(dataset_root / "metadata.json", "r"))
label_set, targets, client_num = exclude_domain(
Expand Down

0 comments on commit 234f25d

Please sign in to comment.