From 4face9ed195b06a86a71f809abbc8760a8610dc7 Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Fri, 22 Mar 2024 17:29:53 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9E=20fix(pFedSim):=20Empty=20self.cli?= =?UTF-8?q?ent=5Ftrainable=5Fparams=20at=20server=20side?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Due to no tracking gradients anymore for tensors in self.global_params_dict, the initialization of self.client_trainable_params needs to be changed. Also delete the arguments `unique_model` and `default_trainer` that pass to `super.__init__()`, which means these two args should be decided by pFedSim's parent class. --- src/server/pfedsim.py | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/src/server/pfedsim.py b/src/server/pfedsim.py index 98c65b5..d774375 100644 --- a/src/server/pfedsim.py +++ b/src/server/pfedsim.py @@ -16,16 +16,10 @@ def get_pfedsim_argparser() -> ArgumentParser: class pFedSimServer(FedAvgServer): - def __init__( - self, - algo: str = "pFedSim", - args: Namespace = None, - unique_model=False, - default_trainer=True, - ): + def __init__(self, algo: str = "pFedSim", args: Namespace = None): if args is None: args = get_pfedsim_argparser().parse_args() - super().__init__(algo, args, unique_model, default_trainer) + super().__init__(algo, args) self.test_flag = False self.weight_matrix = torch.eye(self.client_num, device=self.device) @@ -56,10 +50,13 @@ def train(self): console=self.logger.stdout, ) self.trainer.personal_params_name.extend( - [name for name in self.model.state_dict() if "classifier" in name] + [name for name in self.model.state_dict().keys() if "classifier" in name] ) self.client_trainable_params = [ - trainable_params(self.global_params_dict, detach=True) + [ + self.global_params_dict[key] + for key in trainable_params(self.model, requires_name=True)[1] + ] for _ in self.train_clients ] @@ -76,16 +73,14 @@ def train(self): client_params_cache = [] for client_id in self.selected_clients: client_pers_params = self.generate_client_params(client_id) - ( - client_params, - _, - self.client_metrics[client_id][E], - ) = self.trainer.train( - client_id=client_id, - local_epoch=self.clients_local_epoch[client_id], - new_parameters=client_pers_params, - return_diff=False, - verbose=((E + 1) % self.args.verbose_gap) == 0, + (client_params, _, self.client_metrics[client_id][E]) = ( + self.trainer.train( + client_id=client_id, + local_epoch=self.clients_local_epoch[client_id], + new_parameters=client_pers_params, + return_diff=False, + verbose=((E + 1) % self.args.verbose_gap) == 0, + ) ) client_params_cache.append(client_params)