Skip to content

Commit

Permalink
Merge pull request #336 from hua-zi/patch-2
Browse files Browse the repository at this point in the history
Update qfedavg.py
  • Loading branch information
dunzeng authored Nov 6, 2023
2 parents aff14a5 + 698746e commit e90d76e
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions fedlab/contrib/algorithm/qfedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,41 @@ def train(self, model_parameters, train_loader) -> None:
self.hk = self.q * np.float_power(
ret_loss + 1e-10, self.q - 1) * grad.norm(
)**2 + 1.0 / self.lr * np.float_power(ret_loss + 1e-10, self.q)

class qFedAvgSerialClientTrainer(SGDSerialClientTrainer):
def setup_optim(self, epochs, batch_size, lr, q):
super().setup_optim(epochs, batch_size, lr)
self.q = q

def train(self, model_parameters, train_loader) -> None:
"""Client trains its local model on local dataset.
Args:
model_parameters (torch.Tensor): Serialized model parameters.
"""
self.set_model(model_parameters)
# self._LOGGER.info("Local train procedure is running")
for ep in range(self.epochs):
self._model.train()
ret_loss = 0.0
for data, target in train_loader:
if self.cuda:
data, target = data.cuda(self.device), target.cuda(
self.device)

outputs = self._model(data)
loss = self.criterion(outputs, target)

self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

ret_loss += loss.detach().item()
# self._LOGGER.info("Local train procedure is finished")

grad = (model_parameters - self.model_parameters) / self.lr
self.delta = grad * np.float_power(ret_loss + 1e-10, self.q)
self.hk = self.q * np.float_power(
ret_loss + 1e-10, self.q - 1) * grad.norm(
)**2 + 1.0 / self.lr * np.float_power(ret_loss + 1e-10, self.q)

return [self.delta, self.hk]

0 comments on commit e90d76e

Please sign in to comment.