The Essense of Global Convolution Models

There are recent exciting developments of AI models for sub-quadratic long context modeling involving the use of global convolutions, ranging from Structured State Space Models (S4) (Gu et al., 2022) (Gu et al., 2022), Gated State Spaces (GSS) (Mehta et al., 2023), H3 (Fu et al., 2023), and most recently, Hyena models (Poli et al., 2023; Nguyen et al., 2023). In this blog post, we will attempt to provide a unified perspective on these models and show how they are related to each other. We will also provide a background on necessary topics, including convolution, fast fourier transform, state space models, and attention.

Prelude

Throughout this blog post, the common goal is to understand how various models define context operators. In the single feature dimension case, given a sequence \(\{ u_i \}_{i=0}^{L}, u_i \in \mathbb{R}\)1, we are interested in effective ways to build a contextualized representation \(y_j \in \mathbb{R}\) that captures the context \(\{u_i\}_{i=0}^j\), which is the input up to time \(j\). A good context representation \(y_j\) can have significant implications. For instance, it can allow us to predict the next time step more accurately, which is useful for various tasks such as forecasting or language modeling.

Context Operator.

Attention (Vaswani et al., 2017), for instance, is a popular context operator where \(y_j\) is a weighted sum of the input sequence \(\{u_i\}_{i=0}^j\), based on an input-dependent weight \(A_{ij}\) computed from softmax over dot product between each token and every other tokens in the past. In a vector form, we can also view attention as \(\vec{y} = A \vec{u}\). In general, an context operator can be of any form \(\vec{y} = f(\vec{u})\) where \(f\) is a function. In this blog post, we will see a global convolution operator can be an effective operator \(f\), which builds contextual representation in a global way by aggregating information from all input positions. This is in contrast to short or local convolution where the kernel is of fixed length, popular in convolutional neural networks. The global convolution kernel can also depend on data itself, giving rise of implicit convolution kernels.

Notation

We use the following notations throughout the blog post:

Background

The goal for this blog post is to be self-contained and accessible to the broader audience. We will cover background necessary to deeply understand related convolution models and how they compare and connect with transformer’s attention. For instance, we cover the core idea and derivation of convolution theorem and Fast Fourier Transform (FFT), which is at the heart of being able to do long range model with log linear computation. We provide background section on orthonomal basis in function space, which will be important to understand the construction of HiPPO matrix for continuous-time input memorization problem with state space models. We will also cover einstein summation, as it provides a convenient way to tie different operations together under general (linear) tensor operations and is used throughout the blog post. We will also briefly cover transformers attention and its linear version, which provides connection to RNNs and the recurrence nature of state space models.

If you are already familiar with certain topics, feel free to skip. Below outlines the recommended background for each of the convolution models so that one can decide what to focus on.

A. Convolution Theorem and Fast Fourier Transform

In this section, we will cover the following topics:

What is a Convolution?

Let us consider two N-dimensional vectors \(\mathbf{a}, \mathbf{b} \in \mathbb{R}^N\)2. The convolution of the two vectors, denoted as \(\mathbf{a} * \mathbf{b}\) is defined as:

\[\begin{align*} c_i = (\mathbf{a} * \mathbf{b})_m &= \sum_{n=0}^{N-1} a_n b_{m-n} \\ \end{align*}\]

In this notation, we implicitly assume that \(a_n\) or \(b_n\) where \(n \not \in \{0, \dots, N-1\}\) is undefined and is treated as \(0\).

High-Level Intuition

We can see convolution as a way to combine two signals. The first signal is the convolution kernel \(\mathbf{a}\) and the second signal is the input signal \(\mathbf{b}\) (or vice versa since the operation is commutative). The approach of combining signals in convolution is such that for the output \({c}_m\) gather signals from all input pairs \(\{a_{m-n}, b_m\}\) whose indicies add up to \(m\) exactly.

To make it more concrete, let’s expand this out with \(N=4\).

\[\begin{align*} c_0 &= a_0 b_0 \\ c_1 &= a_0 b_1 + a_1 b_0 \\ c_2 &= a_0 b_2 + a_1 b_1 + a_2 b_0 \\ c_3 &= a_0 b_3 + a_1 b_2 + a_2 b_1 + a_3 b_0 \\ c_4 &= a_1 b_3 + a_2 b_2 + a_3 b_1 \\ c_5 &= a_2 b_3 + a_3 b_2 \\ c_6 &= a_3 b_3 \\ \end{align*}\]

where we note that \(a_j\) or \(b_j\) are treated as \(0\) for \(j \not \in \{0, \dots, N-1 \}\). We can see that, for instance, \(c_2\) is the sum of \(a_0 b_2\) and \(a_1 b_1\) and \(a_2 b_0\), where the indicies add up to \(2\).

Convolution as a way to aggregate information from inputs.

We can also write convolution as a matrix multiplication:

\[\mathbf{c} = \begin{pmatrix} c_0 \\ c_1 \\ c_2 \\ c_3 \\ c_4 \\ c_5 \\ c_6 \\ \end{pmatrix} = \begin{pmatrix} a_0 & 0 & 0 & 0 \\ a_1 & a_0 & 0 & 0 \\ a_2 & a_1 & a_0 & 0 \\ a_3 & a_2 & a_1 & a_0 \\ 0 & a_3 & a_2 & a_1 \\ 0 & 0 & a_3 & a_2 \\ 0 & 0 & 0 & a_3 \\ \end{pmatrix} \begin{pmatrix} b_0 \\ b_1 \\ b_2 \\ b_3 \\ \end{pmatrix} \tag{Convolution}\]

or

\[\mathbf{c} = S_a \mathbf{b}\]

where \(S_a\) is the convolution matrix representation of vector \(\mathbf{a}\). We observe that \(S_a\) is Toeplitz3, meaning that it each diagonal has constant values from left to right. Since the convolution operator is commutative (\(\mathbf{a} * \mathbf{b} = \mathbf{b} * \mathbf{a}\)), we also have \(\mathbf{c} = S_b \mathbf{a}\) where \(S_b\) is the convolution matrix representation of vector \(\mathbf{b}\).

Convolution kernel and its matrix representation.

I included a Colab notebook to visualize the kernel and its matrix representation here.

Causal vs Non-Causal Convolution

Let’s take a look in the case where \(a_n\) have non-zero values where \(n < 0\). Again, let’s use \(N=4\) for simplicity.

\[\begin{align*} c_0 &= a_{-4} b_0 + a_{-3} b_1 + a_{-2} b_2 + a_{-1} b_1 + \textcolor{red}{a_0 b_0} + a_1 b_{-1} + a_2 b_{-2} + a_3 b_{-3} \\ c_1 &= a_{-4} b_1 + a_{-3} b_2 + a_{-2} b_3 + a_{-1} b_2 + \textcolor{red}{a_0 b_1} + \textcolor{red}{a_1 b_0} + a_2 b_{-1} + a_3 b_{-2} \\ c_2 &= a_{-4} b_2 + a_{-3} b_3 + a_{-2} b_4 + a_{-1} b_3 + \textcolor{red}{a_0 b_2} + \textcolor{red}{a_1 b_1} + \textcolor{red}{a_2 b_0} + a_3 b_{-1} \\ \vdots & \end{align*}\]

Observe that if \(a_{n} \ne 0\) for \(n < 0\), then, for example, \(c_2\) gets the contribution from \(a_{-3}b_{3}\) where \(b_3\) is an input signal from a future time step. If we would like to perform causal modeling where the output of the current time step is influenced only by current and previous time steps (and not future steps), all \(a_n\) for \(n < 0\) must be \(0\).

In general, due to how we index the inputs where \(a_{n} = 0\) for \(n < 0\) in the original definition, this results in the convolution being causal, which means that the signal \(c_i\) can only depend on input at time \(i\) or before. This is because if there is a term \(b_{i+m}\) for \(m>0\) that contributes to \(c_i\), the corresponding term from \(\mathbf{a}\) is \(a_{-m}\) which is zero (which makes \(b_{i+m}a_{-m} = 0\)).

In addition, throughout this blog post, we are mainly interested in mapping an input signal \(\mathbf{b}\) to and output \(\mathbf{c}\) on the same time domain \(t = 0, \dots, T-1\), in which case we can use the truncated version where \(c_m\) are only defined for \(m \in \{0, \dots, T-1\}\), implying that the corresponding Toeplitz matrix is square and is lower diagonal.

\[\mathbf{c} = \begin{pmatrix} c_0 \\ c_1 \\ c_2 \\ c_3 \\ \end{pmatrix} = \begin{pmatrix} a_0 & 0 & 0 & 0 \\ a_1 & a_0 & 0 & 0 \\ a_2 & a_1 & a_0 & 0 \\ a_3 & a_2 & a_1 & a_0 \\ \end{pmatrix} \begin{pmatrix} b_0 \\ b_1 \\ b_2 \\ b_3 \\ \end{pmatrix}\]
Examples of Convolution

We look at a few cases of convolution to develop some intuition. First, let’s consider random variables \(X\) and \(Y\) corresponding to rolling two dice. The probability distribution of \(X\) (or \(Y\)) are:

\[\begin{align*} p_X(x) &= \begin{cases} \frac{1}{6} & x \in \{1, 2, 3, 4, 5, 6\} \\ 0 & \text{otherwise} \end{cases} \end{align*}\]

The probability distribution of \(Z = X + Y\) requires summing over all possible combinations of \(X=x\) and \(Y=y\) such that \(x+y = z\). That is, if \(X=x\), \(Y\) must be \(y=z-x\). Therefore, the probability distribution of \(Z\) is exactly the convolution between \(p_X\) and \(p_Y\):

\[\begin{align*} p_Z(z) &= \sum_{x=1}^6 p_X(x) p_Y(z-x) \\ &= p_X * p_Y \end{align*}\]

This is one of examples where convolution shows up quite natarually when we deal of probability. Below, we show the illustration of the convolution results, for both a fair die scenario and an unfair one.

Convolution examples.

Extras: Convolution of Functions

The convolution of two functions \(f\) and \(g\) is the continuous case of the discrete version and is defined as:

\[(f * g)(t) = \int_{-\infty}^{\infty} f(\tau) g(t - \tau) d\tau\]

Similar to the discrete case, if we have two random variables \(X\) and \(Y\) with probability density functions \(P_X\) and \(P_Y\) respectively, then the probability density of \(Z = X + Y\) entails integrating all possible \(X=x\) and \(Y=y\) such that \(x+y = z\). That is, if \(X=x\), \(Y\) must be \(y=z-x\). Therefore, the probability density of \(Z\) is:

\[\begin{align*} p_Z(z) &= \int_{-\infty}^\infty p_X(x) p_Y(z-x) dx \\ \end{align*}\]

which is the exactly the convolution \(p_X * p_Y\).

Another example entails a convolution of a function and a Dirac delta. The Dirac delta function is defined as:

\[\delta(t) = \begin{cases} \infty & t = 0 \\ 0 & t \neq 0 \end{cases}\]

and

\[\int_{-\infty}^{\infty} \delta(t) dt = 1\]

The Dirac delta function is a special function that is zero everywhere4 except at \(t=0\), where it is infinite. However, the integral of the function is \(1\), which implies that we can think of Dirac delta function can be thought of as a probability distribution that is entirely concentrated at \(t=0\).

Next, let’s consider the convolution of the Dirac delta function with another function \(f\):

\[\begin{align*} (f * \delta)(t) &= \int_{-\infty}^{\infty} f(\tau) \delta(t - \tau) d\tau \\ &= f(t) \end{align*}\]

We can see that the convolution of \(f\) with the Dirac delta function is simply \(f\) itself, or \(\delta\) is an identity function with respect to integration. This is because the Dirac delta function is zero everywhere except when its argument \(t-\tau = 0\). Therefore, the only contribution to the integral comes from \(f(\tau = t)\), which is simply \(f\) itself. We can see in this example that the Dirac delta as a convolution kernel performs an identity operation.

Now, let’s consider the convolution of \(f\) with a shifted Dirac delta function:

\[\begin{align*} (f * \delta(t - \tau))(t) &= \int_{-\infty}^{\infty} f(\tau') \delta(t - \tau - \tau') d\tau' \\ &= f(t - \tau) \end{align*}\]

The convolution of \(f\) with a shifted Dirac delta function is simply \(f\) shifted by \(\tau\)! This is because the shifted Dirac delta function is zero everywhere except at \(\tau'=t-\tau\). Therefore, the only contribution to the integral comes from \(f(\tau)\), which is simply \(f\) shifted by \(\tau\).

Discrete Fourier Transform

The Discrete Fourier Transform (DFT) is a mathematical operation used to analyze the frequency components of a discrete signal. Given a discrete sequence of values \(x[n]\) for \(n = 0, 1, 2, \ldots, N-1\), the DFT computes a set of complex coefficients \(X[k]\) for \(k = 0, 1, 2, \ldots, N-1\) that represent the frequency content of the signal. The DFT is defined as:

\[X[k] = \sum_{n=0}^{N-1} x[n] \cdot e^{-i\frac{2\pi}{N}kn}\]

Here:

The original series can be recovered from \(X[k]\) via the inverse DFT (IDFT):

\[x[m] = \frac{1}{N} \sum_{k=0}^{N-1} X[k] \cdot e^{i\frac{2\pi}{N}km}\]

To show that \(x[m]\) as can recovered exactly as above, we will make use of the following results for geometric series. For \(r \in \mathbb{C}\) and \(r \neq 1\),

\[\sum_{k=0}^{N-1} r^k = \frac{1 - r^N}{1 - r}\]

In this case, \(r = e^{-i\frac{2\pi}{N}(n-m)}\) and \(r^N = e^{-i2\pi(n-m)}\). For \(n \ne m\), the exponent is an integer multiple of \(2\pi\), which means that \(r^N = 1\) and \(\sum_{k=0}^{N-1} r^k = \frac{1 - r^N}{1 - r} = 0\). For \(n = m\), \(r = 1\) and \(\sum_{k=0}^{N-1} r^k = \sum_{k=0}^{N-1} 1 = N\). Therefore, we can concisely write

\[\sum_{k=0}^{N-1} e^{-i\frac{2\pi}{N}k(n-m)} = N \cdot \delta_{n,m}\]

where \(\delta_{n,m}\) is the Kronecker delta function, which is \(1\) if \(n=m\) and \(0\) otherwise.

Then, the iDFT becomes:

\[\begin{align*} \frac{1}{N} \sum_{k=0}^{N-1} X[k] \cdot e^{i\frac{2\pi}{N}km} &= \frac{1}{N} \sum_{k=0}^{N-1} \left( \sum_{n=0}^{N-1} x[n] \cdot e^{-i\frac{2\pi}{N}kn} \right) \cdot e^{i\frac{2\pi}{N}km} \\ &= \frac{1}{N} \sum_{n=0}^{N-1} \left( x[n] \cdot \sum_{k=0}^{N-1} e^{-i\frac{2\pi}{N}k(n-m)} \right) \\ &= \frac{1}{N} \sum_{n=0}^{N-1} x[n] \cdot N \cdot \delta_{n,m} \\ &= x[m] \end{align*}\]

which means that we recover the original signal \(x[m]\) for \(m = 0, \dots, N-1\) perfectly from the DFT coefficients \(X[k]\). There is no information loss in the DFT operation.

We can think of this as a duality between the spatial domain and the frequency domain. The DFT converts a signal from the spatial domain (\(N\) numbers) to the frequency domain (\(N\) numbers), and the iDFT converts a signal from the frequency domain back to the spatial domain. The DFT and iDFT are inverse operations of each other.

In terms of computational complexity, both DFT and iDFT involves summing \(N\) numbers for each of the \(N\) entries, thus incurs a complexity \(O(N^2)\). Next, we will discuss a way to perform these operations efficiently with log linear complexity (without any approximation!!) by exploiting special structures of the complex exponentials. This is the main idea behind Fast Fourier Transform (FFT).

Fast Fourier Transform

Fast Discrete Fourier Transform or more commonly called Fast Fourier Transform (FFT) is an algorithm that computes the DFT efficiently. The FFT algorithm is based on the divide-and-conquer strategy, and is able to compute the DFT in \(O(N \log N)\) time, which is much faster than the naive \(O(N^2)\) algorithm. We sketch a proof below.

Let’s start with the DFT definition:

\[X[k] = \sum_{n=0}^{N-1} x[n] \cdot e^{-i\frac{2\pi}{N}kn}\]

We can rewrite this by splitting into the odd and even terms as:

\[X[k] = \sum_{n=0}^{N/2-1} x[2n] \cdot e^{-i\frac{2\pi}{N}k(2n)} + \sum_{n=0}^{N/2-1} x[2n+1] \cdot e^{-i\frac{2\pi}{N}k(2n+1)}\]

We can further rewrite this as:

\[\begin{align*} X[k] &= \sum_{n=0}^{N/2-1} x[2n] \cdot e^{-i\frac{2\pi}{N/2}kn} + e^{-i\frac{2\pi}{N}k} \sum_{n=0}^{N/2-1} x[2n+1] \cdot e^{-i\frac{2\pi}{N/2}kn} \\ &= E_k + e^{-i\frac{2\pi}{N}k} O_k \end{align*}\]

We can see that the first term is the DFT of the even terms of \(x\), denoted by \(E_k\), and the second term is the DFT of the odd terms of \(x\), denoted by \(O_k\), multiplied by a complex exponential. The key part is that we also obtain \(X[k + \frac{N}{2}]\) for free once we have \(E_k\) and \(O_k\), due to the identity:

\[X[k + \frac{N}{2}] = E_k - e^{-i\frac{2\pi}{N}k} O_k\]

To obtain \(E_k\) (or \(O_k\)), we recursively break up it up into two terms, and so on. Therefore, the computational complexity of DFT to obtain \(X[k]\) consists of the complexity of two DFT of \(N/2\) elements plus \(O(1)\) operations, amortized by two, since for each \(X[k]\), we also get \(X[k+ \frac{N}{2}]\). This gives us the following recurrence relation:

\[T(N) = \frac{1}{2} \left( 2 T(N/2) + O(1) \right)\]

which yields \(T(N) = O(\log N)\). For all \(k\), this results in an \(O(N \log N)\) algorithm for DFT.

The Convolution Theorem

In this section, we will show that a convolution of two vectors can be seen as a multiplication of their Fourier transforms. This is known as the convolution theorem. We will show this in the discrete case. The continuous case extends naturally by replacing the summation with an integral (with exchange of order justified by Fubini’s theorem).

The convolution of two sequences \(\mathbf{a}\) and \(\mathbf{b}\) is defined as:

\[\begin{align*} \mathbf{c}_m = (\mathbf{a} * \mathbf{b})_m &= \sum_{n=0}^{N-1} a_n b_{m-n} \\ \end{align*}\]

The Fourier transform of a sequence \(\mathbf{a}\) is defined as:

\[\begin{align*} \mathbf{A}_k = \left(\mathcal{F}(\mathbf{a}) \right)_k &= \sum_{n=0}^{N-1} a_n e^{-i\frac{2\pi}{N}kn} \\ \end{align*}\]

Then,

\[\begin{align*} \left( \mathcal{F} (\mathbf{a} * \mathbf{b}) \right)_k &= \sum_{m=0}^{N-1} (\mathbf{a} * \mathbf{b})_m e^{-i\frac{2\pi}{N}km} \\ &= \sum_{m=0}^{N-1} \left( \sum_{n=0}^{N-1} a_n b_{m-n} \right) e^{-i\frac{2\pi}{N}km} \\ &= \sum_{n=0}^{N-1} \sum_{m=0}^{N-1} a_n e^{-i\frac{2\pi}{N}kn} b_{m-n} e^{-i\frac{2\pi}{N}k(m-n)} \\ &= \sum_{n=0}^{N-1} a_n e^{-i\frac{2\pi}{N}kn} \sum_{m=0}^{N-1} b_{m-n} e^{-i\frac{2\pi}{N}k(m-n)} \\ &= \sum_{n=0}^{N-1} a_n e^{-i\frac{2\pi}{N}kn} \sum_{s=-n}^{N-1-n} b_s e^{-i\frac{2\pi}{N}ks} \\ &= \left( \sum_{n=0}^{N-1} a_n e^{-i\frac{2\pi}{N}kn} \right) \left( \sum_{s=0}^{N-1} b_s e^{-i\frac{2\pi}{N}ks} \right)\\ &= \mathbf{A}_k \cdot \mathbf{B}_k \\ &= \mathcal{F}(\mathbf{a}) \cdot \mathcal{F}(\mathbf{b}) \end{align*}\]

where we exchange variable \(s = m-n\) and the fact that \(b_s = 0\) for \(0 < s < N-1\), so the sum \(\sum_{s=-n}^{N-1-n} b_n \cdot \xi\) is the same as \(\sum_{s=0}^{N-1} b_n \cdot \xi\). The proof can also be done simpler by consider the summation from \(-\infty\) to \(\infty\) where values beyond its support are zero.

Extras: Continuous Case of Convolution Theorem

In the continuous case involving convolution of two functions \(f\) and \(g\), we also provide a proof sketch below.

The convolution of two functions is defined as:

\[(f * g)(t) = \int_{-\infty}^{\infty} f(\tau) g(t - \tau) d\tau\]

Then, the Fourier transform of the convolution is:

\[\begin{align*} \mathcal{F}[f * g](\omega) &= \int_{-\infty}^{\infty} (f * g)(t) e^{-i \omega t} dt \\ &= \int_{-\infty}^{\infty} \left( \int_{-\infty}^{\infty} f(\tau) g(t - \tau) d\tau \right) e^{-i \omega t} dt \\ &= \int_{-\infty}^{\infty} f(\tau) \left( \int_{-\infty}^{\infty} g(t - \tau) e^{-i \omega t} dt \right) d\tau \\ &= \int_{-\infty}^{\infty} f(\tau) \left( \int_{-\infty}^{\infty} g(t) e^{-i \omega (t + \tau)} dt \right) d\tau \\ &= \left( \int_{-\infty}^{\infty} f(\tau) e^{-i \omega \tau} d\tau \right) \cdot \left( \int_{-\infty}^{\infty} g(t) e^{-i \omega t} dt \right) \\ &= \mathcal{F}[f](\omega) \cdot \mathcal{F}[g](\omega) \end{align*}\]

Or in other words,

\[(f*g)(w) = \mathcal{F}^{-1}[\mathcal{F}[f](\omega) \cdot \mathcal{F}[g](\omega)](w)\]

Convolution Theorem + FFT = Log Linear Convolution

To recap, the implication of the convolution theorem is that if we want to perform a convolution of long signals \(f(t)\) and \(g(t)\), which naively would incur \(O(N^2)\) computational complexity, we can reduce it to \(O(N \log N)\) without any approximation. Below are the steps:

B. Orthonormal Basis in Function Space

We turn our attention to a vector space whose elements are functions. A special kind of such vector space we will consider is a Hilbert space5 \(\mathcal{H}\) that is equipped with a countable and dense orthonormal basis6 \(\{ g_n \}_{n=0}^\infty\). In essence, a Hilbert space provides a notion of projections and distance via its inner product \(\langle f, h \rangle\) and the induced norm \(\| f \| = \sqrt{\langle f, f \rangle}\). Further, a Hilbert space has a completeness property, meaning that any sequence of elements in \(\mathcal{H}\) that draws progressively closer together actually converge within the space. As example of such a Hilbert space is the space of square integrable functions on an interval \(L^2[a,b]\), with a corresponding inner product

\[\langle f, g \rangle = \int_a^b f(x)^* g(x) \cdot w(x) dx\]

where \(f(x)^*\) denotes the complex conjugate of \(f\). The function \(w(x)\) corresponds to the weight of different points in the interval which reflects the measure that the inner product is defined on.7 For instance, we may want to weigh points far away with smaller weights. Different weight \(w(x)\) used in the integral above (or different measure) will give rise to a different inner product, which in turn defines a different notion of orthogonality and distance in the associated Hilbert space.

The power of the dense orthonormal basis \(\mathcal{G}\) is that it allows us to represent any function in the Hilbert space using a linear combination of the basis elements. That is, for any function \(u(t) \in \mathcal{H}\), there exists coefficients \(\{ c_n \}_{n=1}^\infty\) such that:

\[u(t) = \sum_{n=1}^{\infty} c_n g_n(t)\]

That is, in the orthonormal basis \(\mathcal{G}\), a function \(u\) can be represented as simply an infinite sequence \(\{c_n\}_{n=1}^\infty\). 8 It is quite profound that an entire function whose domain consists of uncountably many numbers can be described by a countable set of numbers, the coefficients of the respective basis.

We can also think of the partial sum \(u_m(t) = \sum_{n=1}^{m} c_n g_n(t)\) as an approximation of the function \(u(t)\), where the approximation gets better as \(m \to \infty\) (where the convergence is in the norm9). Therefore, a finite vector \((c_1, c_2, .., c_m)\) can also be used to approximately represent an entire function where the approximation gets better as \(m\) is larger. This concept is used widely for compression of signals such as audio, images, and videos.

While this representation ensures approximation, it does not directly offer a method to find \(c_n\) for a given \(u(t)\). The key lies in the orthogonality of the basis. The set \(\{ g_n \}\) is orthonormal, implying \(\langle g_m, g_n \rangle = \delta_{m,n}\), where \(\delta_{m,n}\) is the Kronecker delta, a function that returns 1 when \(m = n\) and 0 otherwise. This orthogonality simplifies our task of finding coefficients to taking inner products, where we extract the coefficient of \(g_n\) via:

\[\begin{align*} \langle u(t), g_n(t) \rangle &= \langle \sum_{m=1}^{\infty} c_m g_m(t), g_n(t) \rangle \\ &= \sum_{m=1}^{\infty} c_m \langle g_m(t), g_n(t) \rangle \\ &= \sum_{m=1}^{\infty} c_m \delta_{m,n} \\ &= c_n \end{align*}\]

This filtering property, intrinsic to orthonormal bases, ensures that we isolate each coefficient efficiently.

Based on the weight function \(w(x)\) and the subspace of functions we operate on, the orthonormal basis can be different. For instance, for uniform weight \(w\) on an interval, the Legendre polynomials form an orthonormal basis. For periodic functions with uniform weight \(w\) on an interval, the Fourier series form an orthonormal basis. For an exponentially decaying weight \(w\), the Laguerre polynomials form an orthonormal basis. We will discuss these three examples below.

Example 1: Legendre Polynomials as Orthonormal Basis on Uniform Measure

The space of square-integrable functions over the interval \([-1, 1]\), denoted by \(L^2[-1,1]\), is an example of a Hilbert space. The inner product is defined as \(\langle f,g \rangle = \int f(x) \cdot g(x) dx\). In this space, the Legendre polynomials form a countable orthonormal basis.

The Legendre Polynomials offer a rich tapestry of function spaces that play a pivotal role across mathematics and physics. Originating from the studies of celestial mechanics by Adrien-Marie Legendre, these polynomials have since been embraced in various applications spanning from quantum mechanics to approximation theory.

Definition: The \(n^{th}\) Legendre polynomial, denoted \(P_n(x)\), is given by Rodrigues’ formula:

\[P_n(x) = \frac{1}{2^n n!} \frac{d^n}{dx^n} \left[ (x^2 - 1)^n \right]\]

More concretely,

\[\begin{align*} P_0(x) &= 1 \\ P_1(x) &= x \\ P_2(x) &= \frac{1}{2} (3x^2 - 1) \\ P_3(x) &= \frac{1}{2} (5x^3 - 3x) \\ &\ .. \end{align*}\]

These polynomials are orthogonal on the interval \([-1,1]\) with respect to the weight function \(w(x) = 1\). In other words:

\[\int_{-1}^{1} P_m(x) P_n(x) dx = \frac{2}{2n+1} \delta_{m,n}\]

where \(\delta_{m,n}\) is the Kronecker delta function.

The example above can be extended to a Hilbert space defined on \(L^2[a,b]\), where the orthonormal basis functions can be derived as follows.

We introduce a linear change of variables to adapt \([-1,1]\) to the interval \([a,b]\). By defining a new variable \(y\) such that \(y = \frac{2(x - a)}{b - a} - 1\), we transform the interval \([a,b]\) to \([-1,1]\). The Legendre polynomials on the interval \([a,b]\) are then expressed as \(P_n\left( \frac{2(x - a)}{b - a} - 1 \right)\). Define \(\tilde{P}_n(x)\) as a normalized Legendre polynomial, together with \(dy = \frac{2}{b-a} dx\) and \(\int P_m P_n dx = \delta_{m,n}\), we have:

\[\tilde{P}_n(x) = \sqrt{\frac{2n+1}{b-a}} P_n\left( \frac{2(x - a)}{b - a} - 1 \right)\]

Example 2: Fourier Basis for Periodic Functions

On \(L^2[-L,L]\) with the inner product \(\langle f,g \rangle = \frac{1}{2L} \int_{-L}^{L} f(x)^* g(x) dx\) where \(^*\) denotes complex conjugate, the Fourier basis is an orthonormal basis when applied to periodic functions.10

The Fourier series representation of a periodic function \(f(x)\) over the interval \([-L,L]\) is given by:

\[f(x) = \sum_{n=-\infty}^{\infty} c_n e^{i\frac{2\pi n}{2L}x}\]

where \(c_n\) are the Fourier coefficients of the basis elements \(f_n = e^{i\frac{2\pi n}{2L}x}\), computed as:

\[c_n = \langle e^{i\frac{2\pi n}{2L}x}, f(x) \rangle = \frac{1}{2L} \int_{-L}^{L} e^{-i\frac{2\pi n}{2L}x} f(x) dx\]

We can extract out the coefficient \(c_n\) from the inner product with \(f_n\) because for distinct integers \(m\) and \(n\), the basis functions \(f_n\) and \(f_m\) are orthogonal on the interval \([-L,L]\). Their inner product is:

\[\begin{align*} \langle f_n , f_m \rangle &= \frac{1}{2L} \int_{-L}^{L} e^{- i\frac{2\pi m}{2L}x} e^{i\frac{2\pi n}{2L}x} dx \\ &= \frac{1}{2L} \int_{-L}^{L} e^{i\frac{2\pi (n-m)}{2L}x} dx \\ &= \frac{1}{2L} \left[ \frac{1}{i\frac{2\pi (n-m)}{2L}} e^{i\frac{2\pi (n-m)}{2L}x} \right]_{-L}^{L} \\ &= \frac{1}{2L} \frac{1}{i\pi (n-m)} \left[ e^{i\pi (n-m)} - e^{-i\pi (n-m)} \right] \\ &= \frac{1}{2L} \frac{1}{i\pi (n-m)} \left[ (-1)^{n-m} - (-1)^{n-m} \right] \\ &= 0. \end{align*}\]

For \(n=m\), we have \(\frac{1}{2L} \int_{-L}^{L} e^{- i\frac{2\pi m}{2L}x} e^{i\frac{2\pi m}{2L}x} dx = \frac{1}{2L} \int_{-L}^{L} 1 dx = 1\), which is indeed normalized.

How do we interpret complex coefficients?

The complex representation here is a convenient way to express the Fourier basis, but we can also express it in terms of cosine and sine functions which gives us nice physical interpretations. This integer multiples of the base frequency (or simply called harmonics), with arbitrary phase in that frequency controlled by the relative coefficient of the cosine and sine functions.

Expanding \(c_n = a_n + i b_n\) and \(e^{i \omega n x } = \cos(\omega n x) + i \sin(\omega n x)\) where \(\omega = \frac{2 \pi}{2L}\), we have

\[\begin{align*} f(x) &= \sum_{n=-\infty}^{\infty} c_n e^{i \omega n x} \\ &= \sum_{n=-\infty}^{\infty} (a_n + i b_n) (\cos(\omega n x) + i \sin(\omega n x)) \\ &= \sum_{n=-\infty}^{\infty} (a_n \cos(\omega n x) - b_n \sin(\omega n x)) + i \sum_{n=-\infty}^{\infty} (b_n \cos(\omega n x) + a_n \sin(\omega n x)) \end{align*}\]

Now, let’s simplify it doing a sum over non-negative \(n\), from \(0\) to \(\infty\). We’ll start with the real part of the function decomposition of \(f\), which is given by:

\[\begin{align*} \text{Re}[f(x)] &= \sum_{n=-\infty}^{\infty} (a_n \cos(\omega n x) - b_n \sin(\omega n x)) \\ &= a_0 + \sum_{n=1}^{\infty} (a_n \cos(\omega n x) - b_n \sin(\omega n x) + \sum_{n=1}^{\infty} (a_{-n} \cos(-\omega n x) - b_{-n} \sin(- \omega n x) \\ &= a_0 + \sum_{n=1}^{\infty} (a_n + a_{-n}) \cos(\omega n x) + (- b_n + b_{-n}) \sin(\omega n x) \\ \end{align*}\]

The complex part of the function decomposition of \(f\) is given by:

\[\begin{align*} \text{Im}[f(x)] &= \sum_{n=-\infty}^{\infty} (b_n \cos(\omega n x) + a_n \sin(\omega n x)) \\ &= b_0 + \sum_{n=1}^{\infty} (b_n \cos(\omega n x) + a_n \sin(\omega n x)) + \sum_{n=1}^{\infty} (b_{-n} \cos(-\omega n x) + a_{-n} \sin(- \omega n x) \\ &= b_0 + \sum_{n=1}^{\infty} (b_n + b_{-n}) \cos(\omega n x) + (a_n - a_{-n}) \sin(\omega n x) \\ \end{align*}\]

In short, the complex coefficients \(c_n\) is such that their real and imaginary parts are the coefficients of the cosine and sine functions, respectively.

Real-Valued Functions

If the function \(f(x)\) is real, then \(\text{Im}[f(x)] = 0\), which implies that \(b_0 = 0\) and \(a_n = a_{-n}\) and \(b_n = -b_{-n}\). That is, for real \(f\), we do not need to compute the coefficients with respect to negative frequencies due the symmetry.

The simplified components for a real \(f\) becomes:

\[\begin{align*} f(x) &= a_0 + \sum_{n=1}^{\infty} (a_n + a_{-n}) \cos(\omega n x) + (- b_n + b_{-n}) \sin(\omega n x) \\ &= a_0 + 2\sum_{n=1}^{\infty} a_n \cos(\omega n x) - b_n \sin(\omega n x) \\ &= a_0 + 2 \sum_{n=1}^{\infty} \sqrt{a_n^2 + b_n^2} \cos(\omega n x - \phi_n) \\ \end{align*}\]

where \(\phi_n = \tan^{-1} \left( \frac{b_n}{a_n} \right)\). In this interpretation, any periodic function is a linear combinarion of cosine waves with integer multiples of the base frequency (or simply called harmonics), with arbitrary phase of that frequency controlled by the relative coefficient of the cosine and sine functions. This is the physical interpretation for a real signal \(f\) we alluded to earlier.

In practice, the complex representation is more convenient to use and has wide adoption. It is also a generalization of the sine and cosine representation that is applicable for both real and complex functions.

Example 3: Laguerre Polynomials as Orthonormal Basis on Expontential Decay Measure

In the context of Hilbert spaces, another example of an orthonormal basis is provided by the Laguerre polynomials. The space of square-integrable functions over the interval \([0, \infty)\), denoted as \(L^2[0, \infty)\), serves as an example where these polynomials form a countable orthonormal basis.

Definition: The \(n^{th}\) Laguerre polynomial, denoted as \(L_n(x)\), can be defined through Rodrigues’ formula:

\[L_n(x) = \frac{e^x}{n!} \frac{d^n}{dx^n} \left(e^{-x}x^n\right)\]

In more concrete terms, the first few Laguerre polynomials are as follows:

\[\begin{align*} L_0(x) &= 1 \\ L_1(x) &= 1 - x \\ L_2(x) &= \frac{1}{2}(x^2 - 4x + 2) \\ L_3(x) &= \frac{1}{6}(-x^3 + 9x^2 - 18x + 6) \\ &\vdots \end{align*}\]

These polynomials are orthogonal over the interval \([0, \infty)\) with respect to the weight function \(w(x) = e^{-x}\). In other words, they satisfy the following orthogonality condition:

\[\int_{0}^{\infty} L_m(x) L_n(x) e^{-x} dx = \delta_{m,n}\]

Here, \(\delta_{m,n}\) is the Kronecker delta function.

We can see that a different weight function, or a different measure, would give rise to a different set of orthonormal basis functions. This is a powerful concept that we will revisit later in the context of convolution models, especially in the construct of the HiPPO matrix.

Examples: Approximating Data with Legendre vs. Fourier Basis

We show a few examples here where we use Legendre and Fourier bases to approximate functions with finite number of basis elements \(n\). See a Colab notebook for code and more examples here for the code.

Approximating data via Legendre and Fourier Basis

If we extrapolate the data outside the original domain, for Fourier series, the pattern repeats periodically. This can also be shown analytically by observing that for any \(x \in [-L,L]\), we have \(f(x) = f(x + 2L)\) in the Fourier representation.

Extrapolation via Fourier Series

C. State Space Models

State space models are a class of models used to describe the time evolution of systems. They are widely used in many fields, including control theory, signal processing, and machine learning. In this section, we will give a brief introduction to state space models, and show how they can be used to model complex dynamics.

In the continuous time scenario, linear state space models can be written as matrix-vector multiplications where the \(N\)-dimensional state vector \(\mathbf{x}(t)\) evolves over time \(t\) according to the following differential equation:

\[\begin{align*} \frac{d}{dt} \mathbf{x}(t) &= A \mathbf{x}(t) + B \mathbf{u}(t) \\ \mathbf{y}(t) &= C \mathbf{x}(t) + D \mathbf{u}(t) = C \mathbf{x}(t) \end{align*}\]

where \(u(t) \in \mathbb{R}\), \(A\) is an \(N \times N\) matrix, \(B\) is \(N \times 1\), and \(C\) is \(1 \times N\). \(\mathbf{u}\) is usually called the input vector, and \(\mathbf{y}\) is the output vector. In most cases \(D\) is assumed to be \(0\).

Recurrent View of State Space Models

In the discretized case, the evolution goes from time step \(t=k-1\) to \(t=k\), instead of infinitesimal step in the continuous time dynamics \(t\) to \(t + dt\). We can approximate \(x_k\) by either using the derivative at \(k\), or also the average of the derivative at \(k\) and \(k-1\) for better numerical stability (bilinear/trapezoid method). With step size \(\Delta\),

\[x_{k} \approx x_{k-1} + \frac{\Delta}{2} (x'_{k} + x'_{k-1}) + O(\Delta^2)\]

Together with the state space equations, we can show that

\[\begin{align*} x_{k} &= (I - \frac{\Delta}{2} A)^{-1} (I + \frac{\Delta}{2} A) x_{k-1} + \frac{\Delta}{2} B (I - \frac{\Delta}{2} A)^{-1} u_k \\ & \ \text{or more succinctly}\\ x_{k} &= \bar{A} x_{k-1} + \bar{B} u_k \\ y_k &= C x_k \end{align*}\]

That is, we can obtain the current state \(x_k\) given the the input \(u_k\) and only past state \(x_{k-1}\), without needing to know the previous states or inputs. This is a recurrent property that is useful during inference since it incurs \(O(1)\) compute complexity without dependency of the context length. This is quite different from attention where we need to use cached key and value to predict the next step, which incurs \(O(L)\) memory IO and compute during incremental decoding.

Discretized linear state space models in recurrent view and convolution view
Example: Harmonic Oscillator

Let’s develop some intuition for what state space models can do. We will consider a simple example of a spring attached to a mass \(m\). The spring has a spring constant \(k\), and the mass is attached to a wall. The mass is also subject to a force \(u(t)\).

The dynamics of the system is described by the following differential equation:

\[m y''(t) = -k y(t) + u(t)\]

where \(y(t)\) is the displacement from equilibrium of the mass at time \(t\). We know that in the case of \(u(t) = 0\), this should be a simple harmonic oscillator where the solutions are pure sine wave with certain phase (depending on the initial position and velocity). Let’s see how well we can model this system using a state space model.

The key is how we define the state \(\mathbf{x}\). Let \(v(t) = \frac{dy}{dt}\) denote the velocity. In this case, we can see that the differential equation above can be written as:

\[v'(t) = - \frac{k}{m} y(t) + \frac{1}{m} u(t)\]

If we define the state as

\[x = \begin{bmatrix} y \\ v \end{bmatrix}\]

then we can describe the differential equation with state space model as:

\[\mathbf{x}'(t) = \begin{bmatrix} y'(t) \\ v'(t) \end{bmatrix} = \begin{bmatrix} 0 & 1 \\ -\frac{k}{m} & 0 \\ \end{bmatrix} \begin{bmatrix} y(t) \\ v(t) \end{bmatrix} + \begin{bmatrix} 0 \\ \frac{1}{m} \\ \end{bmatrix} u(t)\]

The upper row simply says \(y'(t) = v(t)\), the definition of velocity. The lower row says \(v'(t) = -\frac{k}{m} y(t) + \frac{1}{m} u(t)\), exactly the equation for the acceleration. Then, we extract out the position by

\[y(t) = \begin{bmatrix} 1 & 0 \\ \end{bmatrix} \mathbf{x}(t) = C \mathbf{x}(t)\]

In this case we use the initial condition such as

\[\mathbf{x}[0] = \begin{bmatrix} 1 \\ 0 \\ \end{bmatrix}\]

which corresponds to initial position at $1$ without velocity. We also use \(u(t) = 0\), meaning that no additional force is applied. We can see that the position calculated via the discretized discrete state space follows the sine equation quite perfectly. See the Colab notebook here. Note that we adapted the code given in The Annotated S4 for state space models.

State space model for spring-mass with no external force
State space model for spring-mass with external force at time 0

Just for fun, let’s also consider another case where we start from initial position and velocity \(0\), but with \(u[t]\) being something like a Dirac Delta function which injects a fixed momentum at time \(t=0\). In this case, \(u[0] = 1\) and \(u[t] = 0\) for all \(t > 0\). As shown below, the state space models are able to capture the dynamics quite nicely.

We also explore other variations such as exponentially decay amplitude due to friction and resonance where the external force has the same frequency as the natural frequency. We also show the cases where the numerical approximation can break down resulting in alising effect, once the frequency becomes too high compared to the granularity of the discretization.

Friction results in exponential decay of amplitude.
Impulse force at resonance frequency results in increasingly large oscillation.
High frequency compared to sampling frequency leads to aliasing (incorrect solution to the dynamical system).
Another aliasing example.

What Can We Model With Linear State Space Models?

In general, any linear differential equations can be framed as a linear state space. This excludes some dynamics; for instance, celestial mechanics, where the gravitational force is proportional to the inverse square of the distance between the two bodies, cannot be modeled with a linear state space.

Convolution View of State Space Models

Now let’s think about how we can process a vector input \(\mathbf{u} = \{ u_k \}_{k=0}^L\) to obtain an output \(\{y_k\}_{k=0}^L\) in batch. Due to the recurrence relations, we can write the output \(y_k\) as

\[\begin{align*} y_k &= \sum_{i=0}^{k} \bar{C} \bar{A}^{k-i} \bar{B} u_i + \bar{C} \bar{A}^k x_0 \\ y_k &= \sum_{i=0}^{k} \bar{K}_{k-i} u_i + 0 \\ \end{align*}\]

where \(\bar{K}_n = \bar{C} \bar{A}^n \bar{B}\) and \(x_0\) is the initial state, which is often taken to be zero. In this form, we see that \(y_k\) is a convolution between the L-dimensional vector \(\mathbf{\bar{K}}\) and the input vector \(\mathbf{u}\). It means that we can think of the state space model as a convolution model, where the convolution kernel is \(\mathbf{\bar{K}}\). This unrolled view is useful to process the entire input sequence, either during inference or training.

That is, if we have a sequence of length \(L\) as an input, once if we the kernel \(\mathbf{\bar{K}}\), we can compute the entire output \(\mathbf{y}\) in \(O(L \log L)\) time due to the convolution theorem + FFT. The batched operation is important for training as well as initial processing of input sequence during inference.

SSM for Sequence Modeling

Now, you may wonder what would be the matrices \(A\), \(B\), \(C\) that we should use? Remember from our earlier examples for the harmonic oscillator that the matrix \(A\) controls the evolution of the dynamical system. How should we construct or interpret such a system to model something such as language modeling or time series forecasting? In a later section, we discuss the HiPPO matrix, which defines a type of matrix \(A\) that can be used to model long range dependencies.

D. Einstein Summation for General Tensor Operations

We cover a brief introduction to Einstein summation, which is a concise notation for tensor operations. We will use this notation to describe attention or other tensor operations in a concise manner.

Let’s consider the attention weight between query tensor \(Q\) and key tensor \(K\). The query tensor is of shape \(Q_{bhnk}\) where \(b\) is the batch index, \(h\) is the head index, \(n\) is the query length index, and \(k\) is the head dimension index. The key tensor is of shape \(K_{bhmk}\) where \(m\) is the key length index. The attention weight tensor is of shape \(A_{bhnm}\).

The attention weight tensor, before softmax, can be described in various ways such as

\[\begin{align*} A_{bhnm} &= \sum_{k} Q_{bhnk} K_{bhmk} \tag{explicit sum over $k$} \\ A_{bhnm} &= \sum Q_{bhnk} K_{bhmk} \tag{implicit sum over $k$} \\ A_{bhnm} &= Q_{bhnk} K_{bhmk} \tag{what Einstein would write} \\ A_{bhnm} &= \langle Q_{bhnk} , K_{bhmk} \rangle \tag{inner product notation} \\ A &= \langle Q, K \rangle \end{align*}\]

In all of the notations above, we are summing over the head dimension \(k\), which is the dimension that we are reducing. The sum over \(k\) need not be explicitly mentioned if the output dimension is specified. That is, since \(A_{bhnm}\) does not contain \(k\), it implies that the output is the result of reduction over \(k\). Since this is akin to inner product over \(k\) where all other axes are broadcasted, we can write it as \(\langle Q, K \rangle\) for simplicity.

Below are examples of various einsum operations. For more details on attention with einsum, see a separate blog post The Illustrated Attention via Einstein Summation.

Einsum operations

E. Attention and Linear Attention

In this section, we will describe both attention (Vaswani et al., 2017) and linear attention (Katharopoulos et al., 2020). Let \(\sigma\) be a non linear operator where we may denote \(\sigma_L\) to emphasize that the operation \(\sigma\) is non linear over axis \(L\). We use the same notation as in the last section where both \(k,v\) are the head dimension index for key and value. We drop the batch dimension for simplicity.

With this notation, the attention operation (Vaswani et al., 2017) can be written as

\[\begin{align*} O_{hnv} &= \sum_{m=0}^n V_{hmv} \ \sigma_m \left( \sum_{k} Q_{hnk} K_{hmk} \right) \\ &= \sum_{m=0}^n V_{hmv} \ \sigma_m \left( A_{hmn} \right) \\ &= \sum_{m=0}^n V_{hmv} \ W_{hmn} \end{align*}\]

The causality of attention is reflected in the summation over \(m\), which is the key length. That is, the query at position \(n\) can only attend to the keys up to position \(n\). Note that we omit the multiplicative term \(\frac{1}{\sqrt{d}}\) for \(\sum_{k=1}^d Q_{hnk} K_{hmk}\) for simplicity. More details on attention can be found in The Illustrated Attention via Einstein Summation.

Linear Attention

If \(\sigma_m\) is an identity function11, we can rearrange things where \(K\) and \(V\) equivalently interact with each other first. That is,

\[\begin{align*} O_{hnv} &= \sum_{m=0}^n V_{hmv} \left( \sum_{k} Q_{hnk} K_{hmk} \right) \\ &= \sum_k Q_{hnk} \left( \sum_{m=0}^n V_{hmv} K_{hmk} \right) \end{align*}\]

This is a fully linear system given \(Q, K,V\). We can introduce some non linearity over the key length back by replacing \(K\) and \(Q\) with \(K'=\phi(K)\) and \(Q' = \phi(Q)\), where \(\phi\) is a non linear operator over axis \(m\) or \(n\).

\[\begin{align*} O^{\text{Linear Attention}}_{hnv} &= \sum_{m=0}^n V_{hmv} \left( \sum_{k} Q'_{hnk} K'_{hmk} \right) \\ &= \sum_k Q'_{hnk} \left( \sum_{m=0}^n V_{hmv} K'_{hmk} \right) \end{align*}\]

Observe that \(\left( \sum_{m=0}^n V_{hmv} K'_{hmk} \right)\) reduces over dimension \(m\), the key/value length or context length. This means that no matter how long the sequence is, this term has the same dimensionality, and the evolution as context length increases is entirely via addition. Also, observe that

\[\begin{align*} S_{hkv}(n) &= \sum_{m=0}^{n} V_{hmv} K'_{hmk} \\ &= \sum_{m=0}^{n-1} V_{hmv} K'_{hmk} + V_{h(m=n)v} K'_{hLk} \\ &= S_{hkv}(n-1) + V_{h(m=n)v} K'_{h(m=n)k} \end{align*}\]

which means that the subsequent spatial step rolls into the state vector \(S_{hkv}\), which is a recurrent property. Hence, linear attention can be seen as a recurrent operation and helps connect the traditional attention with RNNs.

Note that in the linear attention paper (Katharopoulos et al., 2020), the author motivates the linearization from the perspective of kernel methods. That is, we can view \(\sigma_m \left( \langle Q,K \rangle \right)\) as \(\text{Kernel}(Q,K)\), which is equivalent to \(\langle \phi(Q) , \phi(K) \rangle\) for some feature map \(\phi\). In the case of softmax, this corresponds to the exponential kernel where \(\phi\) is an infinite dimensional feature map. Then, the author proposed an alternate feature map \(\phi'\) that is finite dimensional, for instance, the exponential linear unit \(\phi'(x) = \text{elu}(x) + 1\).

Attention and Linear Attention

Notes and Observations


Long Convolution Models

HiPPO Framework for History Representation via State Space

The HiPPO matrices (Gu et al., 2020; Voelker et al., 2019) are the \(A,B\) matrices associated with state space models, obtained for input memorization problem where we want the state \(\vec{x}_{t}\) to capture \(u_{t' \le t}\), the entire input12 up to time \(t\). This can be very useful for long range modeling where we can build a contualized representation where the output feature at a given time step represents the entire past. In this section, we will walk through how to approach this problem.

First, the state vector \(\vec{x}_{t}\) in state space models is \(N\) dimensional. We want such a vector at time \(t\) to represent the entire past of the input from time \(0\) to \(t\). What would be a good way to represent the entire history in an \(N\) dimensional vector, where the history can be arbitrarily long?

The key is to think about it in function space, where the input vector \((u_0, u_1, \dots, u_t)\) is interpreted as evenly spaced points from a continuous time function \(u(t')\) defined up to \(t' = t\). Then, we choose the definition of the state space \(\vec{x}\) to represent the first \(N\) coefficients according to some orthonormal basis in the function space. Such definition will yield particular matrices \(A,B,C\).

HiPPO Framework -- Illustration from HiPPO paper.

Below, we outline the derivation steps of the HiPPO framework.

\[\frac{d}{dt} x(t) = A(t) x(t) + B(t) u(t),\]

where

\[\begin{align*} A(t) &= \frac{1}{t} A^{\text{HiPPO}} \\ B(t) &= \frac{1}{t} B^{\text{HiPPO}} \\ \end{align*}\]

and \(A^{\text{HiPPO}},B^{\text{HiPPO}}\) are the associated HiPPO matrices obtained from the derivation outlined above.

\[\begin{align*} A^{\text{HiPPO}}_{ij} &= - \begin{cases} (2i+1)^{1/2} (2j+1)^{1/2} & \text{if } \ i > j \\ i+1 & \text{if } \ i = j \\ 0 & \text{if } \ i < j \end{cases} \\ B^{\text{HiPPO}}_i &= (2i+1)^{1/2} \end{align*}\]

S4: Structured State Space

From the previous section, the HiPPO matrix provides a state space model definition that can map input signal \(u(t)\) to an output signal \(o(t)\) that captures the past history at each time step.

The next challenge to be addressed is how to compute the convolution kernel \(\mathbf{\bar{K}}\) for state space models efficiently. Once we have \(\mathbf{\bar{K}}\), computing the output from the input is fast via the convolution theorem and Fast Fourier Transform.

Recall that the convolution kernel for state space models can be written as:

\[\begin{align*} \mathbf{\bar{K}} &= \sum_{\ell=0}^{L-1} \bar{C} \bar{A}^{\ell} \bar{B} \\ \end{align*}\]

where \(\bar{A}\) is the associated matrix for discrete state space dynamics of \(A\).

That is, the convolution kernel requires obtaining \(\bar{A}^\ell\) for all \(\ell = 0, \dots, L\). Since \(\bar{A}\) is an \(N \times N\) matrix, and we need to do it \(L\) times, the computational complexity with naive matrix multiplication is \(O(N^2L)\). Here are the considerations outlined in the S4 paper.

Diagonal State Spaces

Interestingly, (Gupta et al., 2022) also shows that the using the normal part of the Hippo matrix without the low rank correction also works well in practice, with performance matching the original S4.

\[A = V D V^* - PQ^T = V ( D - (V^*P) (V^*Q)^* ) V^*\]

GSS: Gated State Space Models and Gated Attention Units

Gated State Space models (GSS) (Mehta et al., 2023) builds on two lines of work. First is the state space models, where the paper adopts the diagonal state space.

Second is the gating mechanism as an alternative activation function which has been shown to improve model performance and can be seen as a multiplicative residual connection, allowing the gradients to flow back freely. The gating mechanism also allows using weaker attention mechanism without quality degradation. We cover the literature of gating mechanism in Gating section.

In particular, GSS adopts Gated Attention Units (GAU) (Hua et al., 2022) with a diagonal state space model instead of the traditional \(L^2\) attention, together with the input-controlled gating mechanism proposed in GAU.

Below is a simplified version of the GSS model (without dimensionality reduction before the state space model). Given an input \(X\), the model performs linear projection over hidden dimension into \(U,V\) tensors.

\[U = \phi_u(X W_u), V = \phi_v(X W_v)\]

Then, the output is computed as

\[O = (U \odot \hat{V})W_o\]

where \(\hat{V}\) is the contextualized representation over spatial domain.

\[\hat{V} = \text{DSS}(V)\]

In contrast, an alternate contextualized representation is the attention mechanism where \(\hat{V} = \langle A, V \rangle\) and \(A = \text{softmax}(\langle Q, K \rangle + \text{bias})\).

Below are some additional observations from the paper:

H3: Hungry Hungry HiPPOs

The goal for H3 is to address the gap between previous state space models (S4, S4d, etc.) and transformers. The paper to improve (1) the expressivity and (2) the computational efficiency of state space models.

Improved Expressivity

H3 draws inspiration from the attention mechanism in transformers and long range modeling with state space models.

Here, we offer two illustrations. The first one considers the case of single input single output (SISO), where the interaction between \(K,V,Q\) tensors are via elementwise multiplication.

Another illustration is in the batch case where we illustrate how we split the input into multiple heads, and how the interaction between \(K,V,Q\) tensors are via einsum. Then, the head outputs are grouped together from all heads to final output.

Improved Computational Efficiency

Hyena Hierarchy

Hyena model takes a departure from the state space literature where the dynamics are explicitly defined by the matrices \(A,B,C\). Instead, the Hyena model uses a convolutional approach where the convolution kernels are constructed based on a learned mapping from position embeddings. This is considered an implicit13 convolution as opposed to explicit convolution where the convolution kernel does not necessarily depend on the data.

Illustrated Global Convolution Models

Illustration of Global Convolution Models

FAQs

References

  1. Efficiently Modeling Long Sequences with Structured State Spaces
    Albert Gu, Karan Goel, and Christopher Ré
    In The Tenth International Conference on Learning Representations, ICLR 2022, Virtual Event, April 25-29, 2022, 2022
  2. On the Parameterization and Initialization of Diagonal State Space Models
    Albert Gu, Karan Goel, Ankit Gupta, and 1 more author
    In NeurIPS, 2022
  3. Long Range Language Modeling via Gated State Spaces
    Harsh Mehta, Ankit Gupta, Ashok Cutkosky, and 1 more author
    In The Eleventh International Conference on Learning Representations, ICLR 2023, Kigali, Rwanda, May 1-5, 2023, 2023
  4. Hungry Hungry Hippos: Towards Language Modeling with State Space Models
    Daniel Y. Fu, Tri Dao, Khaled Kamal Saab, and 3 more authors
    In The Eleventh International Conference on Learning Representations, ICLR 2023, Kigali, Rwanda, May 1-5, 2023, 2023
  5. Hyena Hierarchy: Towards Larger Convolutional Language Models
    Michael Poli, Stefano Massaroli, Eric Nguyen, and 6 more authors
    In International Conference on Machine Learning, ICML 2023, 23-29 July 2023, Honolulu, Hawaii, USA, 2023
  6. HyenaDNA: Long-Range Genomic Sequence Modeling at Single Nucleotide Resolution
    Eric Nguyen, Michael Poli, Marjan Faizi, and 10 more authors
    CoRR, 2023
  7. Attention is All you Need
    Ashish Vaswani, Noam Shazeer, Niki Parmar, and 5 more authors
    In Advances in Neural Information Processing Systems 30: Annual Conference on Neural Information Processing Systems 2017, December 4-9, 2017, Long Beach, CA, USA, 2017
  8. Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
    Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and 1 more author
    In Proceedings of the 37th International Conference on Machine Learning, ICML 2020, 13-18 July 2020, Virtual Event, 2020
  9. HiPPO: Recurrent Memory with Optimal Polynomial Projections
    Albert Gu, Tri Dao, Stefano Ermon, and 2 more authors
    In Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual, 2020
  10. Legendre Memory Units: Continuous-Time Representation in Recurrent Neural Networks
    Aaron Voelker, Ivana Kajic, and Chris Eliasmith
    In Advances in Neural Information Processing Systems 32: Annual Conference on Neural Information Processing Systems 2019, NeurIPS 2019, December 8-14, 2019, Vancouver, BC, Canada, 2019
  11. How to Train your HIPPO: State Space Models with Generalized Orthogonal Basis Projections
    Albert Gu, Isys Johnson, Aman Timalsina, and 2 more authors
    In The Eleventh International Conference on Learning Representations, ICLR 2023, Kigali, Rwanda, May 1-5, 2023, 2023
  12. Long Range Arena : A Benchmark for Efficient Transformers
    Yi Tay, Mostafa Dehghani, Samira Abnar, and 7 more authors
    In 9th International Conference on Learning Representations, ICLR 2021, Virtual Event, Austria, May 3-7, 2021, 2021
  13. Diagonal State Spaces are as Effective as Structured State Spaces
    Ankit Gupta, Albert Gu, and Jonathan Berant
    In NeurIPS, 2022
  14. Transformer Quality in Linear Time
    Weizhe Hua, Zihang Dai, Hanxiao Liu, and 1 more author
    In International Conference on Machine Learning, ICML 2022, 17-23 July 2022, Baltimore, Maryland, USA, 2022
  15. Block-Recurrent Transformers
    DeLesley Hutchins, Imanol Schlag, Yuhuai Wu, and 2 more authors
    In NeurIPS, 2022
  16. Do Transformer Modifications Transfer Across Implementations and Applications?
    Sharan Narang, Hyung Won Chung, Yi Tay, and 13 more authors
    In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, EMNLP 2021, Virtual Event / Punta Cana, Dominican Republic, 7-11 November, 2021, 2021
  17. GLaM: Efficient Scaling of Language Models with Mixture-of-Experts
    Nan Du, Yanping Huang, Andrew M. Dai, and 24 more authors
    In International Conference on Machine Learning, ICML 2022, 17-23 July 2022, Baltimore, Maryland, USA, 2022
  18. LaMDA: Language Models for Dialog Applications
    Romal Thoppilan, Daniel De Freitas, Jamie Hall, and 54 more authors
    CoRR, 2022
  19. A Mathematical Framework for Transformer Circuits
    Nelson Elhage, Neel Nanda, Catherine Olsson, and 22 more authors
    Transformer Circuits Thread, 2021
    https://transformer-circuits.pub/2021/framework/index.html

Appendix

Gating

The role of gating has been used extensively in deep learning literature.

Gated Linear Units

(Dauphin et al., 2017) introduced gated linear units, which processes the layer input \(X\) by projections into \(X W + b\), modulated by element-wise multiplication with the input-dependent gates \(\sigma(X V + c)\), where \(\sigma\) denotes the sigmoid function. That is,

\[\begin{align*} h(X) = (X W + b) \odot \sigma(X V + c) \end{align*}\]

Gated linear units can be seen as a multiplicative version of residual connection. That is, the gradient of the gated linear unit has a path \(\nabla X \odot \sigma(X)\) without coupling with the non linearity \(\sigma'(X)\), which can be arbitrarily low in some region of \(X\) and could potentially suppress the gradient.

\[\nabla [ X \odot \sigma(X)] = \nabla X \odot \sigma(X) + X \odot \sigma'(X) \odot \nabla X\]

Based on experiments in (Dauphin et al., 2017), gated linear units as activation functions have shown improvement over other activations such as ReLU, Tanh, or the Tanh-gating mechanism used in LSTM (Hochreiter & Schmidhuber, 1997; van den Oord et al., 2016).

GLU Variants

(Shazeer, 2020) shows that variants of GLU are quite effective for transformers such as SwiGLU when used in place of ReLU or GELU between the two linear projections in feedforward layers. SwiGLU is adopted is large scale models such as PaLM (Chowdhery et al., 2023).

gMLP

(Liu et al., 2021) proposes gMLP for vision tasks, which uses a gating mechanism similar to GLU, but with a different formulation. gMLP enables cross-token interactions via a linear projection \(f(Z) = W Z + b\) coupled with gating as

\[s(Z) = Z \odot f(Z)\]

where \(W \in \mathbb{R}^{L \times L}\) controls cross token interactions.The author finds that it is also effective to split \(Z\) into two parts \(Z_1, Z_2\) along the channel dimension and instead uses

\[s(Z) = Z_1 \odot f(Z_2)\]

The comparisons shown in (Liu et al., 2021) indicate that attention is not critical for vision tasks, and the degradation in some NLP tasks can be compensated by making gMLP larger. This is an interesting experiment that attention may not be necessary and can be compensated via this form of gating and scale.

Gated Attention Units

Gated attention units (Hua et al., 2022) can be described as

\[O = (U \odot \hat{V}) W_o , \quad \text{ where } \hat{V} = A V \text{ and } U = \phi_u(X W_u), V = \phi_v(X W_v), \quad \in \mathbb{R}^{L \times E}\]

where \(A \in \mathbb{R}^{L \times L}\), which describes token-token attention weights. This formulation allows contextualized gating via \(\hat{V}\) instead of gating by the same token \(V\) as in MLP.

The paper uses the \(A\) matrix as the query-key attention matrix

\[A = \text{ReLU}^2(Q(Z) K(Z)^T) ; Z = \phi_z (X W_z)\]

The paper uses two GAU layers as a replacement for MLP (or GLU) + multi-head attention, with \(e = 2d\) where \(d\) is the hidden dimension resulting in comparable number of parameters in both scenarios. The paper finds that, consistent with (Liu et al., 2021), gating allows a simpler or weaker attention mechanism without quality degradation and also incorporates linear attention, where in linear attention, \(\hat{V}\) uses \(\sum K V\) first, rather than \(\sum QK\) first.

More on Fourier

As illustrated in Fourier Basis, any periodic function can be written as a linear combination of sine and cosine, or more compactly, as a linear combination of complex exponentials.

However, a general function that is not peroidic can also be expressed in terms of the continuous-spectrum Fourier Transform. That is, the frequency component needs not be multiples of a base frequency (harmonics), but can be an entire continuous spectrum.

The Fourier transform \(\mathcal{F}\) of a function \(f\) is defined as:

\[\begin{align*} \mathcal{F}[f](\omega) &= \int_{-\infty}^{\infty} f(t) e^{-i \omega t} dt \\ &= \int_{-\infty}^{\infty} f(t) \cos(\omega t) dt - i \int_{-\infty}^{\infty} f(t) \sin(\omega t) dt \end{align*}\]

where \(\omega\) is the frequency. Note that the complex notation allows us to extract components of both sine and cosine at once. If the Fourier transform is real, then frequency belongs to the cosine wave, and if it is pure imaginary, then the frequency belongs to the sine wave. A general complex number indicates the phase of the frequency component.

Example: Fourier Transform of a Sine Wave

Let’s look at a simple example where \(f\) is a sine wave with frequency \(\omega_0\):

\[\begin{align*} f(t) &= \sin(\omega_0 t) \end{align*}\]

Then,

\[\begin{align*} \mathcal{F}[f](\omega) &= \int_{-\infty}^{\infty} \sin(\omega_0 t) e^{-i \omega t} dt \\ &= \int_{-\infty}^{\infty} \frac{e^{i \omega_0 t} - e^{-i \omega_0 t}}{2i} e^{-i \omega t} dt \\ &= \frac{1}{2i} \int_{-\infty}^{\infty} e^{i (\omega_0 - \omega) t} dt - \frac{1}{2i} \int_{-\infty}^{\infty} e^{i (\omega_0 + \omega) t} dt \\ &= \frac{1}{2i} \left[ \delta(\omega - \omega_0) - \delta(\omega + \omega_0) \right] \end{align*}\]

where \(\delta\) is the Dirac delta function. We can see that the Fourier transform of a sine wave is a linear combination of two Dirac delta functions at \(\omega_0\) and \(-\omega_0\). The pure imaginary frequecy means that the frequency belongs to the sine wave (phase zero). We can see that in this case, the Fourier transform yields the same results as Fourier series representation – that is, we only need one frequency to represent a sine wave.

Example: Fourier Transform of a Truncated Sine Wave

Another example is where \(f\) is a truncated sine wave. That is, \(f\) is a sine wave for \(-T \leq t \leq T\) and zero otherwise. Then,

\[\begin{align*} \mathcal{F}[f](\omega) &= \int_{-T}^{T} \sin(\omega_0 t) e^{-i \omega t} dt \\ &= \int_{-T}^{T} \frac{e^{i \omega_0 t} - e^{-i \omega_0 t}}{2i} e^{-i \omega t} dt \\ &= \frac{1}{2i} \int_{-T}^{T} e^{i (\omega_0 - \omega) t} dt - \frac{1}{2i} \int_{-T}^{T} e^{i (\omega_0 + \omega) t} dt \\ &= \frac{1}{2i} \left[ \frac{e^{i (\omega_0 - \omega) T} - e^{-i (\omega_0 - \omega) T}}{i (\omega_0 - \omega)} - \frac{e^{i (\omega_0 + \omega) T} - e^{-i (\omega_0 + \omega) T}}{i (\omega_0 + \omega)} \right] \\ &= \frac{1}{2} \left[ \frac{e^{i (\omega_0 - \omega) T} - e^{-i (\omega_0 - \omega) T}}{(\omega_0 - \omega)} - \frac{e^{i (\omega_0 + \omega) T} - e^{-i (\omega_0 + \omega) T}}{(\omega_0 + \omega)} \right] \\ &= \frac{1}{2} \left[ \frac{\sin((\omega_0 - \omega) T)}{(\omega_0 - \omega)} - \frac{\sin((\omega_0 + \omega) T)}{(\omega_0 + \omega)} \right] \end{align*}\]

Here, due to truncatation, the function is no longer periodic and results in frequency components that spread across the spectrum. However, we can see that they are concentrated at \(\pm \omega_0\) and dissipates as \(\vert \omega \pm \omega_0 \vert\) gets larger.

Example: Fourier Transform of a Box Function

Let’s look at an example of the Fourier transform of a box function:

\[\begin{align*} f(t) &= \begin{cases} 1 & |t| < L \\ 0 & \text{otherwise} \end{cases} \end{align*}\]

It can be shown that the Fourier transform of \(f\) is:

\[\begin{align*} \mathcal{F}[f](\omega) &= 2L \frac{\sin(\omega L)}{\omega L} = 2L \ \text{sinc} (\omega L) \end{align*}\]

A high level interpretation of this Fourier transform is that a box function has frequency components even at infinitely high frequencies. However, the contribution of such high frequencies get small as \(\vert \omega \vert\) gets larger, since the range of \(\sin\) is bounded and \(\frac{1}{\omega}\) term gets smaller.

Fun Facts

The Fourier Transform is well-defined for any absolutely integrable function, which includes probability densities. Another name used for Fourier transform of a probability density is a characteristic function, defined as \(\mathbb{E}_Y e^{iyt}\). In fact there is a one-to-one correspondence between a probability density and its Fourier transform. This characteristic function is often much easier to deal with than the density function itself; for example, one can easily prove the Central Limit Theorem using characteristic functions by showing that the characteristic functions of \(\frac{1}{\sqrt{N}} \sum Y_i\) converges to the characteristic function of a Gaussian distribution, which implies that \(\frac{1}{\sqrt{N}} \sum Y_i\) converges to a Gaussian in distribution.


Footnote

  1. We will focus on the single input and single output case for simplicity. 

  2. In general, the vectors can have different sizes, but for simplicity, we will assume that they have the same size. 

  3. See Wikipedia - Toeplitz matrix for more details on Toeplitz matrix. 

  4. In general, we only need a Dirac delta function to be relatively zero everywhere except at \(t=0\). This is the case for \(\lim_{L\to \infty} f_L(x) = L \ \text{sinc}(xL)\) which is not exactly zero outside of \(x=0\) but is relatively zero compared to at \(x = 0\). 

  5. A Hilbert space \(\mathcal{H}\) is a vector space equipped with an inner product \(\langle \cdot, \cdot \rangle\) that is complete with respect to the norm induced by the inner product \(\lVert \cdot \rVert = \sqrt{ \langle \cdot, \cdot \rangle }\). Completeness means that any Cauchy sequence in \(\mathcal{H}\) converges to an element in within \(\mathcal{H}\) itself. The convergence is with respect to the norm, which provides a notion of distance. To elaborate on completeness further, a Cauchy sequence is a sequence such that elements \(v_n, v_m\) are getting closer and closer together as \(n,m \to \infty\). Intuitively, a complete space means that there are no unexpected “holes” in the space where any sequence that is supposed to converge (a Cauchy sequence) actually converges inside the space itself. 

  6. A Schauder basis is a countable basis that is dense in a complete normed space (a Banach space). To say that a basis is dense means that any element of the space can be arbitrarily closely approximated by a finite linear combination of basis elements. In other words, for every element in the space and any given small positive distance (\(\epsilon\)), there exists a finite sum of basis elements that is within that distance of the given element. This ensures that the basis “spans” the entire space in a limiting sense, even if any specific finite subset of the basis does not span the space.
     
    A crucial distinction to note is that a Schauder basis doesn’t need to be mutually orthogonal. This is because it’s typically defined in the context of a Banach space, where angles or orthogonality might not be relevant concepts. Yet, in a Hilbert space, it’s entirely possible to orthogonalize a Schauder basis using the Gram-Schmidt process. 

  7. The measure theoretic way is to define the inner product as \(\langle f, g \rangle = \int f(x)^* g(x) d\mu\) where \(\mu\) is the measure. In the case of \(L^2[a,b]\), the measure is \(d\mu = w(x) dx\) where \(w(x)\) is the weight function and is zero outside the interval \([a,b]\). 

  8. It turns out that such representation is also unique, which implies the an isomorphism between the space of continous functions and the space of sequences of numbers. More precisely, \(L^2[a,b]\) with a Schauder basis is isomorphic to \(\ell^2\), the space of square summable sequences. It is quite profound that we can uniquely represent a function that takes values on uncountably many points with a sequence of numbers, which is an countable set! The special structure of the orthonormal basis allows this to happen. 

  9. The convergence here is in the norm defined by the inner product of the Hilbert space. The proof for convergence is out of scope for this post. 

  10. \(f\) is periodic in the interval \([-L,L]\) if \(f(-L) = f(L)\). In practice, any function can be made to be periodic via padding. For instance, if a function is defined over \([-L, L]\), then we can pad with zeros from \(-L-1\) to \(-L\) and also from \(L\) to \(L+1\). This would result a function that is periodic over \([-L-1, L+1]\). 

  11. In the original linear attention paper, the author keeps the softmax denominator. However, we can view it as an identity function with respect to the length dimension \(m\) since the denominator involve \(\sum_m \exp(A_{hmn}) = D_{hn}\), which is independent of \(m\) and still allows the reodering of the summation \(\sum_m\) and \(\sum_k\). 

  12. In the case of state space models, we often think of \(u(t)\) as a function even though in the discretized version, it is a sequence. This is because we can think of \(u(t)\) as a function that is sampled at discrete time points. In other words, \(u(t)\) is a function that is defined for all \(t\), but we only observe it at discrete time points. 

  13. In general, an implicit convolution means that given an input \(x\), the convolution filter \(h_\theta(x)\) depends on \(x\) and also potentially some parameters \(\theta\). In the case of Hyena, the dependence on the data is not as strong – any input of the same length yields the same position embeddings, hence the same convolution kernels.