-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from APPFL/zilinghan/ppfl
Adding a demo ppfl algorithm for showcasing framework's extendibility
- Loading branch information
Showing
13 changed files
with
815 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
client_configs: | ||
train_configs: | ||
# Local trainer | ||
trainer: "ICEADMMTrainer" | ||
mode: "step" | ||
num_local_steps: 100 | ||
optim: "Adam" | ||
optim_args: | ||
lr: 0.001 | ||
# Algorithm specific | ||
accum_grad: True | ||
coeff_grad: False | ||
init_penalty: 500.0 | ||
residual_balancing: | ||
res_on: False | ||
res_on_every_update: False | ||
tau: 2 | ||
mu: 2 | ||
init_proximity: 0 | ||
# Loss function | ||
loss_fn_path: "./loss/celoss.py" | ||
loss_fn_name: "CELoss" | ||
# Client validation | ||
do_validation: True | ||
do_pre_validation: True | ||
pre_validation_interval: 1 | ||
metric_path: "./metric/acc.py" | ||
metric_name: "accuracy" | ||
# Differential privacy | ||
use_dp: False | ||
epsilon: 1 | ||
clip_grad: False | ||
clip_value: 1 | ||
clip_norm: 1 | ||
# Data loader | ||
train_batch_size: 64 | ||
val_batch_size: 64 | ||
train_data_shuffle: True | ||
val_data_shuffle: False | ||
|
||
model_configs: | ||
model_path: "./model/cnn.py" | ||
model_name: "CNN" | ||
model_kwargs: | ||
num_channel: 1 | ||
num_classes: 10 | ||
num_pixel: 28 | ||
|
||
comm_configs: | ||
compressor_configs: | ||
enable_compression: False | ||
# Used if enable_compression is True | ||
lossy_compressor: "SZ2" | ||
lossless_compressor: "blosc" | ||
error_bounding_mode: "REL" | ||
error_bound: 1e-3 | ||
flat_model_dtype: "np.float32" | ||
param_cutoff: 1024 | ||
|
||
server_configs: | ||
scheduler: "SyncScheduler" | ||
scheduler_kwargs: | ||
num_clients: 2 | ||
same_init_model: True | ||
aggregator: "ICEADMMAggregator" | ||
aggregator_kwargs: | ||
num_clients: 2 | ||
device: "cpu" | ||
num_global_epochs: 10 | ||
server_validation: False | ||
logging_output_dirname: "./output" | ||
logging_output_filename: "result" | ||
comm_configs: | ||
grpc_configs: | ||
server_uri: localhost:50051 | ||
max_message_size: 1048576 | ||
use_ssl: False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
""" | ||
Running the ICEADMM algorithm using gRPC for FL. This example mainly shows | ||
the extendibility of the framework to support custom algorithms. In this case, | ||
the server and clients need to communicate primal and dual states, and a | ||
penalty parameter. In addition, the clients also need to know its relative | ||
sample size for local training purposes. | ||
""" | ||
from omegaconf import OmegaConf | ||
from appfl.agent import APPFLClientAgent | ||
from appfl.comm.grpc import GRPCClientCommunicator | ||
|
||
client_agent_config = OmegaConf.load("config/client_1.yaml") | ||
|
||
client_agent = APPFLClientAgent(client_agent_config=client_agent_config) | ||
client_communicator = GRPCClientCommunicator( | ||
client_id = client_agent.get_id(), | ||
**client_agent_config.comm_configs.grpc_configs, | ||
) | ||
|
||
client_config = client_communicator.get_configuration() | ||
client_agent.load_config(client_config) | ||
|
||
init_global_model = client_communicator.get_global_model(init_model=True) | ||
client_agent.load_parameters(init_global_model) | ||
|
||
# Send the number of local data to the server | ||
sample_size = client_agent.get_sample_size() | ||
client_weight = client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size, sync=True) | ||
client_agent.trainer.set_weight(client_weight["client_weight"]) | ||
|
||
while True: | ||
client_agent.train() | ||
local_model = client_agent.get_parameters() | ||
new_global_model, metadata = client_communicator.update_global_model(local_model) | ||
if metadata['status'] == 'DONE': | ||
break | ||
if 'local_steps' in metadata: | ||
client_agent.trainer.train_configs.num_local_steps = metadata['local_steps'] | ||
client_agent.load_parameters(new_global_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
""" | ||
Running the ICEADMM algorithm using gRPC for FL. This example mainly shows | ||
the extendibility of the framework to support custom algorithms. In this case, | ||
the server and clients need to communicate primal and dual states, and a | ||
penalty parameter. In addition, the clients also need to know its relative | ||
sample size for local training purposes. | ||
""" | ||
from omegaconf import OmegaConf | ||
from appfl.agent import APPFLClientAgent | ||
from appfl.comm.grpc import GRPCClientCommunicator | ||
|
||
client_agent_config = OmegaConf.load("config/client_2.yaml") | ||
|
||
client_agent = APPFLClientAgent(client_agent_config=client_agent_config) | ||
client_communicator = GRPCClientCommunicator( | ||
client_id = client_agent.get_id(), | ||
**client_agent_config.comm_configs.grpc_configs, | ||
) | ||
|
||
client_config = client_communicator.get_configuration() | ||
client_agent.load_config(client_config) | ||
|
||
init_global_model = client_communicator.get_global_model(init_model=True) | ||
client_agent.load_parameters(init_global_model) | ||
|
||
# Send the number of local data to the server | ||
sample_size = client_agent.get_sample_size() | ||
client_weight = client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size, sync=True) | ||
client_agent.trainer.set_weight(client_weight["client_weight"]) | ||
|
||
while True: | ||
client_agent.train() | ||
local_model = client_agent.get_parameters() | ||
new_global_model, metadata = client_communicator.update_global_model(local_model) | ||
if metadata['status'] == 'DONE': | ||
break | ||
if 'local_steps' in metadata: | ||
client_agent.trainer.train_configs.num_local_steps = metadata['local_steps'] | ||
client_agent.load_parameters(new_global_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
""" | ||
Running the ICEADMM algorithm using MPI for FL. This example mainly shows | ||
the extendibility of the framework to support custom algorithms. In this case, | ||
the server and clients need to communicate primal and dual states, and a | ||
penalty parameter. In addition, the clients also need to know its relative | ||
sample size for local training purposes. | ||
""" | ||
|
||
import argparse | ||
from mpi4py import MPI | ||
from omegaconf import OmegaConf | ||
from appfl.agent import APPFLClientAgent, APPFLServerAgent | ||
from appfl.comm.mpi import MPIClientCommunicator, MPIServerCommunicator | ||
|
||
argparse = argparse.ArgumentParser() | ||
argparse.add_argument("--server_config", type=str, default="config/server_iceadmm.yaml") | ||
argparse.add_argument("--client_config", type=str, default="config/client_1.yaml") | ||
args = argparse.parse_args() | ||
|
||
comm = MPI.COMM_WORLD | ||
rank = comm.Get_rank() | ||
size = comm.Get_size() | ||
num_clients = size - 1 | ||
|
||
if rank == 0: | ||
# Load and set the server configurations | ||
server_agent_config = OmegaConf.load(args.server_config) | ||
server_agent_config.server_configs.scheduler_kwargs.num_clients = num_clients | ||
if hasattr(server_agent_config.server_configs.aggregator_kwargs, "num_clients"): | ||
server_agent_config.server_configs.aggregator_kwargs.num_clients = num_clients | ||
# Create the server agent and communicator | ||
server_agent = APPFLServerAgent(server_agent_config=server_agent_config) | ||
server_communicator = MPIServerCommunicator(comm, server_agent) | ||
# Start the server to serve the clients | ||
server_communicator.serve() | ||
else: | ||
# Set the client configurations | ||
client_agent_config = OmegaConf.load(args.client_config) | ||
client_agent_config.train_configs.logging_id = f'Client{rank}' | ||
client_agent_config.data_configs.dataset_kwargs.num_clients = num_clients | ||
client_agent_config.data_configs.dataset_kwargs.client_id = rank - 1 | ||
client_agent_config.data_configs.dataset_kwargs.visualization = True if rank == 1 else False | ||
# Create the client agent and communicator | ||
client_agent = APPFLClientAgent(client_agent_config=client_agent_config) | ||
client_communicator = MPIClientCommunicator(comm, server_rank=0) | ||
# Load the configurations and initial global model | ||
client_config = client_communicator.get_configuration() | ||
client_agent.load_config(client_config) | ||
init_global_model = client_communicator.get_global_model(init_model=True) | ||
client_agent.load_parameters(init_global_model) | ||
# (Specific to ICEADMM) Send the sample size to the server and set the client weight | ||
sample_size = client_agent.get_sample_size() | ||
client_weight = client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size, sync=True) | ||
client_agent.trainer.set_weight(client_weight["client_weight"]) | ||
# Local training and global model update iterations | ||
while True: | ||
client_agent.train() | ||
local_model = client_agent.get_parameters() | ||
new_global_model, metadata = client_communicator.update_global_model(local_model) | ||
if metadata['status'] == 'DONE': | ||
break | ||
if 'local_steps' in metadata: | ||
client_agent.trainer.train_configs.num_local_steps = metadata['local_steps'] | ||
client_agent.load_parameters(new_global_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .base_aggregator import BaseAggregator | ||
from .fedavg_aggregator import FedAvgAggregator | ||
from .fedasync_aggregator import FedAsyncAggregator | ||
from .fedcompass_aggregator import FedCompassAggregator | ||
from .fedcompass_aggregator import FedCompassAggregator | ||
from .iceadmm_aggregator import ICEADMMAggregator |
Oops, something went wrong.