Skip to content

Commit

Permalink
Merge pull request #8 from APPFL/zilinghan/iiadmm
Browse files Browse the repository at this point in the history
Adding IIADMM
  • Loading branch information
Zilinghan authored Apr 3, 2024
2 parents b6661eb + 70f0516 commit 516ce6a
Show file tree
Hide file tree
Showing 13 changed files with 743 additions and 24 deletions.
75 changes: 75 additions & 0 deletions examples/config/server_fedbuff.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
client_configs:
train_configs:
# Local trainer
trainer: "NaiveTrainer"
mode: "step"
num_local_steps: 100
optim: "Adam"
optim_args:
lr: 0.001
# Loss function
loss_fn_path: "./loss/celoss.py"
loss_fn_name: "CELoss"
# Client validation
do_validation: True
do_pre_validation: True
metric_path: "./metric/acc.py"
metric_name: "accuracy"
# Differential privacy
use_dp: False
epsilon: 1
clip_grad: False
clip_value: 1
clip_norm: 1
# Data format
send_gradient: True
# Data loader
train_batch_size: 64
val_batch_size: 64
train_data_shuffle: True
val_data_shuffle: False

model_configs:
model_path: "./model/cnn.py"
model_name: "CNN"
model_kwargs:
num_channel: 1
num_classes: 10
num_pixel: 28

comm_configs:
compressor_configs:
enable_compression: False
# Used if enable_compression is True
lossy_compressor: "SZ2"
lossless_compressor: "blosc"
error_bounding_mode: "REL"
error_bound: 1e-3
flat_model_dtype: "np.float32"
param_cutoff: 1024

server_configs:
scheduler: "AsyncScheduler"
scheduler_kwargs:
num_clients: 2
same_init_model: True
aggregator: "FedBuffAggregator"
aggregator_kwargs:
client_weights_mode: "equal"
num_clients: 2
staleness_fn: "polynomial"
staleness_fn_kwargs:
a: 0.5
alpha: 0.9
gradient_based: True
K: 3
device: "cpu"
num_global_epochs: 20
server_validation: False
logging_output_dirname: "./output"
logging_output_filename: "result"
comm_configs:
grpc_configs:
server_uri: localhost:50051
max_message_size: 1048576
use_ssl: False
76 changes: 76 additions & 0 deletions examples/config/server_iiadmm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
client_configs:
train_configs:
# Local trainer
trainer: "IIADMMTrainer"
mode: "step"
num_local_steps: 100
optim: "Adam"
optim_args:
lr: 0.001
# Algorithm specific
accum_grad: True
coeff_grad: False
init_penalty: 100.0
residual_balancing:
res_on: False
res_on_every_update: False
tau: 1.1
mu: 10
# Loss function
loss_fn_path: "./loss/celoss.py"
loss_fn_name: "CELoss"
# Client validation
do_validation: True
do_pre_validation: True
pre_validation_interval: 1
metric_path: "./metric/acc.py"
metric_name: "accuracy"
# Differential privacy
use_dp: False
epsilon: 1
clip_grad: False
clip_value: 1
clip_norm: 1
# Data loader
train_batch_size: 64
val_batch_size: 64
train_data_shuffle: True
val_data_shuffle: False

model_configs:
model_path: "./model/cnn.py"
model_name: "CNN"
model_kwargs:
num_channel: 1
num_classes: 10
num_pixel: 28

comm_configs:
compressor_configs:
enable_compression: False
# Used if enable_compression is True
lossy_compressor: "SZ2"
lossless_compressor: "blosc"
error_bounding_mode: "REL"
error_bound: 1e-3
flat_model_dtype: "np.float32"
param_cutoff: 1024

server_configs:
scheduler: "SyncScheduler"
scheduler_kwargs:
num_clients: 2
same_init_model: True
aggregator: "IIADMMAggregator"
aggregator_kwargs:
num_clients: 2
device: "cpu"
num_global_epochs: 10
server_validation: False
logging_output_dirname: "./output"
logging_output_filename: "result"
comm_configs:
grpc_configs:
server_uri: localhost:50051
max_message_size: 1048576
use_ssl: False
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""
Running the ICEADMM algorithm using MPI for FL. This example mainly shows
Running the ADMM-based algorithm using MPI for FL. This example mainly shows
the extendibility of the framework to support custom algorithms. In this case,
the server and clients need to communicate primal and dual states, and a
penalty parameter. In addition, the clients also need to know its relative
sample size for local training purposes.
mpiexec -n 6 python mpi/run_mpi_admm.py --server_config config/server_iiadmm.yaml
mpiexec -n 6 python mpi/run_mpi_admm.py --server_config config/server_iceadmm.yaml
"""

import argparse
Expand Down Expand Up @@ -48,7 +50,7 @@
client_agent.load_config(client_config)
init_global_model = client_communicator.get_global_model(init_model=True)
client_agent.load_parameters(init_global_model)
# (Specific to ICEADMM) Send the sample size to the server and set the client weight
# (Specific to ICEADMM and IIADMM) Send the sample size to the server and set the client weight
sample_size = client_agent.get_sample_size()
client_weight = client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size, sync=True)
client_agent.trainer.set_weight(client_weight["client_weight"])
Expand Down
2 changes: 2 additions & 0 deletions src/appfl/aggregator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@
from .fedyogi_aggregator import FedYogiAggregator
from .fedadagrad_aggregator import FedAdagradAggregator
from .fedasync_aggregator import FedAsyncAggregator
from .fedbuff_aggregator import FedBuffAggregator
from .fedcompass_aggregator import FedCompassAggregator
from .iiadmm_aggregator import IIADMMAggregator
from .iceadmm_aggregator import ICEADMMAggregator
48 changes: 31 additions & 17 deletions src/appfl/aggregator/fedasync_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from typing import Union, Dict, OrderedDict, Any

class FedAsyncAggregator(BaseAggregator):
"""
FedAsync Aggregator class for Federated Learning.
For more details, check paper: https://arxiv.org/pdf/1903.03934.pdf
"""
def __init__(
self,
model: torch.nn.Module,
Expand All @@ -26,12 +30,33 @@ def __init__(
self.alpha = self.aggregator_config.get("alpha", 0.9)
self.global_step = 0
self.client_step = {}
self.step = {}

def get_parameters(self, **kwargs) -> Dict:
return copy.deepcopy(self.model.state_dict())

def aggregate(self, client_id: Union[str, int], local_model: Union[Dict, OrderedDict], **kwargs) -> Dict:
global_state = copy.deepcopy(self.model.state_dict())

self.compute_steps(client_id, local_model)

for name in self.model.state_dict():
if name not in self.named_parameters:
global_state[name] = local_model[name]
else:
global_state[name] += self.step[name]
self.model.load_state_dict(global_state)
self.global_step += 1
self.client_step[client_id] = self.global_step
return global_state

def compute_steps(self, client_id: Union[str, int], local_model: Union[Dict, OrderedDict],):
"""
Compute changes to the global model after the aggregation.
"""
if client_id not in self.client_step:
self.client_step[client_id] = 0
gradient_based = self.aggregator_config.get("gradient_based", False)
global_state = copy.deepcopy(self.model.state_dict())
if (
self.client_weights_mode == "sample_size" and
hasattr(self, "client_sample_size") and
Expand All @@ -41,21 +66,11 @@ def aggregate(self, client_id: Union[str, int], local_model: Union[Dict, Ordered
else:
weight = 1.0 / self.aggregator_config.get("num_clients", 1)
alpha_t = self.alpha * self.staleness_fn(self.global_step - self.client_step[client_id]) * weight
for name in self.model.state_dict():
if name in self.named_parameters:
if gradient_based:
global_state[name] -= local_model[name] * alpha_t
else:
global_state[name] = (1-alpha_t) * global_state[name] + alpha_t * local_model[name]
else:
global_state[name] = local_model[name]
self.model.load_state_dict(global_state)
self.global_step += 1
self.client_step[client_id] = self.global_step
return global_state

def get_parameters(self, **kwargs) -> Dict:
return copy.deepcopy(self.model.state_dict())
for name in self.named_parameters:
self.step[name] = (
alpha_t * (-local_model[name]) if gradient_based
else alpha_t * (local_model[name] - self.model.state_dict()[name])
)

def __staleness_fn_factory(self, staleness_fn_name, **kwargs):
if staleness_fn_name == "constant":
Expand All @@ -69,4 +84,3 @@ def __staleness_fn_factory(self, staleness_fn_name, **kwargs):
return lambda u: 1 if u <= b else 1.0/ (a * (u - b) + 1.0)
else:
raise NotImplementedError

2 changes: 1 addition & 1 deletion src/appfl/aggregator/fedavg_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def compute_steps(self, local_models: Dict[Union[str, int], Union[Dict, OrderedD
"""
Compute the changes to the global model after the aggregation.
"""
for name in self.model.state_dict():
for name in self.named_parameters:
self.step[name] = torch.zeros_like(self.model.state_dict()[name])
for client_id, model in local_models.items():
if (
Expand Down
68 changes: 68 additions & 0 deletions src/appfl/aggregator/fedbuff_aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import copy
import torch
from omegaconf import DictConfig
from appfl.aggregator import FedAsyncAggregator
from typing import Union, Dict, OrderedDict, Any

class FedBuffAggregator(FedAsyncAggregator):
"""
FedBuff Aggregator class for Federated Learning.
For more details, check paper: https://proceedings.mlr.press/v151/nguyen22b/nguyen22b.pdf
"""
def __init__(
self,
model: torch.nn.Module,
aggregator_config: DictConfig,
logger: Any
):
super().__init__(model, aggregator_config, logger)
self.buff_size = 0
self.K = self.aggregator_config.K

def aggregate(self, client_id: Union[str, int], local_model: Union[Dict, OrderedDict], **kwargs) -> Dict:
global_state = copy.deepcopy(self.model.state_dict())

self.compute_steps(client_id, local_model)
self.buff_size += 1
if self.buff_size == self.K:
for name in self.model.state_dict():
if name not in self.named_parameters:
global_state[name] = torch.div(self.step[name], self.K)
else:
global_state[name] += self.step[name]
self.model.load_state_dict(global_state)
self.global_step += 1
self.buff_size = 0

self.client_step[client_id] = self.global_step
return global_state

def compute_steps(self, client_id: Union[str, int], local_model: Union[Dict, OrderedDict],):
"""
Compute changes to the global model after the aggregation.
"""
if self.buff_size == 0:
for name in self.model.state_dict():
self.step[name] = torch.zeros_like(self.model.state_dict()[name])

if client_id not in self.client_step:
self.client_step[client_id] = 0
gradient_based = self.aggregator_config.get("gradient_based", False)
if (
self.client_weights_mode == "sample_size" and
hasattr(self, "client_sample_size") and
client_id in self.client_sample_size
):
weight = self.client_sample_size[client_id] / sum(self.client_sample_size.values())
else:
weight = 1.0 / self.aggregator_config.get("num_clients", 1)
alpha_t = self.alpha * self.staleness_fn(self.global_step - self.client_step[client_id]) * weight

for name in self.model.state_dict():
if name in self.named_parameters:
self.step[name] += (
alpha_t * (-local_model[name]) if gradient_based
else alpha_t * (local_model[name] - self.model.state_dict()[name])
)
else:
self.step[name] += local_model[name]
4 changes: 2 additions & 2 deletions src/appfl/aggregator/fedcompass_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

class FedCompassAggregator(BaseAggregator):
"""
Aggregator for `FedCompass` semi-asynchronous federated learning algorithm.
Paper reference: https://arxiv.org/abs/2309.14675
FedCompass semi-asynchronous federated learning algorithm.
For more details, check paper: https://arxiv.org/abs/2309.14675
"""
def __init__(
self,
Expand Down
5 changes: 5 additions & 0 deletions src/appfl/aggregator/iceadmm_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from appfl.aggregator import BaseAggregator

class ICEADMMAggregator(BaseAggregator):
"""
ICEADMM Aggregator class for Federated Learning.
It has to be used with the ICEADMMTrainer.
For more details, check paper: https://arxiv.org/pdf/2110.15318.pdf
"""
def __init__(
self,
model: nn.Module,
Expand Down
Loading

0 comments on commit 516ce6a

Please sign in to comment.