diff --git a/examples/config/server_fedbuff.yaml b/examples/config/server_fedbuff.yaml new file mode 100644 index 0000000..4f8f1fa --- /dev/null +++ b/examples/config/server_fedbuff.yaml @@ -0,0 +1,75 @@ +client_configs: + train_configs: + # Local trainer + trainer: "NaiveTrainer" + mode: "step" + num_local_steps: 100 + optim: "Adam" + optim_args: + lr: 0.001 + # Loss function + loss_fn_path: "./loss/celoss.py" + loss_fn_name: "CELoss" + # Client validation + do_validation: True + do_pre_validation: True + metric_path: "./metric/acc.py" + metric_name: "accuracy" + # Differential privacy + use_dp: False + epsilon: 1 + clip_grad: False + clip_value: 1 + clip_norm: 1 + # Data format + send_gradient: True + # Data loader + train_batch_size: 64 + val_batch_size: 64 + train_data_shuffle: True + val_data_shuffle: False + + model_configs: + model_path: "./model/cnn.py" + model_name: "CNN" + model_kwargs: + num_channel: 1 + num_classes: 10 + num_pixel: 28 + + comm_configs: + compressor_configs: + enable_compression: False + # Used if enable_compression is True + lossy_compressor: "SZ2" + lossless_compressor: "blosc" + error_bounding_mode: "REL" + error_bound: 1e-3 + flat_model_dtype: "np.float32" + param_cutoff: 1024 + +server_configs: + scheduler: "AsyncScheduler" + scheduler_kwargs: + num_clients: 2 + same_init_model: True + aggregator: "FedBuffAggregator" + aggregator_kwargs: + client_weights_mode: "equal" + num_clients: 2 + staleness_fn: "polynomial" + staleness_fn_kwargs: + a: 0.5 + alpha: 0.9 + gradient_based: True + K: 3 + device: "cpu" + num_global_epochs: 20 + server_validation: False + logging_output_dirname: "./output" + logging_output_filename: "result" + comm_configs: + grpc_configs: + server_uri: localhost:50051 + max_message_size: 1048576 + use_ssl: False \ No newline at end of file diff --git a/examples/config/server_iiadmm.yaml b/examples/config/server_iiadmm.yaml new file mode 100644 index 0000000..edf3564 --- /dev/null +++ b/examples/config/server_iiadmm.yaml @@ -0,0 +1,76 @@ +client_configs: + train_configs: + # Local trainer + trainer: "IIADMMTrainer" + mode: "step" + num_local_steps: 100 + optim: "Adam" + optim_args: + lr: 0.001 + # Algorithm specific + accum_grad: True + coeff_grad: False + init_penalty: 100.0 + residual_balancing: + res_on: False + res_on_every_update: False + tau: 1.1 + mu: 10 + # Loss function + loss_fn_path: "./loss/celoss.py" + loss_fn_name: "CELoss" + # Client validation + do_validation: True + do_pre_validation: True + pre_validation_interval: 1 + metric_path: "./metric/acc.py" + metric_name: "accuracy" + # Differential privacy + use_dp: False + epsilon: 1 + clip_grad: False + clip_value: 1 + clip_norm: 1 + # Data loader + train_batch_size: 64 + val_batch_size: 64 + train_data_shuffle: True + val_data_shuffle: False + + model_configs: + model_path: "./model/cnn.py" + model_name: "CNN" + model_kwargs: + num_channel: 1 + num_classes: 10 + num_pixel: 28 + + comm_configs: + compressor_configs: + enable_compression: False + # Used if enable_compression is True + lossy_compressor: "SZ2" + lossless_compressor: "blosc" + error_bounding_mode: "REL" + error_bound: 1e-3 + flat_model_dtype: "np.float32" + param_cutoff: 1024 + +server_configs: + scheduler: "SyncScheduler" + scheduler_kwargs: + num_clients: 2 + same_init_model: True + aggregator: "IIADMMAggregator" + aggregator_kwargs: + num_clients: 2 + device: "cpu" + num_global_epochs: 10 + server_validation: False + logging_output_dirname: "./output" + logging_output_filename: "result" + comm_configs: + grpc_configs: + server_uri: localhost:50051 + max_message_size: 1048576 + use_ssl: False \ No newline at end of file diff --git a/examples/mpi/run_mpi_iceadmm.py b/examples/mpi/run_mpi_admm.py similarity index 89% rename from examples/mpi/run_mpi_iceadmm.py rename to examples/mpi/run_mpi_admm.py index b61ab56..180406f 100644 --- a/examples/mpi/run_mpi_iceadmm.py +++ b/examples/mpi/run_mpi_admm.py @@ -1,9 +1,11 @@ """ -Running the ICEADMM algorithm using MPI for FL. This example mainly shows +Running the ADMM-based algorithm using MPI for FL. This example mainly shows the extendibility of the framework to support custom algorithms. In this case, the server and clients need to communicate primal and dual states, and a penalty parameter. In addition, the clients also need to know its relative sample size for local training purposes. +mpiexec -n 6 python mpi/run_mpi_admm.py --server_config config/server_iiadmm.yaml +mpiexec -n 6 python mpi/run_mpi_admm.py --server_config config/server_iceadmm.yaml """ import argparse @@ -48,7 +50,7 @@ client_agent.load_config(client_config) init_global_model = client_communicator.get_global_model(init_model=True) client_agent.load_parameters(init_global_model) - # (Specific to ICEADMM) Send the sample size to the server and set the client weight + # (Specific to ICEADMM and IIADMM) Send the sample size to the server and set the client weight sample_size = client_agent.get_sample_size() client_weight = client_communicator.invoke_custom_action(action='set_sample_size', sample_size=sample_size, sync=True) client_agent.trainer.set_weight(client_weight["client_weight"]) diff --git a/src/appfl/aggregator/__init__.py b/src/appfl/aggregator/__init__.py index 206cff3..33ba527 100644 --- a/src/appfl/aggregator/__init__.py +++ b/src/appfl/aggregator/__init__.py @@ -5,5 +5,7 @@ from .fedyogi_aggregator import FedYogiAggregator from .fedadagrad_aggregator import FedAdagradAggregator from .fedasync_aggregator import FedAsyncAggregator +from .fedbuff_aggregator import FedBuffAggregator from .fedcompass_aggregator import FedCompassAggregator +from .iiadmm_aggregator import IIADMMAggregator from .iceadmm_aggregator import ICEADMMAggregator \ No newline at end of file diff --git a/src/appfl/aggregator/fedasync_aggregator.py b/src/appfl/aggregator/fedasync_aggregator.py index 85edbe0..ea505a3 100644 --- a/src/appfl/aggregator/fedasync_aggregator.py +++ b/src/appfl/aggregator/fedasync_aggregator.py @@ -5,6 +5,10 @@ from typing import Union, Dict, OrderedDict, Any class FedAsyncAggregator(BaseAggregator): + """ + FedAsync Aggregator class for Federated Learning. + For more details, check paper: https://arxiv.org/pdf/1903.03934.pdf + """ def __init__( self, model: torch.nn.Module, @@ -26,12 +30,33 @@ def __init__( self.alpha = self.aggregator_config.get("alpha", 0.9) self.global_step = 0 self.client_step = {} + self.step = {} + + def get_parameters(self, **kwargs) -> Dict: + return copy.deepcopy(self.model.state_dict()) def aggregate(self, client_id: Union[str, int], local_model: Union[Dict, OrderedDict], **kwargs) -> Dict: + global_state = copy.deepcopy(self.model.state_dict()) + + self.compute_steps(client_id, local_model) + + for name in self.model.state_dict(): + if name not in self.named_parameters: + global_state[name] = local_model[name] + else: + global_state[name] += self.step[name] + self.model.load_state_dict(global_state) + self.global_step += 1 + self.client_step[client_id] = self.global_step + return global_state + + def compute_steps(self, client_id: Union[str, int], local_model: Union[Dict, OrderedDict],): + """ + Compute changes to the global model after the aggregation. + """ if client_id not in self.client_step: self.client_step[client_id] = 0 gradient_based = self.aggregator_config.get("gradient_based", False) - global_state = copy.deepcopy(self.model.state_dict()) if ( self.client_weights_mode == "sample_size" and hasattr(self, "client_sample_size") and @@ -41,21 +66,11 @@ def aggregate(self, client_id: Union[str, int], local_model: Union[Dict, Ordered else: weight = 1.0 / self.aggregator_config.get("num_clients", 1) alpha_t = self.alpha * self.staleness_fn(self.global_step - self.client_step[client_id]) * weight - for name in self.model.state_dict(): - if name in self.named_parameters: - if gradient_based: - global_state[name] -= local_model[name] * alpha_t - else: - global_state[name] = (1-alpha_t) * global_state[name] + alpha_t * local_model[name] - else: - global_state[name] = local_model[name] - self.model.load_state_dict(global_state) - self.global_step += 1 - self.client_step[client_id] = self.global_step - return global_state - - def get_parameters(self, **kwargs) -> Dict: - return copy.deepcopy(self.model.state_dict()) + for name in self.named_parameters: + self.step[name] = ( + alpha_t * (-local_model[name]) if gradient_based + else alpha_t * (local_model[name] - self.model.state_dict()[name]) + ) def __staleness_fn_factory(self, staleness_fn_name, **kwargs): if staleness_fn_name == "constant": @@ -69,4 +84,3 @@ def __staleness_fn_factory(self, staleness_fn_name, **kwargs): return lambda u: 1 if u <= b else 1.0/ (a * (u - b) + 1.0) else: raise NotImplementedError - \ No newline at end of file diff --git a/src/appfl/aggregator/fedavg_aggregator.py b/src/appfl/aggregator/fedavg_aggregator.py index 076fb60..78ace70 100644 --- a/src/appfl/aggregator/fedavg_aggregator.py +++ b/src/appfl/aggregator/fedavg_aggregator.py @@ -49,7 +49,7 @@ def compute_steps(self, local_models: Dict[Union[str, int], Union[Dict, OrderedD """ Compute the changes to the global model after the aggregation. """ - for name in self.model.state_dict(): + for name in self.named_parameters: self.step[name] = torch.zeros_like(self.model.state_dict()[name]) for client_id, model in local_models.items(): if ( diff --git a/src/appfl/aggregator/fedbuff_aggregator.py b/src/appfl/aggregator/fedbuff_aggregator.py new file mode 100644 index 0000000..b0a0fec --- /dev/null +++ b/src/appfl/aggregator/fedbuff_aggregator.py @@ -0,0 +1,68 @@ +import copy +import torch +from omegaconf import DictConfig +from appfl.aggregator import FedAsyncAggregator +from typing import Union, Dict, OrderedDict, Any + +class FedBuffAggregator(FedAsyncAggregator): + """ + FedBuff Aggregator class for Federated Learning. + For more details, check paper: https://proceedings.mlr.press/v151/nguyen22b/nguyen22b.pdf + """ + def __init__( + self, + model: torch.nn.Module, + aggregator_config: DictConfig, + logger: Any + ): + super().__init__(model, aggregator_config, logger) + self.buff_size = 0 + self.K = self.aggregator_config.K + + def aggregate(self, client_id: Union[str, int], local_model: Union[Dict, OrderedDict], **kwargs) -> Dict: + global_state = copy.deepcopy(self.model.state_dict()) + + self.compute_steps(client_id, local_model) + self.buff_size += 1 + if self.buff_size == self.K: + for name in self.model.state_dict(): + if name not in self.named_parameters: + global_state[name] = torch.div(self.step[name], self.K) + else: + global_state[name] += self.step[name] + self.model.load_state_dict(global_state) + self.global_step += 1 + self.buff_size = 0 + + self.client_step[client_id] = self.global_step + return global_state + + def compute_steps(self, client_id: Union[str, int], local_model: Union[Dict, OrderedDict],): + """ + Compute changes to the global model after the aggregation. + """ + if self.buff_size == 0: + for name in self.model.state_dict(): + self.step[name] = torch.zeros_like(self.model.state_dict()[name]) + + if client_id not in self.client_step: + self.client_step[client_id] = 0 + gradient_based = self.aggregator_config.get("gradient_based", False) + if ( + self.client_weights_mode == "sample_size" and + hasattr(self, "client_sample_size") and + client_id in self.client_sample_size + ): + weight = self.client_sample_size[client_id] / sum(self.client_sample_size.values()) + else: + weight = 1.0 / self.aggregator_config.get("num_clients", 1) + alpha_t = self.alpha * self.staleness_fn(self.global_step - self.client_step[client_id]) * weight + + for name in self.model.state_dict(): + if name in self.named_parameters: + self.step[name] += ( + alpha_t * (-local_model[name]) if gradient_based + else alpha_t * (local_model[name] - self.model.state_dict()[name]) + ) + else: + self.step[name] += local_model[name] diff --git a/src/appfl/aggregator/fedcompass_aggregator.py b/src/appfl/aggregator/fedcompass_aggregator.py index 6f8fe6f..36933ee 100644 --- a/src/appfl/aggregator/fedcompass_aggregator.py +++ b/src/appfl/aggregator/fedcompass_aggregator.py @@ -6,8 +6,8 @@ class FedCompassAggregator(BaseAggregator): """ - Aggregator for `FedCompass` semi-asynchronous federated learning algorithm. - Paper reference: https://arxiv.org/abs/2309.14675 + FedCompass semi-asynchronous federated learning algorithm. + For more details, check paper: https://arxiv.org/abs/2309.14675 """ def __init__( self, diff --git a/src/appfl/aggregator/iceadmm_aggregator.py b/src/appfl/aggregator/iceadmm_aggregator.py index 2a3f9cd..fdd7063 100644 --- a/src/appfl/aggregator/iceadmm_aggregator.py +++ b/src/appfl/aggregator/iceadmm_aggregator.py @@ -7,6 +7,11 @@ from appfl.aggregator import BaseAggregator class ICEADMMAggregator(BaseAggregator): + """ + ICEADMM Aggregator class for Federated Learning. + It has to be used with the ICEADMMTrainer. + For more details, check paper: https://arxiv.org/pdf/2110.15318.pdf + """ def __init__( self, model: nn.Module, diff --git a/src/appfl/aggregator/iiadmm_aggregator.py b/src/appfl/aggregator/iiadmm_aggregator.py new file mode 100644 index 0000000..86a6768 --- /dev/null +++ b/src/appfl/aggregator/iiadmm_aggregator.py @@ -0,0 +1,113 @@ +import copy +import torch +import torch.nn as nn +from omegaconf import DictConfig +from collections import OrderedDict +from typing import Any, Dict, Union +from appfl.aggregator import BaseAggregator + +class IIADMMAggregator(BaseAggregator): + """ + IIADMMAggregator Aggregator class for Federated Learning. + It has to be used with the IIADMMTrainer. + For more details, check paper: https://arxiv.org/pdf/2202.03672.pdf + """ + def __init__( + self, + model: nn.Module, + aggregator_config: DictConfig, + logger: Any, + ): + self.model = model + self.logger = logger + self.aggregator_config = aggregator_config + self.named_parameters = set() + for name, _ in self.model.named_parameters(): + self.named_parameters.add(name) + self.is_first_iter = True + self.penalty = OrderedDict() + self.prim_res = 0 + self.dual_res = 0 + self.global_state = OrderedDict() + self.primal_states = OrderedDict() + self.dual_states = OrderedDict() + self.primal_states_curr = OrderedDict() + self.primal_states_prev = OrderedDict() + self.device = self.aggregator_config.device if "device" in self.aggregator_config else "cpu" + + def aggregate( + self, + local_models: Dict[Union[str, int], Union[Dict, OrderedDict]], + **kwargs + ) -> Dict: + if len(self.primal_states) == 0: + self.num_clients = len(local_models) + for i in local_models: + self.primal_states[i] = OrderedDict() + self.dual_states[i] = OrderedDict() + self.primal_states_curr[i] = OrderedDict() + self.primal_states_prev[i] = OrderedDict() + # dual_state = 0 at the beginning + for name in self.named_parameters: + self.dual_states[i][name] = torch.zeros_like(self.model.state_dict()[name]) + + global_state = copy.deepcopy(self.model.state_dict()) + + for client_id, model in local_models.items(): + if model is not None: + self.primal_states[client_id] = model["primal"] + self.penalty[client_id] = model["penalty"] + + # Calculate the primal residual + primal_res = 0 + for client_id in local_models: + for name in self.named_parameters: + primal_res += torch.sum(torch.square( + global_state[name].to(self.device) + - self.primal_states[client_id][name].to(self.device) + )) + self.prim_res = torch.sqrt(primal_res).item() + + # Calculate the dual residual + dual_res = 0 + if self.is_first_iter: + for client_id in local_models: + for name in self.named_parameters: + self.primal_states_curr[client_id][name] = copy.deepcopy( + self.primal_states[client_id][name].to(self.device) + ) + self.is_first_iter = False + else: + self.primal_states_prev = copy.deepcopy(self.primal_states_curr) + for client_id in local_models: + for name in self.named_parameters: + self.primal_states_curr[client_id][name] = copy.deepcopy( + self.primal_states[client_id][name].to(self.device) + ) + for name in self.named_parameters: + res = 0 + for client_id in local_models: + res += self.penalty[client_id] * ( + self.primal_states_prev[client_id][name] + - self.primal_states_curr[client_id][name] + ) + dual_res += torch.sum(torch.square(res)) + self.dual_res = torch.sqrt(dual_res).item() + + for name, param in self.model.named_parameters(): + state_param = torch.zeros_like(param) + for client_id in local_models: + self.dual_states[client_id][name] += self.penalty[client_id] * ( + global_state[name] - self.primal_states[client_id][name] + ) + state_param += ( + self.primal_states[client_id][name] + - (1.0 / self.penalty[client_id]) * self.dual_states[client_id][name] + ) + global_state[name] = state_param / self.num_clients + + self.model.load_state_dict(global_state) + return global_state + + def get_parameters(self, **kwargs) -> Dict: + return copy.deepcopy(self.model.state_dict()) \ No newline at end of file diff --git a/src/appfl/trainer/__init__.py b/src/appfl/trainer/__init__.py index dcca3c0..7f32295 100644 --- a/src/appfl/trainer/__init__.py +++ b/src/appfl/trainer/__init__.py @@ -1,3 +1,4 @@ from .base_trainer import BaseTrainer from .naive_trainer import NaiveTrainer +from .iiadmm_trainer import IIADMMTrainer from .iceadmm_trainer import ICEADMMTrainer \ No newline at end of file diff --git a/src/appfl/trainer/iceadmm_trainer.py b/src/appfl/trainer/iceadmm_trainer.py index db58ac7..b08068d 100644 --- a/src/appfl/trainer/iceadmm_trainer.py +++ b/src/appfl/trainer/iceadmm_trainer.py @@ -71,6 +71,7 @@ def __init__( self._sanity_check() def train(self): + assert hasattr(self, "weight"), "You must set the weight of the client before training. Use `set_weight` method." self.model.train() self.model.to(self.train_configs.device) do_validation = self.train_configs.get("do_validation", False) and self.val_dataloader is not None @@ -172,7 +173,7 @@ def train(self): self.round += 1 - + """Differential Privacy""" for name, param in self.model.named_parameters(): param.data = self.primal_states[name].to(self.train_configs.device) if self.train_configs.get("use_dp", False): @@ -215,7 +216,6 @@ def _sanity_check(self): assert hasattr(self.train_configs, "num_local_epochs"), "Number of local epochs must be specified" else: assert hasattr(self.train_configs, "num_local_steps"), "Number of local steps must be specified" - assert hasattr(self, "weight"), "You must set the weight of the client before training. Use `set_weight` method." if getattr(self.train_configs, "clip_grad", False) or getattr(self.train_configs, "use_dp", False): assert hasattr(self.train_configs, "clip_value"), "Gradient clipping value must be specified" assert hasattr(self.train_configs, "clip_norm"), "Gradient clipping norm must be specified" diff --git a/src/appfl/trainer/iiadmm_trainer.py b/src/appfl/trainer/iiadmm_trainer.py new file mode 100644 index 0000000..fdeec54 --- /dev/null +++ b/src/appfl/trainer/iiadmm_trainer.py @@ -0,0 +1,363 @@ +import copy +import time +import torch +import importlib +import numpy as np +import torch.nn as nn +from omegaconf import DictConfig +from collections import OrderedDict +from torch.utils.data import DataLoader, Dataset +from typing import Any, Optional, Tuple +from appfl.trainer.base_trainer import BaseTrainer +from appfl.privacy import laplace_mechanism_output_perturb + +class IIADMMTrainer(BaseTrainer): + """ + IIADMMTrainer: + Local trainer for the IIADMM algorithm. + This trainer must be used with the IIADMMAggregator. + """ + def __init__( + self, + model: Optional[nn.Module]=None, + loss_fn: Optional[nn.Module]=None, + metric: Optional[Any]=None, + train_dataset: Optional[Dataset]=None, + val_dataset: Optional[Dataset]=None, + train_configs: DictConfig = DictConfig({}), + logger: Optional[Any]=None, + **kwargs + ): + super().__init__( + model=model, + loss_fn=loss_fn, + metric=metric, + train_dataset=train_dataset, + val_dataset=val_dataset, + train_configs=train_configs, + logger=logger, + **kwargs + ) + if not hasattr(self.train_configs, "device"): + self.train_configs.device = "cpu" + self.train_dataloader = DataLoader( + self.train_dataset, + batch_size=self.train_configs.get("train_batch_size", 32), + shuffle=self.train_configs.get("train_data_shuffle", True), + num_workers=self.train_configs.get("num_workers", 0), + ) + self.val_dataloader = DataLoader( + self.val_dataset, + batch_size=self.train_configs.get("val_batch_size", 32), + shuffle=self.train_configs.get("val_data_shuffle", False), + num_workers=self.train_configs.get("num_workers", 0), + ) if self.val_dataset is not None else None + + self.penalty = self.train_configs.get("init_penalty", 500.0) + self.is_first_iter = True + + # At initial, (1) primal_states = global_state, (2) dual_states = 0 + self.primal_states = OrderedDict() + self.dual_states = OrderedDict() + self.primal_states_curr = OrderedDict() + self.primal_states_prev = OrderedDict() + self.named_parameters = set() + + for name, param in self.model.named_parameters(): + self.named_parameters.add(name) + self.primal_states[name] = param.data + self.dual_states[name] = torch.zeros_like(param.data) + self._sanity_check() + + def train(self): + assert hasattr(self, "weight"), "You must set the weight of the client before training. Use `set_weight` method." + self.model.train() + self.model.to(self.train_configs.device) + do_validation = self.train_configs.get("do_validation", False) and self.val_dataloader is not None + do_pre_validation = self.train_configs.get("do_pre_validation", False) and do_validation + + """Set up logging title""" + if self.round == 0: + title = ( + ["Round", "Time", "Train Loss", "Train Accuracy"] + if not do_validation + else ( + ["Round", "Pre Val?", "Time", "Train Loss", "Train Accuracy", "Val Loss", "Val Accuracy"] + if do_pre_validation + else ["Round", "Time", "Train Loss", "Train Accuracy", "Val Loss", "Val Accuracy"] + ) + ) + if self.train_configs.mode == "epoch": + title.insert(1, "Epoch") + self.logger.log_title(title) + + pre_val_interval = self.train_configs.get("pre_validation_interval", 1) + if do_pre_validation and (self.round + 1) % pre_val_interval == 0: + val_loss, val_accuracy = self._validate() + content = [self.round, "Y", " ", " ", " ", val_loss, val_accuracy] + if self.train_configs.mode == "epoch": + content.insert(1, 0) + self.logger.log_content(content) + + optim_module = importlib.import_module("torch.optim") + assert hasattr(optim_module, self.train_configs.optim), f"Optimizer {self.train_configs.optim} not found in torch.optim" + optimizer = getattr(optim_module, self.train_configs.optim)(self.model.parameters(), **self.train_configs.optim_args) + + """ Inputs for the local model update """ + global_state = copy.deepcopy(self.model.state_dict()) + + """Adaptive Penalty (Residual Balancing)""" + if not hasattr(self.train_configs, "residual_balancing"): + self.train_configs.residual_balancing = DictConfig({}) + if getattr(self.train_configs.residual_balancing, "res_on", False): + prim_res = self._primal_residual_at_client(global_state) + dual_res = self._dual_residual_at_client() + self._residual_balancing(prim_res, dual_res) + + if self.train_configs.mode == "epoch": + for epoch in range(self.train_configs.num_local_epochs): + start_time = time.time() + train_loss, target_true, target_pred = 0, [], [] + for data, target in self.train_dataloader: + loss, pred, label = self._train_batch(optimizer, data, target, global_state) + train_loss += loss + target_true.append(label) + target_pred.append(pred) + train_loss /= len(self.train_dataloader) + target_true, target_pred = np.concatenate(target_true), np.concatenate(target_pred) + train_accuracy = float(self.metric(target_true, target_pred)) + if do_validation: + val_loss, val_accuracy = self._validate() + per_epoch_time = time.time() - start_time + self.logger.log_content( + [self.round, epoch, per_epoch_time, train_loss, train_accuracy] + if not do_validation + else ( + [self.round, epoch, per_epoch_time, train_loss, train_accuracy, val_loss, val_accuracy] + if not do_pre_validation + else + [self.round, epoch, 'N', per_epoch_time, train_loss, train_accuracy, val_loss, val_accuracy] + ) + ) + else: + start_time = time.time() + data_iter = iter(self.train_dataloader) + train_loss, target_true, target_pred = 0, [], [] + for _ in range(self.train_configs.num_local_steps): + try: + data, target = next(data_iter) + except: + data_iter = iter(self.train_dataloader) + data, target = next(data_iter) + loss, pred, label = self._train_batch(optimizer, data, target, global_state) + train_loss += loss + target_true.append(label) + target_pred.append(pred) + train_loss /= len(self.train_dataloader) + target_true, target_pred = np.concatenate(target_true), np.concatenate(target_pred) + train_accuracy = float(self.metric(target_true, target_pred)) + if do_validation: + val_loss, val_accuracy = self._validate() + per_step_time = time.time() - start_time + self.logger.log_content( + [self.round, per_step_time, train_loss, train_accuracy] + if not do_validation + else ( + [self.round, per_step_time, train_loss, train_accuracy, val_loss, val_accuracy] + if not do_pre_validation + else + [self.round, 'N', per_step_time, train_loss, train_accuracy, val_loss, val_accuracy] + ) + ) + + self.round += 1 + + """Update dual states""" + for name, param in self.model.named_parameters(): + self.dual_states[name] += self.penalty * (global_state[name] - self.primal_states[name]) + + """Differential Privacy""" + for name, param in self.model.named_parameters(): + param.data = self.primal_states[name].to(self.train_configs.device) + if self.train_configs.get("use_dp", False): + sensitivity = 2.0 * self.train_configs.clip_value / self.penalty + self._model_state = laplace_mechanism_output_perturb( + self.model, + sensitivity, + self.train_configs.epsilon + ) + else: + self._model_state = copy.deepcopy(self.model.state_dict()) + + """Move to CPU for communication""" + if self.train_configs.get("device", "cpu") == "cuda": + for k in self._model_state: + self._model_state[k] = self._model_state[k].cpu() + for name in self.named_parameters: + self.dual_states[name] = self.dual_states[name].cpu() + + self.model_state = OrderedDict() + self.model_state["primal"] = self._model_state + self.model_state["penalty"] = self.penalty + + def get_parameters(self) -> OrderedDict: + hasattr(self, "model_state"), "Please make sure the model has been trained before getting its parameters" + return self.model_state + + def set_weight(self, weight=1.0): + """Set the weight of the client""" + self.weight = weight + + def _sanity_check(self): + """ + Check if the necessary configurations are provided. + """ + assert hasattr(self.train_configs, "mode"), "Training mode must be specified" + assert self.train_configs.mode in ["epoch", "step"], "Training mode must be either 'epoch' or 'step'" + if self.train_configs.mode == "epoch": + assert hasattr(self.train_configs, "num_local_epochs"), "Number of local epochs must be specified" + else: + assert hasattr(self.train_configs, "num_local_steps"), "Number of local steps must be specified" + if getattr(self.train_configs, "clip_grad", False) or getattr(self.train_configs, "use_dp", False): + assert hasattr(self.train_configs, "clip_value"), "Gradient clipping value must be specified" + assert hasattr(self.train_configs, "clip_norm"), "Gradient clipping norm must be specified" + if getattr(self.train_configs, "use_dp", False): + assert hasattr(self.train_configs, "epsilon"), "Privacy budget (epsilon) must be specified" + + def _primal_residual_at_client(self, global_state) -> float: + """ + Calculate primal residual. + :param global_state: global state - input for the local model update + :return: primal residual + """ + primal_res = 0 + for name in self.named_parameters: + primal_res += torch.sum( + torch.square( + global_state[name].to(self.train_configs.device) + - self.primal_states[name].to(self.train_configs.device) + ) + ) + primal_res = torch.sqrt(primal_res).item() + return primal_res + + def _dual_residual_at_client(self) -> float: + """ + Calculate dual residual. + :return: dual residual + """ + dual_res = 0 + if self.is_first_iter: + self.primal_states_curr = self.primal_states + self.is_first_iter = False + else: + self.primal_states_prev = self.primal_states_curr + self.primal_states_curr = self.primal_states + for name in self.named_parameters: + res = self.penalty * ( + self.primal_states_prev[name] - self.primal_states_curr[name] + ) + dual_res += torch.sum(torch.square(res)) + dual_res = torch.sqrt(dual_res).item() + return dual_res + + def _residual_balancing(self, prim_res, dual_res): + if prim_res > self.train_configs.residual_balancing.mu * dual_res: + self.penalty = self.penalty * self.train_configs.residual_balancing.tau + if dual_res > self.train_configs.residual_balancing.mu * prim_res: + self.penalty = self.penalty / self.train_configs.residual_balancing.tau + + def _train_batch(self, optimizer, data, target, global_state): + """ + Train the model for one batch of data + :param optimizer: torch optimizer + :param data: input data + :param target: target label + :param global_state: global model state + :return: loss, prediction, label + """ + """Load primal states to the model""" + for name, param in self.model.named_parameters(): + param.data = self.primal_states[name].to(self.train_configs.device) + + """Adaptive Penalty (Residual Balancing)""" + if ( + getattr(self.train_configs.residual_balancing, "res_on", False) and + getattr(self.train_configs.residual_balancing, "res_on_every_update", False) + ): + prim_res = self._primal_residual_at_client(global_state) + dual_res = self._dual_residual_at_client() + self._residual_balancing(prim_res, dual_res) + + """Train the model""" + data, target = data.to(self.train_configs.device), target.to(self.train_configs.device) + if getattr(self.train_configs, "accum_grad", False) == False: + optimizer.zero_grad() + output = self.model(data) + loss = self.loss_fn(output, target) + loss.backward() + optimizer.step() + + """Gradient Clipping""" + if getattr(self.train_configs, "clip_grad", False) or getattr(self.train_configs, "use_dp", False): + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.train_configs.clip_value, + norm_type=self.train_configs.clip_norm, + ) + + """Update primal and dual states""" + coefficient = 1 + if getattr(self.train_configs, "coeff_grad", False): + coefficient = self.weight * len(target) / len(self.train_dataloader.dataset) + self._iiadmm_step(coefficient, global_state, optimizer) + + return loss.item(), output.detach().cpu().numpy(), target.detach().cpu().numpy() + + def _iiadmm_step(self, coefficient, global_state, optimizer): + """ + Update primal and dual states + """ + momentum = self.train_configs.optim_args.get("momentum", 0) + weight_decay = self.train_configs.optim_args.get("weight_decay", 0) + dampening = self.train_configs.optim_args.get("dampening", 0) + nesterov = self.train_configs.optim_args.get("nesterov", False) + for name, param in self.model.named_parameters(): + grad = copy.deepcopy(param.grad * coefficient) + if weight_decay != 0: + grad.add_(weight_decay, self.primal_states[name]) + if momentum != 0: + param_state = optimizer.state[param] + if "momentum_buffer" not in param_state: + buf = param_state["momentum_buffer"] = grad.clone() + else: + buf = param_state["momentum_buffer"] + buf.mul_(momentum).add_(1 - dampening, grad) + if nesterov: + grad.add_(momentum, buf) + else: + grad = buf + + """Update primal""" + self.primal_states[name] = global_state[name] + (1 / self.penalty) * (self.dual_states[name] - grad) + + def _validate(self) -> Tuple[float, float]: + """ + Validate the model + :return: loss, accuracy + """ + device = self.train_configs.get("device", "cpu") + self.model.eval() + val_loss = 0 + with torch.no_grad(): + target_pred, target_true = [], [] + for data, target in self.val_dataloader: + data, target = data.to(device), target.to(device) + output = self.model(data) + val_loss += self.loss_fn(output, target).item() + target_true.append(target.detach().cpu().numpy()) + target_pred.append(output.detach().cpu().numpy()) + val_loss /= len(self.val_dataloader) + val_accuracy = float(self.metric(np.concatenate(target_true), np.concatenate(target_pred))) + self.model.train() + return val_loss, val_accuracy \ No newline at end of file