Skip to content

Commit

Permalink
Floco fixes and better default parameters (#143)
Browse files Browse the repository at this point in the history
Bugfix and better default parameters (10 simplex endpoints and client projection after 10 rounds) for Floco ([Federated Learning over Connected Modes](https://openreview.net/forum?id=JL2eMCfDW8))

Tested on cifar10 (`python generate_data.py -d cifar10 -a 0.1 -cn 100`).

Local test accuracies (reproduce with `dataset.name=cifar10 common.test.client.interval=1`):
- FedAvg (`python main.py method=fedavg`): 53.84%
- Floco (`python main.py method=floco`): 77.15%
- Floco+ (`python main.py method=floco` +floco.pers_epoch=1): 78.95%
  • Loading branch information
birnbaum authored Dec 22, 2024
1 parent 6fd0266 commit a393db5
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 32 deletions.
10 changes: 4 additions & 6 deletions src/client/floco.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ def __init__(self, **commons):

def set_parameters(self, package: dict[str, Any]):
super().set_parameters(package)
if package["subregion_parameters"]:
self.model.sample_from = package["sample_from"]
self.model.subregion_parameters = package["subregion_parameters"]
self.model.subregion_parameters = package["subregion_parameters"]
if self.args.floco.pers_epoch > 0: # Floco+
self.global_params = OrderedDict(
(key, param.to(self.device))
Expand Down Expand Up @@ -59,9 +57,9 @@ def fit(self):
@torch.no_grad()
def evaluate(self):
if self.args.floco.pers_epoch > 0: # Floco+
super().evaluate(self.pers_model)
return super().evaluate(self.pers_model)
else:
super().evaluate()
return super().evaluate()


def training_loop(
Expand Down Expand Up @@ -93,6 +91,6 @@ def training_loop(


def _regularize_pers_model(model, reg_model_params, lamda):
for pers_param, global_param in zip(model.parameters(), reg_model_params):
for pers_param, global_param in zip(model.parameters(), reg_model_params.values()):
if pers_param.requires_grad and pers_param.grad is not None:
pers_param.grad.data += lamda * pers_param.data - global_param.data
46 changes: 20 additions & 26 deletions src/server/floco.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class FlocoServer(FedAvgServer):
@staticmethod
def get_hyperparams(args_list=None) -> Namespace:
parser = ArgumentParser()
parser.add_argument("--endpoints", type=int, default=1)
parser.add_argument("--tau", type=int, default=100)
parser.add_argument("--endpoints", type=int, default=10)
parser.add_argument("--tau", type=int, default=10)
parser.add_argument("--rho", type=float, default=0.1)

# Floco+ (only used if pers_epoch > 0)
Expand Down Expand Up @@ -74,23 +74,15 @@ def train_one_round(self):
self.clients_personalized_model_params[client_id] = client_packages[
client_id
]["personalized_model_params"]
self.aggregate(client_packages)
self.aggregate_client_updates(client_packages)

def package(self, client_id: int):
server_package = super().package(client_id)
if self.projected_clients is None:
server_package["sample_from"] = (
"simplex_center" if self.testing else "simplex_uniform"
)
server_package["subregion_parameters"] = None
else:
server_package["sample_from"] = (
"subregion_center" if self.testing else "subregion_uniform"
)
server_package["subregion_parameters"] = (
self.projected_clients[client_id],
self.args.floco.rho,
)
server_package["subregion_parameters"] = (
None
if self.projected_clients is None
else (self.projected_clients[client_id], self.args.floco.rho)
)
if self.args.floco.pers_epoch > 0: # Floco+
server_package["personalized_model_params"] = (
self.clients_personalized_model_params[client_id]
Expand All @@ -114,20 +106,22 @@ def __init__(self, args) -> None:
bias=True,
seed=self.args.common.seed,
)
self.sample_from = "simplex_center"
self.subregion_parameters = None

def forward(self, x):
endpoints = self.args.floco.endpoints
if self.sample_from == "simplex_center":
self.classifier.alphas = tuple([1 / endpoints for _ in range(endpoints)])
elif self.sample_from == "simplex_uniform":
sample = np.random.exponential(scale=1.0, size=endpoints)
self.classifier.alphas = sample / sample.sum()
elif self.sample_from == "subregion_center":
self.classifier.alphas = self.subregion_parameters[0]
elif self.sample_from == "subregion_uniform":
self.classifier.alphas = _sample_L1_ball(*self.subregion_parameters)
if self.subregion_parameters is None: # before projection
if self.training: # sample uniformly from simplex for training
sample = np.random.exponential(scale=1.0, size=endpoints)
self.classifier.alphas = sample / sample.sum()
else: # use simplex center for testing
simplex_center = tuple([1 / endpoints for _ in range(endpoints)])
self.classifier.alphas = simplex_center
else: # after projection
if self.training: # sample uniformly from subregion for training
self.classifier.alphas = _sample_L1_ball(*self.subregion_parameters)
else: # use subregion center for testing
self.classifier.alphas = self.subregion_parameters[0]
return super().forward(x)


Expand Down

0 comments on commit a393db5

Please sign in to comment.