-
Notifications
You must be signed in to change notification settings - Fork 1
/
index.html
294 lines (217 loc) · 23.2 KB
/
index.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
<html>
<head>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
<script src='https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=default'></script>
<!-- <script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/3.2.2/es5/latest.min.js"></script> -->
<title>Low-Rank Pruning of Llama2</title>
<link rel="stylesheet" type="text/css" href="styling.css">
<link rel="icon" type="image/png" href="figs/aana_logo.png">
<meta name="description" content="An exploration of model pruning for machine learning, focusing on the reduction of model size and speed optimization for deployment on resource-constrained devices. Discusses structured and unstructured sparsity, low-rank pruning, and introduces a new rank reduction that is compatible with LoRA (Low-Rank Adaptation) approach for efficient training of large language models like LLama2-7B.">
<meta name="keywords" content="Model Pruning, Machine Learning, Low-Rank Pruning, Sparsity, LoRA, LLama2-7B, Model Compression, Singular Value Decomposition, Transformer Models, Neural Networks, AI Optimization">
<meta name="Hicham Badri and Appu Shaji" content="Mobius Labs GmbH">
<!-- Specific tags for Open Graph / social media sharing -->
<meta property="og:title" content="Low Rank Pruning of Llama2">
<meta property="og:description" content="An in-depth article discussing the intricacies of model pruning in machine learning, with a focus on low-rank techniques and their application in large language models for improved performance and efficiency.">
<meta property="og:image" content="https://mobiusml.github.io/low-rank-llama2/figs/pseudo-code.png">
<meta property="og:url" content="https://mobiusml.github.io/low-rank-llama2/">
<meta property="og:type" content="article">
<!-- Twitter Card data -->
<meta name="twitter:card" content="summary_large_image">
<meta name="twitter:title" content="Low Rank Pruning of Llama2">
<meta name="twitter:description" content="Discover the advanced strategies for model pruning in AI, highlighting low-rank pruning and sparsity-aware optimizations for large language models such as LLama2-7B.">
<meta name="twitter:image" content="https://mobiusml.github.io/low-rank-llama2/figs/pseudo-code.png">
<meta name="twitter:creator" content="@appughar">
<!-- Meta tags for article publishing date and modification date -->
<meta name="article:published_time" content="2023-11-03T08:00:00+00:00">
<meta name="article:modified_time" content="2023-11-03T09:00:00+00:00">
</head>
<body>
<article id="low-rank-sparsity" class="page sans">
<header>
<h1 class="page-title">Low-Rank Pruning of Llama2</h1>
</header>
<div class="page-body">
<p><a href="https://scholar.google.com/citations?user=LxweMX4AAAAJ&hl=en"><mark
class="highlight-gray">Hicham Badri</mark></a><mark class="highlight-gray">, </mark><a
href="https://scholar.google.com/citations?user=HxZDDzUAAAAJ&hl=en"><mark class="highlight-gray">Appu Shaji</mark></a><mark
class="highlight-gray"></mark></p>
<p><mark class="highlight-gray"><a href="https://www.mobiuslabs.com/"><mark
class="highlight-gray">Mobius Labs GmbH</mark></a></p>
<hr />
<p>In the ever-evolving landscape of artificial intelligence (AI), one undeniable trend has emerged in recent years: the relentless growth in the size and complexity of machine learning models. More specifically, large language models (LLMs) that mainly rely on transformers as building blocks, are reaching a substantial number of parameters and require a significant amount of compute that is expected to increase with larger and larger models being released.
</p>
<p>In this article, we explore low-rankness as a pruning technique of the <a href="https://huggingface.co/meta-llama/Llama-2-7b">LLama2-7B base model</a>. We show that, by splitting almost all the linear layer weights into low-rank pairs <em>without fine-tuning</em> and leveraging LoRA for custom training, we can achieve the following without <em>implementing custom kernels</em>:
<ul>
<li>~50% reduction in the number of parameters.</li>
<li>Up to ~50% faster training vs. <a href="https://github.com/TimDettmers/bitsandbytes">bitsandbytes’s</a> 8-bit quantization.</li>
<li>Up to ~1.25x inference speed-up.</li>
</ul>
</p>
<!-- <p id="c8835517-e8ec-4781-8d42-047d63df4d94" class=""><strong>Paper</strong>: <a
href="https://arxiv.org/abs/2310.06694">https://arxiv.org/abs/2310.06694</a>
<strong>Code</strong>: <a
href="https://github.com/princeton-nlp/LLM-Shearing">https://github.com/princeton-nlp/LLM-Shearing</a>
<strong>Models</strong>: <a
href="https://huggingface.co/princeton-nlp/Sheared-LLaMA-1.3B">Sheared-LLaMA-1.3B</a>, <a
href="https://huggingface.co/princeton-nlp/Sheared-LLaMA-2.7B">Sheared-LLaMA-2.7B</a>
</p> -->
<hr id="header_seperator" />
<div class="column-list">
<div style="width:32%" class="column">
<!-- <p class="page-description"><img src="./baby_aana.png" /></p> -->
<figure class="image" style="text-align:left"><a
href="figs/baby_aana.png"><img
style="width:240px"
src="figs/baby_aana.png" /></a>
</figure>
<p>
<strong><strong><strong><strong><strong><strong><strong>Table of
Contents</strong></strong></strong></strong></strong></strong></strong>
</p>
<nav class="block-color-gray table_of_contents">
<div class="table_of_contents-item table_of_contents-indent-0"><a class="table_of_contents-link"
href="#intro">Introduction</a></div>
<div class="table_of_contents-item table_of_contents-indent-0"><a class="table_of_contents-link"
href="#lowrankpruning">Low-Rank Pruning</a>
</div>
<div class="table_of_contents-item table_of_contents-indent-0"><a class="table_of_contents-link"
href="#pruningllama2">Low-Rank Pruning of Llama2 Models</a></div>
<div class="table_of_contents-item table_of_contents-indent-0"><a class="table_of_contents-link"
href="#benchmark">Speed Benchmarks</a></div>
<div class="table_of_contents-item table_of_contents-indent-0"><a class="table_of_contents-link"
href="#dataset">Dataset Performance</a></div>
<div class="table_of_contents-item table_of_contents-indent-0"><a class="table_of_contents-link"
href="#conclusion">Conclusion</a></div>
<hr />
<div class="table_of_contents-item table_of_contents-indent-0"> Support code is available at <a href="https://github.com/mobiusml/low-rank-llama2/tree/main/code"><mark
class="highlight-gray">https://github.com/mobiusml/low-rank-llama2/tree/main/code</mark></a></div>
<hr />
<div class="table_of_contents-item table_of_contents-indent-0" ">Coming soon: We will be releasing a blog post about model quantization along with its weights.</div>
<hr />
</nav>
</div>
<div style="width:75%" class="column">
<h2 id="intro" class="">Introduction</h2>
<p>Model pruning refers to the process of removing redundant information from machine learning models to make them “leaner”. As a result, the pruned model is smaller in size and should run faster which is suitable for deployment on resource-constrained devices or in real-time applications. Pruning can be combined with other techniques such as quantization to further optimize runtime. The most popular pruning approaches are based on discarding neurons, layer channels or entire layers. This kind of pruning is referred to as “sparsification”.
</p>
<p>In practice however, sparse pruning has many limitations. In order to achieve actual speed-up in practice, custom sparsity-aware matrix multiplication (<i>matmul</i>) operations are required. For the moment, this is only partially supported in <a href="https://developer.nvidia.com/blog/accelerating-inference-with-sparsity-using-ampere-and-tensorrt/">Ampere GPUs</a> or on CPUs via <a href="https://neuralmagic.com/">NeuralMagic</a> . In Pytorch, sparse matrix multiplication operations are not optimized. For example, there is no implementation available of the batched <i>matmul</i> operation with sparse matrices. Rewriting it with the existing operation requires some reshaping and the result is 2-3x slower performance.
</p>
<p>Structured sparsity on the other hand consists in discarding weights in a structured way. For instance, we can remove columns, remove channels, block matrices, etc. This way, in theory, the model can be pruned without requiring specialized software/hardware for optimized runtime. Some structured sparsity methods still require optimized software to achieve faster runtime. For example, block-sparsity requires implementing dedicated GPU kernels for block-sparse <i>matmul</i> such as <a href="https://openai.com/research/block-sparse-gpu-kernels">OpenAI's Block-sparse GPU kernels</a>.
</p>
<p>In practice however, structured sparsity cannot be pushed too far without a larger drop in accuracy compared to unstructured sparsity. As a result, the performance gain is usually very limited.
</p>
<h2 id="lowrankpruning" class="">Low Rank Pruning</h2>
The idea of low-rank pruning revolves around factorizing the weight matrix <b>W</b> of a linear layer as a matrix multiplication of a pair of two matrices <b>A</b> and <b>B</b>, such that <b>A</b> and <b>B</b> have much less columns and rows respectively:
<figure id="low_rank"><img style="width:480px" src="figs/Matrix2.png" /></figure>
<p>Ideally, we would like the chain <i>matmul</i> operation with <b>A</b> and <b>B</b> to be faster and take less memory, while the overall model prediction stays as close as possible to the original prediction with unaltered weights. We refer to the number of columns of <b>A</b>/number of rows of <b>B</b> as the maximum rank (denoted by <i>max_rank</i>) in the rest of the article. </p>
<p>There are various ways to achieve such a factorization (SVD, QR, etc.). We use the SVD decomposition as follows to get the matrix pairs: </p>
<figure><img style="width:480px;" src="figs/get_lowrank_tuple.png" /></figure>
<p>The idea of using low-rankness is not new in the context of Transformer models. The adoption of low-rank estimation has garnered considerable attention, primarily within the domain of model compression. The works in <a href="https://arxiv.org/pdf/2004.04124.pdf"> https://arxiv.org/pdf/2004.04124.pdf</a> and <a href="https://neurips2022-enlsp.github.io/papers/paper_33.pdf">https://neurips2022-enlsp.github.io/papers/paper_33.pdf</a> study low-rank compression of BERT and GPT models, respectively. An additional approach, documented in <a href="https://openreview.net/pdf?id=uPv9Y3gmAI5">https://openreview.net/pdf?id=uPv9Y3gmAI5</a>, employs weighted low-rank estimation to compress BERT-based models. Furthermore, the research outlined in <a href="https://cs.nju.edu.cn/wujx/paper/AAAI2023_AFM.pdf>https://cs.nju.edu.cn/wujx/paper/AAAI2023_AFM.pdf"https://cs.nju.edu.cn/wujx/paper/AAAI2023_AFM.pdf>https://cs.nju.edu.cn/wujx/paper/AAAI2023_AFM.pdf</a> explores an innovative perspective by focusing on low-rank compression of the model features, as opposed to the model weights.</p>
<p>Among these approaches, one that has gained significant popularity is <a href="https://arxiv.org/abs/2106.09685">LoRA (Low-Rank Adaptation)</a>. LoRA's core concept revolves around training supplementary low-rank parameters to adapt large models. This technique enables the training of expansive models while drastically reducing the memory requirements. </p>
<p>Pruning typically requires fine-tuning on a large dataset, which is very expensive even for smaller LLM models such as LLama2-7B. We find that, by applying low-rank estimation, freezing the weights and leveraging LoRA instead for custom training, we can achieve significant efficiency as we explain in the next section.
</p>
<h2 id="pruningllama2" class="">Low-Rank Pruning of Llama2 Models</h2>
<p>When we analyze the weights of the Llama2-7B model, we find that many are in fact already low-rank, especially those of the attention layers (<b>Q</b>,<b>K</b>,<b>O</b>). The graph below shows the distribution of the average normalized singular values per layer type. We normalize the singular values by the highest value (which is the same as normalizing the matrix weight by its L2 norm) so we can average the singular values across the layers and get a single plot per layer type. We can clearly see that most of the energy is concentrated in a subset of the singular values. More specifically, about 80% of the energy is concentrated in the first half of the singular values of the <b>Q</b>,<b>K</b>,<b>V</b>,<b>O</b> layers of the attention modules. The first layers of the attention module tend to have an even lower-rank. For instance, 88% of the energy of the first <b>Q</b> layer is concentrated in the first 1024 (25%) of its singular values.</p>
<figure><img style="width:480px;" src="figs/svd_distribution.png"/></figure>
<p>In practice, we found that the rank of the original attention and MLP layers can be reduced from 4096 to 1024 and 2048 respectively, while still delivering good performance after LoRA training. This is a 4x rank reduction in the attention layers and 2x for the MLP layers, which is quite aggressive given that these weights are frozen after pruning.
</p>
<p>We summarize the steps for training and inference using the proposed approach:
</p>
<h4>Training Mode</h4>
<ul>
<li>For each linear layer, we run SVD on the weights of the linear layers <b>W</b> to get the <b>A</b>,<b>B</b> matrix pairs such that the matrix multiplication <b>AB</b> estimates <b>W</b> using the predefined <i>max_rank</i> value to truncate the singular values as explained in the previous section. The only layer that we keep full-rank is the <b>v_proj</b>. This is because the rank of the weights of this layer tends to be higher.</li>
<li>We freeze all the weights and use LoRA with the <b>r</b> parameter to create the new trainable parameters.
</li>
</ul>
<h4>Inference mode</h4>
<p>After training, we need to re-estimate new pairs of matrices that combine the original low-rank weights and the newly trained LoRA weights:</p>
<ul>
<li>For each linear layer that was pruned, we have the <b>A</b>,<b>B</b> as well as the LoRA pairs that we refer to as <b>AL</b>,<b>BL</b> </li>
<li><a href="https://www.ic.unicamp.br/~meidanis/PUB/Doutorado/2012-Biller/Marsaglia1964.pdf">Since the rank of the sum of two matrices is lower or equal than the sum of their ranks</a>
$$ {rank({\bf AB}+{\bf A_L} {\bf B_L} ) \le rank({\bf AB}) + rank({\bf A_LB_L})} $$
we can safely combine the 4 weights by applying truncated SVD on the sum of their matrix multiplications using the sum of their ranks to build the new low-rank pair:
$$ {\bf AB} + {\bf A_LB_L} \Longrightarrow{\bf \bar{A}\bar{B}}$$
$$ { rank({\bf \bar{A}\bar{B}} ) = \text{max_rank} + \text{r} } $$
</li>
<li>Now we can use the new pair and remove the older <b>A</b>,<b>B</b> and LoRA weights.
</li>
</ul>
<p>The illustration below shows the difference between the standard LoRA approach and the proposed low-rank LoRA merging method. Note that the result is a pair of matrices.</p>
<figure><center></center><img style="width:480px" src="figs/merging.png"/></center></figure>
<p>The code below summarizes the merging logic:</p>
<figure><center></center><img style="width:640px" src="figs/pseudo-code.png" /></center></figure>
<h2 id="benchmark">Speed Benchmark</h2>
<p>We report the inference speed-up in comparison to the original LLama2-7B model. We employ the HuggingFace implementations with fp16 precision. When we merge the LoRA weights into the original model, the resulting matrices maintain the same dimensions as those in the original model. However, in the pruned version, the rank of the matrices increases by the LoRA rank r. For instance, in the attention layers, the initial weight matrix W has dimensions of 4096x4096. By using a rank of 2048 and a LoRA rank of 32, the resulting pairs A and B will be 4096x2080 and 2080x4096, respectively. Reducing the rank leads to a faster speed boost but has a detrimental effect on prediction accuracy.</p>
<figure style="align-items: left; justify-content: left;">
<img style="margin-right: 10px; max-width: 75%; height: auto;" src="figs/titan.png" />
<img style="margin-right: 10px; max-width: 75%; height: auto;" src="figs/a100.png" />
</figure>
<h2 id="dataset">Dataset Performance</h2>
<p>We present performance results on 5 datasets, evaluating both the unaltered and pruned LLama2-7B models using the perplexity metric. In the case of the original model, we use the default LoRA settings (<b>r</b>=8). Conversely, in the pruned model, we raise the LoRA rank to 32. Remarkably, the pruned model exhibits strong performance despite the removal of approximately half of the original weights, all without any fine-tuning!</p>
<p>It is worth noting that the performance of the pruned model on OpenOrca-1M is better than that of the original model on 100k samples. This indicates that the pruned model has the capacity to learn but needs more samples to compensate for the noise introduced by pruning.</p>
<table>
<tr>
<td><b>Dataset</b></td>
<td><b>LLama2-7B</b></td>
<td><b>LLama2-7B pruned</b></td>
</tr>
<tr>
<td><a href="https://huggingface.co/datasets/vicgalle/alpaca-gpt4">vicgalle/alpaca-gpt4</a></td>
<td>3.49</td>
<td>4.11</td>
</tr>
<tr>
<td><a href="https://huggingface.co/datasets/databricks/databricks-dolly-15k">databricks/databricks-dolly-15k</a></td>
<td>4.13</td>
<td>5.86</td>
</tr>
<tr>
<td><a href="https://huggingface.co/datasets/knkarthick/dialogsum">knkarthick/dialogsum</a></td>
<td>3.78</td>
<td>4.82</td>
</tr>
<tr>
<td><a href="https://huggingface.co/datasets/ArtifactAI/arxiv-math-instruct-50k">ArtifactAI/arxiv-math-instruct-50k</a></td>
<td>3.08</td>
<td>3.73</td>
</tr>
<tr>
<td><a href="https://huggingface.co/datasets/Open-Orca/OpenOrca">Open-Orca/OpenOrca - 100k </a></td>
<td>3.51</td>
<td>4.27</td>
</tr>
<tr>
<td><a href="https://huggingface.co/datasets/Open-Orca/OpenOrca">Open-Orca/OpenOrca - 1M</a></td>
<td>-</td>
<td>3.43</td>
</tr>
<tr>
<td><i>Average</i></td>
<td><i>3.60</i></td>
<td><i>4.56</i></td>
</tr>
</table>
<h2 id="conclusion">Conclusion</h2>
<p>In this article, we've demonstrated the utility of low-rank pruning as an effective method for accelerating large language models like LLama2-7B. Unlike sparse pruning, which often requires custom hardware or software configurations to realize significant speed gains, low-rank pruning doesn't require specialized kernel operations and can seamlessly integrate with existing matrix multiplication (<i><a href="https://pytorch.org/blog/inside-the-matrix/">matmul</a></i>) implementations.
</p>
<p>Nevertheless, there is ample scope for further refinements, and we aspire for this article to serve as an inspiration to the research community. We encourage researchers to embrace low-rank pruning and explore its synergistic potential when combined with other pruning and quantization techniques.
</p>
<p>We provide code examples at <a href="https://github.com/mobiusml/low-rank-llama2/tree/main/code">https://github.com/mobiusml/low-rank-llama2/tree/main/code</a>
</p>
<hr />
<div>
<p style="text-align: center;">Please feel free to <a href="mailto:hicham@mobiuslabs.com">contact us</a>.</p>
<p style="text-align: center; color:hotpink;">Coming soon: We will be releasing a blog post about model quantization.</p>
</div>
</div>
</div>
<p id="d9be7859-86c8-4e9e-8957-b0127ad9431d" class="">
<div class="indented">
<p id="7b0d7f13-0909-4e80-97fe-e0102053cc62" class="">
</p>
</div>
</p>
</div>
</article>
</body>
</html>