diff --git a/examples/config/server_iceadmm.yaml b/examples/config/server_iceadmm.yaml new file mode 100644 index 0000000..f16fbd0 --- /dev/null +++ b/examples/config/server_iceadmm.yaml @@ -0,0 +1,77 @@ +client_configs: + train_configs: + # Local trainer + trainer: "ICEADMMTrainer" + mode: "step" + num_local_steps: 100 + optim: "Adam" + optim_args: + lr: 0.001 + # Algorithm specific + accum_grad: True + coeff_grad: False + init_penalty: 500.0 + residual_balancing: + res_on: False + res_on_every_update: False + tau: 2 + mu: 2 + init_proximity: 0 + # Loss function + loss_fn_path: "./loss/celoss.py" + loss_fn_name: "CELoss" + # Client validation + do_validation: True + do_pre_validation: True + pre_validation_interval: 1 + metric_path: "./metric/acc.py" + metric_name: "accuracy" + # Differential privacy + use_dp: False + epsilon: 1 + clip_grad: False + clip_value: 1 + clip_norm: 1 + # Data loader + train_batch_size: 64 + val_batch_size: 64 + train_data_shuffle: True + val_data_shuffle: False + + model_configs: + model_path: "./model/cnn.py" + model_name: "CNN" + model_kwargs: + num_channel: 1 + num_classes: 10 + num_pixel: 28 + + comm_configs: + compressor_configs: + enable_compression: False + # Used if enable_compression is True + lossy_compressor: "SZ2" + lossless_compressor: "blosc" + error_bounding_mode: "REL" + error_bound: 1e-3 + flat_model_dtype: "np.float32" + param_cutoff: 1024 + +server_configs: + scheduler: "SyncScheduler" + scheduler_kwargs: + num_clients: 2 + same_init_model: True + aggregator: "ICEADMMAggregator" + aggregator_kwargs: + num_clients: 2 + device: "cpu" + num_global_epochs: 10 + server_validation: False + logging_output_dirname: "./output" + logging_output_filename: "result" + comm_configs: + grpc_configs: + server_uri: localhost:50051 + max_message_size: 1048576 + use_ssl: False \ No newline at end of file diff --git a/examples/grpc/run_client_iceadmm_1.py b/examples/grpc/run_client_iceadmm_1.py new file mode 100644 index 0000000..28c4662 --- /dev/null +++ b/examples/grpc/run_client_iceadmm_1.py @@ -0,0 +1,39 @@ +""" +Running the ICEADMM algorithm using gRPC for FL. This example mainly shows +the extendibility of the framework to support custom algorithms. In this case, +the server and clients need to communicate primal and dual states, and a +penalty parameter. In addition, the clients also need to know its relative +sample size for local training purposes. +""" +from omegaconf import OmegaConf +from appfl.agent import APPFLClientAgent +from appfl.comm.grpc import GRPCClientCommunicator + +client_agent_config = OmegaConf.load("config/client_1.yaml") + +client_agent = APPFLClientAgent(client_agent_config=client_agent_config) +client_communicator = GRPCClientCommunicator( + client_id = client_agent.get_id(), + **client_agent_config.comm_configs.grpc_configs, +) + +client_config = client_communicator.get_configuration() +client_agent.load_config(client_config) + +init_global_model = client_communicator.get_global_model(init_model=True) +client_agent.load_parameters(init_global_model) + +# Send the number of local data to the server +sample_size = client_agent.get_sample_size() +client_weight = client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size, sync=True) +client_agent.trainer.set_weight(client_weight["client_weight"]) + +while True: + client_agent.train() + local_model = client_agent.get_parameters() + new_global_model, metadata = client_communicator.update_global_model(local_model) + if metadata['status'] == 'DONE': + break + if 'local_steps' in metadata: + client_agent.trainer.train_configs.num_local_steps = metadata['local_steps'] + client_agent.load_parameters(new_global_model) \ No newline at end of file diff --git a/examples/grpc/run_client_iceadmm_2.py b/examples/grpc/run_client_iceadmm_2.py new file mode 100644 index 0000000..1f22d6d --- /dev/null +++ b/examples/grpc/run_client_iceadmm_2.py @@ -0,0 +1,39 @@ +""" +Running the ICEADMM algorithm using gRPC for FL. This example mainly shows +the extendibility of the framework to support custom algorithms. In this case, +the server and clients need to communicate primal and dual states, and a +penalty parameter. In addition, the clients also need to know its relative +sample size for local training purposes. +""" +from omegaconf import OmegaConf +from appfl.agent import APPFLClientAgent +from appfl.comm.grpc import GRPCClientCommunicator + +client_agent_config = OmegaConf.load("config/client_2.yaml") + +client_agent = APPFLClientAgent(client_agent_config=client_agent_config) +client_communicator = GRPCClientCommunicator( + client_id = client_agent.get_id(), + **client_agent_config.comm_configs.grpc_configs, +) + +client_config = client_communicator.get_configuration() +client_agent.load_config(client_config) + +init_global_model = client_communicator.get_global_model(init_model=True) +client_agent.load_parameters(init_global_model) + +# Send the number of local data to the server +sample_size = client_agent.get_sample_size() +client_weight = client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size, sync=True) +client_agent.trainer.set_weight(client_weight["client_weight"]) + +while True: + client_agent.train() + local_model = client_agent.get_parameters() + new_global_model, metadata = client_communicator.update_global_model(local_model) + if metadata['status'] == 'DONE': + break + if 'local_steps' in metadata: + client_agent.trainer.train_configs.num_local_steps = metadata['local_steps'] + client_agent.load_parameters(new_global_model) \ No newline at end of file diff --git a/examples/mpi/run_mpi_iceadmm.py b/examples/mpi/run_mpi_iceadmm.py new file mode 100644 index 0000000..b61ab56 --- /dev/null +++ b/examples/mpi/run_mpi_iceadmm.py @@ -0,0 +1,64 @@ +""" +Running the ICEADMM algorithm using MPI for FL. This example mainly shows +the extendibility of the framework to support custom algorithms. In this case, +the server and clients need to communicate primal and dual states, and a +penalty parameter. In addition, the clients also need to know its relative +sample size for local training purposes. +""" + +import argparse +from mpi4py import MPI +from omegaconf import OmegaConf +from appfl.agent import APPFLClientAgent, APPFLServerAgent +from appfl.comm.mpi import MPIClientCommunicator, MPIServerCommunicator + +argparse = argparse.ArgumentParser() +argparse.add_argument("--server_config", type=str, default="config/server_iceadmm.yaml") +argparse.add_argument("--client_config", type=str, default="config/client_1.yaml") +args = argparse.parse_args() + +comm = MPI.COMM_WORLD +rank = comm.Get_rank() +size = comm.Get_size() +num_clients = size - 1 + +if rank == 0: + # Load and set the server configurations + server_agent_config = OmegaConf.load(args.server_config) + server_agent_config.server_configs.scheduler_kwargs.num_clients = num_clients + if hasattr(server_agent_config.server_configs.aggregator_kwargs, "num_clients"): + server_agent_config.server_configs.aggregator_kwargs.num_clients = num_clients + # Create the server agent and communicator + server_agent = APPFLServerAgent(server_agent_config=server_agent_config) + server_communicator = MPIServerCommunicator(comm, server_agent) + # Start the server to serve the clients + server_communicator.serve() +else: + # Set the client configurations + client_agent_config = OmegaConf.load(args.client_config) + client_agent_config.train_configs.logging_id = f'Client{rank}' + client_agent_config.data_configs.dataset_kwargs.num_clients = num_clients + client_agent_config.data_configs.dataset_kwargs.client_id = rank - 1 + client_agent_config.data_configs.dataset_kwargs.visualization = True if rank == 1 else False + # Create the client agent and communicator + client_agent = APPFLClientAgent(client_agent_config=client_agent_config) + client_communicator = MPIClientCommunicator(comm, server_rank=0) + # Load the configurations and initial global model + client_config = client_communicator.get_configuration() + client_agent.load_config(client_config) + init_global_model = client_communicator.get_global_model(init_model=True) + client_agent.load_parameters(init_global_model) + # (Specific to ICEADMM) Send the sample size to the server and set the client weight + sample_size = client_agent.get_sample_size() + client_weight = client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size, sync=True) + client_agent.trainer.set_weight(client_weight["client_weight"]) + # Local training and global model update iterations + while True: + client_agent.train() + local_model = client_agent.get_parameters() + new_global_model, metadata = client_communicator.update_global_model(local_model) + if metadata['status'] == 'DONE': + break + if 'local_steps' in metadata: + client_agent.trainer.train_configs.num_local_steps = metadata['local_steps'] + client_agent.load_parameters(new_global_model) \ No newline at end of file diff --git a/src/appfl/agent/server.py b/src/appfl/agent/server.py index bf7b669..f475839 100644 --- a/src/appfl/agent/server.py +++ b/src/appfl/agent/server.py @@ -10,7 +10,7 @@ from appfl.logger import ServerAgentFileLogger from concurrent.futures import Future from omegaconf import OmegaConf, DictConfig -from typing import Union, Dict, OrderedDict, Tuple +from typing import Union, Dict, OrderedDict, Tuple, Optional class APPFLServerAgent: """ @@ -91,9 +91,55 @@ def set_sample_size( self, client_id: Union[int, str], sample_size: int, - ) -> None: - """Set the size of the local dataset of a client.""" + sync: bool = False, + blocking: bool = False, + ) -> Optional[Union[Dict, Future]]: + """ + Set the size of the local dataset of a client. + :param: client_id: A unique client id for server to distinguish clients, which can be obtained via `ClientAgent.get_id()`. + :param: sample_size: The size of the local dataset of a client. + :param: sync: Whether to synchronize the sample size among all clients. If `True`, the method can return the relative weight of the client. + :param: blocking: Whether to block the client until the sample size of all clients is synchronized. + If `True`, the method will return the relative weight of the client. + Otherwise, the method may return a `Future` object of the relative weight, which will be resolved + when the sample size of all clients is synchronized. + """ + if sync: + assert ( + hasattr(self.server_agent_config.server_configs, "num_clients") or + hasattr(self.server_agent_config.server_configs.scheduler_kwargs, "num_clients") or + hasattr(self.server_agent_config.server_configs.aggregator_kwargs, "num_clients") + ), "The number of clients should be set in the server configurations." + num_clients = ( + self.server_agent_config.server_configs.num_clients if + hasattr(self.server_agent_config.server_configs, "num_clients") else + self.server_agent_config.server_configs.scheduler_kwargs.num_clients if + hasattr(self.server_agent_config.server_configs.scheduler_kwargs, "num_clients") else + self.server_agent_config.server_configs.aggregator_kwargs.num_clients + ) self.aggregator.set_client_sample_size(client_id, sample_size) + if sync: + if not hasattr(self, "_client_sample_size"): + self._client_sample_size = {} + self._client_sample_size_future = {} + self._client_sample_size_lock = threading.Lock() + with self._client_sample_size_lock: + self._client_sample_size[client_id] = sample_size + future = Future() + self._client_sample_size_future[client_id] = future + if len(self._client_sample_size) == num_clients: + total_sample_size = sum(self._client_sample_size.values()) + for client_id in self._client_sample_size_future: + self._client_sample_size_future[client_id].set_result( + {"client_weight": self._client_sample_size[client_id] / total_sample_size} + ) + self._client_sample_size = {} + self._client_sample_size_future = {} + if blocking: + return future.result() + else: + return future + return None def training_finished(self, internal_check: bool = False) -> bool: """Notify the client whether the training is finished.""" diff --git a/src/appfl/aggregator/__init__.py b/src/appfl/aggregator/__init__.py index 85a49cd..84b0685 100644 --- a/src/appfl/aggregator/__init__.py +++ b/src/appfl/aggregator/__init__.py @@ -1,4 +1,5 @@ from .base_aggregator import BaseAggregator from .fedavg_aggregator import FedAvgAggregator from .fedasync_aggregator import FedAsyncAggregator -from .fedcompass_aggregator import FedCompassAggregator \ No newline at end of file +from .fedcompass_aggregator import FedCompassAggregator +from .iceadmm_aggregator import ICEADMMAggregator \ No newline at end of file diff --git a/src/appfl/aggregator/iceadmm_aggregator.py b/src/appfl/aggregator/iceadmm_aggregator.py new file mode 100644 index 0000000..2a3f9cd --- /dev/null +++ b/src/appfl/aggregator/iceadmm_aggregator.py @@ -0,0 +1,107 @@ +import copy +import torch +import torch.nn as nn +from omegaconf import DictConfig +from collections import OrderedDict +from typing import Any, Dict, Union +from appfl.aggregator import BaseAggregator + +class ICEADMMAggregator(BaseAggregator): + def __init__( + self, + model: nn.Module, + aggregator_config: DictConfig, + logger: Any, + ): + self.model = model + self.logger = logger + self.aggregator_config = aggregator_config + self.named_parameters = set() + for name, _ in self.model.named_parameters(): + self.named_parameters.add(name) + self.is_first_iter = True + self.penalty = OrderedDict() + self.prim_res = 0 + self.dual_res = 0 + self.global_state = OrderedDict() + self.primal_states = OrderedDict() + self.dual_states = OrderedDict() + self.primal_states_curr = OrderedDict() + self.primal_states_prev = OrderedDict() + self.device = self.aggregator_config.device if "device" in self.aggregator_config else "cpu" + + def aggregate( + self, + local_models: Dict[Union[str, int], Union[Dict, OrderedDict]], + **kwargs + ) -> Dict: + if len(self.primal_states) == 0: + self.num_clients = len(local_models) + for i in local_models: + self.primal_states[i] = OrderedDict() + self.dual_states[i] = OrderedDict() + self.primal_states_curr[i] = OrderedDict() + self.primal_states_prev[i] = OrderedDict() + + global_state = copy.deepcopy(self.model.state_dict()) + + for client_id, model in local_models.items(): + if model is not None: + self.primal_states[client_id] = model["primal"] + self.dual_states[client_id] = model["dual"] + self.penalty[client_id] = model["penalty"] + + # Calculate the primal residual + primal_res = 0 + for client_id in local_models: + for name in self.named_parameters: + primal_res += torch.sum(torch.square( + global_state[name].to(self.device) + - self.primal_states[client_id][name].to(self.device) + )) + self.prim_res = torch.sqrt(primal_res).item() + + # Calculate the dual residual + dual_res = 0 + if self.is_first_iter: + for client_id in local_models: + for name in self.named_parameters: + self.primal_states_curr[client_id][name] = copy.deepcopy( + self.primal_states[client_id][name].to(self.device) + ) + self.is_first_iter = False + else: + self.primal_states_prev = copy.deepcopy(self.primal_states_curr) + for client_id in local_models: + for name in self.named_parameters: + self.primal_states_curr[client_id][name] = copy.deepcopy( + self.primal_states[client_id][name].to(self.device) + ) + for name in self.named_parameters: + res = 0 + for client_id in local_models: + res += self.penalty[client_id] * ( + self.primal_states_prev[client_id][name] + - self.primal_states_curr[client_id][name] + ) + dual_res += torch.sum(torch.square(res)) + self.dual_res = torch.sqrt(dual_res).item() + + total_penalty = 0 + for client_id in local_models: + total_penalty += self.penalty[client_id] + + for name, param in self.model.named_parameters(): + state_param = torch.zeros_like(param) + for client_id in local_models: + self.primal_states[client_id][name] = self.primal_states[client_id][name].to(self.device) + self.dual_states[client_id][name] = self.dual_states[client_id][name].to(self.device) + + state_param += (self.penalty[client_id] / total_penalty) * self.primal_states[client_id][name] + (1.0 / total_penalty) * self.dual_states[client_id][name] + global_state[name] = state_param + + self.model.load_state_dict(global_state) + return global_state + + def get_parameters(self, **kwargs) -> Dict: + return copy.deepcopy(self.model.state_dict()) \ No newline at end of file diff --git a/src/appfl/comm/grpc/grpc_server_communicator.py b/src/appfl/comm/grpc/grpc_server_communicator.py index 6eda394..ebc17fd 100644 --- a/src/appfl/comm/grpc/grpc_server_communicator.py +++ b/src/appfl/comm/grpc/grpc_server_communicator.py @@ -2,6 +2,7 @@ import logging from typing import Optional from omegaconf import OmegaConf +from concurrent.futures import Future from .grpc_communicator_pb2 import * from .grpc_communicator_pb2_grpc import * from appfl.agent import APPFLServerAgent @@ -125,11 +126,19 @@ def InvokeCustomAction(self, request, context): meta_data = json.loads(request.meta_data) if action == "set_sample_size": assert "sample_size" in meta_data, "The metadata should contain parameter `sample_size`." - datasize = meta_data['sample_size'] - self.server_agent.set_sample_size(client_id, datasize) - response = CustomActionResponse( - header=ServerHeader(status=ServerStatus.RUN), - ) + ret_val = self.server_agent.set_sample_size(client_id, **meta_data) + if ret_val is None: + response = CustomActionResponse( + header=ServerHeader(status=ServerStatus.RUN), + ) + else: + if isinstance(ret_val, Future): + ret_val = ret_val.result() + results = json.dumps(ret_val) + response = CustomActionResponse( + header=ServerHeader(status=ServerStatus.RUN), + results=results, + ) return response else: raise NotImplementedError(f"Custom action {action} is not implemented.") diff --git a/src/appfl/comm/mpi/mpi_server_communicator.py b/src/appfl/comm/mpi/mpi_server_communicator.py index 1a0d93c..a755920 100644 --- a/src/appfl/comm/mpi/mpi_server_communicator.py +++ b/src/appfl/comm/mpi/mpi_server_communicator.py @@ -1,3 +1,4 @@ +import time import json import logging from mpi4py import MPI @@ -21,7 +22,8 @@ def __init__( self.comm_size = comm.Get_size() self.server_agent = server_agent self.logger = logger if logger is not None else self._default_logger() - self._response_futures: Dict[int, Future] = {} + self._global_model_futures: Dict[int, Future] = {} + self._meta_data_futures: Dict[int, Future] = {} def serve(self): """ @@ -30,17 +32,19 @@ def serve(self): self.logger.info(f"Server starting...") status = MPI.Status() while not self.server_agent.server_terminated(): - self.comm.probe(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status) - source = status.Get_source() - tag = status.Get_tag() - count = status.Get_count(MPI.BYTE) - request_buffer = bytearray(count) - self.comm.Recv(request_buffer, source=source, tag=tag) - request = byte_to_request(request_buffer) - response = self._request_handler(client_id=source, request_tag=tag, request=request) - if response is not None: - response_bytes = response_to_byte(response) - self.comm.Send(response_bytes, dest=source, tag=source) + time.sleep(0.1) + msg_flag = self.comm.iprobe(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status) + if msg_flag: + source = status.Get_source() + tag = status.Get_tag() + count = status.Get_count(MPI.BYTE) + request_buffer = bytearray(count) + self.comm.Recv(request_buffer, source=source, tag=tag) + request = byte_to_request(request_buffer) + response = self._request_handler(client_id=source, request_tag=tag, request=request) + if response is not None: + response_bytes = response_to_byte(response) + self.comm.Send(response_bytes, dest=source, tag=source) self.logger.info(f"Server terminated.") def _request_handler( @@ -119,8 +123,8 @@ def _get_global_model( meta_data=meta_data, ) else: - self._response_futures[client_id] = model - self._check_response_futures() + self._global_model_futures[client_id] = model + self._check_global_model_futures() return None def _update_global_model( @@ -156,38 +160,63 @@ def _update_global_model( meta_data=meta_data, ) else: - self._response_futures[client_id] = global_model - self._check_response_futures() + self._global_model_futures[client_id] = global_model + self._check_global_model_futures() return None def _invoke_custom_action( self, client_id: int, request: MPITaskRequest, - ) -> MPITaskResponse: + ) -> Optional[MPITaskResponse]: """ Invoke custom action on the server. :param: `client_id`: A unique client ID, which is the rank of the client in MPI (only for logging purpose now) :param: `request.meta_data`: JSON serialized metadata dictionary (if needed) :return `response.status`: Server status + :return `response.meta_data`: JSON serialized metadata dictionary (if needed) """ self.logger.info(f"Received InvokeCustomAction request from client {client_id}") meta_data = json.loads(request.meta_data) if len(request.meta_data) > 0 else {} assert "action" in meta_data, "The action is not specified in the metadata" action = meta_data["action"] + del meta_data["action"] if action == "set_sample_size": - sample_size = meta_data["sample_size"] - self.server_agent.set_sample_size(client_id, sample_size) - return MPITaskResponse(status=MPIServerStatus.RUN.value) + meta_data["blocking"] = False + ret_val = self.server_agent.set_sample_size(client_id, **meta_data) + if ret_val is None: + return MPITaskResponse(status=MPIServerStatus.RUN.value) + else: + self._meta_data_futures[client_id] = ret_val + self._check_meta_data_futures() + return None else: raise NotImplementedError(f"Custom action {action} is not implemented.") + + def _check_meta_data_futures(self): + """ + Return the updated metadata to the client if the metadata `Future` object is available. + """ + delete_keys = [] + for client_id, future in self._meta_data_futures.items(): + if future.done(): + meta_data = future.result() + response = MPITaskResponse( + status=MPIServerStatus.RUN.value, + meta_data=json.dumps(meta_data), + ) + response_bytes = response_to_byte(response) + self.comm.Send(response_bytes, dest=client_id, tag=client_id) + delete_keys.append(client_id) + for key in delete_keys: + del self._meta_data_futures[key] - def _check_response_futures(self): + def _check_global_model_futures(self): """ Return the updated global model to the client if the global model `Future` object is available. """ delete_keys = [] - for client_id, future in self._response_futures.items(): + for client_id, future in self._global_model_futures.items(): if future.done(): global_model = future.result() if isinstance(global_model, tuple): @@ -206,7 +235,7 @@ def _check_response_futures(self): self.comm.Send(response_bytes, dest=client_id, tag=client_id) delete_keys.append(client_id) for key in delete_keys: - del self._response_futures[key] + del self._global_model_futures[key] def _default_logger(self): """Create a default logger for the gRPC server if no logger provided.""" diff --git a/src/appfl/compressor/compressor.py b/src/appfl/compressor/compressor.py index 38fcade..08d592e 100644 --- a/src/appfl/compressor/compressor.py +++ b/src/appfl/compressor/compressor.py @@ -80,9 +80,12 @@ def compress_model( num_lossy_elements = 0 compressed_models = OrderedDict() for key, weights in model.items(): - comprsessed_weights, lossy_elements = self._compress_weights(weights) - compressed_models[key] = comprsessed_weights - lossy_elements += lossy_elements + if isinstance(weights, dict) or isinstance(weights, OrderedDict): + comprsessed_weights, lossy_elements = self._compress_weights(weights) + compressed_models[key] = comprsessed_weights + lossy_elements += lossy_elements + else: + compressed_models[key] = weights else: compressed_models, num_lossy_elements = self._compress_weights(model) return pickle.dumps(compressed_models) @@ -123,7 +126,10 @@ def decompress_model( if is_nested: decompressed_model = OrderedDict() for key, value in compressed_model.items(): - decompressed_model[key] = self._decompress_model(value, model) + if isinstance(value, dict) or isinstance(value, OrderedDict): + decompressed_model[key] = self._decompress_model(value, model) + else: + decompressed_model[key] = value else: decompressed_model = self._decompress_model(compressed_model, model) return decompressed_model diff --git a/src/appfl/trainer/__init__.py b/src/appfl/trainer/__init__.py index a3aeb0f..dcca3c0 100644 --- a/src/appfl/trainer/__init__.py +++ b/src/appfl/trainer/__init__.py @@ -1,2 +1,3 @@ from .base_trainer import BaseTrainer -from .naive_trainer import NaiveTrainer \ No newline at end of file +from .naive_trainer import NaiveTrainer +from .iceadmm_trainer import ICEADMMTrainer \ No newline at end of file diff --git a/src/appfl/trainer/iceadmm_trainer.py b/src/appfl/trainer/iceadmm_trainer.py new file mode 100644 index 0000000..62cccf5 --- /dev/null +++ b/src/appfl/trainer/iceadmm_trainer.py @@ -0,0 +1,355 @@ +import copy +import time +import torch +import importlib +import numpy as np +import torch.nn as nn +from omegaconf import DictConfig +from collections import OrderedDict +from torch.utils.data import DataLoader +from typing import Any, Optional, Tuple +from appfl.trainer.base_trainer import BaseTrainer +from appfl.privacy import laplace_mechanism_output_perturb + +class ICEADMMTrainer(BaseTrainer): + """ + ICEADMM Trainer: + Local trainer for the ICEADMM algorithm. + This trainer must be used with the ICEADMMAggregator. + """ + def __init__( + self, + model: Optional[nn.Module]=None, + loss_fn: Optional[nn.Module]=None, + metric: Optional[Any]=None, + train_dataset: Optional[DataLoader]=None, + val_dataset: Optional[DataLoader]=None, + train_configs: DictConfig = DictConfig({}), + logger: Optional[Any]=None, + **kwargs + ): + super().__init__( + model=model, + loss_fn=loss_fn, + metric=metric, + train_dataset=train_dataset, + val_dataset=val_dataset, + train_configs=train_configs, + logger=logger, + **kwargs + ) + if not hasattr(self.train_configs, "device"): + self.train_configs.device = "cpu" + self.train_dataloader = DataLoader( + self.train_dataset, + batch_size=self.train_configs.get("train_batch_size", 32), + shuffle=self.train_configs.get("train_data_shuffle", True), + num_workers=self.train_configs.get("num_workers", 0), + ) + self.val_dataloader = DataLoader( + self.val_dataset, + batch_size=self.train_configs.get("val_batch_size", 32), + shuffle=self.train_configs.get("val_data_shuffle", False), + num_workers=self.train_configs.get("num_workers", 0), + ) if self.val_dataset is not None else None + + self.penalty = self.train_configs.get("init_penalty", 500.0) + self.proximity = self.train_configs.get("init_proximity", 0) + self.is_first_iter = True + + # At initial, (1) primal_states = global_state, (2) dual_states = 0 + self.primal_states = OrderedDict() + self.dual_states = OrderedDict() + self.primal_states_curr = OrderedDict() + self.primal_states_prev = OrderedDict() + self.named_parameters = set() + + for name, param in self.model.named_parameters(): + self.named_parameters.add(name) + self.primal_states[name] = param.data + self.dual_states[name] = torch.zeros_like(param.data) + + def train(self): + self._sanity_check() + self.model.train() + self.model.to(self.train_configs.device) + do_validation = self.train_configs.get("do_validation", False) and self.val_dataloader is not None + do_pre_validation = self.train_configs.get("do_pre_validation", False) and do_validation + + """Set up logging title""" + if self.round == 0: + title = ( + ["Round", "Time", "Train Loss", "Train Accuracy"] + if not do_validation + else ( + ["Round", "Pre Val?", "Time", "Train Loss", "Train Accuracy", "Val Loss", "Val Accuracy"] + if do_pre_validation + else ["Round", "Time", "Train Loss", "Train Accuracy", "Val Loss", "Val Accuracy"] + ) + ) + if self.train_configs.mode == "epoch": + title.insert(1, "Epoch") + self.logger.log_title(title) + + pre_val_interval = self.train_configs.get("pre_validation_interval", 1) + if do_pre_validation and (self.round + 1) % pre_val_interval == 0: + val_loss, val_accuracy = self._validate() + content = [self.round, "Y", " ", " ", " ", val_loss, val_accuracy] + if self.train_configs.mode == "epoch": + content.insert(1, 0) + self.logger.log_content(content) + + optim_module = importlib.import_module("torch.optim") + assert hasattr(optim_module, self.train_configs.optim), f"Optimizer {self.train_configs.optim} not found in torch.optim" + optimizer = getattr(optim_module, self.train_configs.optim)(self.model.parameters(), **self.train_configs.optim_args) + + """ Inputs for the local model update """ + global_state = copy.deepcopy(self.model.state_dict()) + + """Adaptive Penalty (Residual Balancing)""" + if not hasattr(self.train_configs, "residual_balancing"): + self.train_configs.residual_balancing = DictConfig({}) + if getattr(self.train_configs.residual_balancing, "res_on", False): + prim_res = self._primal_residual_at_client(global_state) + dual_res = self._dual_residual_at_client() + self._residual_balancing(prim_res, dual_res) + + if self.train_configs.mode == "epoch": + for epoch in range(self.train_configs.num_local_epochs): + start_time = time.time() + train_loss, target_true, target_pred = 0, [], [] + for data, target in self.train_dataloader: + loss, pred, label = self._train_batch(optimizer, data, target, global_state) + train_loss += loss + target_true.append(label) + target_pred.append(pred) + train_loss /= len(self.train_dataloader) + target_true, target_pred = np.concatenate(target_true), np.concatenate(target_pred) + train_accuracy = float(self.metric(target_true, target_pred)) + if do_validation: + val_loss, val_accuracy = self._validate() + per_epoch_time = time.time() - start_time + self.logger.log_content( + [self.round, epoch, per_epoch_time, train_loss, train_accuracy] + if not do_validation + else ( + [self.round, epoch, per_epoch_time, train_loss, train_accuracy, val_loss, val_accuracy] + if not do_pre_validation + else + [self.round, epoch, 'N', per_epoch_time, train_loss, train_accuracy, val_loss, val_accuracy] + ) + ) + else: + start_time = time.time() + data_iter = iter(self.train_dataloader) + train_loss, target_true, target_pred = 0, [], [] + for _ in range(self.train_configs.num_local_steps): + try: + data, target = next(data_iter) + except: + data_iter = iter(self.train_dataloader) + data, target = next(data_iter) + loss, pred, label = self._train_batch(optimizer, data, target, global_state) + train_loss += loss + target_true.append(label) + target_pred.append(pred) + train_loss /= len(self.train_dataloader) + target_true, target_pred = np.concatenate(target_true), np.concatenate(target_pred) + train_accuracy = float(self.metric(target_true, target_pred)) + if do_validation: + val_loss, val_accuracy = self._validate() + per_step_time = time.time() - start_time + self.logger.log_content( + [self.round, per_step_time, train_loss, train_accuracy] + if not do_validation + else ( + [self.round, per_step_time, train_loss, train_accuracy, val_loss, val_accuracy] + if not do_pre_validation + else + [self.round, 'N', per_step_time, train_loss, train_accuracy, val_loss, val_accuracy] + ) + ) + + self.round += 1 + + + for name, param in self.model.named_parameters(): + param.data = self.primal_states[name].to(self.train_configs.device) + if self.train_configs.get("use_dp", False): + sensitivity = 2.0 * self.train_configs.clip_value / self.penalty + self._model_state = laplace_mechanism_output_perturb( + self.model, + sensitivity, + self.train_configs.epsilon + ) + else: + self._model_state = copy.deepcopy(self.model.state_dict()) + + """Move to CPU for communication""" + if self.train_configs.get("device", "cpu") == "cuda": + for k in self._model_state: + self._model_state[k] = self._model_state[k].cpu() + for name in self.named_parameters: + self.dual_states[name] = self.dual_states[name].cpu() + + self.model_state = OrderedDict() + self.model_state["primal"] = self._model_state + self.model_state["dual"] = self.dual_states + self.model_state["penalty"] = self.penalty + + def get_parameters(self) -> OrderedDict: + hasattr(self, "model_state"), "Please make sure the model has been trained before getting its parameters" + return self.model_state + + def set_weight(self, weight=1.0): + """Set the weight of the client""" + self.weight = weight + + def _sanity_check(self): + """ + Check if the necessary configurations are provided. + """ + assert hasattr(self.train_configs, "mode"), "Training mode must be specified" + assert self.train_configs.mode in ["epoch", "step"], "Training mode must be either 'epoch' or 'step'" + if self.train_configs.mode == "epoch": + assert hasattr(self.train_configs, "num_local_epochs"), "Number of local epochs must be specified" + else: + assert hasattr(self.train_configs, "num_local_steps"), "Number of local steps must be specified" + assert hasattr(self, "weight"), "You must set the weight of the client before training. Use `set_weight` method." + if getattr(self.train_configs, "clip_grad", False) or getattr(self.train_configs, "use_dp", False): + assert hasattr(self.train_configs, "clip_value"), "Gradient clipping value must be specified" + assert hasattr(self.train_configs, "clip_norm"), "Gradient clipping norm must be specified" + if getattr(self.train_configs, "use_dp", False): + assert hasattr(self.train_configs, "epsilon"), "Privacy budget (epsilon) must be specified" + + def _primal_residual_at_client(self, global_state) -> float: + """ + Calculate primal residual. + :param global_state: global state - input for the local model update + :return: primal residual + """ + primal_res = 0 + for name in self.named_parameters: + primal_res += torch.sum( + torch.square( + global_state[name].to(self.train_configs.device) + - self.primal_states[name].to(self.train_configs.device) + ) + ) + primal_res = torch.sqrt(primal_res).item() + return primal_res + + def _dual_residual_at_client(self) -> float: + """ + Calculate dual residual. + :return: dual residual + """ + dual_res = 0 + if self.is_first_iter: + self.primal_states_curr = self.primal_states + self.is_first_iter = False + else: + self.primal_states_prev = self.primal_states_curr + self.primal_states_curr = self.primal_states + for name in self.named_parameters: + res = self.penalty * ( + self.primal_states_prev[name] - self.primal_states_curr[name] + ) + dual_res += torch.sum(torch.square(res)) + dual_res = torch.sqrt(dual_res).item() + return dual_res + + def _residual_balancing(self, prim_res, dual_res): + if prim_res > self.train_configs.residual_balancing.mu * dual_res: + self.penalty = self.penalty * self.train_configs.residual_balancing.tau + if dual_res > self.train_configs.residual_balancing.mu * prim_res: + self.penalty = self.penalty / self.train_configs.residual_balancing.tau + + def _train_batch(self, optimizer, data, target, global_state): + """ + Train the model for one batch of data + :param optimizer: torch optimizer + :param data: input data + :param target: target label + :param global_state: global model state + :return: loss, prediction, label + """ + """Load primal states to the model""" + for name, param in self.model.named_parameters(): + param.data = self.primal_states[name].to(self.train_configs.device) + + """Adaptive Penalty (Residual Balancing)""" + if ( + getattr(self.train_configs.residual_balancing, "res_on", False) and + getattr(self.train_configs.residual_balancing, "res_on_every_update", False) + ): + prim_res = self._primal_residual_at_client(global_state) + dual_res = self._dual_residual_at_client() + self._residual_balancing(prim_res, dual_res) + + """Train the model""" + data, target = data.to(self.train_configs.device), target.to(self.train_configs.device) + if getattr(self.train_configs, "accum_grad", False) == False: + optimizer.zero_grad() + output = self.model(data) + loss = self.loss_fn(output, target) + loss.backward() + optimizer.step() + + """Gradient Clipping""" + if getattr(self.train_configs, "clip_grad", False) or getattr(self.train_configs, "use_dp", False): + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.train_configs.clip_value, + norm_type=self.train_configs.clip_norm, + ) + + """Update primal and dual states""" + coefficient = 1 + if getattr(self.train_configs, "coeff_grad", False): + coefficient = self.weight * len(target) / len(self.train_dataloader.dataset) + self._iceadmm_step(coefficient, global_state) + + return loss.item(), output.detach().cpu().numpy(), target.detach().cpu().numpy() + + def _iceadmm_step(self, coefficient, global_state): + """ + Update primal and dual states + """ + for name, param in self.model.named_parameters(): + self.primal_states[name] = self.primal_states[name].to(self.train_configs.device) + self.dual_states[name] = self.dual_states[name].to(self.train_configs.device) + global_state[name] = global_state[name].to(self.train_configs.device) + + grad = param.grad * coefficient + """Update primal""" + self.primal_states[name] = self.primal_states[name] - ( + self.penalty * (self.primal_states[name] - global_state[name]) + + grad + + self.dual_states[name] + ) / (self.weight * self.proximity + self.penalty) + """Update dual""" + self.dual_states[name] = self.dual_states[name] + self.penalty * ( + self.primal_states[name] - global_state[name] + ) + + def _validate(self) -> Tuple[float, float]: + """ + Validate the model + :return: loss, accuracy + """ + device = self.train_configs.get("device", "cpu") + self.model.eval() + val_loss = 0 + with torch.no_grad(): + target_pred, target_true = [], [] + for data, target in self.val_dataloader: + data, target = data.to(device), target.to(device) + output = self.model(data) + val_loss += self.loss_fn(output, target).item() + target_true.append(target.detach().cpu().numpy()) + target_pred.append(output.detach().cpu().numpy()) + val_loss /= len(self.val_dataloader) + val_accuracy = float(self.metric(np.concatenate(target_true), np.concatenate(target_pred))) + self.model.train() + return val_loss, val_accuracy \ No newline at end of file diff --git a/src/appfl/trainer/naive_trainer.py b/src/appfl/trainer/naive_trainer.py index b4e2aa4..2eb160f 100644 --- a/src/appfl/trainer/naive_trainer.py +++ b/src/appfl/trainer/naive_trainer.py @@ -79,7 +79,7 @@ def train(self): # Set up logging title if self.round == 0: title = ( - ["Round", "Pre Val?" "Time", "Train Loss", "Train Accuracy"] + ["Round", "Time", "Train Loss", "Train Accuracy"] if not do_validation else ( ["Round", "Pre Val?", "Time", "Train Loss", "Train Accuracy", "Val Loss", "Val Accuracy"] @@ -93,7 +93,10 @@ def train(self): if do_pre_validation: val_loss, val_accuracy = self._validate() - self.logger.log_content([self.round, "Y", " ", " ", " ", val_loss, val_accuracy]) + content = [self.round, "Y", " ", " ", " ", val_loss, val_accuracy] + if self.train_configs.mode == "epoch": + content.insert(1, 0) + self.logger.log_content(content) # Start training optim_module = importlib.import_module("torch.optim")