Skip to content

Commit

Permalink
Publish
Browse files Browse the repository at this point in the history
  • Loading branch information
DouglasOrr committed Apr 22, 2024
1 parent beb955d commit d085a9b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions 2024-04-transformers/article.html
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,20 @@ <h2 id="multi-layer-perceptron-mlp-geglu">Multi-layer perceptron (MLP, GeGLU)</h
<p>First, make two separate projections (dot products with trained parameter matrices) to get two 16384-vectors. One of these vectors, called the <em>gate</em> is passed through a nonlinear function called <em>gelu</em> that applies to each element. We'll treat the function as a black box, but to a first approximation, <code>gelu(x)</code> is quite close to <code>max(0, x)</code>. The gate and other vector are multiplied element-by-element, then another trained down-projection produces the result, a 2048-vector.</p>
<details>
<p><summary>Understanding the MLP</summary></p>
<p>The MLP is a function from a 2048-vector representing a single token to a 2048-vector. The code shown above runs in parallel over the sequence axis, but unlike in attention, each token is processed independently. In this section, we'll build up the complex behaviour of Gemma's GeGLU MLP with a example based on 2-vectors (and a 3-vector inside the MLP).</p>
<p><strong>A. ReLU</strong></p>
<p>We start with a simplified version specified by the original transformer model, using the <em>rectified linear unit</em> or ReLU, <code>relu(a) = max(0, a)</code>:</p>
<p>The MLP is a function from a 2048-vector representing a single token to a 2048-vector. The code shown above runs in parallel over the sequence axis, but unlike in attention, each token is processed independently. In this section, we'll build up the complex behaviour of Gemma's GeGLU MLP with a tiny example based on 2-vectors (and a 3-vector inside the MLP).</p>
<p><em>Note that the acronym MLP stands for multi-layer perceptron, which is slightly old-fashioned language to talk about a chain of dot product (projections) and nonlinearities.</em></p>
<p><strong>A. Linear</strong></p>
<p>We start with a simple linear projection (dot product) from input to output, specified by a random $2 \times 2$ matrix. Each output component is a weighted sum of the inputs. The code is simply:</p>
<pre><code class="language-python">y = x @ proj
# x.shape = y.shape = (2,)</code></pre>
<p>If we look at the first component of the output as a function of the inputs, we see:</p>
<p><img alt="3D surface plot showing a flat slope." class="img-fluid" src="img/mlp_linear.png" /></p>
<p>This is what a linear projection always looks like — a flat slope. It's certainly possible for this function to capture some interesting properties of the data, especially when it works on 2048-vectors rather than 2-vectors. However, it will become much more powerful with a few additions.</p>
<p><strong>B. ReLU</strong></p>
<p>The core idea of the MLP is that we wish to introduce <em>depth</em> into the model, a sequence of layers that transform the input in steps. However, the dot product has the property that a sequence of dot products can be reduced to a single dot product, i.e. there exists a <code>proj_b</code> such that <code>(x @ proj_1) @ proj_2 == x @ proj_b</code>.</p>
<p>Sequences of dot products need to be broken up if they're going to be any more powerful than the simple linear projection we've already seen. The simplest way to do this is by transforming each element of the vector individually by an elementwise nonlinear function. One such function is the <em>rectified linear unit</em> or ReLU, <code>relu(a) = max(0, a)</code>:</p>
<p><img alt="Plot of z = relu(a), with a flat portion below zero, then a linear portion above zero." class="img-fluid" src="img/relu.png" /></p>
<p>Our ReLU MLP runs:</p>
<p>Our ReLU MLP now runs:</p>
<pre><code class="language-python">z = relu(x @ gate_proj)
y = z @ down_proj
# x.shape = y.shape = (2,)
Expand All @@ -267,23 +276,23 @@ <h2 id="multi-layer-perceptron-mlp-geglu">Multi-layer perceptron (MLP, GeGLU)</h
<p>In the top right off-white segment, we have all three components of <code>z</code> "active" (not saturating at zero), so in this region we have <code>z = x @ gate_proj</code> (the ReLU disappears). In the yellow region on the left, we know the blue component <code>z[2]</code> is saturating, so in this region we have <code>z = (x @ gate_proj) * [1, 1, 0]</code>, effectively removing that component. Within each coloured region, <code>z</code> is a linear function of <code>x</code>.</p>
<p>Once we run the down-projection, each component of <code>y</code> is a dot product between <code>z</code> and a vector of trained weights, so the result remains piecewise-linear, transitioning at the boundaries we've just seen:</p>
<p><img alt="3D surface plot showing a piecewise linear function of two input components. Each piece is coloured as per the pinwheel map above." class="img-fluid" src="img/mlp_relu.png" /></p>
<p><strong>B. ReGLU</strong></p>
<p><strong>C. ReGLU</strong></p>
<p>These piecewise linear functions are surprisingly powerful already — the pretraining procedure can manipulate the transitions as well as the slopes of each region. But we might propose more power by making each region quadratic rather than linear. This idea gives us the <em>gated linear unit</em> (GLU):</p>
<pre><code class="language-python">z = relu(x @ gate_proj) * (x @ up_proj)
y = z @ down_proj</code></pre>
<p>With the same <code>gate_proj</code>, the regions in this version are exactly the same as before; the only difference is that within each region, we have a quadratic function of <code>x</code>.</p>
<p><img alt="3D surface plot showing a piecewise quadratic function of two input components." class="img-fluid" src="img/mlp_reglu.png" /></p>
<p>Notice we can still have sharp edges at the region boundaries, but within each region the function is now curved.</p>
<p><strong>C. GeGLU</strong></p>
<p><strong>D. GeGLU</strong></p>
<p>The final change is to substitute the Gaussian error linear unit (GELU) for the ReLU. The definition isn't too important for our discussion, just that it looks like a smoother version of ReLU:</p>
<p><img alt="A plot of GELU(a) and ReLU(a). From a=-2 to a=2, GELU starts near zero, then drops before reaching z=0 when a=0 before climbing to approach z=a." class="img-fluid" src="img/gelu.png" /></p>
<p>The idea of GELU is that it will allow us to build smoother functions than ReLU. I think this is done primarily for the sake of optimisation during training, but it might be that this gives better shapes of function for inference too.</p>
<p>Plugging this into the gated linear unit, we have the full form of an GeGLU MLP as per the original code. It looks quite similar to the ReGLU, but you should be able to see that the transitions between regions are considerably smoother.</p>
<p><img alt="3D surface plot showing a smooth function of two input components." class="img-fluid" src="img/mlp_geglu.png" /></p>
<p>Our approach of considering distinct regions is broken down by the GELU, which doesn't saturate at exactly zero, so does not create strict regions where components of <code>z</code> can be discarded. However, since the GELU is quite similar to ReLU, it's still somewhat reasonable to think in terms of piecewise quadratic regions, at least at a coarse enough scale.</p>
<p><strong>Summary</strong></p>
<p>A final figure might help review the journey we've been on, from ReLU -&gt; ReGLU -&gt; GeGLU. To make things legible, we're now looking at a slice through the surfaces we've seen so far, setting <code>x[1]</code> to a constant value, and just looking at how <code>y[0]</code> depends on <code>x[0]</code>.</p>
<p><img alt="Three line plots, shown as x[0] varies from -2 to 2. The first, ReLU, is piecewise linear. The second, ReGLU, is piecewise quadratic with gradient discontinuities. The third, GeGLU, is smooth but still vaguely quadratic." class="img-fluid" src="img/mlp_slice.png" /></p>
<p>A final figure might help review the journey we've been on, from Linear -&gt; ReLU -&gt; ReGLU -&gt; GeGLU. To make things legible, we're now looking at a slice through the surfaces we've seen so far, setting <code>x[1]</code> to a constant value, and just looking at how <code>y[0]</code> depends on <code>x[0]</code>.</p>
<p><img alt="Four line plots, shown as x[0] varies from -2 to 2. The first, linear is linear. The second, ReLU, is piecewise linear. The third, ReGLU, is piecewise quadratic with gradient discontinuities. The fourth, GeGLU, is smooth but still vaguely quadratic." class="img-fluid" src="img/mlp_slice.png" /></p>
<p>So Gemma's MLP, the GeGLU, can be thought of as a piecewise-quadratic function with smooth boundaries between the pieces. Where our example had 6 regions across a 2-vector input, Gemma's MLPs may have a vast number of regions (perhaps $10^{2000}$) across their 2048-vector input.</p>
<p>The purpose of the MLP in Gemma is to use this function to independently transform each token, ready to form another attention query or ready to match against output tokens. Although MLPs cannot fuse information from across the context by themselves (which is the fundamental task of a language model), our experience shows that including the MLP makes attention much more efficient at doing exactly this.</p>
</details>
Expand Down
Binary file added 2024-04-transformers/img/mlp_linear.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified 2024-04-transformers/img/mlp_slice.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit d085a9b

Please sign in to comment.