diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index def3f7070..1546ecb67 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -3,11 +3,16 @@ from abc import ABC from abc import abstractmethod +from collections import OrderedDict import inspect from inspect import getcallargs +from typing import Any, Callable, Dict, Type, Union import torch +from torch import Tensor from torch.nn import Module +from torch.nn import Parameter +import torch.nn.utils.parametrize as parametrize from torch.overrides import get_testing_overrides from brevitas.fx import GraphModule @@ -154,7 +159,18 @@ def _init_new_module(self, old_module: Module, name=None): def _replace_old_module(self, model, old_module, new_module, load_state_dict=True): replace_module(model, old_module, new_module) if load_state_dict: - new_module.load_state_dict(old_module.state_dict()) + # The dictionary entries relative to parametrizations need to be ignored, as these are passed + # when invoking transfer_parametrizations_and_params. + old_module_state_dict = OrderedDict({ + k: v for k, + v in old_module.state_dict().items() if not k.startswith("parametrizations")}) + # If the old module is parametrized, these need to be transferred to the new module. Strict needs to be set to False, + # as there will be missing keys for those parameters which have any parametrizations attached. + if parametrize.is_parametrized(old_module): + new_module.load_state_dict(old_module_state_dict, strict=False) + parametrize.transfer_parametrizations_and_params(old_module, new_module) + else: + new_module.load_state_dict(old_module_state_dict) class InsertModuleCallAfter(GraphTransform): @@ -174,6 +190,76 @@ def apply(self, graph_model: GraphModule) -> GraphModule: return graph_model +class ModuleInstanceRegisterParametrization(Transform): + + def __init__( + self, old_module_instance: Module, tensor_name: str, + parametrization_module: Module) -> None: + self.old_module_instance = old_module_instance + self.tensor_name = tensor_name + self.parametrization_module = parametrization_module + + def apply(self, model: GraphModule) -> GraphModule: + for old_module in model.modules(): + if old_module is self.old_module_instance: + # register the parametrization in the old_module + parametrize.register_parametrization( + old_module, self.tensor_name, self.parametrization_module) + break + return model + + +class ModuleInstanceFuseRotationWeights(Transform): + + def __init__( + self, + old_module_instance: Module, + rot_mat: Union[Parameter, Tensor], + rot_func: Callable, + K: int, + tensor_name: str, + axis: int, + is_source: bool, + ): + self.old_module_instance = old_module_instance + self.rot_mat = rot_mat + self.rot_func = rot_func + self.K = K + self.tensor_name = tensor_name + self.axis = axis + self.is_source = is_source + + def apply(self, model: GraphModule) -> GraphModule: + for old_module in model.modules(): + if old_module is self.old_module_instance: + if hasattr(old_module, 'allocate_params'): + old_module.allocate_params(old_module) + weight = getattr(old_module, self.tensor_name).data + + if self.is_source: + if self.axis == 0: + weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() + elif self.axis == 1: + weight = self.rot_func(weight, self.rot_mat, self.K) + else: + raise RuntimeError("Not supported yet") + # If not a source, the module is either a sink or an orphan + else: + if self.axis == 1: + weight = self.rot_func(weight, self.rot_mat, self.K) + elif self.axis == 0: + weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() + else: + raise RuntimeError("Not supported yet") + # Modify the weights in-place + getattr(old_module, self.tensor_name).data = weight + + if hasattr(old_module, 'offload_params'): + old_module.offload_params(old_module) + break + return model + + class ModuleInstanceToModuleInstance(Transform): def __init__(self, old_module_instance, new_module_instance): @@ -189,6 +275,31 @@ def apply(self, model: GraphModule) -> GraphModule: return model +class ModuleInstanceWrapModule(Transform): + + def __init__( + self, + old_module_instance: Module, + wrapper_class: Type[Module], + module_attribute: str, + kwargs_wrapper: Dict[str, Any]): + self.old_module_instance = old_module_instance + self.wrapper_class = wrapper_class + self.module_attribute = module_attribute + self.kwargs_wrapper = kwargs_wrapper + + def apply(self, model: GraphModule) -> GraphModule: + for old_module in model.modules(): + if old_module is self.old_module_instance: + kwargs = {self.module_attribute: self.old_module_instance} + kwargs.update(self.kwargs_wrapper) + new_module_instance = self.wrapper_class(**kwargs) + # init the new module based on the old one + replace_module(model, old_module, new_module_instance) + break + return model + + class ModuleToModuleByName(ModuleToModule): def __init__(self, old_module_name, new_module_class, **kwargs): diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 4e5c1a162..0da49233c 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -3,6 +3,7 @@ from abc import ABC from abc import abstractmethod +from collections import defaultdict from dataclasses import dataclass from dataclasses import field from functools import partial @@ -15,25 +16,31 @@ import torch from torch.fx import GraphModule as TorchGraphModule import torch.nn as nn +import torch.nn.utils.parametrize as parametrize from brevitas import torch_version from brevitas.fx import GraphModule from brevitas.fx import Node -from brevitas.graph import ModuleToModuleByClass from brevitas.graph import ModuleToModuleByInstance from brevitas.graph.base import GraphTransform from brevitas.graph.base import InsertModuleCallAfter +from brevitas.graph.base import ModuleInstanceFuseRotationWeights +from brevitas.graph.base import ModuleInstanceRegisterParametrization from brevitas.graph.base import ModuleInstanceToModuleInstance +from brevitas.graph.base import ModuleInstanceWrapModule from brevitas.graph.base import Transform from brevitas.graph.hadamard import get_hadK from brevitas.graph.hadamard import matmul_hadU from brevitas.graph.hadamard import matmul_hadU_cuda +from brevitas.graph.hadamard import random_hadamard_matrix from brevitas.graph.utils import get_module from brevitas.graph.utils import get_node from brevitas.nn.equalized_layer import EqualizedModule from brevitas.nn.equalized_layer import functional_rotate_input from brevitas.nn.equalized_layer import INPUT_NAMES from brevitas.nn.equalized_layer import RotatedModule +from brevitas.nn.equalized_layer import RotationBiasParametrization +from brevitas.nn.equalized_layer import RotationWeightParametrization from brevitas.nn.quant_scale_bias import ScaleBias from brevitas.utils.torch_utils import KwargsForwardHook @@ -339,6 +346,8 @@ def _get_input_axis(module: nn.Module) -> Optional[int]: return 0 else: return None + elif isinstance(module, (RotatedModule,)): + return _get_input_axis(module.layer) else: return None @@ -367,6 +376,8 @@ def _get_output_axis(module: nn.Module) -> Optional[int]: return 0 else: return None + elif isinstance(module, (RotatedModule,)): + return _get_output_axis(module.layer) else: return None @@ -1275,6 +1286,10 @@ def _apply_had_device(tensor, had_K, K): def _apply_ort_device(tensor, ort, *args): ort = ort.type_as(tensor) + if tensor.shape[-1] != ort.shape[0]: + tensor_shape = tensor.shape + return torch.matmul(tensor.view(-1, tensor_shape[-1] // ort.shape[0], ort.shape[0]), + ort).view(tensor_shape) return torch.matmul(tensor, ort) @@ -1299,7 +1314,12 @@ def random_orthogonal_matrix(size): return q -def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method='had'): +def _apply_rotate( + model: nn.Module, + regions: List[Region], + full_rotation_method: str = 'had', + fuse_rotations: bool = True, + apply_inplace_rotations: bool = True): rewriters = [] for region in regions: insert_rotation_module = len(region.srcs) == 0 @@ -1311,6 +1331,13 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= rot_mat = random_orthogonal_matrix(hidden_dim) K = None rot_func = _apply_ort_device + elif not insert_rotation_module and not fuse_rotations: + # TODO: This might be problematic if the parameters are distributed + # across devices. Generalize this logic for safety. + device = next(model.parameters()).device + rot_mat = random_hadamard_matrix(hidden_dim, device) + K = None + rot_func = _apply_ort_device else: try: # Build hadamard rotation matrix @@ -1326,51 +1353,104 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= print("Skipping layers") continue + # If the rotation is not fused, redefine as a Parameter, to enable its optimization + if not insert_rotation_module and not fuse_rotations: + rot_mat = torch.nn.Parameter(rot_mat) + for name, indexes in region.srcs.items(): module = region.get_module_from_name(name) - if hasattr(module, 'allocate_params'): - module.allocate_params(module) axis = _get_output_axis(module) - weight = module.weight.data - if axis == 0: - weight = rot_func(weight.t(), rot_mat, K).t() - elif axis == 1: - weight = rot_func(weight, rot_mat, K) - else: - raise RuntimeError("Not supported yet") - module.weight.data = weight + if fuse_rotations: + rewriter = ModuleInstanceFuseRotationWeights( + old_module_instance=module, + rot_mat=rot_mat, + rot_func=rot_func, + K=K, + tensor_name="weight", + axis=axis, + is_source=True, + ) + rewriters.append(rewriter) - if getattr(module, 'bias', None) is not None: - bias = module.bias.data - bias = rot_func(bias, rot_mat, K) - module.bias.data = bias - if hasattr(module, 'offload_params'): - module.offload_params(module) + if getattr(module, 'bias', None) is not None: + rewriter = ModuleInstanceFuseRotationWeights( + old_module_instance=module, + rot_mat=rot_mat, + rot_func=rot_func, + K=K, + tensor_name="bias", + axis=1, + is_source=True, + ) + rewriters.append(rewriter) + else: + rewriter = ModuleInstanceRegisterParametrization( + module, + "weight", + RotationWeightParametrization( + rot_mat=rot_mat, + rot_func=rot_func, + axis=axis, + is_source=True, + )) + rewriters.append(rewriter) + if getattr(module, 'bias', None) is not None: + # TODO: Consolidate RotationBiasParametrization into a single + # class, by setting output_axis = 1. Also, could use a single + # axis, as input_axis and output_axis are not used simultaneously + rewriter = ModuleInstanceRegisterParametrization( + module, + "bias", + RotationBiasParametrization( + rot_mat=rot_mat, + rot_func=rot_func, + output_axis=axis, + is_source=True, + )) + rewriters.append(rewriter) for name, indexes in region.sinks.items(): module = region.get_module_from_name(name) - if hasattr(module, 'allocate_params'): - module.allocate_params(module) axis = _get_input_axis(module) - weight = module.weight.data - if axis == 1: - _update_weights(module, rot_func(weight, rot_mat, K), 'weight') - elif axis == 0: - _update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight') + if not insert_rotation_module and not fuse_rotations: + rewriter = ModuleInstanceRegisterParametrization( + module, + "weight", + RotationWeightParametrization( + rot_mat=rot_mat, + rot_func=rot_func, + axis=axis, + is_sink=True, + )) + rewriters.append(rewriter) else: - raise RuntimeError("Not supported yet") - - if hasattr(module, 'offload_params'): - module.offload_params(module) + # Verify that there are no parametrizations, as otherwise the underlying weights will not be updated + assert not hasattr(module, "parametrizations"), "Fused rotations need to be incorporated before the parametrized rotations." + + rewriter = ModuleInstanceFuseRotationWeights( + old_module_instance=module, + rot_mat=rot_mat, + rot_func=rot_func, + K=K, + tensor_name="weight", + axis=axis, + is_source=False, + ) + rewriters.append(rewriter) if insert_rotation_module and len(region.srcs) == 0: - rewriter = ModuleInstanceToModuleInstance( - module, RotatedModule(had_mat=rot_mat, k=K, layer=module)) + rewriter = ModuleInstanceWrapModule( + module, RotatedModule, "layer", { + "had_mat": rot_mat, "k": K}) rewriters.append(rewriter) for r in rewriters: - model = r.apply(model) + # The parametrizations need to be registered after the potential HF hooks have been + # removed, as otherwise the device maps will not match the structure of the + # model's state_dict after the registration of the parametrizations. + if apply_inplace_rotations and not isinstance(r, ModuleInstanceRegisterParametrization): + model = r.apply(model) return rewriters @@ -1463,8 +1543,13 @@ def rotate_matmuls(self, graph_module): graph_module.recompile() graph_module.graph.lint() - def apply(self, - graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: + def apply( + self, + graph_model: GraphModule, + fuse_rotations: bool = True, + additional_regions: Optional[List[Region]] = None, + apply_inplace_rotations: bool = True, + ) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: rewriters = [] regions = _extract_regions( graph_model, @@ -1473,6 +1558,8 @@ def apply(self, 'supported_sinks': self.supported_sinks, 'scale_invariant_layers': self.scale_invariant_layers, 'scale_invariant_function': self.scale_invariant_function}) + if additional_regions is not None: + regions.extend(additional_regions) eq_layers = set() orphan_regions = [] self.find_module(graph_model, orphan_regions) @@ -1484,11 +1571,22 @@ def apply(self, # Layerwise have only a single sink named 'sinks0' id_sink = id(o_r.get_module_from_name('sinks0')) if id_sink not in eq_layers: - regions.append(o_r) + # Orphan regions result in an in-place update of the weights, so these are applied before + # the rest of the rotations, to simplify the logic when fuse_rotations = False, as, + # otherwise, additional checks need to be incorporated to verify if the module weights + # have any parametrizations already, since in that case, the in-place update needs to + # be performed in module.parametrizations.weight.original. + # TODO: Use deque to perform this operation in O(1) + regions = [o_r] + regions if self.rotate_matmul: self.rotate_matmuls(graph_model) if len(regions) > 0: - rewriters = _apply_rotate(graph_model, regions, self.full_rotation_method) + rewriters = _apply_rotate( + graph_model, + regions, + self.full_rotation_method, + fuse_rotations, + apply_inplace_rotations) if self.return_rewriters: return graph_model, rewriters else: @@ -1581,7 +1679,7 @@ def __init__(self, blacklist_layer=None): self.supported_sinks = (nn.Linear) self.blacklist_layers = blacklist_layer - def apply(self, model: nn.Module) -> nn.Module: + def apply(self, model: nn.Module, fuse_rotations: bool = True) -> nn.Module: regions: List[Region] = [] self.find_module(model, regions) if len(regions) > 0: diff --git a/src/brevitas/graph/hadamard.py b/src/brevitas/graph/hadamard.py index 235e22567..c74695e00 100644 --- a/src/brevitas/graph/hadamard.py +++ b/src/brevitas/graph/hadamard.py @@ -66,6 +66,23 @@ def get_hadK(n, transpose=False): assert (is_pow2(n // 12)) K = 12 hadK = tensors['get_had12'].T if transpose else tensors['get_had12'] + # TODO: Add this matrix along with the others + elif n % 64 == 0: + assert (is_pow2(n // 64)) + K = 64 + hadK = torch.tensor([[1 if char == '+' else -1 + for char in line] + for line in hadamard_string_64.strip().split('\n')], + dtype=torch.float32, + requires_grad=False) + elif n % 16 == 0: + assert (is_pow2(n // 16)) + K = 16 + hadK = torch.tensor([[1 if char == '+' else -1 + for char in line] + for line in hadamard_string_16.strip().split('\n')], + dtype=torch.float32, + requires_grad=False) else: assert (is_pow2(n)) K = 1 @@ -106,7 +123,7 @@ def random_hadamard_matrix(size, device): Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) Q = Q * 2 - 1 Q = torch.diag(Q) - return matmul_hadU(Q).to(device) + return matmul_hadU(Q).to(device).float() def matmul_hadU_cuda(X, hadK, K): @@ -166,3 +183,86 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False): def is_pow2(n): return (n & (n - 1) == 0) and (n > 0) + + +hadamard_string_16 = """++++++++++++++++ ++-+-+-+-+-+-+-+- +++--++--++--++-- ++--++--++--++--+ +++++----++++---- ++-+--+-++-+--+-+ +++----++++----++ ++--+-++-+--+-++- +++++++++-------- ++-+-+-+--+-+-+-+ +++--++----++--++ ++--++--+-++--++- +++++--------++++ ++-+--+-+-+-++-+- +++----++--++++-- ++--+-++--++-+--+""" + +hadamard_string_64 = """++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+- +++--++--++--++--++--++--++--++--++--++--++--++--++--++--++--++-- ++--++--++--++--++--++--++--++--++--++--++--++--++--++--++--++--+ +++++----++++----++++----++++----++++----++++----++++----++++---- ++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-+ +++----++++----++++----++++----++++----++++----++++----++++----++ ++--+-++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-++- +++++++++--------++++++++--------++++++++--------++++++++-------- ++-+-+-+--+-+-+-++-+-+-+--+-+-+-++-+-+-+--+-+-+-++-+-+-+--+-+-+-+ +++--++----++--++++--++----++--++++--++----++--++++--++----++--++ ++--++--+-++--++-+--++--+-++--++-+--++--+-++--++-+--++--+-++--++- +++++--------++++++++--------++++++++--------++++++++--------++++ ++-+--+-+-+-++-+-+-+--+-+-+-++-+-+-+--+-+-+-++-+-+-+--+-+-+-++-+- +++----++--++++--++----++--++++--++----++--++++--++----++--++++-- ++--+-++--++-+--++--+-++--++-+--++--+-++--++-+--++--+-++--++-+--+ +++++++++++++++++----------------++++++++++++++++---------------- ++-+-+-+-+-+-+-+--+-+-+-+-+-+-+-++-+-+-+-+-+-+-+--+-+-+-+-+-+-+-+ +++--++--++--++----++--++--++--++++--++--++--++----++--++--++--++ ++--++--++--++--+-++--++--++--++-+--++--++--++--+-++--++--++--++- +++++----++++--------++++----++++++++----++++--------++++----++++ ++-+--+-++-+--+-+-+-++-+--+-++-+-+-+--+-++-+--+-+-+-++-+--+-++-+- +++----++++----++--++++----++++--++----++++----++--++++----++++-- ++--+-++-+--+-++--++-+--+-++-+--++--+-++-+--+-++--++-+--+-++-+--+ +++++++++----------------++++++++++++++++----------------++++++++ ++-+-+-+--+-+-+-+-+-+-+-++-+-+-+-+-+-+-+--+-+-+-+-+-+-+-++-+-+-+- +++--++----++--++--++--++++--++--++--++----++--++--++--++++--++-- ++--++--+-++--++--++--++-+--++--++--++--+-++--++--++--++-+--++--+ +++++--------++++----++++++++----++++--------++++----++++++++---- ++-+--+-+-+-++-+--+-++-+-+-+--+-++-+--+-+-+-++-+--+-++-+-+-+--+-+ +++----++--++++----++++--++----++++----++--++++----++++--++----++ ++--+-++--++-+--+-++-+--++--+-++-+--+-++--++-+--+-++-+--++--+-++- +++++++++++++++++++++++++++++++++-------------------------------- ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+--+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +++--++--++--++--++--++--++--++----++--++--++--++--++--++--++--++ ++--++--++--++--++--++--++--++--+-++--++--++--++--++--++--++--++- +++++----++++----++++----++++--------++++----++++----++++----++++ ++-+--+-++-+--+-++-+--+-++-+--+-+-+-++-+--+-++-+--+-++-+--+-++-+- +++----++++----++++----++++----++--++++----++++----++++----++++-- ++--+-++-+--+-++-+--+-++-+--+-++--++-+--+-++-+--+-++-+--+-++-+--+ +++++++++--------++++++++----------------++++++++--------++++++++ ++-+-+-+--+-+-+-++-+-+-+--+-+-+-+-+-+-+-++-+-+-+--+-+-+-++-+-+-+- +++--++----++--++++--++----++--++--++--++++--++----++--++++--++-- ++--++--+-++--++-+--++--+-++--++--++--++-+--++--+-++--++-+--++--+ +++++--------++++++++--------++++----++++++++--------++++++++---- ++-+--+-+-+-++-+-+-+--+-+-+-++-+--+-++-+-+-+--+-+-+-++-+-+-+--+-+ +++----++--++++--++----++--++++----++++--++----++--++++--++----++ ++--+-++--++-+--++--+-++--++-+--+-++-+--++--+-++--++-+--++--+-++- +++++++++++++++++--------------------------------++++++++++++++++ ++-+-+-+-+-+-+-+--+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-++-+-+-+-+-+-+-+- +++--++--++--++----++--++--++--++--++--++--++--++++--++--++--++-- ++--++--++--++--+-++--++--++--++--++--++--++--++-+--++--++--++--+ +++++----++++--------++++----++++----++++----++++++++----++++---- ++-+--+-++-+--+-+-+-++-+--+-++-+--+-++-+--+-++-+-+-+--+-++-+--+-+ +++----++++----++--++++----++++----++++----++++--++----++++----++ ++--+-++-+--+-++--++-+--+-++-+--+-++-+--+-++-+--++--+-++-+--+-++- +++++++++----------------++++++++--------++++++++++++++++-------- ++-+-+-+--+-+-+-+-+-+-+-++-+-+-+--+-+-+-++-+-+-+-+-+-+-+--+-+-+-+ +++--++----++--++--++--++++--++----++--++++--++--++--++----++--++ ++--++--+-++--++--++--++-+--++--+-++--++-+--++--++--++--+-++--++- +++++--------++++----++++++++--------++++++++----++++--------++++ ++-+--+-+-+-++-+--+-++-+-+-+--+-+-+-++-+-+-+--+-++-+--+-+-+-++-+- +++----++--++++----++++--++----++--++++--++----++++----++--++++-- ++--+-++--++-+--+-++-+--++--+-++--++-+--++--+-++-+--+-++--++-+--+""" diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index 535f9a8f9..538ce5717 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -6,6 +6,8 @@ import torch import torch.nn as nn +import torch.nn.utils.parametrize as parametrize +from tqdm import tqdm import brevitas from brevitas.graph.base import InsertModuleCallAfter @@ -511,7 +513,7 @@ def find_module( Specifically, it allows to map nn.MultiheadAttetion to its quantized counterpart and not its Linear submodules. """ - if _module_class_name(type(model)) in layer_map.keys(): + if _module_class_name(parametrize.type_before_parametrizations(model)) in layer_map.keys(): module_to_replace.append(model) else: for name, module in model.named_children(): @@ -532,10 +534,11 @@ def layerwise_layer_handler( find_module(model, layer_map, module_to_replace, name_blacklist) rewriters = [] for module in module_to_replace: - if layer_map[_module_class_name(type(module))] is not None: - quant_module_class, quant_module_kwargs = layer_map[_module_class_name(type(module))] + if layer_map[_module_class_name( + parametrize.type_before_parametrizations(module))] is not None: + quant_module_class, quant_module_kwargs = layer_map[_module_class_name(parametrize.type_before_parametrizations(module))] rewriter = ModuleToModuleByInstance(module, quant_module_class, **quant_module_kwargs) rewriters.append(rewriter) - for rewriter in rewriters: + for rewriter in tqdm(rewriters, leave=False): model = rewriter.apply(model) return model diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index 8413a8208..2c48f9da3 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -1,4 +1,6 @@ +import functools from inspect import signature +from typing import Callable, Optional import torch @@ -51,6 +53,11 @@ def forward(self, *args, **kwargs): return out +def _apply_ort_device(tensor, ort, *args): + ort = ort.type_as(tensor) + return torch.matmul(tensor, ort) + + class RotatedModule(torch.nn.Module): def __init__(self, layer, had_mat=None, k=None) -> None: @@ -64,20 +71,100 @@ def __init__(self, layer, had_mat=None, k=None) -> None: def forward(self, inp, **kwargs): is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None - if is_cuda and fast_hadamard_transform is not None: - if self.had_mat is None or self.k is None: - had_K, K = get_hadK(inp.shape[-1]) - else: - had_K = self.had_mat - K = self.k - inp = matmul_hadU_cuda(inp, had_K, K) + # If k is None, we assume that an orthogonal matrix is used + if self.k is None: + inp = _apply_ort_device(inp, self.had_mat) else: - inp = matmul_hadU(inp) + if is_cuda and fast_hadamard_transform is not None: + if self.had_mat is None or self.k is None: + had_K, K = get_hadK(inp.shape[-1]) + else: + had_K = self.had_mat + K = self.k + inp = matmul_hadU_cuda(inp, had_K, K) + else: + inp = matmul_hadU(inp) o = self.layer(inp) return o +def rot_func_wrapper(weight: torch.Tensor, rot_mat: torch.Tensor, rotation_function: Callable): + weight_shape = weight.shape + rot_mat_dim = rot_mat.shape[0] + return rotation_function(weight.view(-1, weight_shape.shape[1] // rot_mat_dim, + rot_mat_dim)).view(weight_shape) + + +class RotationWeightParametrization(torch.nn.Module): + + def __init__( + self, + rot_mat: torch.nn.Parameter, + rot_func: Callable, + axis: int, + is_source: bool = False, + is_sink: bool = False, + is_orphan: bool = False, + ) -> None: + super().__init__() + self.rot_mat = rot_mat + self.rot_func = rot_func + self.axis = axis + self.is_source = is_source + self.is_sink = is_sink + self.is_orphan = is_orphan + self.K = None + + def forward(self, weight: torch.Tensor) -> torch.Tensor: + if self.is_sink or self.is_orphan: + if self.axis == 1: + weight = self.rot_func(weight, self.rot_mat, self.K) + elif self.axis == 0: + weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() + else: + raise RuntimeError("Not supported yet") + + if self.is_source: + if self.axis == 0: + weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() + elif self.axis == 1: + weight = self.rot_func(weight, self.rot_mat, self.K) + else: + raise RuntimeError("Not supported yet") + + return weight + + +class RotationBiasParametrization(torch.nn.Module): + + def __init__( + self, + rot_mat: torch.nn.Parameter, + rot_func: Callable, + input_axis: Optional[int] = None, + output_axis: Optional[int] = None, + is_source: bool = False, + is_sink: bool = False, + is_orphan: bool = False, + ) -> None: + super().__init__() + self.rot_mat = rot_mat + self.rot_func = rot_func + self.input_axis = input_axis + self.output_axis = output_axis + self.is_source = is_source + self.is_sink = is_sink + self.is_orphan = is_orphan + self.K = None + + def forward(self, bias: torch.Tensor) -> torch.Tensor: + if self.is_source: + bias = self.rot_func(bias, self.rot_mat, self.K) + + return bias + + def functional_rotate_input(inp, transpose=False): is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None if transpose: diff --git a/src/brevitas/optim/sgdg.py b/src/brevitas/optim/sgdg.py new file mode 100644 index 000000000..d1e89e39b --- /dev/null +++ b/src/brevitas/optim/sgdg.py @@ -0,0 +1,197 @@ +# coding=utf-8 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# This code is originally from: https://github.com/JunLi-Galios/Optimization-on-Stiefel-Manifold-via-Cayley-Transform/blob/master/stiefel_optimizer.py + +import random + +import torch +from torch.optim.optimizer import Optimizer + + +def unit(v, dim: int = 1, eps: float = 1e-8): + vnorm = norm(v, dim) + return v / vnorm.add(eps), vnorm + + +def norm(v, dim: int = 1): + assert len(v.size()) == 2 + return v.norm(p=2, dim=dim, keepdim=True) + + +def matrix_norm_one(W): + out = torch.abs(W) + out = torch.sum(out, dim=0) + out = torch.max(out) + return out + + +def Cayley_loop(X, W, tan_vec, t): # + [n, p] = X.size() + Y = X + t * tan_vec + for i in range(5): + Y = X + t * torch.matmul(W, 0.5 * (X + Y)) + + return Y.t() + + +def qr_retraction(tan_vec): # tan_vec, p-by-n, p <= n + [p, n] = tan_vec.size() + tan_vec.t_() + q, r = torch.linalg.qr(tan_vec) + d = torch.diag(r, 0) + ph = d.sign() + q *= ph.expand_as(q) + q.t_() + + return q + + +episilon = 1e-8 + + +class SGDG(Optimizer): + r"""This optimizer updates variables with two different routines + based on the boolean variable 'stiefel'. + + If stiefel is True, the variables will be updated by SGD-G proposed + as decorrelated weight matrix. + + If stiefel is False, the variables will be updated by SGD. + This routine was taken from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + + -- common parameters + lr (float): learning rate + momentum (float, optional): momentum factor (default: 0) + stiefel (bool, optional): whether to use SGD-G (default: False) + + -- parameters in case stiefel is False + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + + -- parameters in case stiefel is True + omega (float, optional): orthogonality regularization factor (default: 0) + grad_clip (float, optional): threshold for gradient norm clipping (default: None) + """ + + def __init__( + self, + params, + lr: float = 1e-3, + momentum: int = 0, + dampening: int = 0, + weight_decay: int = 0, + nesterov: bool = False, + stiefel: bool = False, + omega: int = 0, + grad_clip=None, + ) -> None: + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + stiefel=stiefel, + omega=0, + grad_clip=grad_clip, + ) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super(SGDG, self).__init__(params, defaults) + + def __setstate__(self, state) -> None: + super(SGDG, self).__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + momentum = group["momentum"] + stiefel = group["stiefel"] + + for p in group["params"]: + if p.grad is None: + continue + + unity, _ = unit(p.data.view(p.size()[0], -1)) + if stiefel and unity.size()[0] <= unity.size()[1]: + weight_decay = group["weight_decay"] + dampening = group["dampening"] + nesterov = group["nesterov"] + + rand_num = random.randint(1, 101) + if rand_num == 1: + unity = qr_retraction(unity) + + g = p.grad.data.view(p.size()[0], -1) + + lr = group["lr"] + + param_state = self.state[p] + if "momentum_buffer" not in param_state: + param_state["momentum_buffer"] = torch.zeros(g.t().size()) + if p.is_cuda: + param_state["momentum_buffer"] = param_state["momentum_buffer"].cuda() + + V = param_state["momentum_buffer"] + V = momentum * V - g.t() + MX = torch.mm(V, unity) + XMX = torch.mm(unity, MX) + XXMX = torch.mm(unity.t(), XMX) + W_hat = MX - 0.5 * XXMX + W = W_hat - W_hat.t() + t = 0.5 * 2 / (matrix_norm_one(W) + episilon) + alpha = min(t, lr) + + p_new = Cayley_loop(unity.t(), W, V, alpha) + V_new = torch.mm(W, unity.t()) # n-by-p + # check_identity(p_new.t()) + p.data.copy_(p_new.view(p.size())) + V.copy_(V_new) + + else: + d_p = p.grad.data + # defined. + try: + if weight_decay != 0: + # defined. + d_p.add_(weight_decay, p.data) + except: + pass + if momentum != 0: + param_state = self.state[p] + if "momentum_buffer" not in param_state: + buf = param_state["momentum_buffer"] = d_p.clone() + else: + buf = param_state["momentum_buffer"] + # always defined. + buf.mul_(momentum).add_(1 - dampening, d_p) + # defined. + if nesterov: + d_p = d_p.add(momentum, buf) + else: + d_p = buf + + p.data.add_(-group["lr"], d_p) + + return loss diff --git a/src/brevitas_examples/common/accelerate_utils/accelerate.py b/src/brevitas_examples/common/accelerate_utils/accelerate.py index ead616ed2..369c456c0 100644 --- a/src/brevitas_examples/common/accelerate_utils/accelerate.py +++ b/src/brevitas_examples/common/accelerate_utils/accelerate.py @@ -407,7 +407,6 @@ def offload_model( else: device_map = infer_auto_device_map( model, memory_map, no_split_module_classes=model._no_split_modules) - model = dispatch_model(model, device_map) # Fixes an asymetric behavior in Accelerate where hooks are not attached at all when a single device is used. diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py new file mode 100644 index 000000000..618763498 --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -0,0 +1,98 @@ +""" +Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +""" + +from dataclasses import dataclass +from dataclasses import field +from typing import Optional, Tuple + +import torch +from torch.utils.data import Dataset +import transformers +from transformers import Trainer +from transformers.tokenization_utils import PreTrainedTokenizerBase + +from brevitas.optim.sgdg import SGDG +from brevitas_examples.llm.llm_quant.rotation_utils import extract_trainable_rotation_matrices + + +@dataclass +class ModelArguments: + input_model: Optional[str] = field( + default="meta-llama/Llama-3.2-1B", metadata={"help": "Input model"}) + output_rotation_path: Optional[str] = field( + default="test-output", metadata={"help": "Output rotation checkpoint path"}) + optimized_rotation_path: Optional[str] = field( + default=None, metadata={"help": "Optimized rotation checkpoint path"}) + access_token: Optional[str] = field( + default="", + metadata={"help": "Huggingface access token to access gated repo like Llama"}, + ) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + output_dir: Optional[str] = field(default="/tmp/output/") + use_cpu: Optional[bool] = field(default="False") + model_max_length: Optional[int] = field( + default=2048, + metadata={ + "help": + "Maximum sequence length. Sequences will be right padded (and possibly truncated)"}, + ) + + +def parse_optimization_rotation_args(unknown_args=None) -> None: + parser = transformers.HfArgumentParser(( + ModelArguments, + TrainingArguments, + )) + _, training_args = parser.parse_args_into_dataclasses(args=unknown_args) + return training_args + + +def collate_fn(kwargs_list, return_tensors="pt"): + # Keyword arguments + kwargs = {} + for curr_dict in kwargs_list: + for key, value in curr_dict.items(): + if isinstance(value, torch.Tensor): + if key not in kwargs: + kwargs[key] = [] + kwargs[key].append(value) + else: + if key not in kwargs: + kwargs[key] = value + for key, value in kwargs.items(): + if isinstance(value, list) and len(value) > 0: + kwargs[key] = torch.cat(kwargs[key], dim=0) + # FP outputs + return kwargs + + +def apply_rotation_optimization( + graph_model: torch.fx.GraphModule, + tokenizer: PreTrainedTokenizerBase, + train_dataset: Dataset, + unknown_args=None) -> None: + # Get training arguments + training_args = parse_optimization_rotation_args(unknown_args) + # Set to False the model parameters + for param in graph_model.parameters(): + param.requires_grad = False + # Collect trainable matrices + trainable_rotations = extract_trainable_rotation_matrices(graph_model) + for rot_mat in trainable_rotations: + rot_mat.requires_grad = True + optimizer = SGDG(trainable_rotations, lr=training_args.learning_rate, stiefel=True) + trainer = Trainer( + model=graph_model, + tokenizer=tokenizer, + args=training_args, + train_dataset=train_dataset, + eval_dataset=None, + data_collator=collate_fn, + optimizers=(optimizer, None)) + trainer.train() diff --git a/src/brevitas_examples/llm/llm_quant/rotation_utils.py b/src/brevitas_examples/llm/llm_quant/rotation_utils.py new file mode 100644 index 000000000..9c84aeff7 --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/rotation_utils.py @@ -0,0 +1,94 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import re +from typing import List + +from torch import nn +from torch.fx import GraphModule +import torch.nn.utils.parametrize as parametrize + +from brevitas.graph.base import ModuleInstanceToModuleInstance +from brevitas.graph.base import Transform +from brevitas.graph.equalize import EqualizationIndexes +from brevitas.graph.equalize import Region +from brevitas.graph.equalize import WalkRegionState +from brevitas.nn.equalized_layer import RotationWeightParametrization + + +def find_self_attention_rotation_regions( + graph_model: GraphModule, head_dim: int, state_impl_kwargs=None) -> List[Region]: + regions = [] + # See R2 rotation matrices in https://arxiv.org/pdf/2405.16406. + for src_name, src_module in graph_model.named_modules(): + if "attn_v_proj" in src_name: + if state_impl_kwargs is not None: + state = WalkRegionState(**state_impl_kwargs) + else: + state = WalkRegionState() + + block_number_matches_src = re.findall(r'\d+', src_name) + assert len(block_number_matches_src) == 2, "Could not identify block" + block_number_src = int(block_number_matches_src[1]) + + eq_indexes = EqualizationIndexes(0, head_dim, 0) + state.add_srcs(src_name, src_module, eq_indexes) + + # Now the corresponding sink + for sink_name, sink_module in graph_model.named_modules(): + if "attn_o_proj" in sink_name: + block_number_matches_sink = re.findall(r'\d+', sink_name) + assert len(block_number_matches_sink) == 2, "Could not identify block" + block_number_sink = int(block_number_matches_sink[1]) + # If the blocks match, the region was identified + if block_number_src == block_number_sink: + eq_indexes = EqualizationIndexes(0, head_dim, state.offset) + state.add_sinks(sink_name, sink_module, eq_indexes) + region = Region( + srcs=dict(sorted(state.srcs.items())), + sinks=dict(sorted(state.sinks.items())), + name_to_module=state.name_to_module, + ) + if region not in regions: + regions.append(region) + + return regions + + +def fuse_rotations(model: nn.Module) -> None: + for module in model.modules(): + # Check if the module has any parametrizations + if hasattr(module, "parametrizations"): + # Remove weight parametrizations + parametrize.remove_parametrizations(module, "weight", leave_parametrized=True) + # We need to check again, in case the weight parametrizations were the only ones + if hasattr(module, "parametrizations") and hasattr(module.parametrizations, "bias"): + parametrize.remove_parametrizations(module, "bias", leave_parametrized=True) + + +# TODO: Remove? We rely on ModuleInstanceRegisterParametrization +def extract_rewriters_unfused_rotations(model: nn.Module, + rewriters: List[Transform]) -> List[Transform]: + extra_rewriters = [] + for module in model.modules(): + if hasattr(module, "parametrizations"): + # Verify that the current module does not have already associated a RotatedModule + if len([r for r in rewriters if r.old_module_instance is module and + isinstance(r, ModuleInstanceToModuleInstance)]) == 0: + # Identity rewriter, only useful externaly + rewriter = ModuleInstanceToModuleInstance(module, module) + extra_rewriters.append(rewriter) + return extra_rewriters + + +def extract_trainable_rotation_matrices(model: nn.Module) -> List[nn.Parameter]: + trainable_rotations = [] + # We need to keep track of the IDs of the rotation matrices, as several modules + # can share the same parametrized rotation. + ids_rot = set() + for module in model.modules(): + if isinstance(module, RotationWeightParametrization): + if id(module.rot_mat) not in ids_rot: + ids_rot.add(id(module.rot_mat)) + trainable_rotations.append(module.rot_mat) + return trainable_rotations diff --git a/src/brevitas_examples/llm/llm_quant/run_utils.py b/src/brevitas_examples/llm/llm_quant/run_utils.py index 44ba711a5..4acbd87ed 100644 --- a/src/brevitas_examples/llm/llm_quant/run_utils.py +++ b/src/brevitas_examples/llm/llm_quant/run_utils.py @@ -110,15 +110,24 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): return out +def _get_tensor_weight_id(module, tensor_name): + if hasattr(module, "parametrizations") and tensor_name in module.parametrizations: + return id(module.parametrizations[tensor_name].original) + elif hasattr(module, tensor_name): + return id(getattr(module, tensor_name)) + return None + + # This functions remap rewriters so match modules in a potentially different model that shares the same underlying tensors # We rely on the fact that two versions of the same model (eager vs FX) might have different modules id (id(fx_module) != id (eager_module)) # However, the underlying tensors are still shared, so we can recostruct the mapping between the two # modules. def fix_rewriter(rewriters, old_model_ref, tensor_name): + # We need to account for reparametrizations, to make sure the underlying tensors are accessed for r in rewriters: - tensor_id = id(r.old_module_instance.weight) + tensor_id = _get_tensor_weight_id(r.old_module_instance, tensor_name) module = [ - m for m in old_model_ref.modules() - if hasattr(m, tensor_name) and id(m.weight) == tensor_id] + m for m in old_model_ref.modules() if _get_tensor_weight_id(m, tensor_name) == tensor_id + ] r.old_module_instance = module[0] return rewriters diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 495c47919..08eb92769 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -3,7 +3,11 @@ import argparse from copy import deepcopy +from functools import partial +from functools import wraps +import os import sys +from typing import Callable, List from warnings import warn import numpy as np @@ -16,6 +20,8 @@ from brevitas.export import export_torch_qcdq from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager +from brevitas.graph.base import ModuleInstanceFuseRotationWeights +from brevitas.graph.base import Transform from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import layerwise_quantize @@ -39,6 +45,9 @@ from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch from brevitas_examples.llm.llm_quant.prepare_for_quantize import add_zero_bias_to_linear from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers +from brevitas_examples.llm.llm_quant.rotation_optimization import apply_rotation_optimization +from brevitas_examples.llm.llm_quant.rotation_utils import extract_rewriters_unfused_rotations +from brevitas_examples.llm.llm_quant.rotation_utils import find_self_attention_rotation_regions from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 from brevitas_examples.llm.llm_quant.run_utils import fix_rewriter from brevitas_examples.llm.llm_quant.run_utils import get_fx @@ -49,7 +58,56 @@ def set_seed(seed): torch.random.manual_seed(seed) -def fused_rotation_no_fx(model, calibration_loader, args): +def is_main_process(): + return int(os.environ.get('LOCAL_RANK', -1)) in [-1, 0] + + +def on_process(func: Callable, process_index: int): + + @wraps(func) + def _wrapper(model, *args, **kwargs): + curr_process_index = int(os.environ.get('LOCAL_RANK', -1)) + + if curr_process_index == -1 or (process_index == curr_process_index): + print(f"Applying {func.__name__} on process index {curr_process_index}") + return func(model, *args, **kwargs) + else: + print(f"Skipping function {func.__name__} on process index {curr_process_index}") + return model + + return _wrapper + + +on_main_process = partial(on_process, process_index=0) + + +@on_main_process +def apply_fused_rotations(model: torch.nn.Module, rewriters: List[Transform]) -> torch.nn.Module: + model = offload_model(model) + for r in rewriters: + if isinstance(r, ModuleInstanceFuseRotationWeights): + model = r.apply(model) + remove_hooks(model) + return model + + +@on_main_process +def evaluate_model(model: torch.nn.Module, validation_loader, args, tokenizer): + model = offload_model(model) + quant_ppl = compute_perplexity( + model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) + print(f"Perplexity ({args.dataset}): {quant_ppl:.3f}") + remove_hooks(model) + + +# TODO: Use no_grad? The result of fusing the rotations would yield tensor with requires_grad set to False, +# which might no be a problem, as that flag is set in the appropiate QAT/PTQ algorithms. +def fused_rotation_no_fx( + model, + calibration_loader, + args, + fuse_rotations: bool = True, + add_self_attention_regions: bool = False): with torch.no_grad(): new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) apply_layernorm_affine_merge(new_model) @@ -58,17 +116,26 @@ def fused_rotation_no_fx(model, calibration_loader, args): for r in rewriters: r.apply(model) - new_model = offload_model(new_model) eq = GraphRotationEqualization( orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, return_rewriters=True) - new_model, rewriters = eq.apply(new_model) + # Regions with source v_proj and sink o_proj + self_attention_regions = ( + find_self_attention_rotation_regions( + new_model, model.config.hidden_size // + model.config.num_attention_heads) if add_self_attention_regions else None) + new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations, additional_regions=self_attention_regions, apply_inplace_rotations=False) + # Rewriters need to be fixed to point to the module instances of the original model rewriters = fix_rewriter(rewriters, model, 'weight') - + # The weights of the FX model and the original model are tied, so the rotation fusing has already been applied. + # Note that the parametrization registration cannot be done in a model that has been offloaded using + # offload_model, as the change in the state dictionary when registering the parametrization causes the removal + # of the hooks to crash. This is due to the fact that the device_map in the AlignDevicesHook is no longer valid. + model = apply_fused_rotations(model, rewriters) for r in rewriters: - r.apply(model) - remove_hooks(new_model) + if not isinstance(r, ModuleInstanceFuseRotationWeights): + model = r.apply(model) def set_seed(seed): @@ -168,7 +235,7 @@ def validate(args): "or decreasing the sequence length (seqlen)") -def main(args): +def main(args, unknown_args=None): validate(args) set_seed(args.seed) if args.export_prefix is None: @@ -230,16 +297,30 @@ def main(args): if args.eval: assert args.export_target != 'torch_qcdq', "TorchScript QCDQ export and Evaluation simultaneously" - print("Float model eval...") - model = offload_model(model) - float_ppl = compute_perplexity( - model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) - remove_hooks(model) - print(f"Float perplexity ({args.dataset}): {float_ppl:.3f}") + print("Evaluating float model...") + evaluate_model(model, validation_loader, args, tokenizer) + print("Float evaluation done.") if args.replace_rmsnorm: model = replace_rmsnorm_with_torch(model, model.config) + # TODO: Refactor + if args.rotation in ['fused_no_fx_optimize', 'fused_no_fx_optimize_self_attn_region']: + for i in range(len(calibration_loader)): + del calibration_loader[i]["attention_mask"] + calibration_loader[i]["labels"] = calibration_loader[i]["input_ids"] + + def mock_save_pretrained_fn(*args, **kwargs): + pass + + # For a PretrainedModel, the Trainer in accelerate calls save_pretrained after + # finishing the optimization. However, this method no longer works after + # registering parametrizations/quantizing, so this method is mocked to prevent + # a crash. + model.save_pretrained = mock_save_pretrained_fn + model.config.use_cache = False + model.config.loss_type = "ForCausalLM" + if require_fx: if model.__class__.__name__ in _SUPPORTED_MODELS and not args.replace_rmsnorm: model = get_fx(model, is_export=args.export_target is not None) @@ -272,6 +353,12 @@ def main(args): model = eq.apply(model) elif args.rotation == 'fused_no_fx': fused_rotation_no_fx(model, calibration_loader, args) + elif args.rotation == 'fused_no_fx_optimize': + fused_rotation_no_fx( + model, calibration_loader, args, fuse_rotations=False, add_self_attention_regions=False) + elif args.rotation == 'fused_no_fx_optimize_self_attn_region': + fused_rotation_no_fx( + model, calibration_loader, args, fuse_rotations=False, add_self_attention_regions=True) # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing # with all the variability in HF implementations @@ -362,10 +449,32 @@ def main(args): if args.bias_corr: model = add_zero_bias_to_linear(model) - model = offload_model(model) + # We need to run a calibration forward pass to initialize quantization-related parameters, + # e.g. scales. In DDP, as parameters are synchronized across replicas before optimization, + # it is not needed to run this pass for every process, as the parameters of the main + # process will be broadcasted to each replica. + if is_main_process(): + model = offload_model(model) + with torch.no_grad(): + model(**calibration_loader[0]) + remove_hooks(model) + else: + # TODO: Generalize this logic. Currently, only ParameterFromStatsFromParameterZeroPoint + # and ParameterFromStatsFromParameterScaling have the attribute init_done + for module in model.modules(): + if hasattr(module, "init_done"): + module.init_done = True + + if args.rotation in ['fused_no_fx_optimize', 'fused_no_fx_optimize_self_attn_region']: + apply_rotation_optimization( + graph_model=model, + tokenizer=tokenizer, + train_dataset=calibration_loader, + unknown_args=unknown_args, + ) - with torch.no_grad(): - model(**calibration_loader[0]) + remove_hooks(model) + torch.cuda.empty_cache() if args.act_calibration: print("Apply act calibration...") @@ -402,11 +511,9 @@ def main(args): print("Bias correction applied.") if args.eval and not args.no_quantize: - print("Model eval...") - quant_ppl = compute_perplexity( - model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) - print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") - remove_hooks(model) + print("Evaluating quantized model...") + evaluate_model(model, validation_loader, args, tokenizer) + print("Quantized evaluation done.") if args.checkpoint_name is not None: print(f"Saving checkpoint to {args.checkpoint_name}") @@ -605,7 +712,12 @@ def parse_args(args): '--rotation', type=str, default=None, - choices=['fx', 'layerwise', 'fused_no_fx'], + choices=[ + 'fx', + 'layerwise', + 'fused_no_fx', + 'fused_no_fx_optimize', + 'fused_no_fx_optimize_self_attn_region'], help='Apply graph rotation equalization') parser.add_argument( '--rotation-mode', @@ -658,9 +770,9 @@ def parse_args(args): help= "Whether to merge the dataset sequences in case they are shorter than the requested number of samples per sequence. This is useful in case you would like to quantize or evaluate on long sequences (default: %(default)s).", ) - return parser.parse_args(args) + return parser.parse_known_args(args) if __name__ == '__main__': - args = parse_args(sys.argv[1:]) - main(args) + args, unknown_args = parse_args(sys.argv[1:]) + main(args, unknown_args) diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 035cdaadd..53b68cf1e 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -35,6 +35,8 @@ IN_SIZE_LINEAR = (1, 224, 3) IN_SIZE_CONV_SMALL = (1, 3, 32, 32) +IN_FEATURES_LINEAR = 5 + def equalize_test(regions, merge_bias, bias_shrinkage, scale_computation_type): scale_factors_regions = [] @@ -352,6 +354,24 @@ def forward(self, x): return ConvTransposeModel +@pytest_cases.fixture +def linear_model(): + + class LinearModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear_0 = nn.Linear(in_features=5, out_features=5) + self.linear_1 = nn.Linear(in_features=5, out_features=5) + + def forward(self, x): + x = self.linear_0(x) + x = self.linear_1(x) + return x + + return LinearModel + + list_of_fixtures = [ 'residual_model', 'srcsinkconflict_model', @@ -528,3 +548,10 @@ def forward(self, x): rotation_fixtures = fixture_union( 'rotation_fixtures', list_of_rotation_mixtures, ids=list_of_rotation_mixtures) + +list_of_rotation_unfused_mixtures = ['linear_model'] + +rotation_unfused_fixtures = fixture_union( + 'rotation_unfused_fixtures', + list_of_rotation_unfused_mixtures, + ids=list_of_rotation_unfused_mixtures) diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index afb8636e4..2acf8287b 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -2,21 +2,32 @@ # SPDX-License-Identifier: BSD-3-Clause import copy +import itertools +from typing import List, Tuple +from unittest.mock import patch +import pytest import torch +import torch.nn.utils.parametrize as parametrize from torchvision import models from brevitas.fx import symbolic_trace +from brevitas.graph.equalize import _apply_ort_device from brevitas.graph.equalize import _batch_norm from brevitas.graph.equalize import _extract_regions +from brevitas.graph.equalize import _get_input_axis +from brevitas.graph.equalize import _get_output_axis from brevitas.graph.equalize import _is_supported_module from brevitas.graph.equalize import _supported_layers from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import MergeLnAffine +from brevitas.graph.equalize import random_orthogonal_matrix from brevitas.graph.standardize import DuplicateSharedStatelessModule from brevitas.graph.standardize import TorchFunctionalToModule from brevitas.graph.utils import get_module +from brevitas.nn.equalized_layer import RotationBiasParametrization +from brevitas.nn.equalized_layer import RotationWeightParametrization from tests.marker import requires_pt_ge from .equalization_fixtures import * @@ -276,3 +287,101 @@ def test_models(rotation_fixtures, partial_had): if partial_had: last_weight_new = model.linear_2.layer.weight.data assert not torch.allclose(last_weight, last_weight_new) + + +def _rotate_input_output(is_source: bool, is_sink: bool, is_orphan: bool) -> Tuple[bool, bool]: + # Verify that only one flag is enabled simultaneously + assert sum([is_source, is_sink, is_orphan]) <= 1, "Only one flag can be enabled." + + rotate_input, rotate_output = False, False + if is_source: + rotate_output = True + if is_sink or is_orphan: + rotate_input = True + + return rotate_input, rotate_output + + +def _compute_rotated_ouptut_from_matrices( + module: nn.Module, inp: torch.Tensor, rot_mat_input: torch.Tensor, + rot_mat_output: torch.Tensor): + # If the node is a sink, the input is multiplied by the inverse of the rotation matrix x <- xQ^{-1} + inp = inp @ rot_mat_input.t() + # If the node is a source, the output is multiplied by the rotation matrix o <- oQ + out = module(inp) @ rot_mat_output + # Return rotated output + return out + + +# RotationParametrizations can only have one type flag enabled simultaneously (is_source, is_sink, is_orphan). +# Moreover, orphan rotations need to be the outermost rotation, as this cancels out when rotating the input. +def _generate_rotation_flags(N: int) -> List[bool]: + return [ + rotation_flags for rotation_flags in itertools.product([False, True], repeat=3 * N) if ( + all([sum(rotation_flags[i * 3:(i + 1) * 3]) <= 1 + for i in range(N)]) and all([not rotation_flags[i * 3 + 2] for i in range(N - 1)])) + ] + + +@requires_pt_ge('2.4') +@pytest_cases.parametrize('N', [1, 2, 3], ids=lambda x: f"N={x}") +def test_composition_unfused_rotations(N): + torch.manual_seed(SEED) + + for rotation_flags in _generate_rotation_flags(N): + + in_features = IN_FEATURES_LINEAR + module = nn.Linear(in_features=in_features, out_features=in_features) + rot_module = copy.deepcopy(module) + + # Sample input to pass through the block + sample_input = torch.rand((1, in_features),) + # Composite rotation matrices + rot_mat_input = torch.eye(in_features) + rot_mat_output = torch.eye(in_features) + + for i in range(N): + module_rotation_flags = rotation_flags[i * 3:(i + 1) * 3] + is_source, is_sink, is_orphan = module_rotation_flags + rotate_input, rotate_output = _rotate_input_output(is_source, is_sink, is_orphan) + + # Generate a random matrix + rot_mat = random_orthogonal_matrix(in_features).to(dtype=torch.float32) + + # Aggregate rotation matrices + if rotate_input: + rot_mat_input = rot_mat_input @ rot_mat + if rotate_output: + rot_mat_output = rot_mat_output @ rot_mat + + # Compose rotation modules + parametrize.register_parametrization( + rot_module, + "weight", + RotationWeightParametrization( + rot_mat=rot_mat, + rot_func=_apply_ort_device, + axis=_get_output_axis(rot_module) if is_source else _get_input_axis(rot_module), + is_source=is_source, + is_sink=is_sink, + is_orphan=is_orphan, + )) + parametrize.register_parametrization( + rot_module, + "bias", + RotationBiasParametrization( + rot_mat=rot_mat, + rot_func=_apply_ort_device, + input_axis=_get_input_axis(rot_module), + output_axis=_get_output_axis(rot_module), + is_source=is_source, + is_sink=is_sink, + is_orphan=is_orphan, + )) + + gt_output = _compute_rotated_ouptut_from_matrices( + module, sample_input, rot_mat_input, rot_mat_output) + rot_output = rot_module(sample_input) + + # Verify that the rotation operations were computed correctly + assert torch.allclose(gt_output, rot_output, atol=ATOL) diff --git a/tests/brevitas/optim/test_cailey_sgd.py b/tests/brevitas/optim/test_cailey_sgd.py new file mode 100644 index 000000000..4774ec830 --- /dev/null +++ b/tests/brevitas/optim/test_cailey_sgd.py @@ -0,0 +1,128 @@ +""" +Copyright (C) 2024, Advanced Micro Devices, Inc. +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of AMD, Facebook, Deepmind Technologies, NYU, + NEC Laboratories America and IDIAP Research Institute nor the names + of its contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +""" + +from copy import deepcopy +from itertools import product +import math +import sys +from typing import List, Union +import unittest + +from hypothesis import given +import numpy as np +import pytest +import pytest_cases +from pytest_cases import fixture +from scipy.stats import ortho_group +import torch +from torch.nn import Parameter +import torch.nn as nn +from torch.optim.lr_scheduler import LinearLR + +from brevitas.optim.sgdg import SGDG +from tests.conftest import SEED + +torch.manual_seed(SEED) + +from torch.testing._internal.common_optimizers import OptimizerInput + +OPTIMIZER_KWARGS = [{ + "stiefel": True}, { + "stiefel": True, "lr": 1e-2}, { + "stiefel": True, "lr": torch.tensor(0.001)}] +LR_SCHEDULER_ARGS = [ + None, + (LinearLR, { + "start_factor": 1.0, "end_factor": 0.0, "total_iters": 20}),] +DEVICES = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] +DTYPES = [torch.float32] + +device_dtype_parametrize = pytest_cases.parametrize("device, dtype", list(product(DEVICES, DTYPES))) + + +class TestCaileySGD: + + @device_dtype_parametrize + @pytest_cases.parametrize("optimizer_kwargs", OPTIMIZER_KWARGS) + @pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS) + def test_forloop_goes_right_direction(self, device, dtype, optimizer_kwargs, lr_scheduler_args): + optim_cls = SGDG + # Generate a random orthogonal matrix of size NxN. Columns represent orthonormal vector in R^{N} + N = 5 + P = 3 + weight_orthogonal = ortho_group(dim=N, seed=SEED).rvs() + weight_orthonormal = weight_orthogonal / np.linalg.norm(weight_orthogonal, ord=2, axis=0) + # Verify that the matrix is orthonormal + assert np.allclose(np.matmul(weight_orthonormal.T, weight_orthonormal), np.eye(N)) + # Initialize weights, the Cailey SGD optimizer expects a matrix of size PxN, given the + # condition unity.size()[0] <= unity.size()[1] + weight = Parameter( + torch.from_numpy(weight_orthonormal[:, :P].T).to(device=device, dtype=dtype)) + + optimizer = optim_cls([weight], **deepcopy(optimizer_kwargs)) + scheduler = None if lr_scheduler_args is None else lr_scheduler_args[0]( + optimizer, **lr_scheduler_args[1]) + + def closure(): + optimizer.zero_grad() + loss = (weight - torch.eye(N, P, device=device, dtype=dtype).t()).pow(2).sum() + loss.backward() + return loss + + initial_value = closure().item() + for _ in range(20): + closure() + optimizer.step() + if scheduler is not None: + scheduler.step() + + # Verify that iterates stay within the Stiefel manifold + assert torch.allclose( + weight.detach().cpu() @ weight.detach().cpu().t(), + torch.eye(P, P, device=device, dtype=dtype).detach().cpu(), + atol=1e-5, + rtol=1e-6) + + if optimizer_kwargs.get("maximize", False): + assert closure().item() > initial_value + else: + assert closure().item() < initial_value diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 576af04b1..ffa5f6454 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -2,11 +2,15 @@ # SPDX-License-Identifier: BSD-3-Clause from argparse import Namespace +import copy from dataclasses import dataclass +from functools import partial +from itertools import product import logging import os import platform import shutil +from unittest.mock import patch import numpy as np import onnx @@ -14,15 +18,28 @@ import pytest import pytest_cases import torch +from torch import nn +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer from brevitas import config from brevitas import torch_version +from brevitas.graph.equalize import _apply_had_device +from brevitas.nn.equalized_layer import RotatedModule # LLM example depends on optimum-amd, which requires PyTorch>=2.2 +from brevitas_examples.llm.llm_quant.data_utils import get_dataset_for_model +from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch +from brevitas_examples.llm.llm_quant.rotation_utils import extract_trainable_rotation_matrices +from brevitas_examples.llm.llm_quant.rotation_utils import fuse_rotations +from brevitas_examples.llm.main import fused_rotation_no_fx from brevitas_examples.llm.main import main from brevitas_examples.llm.main import parse_args +from tests.conftest import SEED from tests.marker import jit_disabled_for_export from tests.marker import requires_pt_ge +ATOL = 1e-3 + def ptid2pathname(string): return string.replace("/", "-").replace(":", "-") @@ -114,7 +131,7 @@ def small_models_with_ppl(request): @pytest_cases.fixture() def default_run_args(request): - args = UpdatableNamespace(**vars(parse_args([]))) + args = UpdatableNamespace(**vars(parse_args([])[0])) args.nsamples = 2 args.seqlen = 2 args.model = "hf-internal-testing/tiny-random-MistralForCausalLM" @@ -520,3 +537,191 @@ def test_small_models_torch_export(caplog, torch_export_args): filepath = args.export_prefix + ".pt" torchscript_model = torch.jit.load(filepath) os.remove(filepath) + + +# Adapted from https://github.com/facebookresearch/SpinQuant/blob/main/eval_utils/rotation_utils.py#L26 +# This functions needs to be patches to enable passing the generator and ensuring that the orthogonal +# matrices generated are the same. +def _random_orthogonal_matrix(size, generator): + """ + Generate a random orthogonal matrix of the specified size. + First, we generate a random matrix with entries from a standard distribution. + Then, we use QR decomposition to obtain an orthogonal matrix. + Finally, we multiply by a diagonal matrix with diag r to adjust the signs. + + Args: + size (int): The size of the matrix (size x size). + + Returns: + torch.Tensor: An orthogonal matrix of the specified size. + """ + torch.cuda.empty_cache() + random_matrix = torch.randn(size, size, dtype=torch.float64, generator=generator) + q, r = torch.linalg.qr(random_matrix) + q *= torch.sign(torch.diag(r)).unsqueeze(0).float() + return q + + +@pytest_cases.fixture( + ids=[ + "llama",], + params=[ + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "input_bit_width": None, + "fuse_sequences": False, + "act_calibration": False,},]) +def equalize_args(default_run_args, request): + args = default_run_args + export_dict = request.param + args.update(**export_dict) + yield args + + +# Auxiliar method to compare the weights in rotated modules. +def _compare_fused_unfused_rotation_modules(module_name, fused_rot_module, unfused_rot_module): + fused_weight = fused_rot_module.weight if isinstance( + fused_rot_module, nn.Linear) else fused_rot_module.layer.weight + fused_bias = fused_rot_module.bias if isinstance( + fused_rot_module, nn.Linear) else fused_rot_module.layer.bias + unfused_weight = unfused_rot_module.weight if isinstance( + unfused_rot_module, nn.Linear) else unfused_rot_module.layer.weight + unfused_bias = unfused_rot_module.bias if isinstance( + unfused_rot_module, nn.Linear) else unfused_rot_module.layer.bias + assert torch.allclose(fused_weight, unfused_weight, rtol=0.0, atol=0.0), f"The weights after rotation do not match for module {module_name}." + if fused_bias is not None: + assert torch.allclose(fused_bias, unfused_bias, rtol=0.0, atol=0.0), f"The bias after rotation do not match for module {module_name}." + # In case a RotatedModule is found, additional checks need to be done. + if isinstance(fused_rot_module, RotatedModule): + assert isinstance(unfused_rot_module, RotatedModule), f"Expected an instance of RotatedModule for module {module_name}." + assert torch.allclose(fused_rot_module.had_mat, unfused_rot_module.had_mat, rtol=0.0, atol=0.0), f"The rotation matrices of RotatedModule {module_name} do not match." + + +@pytest.mark.llm +@requires_pt_ge('2.4') +@pytest_cases.parametrize( + 'partial_had, fused_rotations, add_additional_regions', + list(product([False, True], repeat=3)), + ids=[("fused-R1" if fused_rotations else "R1") + ("-R2" if add_additional_regions else "") + + ("-R3" if partial_had else "") for partial_had, + fused_rotations, + add_additional_regions in list(product([False, True], repeat=3))], +) +@pytest_cases.parametrize('rotation_mode', ['ort', 'had']) +def test_small_models_rotations( + caplog, partial_had, fused_rotations, add_additional_regions, rotation_mode, equalize_args): + caplog.set_level(logging.INFO) + args = equalize_args + args.rotation_orphan_sink = partial_had + args.rotation_mode = rotation_mode + + kwargs = {"torch_dtype": torch.float16} + model = AutoModelForCausalLM.from_pretrained(args.model, **kwargs) + model = replace_rmsnorm_with_torch(model, model.config) + model.config.use_cache = False + print("Model loaded.") + model.eval() + tokenizer = AutoTokenizer.from_pretrained(args.model) + # Load the data for calibration and evaluation. + calibration_loader = get_dataset_for_model( + args.model, + dataset_name=args.dataset, + tokenizer=tokenizer, + nsamples=args.nsamples, + seqlen=args.seqlen, + split="train", + seed=args.seed, + require_fx=False, + device=None, + fuse_sequences=args.fuse_sequences) + + # We need to make sure that the same random matrices are being generated + generator = torch.Generator() + generator.manual_seed(SEED) + # Clone generator to make sure we can use the same rotation matrices + generator_clone = generator.clone_state() + + # Run model and save outputs + with torch.no_grad(): + original_logits = model(**calibration_loader[0]).logits + + # Save a copy to apply graph rotation equalization on + model_copy = copy.deepcopy(model) + + # offload_model is patched to behave as an identity, thus making sure that the operations + # are deterministic, enabling to test that the tensors match exactly. + with patch('brevitas_examples.llm.main.offload_model', lambda m: m): + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator)): + fused_rotation_no_fx( + model, + calibration_loader, + args, + fuse_rotations=True, + add_self_attention_regions=add_additional_regions) + + # Run model and save outputs + with torch.no_grad(): + expected_logits = model(**calibration_loader[0]).logits + + # Instead of random orthogonal matrices, we want to use the same ones as when the activations are not fused. + with patch('brevitas_examples.llm.main.offload_model', lambda m: m): + if rotation_mode == 'had': + with patch('brevitas.graph.equalize._apply_ort_device', _apply_had_device): + fused_rotation_no_fx( + model_copy, + calibration_loader, + args, + fuse_rotations=False, + add_self_attention_regions=add_additional_regions) + else: + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator_clone)): + fused_rotation_no_fx( + model_copy, + calibration_loader, + args, + fuse_rotations=False, + add_self_attention_regions=add_additional_regions) + + # Fuse matrices with module weights + if fused_rotations: + fuse_rotations(model_copy) + + # Run model and save outputs + with torch.no_grad(): + logits = model_copy(**calibration_loader[0]).logits + + # Verify that the rotated module output is similar to the original FP + assert torch.allclose(original_logits, logits, atol=ATOL), "Output of rotated network does not approximately match that of the original network." + # Verify that the output is the same + assert torch.allclose(expected_logits, logits, atol=0.0, rtol=0.0), "Outputs of fused/unfused rotated networks do not match exactly." + + num_rotation_matrices = len(extract_trainable_rotation_matrices(model_copy)) + + num_rotated_modules = 0 + # Count the number of RotatedModules + for module in model_copy.modules(): + if isinstance(module, RotatedModule): + num_rotated_modules += 1 + + # Verify that the number of learnable rotation matrices is the expected (R1 + one R2 per block) + expected_number_rotation_matrices = 0 if fused_rotations else ( + 1 + (model.config.num_hidden_layers if add_additional_regions else 0)) + assert num_rotation_matrices == expected_number_rotation_matrices, f"Expected {expected_number_rotation_matrices} learnable rotations, found {num_rotation_matrices}." + + # Verify that the number of rotated modules is correct + expected_number_rotated_modules = 0 if not partial_had else ( + model.config.num_hidden_layers if add_additional_regions else 2 * + model.config.num_hidden_layers) + assert num_rotated_modules == expected_number_rotated_modules, f"Expected {expected_number_rotated_modules} RotatedModules found {num_rotated_modules}." + + # Verify that the weights after fusing match + for name_fused_module, fused_module in model.named_modules(): + # For linear modules verify that the weights match + if isinstance(fused_module, (nn.Linear, RotatedModule)): + for name_unfused_Module, unfused_module in model_copy.named_modules(): + if name_fused_module == name_unfused_Module: + # Verify that everything matches between the fused and unfused rotation modules + _compare_fused_unfused_rotation_modules( + name_fused_module, fused_module, unfused_module)