diff --git a/fedlab/core/model_maintainer.py b/fedlab/core/model_maintainer.py index 89d1e8b0..836e3c1f 100644 --- a/fedlab/core/model_maintainer.py +++ b/fedlab/core/model_maintainer.py @@ -60,6 +60,20 @@ def model_parameters(self) -> torch.Tensor: """Return serialized model parameters.""" return SerializationTool.serialize_model(self._model) + @property + def model_grads(self) -> torch.Tensor: + """Return serialized model gradients(base on model.state_dict(), Shape is the same as model_parameters).""" + params = self._model.state_dict() + for name, p in self._model.named_parameters(): + params[name].grad = p.grad + for key in params: + if params[key].grad is None: + params[key].grad = torch.zeros_like(params[key]) + gradients = [param.grad.data.view(-1) for param in params.values()] + m_gradients = torch.cat(gradients) + m_gradients = m_gradients.cpu() + return m_gradients + @property def model_gradients(self) -> torch.Tensor: """Return serialized model gradients.""" @@ -117,4 +131,4 @@ def set_model(self, parameters: torch.Tensor = None, id: int = None): if id is None: super().set_model(parameters) else: - super().set_model(self.parameters[id]) \ No newline at end of file + super().set_model(self.parameters[id])