From 698746e62ab9c8454b22d1c4864489579e0c7da5 Mon Sep 17 00:00:00 2001 From: hua-zi <83271073+hua-zi@users.noreply.github.com> Date: Fri, 27 Oct 2023 03:14:08 +0800 Subject: [PATCH] Update qfedavg.py qfedavg train multiple clients in a single process. --- fedlab/contrib/algorithm/qfedavg.py | 38 +++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/fedlab/contrib/algorithm/qfedavg.py b/fedlab/contrib/algorithm/qfedavg.py index 9a94cc17..1f9c9e4a 100644 --- a/fedlab/contrib/algorithm/qfedavg.py +++ b/fedlab/contrib/algorithm/qfedavg.py @@ -71,3 +71,41 @@ def train(self, model_parameters, train_loader) -> None: self.hk = self.q * np.float_power( ret_loss + 1e-10, self.q - 1) * grad.norm( )**2 + 1.0 / self.lr * np.float_power(ret_loss + 1e-10, self.q) + +class qFedAvgSerialClientTrainer(SGDSerialClientTrainer): + def setup_optim(self, epochs, batch_size, lr, q): + super().setup_optim(epochs, batch_size, lr) + self.q = q + + def train(self, model_parameters, train_loader) -> None: + """Client trains its local model on local dataset. + Args: + model_parameters (torch.Tensor): Serialized model parameters. + """ + self.set_model(model_parameters) + # self._LOGGER.info("Local train procedure is running") + for ep in range(self.epochs): + self._model.train() + ret_loss = 0.0 + for data, target in train_loader: + if self.cuda: + data, target = data.cuda(self.device), target.cuda( + self.device) + + outputs = self._model(data) + loss = self.criterion(outputs, target) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + ret_loss += loss.detach().item() + # self._LOGGER.info("Local train procedure is finished") + + grad = (model_parameters - self.model_parameters) / self.lr + self.delta = grad * np.float_power(ret_loss + 1e-10, self.q) + self.hk = self.q * np.float_power( + ret_loss + 1e-10, self.q - 1) * grad.norm( + )**2 + 1.0 / self.lr * np.float_power(ret_loss + 1e-10, self.q) + + return [self.delta, self.hk]