From d8cb221b13e4bc1be1ac0bb43a7be7a87a121224 Mon Sep 17 00:00:00 2001 From: hua-zi <83271073+hua-zi@users.noreply.github.com> Date: Fri, 3 Nov 2023 22:12:33 +0800 Subject: [PATCH] Update model_maintainer.py Return serialized model gradients(base on model.state_dict(), Shape is the same as model_parameters). --- fedlab/core/model_maintainer.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/fedlab/core/model_maintainer.py b/fedlab/core/model_maintainer.py index 89d1e8b0..836e3c1f 100644 --- a/fedlab/core/model_maintainer.py +++ b/fedlab/core/model_maintainer.py @@ -60,6 +60,20 @@ def model_parameters(self) -> torch.Tensor: """Return serialized model parameters.""" return SerializationTool.serialize_model(self._model) + @property + def model_grads(self) -> torch.Tensor: + """Return serialized model gradients(base on model.state_dict(), Shape is the same as model_parameters).""" + params = self._model.state_dict() + for name, p in self._model.named_parameters(): + params[name].grad = p.grad + for key in params: + if params[key].grad is None: + params[key].grad = torch.zeros_like(params[key]) + gradients = [param.grad.data.view(-1) for param in params.values()] + m_gradients = torch.cat(gradients) + m_gradients = m_gradients.cpu() + return m_gradients + @property def model_gradients(self) -> torch.Tensor: """Return serialized model gradients.""" @@ -117,4 +131,4 @@ def set_model(self, parameters: torch.Tensor = None, id: int = None): if id is None: super().set_model(parameters) else: - super().set_model(self.parameters[id]) \ No newline at end of file + super().set_model(self.parameters[id])