date | tags |
---|---|
2024-01-14 |
paper |
Albert Gu, Karan Goel, Christopher Ré
arXiv Preprint
Year: 2021
Note 1: I have changed the bizarre notation used in the paper to the one used in Mamba. More specifically, the authors use u and x to denote input and state hidden vectors, respectively. Naturally, they have been changed in these notes by x and h, respectively.
Note 2: these notes cover an intro to State Models only, not the S4 model. The goal of this review was to get in context to understand Mamba paper.
The authors of this paper introduce a new state-space model called Structured State Space (S4). This model is an evolution of previous L-SSM (Linear State-Space Models) that solves the prohibitive memory and computational cost of its computation.
State space models are defined as follows
In the above formulas...
-
$\mathbf{h}(\cdot)$ is called the "state vector", and it represents the hidden state of the system$\mathbf{h}(t) \in \mathbb{R}^n$ . -
$\mathbf{y}(\cdot)$ is called the "output vector", and it is defined as$\mathbf{y}(t) \in \mathbb{R}^q$ . -
$\mathbf{x}$ is called the "input vector", with$\mathbf{x}(t) \in \mathbb{R}^p$ . -
$\mathbf{A}$ is the "state matrix", with$\dim[\mathbf{A}] = n \times n$ . -
$\mathbf{B}$ is the "input matrix", with$\dim[\mathbf{B}] = n \times p$ . -
$\mathbf{C}$ is the "output matrix", with$\dim[\mathbf{C}] = q \times n$ . -
$\mathbf{D}$ is the "feedthrough matrix".
SSMs are unstable due to the fact that they entail a recurrent computation, where the state vector is recurrently multiplied by
\begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & \text{if } n > k \ n+1 & \text{if } n = k \ 0 & \text{if } n < k \end{cases} . $$
This matrix employs Legendre polynomials in order to compress the input history of the input series
SSMs are defined for continuous domains. In neural networks, one is usually interested in discrete domains. This can be easily done through different methods such as Euler method, or as the authors of the paper show, with the bilinear method. The discrete approximation of the matrices stated below are denoted as
$$ \begin{aligned} h_{k} &= \bm{\overline{A}} h_{k-1} + \bm{\overline{B}} x_k & \bm{\overline{A}} &= (\bm{I} - \Delta/2 \cdot \bm{A})^{-1}(\bm{I} + \Delta/2 \cdot \bm{A}) & \ y_k &= \bm{\overline{C}} h_k & \bm{\overline{B}} &= (\bm{I} - \Delta/2 \cdot \bm{A})^{-1} \Delta \bm{B} \ && \bm{\overline{C}} &= \bm{C}
\end{aligned} $$
Due to the linear nature of SSMs, the state vectors can be computed in parallel convolving carefully crafted kernel with the input
For the first time step (
$$\begin{aligned}
h_0 = \bar{\mathbf{B}}x_0 \quad&\rightarrow&\quad y_0=\bar{\mathbf{C}}x_0 = \bar{\mathbf{C}}\bar{\mathbf{B}}x_0 \
h_1 = \bar{\mathbf{A}} h_0 + \bar{\mathbf{B}}x_1 = \bar{\mathbf{A}} \bar{\mathbf{B}}x_0 + \bar{\mathbf{B}}x_1 \quad&\rightarrow&\quad y_1 = \bar{\mathbf{C}}h_1 = \bar{\mathbf{C}}\bar{\mathbf{A}} \bar{\mathbf{B}}x_0 + \bar{\mathbf{C}}\bar{\mathbf{B}}x_1 \
h_2 = \bar{\mathbf{A}} h_1 + \bar{\mathbf{B}}x_2 = \bar{\mathbf{A}}^2 \bar{\mathbf{B}}x_0 + \bar{\mathbf{A}} \bar{\mathbf{B}}x_1 + \bar{\mathbf{B}}x_2 \quad&\rightarrow&\quad y=\bar{\mathbf{C}}h_2 = \bar{\mathbf{C}}\bar{\mathbf{A}}^2 \bar{\mathbf{B}}x_0 +\bar{\mathbf{C}}\bar{\mathbf{A}} \bar{\mathbf{B}}x_1 + \bar{\mathbf{C}}\bar{\mathbf{B}}x_2\ &...& \end{aligned} $$
Generalizing the pattern above, we get the following expression.
$$ \begin{aligned} h_k = \sum_{i=0}^k\bar{\mathbf{A}}^{k-i} \bar{\mathbf{B}}x_i \quad&\rightarrow&\quad y_k = \bar{\mathbf{C}}h_k
\end{aligned} $$
This can be seen as a convolution operation where
- State Models are algorithms for sequential modelling that allow RNN-like computation (for inference) as well as CNN-like computation (for teacher-forcing training).
- State Models work by building a mapping
$x(t)\in\mathbb{R}\rightarrow y(t)\in\mathbb{R}$ through an implicit hidden state$h(t)\in\mathbb{R}$ . - State Models are linear in time-dimension. They do not have any non-linearity, which allows them to be run in parallel (convolutional mode).
- The state-transition matrix
$A$ (and the input$B$ and output$C$ ones) is fixed, and does not depend on the input, the hidden state or the time step. It's commonly said that the model dynamics of SM are linear-time invariant (LTI). In other words, the same computation is repeated over and over for each time step. - The SMs are applied independently to each channel, hence some mechanism to merge channels is required beyond SMs.
- Umar Jamil's video on "Mamba and S4 Explained" https://www.youtube.com/watch?v=8Q_tqwpTpVU
- Albert Gu's presentation (main author): https://www.youtube.com/watch?v=luCBXCErkCs
- Sasha Rush's blog on "The annotated S4" https://srush.github.io/annotated-s4