Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Scaffold #339

Merged
merged 4 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions fedlab/contrib/algorithm/scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,33 @@ def train(self, id, model_parameters, global_c, train_loader):
self.optimizer.zero_grad()
loss.backward()

grad = self.model_gradients
# grad = self.model_gradients
grad = self.model_grads
grad = grad - self.cs[id] + global_c
idx = 0
for parameter in self._model.parameters():
layer_size = parameter.grad.numel()
shape = parameter.grad.shape
#parameter.grad = parameter.grad - self.cs[id][idx:idx + layer_size].view(parameter.grad.shape) + global_c[idx:idx + layer_size].view(parameter.grad.shape)
parameter.grad.data[:] = grad[idx:idx+layer_size].view(shape)[:]

parameters = self._model.parameters()
for p in self._model.state_dict().values():
if p.grad is None: # Batchnorm have no grad
layer_size = p.numel()
else:
parameter = next(parameters)
layer_size = parameter.data.numel()
shape = parameter.grad.shape
parameter.grad.data[:] = grad[idx:idx+layer_size].view(shape)[:]
idx += layer_size

# for parameter in self._model.parameters():
# layer_size = parameter.grad.numel()
# shape = parameter.grad.shape
# #parameter.grad = parameter.grad - self.cs[id][idx:idx + layer_size].view(parameter.grad.shape) + global_c[idx:idx + layer_size].view(parameter.grad.shape)
# parameter.grad.data[:] = grad[idx:idx+layer_size].view(shape)[:]
# idx += layer_size

self.optimizer.step()

dy = self.model_parameters - frz_model
dc = -1.0 / (self.epochs * len(train_loader) * self.lr) * dy - global_c
self.cs[id] += dc
return [dy, dc]

16 changes: 15 additions & 1 deletion fedlab/core/model_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ def model_parameters(self) -> torch.Tensor:
"""Return serialized model parameters."""
return SerializationTool.serialize_model(self._model)

@property
def model_grads(self) -> torch.Tensor:
"""Return serialized model gradients(base on model.state_dict(), Shape is the same as model_parameters)."""
params = self._model.state_dict()
for name, p in self._model.named_parameters():
params[name].grad = p.grad
for key in params:
if params[key].grad is None:
params[key].grad = torch.zeros_like(params[key])
gradients = [param.grad.data.view(-1) for param in params.values()]
m_gradients = torch.cat(gradients)
m_gradients = m_gradients.cpu()
return m_gradients

@property
def model_gradients(self) -> torch.Tensor:
"""Return serialized model gradients."""
Expand Down Expand Up @@ -117,4 +131,4 @@ def set_model(self, parameters: torch.Tensor = None, id: int = None):
if id is None:
super().set_model(parameters)
else:
super().set_model(self.parameters[id])
super().set_model(self.parameters[id])
Loading