From 6316c8a710ecd8ede5e32dfe7634c150f2bb2ea7 Mon Sep 17 00:00:00 2001 From: AgentDS Date: Mon, 10 Jul 2023 13:41:57 -0400 Subject: [PATCH] update serialization tools for subtraction parameters, for both trainable/untrainable params; add 'cpu' option for serialzation functions --- fedlab/utils/serialization.py | 99 ++++++++++++++++------------------- 1 file changed, 44 insertions(+), 55 deletions(-) diff --git a/fedlab/utils/serialization.py b/fedlab/utils/serialization.py index 89655c24..6990b361 100644 --- a/fedlab/utils/serialization.py +++ b/fedlab/utils/serialization.py @@ -15,55 +15,33 @@ import torch -# def serialize_model(model: torch.nn.Module) -> torch.Tensor: -# # parameters = [param.data.view(-1) for param in model.parameters()] -# parameters = [param.data.view(-1) for param in model.state_dict().values()] -# m_parameters = torch.cat(parameters) -# m_parameters = m_parameters.cpu() - -# return m_parameters - -# def deserialize_model(model: torch.nn.Module, -# serialized_parameters: torch.Tensor, -# mode="copy"): -# current_index = 0 # keep track of where to read from grad_update - -# for param in model.state_dict().values(): -# numel = param.numel() -# size = param.size() -# if mode == "copy": -# param.copy_( -# serialized_parameters[current_index:current_index + -# numel].view(size)) -# elif mode == "add": -# param.add_( -# serialized_parameters[current_index:current_index + -# numel].view(size)) -# else: -# raise ValueError( -# "Invalid deserialize mode {}, require \"copy\" or \"add\" " -# .format(mode)) -# current_index += numel - class SerializationTool(object): @staticmethod - def serialize_model_gradients(model: torch.nn.Module) -> torch.Tensor: - """_summary_ + def serialize_model_gradients(model: torch.nn.Module, cpu:bool=True) -> torch.Tensor: + """Vectorize model gradients. Args: - model (torch.nn.Module): _description_ + model (torch.nn.Module): Model with gradients. + cpu (bool, optional): Whether move the vectorized parameter to ``torch.device('cpu')`` by force. Defaults to ``True``. If ``cpu`` is ``False``, the returned vector is on the same device as ``model``. Returns: - torch.Tensor: _description_ + torch.Tensor: Vectorized model gradients. Only contains trainable parameters. """ gradients = [param.grad.data.view(-1) for param in model.parameters()] m_gradients = torch.cat(gradients) - m_gradients = m_gradients.cpu() + if cpu: + m_gradients = m_gradients.cpu() return m_gradients @staticmethod - def deserialize_model_gradients(model: torch.nn.Module, gradients: torch.Tensor): + def deserialize_model_gradients(model: torch.nn.Module, gradients: torch.Tensor) -> None: + """Deserialize vectorized ``gradients`` into ``model``'s ``param.grad.data`` for each trainable parameter. + + Args: + model (torch.nn.Module): Model. + gradients (torch.Tensor): Vectorized gradients for single model. + """ idx = 0 for parameter in model.parameters(): layer_size = parameter.grad.numel() @@ -73,22 +51,23 @@ def deserialize_model_gradients(model: torch.nn.Module, gradients: torch.Tensor) idx += layer_size @staticmethod - def serialize_model(model: torch.nn.Module) -> torch.Tensor: - """Unfold model parameters + def serialize_model(model: torch.nn.Module, cpu:bool=True) -> torch.Tensor: + """Unfold model parameters, including trainable as well as untrainable parameters. - Unfold every layer of model, concate all of tensors into one. - Return a `torch.Tensor` with shape (size, ). + Unfold every layer of model, concate all of tensors into one vector. + Return a `torch.Tensor` with shape ``(d, )``, where ``d`` is the total number of parameters in ``model``, including trainable as well as untrainable parameters. Please note that we update the implementation. Current version of serialization includes the parameters in batchnorm layers. Args: model (torch.nn.Module): model to serialize. + cpu (bool, optional): Whether move the vectorized parameter to ``torch.device('cpu')`` by force. Defaults to ``True``. If ``cpu`` is ``False``, the returned vector is on the same device as ``model``. """ - # parameters = [param.data.view(-1) for param in model.parameters()] parameters = [param.data.view(-1) for param in model.state_dict().values()] m_parameters = torch.cat(parameters) - m_parameters = m_parameters.cpu() + if cpu: + m_parameters = m_parameters.cpu() return m_parameters @@ -96,14 +75,14 @@ def serialize_model(model: torch.nn.Module) -> torch.Tensor: def deserialize_model(model: torch.nn.Module, serialized_parameters: torch.Tensor, mode="copy"): - """Assigns serialized parameters to model.parameters. - This is done by iterating through ``model.parameters()`` and assigning the relevant params in ``grad_update``. - NOTE: this function manipulates ``model.parameters``. + """Assigns serialized parameters to parameters in ``model.state_dict()``, which includes both trainable parameters and untrainable parameters. + This is done by iterating through ``model.state_dict()`` and assigning the relevant values with the same dimension as the ``model.state_dict()``. + NOTE: this function manipulates ``model.state_dict()``. Args: model (torch.nn.Module): model to deserialize. serialized_parameters (torch.Tensor): serialized model parameters. - mode (str): deserialize mode. "copy" or "add". + mode (str): deserialize mode. Support "copy", "add", and "sub". """ current_index = 0 # keep track of where to read from grad_update @@ -118,27 +97,33 @@ def deserialize_model(model: torch.nn.Module, param.add_( serialized_parameters[current_index:current_index + numel].view(size)) + elif mode == "sub": + param.sub_( + serialized_parameters[current_index:current_index + + numel].view(size)) else: raise ValueError( - "Invalid deserialize mode {}, require \"copy\" or \"add\" " + "Invalid deserialize mode {}, require \"copy\", \"add\" or \"sub\" " .format(mode)) current_index += numel @staticmethod - def serialize_trainable_model(model: torch.nn.Module) -> torch.Tensor: - """Unfold model parameters + def serialize_trainable_model(model: torch.nn.Module, cpu:bool=True) -> torch.Tensor: + """Unfold trainable model parameters. - Unfold every layer of model, concate all of tensors into one. + Unfold every layer of model by iterating though ``model.parameters()``, then concate all of tensors into one vector. Return a `torch.Tensor` with shape (size, ). Args: model (torch.nn.Module): model to serialize. + cpu (bool, optional): Whether move the vectorized parameter to ``torch.device('cpu')`` by force. Defaults to ``True``. If ``cpu`` is ``False``, the returned vector is on the same device as ``model``. """ parameters = [param.data.view(-1) for param in model.parameters()] m_parameters = torch.cat(parameters) - m_parameters = m_parameters.cpu() + if cpu: + m_parameters = m_parameters.cpu() return m_parameters @@ -146,14 +131,14 @@ def serialize_trainable_model(model: torch.nn.Module) -> torch.Tensor: def deserialize_trainable_model(model: torch.nn.Module, serialized_parameters: torch.Tensor, mode="copy"): - """Assigns serialized parameters to model.parameters. + """Assigns serialized trainable parameters to ``model.parameters``. This is done by iterating through ``model.parameters()`` and assigning the relevant params in ``grad_update``. - NOTE: this function manipulates ``model.parameters``. + NOTE: this function manipulates ``model.parameters()``. Args: model (torch.nn.Module): model to deserialize. serialized_parameters (torch.Tensor): serialized model parameters. - mode (str): deserialize mode. "copy" or "add". + mode (str): deserialize mode. Support "copy", "add", and "sub". """ current_index = 0 # keep track of where to read from grad_update for parameter in model.parameters(): @@ -167,8 +152,12 @@ def deserialize_trainable_model(model: torch.nn.Module, parameter.data.add_( serialized_parameters[current_index:current_index + numel].view(size)) + elif mode == "sub": + parameter.data.sub_( + serialized_parameters[current_index:current_index + + numel].view(size)) else: raise ValueError( - "Invalid deserialize mode {}, require \"copy\" or \"add\" " + "Invalid deserialize mode {}, require \"copy\", \"add\" or \"sub\" " .format(mode)) current_index += numel \ No newline at end of file