Skip to content

Commit

Permalink
Merge pull request #339 from hua-zi/patch-3
Browse files Browse the repository at this point in the history
Fix Scaffold
  • Loading branch information
dunzeng authored Nov 6, 2023
2 parents e90d76e + f8bf067 commit 7564726
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions fedlab/contrib/algorithm/scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,33 @@ def train(self, id, model_parameters, global_c, train_loader):
self.optimizer.zero_grad()
loss.backward()

grad = self.model_gradients
# grad = self.model_gradients
grad = self.model_grads
grad = grad - self.cs[id] + global_c
idx = 0
for parameter in self._model.parameters():
layer_size = parameter.grad.numel()
shape = parameter.grad.shape
#parameter.grad = parameter.grad - self.cs[id][idx:idx + layer_size].view(parameter.grad.shape) + global_c[idx:idx + layer_size].view(parameter.grad.shape)
parameter.grad.data[:] = grad[idx:idx+layer_size].view(shape)[:]

parameters = self._model.parameters()
for p in self._model.state_dict().values():
if p.grad is None: # Batchnorm have no grad
layer_size = p.numel()
else:
parameter = next(parameters)
layer_size = parameter.data.numel()
shape = parameter.grad.shape
parameter.grad.data[:] = grad[idx:idx+layer_size].view(shape)[:]
idx += layer_size

# for parameter in self._model.parameters():
# layer_size = parameter.grad.numel()
# shape = parameter.grad.shape
# #parameter.grad = parameter.grad - self.cs[id][idx:idx + layer_size].view(parameter.grad.shape) + global_c[idx:idx + layer_size].view(parameter.grad.shape)
# parameter.grad.data[:] = grad[idx:idx+layer_size].view(shape)[:]
# idx += layer_size

self.optimizer.step()

dy = self.model_parameters - frz_model
dc = -1.0 / (self.epochs * len(train_loader) * self.lr) * dy - global_c
self.cs[id] += dc
return [dy, dc]

0 comments on commit 7564726

Please sign in to comment.