From 21e5e00853e47bb66882b982ba19e0c445b0f684 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 28 Nov 2024 15:17:35 +0000 Subject: [PATCH 01/11] Cailey SGD --- src/brevitas/optim/sgdg.py | 197 ++++++++++++++++++++++++ tests/brevitas/optim/test_cailey_sgd.py | 128 +++++++++++++++ 2 files changed, 325 insertions(+) create mode 100644 src/brevitas/optim/sgdg.py create mode 100644 tests/brevitas/optim/test_cailey_sgd.py diff --git a/src/brevitas/optim/sgdg.py b/src/brevitas/optim/sgdg.py new file mode 100644 index 000000000..d1e89e39b --- /dev/null +++ b/src/brevitas/optim/sgdg.py @@ -0,0 +1,197 @@ +# coding=utf-8 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# This code is originally from: https://github.com/JunLi-Galios/Optimization-on-Stiefel-Manifold-via-Cayley-Transform/blob/master/stiefel_optimizer.py + +import random + +import torch +from torch.optim.optimizer import Optimizer + + +def unit(v, dim: int = 1, eps: float = 1e-8): + vnorm = norm(v, dim) + return v / vnorm.add(eps), vnorm + + +def norm(v, dim: int = 1): + assert len(v.size()) == 2 + return v.norm(p=2, dim=dim, keepdim=True) + + +def matrix_norm_one(W): + out = torch.abs(W) + out = torch.sum(out, dim=0) + out = torch.max(out) + return out + + +def Cayley_loop(X, W, tan_vec, t): # + [n, p] = X.size() + Y = X + t * tan_vec + for i in range(5): + Y = X + t * torch.matmul(W, 0.5 * (X + Y)) + + return Y.t() + + +def qr_retraction(tan_vec): # tan_vec, p-by-n, p <= n + [p, n] = tan_vec.size() + tan_vec.t_() + q, r = torch.linalg.qr(tan_vec) + d = torch.diag(r, 0) + ph = d.sign() + q *= ph.expand_as(q) + q.t_() + + return q + + +episilon = 1e-8 + + +class SGDG(Optimizer): + r"""This optimizer updates variables with two different routines + based on the boolean variable 'stiefel'. + + If stiefel is True, the variables will be updated by SGD-G proposed + as decorrelated weight matrix. + + If stiefel is False, the variables will be updated by SGD. + This routine was taken from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + + -- common parameters + lr (float): learning rate + momentum (float, optional): momentum factor (default: 0) + stiefel (bool, optional): whether to use SGD-G (default: False) + + -- parameters in case stiefel is False + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + + -- parameters in case stiefel is True + omega (float, optional): orthogonality regularization factor (default: 0) + grad_clip (float, optional): threshold for gradient norm clipping (default: None) + """ + + def __init__( + self, + params, + lr: float = 1e-3, + momentum: int = 0, + dampening: int = 0, + weight_decay: int = 0, + nesterov: bool = False, + stiefel: bool = False, + omega: int = 0, + grad_clip=None, + ) -> None: + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + stiefel=stiefel, + omega=0, + grad_clip=grad_clip, + ) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super(SGDG, self).__init__(params, defaults) + + def __setstate__(self, state) -> None: + super(SGDG, self).__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + momentum = group["momentum"] + stiefel = group["stiefel"] + + for p in group["params"]: + if p.grad is None: + continue + + unity, _ = unit(p.data.view(p.size()[0], -1)) + if stiefel and unity.size()[0] <= unity.size()[1]: + weight_decay = group["weight_decay"] + dampening = group["dampening"] + nesterov = group["nesterov"] + + rand_num = random.randint(1, 101) + if rand_num == 1: + unity = qr_retraction(unity) + + g = p.grad.data.view(p.size()[0], -1) + + lr = group["lr"] + + param_state = self.state[p] + if "momentum_buffer" not in param_state: + param_state["momentum_buffer"] = torch.zeros(g.t().size()) + if p.is_cuda: + param_state["momentum_buffer"] = param_state["momentum_buffer"].cuda() + + V = param_state["momentum_buffer"] + V = momentum * V - g.t() + MX = torch.mm(V, unity) + XMX = torch.mm(unity, MX) + XXMX = torch.mm(unity.t(), XMX) + W_hat = MX - 0.5 * XXMX + W = W_hat - W_hat.t() + t = 0.5 * 2 / (matrix_norm_one(W) + episilon) + alpha = min(t, lr) + + p_new = Cayley_loop(unity.t(), W, V, alpha) + V_new = torch.mm(W, unity.t()) # n-by-p + # check_identity(p_new.t()) + p.data.copy_(p_new.view(p.size())) + V.copy_(V_new) + + else: + d_p = p.grad.data + # defined. + try: + if weight_decay != 0: + # defined. + d_p.add_(weight_decay, p.data) + except: + pass + if momentum != 0: + param_state = self.state[p] + if "momentum_buffer" not in param_state: + buf = param_state["momentum_buffer"] = d_p.clone() + else: + buf = param_state["momentum_buffer"] + # always defined. + buf.mul_(momentum).add_(1 - dampening, d_p) + # defined. + if nesterov: + d_p = d_p.add(momentum, buf) + else: + d_p = buf + + p.data.add_(-group["lr"], d_p) + + return loss diff --git a/tests/brevitas/optim/test_cailey_sgd.py b/tests/brevitas/optim/test_cailey_sgd.py new file mode 100644 index 000000000..4774ec830 --- /dev/null +++ b/tests/brevitas/optim/test_cailey_sgd.py @@ -0,0 +1,128 @@ +""" +Copyright (C) 2024, Advanced Micro Devices, Inc. +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of AMD, Facebook, Deepmind Technologies, NYU, + NEC Laboratories America and IDIAP Research Institute nor the names + of its contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +""" + +from copy import deepcopy +from itertools import product +import math +import sys +from typing import List, Union +import unittest + +from hypothesis import given +import numpy as np +import pytest +import pytest_cases +from pytest_cases import fixture +from scipy.stats import ortho_group +import torch +from torch.nn import Parameter +import torch.nn as nn +from torch.optim.lr_scheduler import LinearLR + +from brevitas.optim.sgdg import SGDG +from tests.conftest import SEED + +torch.manual_seed(SEED) + +from torch.testing._internal.common_optimizers import OptimizerInput + +OPTIMIZER_KWARGS = [{ + "stiefel": True}, { + "stiefel": True, "lr": 1e-2}, { + "stiefel": True, "lr": torch.tensor(0.001)}] +LR_SCHEDULER_ARGS = [ + None, + (LinearLR, { + "start_factor": 1.0, "end_factor": 0.0, "total_iters": 20}),] +DEVICES = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] +DTYPES = [torch.float32] + +device_dtype_parametrize = pytest_cases.parametrize("device, dtype", list(product(DEVICES, DTYPES))) + + +class TestCaileySGD: + + @device_dtype_parametrize + @pytest_cases.parametrize("optimizer_kwargs", OPTIMIZER_KWARGS) + @pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS) + def test_forloop_goes_right_direction(self, device, dtype, optimizer_kwargs, lr_scheduler_args): + optim_cls = SGDG + # Generate a random orthogonal matrix of size NxN. Columns represent orthonormal vector in R^{N} + N = 5 + P = 3 + weight_orthogonal = ortho_group(dim=N, seed=SEED).rvs() + weight_orthonormal = weight_orthogonal / np.linalg.norm(weight_orthogonal, ord=2, axis=0) + # Verify that the matrix is orthonormal + assert np.allclose(np.matmul(weight_orthonormal.T, weight_orthonormal), np.eye(N)) + # Initialize weights, the Cailey SGD optimizer expects a matrix of size PxN, given the + # condition unity.size()[0] <= unity.size()[1] + weight = Parameter( + torch.from_numpy(weight_orthonormal[:, :P].T).to(device=device, dtype=dtype)) + + optimizer = optim_cls([weight], **deepcopy(optimizer_kwargs)) + scheduler = None if lr_scheduler_args is None else lr_scheduler_args[0]( + optimizer, **lr_scheduler_args[1]) + + def closure(): + optimizer.zero_grad() + loss = (weight - torch.eye(N, P, device=device, dtype=dtype).t()).pow(2).sum() + loss.backward() + return loss + + initial_value = closure().item() + for _ in range(20): + closure() + optimizer.step() + if scheduler is not None: + scheduler.step() + + # Verify that iterates stay within the Stiefel manifold + assert torch.allclose( + weight.detach().cpu() @ weight.detach().cpu().t(), + torch.eye(P, P, device=device, dtype=dtype).detach().cpu(), + atol=1e-5, + rtol=1e-6) + + if optimizer_kwargs.get("maximize", False): + assert closure().item() > initial_value + else: + assert closure().item() < initial_value From ee00639ef3395e6cb3078d7224980ea000031a51 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 6 Dec 2024 14:01:11 +0000 Subject: [PATCH 02/11] Draft implementation unfused rotation --- src/brevitas/graph/equalize.py | 95 ++++++++- src/brevitas/nn/equalized_layer.py | 95 +++++++++ tests/brevitas/graph/equalization_fixtures.py | 27 +++ tests/brevitas/graph/test_equalization.py | 196 ++++++++++++++++++ 4 files changed, 409 insertions(+), 4 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 4e5c1a162..aad00ca99 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -3,6 +3,7 @@ from abc import ABC from abc import abstractmethod +from collections import defaultdict from dataclasses import dataclass from dataclasses import field from functools import partial @@ -34,6 +35,7 @@ from brevitas.nn.equalized_layer import functional_rotate_input from brevitas.nn.equalized_layer import INPUT_NAMES from brevitas.nn.equalized_layer import RotatedModule +from brevitas.nn.equalized_layer import UnfusedRotatedModule from brevitas.nn.quant_scale_bias import ScaleBias from brevitas.utils.torch_utils import KwargsForwardHook @@ -339,6 +341,8 @@ def _get_input_axis(module: nn.Module) -> Optional[int]: return 0 else: return None + elif isinstance(module, (UnfusedRotatedModule)): + return _get_input_axis(module.module) else: return None @@ -367,6 +371,8 @@ def _get_output_axis(module: nn.Module) -> Optional[int]: return 0 else: return None + elif isinstance(module, (UnfusedRotatedModule)): + return _get_output_axis(module.module) else: return None @@ -1307,7 +1313,7 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= if not insert_rotation_module and not region.is_valid: continue hidden_dim = region.max_shape_sinks - if not insert_rotation_module and full_rotation_method == 'ort': + if full_rotation_method == 'ort': rot_mat = random_orthogonal_matrix(hidden_dim) K = None rot_func = _apply_ort_device @@ -1374,6 +1380,82 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= return rewriters +@dataclass +class UnfusedRotation: + rot_mat: torch.Tensor + is_sink: bool + is_source: bool + is_orphan: bool + + +def _apply_unfused_rotate(model: nn.Module, regions: List[Region], full_rotation_method='ort'): + rewriters = [] + fused_rotated_modules = defaultdict(list) + rot_func = _apply_ort_device + + for region in regions: + insert_rotation_module = len(region.srcs) == 0 + + if not insert_rotation_module and not region.is_valid: + continue + hidden_dim = region.max_shape_sinks + + rot_mat = random_orthogonal_matrix(hidden_dim) + + for name, indexes in region.srcs.items(): + module = region.get_module_from_name(name) + + fused_rotated_modules[module].append( + UnfusedRotation( + rot_mat=rot_mat, + is_sink=False, + is_source=True, + is_orphan=False, + )) + + for name, indexes in region.sinks.items(): + module = region.get_module_from_name(name) + + if insert_rotation_module and len(region.srcs) == 0: + fused_rotated_modules[module].append( + UnfusedRotation( + rot_mat=rot_mat, + is_sink=False, + is_source=False, + is_orphan=True, + )) + else: + fused_rotated_modules[module].append( + UnfusedRotation( + rot_mat=rot_mat, + is_sink=True, + is_source=False, + is_orphan=False, + )) + + for module, rotation_modules in fused_rotated_modules.items(): + rotation_module = module + for rotation_module_dataclass in rotation_modules: + rotation_module = UnfusedRotatedModule( + module=rotation_module, + rot_func=rot_func, + rot_mat=rotation_module_dataclass.rot_mat, + _get_input_axis=_get_input_axis, + _get_output_axis=_get_output_axis, + is_source=rotation_module_dataclass.is_source, + is_sink=rotation_module_dataclass.is_sink, + is_orphan=rotation_module_dataclass.is_orphan, + ) + rewriter = ModuleInstanceToModuleInstance( + module, + rotation_module, + ) + rewriters.append(rewriter) + for r in rewriters: + model = r.apply(model) + return rewriters + + def _replace_bias(next_module, new_bias): new_bias = new_bias.view(-1) if next_module.bias is not None: @@ -1463,8 +1545,10 @@ def rotate_matmuls(self, graph_module): graph_module.recompile() graph_module.graph.lint() - def apply(self, - graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: + def apply( + self, + graph_model: GraphModule, + fuse_rotations: bool = True) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: rewriters = [] regions = _extract_regions( graph_model, @@ -1488,7 +1572,10 @@ def apply(self, if self.rotate_matmul: self.rotate_matmuls(graph_model) if len(regions) > 0: - rewriters = _apply_rotate(graph_model, regions, self.full_rotation_method) + if fuse_rotations: + rewriters = _apply_rotate(graph_model, regions, self.full_rotation_method) + else: + rewriters = _apply_unfused_rotate(graph_model, regions, self.full_rotation_method) if self.return_rewriters: return graph_model, rewriters else: diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index 8413a8208..09c6dd70e 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -1,4 +1,5 @@ from inspect import signature +from typing import Callable, Optional import torch @@ -78,6 +79,100 @@ def forward(self, inp, **kwargs): return o +class UnfusedRotatedModule(torch.nn.Module): + + def __init__( + self, + module: torch.nn.Module, + rot_func: Callable, + rot_mat: torch.Tensor, + _get_input_axis: Callable, + _get_output_axis: Callable, + is_source: bool = False, + is_sink: bool = False, + is_orphan: bool = False, + ) -> None: + super().__init__() + self.module = module + self.rot_func = rot_func + self.rot_mat = torch.nn.Parameter(rot_mat).cpu() + + # TODO: This were included to prevent circular imports. + self._get_input_axis = _get_input_axis + self._get_output_axis = _get_output_axis + + self.is_source = is_source + self.is_sink = is_sink + self.is_orphan = is_orphan + + # These properties enable propagating the fusing to the module weights + @property + def weight(self) -> Optional[torch.Tensor]: + weight = getattr(self.module, 'weight', None) + # Add rotation and let these being propagated till the parent + # unfused rotated module + if self.is_sink or self.is_orphan: + axis = self._get_input_axis(self.module) + if axis == 1: + weight = self.rot_func(weight, self.rot_mat) + elif axis == 0: + weight = self.rot_func(weight.t(), self.rot_mat).t() + else: + raise RuntimeError("Not supported yet") + + if self.is_source: + axis = self._get_output_axis(self.module) + if axis == 0: + weight = self.rot_func(weight.t(), self.rot_mat).t() + elif axis == 1: + weight = self.rot_func(weight, self.rot_mat) + else: + raise RuntimeError("Not supported yet") + + return weight + + @property + def bias(self) -> Optional[torch.Tensor]: + bias = getattr(self.module, 'bias', None) + # Propagate bias adding the rotations incrementally + if self.is_source: + if bias is not None: + bias = self.rot_func(bias, self.rot_mat) + + return bias + + def forward(self, inp, **kwargs): + # Rotated matrices + weight = self.weight.data + bias = self.bias.data if self.bias is not None else None + + # Propagate calls till getting to the original module being rotated + child_module = self.module + # Iterate until the original module is reached, keeping the rotations that need to be performed on the input + while isinstance(child_module, UnfusedRotatedModule): + child_module = child_module.module + # child_module contains the original module in the network. Before applying its forward method, we need to + # rotate the inpute appropiately + if self.is_orphan: + # Rotate the input for an orphan sink + inp = self.rot_func(inp, self.rot_mat) + # Modify the weights, and run the original model forward. After that, restore the previous values. + if weight is not None: + orig_weight = child_module.weight.data + child_module.weight.data = weight + if bias is not None: + orig_bias = child_module.bias.data + child_module.bias.data = bias + # Call forward of the original module + o = child_module(inp) + # Restore un-rotated weights + child_module.weight.data = orig_weight + if bias is not None: + child_module.bias.data = orig_bias + # Return rotated output + return o + + def functional_rotate_input(inp, transpose=False): is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None if transpose: diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 035cdaadd..53b68cf1e 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -35,6 +35,8 @@ IN_SIZE_LINEAR = (1, 224, 3) IN_SIZE_CONV_SMALL = (1, 3, 32, 32) +IN_FEATURES_LINEAR = 5 + def equalize_test(regions, merge_bias, bias_shrinkage, scale_computation_type): scale_factors_regions = [] @@ -352,6 +354,24 @@ def forward(self, x): return ConvTransposeModel +@pytest_cases.fixture +def linear_model(): + + class LinearModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear_0 = nn.Linear(in_features=5, out_features=5) + self.linear_1 = nn.Linear(in_features=5, out_features=5) + + def forward(self, x): + x = self.linear_0(x) + x = self.linear_1(x) + return x + + return LinearModel + + list_of_fixtures = [ 'residual_model', 'srcsinkconflict_model', @@ -528,3 +548,10 @@ def forward(self, x): rotation_fixtures = fixture_union( 'rotation_fixtures', list_of_rotation_mixtures, ids=list_of_rotation_mixtures) + +list_of_rotation_unfused_mixtures = ['linear_model'] + +rotation_unfused_fixtures = fixture_union( + 'rotation_unfused_fixtures', + list_of_rotation_unfused_mixtures, + ids=list_of_rotation_unfused_mixtures) diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index afb8636e4..2c4b4ecee 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -2,21 +2,33 @@ # SPDX-License-Identifier: BSD-3-Clause import copy +from functools import partial +import itertools +from typing import List, Tuple +from unittest.mock import patch +import pytest import torch from torchvision import models from brevitas.fx import symbolic_trace +# TODO: Refactor to prevent circular import +from brevitas.graph.equalize import _apply_ort_device from brevitas.graph.equalize import _batch_norm from brevitas.graph.equalize import _extract_regions +from brevitas.graph.equalize import _get_input_axis +from brevitas.graph.equalize import _get_output_axis from brevitas.graph.equalize import _is_supported_module from brevitas.graph.equalize import _supported_layers from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import MergeLnAffine +from brevitas.graph.equalize import random_orthogonal_matrix from brevitas.graph.standardize import DuplicateSharedStatelessModule from brevitas.graph.standardize import TorchFunctionalToModule from brevitas.graph.utils import get_module +from brevitas.nn.equalized_layer import RotatedModule +from brevitas.nn.equalized_layer import UnfusedRotatedModule from tests.marker import requires_pt_ge from .equalization_fixtures import * @@ -276,3 +288,187 @@ def test_models(rotation_fixtures, partial_had): if partial_had: last_weight_new = model.linear_2.layer.weight.data assert not torch.allclose(last_weight, last_weight_new) + + +def _rotate_input_output(is_source: bool, is_sink: bool, is_orphan: bool) -> Tuple[bool, bool]: + # Verify that only one flag is enabled at the same time + assert sum([is_source, is_sink, is_orphan]) <= 1, "Only one flag can be enabled." + + rotate_input, rotate_output = False, False + if is_source: + rotate_output = True + if is_sink: + rotate_input = True + + return rotate_input, rotate_output + + +def _compute_rotated_ouptut_from_matrices( + module: nn.Module, inp: torch.Tensor, rot_mat_input: torch.Tensor, + rot_mat_output: torch.Tensor): + # If the node is a sink, the input is multiplied by the inverse of the rotation matrix x <- xQ^{-1} + inp = inp @ rot_mat_input.t() + # If the node is a source, the output is multiplied by the rotation matrix o <- oQ + out = module(inp) @ rot_mat_output + # Return rotated output + return out + + +# NOTE: The assumption is that only one flag can be true simultaneously +# NOTE: Orphans need to be taken care of. A module can only be orphan once. +def _generate_rotation_flags(N: int) -> List[bool]: + return [ + rotation_flags for rotation_flags in itertools.product([False, True], repeat=3 * N) if ( + all([sum(rotation_flags[i * 3:(i + 1) * 3]) <= 1 for i in range(N)]) and + # Only outermost rotation can be orphan + all([not rotation_flags[i * 3 + 2] for i in range(N - 1)]))] + + +@requires_pt_ge('2.4') +@pytest_cases.parametrize('N', [1, 2, 3], ids=lambda x: f"N={x}") +def test_composition_unfused_rotation_layer(N): + torch.manual_seed(SEED) + + for rotation_flags in _generate_rotation_flags(N): + + in_features = IN_FEATURES_LINEAR + module = nn.Linear(in_features=in_features, out_features=in_features) + + # Sample input to pass through the block + sample_input = torch.rand((1, in_features),) + + # Compose rotation modules + rotated_module = module + + # Composite rotation matrices + rot_mat_input = torch.eye(in_features) + rot_mat_output = torch.eye(in_features) + + for i in range(N): + module_rotation_flags = rotation_flags[i * 3:(i + 1) * 3] + is_source, is_sink, is_orphan = module_rotation_flags + rotate_input, rotate_output = _rotate_input_output(is_source, is_sink, is_orphan) + + # Generate a random matrix + rot_mat = random_orthogonal_matrix(in_features).to(dtype=torch.float32) + + # Aggregate rotation matrices + if rotate_input: + rot_mat_input = rot_mat_input @ rot_mat + if rotate_output: + rot_mat_output = rot_mat_output @ rot_mat + + # Compose rotation modules + rotated_module = UnfusedRotatedModule( + module=rotated_module, + rot_func=_apply_ort_device, + _get_input_axis=_get_input_axis, + _get_output_axis=_get_output_axis, + rot_mat=rot_mat, + is_source=is_source, + is_sink=is_sink, + is_orphan=is_orphan, + ) + + # Compute outputs to compare + gt_output = _compute_rotated_ouptut_from_matrices( + module, sample_input, rot_mat_input, rot_mat_output) + rot_output = rotated_module(sample_input) + + # Verify that the rotation operations were computed correctly + assert torch.allclose(gt_output, rot_output, atol=ATOL) + + +# Adapted from https://github.com/facebookresearch/SpinQuant/blob/main/eval_utils/rotation_utils.py#L26 +def _random_orthogonal_matrix(size, generator): + """ + Generate a random orthogonal matrix of the specified size. + First, we generate a random matrix with entries from a standard distribution. + Then, we use QR decomposition to obtain an orthogonal matrix. + Finally, we multiply by a diagonal matrix with diag r to adjust the signs. + + Args: + size (int): The size of the matrix (size x size). + + Returns: + torch.Tensor: An orthogonal matrix of the specified size. + """ + torch.cuda.empty_cache() + random_matrix = torch.randn(size, size, dtype=torch.float64, generator=generator) + q, r = torch.linalg.qr(random_matrix) + q *= torch.sign(torch.diag(r)).unsqueeze(0).float() + return q + + +# This test verifies that the weights returned by the unfused rotate modules +# match those when fusing +@requires_pt_ge('2.4') +@pytest_cases.parametrize('partial_had', [False, True]) +def test_models_unfused_rotations(rotation_fixtures, partial_had): + + in_shape = IN_SIZE_LINEAR + + model_class = rotation_fixtures + model = model_class() + + model.eval() + inp = torch.rand(in_shape) + with torch.no_grad(): + expected_out = model(inp) + + model = symbolic_trace(model) + merge = MergeLnAffine() + model = merge.apply(model) + eq = GraphRotationEqualization(orphan_sink=partial_had, full_rotation_method='ort') + + # Save a copy to apply graph rotation equalization on + model_copy = copy.deepcopy(model) + + # We need to make sure that the same random matrices are being generated + generator = torch.Generator() + generator.manual_seed(SEED) + # Clone generator to make sure we can use the same rotation matrices + generator_clone = generator.clone_state() + + # We pass the generator to make sure that we can reproduce the random orthogonal matrices that are generated + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator)) as mock_ort_generator: + # Apply rotation equalization while controlling the random matrices that are generated + model = eq.apply(model) + + # Now rotate but without fusing the rotation matrices + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator_clone)) as mock_ort_generator: + # Apply rotation equalization while controlling the random matrices that are generated + model_copy = eq.apply(model_copy, fuse_rotations=False) + + with torch.no_grad(): + out = model_copy(inp) + + # Verify that the output of the model does not change after incorporating the rotations + assert torch.allclose(expected_out, out) + + # Verify that weight matrices + for model_node, model_copy_node in zip(model.graph.nodes, model_copy.graph.nodes): + if model_node.op == 'call_module': + module = get_module(model, model_node.target) + module_copy = get_module(model_copy, model_copy_node.target) + if isinstance(module, (nn.Linear, RotatedModule)): + weight = module.weight if isinstance(module, nn.Linear) else module.layer.weight + bias = module.bias if isinstance(module, nn.Linear) else module.layer.bias + weight_copy = module_copy.weight + bias_copy = module_copy.bias + assert torch.allclose(weight, weight_copy, atol=ATOL) + if bias is not None: + assert torch.allclose(bias, bias_copy, atol=ATOL) + # For a RotatedModule, corresponding to an orphan node, additional checks need to be done + if isinstance(module, RotatedModule): + # The outermost should be an orphan + rotated_module = module_copy + assert rotated_module.is_orphan, "Unfused rotated module needs to be an orphan." + # Check that the inner UnfusedRotatedModules are not orphans + while isinstance(rotated_module.module, UnfusedRotatedModule): + assert not rotated_module.module.is_orphan, "Inner unfused rotated modules cannot be orphans." + rotated_module = rotated_module.module + # Verify that the rotation matrices match + assert torch.allclose(module.had_mat, module_copy.rot_mat) From 26081d0180c566b5a56db0afe0f11542e0b6072b Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 6 Dec 2024 17:55:22 +0000 Subject: [PATCH 03/11] Enable rotation matrix fusing --- src/brevitas/graph/equalize.py | 34 ++++++++ src/brevitas/nn/equalized_layer.py | 30 +++++-- .../llm/llm_quant/rotation_optimization.py | 81 +++++++++++++++++++ src/brevitas_examples/llm/main.py | 17 ++-- tests/brevitas/graph/test_equalization.py | 79 +++++++++++++++++- 5 files changed, 225 insertions(+), 16 deletions(-) create mode 100644 src/brevitas_examples/llm/llm_quant/rotation_optimization.py diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index aad00ca99..b0e02cc78 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1380,6 +1380,40 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= return rewriters +def _fuse_rotations(model: nn.Module): + rewriters = [] + + def _fuse_rotations_aux(module: nn.Module): + if isinstance(module, UnfusedRotatedModule): + unrotated_module = module.unrotated_module + rot_weight = module.weight.data + + # Fuse rotations with weights + unrotated_module.weight.data = rot_weight + # Fuse rotations with bias if existent + if module.bias is not None: + rot_bias = module.bias.data + unrotated_module.bias.data = rot_bias + + # Use rotated module if orphan + if module.is_orphan: + rewriter = ModuleInstanceToModuleInstance( + module, RotatedModule(had_mat=module.rot_mat, k=None, layer=unrotated_module)) + else: + rewriter = ModuleInstanceToModuleInstance(module, unrotated_module) + # Save rewriter + rewriters.append(rewriter) + else: + for child_module in module.children(): + _fuse_rotations_aux(child_module) + + # Populate rewriters + _fuse_rotations_aux(model) + # Apply rewriter to fuse the weights + for r in rewriters: + model = r.apply(model) + + @dataclass class UnfusedRotation: rot_mat: torch.Tensor diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index 09c6dd70e..899612796 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -52,6 +52,11 @@ def forward(self, *args, **kwargs): return out +def _apply_ort_device(tensor, ort, *args): + ort = ort.type_as(tensor) + return torch.matmul(tensor, ort) + + class RotatedModule(torch.nn.Module): def __init__(self, layer, had_mat=None, k=None) -> None: @@ -65,15 +70,19 @@ def __init__(self, layer, had_mat=None, k=None) -> None: def forward(self, inp, **kwargs): is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None - if is_cuda and fast_hadamard_transform is not None: - if self.had_mat is None or self.k is None: - had_K, K = get_hadK(inp.shape[-1]) - else: - had_K = self.had_mat - K = self.k - inp = matmul_hadU_cuda(inp, had_K, K) + # If k is None, we assume that an orthogonal matrix is used + if self.k is None: + inp = _apply_ort_device(inp, self.had_mat) else: - inp = matmul_hadU(inp) + if is_cuda and fast_hadamard_transform is not None: + if self.had_mat is None or self.k is None: + had_K, K = get_hadK(inp.shape[-1]) + else: + had_K = self.had_mat + K = self.k + inp = matmul_hadU_cuda(inp, had_K, K) + else: + inp = matmul_hadU(inp) o = self.layer(inp) return o @@ -141,6 +150,11 @@ def bias(self) -> Optional[torch.Tensor]: return bias + @property + def unrotated_module(self) -> torch.nn.Module: + return self.module.unrotated_module if isinstance( + self.module, UnfusedRotatedModule) else self.module + def forward(self, inp, **kwargs): # Rotated matrices weight = self.weight.data diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py new file mode 100644 index 000000000..809d07bbe --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -0,0 +1,81 @@ +""" +Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +""" + +from dataclasses import dataclass +from dataclasses import field +from typing import Optional, Tuple + +import torch +from torch.utils.data import Dataset +from tqdm import tqdm +import transformers +from transformers import default_data_collator +from transformers import Trainer +from transformers.tokenization_utils import PreTrainedTokenizerBase + +from brevitas.nn.equalized_layer import UnfusedRotatedModule +from brevitas.optim.sgdg import SGDG + + +@dataclass +class ModelArguments: + input_model: Optional[str] = field( + default="hf-internal-testing/tiny-random-LlamaForCausalLM", + metadata={"help": "Input model"}) + output_rotation_path: Optional[str] = field( + default="test-output", metadata={"help": "Output rotation checkpoint path"}) + optimized_rotation_path: Optional[str] = field( + default=None, metadata={"help": "Optimized rotation checkpoint path"}) + access_token: Optional[str] = field( + default="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", + metadata={"help": "Huggingface access token to access gated repo like Llama"}, + ) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + output_dir: Optional[str] = field(default="/tmp/output/") + model_max_length: Optional[int] = field( + default=2048, + metadata={ + "help": + "Maximum sequence length. Sequences will be right padded (and possibly truncated)"}, + ) + + +def parse_optimization_rotation_args(unknown_args=None) -> None: + parser = transformers.HfArgumentParser(( + ModelArguments, + TrainingArguments, + )) + _, training_args = parser.parse_args_into_dataclasses(args=unknown_args) + return training_args + + +def apply_rotation_optimization( + graph_model: torch.fx.GraphModule, + tokenizer: PreTrainedTokenizerBase, + train_dataset: Dataset, + unknown_args=None) -> None: + # Get training arguments + training_args = parse_optimization_rotation_args(unknown_args) + # Collect trainable matrices + trainable_parameters = [] + for module in graph_model.modules(): + if isinstance(module, UnfusedRotatedModule): + if not module.is_sink: + trainable_parameters.append(module.rot_mat) + # Initialize optimizer + optimizer = SGDG(trainable_parameters, lr=training_args.learning_rate, stiefel=True) + trainer = Trainer( + model=graph_model, + tokenizer=tokenizer, + args=training_args, + train_dataset=train_dataset, + eval_dataset=None, + data_collator=default_data_collator, + optimizers=(optimizer, None)) + trainer.train() diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 495c47919..b3d6d4e03 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -39,6 +39,7 @@ from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch from brevitas_examples.llm.llm_quant.prepare_for_quantize import add_zero_bias_to_linear from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers +from brevitas_examples.llm.llm_quant.rotation_optimization import apply_rotation_optimization from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 from brevitas_examples.llm.llm_quant.run_utils import fix_rewriter from brevitas_examples.llm.llm_quant.run_utils import get_fx @@ -168,7 +169,7 @@ def validate(args): "or decreasing the sequence length (seqlen)") -def main(args): +def main(args, unknown_args=None): validate(args) set_seed(args.seed) if args.export_prefix is None: @@ -265,7 +266,7 @@ def main(args): model = offload_model(model) eq = GraphRotationEqualization( orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode) - model = eq.apply(model) + model = eq.apply(model, fuse_rotations=not args.rotation_optimize) remove_hooks(model) elif args.rotation == 'layerwise': eq = LayerwiseActivationRotation() @@ -335,6 +336,7 @@ def main(args): quantize_embedding=False) if not args.quantize_last_layer: if require_fx: + # TODO: Fix when using UnfusedRotation, layer_map[type(last_module)][1] crashes last_node = [node for node in model.graph.nodes if node.op == 'call_module'][-1] last_module = get_module(model, last_node.target) last_layer_kwargs = layer_map[type(last_module)][1] @@ -614,6 +616,11 @@ def parse_args(args): help= 'If GraphRotation is enabled, decide how to compute the random rotation matrix that is fully fused. Online or partial rotation will always be Hadamard' ) + # TODO: Make sure in argument validator that + parser.add_argument( + '--rotation-optimize', + action='store_true', + help='Whether to optimize the rotation matrices.') parser.add_argument( '--rotation-orphan-sink', action="store_true", @@ -658,9 +665,9 @@ def parse_args(args): help= "Whether to merge the dataset sequences in case they are shorter than the requested number of samples per sequence. This is useful in case you would like to quantize or evaluate on long sequences (default: %(default)s).", ) - return parser.parse_args(args) + return parser.parse_known_args(args) if __name__ == '__main__': - args = parse_args(sys.argv[1:]) - main(args) + args, unknown_args = parse_args(sys.argv[1:]) + main(args, unknown_args) diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 2c4b4ecee..d8ea0bcb2 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -16,6 +16,7 @@ from brevitas.graph.equalize import _apply_ort_device from brevitas.graph.equalize import _batch_norm from brevitas.graph.equalize import _extract_regions +from brevitas.graph.equalize import _fuse_rotations from brevitas.graph.equalize import _get_input_axis from brevitas.graph.equalize import _get_output_axis from brevitas.graph.equalize import _is_supported_module @@ -413,8 +414,6 @@ def test_models_unfused_rotations(rotation_fixtures, partial_had): model.eval() inp = torch.rand(in_shape) - with torch.no_grad(): - expected_out = model(inp) model = symbolic_trace(model) merge = MergeLnAffine() @@ -446,7 +445,8 @@ def test_models_unfused_rotations(rotation_fixtures, partial_had): out = model_copy(inp) # Verify that the output of the model does not change after incorporating the rotations - assert torch.allclose(expected_out, out) + with torch.no_grad(): + expected_out = model(inp) # Verify that weight matrices for model_node, model_copy_node in zip(model.graph.nodes, model_copy.graph.nodes): @@ -472,3 +472,76 @@ def test_models_unfused_rotations(rotation_fixtures, partial_had): rotated_module = rotated_module.module # Verify that the rotation matrices match assert torch.allclose(module.had_mat, module_copy.rot_mat) + + +# This test verifies that the weights returned by the unfused rotate modules +# match those when fusing +@requires_pt_ge('2.4') +@pytest_cases.parametrize('partial_had', [False, True]) +def test_models_fused_rotations(rotation_fixtures, partial_had): + + in_shape = IN_SIZE_LINEAR + + model_class = rotation_fixtures + model = model_class() + + model.eval() + inp = torch.rand(in_shape) + + model = symbolic_trace(model) + merge = MergeLnAffine() + model = merge.apply(model) + eq = GraphRotationEqualization(orphan_sink=partial_had, full_rotation_method='ort') + + # Save a copy to apply graph rotation equalization on + model_copy = copy.deepcopy(model) + + # We need to make sure that the same random matrices are being generated + generator = torch.Generator() + generator.manual_seed(SEED) + # Clone generator to make sure we can use the same rotation matrices + generator_clone = generator.clone_state() + + # We pass the generator to make sure that we can reproduce the random orthogonal matrices that are generated + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator)) as mock_ort_generator: + # Apply rotation equalization while controlling the random matrices that are generated + model = eq.apply(model) + + with torch.no_grad(): + expected_out = model(inp) + + # Now rotate but without fusing the rotation matrices + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator_clone)) as mock_ort_generator: + # Apply rotation equalization while controlling the random matrices that are generated + model_copy = eq.apply(model_copy, fuse_rotations=False) + + # Fuse the rotations and make sure the behaviour is the same + _fuse_rotations(model_copy) + + with torch.no_grad(): + out = model_copy(inp) + + # Verify that the output of the model does not change after incorporating the rotations + assert torch.allclose(expected_out, out) + + # Verify that weight matrices + for model_node, model_copy_node in zip(model.graph.nodes, model_copy.graph.nodes): + if model_node.op == 'call_module': + module = get_module(model, model_node.target) + module_copy = get_module(model_copy, model_copy_node.target) + if isinstance(module, (nn.Linear, RotatedModule)): + weight = module.weight if isinstance(module, nn.Linear) else module.layer.weight + bias = module.bias if isinstance(module, nn.Linear) else module.layer.bias + weight_copy = module_copy.weight if isinstance( + module_copy, nn.Linear) else module_copy.layer.weight + bias_copy = module_copy.bias if isinstance( + module_copy, nn.Linear) else module_copy.layer.bias + assert torch.allclose(weight, weight_copy, atol=ATOL) + if bias is not None: + assert torch.allclose(bias, bias_copy, atol=ATOL) + # For a RotatedModule, corresponding to an orphan node, additional checks need to be done + if isinstance(module, RotatedModule): + # Verify that the rotation matrices match + assert torch.allclose(module.had_mat, module_copy.had_mat) From 349a6aeee6a15e6b340eb7f27eb5408787a60938 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 6 Dec 2024 17:56:51 +0000 Subject: [PATCH 04/11] Remove default --- src/brevitas_examples/llm/llm_quant/rotation_optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py index 809d07bbe..8baf07596 100644 --- a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -29,7 +29,7 @@ class ModelArguments: optimized_rotation_path: Optional[str] = field( default=None, metadata={"help": "Optimized rotation checkpoint path"}) access_token: Optional[str] = field( - default="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", + default="", metadata={"help": "Huggingface access token to access gated repo like Llama"}, ) From d129fe9bf9475e037e8116735c98700441de00c0 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 12 Dec 2024 14:24:33 +0000 Subject: [PATCH 05/11] New tests --- src/brevitas/graph/equalize.py | 266 +++++++++++- src/brevitas/graph/hadamard.py | 100 +++++ src/brevitas/nn/equalized_layer.py | 123 +++++- .../llm/llm_quant/rotation_optimization.py | 32 +- src/brevitas_examples/llm/main.py | 68 ++- tests/brevitas/graph/test_equalization.py | 390 ++++++++++++++---- tests/brevitas_examples/test_llm.py | 2 +- 7 files changed, 880 insertions(+), 101 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index b0e02cc78..0cd8ea47c 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -29,6 +29,7 @@ from brevitas.graph.hadamard import get_hadK from brevitas.graph.hadamard import matmul_hadU from brevitas.graph.hadamard import matmul_hadU_cuda +from brevitas.graph.hadamard import random_hadamard_matrix from brevitas.graph.utils import get_module from brevitas.graph.utils import get_node from brevitas.nn.equalized_layer import EqualizedModule @@ -341,8 +342,11 @@ def _get_input_axis(module: nn.Module) -> Optional[int]: return 0 else: return None + # TODO: Remove with parametrizations elif isinstance(module, (UnfusedRotatedModule)): return _get_input_axis(module.module) + elif isinstance(module, (RotatedModule,)): + return _get_input_axis(module.layer) else: return None @@ -371,8 +375,11 @@ def _get_output_axis(module: nn.Module) -> Optional[int]: return 0 else: return None + # TODO: Remove with parametrizations elif isinstance(module, (UnfusedRotatedModule)): return _get_output_axis(module.module) + elif isinstance(module, (RotatedModule,)): + return _get_output_axis(module.layer) else: return None @@ -1281,6 +1288,10 @@ def _apply_had_device(tensor, had_K, K): def _apply_ort_device(tensor, ort, *args): ort = ort.type_as(tensor) + if tensor.shape[-1] != ort.shape[0]: + tensor_shape = tensor.shape + return torch.matmul(tensor.view(-1, tensor_shape[-1] // ort.shape[0], ort.shape[0]), + ort).view(tensor_shape) return torch.matmul(tensor, ort) @@ -1313,6 +1324,7 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= if not insert_rotation_module and not region.is_valid: continue hidden_dim = region.max_shape_sinks + # TODO: Include again not insert_rotation_module if full_rotation_method == 'ort': rot_mat = random_orthogonal_matrix(hidden_dim) K = None @@ -1606,10 +1618,8 @@ def apply( if self.rotate_matmul: self.rotate_matmuls(graph_model) if len(regions) > 0: - if fuse_rotations: - rewriters = _apply_rotate(graph_model, regions, self.full_rotation_method) - else: - rewriters = _apply_unfused_rotate(graph_model, regions, self.full_rotation_method) + _apply_rotate_fn = _apply_rotate if fuse_rotations else _apply_unfused_rotate + rewriters = _apply_rotate_fn(graph_model, regions, self.full_rotation_method) if self.return_rewriters: return graph_model, rewriters else: @@ -1702,9 +1712,253 @@ def __init__(self, blacklist_layer=None): self.supported_sinks = (nn.Linear) self.blacklist_layers = blacklist_layer - def apply(self, model: nn.Module) -> nn.Module: + def apply(self, model: nn.Module, fuse_rotations: bool = True) -> nn.Module: regions: List[Region] = [] self.find_module(model, regions) if len(regions) > 0: - _apply_rotate(model, regions) + _apply_rotate_fn = _apply_rotate if fuse_rotations else _apply_unfused_rotate + _apply_rotate_fn(model, regions) return model + + +def find_missing_rotation_regions(graph_model: GraphModule, + head_dim: int, + state_impl_kwargs=None) -> List[Region]: + import re + + regions = [] + # Add R2 regions, this should be innermost + for src_name, src_module in graph_model.named_modules(): + if "attn_v_proj" in src_name: + if state_impl_kwargs is not None: + state = WalkRegionState(**state_impl_kwargs) + else: + state = WalkRegionState() + + block_number_matches_src = re.findall(r'\d+', src_name) + assert len(block_number_matches_src) == 2, "Could not identify block" + block_number_src = int(block_number_matches_src[1]) + + eq_indexes = EqualizationIndexes(0, head_dim, 0) + state.add_srcs(src_name, src_module, eq_indexes) + + # Now the corresponding sink + for sink_name, sink_module in graph_model.named_modules(): + if "attn_o_proj" in sink_name: + block_number_matches_sink = re.findall(r'\d+', sink_name) + assert len(block_number_matches_sink) == 2, "Could not identify block" + block_number_sink = int(block_number_matches_sink[1]) + # If the blocks match, we identified the region + if block_number_src == block_number_sink: + eq_indexes = EqualizationIndexes(0, head_dim, state.offset) + state.add_sinks(sink_name, sink_module, eq_indexes) + # Instantiate region and add to list + region = Region( + srcs=dict(sorted(state.srcs.items())), + sinks=dict(sorted(state.sinks.items())), + name_to_module=state.name_to_module, + ) + if region not in regions: + regions.append(region) + + return regions + + +def _apply_rotate_fused_rotations( + model: nn.Module, + regions: List[Region], + full_rotation_method: str = 'had', + fuse_rotations: bool = True): + rewriters = [] + # Dictionary to append the unfused rotated modules for optimization + unfused_rotated_modules = defaultdict(list) + # Dictionary to keep track of the modules that are assigned to a RotatedModule + fused_rotated_modules = {} + # List to keep track of the rotation matrices added to the + rotation_matrices = [] + + for region in regions: + insert_rotation_module = len(region.srcs) == 0 + + if not insert_rotation_module and not region.is_valid: + continue + hidden_dim = region.max_shape_sinks + if not insert_rotation_module and full_rotation_method == 'ort': + rot_mat = random_orthogonal_matrix( + hidden_dim) if fuse_rotations else torch.nn.Parameter( + random_orthogonal_matrix(hidden_dim)) + K = None + rot_func = _apply_ort_device + # Store rotation matrix for optimization + rotation_matrices.append(rot_mat) + elif not insert_rotation_module and not fuse_rotations: + # TODO: Make it more general + rot_mat = torch.nn.Parameter(random_hadamard_matrix(hidden_dim, torch.device('cpu'))) + K = None + rot_func = _apply_ort_device + # Store rotation matrix for optimization + rotation_matrices.append(rot_mat) + else: + try: + # Build hadamard rotation matrix + rot_mat, K = get_hadK(hidden_dim) + rot_func = _apply_had_device + except AssertionError as e: + print(f"Incomptible shapes {hidden_dim}") + if not insert_rotation_module: + print("Falling back to orthogonal matrices") + rot_mat = random_orthogonal_matrix(hidden_dim) + K = None + rot_func = _apply_ort_device + print("Skipping layers") + continue + + for name, indexes in region.srcs.items(): + module = region.get_module_from_name(name) + + if not insert_rotation_module and fuse_rotations: + if hasattr(module, 'allocate_params'): + module.allocate_params(module) + + axis = _get_output_axis(module) + weight = module.weight.data + + if axis == 0: + weight = rot_func(weight.t(), rot_mat, K).t() + elif axis == 1: + weight = rot_func(weight, rot_mat, K) + else: + raise RuntimeError("Not supported yet") + module.weight.data = weight + + if getattr(module, 'bias', None) is not None: + bias = module.bias.data + bias = rot_func(bias, rot_mat, K) + module.bias.data = bias + + if hasattr(module, 'offload_params'): + module.offload_params(module) + else: + unfused_rotated_modules[module].append( + UnfusedRotation( + rot_mat=rot_mat, + is_sink=False, + is_source=True, + is_orphan=False, + )) + + for name, indexes in region.sinks.items(): + module = region.get_module_from_name(name) + + if not insert_rotation_module and not fuse_rotations: + unfused_rotated_modules[module].append( + UnfusedRotation( + rot_mat=rot_mat, + is_sink=True, + is_source=False, + is_orphan=False, + )) + else: + if hasattr(module, 'allocate_params'): + module.allocate_params(module) + + axis = _get_input_axis(module) + weight = module.weight.data + + if axis == 1: + _update_weights(module, rot_func(weight, rot_mat, K), 'weight') + elif axis == 0: + _update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight') + else: + raise RuntimeError("Not supported yet") + + if hasattr(module, 'offload_params'): + module.offload_params(module) + + if insert_rotation_module: + if module not in fused_rotated_modules: + fused_rotated_modules[module] = RotatedModule( + had_mat=rot_mat, k=K, layer=module) + else: + raise RuntimeError( + "Only one RotatedModule at most can be assigned to a module.") + # For this to work, we need to have the following hierarchy UnfusedRotatedModule -> (RotatedModule) -> Linear + for module, rotation_modules in unfused_rotated_modules.items(): + # Verify that at most one RotatedModule is available + rotation_module = module if module not in fused_rotated_modules else fused_rotated_modules[ + module] + for rotation_module_dataclass in rotation_modules: + rotation_module = UnfusedRotatedModule( + module=rotation_module, + rot_func=rot_func, + rot_mat=rotation_module_dataclass.rot_mat, + _get_input_axis=_get_input_axis, + _get_output_axis=_get_output_axis, + is_source=rotation_module_dataclass.is_source, + is_sink=rotation_module_dataclass.is_sink, + is_orphan=rotation_module_dataclass.is_orphan, + ) + # Instantiate rewriters + rewriter = ModuleInstanceToModuleInstance(module, rotation_module) + rewriters.append(rewriter) + # Add missing RotatedModules, in case there are any + for module, rotation_module in fused_rotated_modules.items(): + if module not in unfused_rotated_modules: + rewriter = ModuleInstanceToModuleInstance(module, rotation_module) + rewriters.append(rewriter) + for r in rewriters: + model = r.apply(model) + return rewriters, rotation_matrices + + +class GraphRotationEqualizationOptimization(GraphRotationEqualization): + + def __init__( + self, + blacklist_layers: Optional[List[str]] = None, + orphan_sink: bool = False, + rotate_matmul: bool = False, + full_rotation_method: str = 'had', + ) -> None: + super(GraphRotationEqualizationOptimization, self).__init__( + blacklist_layers=blacklist_layers, + orphan_sink=orphan_sink, + rotate_matmul=rotate_matmul, + full_rotation_method=full_rotation_method, + return_rewriters=True, + ) + + def apply( + self, + graph_model: GraphModule, + fuse_rotations: bool = True, + additional_regions: Optional[List] = None + ) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: + rewriters = [] + regions = _extract_regions( + graph_model, + state_impl_kwargs={ + 'supported_srcs': self.supported_srcs, + 'supported_sinks': self.supported_sinks, + 'scale_invariant_layers': self.scale_invariant_layers, + 'scale_invariant_function': self.scale_invariant_function}) + if additional_regions is not None: + regions.extend(additional_regions) + eq_layers = set() + orphan_regions = [] + self.find_module(graph_model, orphan_regions) + for r in regions: + id_list = [id(r.name_to_module[sink_name]) for sink_name in r.sinks_names] + eq_layers.update(id_list) + if self.orphan_sink: + for o_r in orphan_regions: + # Layerwise have only a single sink named 'sinks0' + id_sink = id(o_r.get_module_from_name('sinks0')) + if id_sink not in eq_layers: + # TODO: Change data structure to insert in the beginning with O(1) + regions = [o_r] + regions + if self.rotate_matmul: + self.rotate_matmuls(graph_model) + if len(regions) > 0: + rewriters, rotation_matrices = _apply_rotate_fused_rotations(graph_model, regions, self.full_rotation_method, fuse_rotations) + return graph_model, rewriters, rotation_matrices diff --git a/src/brevitas/graph/hadamard.py b/src/brevitas/graph/hadamard.py index 235e22567..29a09ebed 100644 --- a/src/brevitas/graph/hadamard.py +++ b/src/brevitas/graph/hadamard.py @@ -66,6 +66,23 @@ def get_hadK(n, transpose=False): assert (is_pow2(n // 12)) K = 12 hadK = tensors['get_had12'].T if transpose else tensors['get_had12'] + # TODO: Add this matrix along with the others + elif n % 64 == 0: + assert (is_pow2(n // 64)) + K = 64 + hadK = torch.tensor([[1 if char == '+' else -1 + for char in line] + for line in hadamard_string_64.strip().split('\n')], + dtype=torch.float32, + requires_grad=False) + elif n % 16 == 0: + assert (is_pow2(n // 16)) + K = 16 + hadK = torch.tensor([[1 if char == '+' else -1 + for char in line] + for line in hadamard_string_16.strip().split('\n')], + dtype=torch.float32, + requires_grad=False) else: assert (is_pow2(n)) K = 1 @@ -166,3 +183,86 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False): def is_pow2(n): return (n & (n - 1) == 0) and (n > 0) + + +hadamard_string_16 = """++++++++++++++++ ++-+-+-+-+-+-+-+- +++--++--++--++-- ++--++--++--++--+ +++++----++++---- ++-+--+-++-+--+-+ +++----++++----++ ++--+-++-+--+-++- +++++++++-------- ++-+-+-+--+-+-+-+ +++--++----++--++ ++--++--+-++--++- +++++--------++++ ++-+--+-+-+-++-+- +++----++--++++-- ++--+-++--++-+--+""" + +hadamard_string_64 = """++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+- +++--++--++--++--++--++--++--++--++--++--++--++--++--++--++--++-- ++--++--++--++--++--++--++--++--++--++--++--++--++--++--++--++--+ +++++----++++----++++----++++----++++----++++----++++----++++---- ++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-+ +++----++++----++++----++++----++++----++++----++++----++++----++ ++--+-++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-++- +++++++++--------++++++++--------++++++++--------++++++++-------- ++-+-+-+--+-+-+-++-+-+-+--+-+-+-++-+-+-+--+-+-+-++-+-+-+--+-+-+-+ +++--++----++--++++--++----++--++++--++----++--++++--++----++--++ ++--++--+-++--++-+--++--+-++--++-+--++--+-++--++-+--++--+-++--++- +++++--------++++++++--------++++++++--------++++++++--------++++ ++-+--+-+-+-++-+-+-+--+-+-+-++-+-+-+--+-+-+-++-+-+-+--+-+-+-++-+- +++----++--++++--++----++--++++--++----++--++++--++----++--++++-- ++--+-++--++-+--++--+-++--++-+--++--+-++--++-+--++--+-++--++-+--+ +++++++++++++++++----------------++++++++++++++++---------------- ++-+-+-+-+-+-+-+--+-+-+-+-+-+-+-++-+-+-+-+-+-+-+--+-+-+-+-+-+-+-+ +++--++--++--++----++--++--++--++++--++--++--++----++--++--++--++ ++--++--++--++--+-++--++--++--++-+--++--++--++--+-++--++--++--++- +++++----++++--------++++----++++++++----++++--------++++----++++ ++-+--+-++-+--+-+-+-++-+--+-++-+-+-+--+-++-+--+-+-+-++-+--+-++-+- +++----++++----++--++++----++++--++----++++----++--++++----++++-- ++--+-++-+--+-++--++-+--+-++-+--++--+-++-+--+-++--++-+--+-++-+--+ +++++++++----------------++++++++++++++++----------------++++++++ ++-+-+-+--+-+-+-+-+-+-+-++-+-+-+-+-+-+-+--+-+-+-+-+-+-+-++-+-+-+- +++--++----++--++--++--++++--++--++--++----++--++--++--++++--++-- ++--++--+-++--++--++--++-+--++--++--++--+-++--++--++--++-+--++--+ +++++--------++++----++++++++----++++--------++++----++++++++---- ++-+--+-+-+-++-+--+-++-+-+-+--+-++-+--+-+-+-++-+--+-++-+-+-+--+-+ +++----++--++++----++++--++----++++----++--++++----++++--++----++ ++--+-++--++-+--+-++-+--++--+-++-+--+-++--++-+--+-++-+--++--+-++- +++++++++++++++++++++++++++++++++-------------------------------- ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+--+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +++--++--++--++--++--++--++--++----++--++--++--++--++--++--++--++ ++--++--++--++--++--++--++--++--+-++--++--++--++--++--++--++--++- +++++----++++----++++----++++--------++++----++++----++++----++++ ++-+--+-++-+--+-++-+--+-++-+--+-+-+-++-+--+-++-+--+-++-+--+-++-+- +++----++++----++++----++++----++--++++----++++----++++----++++-- ++--+-++-+--+-++-+--+-++-+--+-++--++-+--+-++-+--+-++-+--+-++-+--+ +++++++++--------++++++++----------------++++++++--------++++++++ ++-+-+-+--+-+-+-++-+-+-+--+-+-+-+-+-+-+-++-+-+-+--+-+-+-++-+-+-+- +++--++----++--++++--++----++--++--++--++++--++----++--++++--++-- ++--++--+-++--++-+--++--+-++--++--++--++-+--++--+-++--++-+--++--+ +++++--------++++++++--------++++----++++++++--------++++++++---- ++-+--+-+-+-++-+-+-+--+-+-+-++-+--+-++-+-+-+--+-+-+-++-+-+-+--+-+ +++----++--++++--++----++--++++----++++--++----++--++++--++----++ ++--+-++--++-+--++--+-++--++-+--+-++-+--++--+-++--++-+--++--+-++- +++++++++++++++++--------------------------------++++++++++++++++ ++-+-+-+-+-+-+-+--+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-++-+-+-+-+-+-+-+- +++--++--++--++----++--++--++--++--++--++--++--++++--++--++--++-- ++--++--++--++--+-++--++--++--++--++--++--++--++-+--++--++--++--+ +++++----++++--------++++----++++----++++----++++++++----++++---- ++-+--+-++-+--+-+-+-++-+--+-++-+--+-++-+--+-++-+-+-+--+-++-+--+-+ +++----++++----++--++++----++++----++++----++++--++----++++----++ ++--+-++-+--+-++--++-+--+-++-+--+-++-+--+-++-+--++--+-++-+--+-++- +++++++++----------------++++++++--------++++++++++++++++-------- ++-+-+-+--+-+-+-+-+-+-+-++-+-+-+--+-+-+-++-+-+-+-+-+-+-+--+-+-+-+ +++--++----++--++--++--++++--++----++--++++--++--++--++----++--++ ++--++--+-++--++--++--++-+--++--+-++--++-+--++--++--++--+-++--++- +++++--------++++----++++++++--------++++++++----++++--------++++ ++-+--+-+-+-++-+--+-++-+-+-+--+-+-+-++-+-+-+--+-++-+--+-+-+-++-+- +++----++--++++----++++--++----++--++++--++----++++----++--++++-- ++--+-++--++-+--+-++-+--++--+-++--++-+--++--+-++-+--+-++--++-+--+""" diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index 899612796..fbe045124 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -1,3 +1,4 @@ +import functools from inspect import signature from typing import Callable, Optional @@ -68,6 +69,14 @@ def __init__(self, layer, had_mat=None, k=None) -> None: self.layer = layer self.k = k + @property + def weight(self) -> Optional[torch.Tensor]: + return getattr(self.layer, 'weight', None) + + @property + def bias(self) -> Optional[torch.Tensor]: + return getattr(self.layer, 'bias', None) + def forward(self, inp, **kwargs): is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None # If k is None, we assume that an orthogonal matrix is used @@ -88,13 +97,91 @@ def forward(self, inp, **kwargs): return o +def rot_func_wrapper(weight: torch.Tensor, rot_mat: torch.Tensor, rotation_function: Callable): + weight_shape = weight.shape + rot_mat_dim = rot_mat.shape[0] + return rotation_function(weight.view(-1, weight_shape.shape[1] // rot_mat_dim, + rot_mat_dim)).view(weight_shape) + + +class RotationWeightParametrization(torch.nn.Module): + + def __init__( + self, + rot_mat: torch.nn.Parameter, + rot_func: Callable, + input_axis: Optional[int] = None, + output_axis: Optional[int] = None, + is_source: bool = False, + is_sink: bool = False, + is_orphan: bool = False, + ) -> None: + super().__init__() + self.rot_mat = rot_mat + self.rot_func = rot_func + self.input_axis = input_axis + self.output_axis = output_axis + self.is_source = is_source + self.is_sink = is_sink + self.is_orphan = is_orphan + self.K = None + + def forward(self, weight: torch.Tensor) -> torch.Tensor: + if self.is_sink or self.is_orphan: + if self.input_axis == 1: + weight = self.rot_func(weight, self.rot_mat, self.K) + elif self.input_axis == 0: + weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() + else: + raise RuntimeError("Not supported yet") + + if self.is_source: + if self.output_axis == 0: + weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() + elif self.output_axis == 1: + weight = self.rot_func(weight, self.rot_mat, self.K) + else: + raise RuntimeError("Not supported yet") + + return weight + + +class RotationBiasParametrization(torch.nn.Module): + + def __init__( + self, + rot_mat: torch.nn.Parameter, + rot_func: Callable, + input_axis: Optional[int] = None, + output_axis: Optional[int] = None, + is_source: bool = False, + is_sink: bool = False, + is_orphan: bool = False, + ) -> None: + super().__init__() + self.rot_mat = rot_mat + self.rot_func = rot_func + self.input_axis = input_axis + self.output_axis = output_axis + self.is_source = is_source + self.is_sink = is_sink + self.is_orphan = is_orphan + self.K = None + + def forward(self, bias: torch.Tensor) -> torch.Tensor: + if self.is_source: + bias = self.rot_func(bias, self.rot_mat, self.K) + + return bias + + class UnfusedRotatedModule(torch.nn.Module): def __init__( self, module: torch.nn.Module, rot_func: Callable, - rot_mat: torch.Tensor, + rot_mat: torch.nn.Parameter, _get_input_axis: Callable, _get_output_axis: Callable, is_source: bool = False, @@ -104,7 +191,8 @@ def __init__( super().__init__() self.module = module self.rot_func = rot_func - self.rot_mat = torch.nn.Parameter(rot_mat).cpu() + self.rot_mat = rot_mat + self.K = None # TODO: This were included to prevent circular imports. self._get_input_axis = _get_input_axis @@ -114,6 +202,25 @@ def __init__( self.is_sink = is_sink self.is_orphan = is_orphan + # TODO: Does it make sense the extra complexity just to prevent the view operation? + # Probably if no reshaping needs to be done, no change is required + def _wrap_rot(self) -> bool: + weight_shape = self.module.weight.shape + rot_dim = self.rot_mat.shape[0] + if self.is_sink or self.is_orphan: + weight_shape_dim = weight_shape[self._get_input_axis(self.module)] + elif self.is_source: + weight_shape_dim = weight_shape[self._get_output_axis(self.module)] + else: + weight_shape_dim = None + + if weight_shape_dim is not None: + if rot_dim != weight_shape_dim: + assert weight_shape_dim % rot_dim == 0, "Sizes need to be divisibile" + return True + # No need to incorporate additional view operations + return False + # These properties enable propagating the fusing to the module weights @property def weight(self) -> Optional[torch.Tensor]: @@ -123,18 +230,18 @@ def weight(self) -> Optional[torch.Tensor]: if self.is_sink or self.is_orphan: axis = self._get_input_axis(self.module) if axis == 1: - weight = self.rot_func(weight, self.rot_mat) + weight = self.rot_func(weight, self.rot_mat, self.K) elif axis == 0: - weight = self.rot_func(weight.t(), self.rot_mat).t() + weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() else: raise RuntimeError("Not supported yet") if self.is_source: axis = self._get_output_axis(self.module) if axis == 0: - weight = self.rot_func(weight.t(), self.rot_mat).t() + weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() elif axis == 1: - weight = self.rot_func(weight, self.rot_mat) + weight = self.rot_func(weight, self.rot_mat, self.K) else: raise RuntimeError("Not supported yet") @@ -146,7 +253,7 @@ def bias(self) -> Optional[torch.Tensor]: # Propagate bias adding the rotations incrementally if self.is_source: if bias is not None: - bias = self.rot_func(bias, self.rot_mat) + bias = self.rot_func(bias, self.rot_mat, self.K) return bias @@ -169,7 +276,7 @@ def forward(self, inp, **kwargs): # rotate the inpute appropiately if self.is_orphan: # Rotate the input for an orphan sink - inp = self.rot_func(inp, self.rot_mat) + inp = self.rot_func(inp, self.rot_mat, self.K) # Modify the weights, and run the original model forward. After that, restore the previous values. if weight is not None: orig_weight = child_module.weight.data diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py index 8baf07596..7b2baecae 100644 --- a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -38,6 +38,7 @@ class ModelArguments: class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) output_dir: Optional[str] = field(default="/tmp/output/") + use_cpu: Optional[bool] = field(default="False") model_max_length: Optional[int] = field( default=2048, metadata={ @@ -55,6 +56,25 @@ def parse_optimization_rotation_args(unknown_args=None) -> None: return training_args +def collate_fn(kwargs_list, return_tensors="pt"): + # Keyword arguments + kwargs = {} + for curr_dict in kwargs_list: + for key, value in curr_dict.items(): + if isinstance(value, torch.Tensor): + if key not in kwargs: + kwargs[key] = [] + kwargs[key].append(value) + else: + if key not in kwargs: + kwargs[key] = value + for key, value in kwargs.items(): + if isinstance(value, list) and len(value) > 0: + kwargs[key] = torch.cat(kwargs[key], dim=0) + # FP outputs + return kwargs + + def apply_rotation_optimization( graph_model: torch.fx.GraphModule, tokenizer: PreTrainedTokenizerBase, @@ -62,12 +82,20 @@ def apply_rotation_optimization( unknown_args=None) -> None: # Get training arguments training_args = parse_optimization_rotation_args(unknown_args) + # Set to False the model parameters + for param in graph_model.parameters(): + param.requires_grad = False # Collect trainable matrices trainable_parameters = [] + ids_rot = set() for module in graph_model.modules(): if isinstance(module, UnfusedRotatedModule): - if not module.is_sink: + if id(module.rot_mat) not in ids_rot: + ids_rot.add(id(module.rot_mat)) trainable_parameters.append(module.rot_mat) + # Collect parameters for the rotation matrices + for rot_mat in trainable_parameters: + rot_mat.requires_grad = True # Initialize optimizer optimizer = SGDG(trainable_parameters, lr=training_args.learning_rate, stiefel=True) trainer = Trainer( @@ -76,6 +104,6 @@ def apply_rotation_optimization( args=training_args, train_dataset=train_dataset, eval_dataset=None, - data_collator=default_data_collator, + data_collator=collate_fn, optimizers=(optimizer, None)) trainer.train() diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index b3d6d4e03..df522e901 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -16,7 +16,9 @@ from brevitas.export import export_torch_qcdq from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager +from brevitas.graph.equalize import find_missing_rotation_regions from brevitas.graph.equalize import GraphRotationEqualization +from brevitas.graph.equalize import GraphRotationEqualizationOptimization from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.utils import get_module @@ -50,7 +52,7 @@ def set_seed(seed): torch.random.manual_seed(seed) -def fused_rotation_no_fx(model, calibration_loader, args): +def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool = False): with torch.no_grad(): new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) apply_layernorm_affine_merge(new_model) @@ -64,7 +66,7 @@ def fused_rotation_no_fx(model, calibration_loader, args): orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, return_rewriters=True) - new_model, rewriters = eq.apply(new_model) + new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations) rewriters = fix_rewriter(rewriters, model, 'weight') for r in rewriters: @@ -72,6 +74,36 @@ def fused_rotation_no_fx(model, calibration_loader, args): remove_hooks(new_model) +def fused_optimized_rotation_no_fx( + model, + calibration_loader, + args, + fuse_rotations: bool = False, + add_additional_regions: bool = False): + with torch.no_grad(): + new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) + apply_layernorm_affine_merge(new_model) + new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True) + rewriters = fix_rewriter(rewriters, model, 'weight') + + for r in rewriters: + r.apply(model) + #new_model = offload_model(new_model) + additional_regions = find_missing_rotation_regions( + new_model, model.config.hidden_size // + model.config.num_attention_heads) if add_additional_regions else None + eq = GraphRotationEqualizationOptimization( + orphan_sink=args.rotation_orphan_sink, + full_rotation_method=args.rotation_mode, + ) + new_model, rewriters, rotation_matrices = eq.apply(new_model, fuse_rotations=fuse_rotations, additional_regions=additional_regions) + rewriters = fix_rewriter(rewriters, model, 'weight') + + for r in rewriters: + r.apply(model) + #remove_hooks(new_model) + + def set_seed(seed): np.random.seed(seed) torch.random.manual_seed(seed) @@ -241,6 +273,15 @@ def main(args, unknown_args=None): if args.replace_rmsnorm: model = replace_rmsnorm_with_torch(model, model.config) + # TODO: Refactor + if args.rotation == 'fused_no_fx_optimize': + for i in range(len(calibration_loader)): + del calibration_loader[i]["attention_mask"] + calibration_loader[i]["labels"] = calibration_loader[i]["input_ids"] + + model.config.use_cache = False + model.config.loss_type = "ForCausalLM" + if require_fx: if model.__class__.__name__ in _SUPPORTED_MODELS and not args.replace_rmsnorm: model = get_fx(model, is_export=args.export_target is not None) @@ -266,13 +307,16 @@ def main(args, unknown_args=None): model = offload_model(model) eq = GraphRotationEqualization( orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode) - model = eq.apply(model, fuse_rotations=not args.rotation_optimize) + model = eq.apply(model) remove_hooks(model) elif args.rotation == 'layerwise': eq = LayerwiseActivationRotation() model = eq.apply(model) elif args.rotation == 'fused_no_fx': fused_rotation_no_fx(model, calibration_loader, args) + elif args.rotation == 'fused_no_fx_optimize': + fused_optimized_rotation_no_fx( + model, calibration_loader, args, fuse_rotations=False, add_additional_regions=True) # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing # with all the variability in HF implementations @@ -369,6 +413,17 @@ def main(args, unknown_args=None): with torch.no_grad(): model(**calibration_loader[0]) + # TODO: Refactor + remove_hooks(model) + + if args.rotation == 'fused_no_fx_optimize': + apply_rotation_optimization( + graph_model=model, + tokenizer=tokenizer, + train_dataset=calibration_loader, + unknown_args=unknown_args, + ) + if args.act_calibration: print("Apply act calibration...") apply_calibration(model, calibration_loader) @@ -607,7 +662,7 @@ def parse_args(args): '--rotation', type=str, default=None, - choices=['fx', 'layerwise', 'fused_no_fx'], + choices=['fx', 'layerwise', 'fused_no_fx', 'fused_no_fx_optimize'], help='Apply graph rotation equalization') parser.add_argument( '--rotation-mode', @@ -616,11 +671,6 @@ def parse_args(args): help= 'If GraphRotation is enabled, decide how to compute the random rotation matrix that is fully fused. Online or partial rotation will always be Hadamard' ) - # TODO: Make sure in argument validator that - parser.add_argument( - '--rotation-optimize', - action='store_true', - help='Whether to optimize the rotation matrices.') parser.add_argument( '--rotation-orphan-sink', action="store_true", diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index d8ea0bcb2..0f65db168 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -25,6 +25,7 @@ from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import MergeLnAffine from brevitas.graph.equalize import random_orthogonal_matrix +from brevitas.graph.hadamard import matmul_hadU from brevitas.graph.standardize import DuplicateSharedStatelessModule from brevitas.graph.standardize import TorchFunctionalToModule from brevitas.graph.utils import get_module @@ -401,11 +402,47 @@ def _random_orthogonal_matrix(size, generator): return q +def _random_hadamard_matrix(size, device, generator): + # See https://github.com/Cornell-RelaxML/quip-sharp , Section "Randomized Hadamard Transformation" + Q = torch.randint(low=0, high=2, size=(size,), generator=generator).to(torch.float64) + Q = Q * 2 - 1 + Q = torch.diag(Q) + return matmul_hadU(Q).to(device) + + +def _compare_module_weights_fused_unfused(gt_module, rot_module, fused_rotations=False): + gt_weight = gt_module.weight if isinstance(gt_module, nn.Linear) else gt_module.layer.weight + gt_bias = gt_module.bias if isinstance(gt_module, nn.Linear) else gt_module.layer.bias + if fused_rotations: + rot_weight = rot_module.weight if isinstance( + rot_module, nn.Linear) else rot_module.layer.weight + rot_bias = rot_module.bias if isinstance(rot_module, nn.Linear) else rot_module.layer.bias + else: + rot_weight = rot_module.weight + rot_bias = rot_module.bias + assert torch.allclose(gt_weight, rot_weight, rtol=0.0, atol=0.0) + if gt_bias is not None: + assert torch.allclose(gt_bias, rot_bias, rtol=0.0, atol=0.0) + # For a RotatedModule, corresponding to an orphan node, additional checks need to be done + if isinstance(gt_module, RotatedModule): + if not fused_rotations: + # The outermost should be an orphan + child_rot_module = rot_module + assert child_rot_module.is_orphan, "Unfused rotated module needs to be an orphan." + # Check that the inner UnfusedRotatedModules are not orphans + while isinstance(child_rot_module.module, UnfusedRotatedModule): + assert not child_rot_module.module.is_orphan, "Inner unfused rotated modules cannot be orphans." + child_rot_module = child_rot_module.module + # Verify that the rotation matrices match + assert torch.allclose(gt_module.had_mat, rot_module.rot_mat) + + # This test verifies that the weights returned by the unfused rotate modules # match those when fusing @requires_pt_ge('2.4') @pytest_cases.parametrize('partial_had', [False, True]) -def test_models_unfused_rotations(rotation_fixtures, partial_had): +@pytest_cases.parametrize('fused_rotations', [False, True]) +def test_models_rotations(rotation_fixtures, partial_had, fused_rotations): in_shape = IN_SIZE_LINEAR @@ -431,22 +468,28 @@ def test_models_unfused_rotations(rotation_fixtures, partial_had): # We pass the generator to make sure that we can reproduce the random orthogonal matrices that are generated with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator)) as mock_ort_generator: + partial(_random_orthogonal_matrix, generator=generator)): # Apply rotation equalization while controlling the random matrices that are generated model = eq.apply(model) + with torch.no_grad(): + expected_out = model(inp) + # Now rotate but without fusing the rotation matrices with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator_clone)) as mock_ort_generator: + partial(_random_orthogonal_matrix, generator=generator_clone)): # Apply rotation equalization while controlling the random matrices that are generated model_copy = eq.apply(model_copy, fuse_rotations=False) + # Fuse the rotations and make sure the behaviour is the same + if fused_rotations: + _fuse_rotations(model_copy) + with torch.no_grad(): out = model_copy(inp) # Verify that the output of the model does not change after incorporating the rotations - with torch.no_grad(): - expected_out = model(inp) + assert torch.allclose(expected_out, out, rtol=0.0, atol=0.0) # Verify that weight matrices for model_node, model_copy_node in zip(model.graph.nodes, model_copy.graph.nodes): @@ -454,47 +497,85 @@ def test_models_unfused_rotations(rotation_fixtures, partial_had): module = get_module(model, model_node.target) module_copy = get_module(model_copy, model_copy_node.target) if isinstance(module, (nn.Linear, RotatedModule)): - weight = module.weight if isinstance(module, nn.Linear) else module.layer.weight - bias = module.bias if isinstance(module, nn.Linear) else module.layer.bias - weight_copy = module_copy.weight - bias_copy = module_copy.bias - assert torch.allclose(weight, weight_copy, atol=ATOL) - if bias is not None: - assert torch.allclose(bias, bias_copy, atol=ATOL) - # For a RotatedModule, corresponding to an orphan node, additional checks need to be done - if isinstance(module, RotatedModule): - # The outermost should be an orphan - rotated_module = module_copy - assert rotated_module.is_orphan, "Unfused rotated module needs to be an orphan." - # Check that the inner UnfusedRotatedModules are not orphans - while isinstance(rotated_module.module, UnfusedRotatedModule): - assert not rotated_module.module.is_orphan, "Inner unfused rotated modules cannot be orphans." - rotated_module = rotated_module.module - # Verify that the rotation matrices match - assert torch.allclose(module.had_mat, module_copy.rot_mat) - - -# This test verifies that the weights returned by the unfused rotate modules -# match those when fusing + _compare_module_weights_fused_unfused(module, module_copy, fused_rotations) + + +def _compare_module_weights(module, module_copy): + weight = module.weight if isinstance(module, nn.Linear) else module.layer.weight + bias = module.bias if isinstance(module, nn.Linear) else module.layer.bias + weight_copy = module_copy.weight + bias_copy = module_copy.bias + assert torch.allclose(weight, weight_copy, rtol=0.0, atol=0.0) + if bias is not None: + assert torch.allclose(bias, bias_copy, rtol=0.0, atol=0.0) + + +import logging + +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer + +from brevitas.graph.equalize import find_missing_rotation_regions +from brevitas_examples.common.accelerate_utils.accelerate import offload_model +from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks +from brevitas_examples.llm.llm_quant.data_utils import get_dataset_for_model +from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge +from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_to_rmsnorm +from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch +from brevitas_examples.llm.llm_quant.run_utils import fix_rewriter +from brevitas_examples.llm.main import fused_optimized_rotation_no_fx +from brevitas_examples.llm.main import fused_rotation_no_fx +from tests.brevitas_examples.test_llm import default_run_args + + +@pytest_cases.fixture( + ids=[ + "llama",], + params=[ + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "input_bit_width": None, + "fuse_sequences": False, + "act_calibration": False,},]) +def equalize_args(default_run_args, request): + args = default_run_args + export_dict = request.param + args.update(**export_dict) + yield args + + +@pytest.mark.llm @requires_pt_ge('2.4') @pytest_cases.parametrize('partial_had', [False, True]) -def test_models_fused_rotations(rotation_fixtures, partial_had): - - in_shape = IN_SIZE_LINEAR - - model_class = rotation_fixtures - model = model_class() - +@pytest_cases.parametrize('fused_rotations', [False, True]) +def test_small_models_equalize_legacy_rotation_orthogonal( + caplog, partial_had, fused_rotations, equalize_args): + import os + os.environ["HF_HUB_CACHE"] = "/scratch/hf_models/" + caplog.set_level(logging.INFO) + args = equalize_args + args.rotation_orphan_sink = partial_had + args.rotation_mode = 'ort' + + kwargs = {"torch_dtype": torch.float16} + model = AutoModelForCausalLM.from_pretrained(args.model, **kwargs) + model = replace_rmsnorm_with_torch(model, model.config) + model.config.use_cache = False + print("Model loaded.") model.eval() - inp = torch.rand(in_shape) - - model = symbolic_trace(model) - merge = MergeLnAffine() - model = merge.apply(model) - eq = GraphRotationEqualization(orphan_sink=partial_had, full_rotation_method='ort') - - # Save a copy to apply graph rotation equalization on - model_copy = copy.deepcopy(model) + tokenizer = AutoTokenizer.from_pretrained(args.model) + # Load the data for calibration and evaluation. + calibration_loader = get_dataset_for_model( + args.model, + dataset_name=args.dataset, + tokenizer=tokenizer, + nsamples=args.nsamples, + seqlen=args.seqlen, + split="train", + seed=args.seed, + require_fx=False, + device=None, + fuse_sequences=args.fuse_sequences) # We need to make sure that the same random matrices are being generated generator = torch.Generator() @@ -502,46 +583,205 @@ def test_models_fused_rotations(rotation_fixtures, partial_had): # Clone generator to make sure we can use the same rotation matrices generator_clone = generator.clone_state() + # Save a copy to apply graph rotation equalization on + model_copy = copy.deepcopy(model) + # We pass the generator to make sure that we can reproduce the random orthogonal matrices that are generated with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator)) as mock_ort_generator: - # Apply rotation equalization while controlling the random matrices that are generated - model = eq.apply(model) + partial(_random_orthogonal_matrix, generator=generator)): + with patch('brevitas.graph.hadamard.random_hadamard_matrix', + partial(_random_hadamard_matrix, generator=generator)): + fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations=True) + # Run model and save outputs with torch.no_grad(): - expected_out = model(inp) + expected_logits = model(**calibration_loader[0]).logits - # Now rotate but without fusing the rotation matrices + # We pass the generator to make sure that we can reproduce the random orthogonal matrices that are generated with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator_clone)) as mock_ort_generator: - # Apply rotation equalization while controlling the random matrices that are generated - model_copy = eq.apply(model_copy, fuse_rotations=False) + partial(_random_orthogonal_matrix, generator=generator_clone)): + with patch('brevitas.graph.hadamard.random_hadamard_matrix', + partial(_random_hadamard_matrix, generator=generator_clone)): + fused_rotation_no_fx(model_copy, calibration_loader, args, fuse_rotations=False) - # Fuse the rotations and make sure the behaviour is the same - _fuse_rotations(model_copy) + if fused_rotations: + _fuse_rotations(model_copy) + # Run model and save outputs with torch.no_grad(): - out = model_copy(inp) + logits = model_copy(**calibration_loader[0]).logits + + # Verify that the output is the same + assert torch.allclose(expected_logits, logits) + + # Verify that the weights after fusing match + for name_fused_module, fused_module in model.named_modules(): + # For linear modules verify that the weights match + if isinstance(fused_module, (nn.Linear, RotatedModule)): + for name_unfused_Module, unfused_module in model_copy.named_modules(): + if name_fused_module == name_unfused_Module: + _compare_module_weights(fused_module, unfused_module) + # For a RotatedModule, corresponding to an orphan node, additional checks need to be done + if isinstance(fused_module, RotatedModule): + # Verify that the outer module is an orphan + if fused_rotations: + assert isinstance(unfused_module, RotatedModule) + assert torch.allclose(fused_module.had_mat, unfused_module.had_mat) + else: + assert unfused_module.is_orphan + # Verify that the rotation matrices match + assert torch.allclose(fused_module.had_mat, unfused_module.rot_mat) + + +from itertools import product + +from brevitas.graph.equalize import _apply_had_device +from brevitas.graph.hadamard import get_hadK + + +# NOTE: This test works because in R2 we patch the rotation method, so the appropiate matrix is not effectively used. This is because when the fast_hadamard_transform is not avai +@pytest.mark.llm +@requires_pt_ge('2.4') +@pytest_cases.parametrize( + 'partial_had, fused_rotations, add_additional_regions', + list(product([False, True], repeat=3)), + ids=[("fused-R1" if fused_rotations else "R1") + ("-R2" if add_additional_regions else "") + + ("-R3" if partial_had else "") for partial_had, + fused_rotations, + add_additional_regions in list(product([False, True], repeat=3))], +) +@pytest_cases.parametrize('rotation_mode', ['ort', 'had']) +def test_small_models_equalize_mixed_fused_unfused( + caplog, partial_had, fused_rotations, add_additional_regions, rotation_mode, equalize_args): + import os + os.environ["HF_HUB_CACHE"] = "/scratch/hf_models/" + caplog.set_level(logging.INFO) + args = equalize_args + args.rotation_orphan_sink = partial_had + args.rotation_mode = rotation_mode + + kwargs = {"torch_dtype": torch.float16} + model = AutoModelForCausalLM.from_pretrained(args.model, **kwargs) + model = replace_rmsnorm_with_torch(model, model.config) + model.config.use_cache = False + print("Model loaded.") + model.eval() + tokenizer = AutoTokenizer.from_pretrained(args.model) + # Load the data for calibration and evaluation. + calibration_loader = get_dataset_for_model( + args.model, + dataset_name=args.dataset, + tokenizer=tokenizer, + nsamples=args.nsamples, + seqlen=args.seqlen, + split="train", + seed=args.seed, + require_fx=False, + device=None, + fuse_sequences=args.fuse_sequences) - # Verify that the output of the model does not change after incorporating the rotations - assert torch.allclose(expected_out, out) + # We need to make sure that the same random matrices are being generated + generator = torch.Generator() + generator.manual_seed(SEED) + # Clone generator to make sure we can use the same rotation matrices + generator_clone = generator.clone_state() - # Verify that weight matrices - for model_node, model_copy_node in zip(model.graph.nodes, model_copy.graph.nodes): - if model_node.op == 'call_module': - module = get_module(model, model_node.target) - module_copy = get_module(model_copy, model_copy_node.target) - if isinstance(module, (nn.Linear, RotatedModule)): - weight = module.weight if isinstance(module, nn.Linear) else module.layer.weight - bias = module.bias if isinstance(module, nn.Linear) else module.layer.bias - weight_copy = module_copy.weight if isinstance( - module_copy, nn.Linear) else module_copy.layer.weight - bias_copy = module_copy.bias if isinstance( - module_copy, nn.Linear) else module_copy.layer.bias - assert torch.allclose(weight, weight_copy, atol=ATOL) - if bias is not None: - assert torch.allclose(bias, bias_copy, atol=ATOL) - # For a RotatedModule, corresponding to an orphan node, additional checks need to be done - if isinstance(module, RotatedModule): - # Verify that the rotation matrices match - assert torch.allclose(module.had_mat, module_copy.had_mat) + # Run model and save outputs + with torch.no_grad(): + original_logits = model(**calibration_loader[0]).logits + + # Save a copy to apply graph rotation equalization on + model_copy = copy.deepcopy(model) + + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator)): + fused_optimized_rotation_no_fx( + model, + calibration_loader, + args, + fuse_rotations=True, + add_additional_regions=add_additional_regions) + + # Run model and save outputs + with torch.no_grad(): + expected_logits = model(**calibration_loader[0]).logits + + # Instead of random orthogonal matrices, we want to use the same ones as when the activations are not fused. + if rotation_mode == 'had': + with patch('brevitas.graph.equalize._apply_ort_device', _apply_had_device): + fused_optimized_rotation_no_fx( + model_copy, + calibration_loader, + args, + fuse_rotations=False, + add_additional_regions=add_additional_regions) + else: + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator_clone)): + fused_optimized_rotation_no_fx( + model_copy, + calibration_loader, + args, + fuse_rotations=False, + add_additional_regions=add_additional_regions) + + # Fuse matrices with module weights + if fused_rotations: + _fuse_rotations(model_copy) + + ids_rot = set() + num_rotation_matrices = 0 + # Count the number of unique rotation matrices + for module in model_copy.modules(): + if isinstance(module, UnfusedRotatedModule): + if id(module.rot_mat) not in ids_rot: + num_rotation_matrices += 1 + ids_rot.add(id(module.rot_mat)) + + num_rotated_modules = 0 + # Count the number of RotatedModules + for module in model_copy.modules(): + if isinstance(module, RotatedModule): + num_rotated_modules += 1 + + # Run model and save outputs + with torch.no_grad(): + logits = model_copy(**calibration_loader[0]).logits + + # Verify that the number of learnable rotation matrices is the expected (R1 + one R2 per block) + expected_number_rotation_matrices = 0 if fused_rotations else ( + 1 + (model.config.num_hidden_layers if add_additional_regions else 0)) + assert num_rotation_matrices == expected_number_rotation_matrices, f"Expected {expected_number_rotation_matrices} learnable rotations, found {num_rotation_matrices}." + + # Verify that the number of rotated modules is correct + expected_number_rotated_modules = 0 if not partial_had else ( + model.config.num_hidden_layers if add_additional_regions else 2 * + model.config.num_hidden_layers) + assert num_rotated_modules == expected_number_rotated_modules, f"Expected {expected_number_rotated_modules} learnable rotations, found {num_rotated_modules}." + + # Verify that the rotated module output is similar to the original FP + assert torch.allclose(original_logits, logits, atol=ATOL) + # Verify that the output is the same + assert torch.allclose(expected_logits, logits, atol=0.0, rtol=0.0) + + # Verify that the weights after fusing match + for name_fused_module, fused_module in model.named_modules(): + # For linear modules verify that the weights match + if isinstance(fused_module, (nn.Linear, RotatedModule)): + for name_unfused_Module, unfused_module in model_copy.named_modules(): + if name_fused_module == name_unfused_Module: + _compare_module_weights(fused_module, unfused_module) + # In case a RotatedModule is found, additional checks need to be done. + if isinstance(fused_module, RotatedModule): + if fused_rotations: + assert isinstance(unfused_module, RotatedModule) + assert torch.allclose(fused_module.had_mat, unfused_module.had_mat, rtol=0.0, atol=0.0), "The rotation matrices do not match." + else: + # Iterate over child nodes until finding the innermost RotatedModule + child_module = unfused_module + while isinstance(child_module, UnfusedRotatedModule): + assert not child_module.is_orphan, "UnfusedRotatedModule should not be an orphan." + child_module = child_module.module + # After finding the inner Rotated Module, they need to be compared + assert isinstance(child_module, RotatedModule), "Inner module should be RotatedModule." + assert torch.allclose(fused_module.had_mat, child_module.had_mat, rtol=0.0, atol=0.0), "The rotation matrices do not match." diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 576af04b1..f141c59ec 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -114,7 +114,7 @@ def small_models_with_ppl(request): @pytest_cases.fixture() def default_run_args(request): - args = UpdatableNamespace(**vars(parse_args([]))) + args = UpdatableNamespace(**vars(parse_args([])[0])) args.nsamples = 2 args.seqlen = 2 args.model = "hf-internal-testing/tiny-random-MistralForCausalLM" From 9a273686c6aa716121bf369b6522ad9b88e97f40 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 13 Dec 2024 11:10:35 +0000 Subject: [PATCH 06/11] Switch to reparametrizations and refactor --- src/brevitas/graph/equalize.py | 296 +++-------- src/brevitas/nn/equalized_layer.py | 119 ----- .../llm/llm_quant/rotation_optimization.py | 17 +- .../llm/llm_quant/rotation_utils.py | 92 ++++ .../llm/llm_quant/run_utils.py | 15 +- src/brevitas_examples/llm/main.py | 14 +- tests/brevitas/graph/test_equalization.py | 475 ++---------------- tests/brevitas_examples/test_llm.py | 201 ++++++++ 8 files changed, 420 insertions(+), 809 deletions(-) create mode 100644 src/brevitas_examples/llm/llm_quant/rotation_utils.py diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 0cd8ea47c..7056f18cd 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -16,6 +16,7 @@ import torch from torch.fx import GraphModule as TorchGraphModule import torch.nn as nn +import torch.nn.utils.parametrize as parametrize from brevitas import torch_version from brevitas.fx import GraphModule @@ -36,7 +37,8 @@ from brevitas.nn.equalized_layer import functional_rotate_input from brevitas.nn.equalized_layer import INPUT_NAMES from brevitas.nn.equalized_layer import RotatedModule -from brevitas.nn.equalized_layer import UnfusedRotatedModule +from brevitas.nn.equalized_layer import RotationBiasParametrization +from brevitas.nn.equalized_layer import RotationWeightParametrization from brevitas.nn.quant_scale_bias import ScaleBias from brevitas.utils.torch_utils import KwargsForwardHook @@ -342,9 +344,6 @@ def _get_input_axis(module: nn.Module) -> Optional[int]: return 0 else: return None - # TODO: Remove with parametrizations - elif isinstance(module, (UnfusedRotatedModule)): - return _get_input_axis(module.module) elif isinstance(module, (RotatedModule,)): return _get_input_axis(module.layer) else: @@ -375,9 +374,6 @@ def _get_output_axis(module: nn.Module) -> Optional[int]: return 0 else: return None - # TODO: Remove with parametrizations - elif isinstance(module, (UnfusedRotatedModule)): - return _get_output_axis(module.module) elif isinstance(module, (RotatedModule,)): return _get_output_axis(module.layer) else: @@ -1324,8 +1320,7 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= if not insert_rotation_module and not region.is_valid: continue hidden_dim = region.max_shape_sinks - # TODO: Include again not insert_rotation_module - if full_rotation_method == 'ort': + if not insert_rotation_module and full_rotation_method == 'ort': rot_mat = random_orthogonal_matrix(hidden_dim) K = None rot_func = _apply_ort_device @@ -1392,116 +1387,6 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= return rewriters -def _fuse_rotations(model: nn.Module): - rewriters = [] - - def _fuse_rotations_aux(module: nn.Module): - if isinstance(module, UnfusedRotatedModule): - unrotated_module = module.unrotated_module - rot_weight = module.weight.data - - # Fuse rotations with weights - unrotated_module.weight.data = rot_weight - # Fuse rotations with bias if existent - if module.bias is not None: - rot_bias = module.bias.data - unrotated_module.bias.data = rot_bias - - # Use rotated module if orphan - if module.is_orphan: - rewriter = ModuleInstanceToModuleInstance( - module, RotatedModule(had_mat=module.rot_mat, k=None, layer=unrotated_module)) - else: - rewriter = ModuleInstanceToModuleInstance(module, unrotated_module) - # Save rewriter - rewriters.append(rewriter) - else: - for child_module in module.children(): - _fuse_rotations_aux(child_module) - - # Populate rewriters - _fuse_rotations_aux(model) - # Apply rewriter to fuse the weights - for r in rewriters: - model = r.apply(model) - - -@dataclass -class UnfusedRotation: - rot_mat: torch.Tensor - is_sink: bool - is_source: bool - is_orphan: bool - - -def _apply_unfused_rotate(model: nn.Module, regions: List[Region], full_rotation_method='ort'): - rewriters = [] - fused_rotated_modules = defaultdict(list) - rot_func = _apply_ort_device - - for region in regions: - insert_rotation_module = len(region.srcs) == 0 - - if not insert_rotation_module and not region.is_valid: - continue - hidden_dim = region.max_shape_sinks - - rot_mat = random_orthogonal_matrix(hidden_dim) - - for name, indexes in region.srcs.items(): - module = region.get_module_from_name(name) - - fused_rotated_modules[module].append( - UnfusedRotation( - rot_mat=rot_mat, - is_sink=False, - is_source=True, - is_orphan=False, - )) - - for name, indexes in region.sinks.items(): - module = region.get_module_from_name(name) - - if insert_rotation_module and len(region.srcs) == 0: - fused_rotated_modules[module].append( - UnfusedRotation( - rot_mat=rot_mat, - is_sink=False, - is_source=False, - is_orphan=True, - )) - else: - fused_rotated_modules[module].append( - UnfusedRotation( - rot_mat=rot_mat, - is_sink=True, - is_source=False, - is_orphan=False, - )) - - for module, rotation_modules in fused_rotated_modules.items(): - rotation_module = module - for rotation_module_dataclass in rotation_modules: - rotation_module = UnfusedRotatedModule( - module=rotation_module, - rot_func=rot_func, - rot_mat=rotation_module_dataclass.rot_mat, - _get_input_axis=_get_input_axis, - _get_output_axis=_get_output_axis, - is_source=rotation_module_dataclass.is_source, - is_sink=rotation_module_dataclass.is_sink, - is_orphan=rotation_module_dataclass.is_orphan, - ) - rewriter = ModuleInstanceToModuleInstance( - module, - rotation_module, - ) - rewriters.append(rewriter) - for r in rewriters: - model = r.apply(model) - return rewriters - - def _replace_bias(next_module, new_bias): new_bias = new_bias.view(-1) if next_module.bias is not None: @@ -1591,10 +1476,8 @@ def rotate_matmuls(self, graph_module): graph_module.recompile() graph_module.graph.lint() - def apply( - self, - graph_model: GraphModule, - fuse_rotations: bool = True) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: + def apply(self, + graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: rewriters = [] regions = _extract_regions( graph_model, @@ -1618,8 +1501,7 @@ def apply( if self.rotate_matmul: self.rotate_matmuls(graph_model) if len(regions) > 0: - _apply_rotate_fn = _apply_rotate if fuse_rotations else _apply_unfused_rotate - rewriters = _apply_rotate_fn(graph_model, regions, self.full_rotation_method) + rewriters = _apply_rotate(graph_model, regions, self.full_rotation_method) if self.return_rewriters: return graph_model, rewriters else: @@ -1716,67 +1598,16 @@ def apply(self, model: nn.Module, fuse_rotations: bool = True) -> nn.Module: regions: List[Region] = [] self.find_module(model, regions) if len(regions) > 0: - _apply_rotate_fn = _apply_rotate if fuse_rotations else _apply_unfused_rotate - _apply_rotate_fn(model, regions) + _apply_rotate(model, regions) return model -def find_missing_rotation_regions(graph_model: GraphModule, - head_dim: int, - state_impl_kwargs=None) -> List[Region]: - import re - - regions = [] - # Add R2 regions, this should be innermost - for src_name, src_module in graph_model.named_modules(): - if "attn_v_proj" in src_name: - if state_impl_kwargs is not None: - state = WalkRegionState(**state_impl_kwargs) - else: - state = WalkRegionState() - - block_number_matches_src = re.findall(r'\d+', src_name) - assert len(block_number_matches_src) == 2, "Could not identify block" - block_number_src = int(block_number_matches_src[1]) - - eq_indexes = EqualizationIndexes(0, head_dim, 0) - state.add_srcs(src_name, src_module, eq_indexes) - - # Now the corresponding sink - for sink_name, sink_module in graph_model.named_modules(): - if "attn_o_proj" in sink_name: - block_number_matches_sink = re.findall(r'\d+', sink_name) - assert len(block_number_matches_sink) == 2, "Could not identify block" - block_number_sink = int(block_number_matches_sink[1]) - # If the blocks match, we identified the region - if block_number_src == block_number_sink: - eq_indexes = EqualizationIndexes(0, head_dim, state.offset) - state.add_sinks(sink_name, sink_module, eq_indexes) - # Instantiate region and add to list - region = Region( - srcs=dict(sorted(state.srcs.items())), - sinks=dict(sorted(state.sinks.items())), - name_to_module=state.name_to_module, - ) - if region not in regions: - regions.append(region) - - return regions - - def _apply_rotate_fused_rotations( model: nn.Module, regions: List[Region], - full_rotation_method: str = 'had', + full_rotation_method='had', fuse_rotations: bool = True): rewriters = [] - # Dictionary to append the unfused rotated modules for optimization - unfused_rotated_modules = defaultdict(list) - # Dictionary to keep track of the modules that are assigned to a RotatedModule - fused_rotated_modules = {} - # List to keep track of the rotation matrices added to the - rotation_matrices = [] - for region in regions: insert_rotation_module = len(region.srcs) == 0 @@ -1784,20 +1615,18 @@ def _apply_rotate_fused_rotations( continue hidden_dim = region.max_shape_sinks if not insert_rotation_module and full_rotation_method == 'ort': - rot_mat = random_orthogonal_matrix( - hidden_dim) if fuse_rotations else torch.nn.Parameter( - random_orthogonal_matrix(hidden_dim)) + rot_mat = random_orthogonal_matrix(hidden_dim) + # If the rotations are not fused, redefine as parameter + if not fuse_rotations: + rot_mat = torch.nn.Parameter(rot_mat) K = None rot_func = _apply_ort_device - # Store rotation matrix for optimization - rotation_matrices.append(rot_mat) elif not insert_rotation_module and not fuse_rotations: - # TODO: Make it more general - rot_mat = torch.nn.Parameter(random_hadamard_matrix(hidden_dim, torch.device('cpu'))) + # TODO: Generalize + device = next(model.parameters()).device + rot_mat = torch.nn.Parameter(random_hadamard_matrix(hidden_dim, device)) K = None rot_func = _apply_ort_device - # Store rotation matrix for optimization - rotation_matrices.append(rot_mat) else: try: # Build hadamard rotation matrix @@ -1815,12 +1644,17 @@ def _apply_rotate_fused_rotations( for name, indexes in region.srcs.items(): module = region.get_module_from_name(name) + axis = _get_output_axis(module) + + assert not insert_rotation_module, "Orphan regions must not have sources." if not insert_rotation_module and fuse_rotations: + # Verify that there are no parametrizations, as otherwise the underlying data will not be updated + assert not hasattr(module, "parametrizations"), "Fused rotations need to be incorporated before the parametrized rotations." + if hasattr(module, 'allocate_params'): module.allocate_params(module) - axis = _get_output_axis(module) weight = module.weight.data if axis == 0: @@ -1838,31 +1672,48 @@ def _apply_rotate_fused_rotations( if hasattr(module, 'offload_params'): module.offload_params(module) - else: - unfused_rotated_modules[module].append( - UnfusedRotation( + elif not insert_rotation_module and not fuse_rotations: + # Parametrize weights and possibly bias with unfused rotations + parametrize.register_parametrization( + module, + "weight", + RotationWeightParametrization( rot_mat=rot_mat, - is_sink=False, + rot_func=rot_func, + output_axis=axis, is_source=True, - is_orphan=False, )) + if getattr(module, 'bias', None) is not None: + parametrize.register_parametrization( + module, + "bias", + RotationBiasParametrization( + rot_mat=rot_mat, + rot_func=rot_func, + output_axis=axis, + is_source=True, + )) for name, indexes in region.sinks.items(): module = region.get_module_from_name(name) + axis = _get_input_axis(module) if not insert_rotation_module and not fuse_rotations: - unfused_rotated_modules[module].append( - UnfusedRotation( + parametrize.register_parametrization( + module, + "weight", + RotationWeightParametrization( rot_mat=rot_mat, + rot_func=rot_func, + input_axis=axis, is_sink=True, - is_source=False, - is_orphan=False, )) else: + # Verify that there are no parametrizations, as otherwise the underlying data will not be updated + assert not hasattr(module, "parametrizations"), "Fused rotations need to be incorporated before the parametrized rotations." + if hasattr(module, 'allocate_params'): module.allocate_params(module) - - axis = _get_input_axis(module) weight = module.weight.data if axis == 1: @@ -1875,42 +1726,16 @@ def _apply_rotate_fused_rotations( if hasattr(module, 'offload_params'): module.offload_params(module) - if insert_rotation_module: - if module not in fused_rotated_modules: - fused_rotated_modules[module] = RotatedModule( - had_mat=rot_mat, k=K, layer=module) - else: - raise RuntimeError( - "Only one RotatedModule at most can be assigned to a module.") - # For this to work, we need to have the following hierarchy UnfusedRotatedModule -> (RotatedModule) -> Linear - for module, rotation_modules in unfused_rotated_modules.items(): - # Verify that at most one RotatedModule is available - rotation_module = module if module not in fused_rotated_modules else fused_rotated_modules[ - module] - for rotation_module_dataclass in rotation_modules: - rotation_module = UnfusedRotatedModule( - module=rotation_module, - rot_func=rot_func, - rot_mat=rotation_module_dataclass.rot_mat, - _get_input_axis=_get_input_axis, - _get_output_axis=_get_output_axis, - is_source=rotation_module_dataclass.is_source, - is_sink=rotation_module_dataclass.is_sink, - is_orphan=rotation_module_dataclass.is_orphan, - ) - # Instantiate rewriters - rewriter = ModuleInstanceToModuleInstance(module, rotation_module) - rewriters.append(rewriter) - # Add missing RotatedModules, in case there are any - for module, rotation_module in fused_rotated_modules.items(): - if module not in unfused_rotated_modules: - rewriter = ModuleInstanceToModuleInstance(module, rotation_module) - rewriters.append(rewriter) + if insert_rotation_module and len(region.srcs) == 0: + rewriter = ModuleInstanceToModuleInstance( + module, RotatedModule(had_mat=rot_mat, k=K, layer=module)) + rewriters.append(rewriter) for r in rewriters: model = r.apply(model) - return rewriters, rotation_matrices + return rewriters +# TODO: Consolidate with GraphRotationEqualization class GraphRotationEqualizationOptimization(GraphRotationEqualization): def __init__( @@ -1955,10 +1780,13 @@ def apply( # Layerwise have only a single sink named 'sinks0' id_sink = id(o_r.get_module_from_name('sinks0')) if id_sink not in eq_layers: - # TODO: Change data structure to insert in the beginning with O(1) regions = [o_r] + regions if self.rotate_matmul: self.rotate_matmuls(graph_model) if len(regions) > 0: - rewriters, rotation_matrices = _apply_rotate_fused_rotations(graph_model, regions, self.full_rotation_method, fuse_rotations) - return graph_model, rewriters, rotation_matrices + rewriters = _apply_rotate_fused_rotations( + graph_model, regions, self.full_rotation_method, fuse_rotations) + if self.return_rewriters: + return graph_model, rewriters + else: + return graph_model diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index fbe045124..ccd812713 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -175,125 +175,6 @@ def forward(self, bias: torch.Tensor) -> torch.Tensor: return bias -class UnfusedRotatedModule(torch.nn.Module): - - def __init__( - self, - module: torch.nn.Module, - rot_func: Callable, - rot_mat: torch.nn.Parameter, - _get_input_axis: Callable, - _get_output_axis: Callable, - is_source: bool = False, - is_sink: bool = False, - is_orphan: bool = False, - ) -> None: - super().__init__() - self.module = module - self.rot_func = rot_func - self.rot_mat = rot_mat - self.K = None - - # TODO: This were included to prevent circular imports. - self._get_input_axis = _get_input_axis - self._get_output_axis = _get_output_axis - - self.is_source = is_source - self.is_sink = is_sink - self.is_orphan = is_orphan - - # TODO: Does it make sense the extra complexity just to prevent the view operation? - # Probably if no reshaping needs to be done, no change is required - def _wrap_rot(self) -> bool: - weight_shape = self.module.weight.shape - rot_dim = self.rot_mat.shape[0] - if self.is_sink or self.is_orphan: - weight_shape_dim = weight_shape[self._get_input_axis(self.module)] - elif self.is_source: - weight_shape_dim = weight_shape[self._get_output_axis(self.module)] - else: - weight_shape_dim = None - - if weight_shape_dim is not None: - if rot_dim != weight_shape_dim: - assert weight_shape_dim % rot_dim == 0, "Sizes need to be divisibile" - return True - # No need to incorporate additional view operations - return False - - # These properties enable propagating the fusing to the module weights - @property - def weight(self) -> Optional[torch.Tensor]: - weight = getattr(self.module, 'weight', None) - # Add rotation and let these being propagated till the parent - # unfused rotated module - if self.is_sink or self.is_orphan: - axis = self._get_input_axis(self.module) - if axis == 1: - weight = self.rot_func(weight, self.rot_mat, self.K) - elif axis == 0: - weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() - else: - raise RuntimeError("Not supported yet") - - if self.is_source: - axis = self._get_output_axis(self.module) - if axis == 0: - weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() - elif axis == 1: - weight = self.rot_func(weight, self.rot_mat, self.K) - else: - raise RuntimeError("Not supported yet") - - return weight - - @property - def bias(self) -> Optional[torch.Tensor]: - bias = getattr(self.module, 'bias', None) - # Propagate bias adding the rotations incrementally - if self.is_source: - if bias is not None: - bias = self.rot_func(bias, self.rot_mat, self.K) - - return bias - - @property - def unrotated_module(self) -> torch.nn.Module: - return self.module.unrotated_module if isinstance( - self.module, UnfusedRotatedModule) else self.module - - def forward(self, inp, **kwargs): - # Rotated matrices - weight = self.weight.data - bias = self.bias.data if self.bias is not None else None - - # Propagate calls till getting to the original module being rotated - child_module = self.module - # Iterate until the original module is reached, keeping the rotations that need to be performed on the input - while isinstance(child_module, UnfusedRotatedModule): - child_module = child_module.module - # child_module contains the original module in the network. Before applying its forward method, we need to - # rotate the inpute appropiately - if self.is_orphan: - # Rotate the input for an orphan sink - inp = self.rot_func(inp, self.rot_mat, self.K) - # Modify the weights, and run the original model forward. After that, restore the previous values. - if weight is not None: - orig_weight = child_module.weight.data - child_module.weight.data = weight - if bias is not None: - orig_bias = child_module.bias.data - child_module.bias.data = bias - # Call forward of the original module - o = child_module(inp) - # Restore un-rotated weights - child_module.weight.data = orig_weight - if bias is not None: - child_module.bias.data = orig_bias - # Return rotated output - return o - - def functional_rotate_input(inp, transpose=False): is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None if transpose: diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py index 7b2baecae..93b7edf7e 100644 --- a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -9,14 +9,12 @@ import torch from torch.utils.data import Dataset -from tqdm import tqdm import transformers -from transformers import default_data_collator from transformers import Trainer from transformers.tokenization_utils import PreTrainedTokenizerBase -from brevitas.nn.equalized_layer import UnfusedRotatedModule from brevitas.optim.sgdg import SGDG +from brevitas_examples.llm.llm_quant.rotation_utils import extract_trainable_rotation_matrices @dataclass @@ -86,18 +84,11 @@ def apply_rotation_optimization( for param in graph_model.parameters(): param.requires_grad = False # Collect trainable matrices - trainable_parameters = [] - ids_rot = set() - for module in graph_model.modules(): - if isinstance(module, UnfusedRotatedModule): - if id(module.rot_mat) not in ids_rot: - ids_rot.add(id(module.rot_mat)) - trainable_parameters.append(module.rot_mat) - # Collect parameters for the rotation matrices - for rot_mat in trainable_parameters: + trainable_rotations = extract_trainable_rotation_matrices(graph_model) + for rot_mat in trainable_rotations: rot_mat.requires_grad = True # Initialize optimizer - optimizer = SGDG(trainable_parameters, lr=training_args.learning_rate, stiefel=True) + optimizer = SGDG(trainable_rotations, lr=training_args.learning_rate, stiefel=True) trainer = Trainer( model=graph_model, tokenizer=tokenizer, diff --git a/src/brevitas_examples/llm/llm_quant/rotation_utils.py b/src/brevitas_examples/llm/llm_quant/rotation_utils.py new file mode 100644 index 000000000..037de166d --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/rotation_utils.py @@ -0,0 +1,92 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import re +from typing import List + +from torch import nn +from torch.fx import GraphModule +import torch.nn.utils.parametrize as parametrize + +from brevitas.graph.base import ModuleInstanceToModuleInstance +from brevitas.graph.base import Transform +from brevitas.graph.equalize import EqualizationIndexes +from brevitas.graph.equalize import Region +from brevitas.graph.equalize import WalkRegionState +from brevitas.nn.equalized_layer import RotationWeightParametrization + + +def find_self_attention_rotation_regions( + graph_model: GraphModule, head_dim: int, state_impl_kwargs=None) -> List[Region]: + regions = [] + # See R2 rotation matrices in https://arxiv.org/pdf/2405.16406. + for src_name, src_module in graph_model.named_modules(): + if "attn_v_proj" in src_name: + if state_impl_kwargs is not None: + state = WalkRegionState(**state_impl_kwargs) + else: + state = WalkRegionState() + + block_number_matches_src = re.findall(r'\d+', src_name) + assert len(block_number_matches_src) == 2, "Could not identify block" + block_number_src = int(block_number_matches_src[1]) + + eq_indexes = EqualizationIndexes(0, head_dim, 0) + state.add_srcs(src_name, src_module, eq_indexes) + + # Now the corresponding sink + for sink_name, sink_module in graph_model.named_modules(): + if "attn_o_proj" in sink_name: + block_number_matches_sink = re.findall(r'\d+', sink_name) + assert len(block_number_matches_sink) == 2, "Could not identify block" + block_number_sink = int(block_number_matches_sink[1]) + # If the blocks match, the region was identified + if block_number_src == block_number_sink: + eq_indexes = EqualizationIndexes(0, head_dim, state.offset) + state.add_sinks(sink_name, sink_module, eq_indexes) + region = Region( + srcs=dict(sorted(state.srcs.items())), + sinks=dict(sorted(state.sinks.items())), + name_to_module=state.name_to_module, + ) + if region not in regions: + regions.append(region) + + return regions + + +def fuse_rotations(model: nn.Module) -> None: + for module in model.modules(): + # Check if the module has any parametrizations + if hasattr(module, "parametrizations"): + # Remove weight parametrizations + parametrize.remove_parametrizations(module, "weight", leave_parametrized=True) + # We need to check again, in case the weight parametrizations were the only ones + if hasattr(module, "parametrizations") and hasattr(module.parametrizations, "bias"): + parametrize.remove_parametrizations(module, "bias", leave_parametrized=True) + + +def extract_rewriters_unfused_rotations(model: nn.Module, + rewriters: List[Transform]) -> List[Transform]: + extra_rewriters = [] + for module in model.modules(): + if hasattr(module, "parametrizations"): + # Verify that the current module does not have already associated a RotatedModule + if len([r for r in rewriters if r.old_module_instance is module]) == 0: + # Identity rewriter, only useful externaly + rewriter = ModuleInstanceToModuleInstance(module, module) + extra_rewriters.append(rewriter) + return extra_rewriters + + +def extract_trainable_rotation_matrices(model: nn.Module) -> List[nn.Parameter]: + trainable_rotations = [] + # We need to keep track of the IDs of the rotation matrices, as several modules + # can share the same parametrized rotation. + ids_rot = set() + for module in model.modules(): + if isinstance(module, RotationWeightParametrization): + if id(module.rot_mat) not in ids_rot: + ids_rot.add(id(module.rot_mat)) + trainable_rotations.append(module.rot_mat) + return trainable_rotations diff --git a/src/brevitas_examples/llm/llm_quant/run_utils.py b/src/brevitas_examples/llm/llm_quant/run_utils.py index 44ba711a5..4acbd87ed 100644 --- a/src/brevitas_examples/llm/llm_quant/run_utils.py +++ b/src/brevitas_examples/llm/llm_quant/run_utils.py @@ -110,15 +110,24 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): return out +def _get_tensor_weight_id(module, tensor_name): + if hasattr(module, "parametrizations") and tensor_name in module.parametrizations: + return id(module.parametrizations[tensor_name].original) + elif hasattr(module, tensor_name): + return id(getattr(module, tensor_name)) + return None + + # This functions remap rewriters so match modules in a potentially different model that shares the same underlying tensors # We rely on the fact that two versions of the same model (eager vs FX) might have different modules id (id(fx_module) != id (eager_module)) # However, the underlying tensors are still shared, so we can recostruct the mapping between the two # modules. def fix_rewriter(rewriters, old_model_ref, tensor_name): + # We need to account for reparametrizations, to make sure the underlying tensors are accessed for r in rewriters: - tensor_id = id(r.old_module_instance.weight) + tensor_id = _get_tensor_weight_id(r.old_module_instance, tensor_name) module = [ - m for m in old_model_ref.modules() - if hasattr(m, tensor_name) and id(m.weight) == tensor_id] + m for m in old_model_ref.modules() if _get_tensor_weight_id(m, tensor_name) == tensor_id + ] r.old_module_instance = module[0] return rewriters diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index df522e901..de01b9d7f 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -16,7 +16,6 @@ from brevitas.export import export_torch_qcdq from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager -from brevitas.graph.equalize import find_missing_rotation_regions from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import GraphRotationEqualizationOptimization from brevitas.graph.equalize import LayerwiseActivationRotation @@ -42,6 +41,8 @@ from brevitas_examples.llm.llm_quant.prepare_for_quantize import add_zero_bias_to_linear from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers from brevitas_examples.llm.llm_quant.rotation_optimization import apply_rotation_optimization +from brevitas_examples.llm.llm_quant.rotation_utils import extract_rewriters_unfused_rotations +from brevitas_examples.llm.llm_quant.rotation_utils import find_self_attention_rotation_regions from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 from brevitas_examples.llm.llm_quant.run_utils import fix_rewriter from brevitas_examples.llm.llm_quant.run_utils import get_fx @@ -89,14 +90,21 @@ def fused_optimized_rotation_no_fx( for r in rewriters: r.apply(model) #new_model = offload_model(new_model) - additional_regions = find_missing_rotation_regions( + + # Regions with source o_proj and sink down_proj + self_attention_regions = find_self_attention_rotation_regions( new_model, model.config.hidden_size // model.config.num_attention_heads) if add_additional_regions else None eq = GraphRotationEqualizationOptimization( orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, ) - new_model, rewriters, rotation_matrices = eq.apply(new_model, fuse_rotations=fuse_rotations, additional_regions=additional_regions) + new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations, additional_regions=self_attention_regions) + + # Retrieve additional rewriters for unfused rotations + rewriters_unfused_rotations = extract_rewriters_unfused_rotations(new_model, rewriters) + rewriters.extend(rewriters_unfused_rotations) + rewriters = fix_rewriter(rewriters, model, 'weight') for r in rewriters: diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 0f65db168..1e913845c 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -2,21 +2,19 @@ # SPDX-License-Identifier: BSD-3-Clause import copy -from functools import partial import itertools from typing import List, Tuple from unittest.mock import patch import pytest import torch +import torch.nn.utils.parametrize as parametrize from torchvision import models from brevitas.fx import symbolic_trace -# TODO: Refactor to prevent circular import from brevitas.graph.equalize import _apply_ort_device from brevitas.graph.equalize import _batch_norm from brevitas.graph.equalize import _extract_regions -from brevitas.graph.equalize import _fuse_rotations from brevitas.graph.equalize import _get_input_axis from brevitas.graph.equalize import _get_output_axis from brevitas.graph.equalize import _is_supported_module @@ -25,12 +23,11 @@ from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import MergeLnAffine from brevitas.graph.equalize import random_orthogonal_matrix -from brevitas.graph.hadamard import matmul_hadU from brevitas.graph.standardize import DuplicateSharedStatelessModule from brevitas.graph.standardize import TorchFunctionalToModule from brevitas.graph.utils import get_module -from brevitas.nn.equalized_layer import RotatedModule -from brevitas.nn.equalized_layer import UnfusedRotatedModule +from brevitas.nn.equalized_layer import RotationBiasParametrization +from brevitas.nn.equalized_layer import RotationWeightParametrization from tests.marker import requires_pt_ge from .equalization_fixtures import * @@ -293,13 +290,13 @@ def test_models(rotation_fixtures, partial_had): def _rotate_input_output(is_source: bool, is_sink: bool, is_orphan: bool) -> Tuple[bool, bool]: - # Verify that only one flag is enabled at the same time + # Verify that only one flag is enabled simultaneously assert sum([is_source, is_sink, is_orphan]) <= 1, "Only one flag can be enabled." rotate_input, rotate_output = False, False if is_source: rotate_output = True - if is_sink: + if is_sink or is_orphan: rotate_input = True return rotate_input, rotate_output @@ -316,32 +313,29 @@ def _compute_rotated_ouptut_from_matrices( return out -# NOTE: The assumption is that only one flag can be true simultaneously -# NOTE: Orphans need to be taken care of. A module can only be orphan once. +# RotationParametrizations can only have one type flag enabled simultaneously (is_source, is_sink, is_orphan). +# Moreover, orphan rotations need to be the outermost rotation, as this cancels out when rotating the input. def _generate_rotation_flags(N: int) -> List[bool]: return [ rotation_flags for rotation_flags in itertools.product([False, True], repeat=3 * N) if ( - all([sum(rotation_flags[i * 3:(i + 1) * 3]) <= 1 for i in range(N)]) and - # Only outermost rotation can be orphan - all([not rotation_flags[i * 3 + 2] for i in range(N - 1)]))] + all([sum(rotation_flags[i * 3:(i + 1) * 3]) <= 1 + for i in range(N)]) and all([not rotation_flags[i * 3 + 2] for i in range(N - 1)])) + ] @requires_pt_ge('2.4') @pytest_cases.parametrize('N', [1, 2, 3], ids=lambda x: f"N={x}") -def test_composition_unfused_rotation_layer(N): +def test_composition_unfused_rotations(N): torch.manual_seed(SEED) for rotation_flags in _generate_rotation_flags(N): in_features = IN_FEATURES_LINEAR module = nn.Linear(in_features=in_features, out_features=in_features) + rot_module = copy.deepcopy(module) # Sample input to pass through the block sample_input = torch.rand((1, in_features),) - - # Compose rotation modules - rotated_module = module - # Composite rotation matrices rot_mat_input = torch.eye(in_features) rot_mat_output = torch.eye(in_features) @@ -361,427 +355,34 @@ def test_composition_unfused_rotation_layer(N): rot_mat_output = rot_mat_output @ rot_mat # Compose rotation modules - rotated_module = UnfusedRotatedModule( - module=rotated_module, - rot_func=_apply_ort_device, - _get_input_axis=_get_input_axis, - _get_output_axis=_get_output_axis, - rot_mat=rot_mat, - is_source=is_source, - is_sink=is_sink, - is_orphan=is_orphan, - ) - - # Compute outputs to compare + parametrize.register_parametrization( + rot_module, + "weight", + RotationWeightParametrization( + rot_mat=rot_mat, + rot_func=_apply_ort_device, + input_axis=_get_input_axis(rot_module), + output_axis=_get_output_axis(rot_module), + is_source=is_source, + is_sink=is_sink, + is_orphan=is_orphan, + )) + parametrize.register_parametrization( + rot_module, + "bias", + RotationBiasParametrization( + rot_mat=rot_mat, + rot_func=_apply_ort_device, + input_axis=_get_input_axis(rot_module), + output_axis=_get_output_axis(rot_module), + is_source=is_source, + is_sink=is_sink, + is_orphan=is_orphan, + )) + gt_output = _compute_rotated_ouptut_from_matrices( module, sample_input, rot_mat_input, rot_mat_output) - rot_output = rotated_module(sample_input) + rot_output = rot_module(sample_input) # Verify that the rotation operations were computed correctly assert torch.allclose(gt_output, rot_output, atol=ATOL) - - -# Adapted from https://github.com/facebookresearch/SpinQuant/blob/main/eval_utils/rotation_utils.py#L26 -def _random_orthogonal_matrix(size, generator): - """ - Generate a random orthogonal matrix of the specified size. - First, we generate a random matrix with entries from a standard distribution. - Then, we use QR decomposition to obtain an orthogonal matrix. - Finally, we multiply by a diagonal matrix with diag r to adjust the signs. - - Args: - size (int): The size of the matrix (size x size). - - Returns: - torch.Tensor: An orthogonal matrix of the specified size. - """ - torch.cuda.empty_cache() - random_matrix = torch.randn(size, size, dtype=torch.float64, generator=generator) - q, r = torch.linalg.qr(random_matrix) - q *= torch.sign(torch.diag(r)).unsqueeze(0).float() - return q - - -def _random_hadamard_matrix(size, device, generator): - # See https://github.com/Cornell-RelaxML/quip-sharp , Section "Randomized Hadamard Transformation" - Q = torch.randint(low=0, high=2, size=(size,), generator=generator).to(torch.float64) - Q = Q * 2 - 1 - Q = torch.diag(Q) - return matmul_hadU(Q).to(device) - - -def _compare_module_weights_fused_unfused(gt_module, rot_module, fused_rotations=False): - gt_weight = gt_module.weight if isinstance(gt_module, nn.Linear) else gt_module.layer.weight - gt_bias = gt_module.bias if isinstance(gt_module, nn.Linear) else gt_module.layer.bias - if fused_rotations: - rot_weight = rot_module.weight if isinstance( - rot_module, nn.Linear) else rot_module.layer.weight - rot_bias = rot_module.bias if isinstance(rot_module, nn.Linear) else rot_module.layer.bias - else: - rot_weight = rot_module.weight - rot_bias = rot_module.bias - assert torch.allclose(gt_weight, rot_weight, rtol=0.0, atol=0.0) - if gt_bias is not None: - assert torch.allclose(gt_bias, rot_bias, rtol=0.0, atol=0.0) - # For a RotatedModule, corresponding to an orphan node, additional checks need to be done - if isinstance(gt_module, RotatedModule): - if not fused_rotations: - # The outermost should be an orphan - child_rot_module = rot_module - assert child_rot_module.is_orphan, "Unfused rotated module needs to be an orphan." - # Check that the inner UnfusedRotatedModules are not orphans - while isinstance(child_rot_module.module, UnfusedRotatedModule): - assert not child_rot_module.module.is_orphan, "Inner unfused rotated modules cannot be orphans." - child_rot_module = child_rot_module.module - # Verify that the rotation matrices match - assert torch.allclose(gt_module.had_mat, rot_module.rot_mat) - - -# This test verifies that the weights returned by the unfused rotate modules -# match those when fusing -@requires_pt_ge('2.4') -@pytest_cases.parametrize('partial_had', [False, True]) -@pytest_cases.parametrize('fused_rotations', [False, True]) -def test_models_rotations(rotation_fixtures, partial_had, fused_rotations): - - in_shape = IN_SIZE_LINEAR - - model_class = rotation_fixtures - model = model_class() - - model.eval() - inp = torch.rand(in_shape) - - model = symbolic_trace(model) - merge = MergeLnAffine() - model = merge.apply(model) - eq = GraphRotationEqualization(orphan_sink=partial_had, full_rotation_method='ort') - - # Save a copy to apply graph rotation equalization on - model_copy = copy.deepcopy(model) - - # We need to make sure that the same random matrices are being generated - generator = torch.Generator() - generator.manual_seed(SEED) - # Clone generator to make sure we can use the same rotation matrices - generator_clone = generator.clone_state() - - # We pass the generator to make sure that we can reproduce the random orthogonal matrices that are generated - with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator)): - # Apply rotation equalization while controlling the random matrices that are generated - model = eq.apply(model) - - with torch.no_grad(): - expected_out = model(inp) - - # Now rotate but without fusing the rotation matrices - with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator_clone)): - # Apply rotation equalization while controlling the random matrices that are generated - model_copy = eq.apply(model_copy, fuse_rotations=False) - - # Fuse the rotations and make sure the behaviour is the same - if fused_rotations: - _fuse_rotations(model_copy) - - with torch.no_grad(): - out = model_copy(inp) - - # Verify that the output of the model does not change after incorporating the rotations - assert torch.allclose(expected_out, out, rtol=0.0, atol=0.0) - - # Verify that weight matrices - for model_node, model_copy_node in zip(model.graph.nodes, model_copy.graph.nodes): - if model_node.op == 'call_module': - module = get_module(model, model_node.target) - module_copy = get_module(model_copy, model_copy_node.target) - if isinstance(module, (nn.Linear, RotatedModule)): - _compare_module_weights_fused_unfused(module, module_copy, fused_rotations) - - -def _compare_module_weights(module, module_copy): - weight = module.weight if isinstance(module, nn.Linear) else module.layer.weight - bias = module.bias if isinstance(module, nn.Linear) else module.layer.bias - weight_copy = module_copy.weight - bias_copy = module_copy.bias - assert torch.allclose(weight, weight_copy, rtol=0.0, atol=0.0) - if bias is not None: - assert torch.allclose(bias, bias_copy, rtol=0.0, atol=0.0) - - -import logging - -from transformers import AutoModelForCausalLM -from transformers import AutoTokenizer - -from brevitas.graph.equalize import find_missing_rotation_regions -from brevitas_examples.common.accelerate_utils.accelerate import offload_model -from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks -from brevitas_examples.llm.llm_quant.data_utils import get_dataset_for_model -from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge -from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_to_rmsnorm -from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch -from brevitas_examples.llm.llm_quant.run_utils import fix_rewriter -from brevitas_examples.llm.main import fused_optimized_rotation_no_fx -from brevitas_examples.llm.main import fused_rotation_no_fx -from tests.brevitas_examples.test_llm import default_run_args - - -@pytest_cases.fixture( - ids=[ - "llama",], - params=[ - { - "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", - "input_bit_width": None, - "fuse_sequences": False, - "act_calibration": False,},]) -def equalize_args(default_run_args, request): - args = default_run_args - export_dict = request.param - args.update(**export_dict) - yield args - - -@pytest.mark.llm -@requires_pt_ge('2.4') -@pytest_cases.parametrize('partial_had', [False, True]) -@pytest_cases.parametrize('fused_rotations', [False, True]) -def test_small_models_equalize_legacy_rotation_orthogonal( - caplog, partial_had, fused_rotations, equalize_args): - import os - os.environ["HF_HUB_CACHE"] = "/scratch/hf_models/" - caplog.set_level(logging.INFO) - args = equalize_args - args.rotation_orphan_sink = partial_had - args.rotation_mode = 'ort' - - kwargs = {"torch_dtype": torch.float16} - model = AutoModelForCausalLM.from_pretrained(args.model, **kwargs) - model = replace_rmsnorm_with_torch(model, model.config) - model.config.use_cache = False - print("Model loaded.") - model.eval() - tokenizer = AutoTokenizer.from_pretrained(args.model) - # Load the data for calibration and evaluation. - calibration_loader = get_dataset_for_model( - args.model, - dataset_name=args.dataset, - tokenizer=tokenizer, - nsamples=args.nsamples, - seqlen=args.seqlen, - split="train", - seed=args.seed, - require_fx=False, - device=None, - fuse_sequences=args.fuse_sequences) - - # We need to make sure that the same random matrices are being generated - generator = torch.Generator() - generator.manual_seed(SEED) - # Clone generator to make sure we can use the same rotation matrices - generator_clone = generator.clone_state() - - # Save a copy to apply graph rotation equalization on - model_copy = copy.deepcopy(model) - - # We pass the generator to make sure that we can reproduce the random orthogonal matrices that are generated - with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator)): - with patch('brevitas.graph.hadamard.random_hadamard_matrix', - partial(_random_hadamard_matrix, generator=generator)): - fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations=True) - - # Run model and save outputs - with torch.no_grad(): - expected_logits = model(**calibration_loader[0]).logits - - # We pass the generator to make sure that we can reproduce the random orthogonal matrices that are generated - with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator_clone)): - with patch('brevitas.graph.hadamard.random_hadamard_matrix', - partial(_random_hadamard_matrix, generator=generator_clone)): - fused_rotation_no_fx(model_copy, calibration_loader, args, fuse_rotations=False) - - if fused_rotations: - _fuse_rotations(model_copy) - - # Run model and save outputs - with torch.no_grad(): - logits = model_copy(**calibration_loader[0]).logits - - # Verify that the output is the same - assert torch.allclose(expected_logits, logits) - - # Verify that the weights after fusing match - for name_fused_module, fused_module in model.named_modules(): - # For linear modules verify that the weights match - if isinstance(fused_module, (nn.Linear, RotatedModule)): - for name_unfused_Module, unfused_module in model_copy.named_modules(): - if name_fused_module == name_unfused_Module: - _compare_module_weights(fused_module, unfused_module) - # For a RotatedModule, corresponding to an orphan node, additional checks need to be done - if isinstance(fused_module, RotatedModule): - # Verify that the outer module is an orphan - if fused_rotations: - assert isinstance(unfused_module, RotatedModule) - assert torch.allclose(fused_module.had_mat, unfused_module.had_mat) - else: - assert unfused_module.is_orphan - # Verify that the rotation matrices match - assert torch.allclose(fused_module.had_mat, unfused_module.rot_mat) - - -from itertools import product - -from brevitas.graph.equalize import _apply_had_device -from brevitas.graph.hadamard import get_hadK - - -# NOTE: This test works because in R2 we patch the rotation method, so the appropiate matrix is not effectively used. This is because when the fast_hadamard_transform is not avai -@pytest.mark.llm -@requires_pt_ge('2.4') -@pytest_cases.parametrize( - 'partial_had, fused_rotations, add_additional_regions', - list(product([False, True], repeat=3)), - ids=[("fused-R1" if fused_rotations else "R1") + ("-R2" if add_additional_regions else "") + - ("-R3" if partial_had else "") for partial_had, - fused_rotations, - add_additional_regions in list(product([False, True], repeat=3))], -) -@pytest_cases.parametrize('rotation_mode', ['ort', 'had']) -def test_small_models_equalize_mixed_fused_unfused( - caplog, partial_had, fused_rotations, add_additional_regions, rotation_mode, equalize_args): - import os - os.environ["HF_HUB_CACHE"] = "/scratch/hf_models/" - caplog.set_level(logging.INFO) - args = equalize_args - args.rotation_orphan_sink = partial_had - args.rotation_mode = rotation_mode - - kwargs = {"torch_dtype": torch.float16} - model = AutoModelForCausalLM.from_pretrained(args.model, **kwargs) - model = replace_rmsnorm_with_torch(model, model.config) - model.config.use_cache = False - print("Model loaded.") - model.eval() - tokenizer = AutoTokenizer.from_pretrained(args.model) - # Load the data for calibration and evaluation. - calibration_loader = get_dataset_for_model( - args.model, - dataset_name=args.dataset, - tokenizer=tokenizer, - nsamples=args.nsamples, - seqlen=args.seqlen, - split="train", - seed=args.seed, - require_fx=False, - device=None, - fuse_sequences=args.fuse_sequences) - - # We need to make sure that the same random matrices are being generated - generator = torch.Generator() - generator.manual_seed(SEED) - # Clone generator to make sure we can use the same rotation matrices - generator_clone = generator.clone_state() - - # Run model and save outputs - with torch.no_grad(): - original_logits = model(**calibration_loader[0]).logits - - # Save a copy to apply graph rotation equalization on - model_copy = copy.deepcopy(model) - - with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator)): - fused_optimized_rotation_no_fx( - model, - calibration_loader, - args, - fuse_rotations=True, - add_additional_regions=add_additional_regions) - - # Run model and save outputs - with torch.no_grad(): - expected_logits = model(**calibration_loader[0]).logits - - # Instead of random orthogonal matrices, we want to use the same ones as when the activations are not fused. - if rotation_mode == 'had': - with patch('brevitas.graph.equalize._apply_ort_device', _apply_had_device): - fused_optimized_rotation_no_fx( - model_copy, - calibration_loader, - args, - fuse_rotations=False, - add_additional_regions=add_additional_regions) - else: - with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator_clone)): - fused_optimized_rotation_no_fx( - model_copy, - calibration_loader, - args, - fuse_rotations=False, - add_additional_regions=add_additional_regions) - - # Fuse matrices with module weights - if fused_rotations: - _fuse_rotations(model_copy) - - ids_rot = set() - num_rotation_matrices = 0 - # Count the number of unique rotation matrices - for module in model_copy.modules(): - if isinstance(module, UnfusedRotatedModule): - if id(module.rot_mat) not in ids_rot: - num_rotation_matrices += 1 - ids_rot.add(id(module.rot_mat)) - - num_rotated_modules = 0 - # Count the number of RotatedModules - for module in model_copy.modules(): - if isinstance(module, RotatedModule): - num_rotated_modules += 1 - - # Run model and save outputs - with torch.no_grad(): - logits = model_copy(**calibration_loader[0]).logits - - # Verify that the number of learnable rotation matrices is the expected (R1 + one R2 per block) - expected_number_rotation_matrices = 0 if fused_rotations else ( - 1 + (model.config.num_hidden_layers if add_additional_regions else 0)) - assert num_rotation_matrices == expected_number_rotation_matrices, f"Expected {expected_number_rotation_matrices} learnable rotations, found {num_rotation_matrices}." - - # Verify that the number of rotated modules is correct - expected_number_rotated_modules = 0 if not partial_had else ( - model.config.num_hidden_layers if add_additional_regions else 2 * - model.config.num_hidden_layers) - assert num_rotated_modules == expected_number_rotated_modules, f"Expected {expected_number_rotated_modules} learnable rotations, found {num_rotated_modules}." - - # Verify that the rotated module output is similar to the original FP - assert torch.allclose(original_logits, logits, atol=ATOL) - # Verify that the output is the same - assert torch.allclose(expected_logits, logits, atol=0.0, rtol=0.0) - - # Verify that the weights after fusing match - for name_fused_module, fused_module in model.named_modules(): - # For linear modules verify that the weights match - if isinstance(fused_module, (nn.Linear, RotatedModule)): - for name_unfused_Module, unfused_module in model_copy.named_modules(): - if name_fused_module == name_unfused_Module: - _compare_module_weights(fused_module, unfused_module) - # In case a RotatedModule is found, additional checks need to be done. - if isinstance(fused_module, RotatedModule): - if fused_rotations: - assert isinstance(unfused_module, RotatedModule) - assert torch.allclose(fused_module.had_mat, unfused_module.had_mat, rtol=0.0, atol=0.0), "The rotation matrices do not match." - else: - # Iterate over child nodes until finding the innermost RotatedModule - child_module = unfused_module - while isinstance(child_module, UnfusedRotatedModule): - assert not child_module.is_orphan, "UnfusedRotatedModule should not be an orphan." - child_module = child_module.module - # After finding the inner Rotated Module, they need to be compared - assert isinstance(child_module, RotatedModule), "Inner module should be RotatedModule." - assert torch.allclose(fused_module.had_mat, child_module.had_mat, rtol=0.0, atol=0.0), "The rotation matrices do not match." diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index f141c59ec..f7e11095d 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -2,11 +2,15 @@ # SPDX-License-Identifier: BSD-3-Clause from argparse import Namespace +import copy from dataclasses import dataclass +from functools import partial +from itertools import product import logging import os import platform import shutil +from unittest.mock import patch import numpy as np import onnx @@ -14,15 +18,28 @@ import pytest import pytest_cases import torch +from torch import nn +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer from brevitas import config from brevitas import torch_version +from brevitas.graph.equalize import _apply_had_device +from brevitas.nn.equalized_layer import RotatedModule # LLM example depends on optimum-amd, which requires PyTorch>=2.2 +from brevitas_examples.llm.llm_quant.data_utils import get_dataset_for_model +from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch +from brevitas_examples.llm.llm_quant.rotation_utils import extract_trainable_rotation_matrices +from brevitas_examples.llm.llm_quant.rotation_utils import fuse_rotations +from brevitas_examples.llm.main import fused_optimized_rotation_no_fx from brevitas_examples.llm.main import main from brevitas_examples.llm.main import parse_args +from tests.conftest import SEED from tests.marker import jit_disabled_for_export from tests.marker import requires_pt_ge +ATOL = 1e-3 + def ptid2pathname(string): return string.replace("/", "-").replace(":", "-") @@ -520,3 +537,187 @@ def test_small_models_torch_export(caplog, torch_export_args): filepath = args.export_prefix + ".pt" torchscript_model = torch.jit.load(filepath) os.remove(filepath) + + +# Adapted from https://github.com/facebookresearch/SpinQuant/blob/main/eval_utils/rotation_utils.py#L26 +# This functions needs to be patches to enable passing the generator and ensuring that the orthogonal +# matrices generated are the same. +def _random_orthogonal_matrix(size, generator): + """ + Generate a random orthogonal matrix of the specified size. + First, we generate a random matrix with entries from a standard distribution. + Then, we use QR decomposition to obtain an orthogonal matrix. + Finally, we multiply by a diagonal matrix with diag r to adjust the signs. + + Args: + size (int): The size of the matrix (size x size). + + Returns: + torch.Tensor: An orthogonal matrix of the specified size. + """ + torch.cuda.empty_cache() + random_matrix = torch.randn(size, size, dtype=torch.float64, generator=generator) + q, r = torch.linalg.qr(random_matrix) + q *= torch.sign(torch.diag(r)).unsqueeze(0).float() + return q + + +@pytest_cases.fixture( + ids=[ + "llama",], + params=[ + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "input_bit_width": None, + "fuse_sequences": False, + "act_calibration": False,},]) +def equalize_args(default_run_args, request): + args = default_run_args + export_dict = request.param + args.update(**export_dict) + yield args + + +# Auxiliar method to compare the weights in rotated modules. +def _compare_fused_unfused_rotation_modules(module_name, fused_rot_module, unfused_rot_module): + fused_weight = fused_rot_module.weight if isinstance( + fused_rot_module, nn.Linear) else fused_rot_module.layer.weight + fused_bias = fused_rot_module.bias if isinstance( + fused_rot_module, nn.Linear) else fused_rot_module.layer.bias + unfused_weight = unfused_rot_module.weight if isinstance( + unfused_rot_module, nn.Linear) else unfused_rot_module.layer.weight + unfused_bias = unfused_rot_module.bias if isinstance( + unfused_rot_module, nn.Linear) else unfused_rot_module.layer.bias + assert torch.allclose(fused_weight, unfused_weight, rtol=0.0, atol=0.0), f"The weights after rotation do not match for module {module_name}." + if fused_bias is not None: + assert torch.allclose(fused_bias, unfused_bias, rtol=0.0, atol=0.0), f"The bias after rotation do not match for module {module_name}." + # In case a RotatedModule is found, additional checks need to be done. + if isinstance(fused_rot_module, RotatedModule): + assert isinstance(unfused_rot_module, RotatedModule), f"Expected an instance of RotatedModule for module {module_name}." + assert torch.allclose(fused_rot_module.had_mat, unfused_rot_module.had_mat, rtol=0.0, atol=0.0), f"The rotation matrices of RotatedModule {module_name} do not match." + + +@pytest.mark.llm +@requires_pt_ge('2.4') +@pytest_cases.parametrize( + 'partial_had, fused_rotations, add_additional_regions', + list(product([False, True], repeat=3)), + ids=[("fused-R1" if fused_rotations else "R1") + ("-R2" if add_additional_regions else "") + + ("-R3" if partial_had else "") for partial_had, + fused_rotations, + add_additional_regions in list(product([False, True], repeat=3))], +) +@pytest_cases.parametrize('rotation_mode', ['ort', 'had']) +def test_small_models_rotations( + caplog, partial_had, fused_rotations, add_additional_regions, rotation_mode, equalize_args): + caplog.set_level(logging.INFO) + args = equalize_args + args.rotation_orphan_sink = partial_had + args.rotation_mode = rotation_mode + + kwargs = {"torch_dtype": torch.float16} + model = AutoModelForCausalLM.from_pretrained(args.model, **kwargs) + model = replace_rmsnorm_with_torch(model, model.config) + model.config.use_cache = False + print("Model loaded.") + model.eval() + tokenizer = AutoTokenizer.from_pretrained(args.model) + # Load the data for calibration and evaluation. + calibration_loader = get_dataset_for_model( + args.model, + dataset_name=args.dataset, + tokenizer=tokenizer, + nsamples=args.nsamples, + seqlen=args.seqlen, + split="train", + seed=args.seed, + require_fx=False, + device=None, + fuse_sequences=args.fuse_sequences) + + # We need to make sure that the same random matrices are being generated + generator = torch.Generator() + generator.manual_seed(SEED) + # Clone generator to make sure we can use the same rotation matrices + generator_clone = generator.clone_state() + + # Run model and save outputs + with torch.no_grad(): + original_logits = model(**calibration_loader[0]).logits + + # Save a copy to apply graph rotation equalization on + model_copy = copy.deepcopy(model) + + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator)): + fused_optimized_rotation_no_fx( + model, + calibration_loader, + args, + fuse_rotations=True, + add_additional_regions=add_additional_regions) + + # Run model and save outputs + with torch.no_grad(): + expected_logits = model(**calibration_loader[0]).logits + + # Instead of random orthogonal matrices, we want to use the same ones as when the activations are not fused. + if rotation_mode == 'had': + with patch('brevitas.graph.equalize._apply_ort_device', _apply_had_device): + fused_optimized_rotation_no_fx( + model_copy, + calibration_loader, + args, + fuse_rotations=False, + add_additional_regions=add_additional_regions) + else: + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator_clone)): + fused_optimized_rotation_no_fx( + model_copy, + calibration_loader, + args, + fuse_rotations=False, + add_additional_regions=add_additional_regions) + + # Fuse matrices with module weights + if fused_rotations: + fuse_rotations(model_copy) + + # Run model and save outputs + with torch.no_grad(): + logits = model_copy(**calibration_loader[0]).logits + + # Verify that the rotated module output is similar to the original FP + assert torch.allclose(original_logits, logits, atol=ATOL), "Output of rotated network does not approximately match that of the original network." + # Verify that the output is the same + assert torch.allclose(expected_logits, logits, atol=0.0, rtol=0.0), "Outputs of fused/unfused rotated networks do not match exactly." + + num_rotation_matrices = len(extract_trainable_rotation_matrices(model_copy)) + + num_rotated_modules = 0 + # Count the number of RotatedModules + for module in model_copy.modules(): + if isinstance(module, RotatedModule): + num_rotated_modules += 1 + + # Verify that the number of learnable rotation matrices is the expected (R1 + one R2 per block) + expected_number_rotation_matrices = 0 if fused_rotations else ( + 1 + (model.config.num_hidden_layers if add_additional_regions else 0)) + assert num_rotation_matrices == expected_number_rotation_matrices, f"Expected {expected_number_rotation_matrices} learnable rotations, found {num_rotation_matrices}." + + # Verify that the number of rotated modules is correct + expected_number_rotated_modules = 0 if not partial_had else ( + model.config.num_hidden_layers if add_additional_regions else 2 * + model.config.num_hidden_layers) + assert num_rotated_modules == expected_number_rotated_modules, f"Expected {expected_number_rotated_modules} RotatedModules found {num_rotated_modules}." + + # Verify that the weights after fusing match + for name_fused_module, fused_module in model.named_modules(): + # For linear modules verify that the weights match + if isinstance(fused_module, (nn.Linear, RotatedModule)): + for name_unfused_Module, unfused_module in model_copy.named_modules(): + if name_fused_module == name_unfused_Module: + # Verify that everything matches between the fused and unfused rotation modules + _compare_fused_unfused_rotation_modules( + name_fused_module, fused_module, unfused_module) From 08eb355d122486607d0b26c7d1ab0c1584f45680 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 13 Dec 2024 11:12:04 +0000 Subject: [PATCH 07/11] Unsaved changes in llm main --- src/brevitas_examples/llm/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index de01b9d7f..fb5b9f368 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -91,7 +91,7 @@ def fused_optimized_rotation_no_fx( r.apply(model) #new_model = offload_model(new_model) - # Regions with source o_proj and sink down_proj + # Regions with source v_proj and sink o_proj self_attention_regions = find_self_attention_rotation_regions( new_model, model.config.hidden_size // model.config.num_attention_heads) if add_additional_regions else None From b59dc247fe68c7cf75624eea539b13f52dd0f961 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 13 Dec 2024 17:20:51 +0000 Subject: [PATCH 08/11] Consolidate fused/unfused rotations --- src/brevitas/graph/equalize.py | 317 ++++++------------ src/brevitas/graph/hadamard.py | 2 +- .../llm/llm_quant/rotation_optimization.py | 1 - src/brevitas_examples/llm/main.py | 72 ++-- tests/brevitas_examples/test_llm.py | 56 ++-- 5 files changed, 155 insertions(+), 293 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 7056f18cd..faae03a62 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1312,7 +1312,11 @@ def random_orthogonal_matrix(size): return q -def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method='had'): +def _apply_rotate( + model: nn.Module, + regions: List[Region], + full_rotation_method: str = 'had', + fuse_rotations: bool = True): rewriters = [] for region in regions: insert_rotation_module = len(region.srcs) == 0 @@ -1324,6 +1328,13 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= rot_mat = random_orthogonal_matrix(hidden_dim) K = None rot_func = _apply_ort_device + elif not insert_rotation_module and not fuse_rotations: + # TODO: This might be problematic if the parameters are distributed + # across devices. Generalize this logic for safety. + device = next(model.parameters()).device + rot_mat = random_hadamard_matrix(hidden_dim, device) + K = None + rot_func = _apply_ort_device else: try: # Build hadamard rotation matrix @@ -1339,44 +1350,85 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= print("Skipping layers") continue + # If the rotation is not fused, redefine as a Parameter, to enable its optimization + if not fuse_rotations: + rot_mat = torch.nn.Parameter(rot_mat) + for name, indexes in region.srcs.items(): module = region.get_module_from_name(name) - if hasattr(module, 'allocate_params'): - module.allocate_params(module) axis = _get_output_axis(module) - weight = module.weight.data - if axis == 0: - weight = rot_func(weight.t(), rot_mat, K).t() - elif axis == 1: - weight = rot_func(weight, rot_mat, K) - else: - raise RuntimeError("Not supported yet") - module.weight.data = weight + if fuse_rotations: + if hasattr(module, 'allocate_params'): + module.allocate_params(module) + weight = module.weight.data - if getattr(module, 'bias', None) is not None: - bias = module.bias.data - bias = rot_func(bias, rot_mat, K) - module.bias.data = bias - if hasattr(module, 'offload_params'): - module.offload_params(module) + if axis == 0: + weight = rot_func(weight.t(), rot_mat, K).t() + elif axis == 1: + weight = rot_func(weight, rot_mat, K) + else: + raise RuntimeError("Not supported yet") + module.weight.data = weight + + if getattr(module, 'bias', None) is not None: + bias = module.bias.data + bias = rot_func(bias, rot_mat, K) + module.bias.data = bias + if hasattr(module, 'offload_params'): + module.offload_params(module) + else: + parametrize.register_parametrization( + module, + "weight", + RotationWeightParametrization( + rot_mat=rot_mat, + rot_func=rot_func, + output_axis=axis, + is_source=True, + )) + if getattr(module, 'bias', None) is not None: + parametrize.register_parametrization( + module, + "bias", + RotationBiasParametrization( + rot_mat=rot_mat, + rot_func=rot_func, + output_axis=axis, + is_source=True, + )) for name, indexes in region.sinks.items(): module = region.get_module_from_name(name) - if hasattr(module, 'allocate_params'): - module.allocate_params(module) axis = _get_input_axis(module) - weight = module.weight.data - if axis == 1: - _update_weights(module, rot_func(weight, rot_mat, K), 'weight') - elif axis == 0: - _update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight') + if not insert_rotation_module and not fuse_rotations: + parametrize.register_parametrization( + module, + "weight", + RotationWeightParametrization( + rot_mat=rot_mat, + rot_func=rot_func, + input_axis=axis, + is_sink=True, + )) else: - raise RuntimeError("Not supported yet") + # Verify that there are no parametrizations, as otherwise the underlying weights will not be updated + assert not hasattr(module, "parametrizations"), "Fused rotations need to be incorporated before the parametrized rotations." + + if hasattr(module, 'allocate_params'): + module.allocate_params(module) + weight = module.weight.data + + if axis == 1: + _update_weights(module, rot_func(weight, rot_mat, K), 'weight') + elif axis == 0: + _update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight') + else: + raise RuntimeError("Not supported yet") - if hasattr(module, 'offload_params'): - module.offload_params(module) + if hasattr(module, 'offload_params'): + module.offload_params(module) if insert_rotation_module and len(region.srcs) == 0: rewriter = ModuleInstanceToModuleInstance( @@ -1476,8 +1528,12 @@ def rotate_matmuls(self, graph_module): graph_module.recompile() graph_module.graph.lint() - def apply(self, - graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: + def apply( + self, + graph_model: GraphModule, + fuse_rotations: bool = True, + additional_regions: Optional[List[Region]] = None + ) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: rewriters = [] regions = _extract_regions( graph_model, @@ -1486,6 +1542,8 @@ def apply(self, 'supported_sinks': self.supported_sinks, 'scale_invariant_layers': self.scale_invariant_layers, 'scale_invariant_function': self.scale_invariant_function}) + if additional_regions is not None: + regions.extend(additional_regions) eq_layers = set() orphan_regions = [] self.find_module(graph_model, orphan_regions) @@ -1497,11 +1555,18 @@ def apply(self, # Layerwise have only a single sink named 'sinks0' id_sink = id(o_r.get_module_from_name('sinks0')) if id_sink not in eq_layers: - regions.append(o_r) + # Orphan regions result in an in-place update of the weights, so these are applied before + # the rest of the rotations, to simplify the logic when fuse_rotations = False, as, + # otherwise, additional checks need to be incorporated to verify if the module weights + # have any parametrizations already, since in that case, the in-place update needs to + # be performed in module.parametrizations.weight.original. + # TODO: Use deque to perform this operation in O(1) + regions = [o_r] + regions if self.rotate_matmul: self.rotate_matmuls(graph_model) if len(regions) > 0: - rewriters = _apply_rotate(graph_model, regions, self.full_rotation_method) + rewriters = _apply_rotate( + graph_model, regions, self.full_rotation_method, fuse_rotations) if self.return_rewriters: return graph_model, rewriters else: @@ -1600,193 +1665,3 @@ def apply(self, model: nn.Module, fuse_rotations: bool = True) -> nn.Module: if len(regions) > 0: _apply_rotate(model, regions) return model - - -def _apply_rotate_fused_rotations( - model: nn.Module, - regions: List[Region], - full_rotation_method='had', - fuse_rotations: bool = True): - rewriters = [] - for region in regions: - insert_rotation_module = len(region.srcs) == 0 - - if not insert_rotation_module and not region.is_valid: - continue - hidden_dim = region.max_shape_sinks - if not insert_rotation_module and full_rotation_method == 'ort': - rot_mat = random_orthogonal_matrix(hidden_dim) - # If the rotations are not fused, redefine as parameter - if not fuse_rotations: - rot_mat = torch.nn.Parameter(rot_mat) - K = None - rot_func = _apply_ort_device - elif not insert_rotation_module and not fuse_rotations: - # TODO: Generalize - device = next(model.parameters()).device - rot_mat = torch.nn.Parameter(random_hadamard_matrix(hidden_dim, device)) - K = None - rot_func = _apply_ort_device - else: - try: - # Build hadamard rotation matrix - rot_mat, K = get_hadK(hidden_dim) - rot_func = _apply_had_device - except AssertionError as e: - print(f"Incomptible shapes {hidden_dim}") - if not insert_rotation_module: - print("Falling back to orthogonal matrices") - rot_mat = random_orthogonal_matrix(hidden_dim) - K = None - rot_func = _apply_ort_device - print("Skipping layers") - continue - - for name, indexes in region.srcs.items(): - module = region.get_module_from_name(name) - axis = _get_output_axis(module) - - assert not insert_rotation_module, "Orphan regions must not have sources." - - if not insert_rotation_module and fuse_rotations: - # Verify that there are no parametrizations, as otherwise the underlying data will not be updated - assert not hasattr(module, "parametrizations"), "Fused rotations need to be incorporated before the parametrized rotations." - - if hasattr(module, 'allocate_params'): - module.allocate_params(module) - - weight = module.weight.data - - if axis == 0: - weight = rot_func(weight.t(), rot_mat, K).t() - elif axis == 1: - weight = rot_func(weight, rot_mat, K) - else: - raise RuntimeError("Not supported yet") - module.weight.data = weight - - if getattr(module, 'bias', None) is not None: - bias = module.bias.data - bias = rot_func(bias, rot_mat, K) - module.bias.data = bias - - if hasattr(module, 'offload_params'): - module.offload_params(module) - elif not insert_rotation_module and not fuse_rotations: - # Parametrize weights and possibly bias with unfused rotations - parametrize.register_parametrization( - module, - "weight", - RotationWeightParametrization( - rot_mat=rot_mat, - rot_func=rot_func, - output_axis=axis, - is_source=True, - )) - if getattr(module, 'bias', None) is not None: - parametrize.register_parametrization( - module, - "bias", - RotationBiasParametrization( - rot_mat=rot_mat, - rot_func=rot_func, - output_axis=axis, - is_source=True, - )) - - for name, indexes in region.sinks.items(): - module = region.get_module_from_name(name) - axis = _get_input_axis(module) - - if not insert_rotation_module and not fuse_rotations: - parametrize.register_parametrization( - module, - "weight", - RotationWeightParametrization( - rot_mat=rot_mat, - rot_func=rot_func, - input_axis=axis, - is_sink=True, - )) - else: - # Verify that there are no parametrizations, as otherwise the underlying data will not be updated - assert not hasattr(module, "parametrizations"), "Fused rotations need to be incorporated before the parametrized rotations." - - if hasattr(module, 'allocate_params'): - module.allocate_params(module) - weight = module.weight.data - - if axis == 1: - _update_weights(module, rot_func(weight, rot_mat, K), 'weight') - elif axis == 0: - _update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight') - else: - raise RuntimeError("Not supported yet") - - if hasattr(module, 'offload_params'): - module.offload_params(module) - - if insert_rotation_module and len(region.srcs) == 0: - rewriter = ModuleInstanceToModuleInstance( - module, RotatedModule(had_mat=rot_mat, k=K, layer=module)) - rewriters.append(rewriter) - for r in rewriters: - model = r.apply(model) - return rewriters - - -# TODO: Consolidate with GraphRotationEqualization -class GraphRotationEqualizationOptimization(GraphRotationEqualization): - - def __init__( - self, - blacklist_layers: Optional[List[str]] = None, - orphan_sink: bool = False, - rotate_matmul: bool = False, - full_rotation_method: str = 'had', - ) -> None: - super(GraphRotationEqualizationOptimization, self).__init__( - blacklist_layers=blacklist_layers, - orphan_sink=orphan_sink, - rotate_matmul=rotate_matmul, - full_rotation_method=full_rotation_method, - return_rewriters=True, - ) - - def apply( - self, - graph_model: GraphModule, - fuse_rotations: bool = True, - additional_regions: Optional[List] = None - ) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: - rewriters = [] - regions = _extract_regions( - graph_model, - state_impl_kwargs={ - 'supported_srcs': self.supported_srcs, - 'supported_sinks': self.supported_sinks, - 'scale_invariant_layers': self.scale_invariant_layers, - 'scale_invariant_function': self.scale_invariant_function}) - if additional_regions is not None: - regions.extend(additional_regions) - eq_layers = set() - orphan_regions = [] - self.find_module(graph_model, orphan_regions) - for r in regions: - id_list = [id(r.name_to_module[sink_name]) for sink_name in r.sinks_names] - eq_layers.update(id_list) - if self.orphan_sink: - for o_r in orphan_regions: - # Layerwise have only a single sink named 'sinks0' - id_sink = id(o_r.get_module_from_name('sinks0')) - if id_sink not in eq_layers: - regions = [o_r] + regions - if self.rotate_matmul: - self.rotate_matmuls(graph_model) - if len(regions) > 0: - rewriters = _apply_rotate_fused_rotations( - graph_model, regions, self.full_rotation_method, fuse_rotations) - if self.return_rewriters: - return graph_model, rewriters - else: - return graph_model diff --git a/src/brevitas/graph/hadamard.py b/src/brevitas/graph/hadamard.py index 29a09ebed..c74695e00 100644 --- a/src/brevitas/graph/hadamard.py +++ b/src/brevitas/graph/hadamard.py @@ -123,7 +123,7 @@ def random_hadamard_matrix(size, device): Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) Q = Q * 2 - 1 Q = torch.diag(Q) - return matmul_hadU(Q).to(device) + return matmul_hadU(Q).to(device).float() def matmul_hadU_cuda(X, hadK, K): diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py index 93b7edf7e..31cf00051 100644 --- a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -87,7 +87,6 @@ def apply_rotation_optimization( trainable_rotations = extract_trainable_rotation_matrices(graph_model) for rot_mat in trainable_rotations: rot_mat.requires_grad = True - # Initialize optimizer optimizer = SGDG(trainable_rotations, lr=training_args.learning_rate, stiefel=True) trainer = Trainer( model=graph_model, diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index fb5b9f368..95a3055dd 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -17,7 +17,6 @@ from brevitas.export import export_torch_qcdq from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas.graph.equalize import GraphRotationEqualization -from brevitas.graph.equalize import GraphRotationEqualizationOptimization from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.utils import get_module @@ -53,7 +52,12 @@ def set_seed(seed): torch.random.manual_seed(seed) -def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool = False): +def fused_rotation_no_fx( + model, + calibration_loader, + args, + fuse_rotations: bool = True, + add_self_attention_regions: bool = False): with torch.no_grad(): new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) apply_layernorm_affine_merge(new_model) @@ -67,49 +71,22 @@ def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool = orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, return_rewriters=True) - new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations) - rewriters = fix_rewriter(rewriters, model, 'weight') - - for r in rewriters: - r.apply(model) - remove_hooks(new_model) - - -def fused_optimized_rotation_no_fx( - model, - calibration_loader, - args, - fuse_rotations: bool = False, - add_additional_regions: bool = False): - with torch.no_grad(): - new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) - apply_layernorm_affine_merge(new_model) - new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True) - rewriters = fix_rewriter(rewriters, model, 'weight') - - for r in rewriters: - r.apply(model) - #new_model = offload_model(new_model) - # Regions with source v_proj and sink o_proj - self_attention_regions = find_self_attention_rotation_regions( - new_model, model.config.hidden_size // - model.config.num_attention_heads) if add_additional_regions else None - eq = GraphRotationEqualizationOptimization( - orphan_sink=args.rotation_orphan_sink, - full_rotation_method=args.rotation_mode, - ) + self_attention_regions = ( + find_self_attention_rotation_regions( + new_model, model.config.hidden_size // + model.config.num_attention_heads) if add_self_attention_regions else None) new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations, additional_regions=self_attention_regions) - - # Retrieve additional rewriters for unfused rotations - rewriters_unfused_rotations = extract_rewriters_unfused_rotations(new_model, rewriters) - rewriters.extend(rewriters_unfused_rotations) + # Additional rewriters need to be added if rotations are not fused + if not fuse_rotations: + rewriters_unfused_rotations = extract_rewriters_unfused_rotations(new_model, rewriters) + rewriters.extend(rewriters_unfused_rotations) rewriters = fix_rewriter(rewriters, model, 'weight') for r in rewriters: r.apply(model) - #remove_hooks(new_model) + remove_hooks(new_model) def set_seed(seed): @@ -282,7 +259,7 @@ def main(args, unknown_args=None): model = replace_rmsnorm_with_torch(model, model.config) # TODO: Refactor - if args.rotation == 'fused_no_fx_optimize': + if args.rotation in ['fused_no_fx_optimize', 'fused_no_fx_optimize_self_attn_region']: for i in range(len(calibration_loader)): del calibration_loader[i]["attention_mask"] calibration_loader[i]["labels"] = calibration_loader[i]["input_ids"] @@ -323,8 +300,11 @@ def main(args, unknown_args=None): elif args.rotation == 'fused_no_fx': fused_rotation_no_fx(model, calibration_loader, args) elif args.rotation == 'fused_no_fx_optimize': - fused_optimized_rotation_no_fx( - model, calibration_loader, args, fuse_rotations=False, add_additional_regions=True) + fused_rotation_no_fx( + model, calibration_loader, args, fuse_rotations=False, add_self_attention_regions=False) + elif args.rotation == 'fused_no_fx_optimize_self_attn_region': + fused_rotation_no_fx( + model, calibration_loader, args, fuse_rotations=False, add_self_attention_regions=True) # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing # with all the variability in HF implementations @@ -421,10 +401,9 @@ def main(args, unknown_args=None): with torch.no_grad(): model(**calibration_loader[0]) - # TODO: Refactor remove_hooks(model) - if args.rotation == 'fused_no_fx_optimize': + if args.rotation in ['fused_no_fx_optimize', 'fused_no_fx_optimize_self_attn_region']: apply_rotation_optimization( graph_model=model, tokenizer=tokenizer, @@ -670,7 +649,12 @@ def parse_args(args): '--rotation', type=str, default=None, - choices=['fx', 'layerwise', 'fused_no_fx', 'fused_no_fx_optimize'], + choices=[ + 'fx', + 'layerwise', + 'fused_no_fx', + 'fused_no_fx_optimize', + 'fused_no_fx_optimize_self_attn_region'], help='Apply graph rotation equalization') parser.add_argument( '--rotation-mode', diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index f7e11095d..ffa5f6454 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -31,7 +31,7 @@ from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch from brevitas_examples.llm.llm_quant.rotation_utils import extract_trainable_rotation_matrices from brevitas_examples.llm.llm_quant.rotation_utils import fuse_rotations -from brevitas_examples.llm.main import fused_optimized_rotation_no_fx +from brevitas_examples.llm.main import fused_rotation_no_fx from brevitas_examples.llm.main import main from brevitas_examples.llm.main import parse_args from tests.conftest import SEED @@ -648,37 +648,41 @@ def test_small_models_rotations( # Save a copy to apply graph rotation equalization on model_copy = copy.deepcopy(model) - with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator)): - fused_optimized_rotation_no_fx( - model, - calibration_loader, - args, - fuse_rotations=True, - add_additional_regions=add_additional_regions) + # offload_model is patched to behave as an identity, thus making sure that the operations + # are deterministic, enabling to test that the tensors match exactly. + with patch('brevitas_examples.llm.main.offload_model', lambda m: m): + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator)): + fused_rotation_no_fx( + model, + calibration_loader, + args, + fuse_rotations=True, + add_self_attention_regions=add_additional_regions) # Run model and save outputs with torch.no_grad(): expected_logits = model(**calibration_loader[0]).logits # Instead of random orthogonal matrices, we want to use the same ones as when the activations are not fused. - if rotation_mode == 'had': - with patch('brevitas.graph.equalize._apply_ort_device', _apply_had_device): - fused_optimized_rotation_no_fx( - model_copy, - calibration_loader, - args, - fuse_rotations=False, - add_additional_regions=add_additional_regions) - else: - with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator_clone)): - fused_optimized_rotation_no_fx( - model_copy, - calibration_loader, - args, - fuse_rotations=False, - add_additional_regions=add_additional_regions) + with patch('brevitas_examples.llm.main.offload_model', lambda m: m): + if rotation_mode == 'had': + with patch('brevitas.graph.equalize._apply_ort_device', _apply_had_device): + fused_rotation_no_fx( + model_copy, + calibration_loader, + args, + fuse_rotations=False, + add_self_attention_regions=add_additional_regions) + else: + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator_clone)): + fused_rotation_no_fx( + model_copy, + calibration_loader, + args, + fuse_rotations=False, + add_self_attention_regions=add_additional_regions) # Fuse matrices with module weights if fused_rotations: From 53488cc2d38142729cc080519a895c75bcdd285f Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 16 Dec 2024 11:17:32 +0000 Subject: [PATCH 09/11] Enable quantization of parametrized modules --- src/brevitas/graph/base.py | 15 ++++++++++++++- src/brevitas/graph/quantize_impl.py | 8 +++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index def3f7070..dae5160ce 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -3,11 +3,13 @@ from abc import ABC from abc import abstractmethod +from collections import OrderedDict import inspect from inspect import getcallargs import torch from torch.nn import Module +import torch.nn.utils.parametrize as parametrize from torch.overrides import get_testing_overrides from brevitas.fx import GraphModule @@ -154,7 +156,18 @@ def _init_new_module(self, old_module: Module, name=None): def _replace_old_module(self, model, old_module, new_module, load_state_dict=True): replace_module(model, old_module, new_module) if load_state_dict: - new_module.load_state_dict(old_module.state_dict()) + # The dictionary entries relative to parametrizations need to be ignored, as these are passed + # when invoking transfer_parametrizations_and_params. + old_module_state_dict = OrderedDict({ + k: v for k, + v in old_module.state_dict().items() if not k.startswith("parametrizations")}) + # If the old module is parametrized, these need to be transferred to the new module. Strict needs to be set to False, + # as there will be missing keys for those parameters which have any parametrizations attached. + if parametrize.is_parametrized(old_module): + new_module.load_state_dict(old_module_state_dict, strict=False) + parametrize.transfer_parametrizations_and_params(old_module, new_module) + else: + new_module.load_state_dict(old_module_state_dict) class InsertModuleCallAfter(GraphTransform): diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index 535f9a8f9..a4d348ab5 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn +import torch.nn.utils.parametrize as parametrize import brevitas from brevitas.graph.base import InsertModuleCallAfter @@ -511,7 +512,7 @@ def find_module( Specifically, it allows to map nn.MultiheadAttetion to its quantized counterpart and not its Linear submodules. """ - if _module_class_name(type(model)) in layer_map.keys(): + if _module_class_name(parametrize.type_before_parametrizations(model)) in layer_map.keys(): module_to_replace.append(model) else: for name, module in model.named_children(): @@ -532,8 +533,9 @@ def layerwise_layer_handler( find_module(model, layer_map, module_to_replace, name_blacklist) rewriters = [] for module in module_to_replace: - if layer_map[_module_class_name(type(module))] is not None: - quant_module_class, quant_module_kwargs = layer_map[_module_class_name(type(module))] + if layer_map[_module_class_name( + parametrize.type_before_parametrizations(module))] is not None: + quant_module_class, quant_module_kwargs = layer_map[_module_class_name(parametrize.type_before_parametrizations(module))] rewriter = ModuleToModuleByInstance(module, quant_module_class, **quant_module_kwargs) rewriters.append(rewriter) for rewriter in rewriters: From 0a209ee4363643717cca5d74620761fbfe0ba019 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 31 Dec 2024 11:04:44 +0000 Subject: [PATCH 10/11] Fix logic registering parametrizations --- src/brevitas/graph/base.py | 98 +++++++++++++++++ src/brevitas/graph/equalize.py | 104 +++++++++++------- src/brevitas/graph/quantize_impl.py | 3 +- src/brevitas/nn/equalized_layer.py | 22 +--- .../llm/llm_quant/rotation_optimization.py | 3 +- .../llm/llm_quant/rotation_utils.py | 4 +- src/brevitas_examples/llm/main.py | 68 ++++++++++-- tests/brevitas/graph/test_equalization.py | 3 +- 8 files changed, 229 insertions(+), 76 deletions(-) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index dae5160ce..1546ecb67 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -6,9 +6,12 @@ from collections import OrderedDict import inspect from inspect import getcallargs +from typing import Any, Callable, Dict, Type, Union import torch +from torch import Tensor from torch.nn import Module +from torch.nn import Parameter import torch.nn.utils.parametrize as parametrize from torch.overrides import get_testing_overrides @@ -187,6 +190,76 @@ def apply(self, graph_model: GraphModule) -> GraphModule: return graph_model +class ModuleInstanceRegisterParametrization(Transform): + + def __init__( + self, old_module_instance: Module, tensor_name: str, + parametrization_module: Module) -> None: + self.old_module_instance = old_module_instance + self.tensor_name = tensor_name + self.parametrization_module = parametrization_module + + def apply(self, model: GraphModule) -> GraphModule: + for old_module in model.modules(): + if old_module is self.old_module_instance: + # register the parametrization in the old_module + parametrize.register_parametrization( + old_module, self.tensor_name, self.parametrization_module) + break + return model + + +class ModuleInstanceFuseRotationWeights(Transform): + + def __init__( + self, + old_module_instance: Module, + rot_mat: Union[Parameter, Tensor], + rot_func: Callable, + K: int, + tensor_name: str, + axis: int, + is_source: bool, + ): + self.old_module_instance = old_module_instance + self.rot_mat = rot_mat + self.rot_func = rot_func + self.K = K + self.tensor_name = tensor_name + self.axis = axis + self.is_source = is_source + + def apply(self, model: GraphModule) -> GraphModule: + for old_module in model.modules(): + if old_module is self.old_module_instance: + if hasattr(old_module, 'allocate_params'): + old_module.allocate_params(old_module) + weight = getattr(old_module, self.tensor_name).data + + if self.is_source: + if self.axis == 0: + weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() + elif self.axis == 1: + weight = self.rot_func(weight, self.rot_mat, self.K) + else: + raise RuntimeError("Not supported yet") + # If not a source, the module is either a sink or an orphan + else: + if self.axis == 1: + weight = self.rot_func(weight, self.rot_mat, self.K) + elif self.axis == 0: + weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() + else: + raise RuntimeError("Not supported yet") + # Modify the weights in-place + getattr(old_module, self.tensor_name).data = weight + + if hasattr(old_module, 'offload_params'): + old_module.offload_params(old_module) + break + return model + + class ModuleInstanceToModuleInstance(Transform): def __init__(self, old_module_instance, new_module_instance): @@ -202,6 +275,31 @@ def apply(self, model: GraphModule) -> GraphModule: return model +class ModuleInstanceWrapModule(Transform): + + def __init__( + self, + old_module_instance: Module, + wrapper_class: Type[Module], + module_attribute: str, + kwargs_wrapper: Dict[str, Any]): + self.old_module_instance = old_module_instance + self.wrapper_class = wrapper_class + self.module_attribute = module_attribute + self.kwargs_wrapper = kwargs_wrapper + + def apply(self, model: GraphModule) -> GraphModule: + for old_module in model.modules(): + if old_module is self.old_module_instance: + kwargs = {self.module_attribute: self.old_module_instance} + kwargs.update(self.kwargs_wrapper) + new_module_instance = self.wrapper_class(**kwargs) + # init the new module based on the old one + replace_module(model, old_module, new_module_instance) + break + return model + + class ModuleToModuleByName(ModuleToModule): def __init__(self, old_module_name, new_module_class, **kwargs): diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index faae03a62..0da49233c 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -21,11 +21,13 @@ from brevitas import torch_version from brevitas.fx import GraphModule from brevitas.fx import Node -from brevitas.graph import ModuleToModuleByClass from brevitas.graph import ModuleToModuleByInstance from brevitas.graph.base import GraphTransform from brevitas.graph.base import InsertModuleCallAfter +from brevitas.graph.base import ModuleInstanceFuseRotationWeights +from brevitas.graph.base import ModuleInstanceRegisterParametrization from brevitas.graph.base import ModuleInstanceToModuleInstance +from brevitas.graph.base import ModuleInstanceWrapModule from brevitas.graph.base import Transform from brevitas.graph.hadamard import get_hadK from brevitas.graph.hadamard import matmul_hadU @@ -1316,7 +1318,8 @@ def _apply_rotate( model: nn.Module, regions: List[Region], full_rotation_method: str = 'had', - fuse_rotations: bool = True): + fuse_rotations: bool = True, + apply_inplace_rotations: bool = True): rewriters = [] for region in regions: insert_rotation_module = len(region.srcs) == 0 @@ -1351,7 +1354,7 @@ def _apply_rotate( continue # If the rotation is not fused, redefine as a Parameter, to enable its optimization - if not fuse_rotations: + if not insert_rotation_module and not fuse_rotations: rot_mat = torch.nn.Parameter(rot_mat) for name, indexes in region.srcs.items(): @@ -1359,36 +1362,44 @@ def _apply_rotate( axis = _get_output_axis(module) if fuse_rotations: - if hasattr(module, 'allocate_params'): - module.allocate_params(module) - weight = module.weight.data - - if axis == 0: - weight = rot_func(weight.t(), rot_mat, K).t() - elif axis == 1: - weight = rot_func(weight, rot_mat, K) - else: - raise RuntimeError("Not supported yet") - module.weight.data = weight + rewriter = ModuleInstanceFuseRotationWeights( + old_module_instance=module, + rot_mat=rot_mat, + rot_func=rot_func, + K=K, + tensor_name="weight", + axis=axis, + is_source=True, + ) + rewriters.append(rewriter) if getattr(module, 'bias', None) is not None: - bias = module.bias.data - bias = rot_func(bias, rot_mat, K) - module.bias.data = bias - if hasattr(module, 'offload_params'): - module.offload_params(module) + rewriter = ModuleInstanceFuseRotationWeights( + old_module_instance=module, + rot_mat=rot_mat, + rot_func=rot_func, + K=K, + tensor_name="bias", + axis=1, + is_source=True, + ) + rewriters.append(rewriter) else: - parametrize.register_parametrization( + rewriter = ModuleInstanceRegisterParametrization( module, "weight", RotationWeightParametrization( rot_mat=rot_mat, rot_func=rot_func, - output_axis=axis, + axis=axis, is_source=True, )) + rewriters.append(rewriter) if getattr(module, 'bias', None) is not None: - parametrize.register_parametrization( + # TODO: Consolidate RotationBiasParametrization into a single + # class, by setting output_axis = 1. Also, could use a single + # axis, as input_axis and output_axis are not used simultaneously + rewriter = ModuleInstanceRegisterParametrization( module, "bias", RotationBiasParametrization( @@ -1397,45 +1408,49 @@ def _apply_rotate( output_axis=axis, is_source=True, )) + rewriters.append(rewriter) for name, indexes in region.sinks.items(): module = region.get_module_from_name(name) axis = _get_input_axis(module) if not insert_rotation_module and not fuse_rotations: - parametrize.register_parametrization( + rewriter = ModuleInstanceRegisterParametrization( module, "weight", RotationWeightParametrization( rot_mat=rot_mat, rot_func=rot_func, - input_axis=axis, + axis=axis, is_sink=True, )) + rewriters.append(rewriter) else: # Verify that there are no parametrizations, as otherwise the underlying weights will not be updated assert not hasattr(module, "parametrizations"), "Fused rotations need to be incorporated before the parametrized rotations." - if hasattr(module, 'allocate_params'): - module.allocate_params(module) - weight = module.weight.data - - if axis == 1: - _update_weights(module, rot_func(weight, rot_mat, K), 'weight') - elif axis == 0: - _update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight') - else: - raise RuntimeError("Not supported yet") - - if hasattr(module, 'offload_params'): - module.offload_params(module) + rewriter = ModuleInstanceFuseRotationWeights( + old_module_instance=module, + rot_mat=rot_mat, + rot_func=rot_func, + K=K, + tensor_name="weight", + axis=axis, + is_source=False, + ) + rewriters.append(rewriter) if insert_rotation_module and len(region.srcs) == 0: - rewriter = ModuleInstanceToModuleInstance( - module, RotatedModule(had_mat=rot_mat, k=K, layer=module)) + rewriter = ModuleInstanceWrapModule( + module, RotatedModule, "layer", { + "had_mat": rot_mat, "k": K}) rewriters.append(rewriter) for r in rewriters: - model = r.apply(model) + # The parametrizations need to be registered after the potential HF hooks have been + # removed, as otherwise the device maps will not match the structure of the + # model's state_dict after the registration of the parametrizations. + if apply_inplace_rotations and not isinstance(r, ModuleInstanceRegisterParametrization): + model = r.apply(model) return rewriters @@ -1532,7 +1547,8 @@ def apply( self, graph_model: GraphModule, fuse_rotations: bool = True, - additional_regions: Optional[List[Region]] = None + additional_regions: Optional[List[Region]] = None, + apply_inplace_rotations: bool = True, ) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: rewriters = [] regions = _extract_regions( @@ -1566,7 +1582,11 @@ def apply( self.rotate_matmuls(graph_model) if len(regions) > 0: rewriters = _apply_rotate( - graph_model, regions, self.full_rotation_method, fuse_rotations) + graph_model, + regions, + self.full_rotation_method, + fuse_rotations, + apply_inplace_rotations) if self.return_rewriters: return graph_model, rewriters else: diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index a4d348ab5..538ce5717 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn import torch.nn.utils.parametrize as parametrize +from tqdm import tqdm import brevitas from brevitas.graph.base import InsertModuleCallAfter @@ -538,6 +539,6 @@ def layerwise_layer_handler( quant_module_class, quant_module_kwargs = layer_map[_module_class_name(parametrize.type_before_parametrizations(module))] rewriter = ModuleToModuleByInstance(module, quant_module_class, **quant_module_kwargs) rewriters.append(rewriter) - for rewriter in rewriters: + for rewriter in tqdm(rewriters, leave=False): model = rewriter.apply(model) return model diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index ccd812713..2c48f9da3 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -69,14 +69,6 @@ def __init__(self, layer, had_mat=None, k=None) -> None: self.layer = layer self.k = k - @property - def weight(self) -> Optional[torch.Tensor]: - return getattr(self.layer, 'weight', None) - - @property - def bias(self) -> Optional[torch.Tensor]: - return getattr(self.layer, 'bias', None) - def forward(self, inp, **kwargs): is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None # If k is None, we assume that an orthogonal matrix is used @@ -110,8 +102,7 @@ def __init__( self, rot_mat: torch.nn.Parameter, rot_func: Callable, - input_axis: Optional[int] = None, - output_axis: Optional[int] = None, + axis: int, is_source: bool = False, is_sink: bool = False, is_orphan: bool = False, @@ -119,8 +110,7 @@ def __init__( super().__init__() self.rot_mat = rot_mat self.rot_func = rot_func - self.input_axis = input_axis - self.output_axis = output_axis + self.axis = axis self.is_source = is_source self.is_sink = is_sink self.is_orphan = is_orphan @@ -128,17 +118,17 @@ def __init__( def forward(self, weight: torch.Tensor) -> torch.Tensor: if self.is_sink or self.is_orphan: - if self.input_axis == 1: + if self.axis == 1: weight = self.rot_func(weight, self.rot_mat, self.K) - elif self.input_axis == 0: + elif self.axis == 0: weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() else: raise RuntimeError("Not supported yet") if self.is_source: - if self.output_axis == 0: + if self.axis == 0: weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() - elif self.output_axis == 1: + elif self.axis == 1: weight = self.rot_func(weight, self.rot_mat, self.K) else: raise RuntimeError("Not supported yet") diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py index 31cf00051..618763498 100644 --- a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -20,8 +20,7 @@ @dataclass class ModelArguments: input_model: Optional[str] = field( - default="hf-internal-testing/tiny-random-LlamaForCausalLM", - metadata={"help": "Input model"}) + default="meta-llama/Llama-3.2-1B", metadata={"help": "Input model"}) output_rotation_path: Optional[str] = field( default="test-output", metadata={"help": "Output rotation checkpoint path"}) optimized_rotation_path: Optional[str] = field( diff --git a/src/brevitas_examples/llm/llm_quant/rotation_utils.py b/src/brevitas_examples/llm/llm_quant/rotation_utils.py index 037de166d..9c84aeff7 100644 --- a/src/brevitas_examples/llm/llm_quant/rotation_utils.py +++ b/src/brevitas_examples/llm/llm_quant/rotation_utils.py @@ -66,13 +66,15 @@ def fuse_rotations(model: nn.Module) -> None: parametrize.remove_parametrizations(module, "bias", leave_parametrized=True) +# TODO: Remove? We rely on ModuleInstanceRegisterParametrization def extract_rewriters_unfused_rotations(model: nn.Module, rewriters: List[Transform]) -> List[Transform]: extra_rewriters = [] for module in model.modules(): if hasattr(module, "parametrizations"): # Verify that the current module does not have already associated a RotatedModule - if len([r for r in rewriters if r.old_module_instance is module]) == 0: + if len([r for r in rewriters if r.old_module_instance is module and + isinstance(r, ModuleInstanceToModuleInstance)]) == 0: # Identity rewriter, only useful externaly rewriter = ModuleInstanceToModuleInstance(module, module) extra_rewriters.append(rewriter) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 95a3055dd..947314185 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -3,7 +3,10 @@ import argparse from copy import deepcopy +from functools import wraps +import os import sys +from typing import Callable, List from warnings import warn import numpy as np @@ -16,6 +19,7 @@ from brevitas.export import export_torch_qcdq from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager +from brevitas.graph.base import ModuleInstanceFuseRotationWeights from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import layerwise_quantize @@ -52,6 +56,37 @@ def set_seed(seed): torch.random.manual_seed(seed) +def on_process(process_index: int): + + def decorator(func: Callable): + + @wraps(func) + def _wrapper(model, *args, **kwargs): + curr_process_index = int(os.environ.get('LOCAL_RANK', -1)) + + if curr_process_index == -1 or (process_index == curr_process_index): + print(f"Applying {func.__name__} on process index {curr_process_index}") + return func(model, *args, **kwargs) + else: + print(f"Skipping function {func.__name__} on process index {curr_process_index}") + return model + + return _wrapper + + return decorator + + +def apply_fused_rotations(model: torch.nn.Module, rewriters: List) -> torch.nn.Module: + model = offload_model(model) + for r in rewriters: + if isinstance(r, ModuleInstanceFuseRotationWeights): + model = r.apply(model) + remove_hooks(model) + return model + + +# TODO: Use no_grad? The result of fusing the rotations would yield tensor with requires_grad set to False, +# which might no be a problem, as that flag is set in the appropiate QAT/PTQ algorithms. def fused_rotation_no_fx( model, calibration_loader, @@ -66,7 +101,6 @@ def fused_rotation_no_fx( for r in rewriters: r.apply(model) - new_model = offload_model(new_model) eq = GraphRotationEqualization( orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, @@ -76,17 +110,17 @@ def fused_rotation_no_fx( find_self_attention_rotation_regions( new_model, model.config.hidden_size // model.config.num_attention_heads) if add_self_attention_regions else None) - new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations, additional_regions=self_attention_regions) - # Additional rewriters need to be added if rotations are not fused - if not fuse_rotations: - rewriters_unfused_rotations = extract_rewriters_unfused_rotations(new_model, rewriters) - rewriters.extend(rewriters_unfused_rotations) - + new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations, additional_regions=self_attention_regions, apply_inplace_rotations=False) + # Rewriters need to be fixed to point to the module instances of the original model rewriters = fix_rewriter(rewriters, model, 'weight') - + # The weights of the FX model and the original model are tied, so the rotation fusing has already been applied. + # Note that the parametrization registration cannot be done in a model that has been offloaded using + # offload_model, as the change in the state dictionary when registering the parametrization causes the removal + # of the hooks to crash. This is due to the fact that the device_map in the AlignDevicesHook is no longer valid. + model = apply_fused_rotations(model, rewriters) for r in rewriters: - r.apply(model) - remove_hooks(new_model) + if not isinstance(r, ModuleInstanceFuseRotationWeights): + model = r.apply(model) def set_seed(seed): @@ -264,6 +298,14 @@ def main(args, unknown_args=None): del calibration_loader[i]["attention_mask"] calibration_loader[i]["labels"] = calibration_loader[i]["input_ids"] + def mock_save_pretrained_fn(*args, **kwargs): + pass + + # For a PretrainedModel, the Trainer in accelerate calls save_pretrained after + # finishing the optimization. However, this method no longer works after + # registering parametrizations/quantizing, so this method is mocked to prevent + # a crash. + model.save_pretrained = mock_save_pretrained_fn model.config.use_cache = False model.config.loss_type = "ForCausalLM" @@ -368,7 +410,6 @@ def main(args, unknown_args=None): quantize_embedding=False) if not args.quantize_last_layer: if require_fx: - # TODO: Fix when using UnfusedRotation, layer_map[type(last_module)][1] crashes last_node = [node for node in model.graph.nodes if node.op == 'call_module'][-1] last_module = get_module(model, last_node.target) last_layer_kwargs = layer_map[type(last_module)][1] @@ -411,6 +452,8 @@ def main(args, unknown_args=None): unknown_args=unknown_args, ) + remove_hooks(model) + if args.act_calibration: print("Apply act calibration...") apply_calibration(model, calibration_loader) @@ -447,10 +490,11 @@ def main(args, unknown_args=None): if args.eval and not args.no_quantize: print("Model eval...") + model = offload_model(model) quant_ppl = compute_perplexity( model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") - remove_hooks(model) + remove_hooks(model) if args.checkpoint_name is not None: print(f"Saving checkpoint to {args.checkpoint_name}") diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 1e913845c..2acf8287b 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -361,8 +361,7 @@ def test_composition_unfused_rotations(N): RotationWeightParametrization( rot_mat=rot_mat, rot_func=_apply_ort_device, - input_axis=_get_input_axis(rot_module), - output_axis=_get_output_axis(rot_module), + axis=_get_output_axis(rot_module) if is_source else _get_input_axis(rot_module), is_source=is_source, is_sink=is_sink, is_orphan=is_orphan, From adc5238bdff782ef8892e069e67cfeda164d9996 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 2 Jan 2025 12:39:40 +0000 Subject: [PATCH 11/11] Fix multi-GPU setup --- .../common/accelerate_utils/accelerate.py | 1 - src/brevitas_examples/llm/main.py | 83 ++++++++++++------- 2 files changed, 51 insertions(+), 33 deletions(-) diff --git a/src/brevitas_examples/common/accelerate_utils/accelerate.py b/src/brevitas_examples/common/accelerate_utils/accelerate.py index ead616ed2..369c456c0 100644 --- a/src/brevitas_examples/common/accelerate_utils/accelerate.py +++ b/src/brevitas_examples/common/accelerate_utils/accelerate.py @@ -407,7 +407,6 @@ def offload_model( else: device_map = infer_auto_device_map( model, memory_map, no_split_module_classes=model._no_split_modules) - model = dispatch_model(model, device_map) # Fixes an asymetric behavior in Accelerate where hooks are not attached at all when a single device is used. diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 947314185..08eb92769 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -3,6 +3,7 @@ import argparse from copy import deepcopy +from functools import partial from functools import wraps import os import sys @@ -20,6 +21,7 @@ from brevitas.export import export_torch_qcdq from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas.graph.base import ModuleInstanceFuseRotationWeights +from brevitas.graph.base import Transform from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import layerwise_quantize @@ -56,27 +58,31 @@ def set_seed(seed): torch.random.manual_seed(seed) -def on_process(process_index: int): +def is_main_process(): + return int(os.environ.get('LOCAL_RANK', -1)) in [-1, 0] - def decorator(func: Callable): - @wraps(func) - def _wrapper(model, *args, **kwargs): - curr_process_index = int(os.environ.get('LOCAL_RANK', -1)) +def on_process(func: Callable, process_index: int): - if curr_process_index == -1 or (process_index == curr_process_index): - print(f"Applying {func.__name__} on process index {curr_process_index}") - return func(model, *args, **kwargs) - else: - print(f"Skipping function {func.__name__} on process index {curr_process_index}") - return model + @wraps(func) + def _wrapper(model, *args, **kwargs): + curr_process_index = int(os.environ.get('LOCAL_RANK', -1)) + + if curr_process_index == -1 or (process_index == curr_process_index): + print(f"Applying {func.__name__} on process index {curr_process_index}") + return func(model, *args, **kwargs) + else: + print(f"Skipping function {func.__name__} on process index {curr_process_index}") + return model - return _wrapper + return _wrapper - return decorator +on_main_process = partial(on_process, process_index=0) -def apply_fused_rotations(model: torch.nn.Module, rewriters: List) -> torch.nn.Module: + +@on_main_process +def apply_fused_rotations(model: torch.nn.Module, rewriters: List[Transform]) -> torch.nn.Module: model = offload_model(model) for r in rewriters: if isinstance(r, ModuleInstanceFuseRotationWeights): @@ -85,6 +91,15 @@ def apply_fused_rotations(model: torch.nn.Module, rewriters: List) -> torch.nn.M return model +@on_main_process +def evaluate_model(model: torch.nn.Module, validation_loader, args, tokenizer): + model = offload_model(model) + quant_ppl = compute_perplexity( + model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) + print(f"Perplexity ({args.dataset}): {quant_ppl:.3f}") + remove_hooks(model) + + # TODO: Use no_grad? The result of fusing the rotations would yield tensor with requires_grad set to False, # which might no be a problem, as that flag is set in the appropiate QAT/PTQ algorithms. def fused_rotation_no_fx( @@ -282,12 +297,9 @@ def main(args, unknown_args=None): if args.eval: assert args.export_target != 'torch_qcdq', "TorchScript QCDQ export and Evaluation simultaneously" - print("Float model eval...") - model = offload_model(model) - float_ppl = compute_perplexity( - model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) - remove_hooks(model) - print(f"Float perplexity ({args.dataset}): {float_ppl:.3f}") + print("Evaluating float model...") + evaluate_model(model, validation_loader, args, tokenizer) + print("Float evaluation done.") if args.replace_rmsnorm: model = replace_rmsnorm_with_torch(model, model.config) @@ -437,12 +449,21 @@ def mock_save_pretrained_fn(*args, **kwargs): if args.bias_corr: model = add_zero_bias_to_linear(model) - model = offload_model(model) - - with torch.no_grad(): - model(**calibration_loader[0]) - - remove_hooks(model) + # We need to run a calibration forward pass to initialize quantization-related parameters, + # e.g. scales. In DDP, as parameters are synchronized across replicas before optimization, + # it is not needed to run this pass for every process, as the parameters of the main + # process will be broadcasted to each replica. + if is_main_process(): + model = offload_model(model) + with torch.no_grad(): + model(**calibration_loader[0]) + remove_hooks(model) + else: + # TODO: Generalize this logic. Currently, only ParameterFromStatsFromParameterZeroPoint + # and ParameterFromStatsFromParameterScaling have the attribute init_done + for module in model.modules(): + if hasattr(module, "init_done"): + module.init_done = True if args.rotation in ['fused_no_fx_optimize', 'fused_no_fx_optimize_self_attn_region']: apply_rotation_optimization( @@ -453,6 +474,7 @@ def mock_save_pretrained_fn(*args, **kwargs): ) remove_hooks(model) + torch.cuda.empty_cache() if args.act_calibration: print("Apply act calibration...") @@ -489,12 +511,9 @@ def mock_save_pretrained_fn(*args, **kwargs): print("Bias correction applied.") if args.eval and not args.no_quantize: - print("Model eval...") - model = offload_model(model) - quant_ppl = compute_perplexity( - model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) - print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") - remove_hooks(model) + print("Evaluating quantized model...") + evaluate_model(model, validation_loader, args, tokenizer) + print("Quantized evaluation done.") if args.checkpoint_name is not None: print(f"Saving checkpoint to {args.checkpoint_name}")