From e613cda067ad48c3f8deb436741226d3e12902f8 Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Thu, 28 Mar 2024 22:37:54 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat(datasets):=20Separate=20the=20?= =?UTF-8?q?data=20preprocessing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The data preprocessings of trainset and testset are defined separately with no overlap. Add `train()` and `eval()` for `BaseDataset` that for informing FL-bench that which data preprocessing should be adopted. 🐞fix(Elastic): Remove deprecated arg in torch's lr_scheduler 🐞fix(ADCOL): Add definitions of some variables in `__init__()` --- data/utils/datasets.py | 152 ++++++++++++++++++++-------------------- src/client/adcol.py | 2 + src/client/apfl.py | 1 + src/client/ditto.py | 1 + src/client/fedavg.py | 51 ++++++++------ src/client/fedbabu.py | 1 + src/client/feddyn.py | 1 + src/client/fedfomo.py | 4 -- src/client/fedgen.py | 1 + src/client/fediir.py | 1 + src/client/fedproto.py | 1 + src/client/fedprox.py | 1 + src/client/fedrep.py | 2 + src/client/fedrod.py | 4 +- src/client/fedsr.py | 1 + src/client/knnper.py | 4 +- src/client/metafed.py | 2 + src/client/moon.py | 1 + src/client/perfedavg.py | 1 + src/client/pfedme.py | 1 + src/client/scaffold.py | 1 + src/server/adcol.py | 1 + src/server/elastic.py | 2 +- 23 files changed, 134 insertions(+), 103 deletions(-) diff --git a/data/utils/datasets.py b/data/utils/datasets.py index b9ea159..6237732 100644 --- a/data/utils/datasets.py +++ b/data/utils/datasets.py @@ -22,22 +22,28 @@ def __init__(self) -> None: self.targets: torch.Tensor = None self.train_data_transform = None self.train_target_transform = None - self.general_data_transform = None - self.general_target_transform = None - self.enable_train_transform = True + self.test_data_transform = None + self.test_target_transform = None + self.data_transform = None + self.target_transform = None def __getitem__(self, index): data, targets = self.data[index], self.targets[index] - if self.enable_train_transform and self.train_data_transform is not None: - data = self.train_data_transform(data) - if self.enable_train_transform and self.train_target_transform is not None: - targets = self.train_target_transform(targets) - if self.general_data_transform is not None: - data = self.general_data_transform(data) - if self.general_target_transform is not None: - targets = self.general_target_transform(targets) + if self.data_transform is not None: + data = self.data_transform(data) + if self.target_transform is not None: + targets = self.target_transform(targets) + return data, targets + def train(self): + self.data_transform = self.train_data_transform + self.target_transform = self.train_target_transform + + def eval(self): + self.data_transform = self.test_data_transform + self.target_transform = self.test_target_transform + def __len__(self): return len(self.targets) @@ -47,8 +53,8 @@ def __init__( self, root, args=None, - general_data_transform=None, - general_target_transform=None, + test_data_transform=None, + test_target_transform=None, train_data_transform=None, train_target_transform=None, ) -> None: @@ -68,8 +74,8 @@ def __init__( self.data = torch.from_numpy(data).float().reshape(-1, 1, 28, 28) self.targets = torch.from_numpy(targets).long() self.classes = list(range(62)) - self.general_data_transform = general_data_transform - self.general_target_transform = general_target_transform + self.test_data_transform = test_data_transform + self.test_target_transform = test_target_transform self.train_data_transform = train_data_transform self.train_target_transform = train_target_transform @@ -99,8 +105,8 @@ def __init__( self, root, args=None, - general_data_transform=None, - general_target_transform=None, + test_data_transform=None, + test_target_transform=None, train_data_transform=None, train_target_transform=None, ) -> None: @@ -119,8 +125,8 @@ def __init__( self.data = torch.from_numpy(data).permute([0, -1, 1, 2]).float() self.targets = torch.from_numpy(targets).long() - self.general_data_transform = general_data_transform - self.general_target_transform = general_target_transform + self.test_data_transform = test_data_transform + self.test_target_transform = test_target_transform self.train_data_transform = train_data_transform self.train_target_transform = train_target_transform self.classes = [0, 1] @@ -131,8 +137,8 @@ def __init__( self, root, args=None, - general_data_transform=None, - general_target_transform=None, + test_data_transform=None, + test_target_transform=None, train_data_transform=None, train_target_transform=None, ): @@ -146,8 +152,8 @@ def __init__( self.targets = ( torch.Tensor(np.load(root / "raw" / "ydata.npy")).long().squeeze() ) - self.general_data_transform = general_data_transform - self.general_target_transform = general_target_transform + self.test_data_transform = test_data_transform + self.test_target_transform = test_target_transform self.train_data_transform = train_data_transform self.train_target_transform = train_target_transform @@ -157,8 +163,8 @@ def __init__( self, root, args=None, - general_data_transform=None, - general_target_transform=None, + test_data_transform=None, + test_target_transform=None, train_data_transform=None, train_target_transform=None, ): @@ -174,8 +180,8 @@ def __init__( torch.Tensor(np.load(root / "raw" / "ydata.npy")).long().squeeze() ) self.classes = [0, 1, 2, 3] - self.general_data_transform = general_data_transform - self.general_target_transform = general_target_transform + self.test_data_transform = test_data_transform + self.test_target_transform = test_target_transform self.train_data_transform = train_data_transform self.train_target_transform = train_target_transform @@ -185,8 +191,8 @@ def __init__( self, root, args=None, - general_data_transform=None, - general_target_transform=None, + test_data_transform=None, + test_target_transform=None, train_data_transform=None, train_target_transform=None, ): @@ -203,8 +209,8 @@ def __init__( self.data = torch.cat([train_data, test_data]) self.targets = torch.cat([train_targets, test_targets]) self.classes = list(range(10)) - self.general_data_transform = general_data_transform - self.general_target_transform = general_target_transform + self.test_data_transform = test_data_transform + self.test_target_transform = test_target_transform self.train_data_transform = train_data_transform self.train_target_transform = train_target_transform @@ -214,8 +220,8 @@ def __init__( self, root, args=None, - general_data_transform=None, - general_target_transform=None, + test_data_transform=None, + test_target_transform=None, train_data_transform=None, train_target_transform=None, ): @@ -232,8 +238,8 @@ def __init__( self.data = torch.cat([train_data, test_data]) self.targets = torch.cat([train_targets, test_targets]) self.classes = list(range(10)) - self.general_data_transform = general_data_transform - self.general_target_transform = general_target_transform + self.test_data_transform = test_data_transform + self.test_target_transform = test_target_transform self.train_data_transform = train_data_transform self.train_target_transform = train_target_transform @@ -243,8 +249,8 @@ def __init__( self, root, args=None, - general_data_transform=None, - general_target_transform=None, + test_data_transform=None, + test_target_transform=None, train_data_transform=None, train_target_transform=None, ): @@ -258,8 +264,8 @@ def __init__( self.data = torch.cat([train_data, test_data]) self.targets = torch.cat([train_targets, test_targets]) self.classes = train_part.classes - self.general_data_transform = general_data_transform - self.general_target_transform = general_target_transform + self.test_data_transform = test_data_transform + self.test_target_transform = test_target_transform self.train_data_transform = train_data_transform self.train_target_transform = train_target_transform @@ -269,8 +275,8 @@ def __init__( self, root, args=None, - general_data_transform=None, - general_target_transform=None, + test_data_transform=None, + test_target_transform=None, train_data_transform=None, train_target_transform=None, ): @@ -284,8 +290,8 @@ def __init__( self.data = torch.cat([train_data, test_data]) self.targets = torch.cat([train_targets, test_targets]) self.classes = train_part.classes - self.general_data_transform = general_data_transform - self.general_target_transform = general_target_transform + self.test_data_transform = test_data_transform + self.test_target_transform = test_target_transform self.train_data_transform = train_data_transform self.train_target_transform = train_target_transform @@ -295,8 +301,8 @@ def __init__( self, root, args, - general_data_transform=None, - general_target_transform=None, + test_data_transform=None, + test_target_transform=None, train_data_transform=None, train_target_transform=None, ): @@ -319,8 +325,8 @@ def __init__( self.data = torch.cat([train_data, test_data]) self.targets = torch.cat([train_targets, test_targets]) self.classes = train_part.classes - self.general_data_transform = general_data_transform - self.general_target_transform = general_target_transform + self.test_data_transform = test_data_transform + self.test_target_transform = test_target_transform self.train_data_transform = train_data_transform self.train_target_transform = train_target_transform @@ -330,8 +336,8 @@ def __init__( self, root, args=None, - general_data_transform=None, - general_target_transform=None, + test_data_transform=None, + test_target_transform=None, train_data_transform=None, train_target_transform=None, ): @@ -345,8 +351,8 @@ def __init__( self.data = torch.cat([train_data, test_data]) self.targets = torch.cat([train_targets, test_targets]) self.classes = train_part.classes - self.general_data_transform = general_data_transform - self.general_target_transform = general_target_transform + self.test_data_transform = test_data_transform + self.test_target_transform = test_target_transform self.train_data_transform = train_data_transform self.train_target_transform = train_target_transform @@ -356,8 +362,8 @@ def __init__( self, root, args, - general_data_transform=None, - general_target_transform=None, + test_data_transform=None, + test_target_transform=None, train_data_transform=None, train_target_transform=None, ): @@ -371,8 +377,8 @@ def __init__( self.data = torch.cat([train_data, test_data]) self.targets = torch.cat([train_targets, test_targets]) self.classes = train_part.classes - self.general_data_transform = general_data_transform - self.general_target_transform = general_target_transform + self.test_data_transform = test_data_transform + self.test_target_transform = test_target_transform self.train_data_transform = train_data_transform self.train_target_transform = train_target_transform super_class = None @@ -420,8 +426,8 @@ def __init__( self, root, args=None, - general_data_transform=None, - general_target_transform=None, + test_data_transform=None, + test_target_transform=None, train_data_transform=None, train_target_transform=None, ): @@ -472,8 +478,8 @@ def __init__( self.data = torch.load(root / "data.pt") self.targets = torch.load(root / "targets.pt") - self.general_data_transform = general_data_transform - self.general_target_transform = general_target_transform + self.test_data_transform = test_data_transform + self.test_target_transform = test_target_transform self.train_data_transform = train_data_transform self.train_target_transform = train_target_transform @@ -483,8 +489,8 @@ def __init__( self, root, args=None, - general_data_transform=None, - general_target_transform=None, + test_data_transform=None, + test_target_transform=None, train_data_transform=None, train_target_transform=None, ): @@ -528,8 +534,8 @@ def __init__( self.data = torch.load(root / "data.pt") self.targets = torch.load(root / "targets.pt") - self.general_data_transform = general_data_transform - self.general_target_transform = general_target_transform + self.test_data_transform = test_data_transform + self.test_target_transform = test_target_transform self.train_data_transform = train_data_transform self.train_target_transform = train_target_transform @@ -539,8 +545,8 @@ def __init__( self, root, args=None, - general_data_transform=None, - general_target_transform=None, + test_data_transform=None, + test_target_transform=None, train_data_transform=None, train_target_transform=None, ) -> None: @@ -576,22 +582,18 @@ def __init__( transforms.ToTensor(), ] ) - self.general_data_transform = general_data_transform - self.general_target_transform = general_target_transform + self.test_data_transform = test_data_transform + self.test_target_transform = test_target_transform self.train_data_transform = train_data_transform self.train_target_transform = train_target_transform def __getitem__(self, index): data = self.pre_transform(Image.open(self.filename_list[index]).convert("RGB")) targets = self.targets[index] - if self.enable_train_transform and self.train_data_transform is not None: - data = self.train_data_transform(data) - if self.enable_train_transform and self.train_target_transform is not None: - targets = self.train_target_transform(targets) - if self.general_data_transform is not None: - data = self.general_data_transform(data) - if self.general_target_transform is not None: - targets = self.general_target_transform(targets) + if self.data_transform is not None: + data = self.data_transform(data) + if self.target_transform is not None: + targets = self.target_transform(targets) return data, targets diff --git a/src/client/adcol.py b/src/client/adcol.py index 96bd378..f2897b0 100644 --- a/src/client/adcol.py +++ b/src/client/adcol.py @@ -15,10 +15,12 @@ def __init__(self, model, discriminator, args, logger, device, client_num): self.discriminator = discriminator self.discriminator.to(self.device) self.client_num = client_num + self.featrure_list = [] def fit(self): self.model.train() self.discriminator.eval() + self.dataset.train() self.featrure_list = [] for i in range(self.local_epoch): for x, y in self.trainloader: diff --git a/src/client/apfl.py b/src/client/apfl.py index fe40c0e..7c4aa7d 100644 --- a/src/client/apfl.py +++ b/src/client/apfl.py @@ -53,6 +53,7 @@ def save_state(self): def fit(self): self.model.train() self.local_model.train() + self.dataset.train() for i in range(self.local_epoch): for x, y in self.trainloader: if len(x) <= 1: diff --git a/src/client/ditto.py b/src/client/ditto.py index 596d349..1e9ba84 100644 --- a/src/client/ditto.py +++ b/src/client/ditto.py @@ -32,6 +32,7 @@ def save_state(self): def fit(self): self.model.train() + self.dataset.train() for _ in range(self.local_epoch): for x, y in self.trainloader: if len(x) <= 1: diff --git a/src/client/fedavg.py b/src/client/fedavg.py index ab6a581..b44fb7d 100644 --- a/src/client/fedavg.py +++ b/src/client/fedavg.py @@ -15,7 +15,7 @@ from src.utils.metrics import Metrics from src.utils.models import DecoupledModel from src.utils.constants import DATA_MEAN, DATA_STD -from data.utils.datasets import DATASETS +from data.utils.datasets import DATASETS, BaseDataset class FedAvgClient: @@ -40,8 +40,18 @@ def __init__( self.data_indices: List[List[int]] = partition["data_indices"] - # --------- you can define your own data transformation strategy here ------------ - general_data_transform = transforms.Compose( + # --------- you can define your custom data preprocessing strategy here ------------ + test_data_transform = transforms.Compose( + [ + transforms.Normalize( + DATA_MEAN[self.args.dataset], DATA_STD[self.args.dataset] + ) + ] + if self.args.dataset in DATA_MEAN and self.args.dataset in DATA_STD + else [] + ) + test_target_transform = transforms.Compose([]) + train_data_transform = transforms.Compose( [ transforms.Normalize( DATA_MEAN[self.args.dataset], DATA_STD[self.args.dataset] @@ -50,26 +60,27 @@ def __init__( if self.args.dataset in DATA_MEAN and self.args.dataset in DATA_STD else [] ) - general_target_transform = transforms.Compose([]) - train_data_transform = transforms.Compose([]) train_target_transform = transforms.Compose([]) # -------------------------------------------------------------------------------- - self.dataset = DATASETS[self.args.dataset]( + self.dataset: BaseDataset = DATASETS[self.args.dataset]( root=PROJECT_DIR / "data" / args.dataset, args=args.dataset_args, - general_data_transform=general_data_transform, - general_target_transform=general_target_transform, + test_data_transform=test_data_transform, + test_target_transform=test_target_transform, train_data_transform=train_data_transform, train_target_transform=train_target_transform, ) - self.trainloader: DataLoader = None - self.valloader: DataLoader = None - self.testloader: DataLoader = None - self.trainset: Subset = Subset(self.dataset, indices=[]) - self.valset: Subset = Subset(self.dataset, indices=[]) - self.testset: Subset = Subset(self.dataset, indices=[]) + # don't bother with the [0], which is only for avoiding raising runtime error by setting indices=[] in Subset() with shuffle=True in DataLoader() + self.trainset = Subset(self.dataset, indices=[0]) + self.valset = Subset(self.dataset, indices=[]) + self.testset = Subset(self.dataset, indices=[]) + self.trainloader = DataLoader( + self.trainset, batch_size=self.args.batch_size, shuffle=True + ) + self.valloader = DataLoader(self.valset, batch_size=self.args.batch_size) + self.testloader = DataLoader(self.testset, batch_size=self.args.batch_size) self.test_flag = False self.model = model.to(self.device) @@ -102,11 +113,8 @@ def __init__( def load_dataset(self): """This function is for loading data indices for No.`self.client_id` client.""" self.trainset.indices = self.data_indices[self.client_id]["train"] - self.testset.indices = self.data_indices[self.client_id]["test"] self.valset.indices = self.data_indices[self.client_id]["val"] - self.trainloader = DataLoader(self.trainset, self.args.batch_size, shuffle=True) - self.valloader = DataLoader(self.valset, self.args.batch_size) - self.testloader = DataLoader(self.testset, self.args.batch_size) + self.testset.indices = self.data_indices[self.client_id]["test"] def train_and_log(self, verbose=False) -> Dict[str, Dict[str, float]]: """This function includes the local training and logging process. @@ -192,8 +200,8 @@ def train( new_parameters (OrderedDict[str, torch.Tensor]): Parameters of FL model. return_diff (bool, optional): - Set as `True` to send the difference between FL model parameters that before and after training; - Set as `False` to send FL model parameters without any change. Defaults to True. + `True`: to send the differences between FL model parameters that before and after training; + `False`: to send FL model parameters without any change. Defaults to True. verbose (bool, optional): Set to `True` for print logging info onto the stdout (Controled by the server by default). Defaults to False. @@ -228,6 +236,7 @@ def fit(self): If you wanna implement your method and your method has different local training operations to FedAvg, this method has to be overrided. """ self.model.train() + self.dataset.train() for _ in range(self.local_epoch): for x, y in self.trainloader: # When the current batch size is 1, the batchNorm2d modules in the model would raise error. @@ -257,6 +266,7 @@ def evaluate(self, model: torch.nn.Module = None) -> Dict[str, Metrics]: target_model = self.model if model is None else model target_model.eval() + self.dataset.eval() train_metrics = Metrics() val_metrics = Metrics() test_metrics = Metrics() @@ -326,6 +336,7 @@ def finetune(self): This function will only be activated in FL test epoches. """ self.model.train() + self.dataset.train() for _ in range(self.args.finetune_epoch): for x, y in self.trainloader: if len(x) <= 1: diff --git a/src/client/fedbabu.py b/src/client/fedbabu.py index a745dd3..22a09de 100644 --- a/src/client/fedbabu.py +++ b/src/client/fedbabu.py @@ -7,6 +7,7 @@ def __init__(self, model, args, logger, device): def fit(self): self.model.train() + self.dataset.train() for _ in range(self.local_epoch): for x, y in self.trainloader: if len(x) <= 1: diff --git a/src/client/feddyn.py b/src/client/feddyn.py index 3e53698..7c9995a 100644 --- a/src/client/feddyn.py +++ b/src/client/feddyn.py @@ -34,6 +34,7 @@ def train( def fit(self): self.model.train() + self.dataset.train() for _ in range(self.local_epoch): for x, y in self.trainloader: if len(x) <= 1: diff --git a/src/client/fedfomo.py b/src/client/fedfomo.py index a10aa15..92f28c0 100644 --- a/src/client/fedfomo.py +++ b/src/client/fedfomo.py @@ -3,7 +3,6 @@ from typing import Dict, List import torch -from torch.utils.data import DataLoader, Subset from fedavg import FedAvgClient from src.utils.tools import trainable_params, evalutate_model, vectorize @@ -16,8 +15,6 @@ def __init__(self, model, args, logger, device, client_num): self.eval_model = deepcopy(self.model) self.weight_vector = torch.zeros(client_num, device=self.device) self.trainable_params_name = trainable_params(self.model, requires_name=True)[1] - self.valset = Subset(self.dataset, indices=[]) - self.valloader: DataLoader = None def train( self, @@ -42,7 +39,6 @@ def load_dataset(self): num_val_samples = int(len(self.trainset) * self.args.valset_ratio) self.valset.indices = self.trainset.indices[:num_val_samples] self.trainset.indices = self.trainset.indices[num_val_samples:] - self.valloader = DataLoader(self.valset, 32, shuffle=True) def set_parameters(self, received_params: Dict[int, List[torch.Tensor]]): local_params_dict = OrderedDict( diff --git a/src/client/fedgen.py b/src/client/fedgen.py index 318ff59..f181501 100644 --- a/src/client/fedgen.py +++ b/src/client/fedgen.py @@ -57,6 +57,7 @@ def set_parameters(self, new_parameters: OrderedDict[str, torch.Tensor]): def fit(self): self.model.train() self.generator.train() + self.dataset.train() for _ in range(self.local_epoch): for x, y in self.trainloader: if len(y) <= 1: diff --git a/src/client/fediir.py b/src/client/fediir.py index fddbaf2..6d69b3b 100644 --- a/src/client/fediir.py +++ b/src/client/fediir.py @@ -11,6 +11,7 @@ def __init__(self, model, args, logger, device): def fit(self): self.model.train() + self.dataset.train() for i in range(self.local_epoch): for x, y in self.trainloader: if len(x) <= 1: diff --git a/src/client/fedproto.py b/src/client/fedproto.py index 44305fa..44d4316 100644 --- a/src/client/fedproto.py +++ b/src/client/fedproto.py @@ -64,6 +64,7 @@ def train( def fit(self): self.model.train() + self.dataset.train() for _ in range(self.local_epoch): for x, y in self.trainloader: if len(x) <= 1: diff --git a/src/client/fedprox.py b/src/client/fedprox.py index 0f18e5c..e403191 100644 --- a/src/client/fedprox.py +++ b/src/client/fedprox.py @@ -8,6 +8,7 @@ def __init__(self, model, args, logger, device): def fit(self): self.model.train() + self.dataset.train() global_params = trainable_params(self.model, detach=True) for _ in range(self.local_epoch): for x, y in self.trainloader: diff --git a/src/client/fedrep.py b/src/client/fedrep.py index 155fbae..420ad93 100644 --- a/src/client/fedrep.py +++ b/src/client/fedrep.py @@ -7,6 +7,7 @@ def __init__(self, model, args, logger, device): def fit(self): self.model.train() + self.dataset.train() for E in range(self.local_epoch): for x, y in self.trainloader: if len(x) <= 1: @@ -31,6 +32,7 @@ def fit(self): def finetune(self): self.model.train() + self.dataset.train() full_model = False if full_model: # fine-tune the full model diff --git a/src/client/fedrod.py b/src/client/fedrod.py index 63d03db..c4bcaf4 100644 --- a/src/client/fedrod.py +++ b/src/client/fedrod.py @@ -94,6 +94,8 @@ def train( ) def fit(self): + self.model.train() + self.dataset.train() label_counts = torch.tensor( count_labels(self.dataset, self.trainset.indices), device=self.device ) @@ -115,7 +117,6 @@ def fit(self): .clone() ) - self.model.train() for _ in range(self.local_epoch): for x, y in self.trainloader: if len(x) <= 1: @@ -149,6 +150,7 @@ def fit(self): def finetune(self): self.model.train() + self.dataset.train() for _ in range(self.args.finetune_epoch): for x, y in self.trainloader: if len(x) <= 1: diff --git a/src/client/fedsr.py b/src/client/fedsr.py index 75c5516..77e6fd0 100644 --- a/src/client/fedsr.py +++ b/src/client/fedsr.py @@ -10,6 +10,7 @@ def __init__(self, model, args, logger, device): def fit(self): self.model.train() + self.dataset.train() for i in range(self.local_epoch): for x, y in self.trainloader: if len(x) <= 1: diff --git a/src/client/knnper.py b/src/client/knnper.py index 1fc9893..a6d3250 100644 --- a/src/client/knnper.py +++ b/src/client/knnper.py @@ -1,5 +1,5 @@ import random -from typing import Dict, List +from typing import List import torch import numpy as np @@ -17,7 +17,7 @@ def __init__(self, model, args, logger, device): @torch.no_grad() def evaluate(self, model=None): if self.test_flag: - self.dataset.enable_train_transform = False + self.dataset.eval() target_model = self.model if model is None else model target_model.eval() criterion = torch.nn.CrossEntropyLoss(reduction="sum") diff --git a/src/client/metafed.py b/src/client/metafed.py index a49afac..4934430 100644 --- a/src/client/metafed.py +++ b/src/client/metafed.py @@ -60,7 +60,9 @@ def train( return trainable_params(self.model, detach=True), stats def fit(self): + self.model.train() self.teacher.eval() + self.dataset.train() for _ in range(self.local_epoch): for x, y in self.trainloader: x, y = x.to(self.device), y.to(self.device) diff --git a/src/client/moon.py b/src/client/moon.py index 554aec7..e1daff1 100644 --- a/src/client/moon.py +++ b/src/client/moon.py @@ -29,6 +29,7 @@ def set_parameters(self, new_parameters): def fit(self): self.model.train() + self.dataset.train() for _ in range(self.local_epoch): for x, y in self.trainloader: if len(x) <= 1: diff --git a/src/client/perfedavg.py b/src/client/perfedavg.py index 4800f1b..e0059bd 100644 --- a/src/client/perfedavg.py +++ b/src/client/perfedavg.py @@ -40,6 +40,7 @@ def load_dataset(self): def fit(self): self.model.train() + self.dataset.train() for _ in range(self.local_epoch): for _ in range(len(self.trainloader) // (2 + (self.args.version == "hf"))): x0, y0 = self.get_data_batch() diff --git a/src/client/pfedme.py b/src/client/pfedme.py index 1264edd..ea58919 100644 --- a/src/client/pfedme.py +++ b/src/client/pfedme.py @@ -44,6 +44,7 @@ def save_state(self): def fit(self): self.model.train() + self.dataset.train() for _ in range(self.args.local_epoch): for x, y in self.trainloader: if len(x) <= 1: diff --git a/src/client/scaffold.py b/src/client/scaffold.py index 15ae502..b66bcca 100644 --- a/src/client/scaffold.py +++ b/src/client/scaffold.py @@ -61,6 +61,7 @@ def train( def fit(self): self.model.train() + self.dataset.train() for _ in range(self.args.local_epoch): x, y = self.get_data_batch() logits = self.model(x) diff --git a/src/server/adcol.py b/src/server/adcol.py index d8623eb..8be53c5 100644 --- a/src/server/adcol.py +++ b/src/server/adcol.py @@ -82,6 +82,7 @@ def __init__( self.client_num, ) self.feature_dataloader = None + self.features = {} def train_one_round(self): delta_cache = [] diff --git a/src/server/elastic.py b/src/server/elastic.py index a62d09c..7f4fdc6 100644 --- a/src/server/elastic.py +++ b/src/server/elastic.py @@ -30,7 +30,7 @@ def __init__( deepcopy(self.model), self.args, self.logger, self.device ) self.client_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - self.trainer.optimizer, self.args.global_epoch, verbose=False + self.trainer.optimizer, self.args.global_epoch ) def train_one_round(self):