diff --git a/data/utils/datasets.py b/data/utils/datasets.py index f5244f5..b9ea159 100644 --- a/data/utils/datasets.py +++ b/data/utils/datasets.py @@ -59,7 +59,7 @@ def __init__( root / "targets.npy" ): raise RuntimeError( - "run data/utils/run.py -d femnist for generating the data.npy and targets.npy first." + "Please run generate_data.py -d synthetic for generating the data.npy and targets.npy first." ) data = np.load(root / "data.npy") @@ -83,7 +83,7 @@ def __init__(self, root, *args, **kwargs) -> None: root / "targets.npy" ): raise RuntimeError( - "run data/utils/run.py -d femnist for generating the data.npy and targets.npy first." + "Please run generate_data.py -d synthetic for generating the data.npy and targets.npy first." ) data = np.load(root / "data.npy") @@ -111,7 +111,7 @@ def __init__( root / "targets.npy" ): raise RuntimeError( - "run data/utils/run.py -d femnist for generating the data.npy and targets.npy first." + "Please run generate_data.py -d synthetic for generating the data.npy and targets.npy first." ) data = np.load(root / "data.npy") diff --git a/data/utils/process.py b/data/utils/process.py index 7fb43b6..78eac5c 100644 --- a/data/utils/process.py +++ b/data/utils/process.py @@ -10,6 +10,8 @@ from matplotlib import pyplot as plt from PIL import Image +from data.utils.datasets import FEMNIST, CelebA, Synthetic + DATA_ROOT = Path(__file__).parent.parent.absolute() @@ -81,10 +83,10 @@ def process_femnist(args, partition: Dict, stats: Dict): data_indices = {} clients_4_train, clients_4_test = None, None with open(DATA_ROOT / "femnist" / "preprocess_args.json", "r") as f: - args = json.load(f) + preprocess_args = json.load(f) # load data of train clients - if args["t"] == "sample": + if preprocess_args["t"] == "sample": train_filename_list = sorted(os.listdir(train_dir)) test_filename_list = sorted(os.listdir(test_dir)) for train_js_file, test_js_file in zip(train_filename_list, test_filename_list): @@ -103,7 +105,7 @@ def process_femnist(args, partition: Dict, stats: Dict): targets = train_targets + test_targets all_data.append(np.array(data)) all_targets.append(np.array(targets)) - partition["data_indices"][client_cnt] = { + data_indices[client_cnt] = { "train": list(range(data_cnt, data_cnt + len(train_data))), "test": list( range(data_cnt + len(train_data), data_cnt + len(data)) @@ -196,6 +198,14 @@ def process_femnist(args, partition: Dict, stats: Dict): } partition["data_indices"] = [indices for indices in data_indices.values()] args.client_num = client_cnt + return FEMNIST( + root=DATA_ROOT / "femnist", + args=None, + general_data_transform=None, + general_target_transform=None, + train_data_transform=None, + train_target_transform=None, + ) def process_celeba(args, partition: Dict, stats: Dict): @@ -217,9 +227,9 @@ def process_celeba(args, partition: Dict, stats: Dict): clients_4_test, clients_4_train = None, None with open(DATA_ROOT / "celeba" / "preprocess_args.json") as f: - args = json.load(f) + preprocess_args = json.load(f) - if args["t"] == "sample": + if preprocess_args["t"] == "sample": for client_cnt, ori_id in enumerate(train["users"]): stats[client_cnt] = {"x": None, "y": None} train_data = np.stack( @@ -360,6 +370,14 @@ def process_celeba(args, partition: Dict, stats: Dict): } partition["data_indices"] = [indices for indices in data_indices.values()] args.client_num = client_cnt + return CelebA( + root=DATA_ROOT / "celeba", + args=None, + general_data_transform=None, + general_target_transform=None, + train_data_transform=None, + train_target_transform=None, + ) def generate_synthetic_data(args, partition: Dict, stats: Dict): @@ -440,6 +458,7 @@ def softmax(x): "std": num_samples.mean().item(), "stddev": num_samples.std().item(), } + return Synthetic(root=DATA_ROOT / "synthetic") def exclude_domain( diff --git a/generate_data.py b/generate_data.py index d267bcf..3be017e 100644 --- a/generate_data.py +++ b/generate_data.py @@ -23,7 +23,7 @@ allocate_shards, semantic_partition, ) -from data.utils.datasets import DATASETS +from data.utils.datasets import DATASETS, BaseDataset CURRENT_DIR = Path(__file__).parent.absolute() @@ -36,22 +36,23 @@ def main(args): if not os.path.isdir(dataset_root): os.mkdir(dataset_root) - dataset = DATASETS[args.dataset](dataset_root, args) - targets = np.array(dataset.targets, dtype=np.int32) - label_set = set(range(len(dataset.classes))) client_num = args.client_num partition = {"separation": None, "data_indices": [[] for _ in range(client_num)]} stats = {} - + dataset: BaseDataset = None + if args.dataset == "femnist": - process_femnist(args, partition, stats) + dataset = process_femnist(args, partition, stats) elif args.dataset == "celeba": - process_celeba(args, partition, stats) + dataset = process_celeba(args, partition, stats) elif args.dataset == "synthetic": - generate_synthetic_data(args, partition, stats) + dataset = generate_synthetic_data(args, partition, stats) else: # MEDMNIST, COVID, MNIST, CIFAR10, ... # NOTE: If `args.ood_domains`` is not empty, then FL-bench will map all labels (class space) to the domain space # and partition data according to the new `targets` array. + dataset = DATASETS[args.dataset](dataset_root, args) + targets = np.array(dataset.targets, dtype=np.int32) + label_set = set(range(len(dataset.classes))) if args.dataset in ["domain"] and args.ood_domains: metadata = json.load(open(dataset_root / "metadata.json", "r")) label_set, targets, client_num = exclude_domain(