Skip to content

Commit

Permalink
Finish IIADMM example
Browse files Browse the repository at this point in the history
  • Loading branch information
Zilinghan committed Apr 3, 2024
1 parent b3aa0b6 commit 70f0516
Show file tree
Hide file tree
Showing 7 changed files with 447 additions and 4 deletions.
76 changes: 76 additions & 0 deletions examples/config/server_iiadmm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
client_configs:
train_configs:
# Local trainer
trainer: "IIADMMTrainer"
mode: "step"
num_local_steps: 100
optim: "Adam"
optim_args:
lr: 0.001
# Algorithm specific
accum_grad: True
coeff_grad: False
init_penalty: 100.0
residual_balancing:
res_on: False
res_on_every_update: False
tau: 1.1
mu: 10
# Loss function
loss_fn_path: "./loss/celoss.py"
loss_fn_name: "CELoss"
# Client validation
do_validation: True
do_pre_validation: True
pre_validation_interval: 1
metric_path: "./metric/acc.py"
metric_name: "accuracy"
# Differential privacy
use_dp: False
epsilon: 1
clip_grad: False
clip_value: 1
clip_norm: 1
# Data loader
train_batch_size: 64
val_batch_size: 64
train_data_shuffle: True
val_data_shuffle: False

model_configs:
model_path: "./model/cnn.py"
model_name: "CNN"
model_kwargs:
num_channel: 1
num_classes: 10
num_pixel: 28

comm_configs:
compressor_configs:
enable_compression: False
# Used if enable_compression is True
lossy_compressor: "SZ2"
lossless_compressor: "blosc"
error_bounding_mode: "REL"
error_bound: 1e-3
flat_model_dtype: "np.float32"
param_cutoff: 1024

server_configs:
scheduler: "SyncScheduler"
scheduler_kwargs:
num_clients: 2
same_init_model: True
aggregator: "IIADMMAggregator"
aggregator_kwargs:
num_clients: 2
device: "cpu"
num_global_epochs: 10
server_validation: False
logging_output_dirname: "./output"
logging_output_filename: "result"
comm_configs:
grpc_configs:
server_uri: localhost:50051
max_message_size: 1048576
use_ssl: False
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""
Running the ICEADMM algorithm using MPI for FL. This example mainly shows
Running the ADMM-based algorithm using MPI for FL. This example mainly shows
the extendibility of the framework to support custom algorithms. In this case,
the server and clients need to communicate primal and dual states, and a
penalty parameter. In addition, the clients also need to know its relative
sample size for local training purposes.
mpiexec -n 6 python mpi/run_mpi_admm.py --server_config config/server_iiadmm.yaml
mpiexec -n 6 python mpi/run_mpi_admm.py --server_config config/server_iceadmm.yaml
"""

import argparse
Expand Down Expand Up @@ -48,7 +50,7 @@
client_agent.load_config(client_config)
init_global_model = client_communicator.get_global_model(init_model=True)
client_agent.load_parameters(init_global_model)
# (Specific to ICEADMM) Send the sample size to the server and set the client weight
# (Specific to ICEADMM and IIADMM) Send the sample size to the server and set the client weight
sample_size = client_agent.get_sample_size()
client_weight = client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size, sync=True)
client_agent.trainer.set_weight(client_weight["client_weight"])
Expand Down
1 change: 1 addition & 0 deletions src/appfl/aggregator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
from .fedasync_aggregator import FedAsyncAggregator
from .fedbuff_aggregator import FedBuffAggregator
from .fedcompass_aggregator import FedCompassAggregator
from .iiadmm_aggregator import IIADMMAggregator
from .iceadmm_aggregator import ICEADMMAggregator
File renamed without changes.
1 change: 1 addition & 0 deletions src/appfl/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base_trainer import BaseTrainer
from .naive_trainer import NaiveTrainer
from .iiadmm_trainer import IIADMMTrainer
from .iceadmm_trainer import ICEADMMTrainer
4 changes: 2 additions & 2 deletions src/appfl/trainer/iceadmm_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
self._sanity_check()

def train(self):
assert hasattr(self, "weight"), "You must set the weight of the client before training. Use `set_weight` method."
self.model.train()
self.model.to(self.train_configs.device)
do_validation = self.train_configs.get("do_validation", False) and self.val_dataloader is not None
Expand Down Expand Up @@ -172,7 +173,7 @@ def train(self):

self.round += 1


"""Differential Privacy"""
for name, param in self.model.named_parameters():
param.data = self.primal_states[name].to(self.train_configs.device)
if self.train_configs.get("use_dp", False):
Expand Down Expand Up @@ -215,7 +216,6 @@ def _sanity_check(self):
assert hasattr(self.train_configs, "num_local_epochs"), "Number of local epochs must be specified"
else:
assert hasattr(self.train_configs, "num_local_steps"), "Number of local steps must be specified"
assert hasattr(self, "weight"), "You must set the weight of the client before training. Use `set_weight` method."
if getattr(self.train_configs, "clip_grad", False) or getattr(self.train_configs, "use_dp", False):
assert hasattr(self.train_configs, "clip_value"), "Gradient clipping value must be specified"
assert hasattr(self.train_configs, "clip_norm"), "Gradient clipping norm must be specified"
Expand Down
Loading

0 comments on commit 70f0516

Please sign in to comment.