Skip to content

Commit

Permalink
🐞 fix(pFedSim): Empty self.client_trainable_params at server side
Browse files Browse the repository at this point in the history
Due to no tracking gradients anymore for tensors in self.global_params_dict, the initialization of self.client_trainable_params needs to be changed.

Also delete the arguments `unique_model` and `default_trainer` that pass to `super.__init__()`, which means these two args should be decided by pFedSim's parent class.
  • Loading branch information
KarhouTam committed Mar 22, 2024
1 parent cd935ac commit 4face9e
Showing 1 changed file with 15 additions and 20 deletions.
35 changes: 15 additions & 20 deletions src/server/pfedsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,10 @@ def get_pfedsim_argparser() -> ArgumentParser:


class pFedSimServer(FedAvgServer):
def __init__(
self,
algo: str = "pFedSim",
args: Namespace = None,
unique_model=False,
default_trainer=True,
):
def __init__(self, algo: str = "pFedSim", args: Namespace = None):
if args is None:
args = get_pfedsim_argparser().parse_args()
super().__init__(algo, args, unique_model, default_trainer)
super().__init__(algo, args)
self.test_flag = False
self.weight_matrix = torch.eye(self.client_num, device=self.device)

Expand Down Expand Up @@ -56,10 +50,13 @@ def train(self):
console=self.logger.stdout,
)
self.trainer.personal_params_name.extend(
[name for name in self.model.state_dict() if "classifier" in name]
[name for name in self.model.state_dict().keys() if "classifier" in name]
)
self.client_trainable_params = [
trainable_params(self.global_params_dict, detach=True)
[
self.global_params_dict[key]
for key in trainable_params(self.model, requires_name=True)[1]
]
for _ in self.train_clients
]

Expand All @@ -76,16 +73,14 @@ def train(self):
client_params_cache = []
for client_id in self.selected_clients:
client_pers_params = self.generate_client_params(client_id)
(
client_params,
_,
self.client_metrics[client_id][E],
) = self.trainer.train(
client_id=client_id,
local_epoch=self.clients_local_epoch[client_id],
new_parameters=client_pers_params,
return_diff=False,
verbose=((E + 1) % self.args.verbose_gap) == 0,
(client_params, _, self.client_metrics[client_id][E]) = (
self.trainer.train(
client_id=client_id,
local_epoch=self.clients_local_epoch[client_id],
new_parameters=client_pers_params,
return_diff=False,
verbose=((E + 1) % self.args.verbose_gap) == 0,
)
)
client_params_cache.append(client_params)

Expand Down

0 comments on commit 4face9e

Please sign in to comment.