Skip to content

Commit

Permalink
🐞 fix(method): Fix methods
Browse files Browse the repository at this point in the history
APFL, Ditto, MetaFed, pFedMe
  • Loading branch information
KarhouTam committed May 22, 2024
1 parent e1d1d27 commit 3b7d958
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/client/apfl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def _re_init(src):
return deepcopy(target.state_dict())

self.optimizer.add_param_group({"params": trainable_params(self.local_model)})
self.init_optimizer_state = deepcopy(self.optimizer.state_dict())

def set_parameters(self, package: dict[str, Any]):
super().set_parameters(package)
Expand Down
1 change: 1 addition & 0 deletions src/client/ditto.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def __init__(self, **commons):
super().__init__(**commons)
self.pers_model = deepcopy(self.model).to(self.device)
self.optimizer.add_param_group({"params": trainable_params(self.pers_model)})
self.init_optimizer_state = deepcopy(self.optimizer.state_dict())

def set_parameters(self, package: dict[str, Any]):
super().set_parameters(package)
Expand Down
2 changes: 2 additions & 0 deletions src/client/fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def __init__(
self.init_optimizer_state = deepcopy(self.optimizer.state_dict())
self.lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None
self.init_lr_scheduler_state: dict = None
self.lr_scheduler_cls = None
if lr_scheduler_cls is not None:
self.lr_scheduler_cls = lr_scheduler_cls
self.lr_scheduler = lr_scheduler_cls(optimizer=self.optimizer)
self.init_lr_scheduler_state = deepcopy(self.lr_scheduler.state_dict())

Expand Down
13 changes: 7 additions & 6 deletions src/client/metafed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from src.client.fedavg import FedAvgClient
from torch.utils.data import Subset, DataLoader
from src.utils.tools import trainable_params, evalutate_model
from src.utils.tools import evalutate_model


class MetaFedClient(FedAvgClient):
Expand Down Expand Up @@ -60,12 +60,13 @@ def set_parameters(self, package: dict[str, Any]):
if package["optimizer_state"]:
self.optimizer.load_state_dict(package["optimizer_state"])
else:
self.optimizer = self.optimizer_cls(params=trainable_params(self.model))
self.optimizer.load_state_dict(self.init_optimizer_state)

if package["lr_scheduler_state"]:
self.lr_scheduler.load_state_dict(package["lr_scheduler_state"])
elif self.lr_scheduler_cls is not None:
self.lr_scheduler = self.lr_scheduler_cls(optimizer=self.optimizer)
if self.lr_scheduler is not None:
if package["lr_scheduler_state"]:
self.lr_scheduler.load_state_dict(package["lr_scheduler_state"])
else:
self.lr_scheduler.load_state_dict(self.init_lr_scheduler_state)

self.client_flag = package["client_flag"]
self.model.load_state_dict(package["student_model_params"])
Expand Down
3 changes: 2 additions & 1 deletion src/client/pfedme.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def __init__(self, **commons):
self.args.pfedme.lamda,
self.args.pfedme.mu,
)
self.lr_scheduler = self.lr_scheduler_cls(self.optimizer)
if self.lr_scheduler_cls is not None:
self.lr_scheduler = self.lr_scheduler_cls(self.optimizer)

def set_parameters(self, package: dict[str, Any]):
super().set_parameters(package)
Expand Down

0 comments on commit 3b7d958

Please sign in to comment.