Skip to content

Commit

Permalink
Merge pull request #5 from APPFL/zilinghan/ppfl
Browse files Browse the repository at this point in the history
Adding a demo ppfl algorithm for showcasing framework's extendibility
  • Loading branch information
Zilinghan authored Mar 31, 2024
2 parents 4a0322d + ef934eb commit a30b800
Show file tree
Hide file tree
Showing 13 changed files with 815 additions and 39 deletions.
77 changes: 77 additions & 0 deletions examples/config/server_iceadmm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
client_configs:
train_configs:
# Local trainer
trainer: "ICEADMMTrainer"
mode: "step"
num_local_steps: 100
optim: "Adam"
optim_args:
lr: 0.001
# Algorithm specific
accum_grad: True
coeff_grad: False
init_penalty: 500.0
residual_balancing:
res_on: False
res_on_every_update: False
tau: 2
mu: 2
init_proximity: 0
# Loss function
loss_fn_path: "./loss/celoss.py"
loss_fn_name: "CELoss"
# Client validation
do_validation: True
do_pre_validation: True
pre_validation_interval: 1
metric_path: "./metric/acc.py"
metric_name: "accuracy"
# Differential privacy
use_dp: False
epsilon: 1
clip_grad: False
clip_value: 1
clip_norm: 1
# Data loader
train_batch_size: 64
val_batch_size: 64
train_data_shuffle: True
val_data_shuffle: False

model_configs:
model_path: "./model/cnn.py"
model_name: "CNN"
model_kwargs:
num_channel: 1
num_classes: 10
num_pixel: 28

comm_configs:
compressor_configs:
enable_compression: False
# Used if enable_compression is True
lossy_compressor: "SZ2"
lossless_compressor: "blosc"
error_bounding_mode: "REL"
error_bound: 1e-3
flat_model_dtype: "np.float32"
param_cutoff: 1024

server_configs:
scheduler: "SyncScheduler"
scheduler_kwargs:
num_clients: 2
same_init_model: True
aggregator: "ICEADMMAggregator"
aggregator_kwargs:
num_clients: 2
device: "cpu"
num_global_epochs: 10
server_validation: False
logging_output_dirname: "./output"
logging_output_filename: "result"
comm_configs:
grpc_configs:
server_uri: localhost:50051
max_message_size: 1048576
use_ssl: False
39 changes: 39 additions & 0 deletions examples/grpc/run_client_iceadmm_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Running the ICEADMM algorithm using gRPC for FL. This example mainly shows
the extendibility of the framework to support custom algorithms. In this case,
the server and clients need to communicate primal and dual states, and a
penalty parameter. In addition, the clients also need to know its relative
sample size for local training purposes.
"""
from omegaconf import OmegaConf
from appfl.agent import APPFLClientAgent
from appfl.comm.grpc import GRPCClientCommunicator

client_agent_config = OmegaConf.load("config/client_1.yaml")

client_agent = APPFLClientAgent(client_agent_config=client_agent_config)
client_communicator = GRPCClientCommunicator(
client_id = client_agent.get_id(),
**client_agent_config.comm_configs.grpc_configs,
)

client_config = client_communicator.get_configuration()
client_agent.load_config(client_config)

init_global_model = client_communicator.get_global_model(init_model=True)
client_agent.load_parameters(init_global_model)

# Send the number of local data to the server
sample_size = client_agent.get_sample_size()
client_weight = client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size, sync=True)
client_agent.trainer.set_weight(client_weight["client_weight"])

while True:
client_agent.train()
local_model = client_agent.get_parameters()
new_global_model, metadata = client_communicator.update_global_model(local_model)
if metadata['status'] == 'DONE':
break
if 'local_steps' in metadata:
client_agent.trainer.train_configs.num_local_steps = metadata['local_steps']
client_agent.load_parameters(new_global_model)
39 changes: 39 additions & 0 deletions examples/grpc/run_client_iceadmm_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Running the ICEADMM algorithm using gRPC for FL. This example mainly shows
the extendibility of the framework to support custom algorithms. In this case,
the server and clients need to communicate primal and dual states, and a
penalty parameter. In addition, the clients also need to know its relative
sample size for local training purposes.
"""
from omegaconf import OmegaConf
from appfl.agent import APPFLClientAgent
from appfl.comm.grpc import GRPCClientCommunicator

client_agent_config = OmegaConf.load("config/client_2.yaml")

client_agent = APPFLClientAgent(client_agent_config=client_agent_config)
client_communicator = GRPCClientCommunicator(
client_id = client_agent.get_id(),
**client_agent_config.comm_configs.grpc_configs,
)

client_config = client_communicator.get_configuration()
client_agent.load_config(client_config)

init_global_model = client_communicator.get_global_model(init_model=True)
client_agent.load_parameters(init_global_model)

# Send the number of local data to the server
sample_size = client_agent.get_sample_size()
client_weight = client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size, sync=True)
client_agent.trainer.set_weight(client_weight["client_weight"])

while True:
client_agent.train()
local_model = client_agent.get_parameters()
new_global_model, metadata = client_communicator.update_global_model(local_model)
if metadata['status'] == 'DONE':
break
if 'local_steps' in metadata:
client_agent.trainer.train_configs.num_local_steps = metadata['local_steps']
client_agent.load_parameters(new_global_model)
64 changes: 64 additions & 0 deletions examples/mpi/run_mpi_iceadmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Running the ICEADMM algorithm using MPI for FL. This example mainly shows
the extendibility of the framework to support custom algorithms. In this case,
the server and clients need to communicate primal and dual states, and a
penalty parameter. In addition, the clients also need to know its relative
sample size for local training purposes.
"""

import argparse
from mpi4py import MPI
from omegaconf import OmegaConf
from appfl.agent import APPFLClientAgent, APPFLServerAgent
from appfl.comm.mpi import MPIClientCommunicator, MPIServerCommunicator

argparse = argparse.ArgumentParser()
argparse.add_argument("--server_config", type=str, default="config/server_iceadmm.yaml")
argparse.add_argument("--client_config", type=str, default="config/client_1.yaml")
args = argparse.parse_args()

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
num_clients = size - 1

if rank == 0:
# Load and set the server configurations
server_agent_config = OmegaConf.load(args.server_config)
server_agent_config.server_configs.scheduler_kwargs.num_clients = num_clients
if hasattr(server_agent_config.server_configs.aggregator_kwargs, "num_clients"):
server_agent_config.server_configs.aggregator_kwargs.num_clients = num_clients
# Create the server agent and communicator
server_agent = APPFLServerAgent(server_agent_config=server_agent_config)
server_communicator = MPIServerCommunicator(comm, server_agent)
# Start the server to serve the clients
server_communicator.serve()
else:
# Set the client configurations
client_agent_config = OmegaConf.load(args.client_config)
client_agent_config.train_configs.logging_id = f'Client{rank}'
client_agent_config.data_configs.dataset_kwargs.num_clients = num_clients
client_agent_config.data_configs.dataset_kwargs.client_id = rank - 1
client_agent_config.data_configs.dataset_kwargs.visualization = True if rank == 1 else False
# Create the client agent and communicator
client_agent = APPFLClientAgent(client_agent_config=client_agent_config)
client_communicator = MPIClientCommunicator(comm, server_rank=0)
# Load the configurations and initial global model
client_config = client_communicator.get_configuration()
client_agent.load_config(client_config)
init_global_model = client_communicator.get_global_model(init_model=True)
client_agent.load_parameters(init_global_model)
# (Specific to ICEADMM) Send the sample size to the server and set the client weight
sample_size = client_agent.get_sample_size()
client_weight = client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size, sync=True)
client_agent.trainer.set_weight(client_weight["client_weight"])
# Local training and global model update iterations
while True:
client_agent.train()
local_model = client_agent.get_parameters()
new_global_model, metadata = client_communicator.update_global_model(local_model)
if metadata['status'] == 'DONE':
break
if 'local_steps' in metadata:
client_agent.trainer.train_configs.num_local_steps = metadata['local_steps']
client_agent.load_parameters(new_global_model)
52 changes: 49 additions & 3 deletions src/appfl/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from appfl.logger import ServerAgentFileLogger
from concurrent.futures import Future
from omegaconf import OmegaConf, DictConfig
from typing import Union, Dict, OrderedDict, Tuple
from typing import Union, Dict, OrderedDict, Tuple, Optional

class APPFLServerAgent:
"""
Expand Down Expand Up @@ -91,9 +91,55 @@ def set_sample_size(
self,
client_id: Union[int, str],
sample_size: int,
) -> None:
"""Set the size of the local dataset of a client."""
sync: bool = False,
blocking: bool = False,
) -> Optional[Union[Dict, Future]]:
"""
Set the size of the local dataset of a client.
:param: client_id: A unique client id for server to distinguish clients, which can be obtained via `ClientAgent.get_id()`.
:param: sample_size: The size of the local dataset of a client.
:param: sync: Whether to synchronize the sample size among all clients. If `True`, the method can return the relative weight of the client.
:param: blocking: Whether to block the client until the sample size of all clients is synchronized.
If `True`, the method will return the relative weight of the client.
Otherwise, the method may return a `Future` object of the relative weight, which will be resolved
when the sample size of all clients is synchronized.
"""
if sync:
assert (
hasattr(self.server_agent_config.server_configs, "num_clients") or
hasattr(self.server_agent_config.server_configs.scheduler_kwargs, "num_clients") or
hasattr(self.server_agent_config.server_configs.aggregator_kwargs, "num_clients")
), "The number of clients should be set in the server configurations."
num_clients = (
self.server_agent_config.server_configs.num_clients if
hasattr(self.server_agent_config.server_configs, "num_clients") else
self.server_agent_config.server_configs.scheduler_kwargs.num_clients if
hasattr(self.server_agent_config.server_configs.scheduler_kwargs, "num_clients") else
self.server_agent_config.server_configs.aggregator_kwargs.num_clients
)
self.aggregator.set_client_sample_size(client_id, sample_size)
if sync:
if not hasattr(self, "_client_sample_size"):
self._client_sample_size = {}
self._client_sample_size_future = {}
self._client_sample_size_lock = threading.Lock()
with self._client_sample_size_lock:
self._client_sample_size[client_id] = sample_size
future = Future()
self._client_sample_size_future[client_id] = future
if len(self._client_sample_size) == num_clients:
total_sample_size = sum(self._client_sample_size.values())
for client_id in self._client_sample_size_future:
self._client_sample_size_future[client_id].set_result(
{"client_weight": self._client_sample_size[client_id] / total_sample_size}
)
self._client_sample_size = {}
self._client_sample_size_future = {}
if blocking:
return future.result()
else:
return future
return None

def training_finished(self, internal_check: bool = False) -> bool:
"""Notify the client whether the training is finished."""
Expand Down
3 changes: 2 additions & 1 deletion src/appfl/aggregator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base_aggregator import BaseAggregator
from .fedavg_aggregator import FedAvgAggregator
from .fedasync_aggregator import FedAsyncAggregator
from .fedcompass_aggregator import FedCompassAggregator
from .fedcompass_aggregator import FedCompassAggregator
from .iceadmm_aggregator import ICEADMMAggregator
Loading

0 comments on commit a30b800

Please sign in to comment.