There are recent exciting developments of AI models for sub-quadratic long context modeling involving the use of global convolutions, ranging from Structured State Space Models (S4) (Gu et al., 2022)(Gu et al., 2022), Gated State Spaces (GSS) (Mehta et al., 2023), H3 (Fu et al., 2023), and most recently, Hyena models (Poli et al., 2023; Nguyen et al., 2023). In this blog post, we will attempt to provide a unified perspective on these models and show how they are related to each other. We will also provide a background on necessary topics, including convolution, fast fourier transform, state space models, and attention.
Prelude
Throughout this blog post, the common goal is to understand how various models define context operators. In the single feature dimension case, given a sequence 1, we are interested in effective ways to build a contextualized representation that captures the context , which is the input up to time . A good context representation can have significant implications. For instance, it can allow us to predict the next time step more accurately, which is useful for various tasks such as forecasting or language modeling.
Context Operator.
Attention (Vaswani et al., 2017), for instance, is a popular context operator where is a weighted sum of the input sequence , based on an input-dependent weight computed from softmax over dot product between each token and every other tokens in the past. In a vector form, we can also view attention as . In general, an context operator can be of any form where is a function. In this blog post, we will see a global convolution operator can be an effective operator , which builds contextual representation in a global way by aggregating information from all input positions. This is in contrast to short or local convolution where the kernel is of fixed length, popular in convolutional neural networks. The global convolution kernel can also depend on data itself, giving rise of implicit convolution kernels.
Notation
We use the following notations throughout the blog post:
We interchangeably think of different points along the sequence as time or spatial dimension. For instance, we can think of as the input at time or the input at position .
We use , , or sometimes to denote the sequence length.
represents either element-wise multiplication or einstein summation when appropriate.
denotes the convolution of two vectors and . In this post, we primarily deal typically deal with convolution of equal length vectors.
We either use boldface or to denote a vector.
Background
The goal for this blog post is to be self-contained and accessible to the broader audience. We will cover background necessary to deeply understand related convolution models and how they compare and connect with transformer’s attention. For instance, we cover the core idea and derivation of convolution theorem and Fast Fourier Transform (FFT), which is at the heart of being able to do long range model with log linear computation. We provide background section on orthonomal basis in function space, which will be important to understand the construction of HiPPO matrix for continuous-time input memorization problem with state space models. We will also cover einstein summation, as it provides a convenient way to tie different operations together under general (linear) tensor operations and is used throughout the blog post. We will also briefly cover transformers attention and its linear version, which provides connection to RNNs and the recurrence nature of state space models.
If you are already familiar with certain topics, feel free to skip. Below outlines the recommended background for each of the convolution models so that one can decide what to focus on.
S4: requires background A and C and understanding of HiPPO framework
H3: requires background A, C and understanding of HiPPO matrix and S4
Hyena: requires background A. All background topics recommended.
A. Convolution Theorem and Fast Fourier Transform
In this section, we will cover the following topics:
The concept of convolution, including some illustration and examples of convolution such as the probability distribution interpretation: the density of random variables is the convolution of the density and .
Discrete Fourier transform as a way to convert time-domain signal to frequency domain signal, and vice versa.
Fast Fourier transform as a way to compute the Discrete Fourier transform efficiently.
Convolution theorem, which states that convolution in time domain is equivalent to multiplication in frequency domain.
What is a Convolution?
Let us consider two N-dimensional vectors 2. The convolution of the two vectors, denoted as is defined as:
In this notation, we implicitly assume that or where is undefined and is treated as .
High-Level Intuition
We can see convolution as a way to combine two signals. The first signal is the convolution kernel and the second signal is the input signal (or vice versa since the operation is commutative). The approach of combining signals in convolution is such that for the output gather signals from all input pairs whose indicies add up to exactly.
To make it more concrete, let’s expand this out with .
where we note that or are treated as for .
We can see that, for instance, is the sum of and and , where the indicies add up to .
Convolution as a way to aggregate information from inputs.
We can also write convolution as a matrix multiplication:
or
where is the convolution matrix representation of vector .
We observe that is Toeplitz3, meaning that it each diagonal has constant values from left to right.
Since the convolution operator is commutative (), we also have where is the convolution matrix representation of vector .
Convolution kernel and its matrix representation.
I included a Colab notebook to visualize the kernel and its matrix representation here.
Causal vs Non-Causal Convolution
Let’s take a look in the case where have non-zero values where . Again, let’s use for simplicity.
Observe that if for , then, for example, gets the contribution from where is an input signal from a future time step. If we would like to perform causal modeling where the output of the current time step is influenced only by current and previous time steps (and not future steps), all for must be .
In general, due to how we index the inputs where for in the original definition, this results in the convolution being causal, which means that the signal can only depend on input at time or before. This is because if there is a term for that contributes to , the corresponding term from is which is zero (which makes ).
In addition, throughout this blog post, we are mainly interested in mapping an input signal to and output on the same time domain , in which case we can use the truncated version where are only defined for , implying that the corresponding Toeplitz matrix is square and is lower diagonal.
Examples of Convolution
We look at a few cases of convolution to develop some intuition. First, let’s consider random variables and corresponding to rolling two dice. The probability distribution of (or ) are:
The probability distribution of requires summing over all possible combinations of and such that . That is, if , must be . Therefore, the probability distribution of is exactly the convolution between and :
This is one of examples where convolution shows up quite natarually when we deal of probability. Below, we show the illustration of the convolution results, for both a fair die scenario and an unfair one.
Convolution examples.
Extras: Convolution of Functions
The convolution of two functions and is the continuous case of the discrete version and is defined as:
Similar to the discrete case, if we have two random variables and with probability density functions and respectively, then the probability density of entails integrating all possible and such that . That is, if , must be . Therefore, the probability density of is:
which is the exactly the convolution .
Another example entails a convolution of a function and a Dirac delta. The Dirac delta function is defined as:
and
The Dirac delta function is a special function that is zero everywhere4 except at , where it is infinite. However, the integral of the function is , which implies that we can think of Dirac delta function can be thought of as a probability distribution that is entirely concentrated at .
Next, let’s consider the convolution of the Dirac delta function with another function :
We can see that the convolution of with the Dirac delta function is simply itself, or is an identity function with respect to integration. This is because the Dirac delta function is zero everywhere except when its argument . Therefore, the only contribution to the integral comes from , which is simply itself. We can see in this example that the Dirac delta as a convolution kernel performs an identity operation.
Now, let’s consider the convolution of with a shifted Dirac delta function:
The convolution of with a shifted Dirac delta function is simply shifted by ! This is because the shifted Dirac delta function is zero everywhere except at . Therefore, the only contribution to the integral comes from , which is simply shifted by .
Discrete Fourier Transform
The Discrete Fourier Transform (DFT) is a mathematical operation used to analyze the frequency components of a discrete signal. Given a discrete sequence of values for , the DFT computes a set of complex coefficients for that represent the frequency content of the signal. The DFT is defined as:
Here:
is the DFT coefficient at frequency .
is the input signal at time index .
is the total number of samples in the input signal.
The coefficients are complex and can be interpreted as the coefficients of the sine and cosine components (see more in Fourier Basis).
The original series can be recovered from via the inverse DFT (IDFT):
To show that as can recovered exactly as above, we will make use of the following results for geometric series. For and ,
In this case, and . For , the exponent is an integer multiple of , which means that and . For , and . Therefore, we can concisely write
where is the Kronecker delta function, which is if and otherwise.
Then, the iDFT becomes:
which means that we recover the original signal for perfectly from the DFT coefficients . There is no information loss in the DFT operation.
We can think of this as a duality between the spatial domain and the frequency domain. The DFT converts a signal from the spatial domain ( numbers) to the frequency domain ( numbers), and the iDFT converts a signal from the frequency domain back to the spatial domain. The DFT and iDFT are inverse operations of each other.
In terms of computational complexity, both DFT and iDFT involves summing numbers for each of the entries, thus incurs a complexity . Next, we will discuss a way to perform these operations efficiently with log linear complexity (without any approximation!!) by exploiting special structures of the complex exponentials. This is the main idea behind Fast Fourier Transform (FFT).
Fast Fourier Transform
Fast Discrete Fourier Transform or more commonly called Fast Fourier Transform (FFT) is an algorithm that computes the DFT efficiently. The FFT algorithm is based on the divide-and-conquer strategy, and is able to compute the DFT in time, which is much faster than the naive algorithm. We sketch a proof below.
Let’s start with the DFT definition:
We can rewrite this by splitting into the odd and even terms as:
We can further rewrite this as:
We can see that the first term is the DFT of the even terms of , denoted by , and the second term is the DFT of the odd terms of , denoted by , multiplied by a complex exponential. The key part is that we also obtain for free once we have and , due to the identity:
To obtain (or ), we recursively break up it up into two terms, and so on. Therefore, the computational complexity of DFT to obtain consists of the complexity of two DFT of elements plus operations, amortized by two, since for each , we also get . This gives us the following recurrence relation:
which yields . For all , this results in an algorithm for DFT.
The Convolution Theorem
In this section, we will show that a convolution of two vectors can be seen as a multiplication of their Fourier transforms. This is known as the convolution theorem. We will show this in the discrete case. The continuous case extends naturally by replacing the summation with an integral (with exchange of order justified by Fubini’s theorem).
The convolution of two sequences and is defined as:
The Fourier transform of a sequence is defined as:
Then,
where we exchange variable and the fact that for , so the sum is the same as . The proof can also be done simpler by consider the summation from to where values beyond its support are zero.
Extras: Continuous Case of Convolution Theorem
In the continuous case involving convolution of two functions and , we also provide a proof sketch below.
The convolution of two functions is defined as:
Then, the Fourier transform of the convolution is:
Or in other words,
Convolution Theorem + FFT = Log Linear Convolution
To recap, the implication of the convolution theorem is that if we want to perform a convolution of long signals and , which naively would incur computational complexity, we can reduce it to without any approximation. Below are the steps:
Compute the Fourier transform of vectors and , each of length , which incures via Fast Fourier Transform. Here, we obtain the frequency components and , each of length .
Multiply the Fourier transforms of two vectors and , incurring , and finally compute the inverse Fourier transform of , which incurs another .
In total, the convolution can be done in via Fast Fourier Transform and the Convolution Theorem, instead of the usual . Magic!
This sub-quadratic behavior allows fast long range modeling, and is the foundation of convolution models such as S4, S4d, GSS, H3, and Hyena.
B. Orthonormal Basis in Function Space
We turn our attention to a vector space whose elements are functions. A special kind of such vector space we will consider is a Hilbert space5 that is equipped with a countable and dense orthonormal basis6. In essence, a Hilbert space provides a notion of projections and distance via its inner product and the induced norm . Further, a Hilbert space has a completeness property, meaning that any sequence of elements in that draws progressively closer together actually converge within the space. As example of such a Hilbert space is the space of square integrable functions on an interval , with a corresponding inner product
where denotes the complex conjugate of . The function corresponds to the weight of different points in the interval which reflects the measure that the inner product is defined on.7 For instance, we may want to weigh points far away with smaller weights. Different weight used in the integral above (or different measure) will give rise to a different inner product, which in turn defines a different notion of orthogonality and distance in the associated Hilbert space.
The power of the dense orthonormal basis is that it allows us to represent any function in the Hilbert space using a linear combination of the basis elements. That is, for any function , there exists coefficients such that:
That is, in the orthonormal basis , a function can be represented as simply an infinite sequence . 8
It is quite profound that an entire function whose domain consists of uncountably many numbers can be described by a countable set of numbers, the coefficients of the respective basis.
We can also think of the partial sum as an approximation of the function , where the approximation gets better as (where the convergence is in the norm9). Therefore, a finite vector can also be used to approximately represent an entire function where the approximation gets better as is larger. This concept is used widely for compression of signals such as audio, images, and videos.
While this representation ensures approximation, it does not directly offer a method to find for a given . The key lies in the orthogonality of the basis. The set is orthonormal, implying , where is the Kronecker delta, a function that returns 1 when and 0 otherwise. This orthogonality simplifies our task of finding coefficients to taking inner products, where we extract the coefficient of via:
This filtering property, intrinsic to orthonormal bases, ensures that we isolate each coefficient efficiently.
Based on the weight function and the subspace of functions we operate on, the orthonormal basis can be different. For instance, for uniform weight on an interval, the Legendre polynomials form an orthonormal basis. For periodic functions with uniform weight on an interval, the Fourier series form an orthonormal basis. For an exponentially decaying weight , the Laguerre polynomials form an orthonormal basis. We will discuss these three examples below.
Example 1: Legendre Polynomials as Orthonormal Basis on Uniform Measure
The space of square-integrable functions over the interval , denoted by , is an example of a Hilbert space. The inner product is defined as . In this space, the Legendre polynomials form a countable orthonormal basis.
The Legendre Polynomials offer a rich tapestry of function spaces that play a pivotal role across mathematics and physics. Originating from the studies of celestial mechanics by Adrien-Marie Legendre, these polynomials have since been embraced in various applications spanning from quantum mechanics to approximation theory.
Definition:
The Legendre polynomial, denoted , is given by Rodrigues’ formula:
More concretely,
These polynomials are orthogonal on the interval with respect to the weight function . In other words:
where is the Kronecker delta function.
The example above can be extended to a Hilbert space defined on , where the orthonormal basis functions can be derived as follows.
We introduce a linear change of variables to adapt to the interval . By defining a new variable such that , we transform the interval to . The Legendre polynomials on the interval are then expressed as . Define as a normalized Legendre polynomial, together with and , we have:
Example 2: Fourier Basis for Periodic Functions
On with the inner product where denotes complex conjugate, the Fourier basis is an orthonormal basis when applied to periodic functions.10
The Fourier series representation of a periodic function over the interval is given by:
where are the Fourier coefficients of the basis elements , computed as:
We can extract out the coefficient from the inner product with because for distinct integers and , the basis functions and are orthogonal on the interval . Their inner product is:
For , we have , which is indeed normalized.
How do we interpret complex coefficients?
The complex representation here is a convenient way to express the Fourier basis, but we can also express it in terms of cosine and sine functions which gives us nice physical interpretations. This integer multiples of the base frequency (or simply called harmonics), with arbitrary phase in that frequency controlled by the relative coefficient of the cosine and sine functions.
Expanding and where , we have
Now, let’s simplify it doing a sum over non-negative , from to . We’ll start with the real part of the function decomposition of , which is given by:
The complex part of the function decomposition of is given by:
In short, the complex coefficients is such that their real and imaginary parts are the coefficients of the cosine and sine functions, respectively.
Real-Valued Functions
If the function is real, then , which implies that and and . That is, for real , we do not need to compute the coefficients with respect to negative frequencies due the symmetry.
The simplified components for a real becomes:
where . In this interpretation, any periodic function is a linear combinarion of cosine waves with integer multiples of the base frequency (or simply called harmonics), with arbitrary phase of that frequency controlled by the relative coefficient of the cosine and sine functions. This is the physical interpretation for a real signal we alluded to earlier.
In practice, the complex representation is more convenient to use and has wide adoption. It is also a generalization of the sine and cosine representation that is applicable for both real and complex functions.
Example 3: Laguerre Polynomials as Orthonormal Basis on Expontential Decay Measure
In the context of Hilbert spaces, another example of an orthonormal basis is provided by the Laguerre polynomials. The space of square-integrable functions over the interval , denoted as , serves as an example where these polynomials form a countable orthonormal basis.
Definition:
The Laguerre polynomial, denoted as , can be defined through Rodrigues’ formula:
In more concrete terms, the first few Laguerre polynomials are as follows:
These polynomials are orthogonal over the interval with respect to the weight function . In other words, they satisfy the following orthogonality condition:
Here, is the Kronecker delta function.
We can see that a different weight function, or a different measure, would give rise to a different set of orthonormal basis functions. This is a powerful concept that we will revisit later in the context of convolution models, especially in the construct of the HiPPO matrix.
Examples: Approximating Data with Legendre vs. Fourier Basis
We show a few examples here where we use Legendre and Fourier bases to approximate functions with finite number of basis elements . See a Colab notebook for code and more examples here for the code.
Approximating data via Legendre and Fourier Basis
If we extrapolate the data outside the original domain, for Fourier series, the pattern repeats periodically. This can also be shown analytically by observing that for any , we have in the Fourier representation.
Extrapolation via Fourier Series
C. State Space Models
State space models are a class of models used to describe the time evolution of systems. They are widely used in many fields, including control theory, signal processing, and machine learning. In this section, we will give a brief introduction to state space models, and show how they can be used to model complex dynamics.
In the continuous time scenario, linear state space models can be written as matrix-vector multiplications where the -dimensional state vector evolves over time according to the following differential equation:
where , is an matrix, is , and is . is usually called the input vector, and is the output vector.
In most cases is assumed to be .
Recurrent View of State Space Models
In the discretized case, the evolution goes from time step to , instead of infinitesimal step in the continuous time dynamics to . We can approximate by either using the derivative at , or also the average of the derivative at and for better numerical stability (bilinear/trapezoid method). With step size ,
Together with the state space equations, we can show that
That is, we can obtain the current state given the the input and only past state , without needing to know the previous states or inputs. This is a recurrent property that is useful during inference since it incurs compute complexity without dependency of the context length. This is quite different from attention where we need to use cached key and value to predict the next step, which incurs memory IO and compute during incremental decoding.
Discretized linear state space models in recurrent view and convolution view
Example: Harmonic Oscillator
Let’s develop some intuition for what state space models can do. We will consider a simple example of a spring attached to a mass . The spring has a spring constant , and the mass is attached to a wall. The mass is also subject to a force .
The dynamics of the system is described by the following differential equation:
where is the displacement from equilibrium of the mass at time . We know that in the case of , this should be a simple harmonic oscillator where the solutions are pure sine wave with certain phase (depending on the initial position and velocity). Let’s see how well we can model this system using a state space model.
The key is how we define the state . Let denote the velocity. In this case, we can see that the differential equation above can be written as:
If we define the state as
then we can describe the differential equation with state space model as:
The upper row simply says , the definition of velocity. The lower row says , exactly the equation for the acceleration. Then, we extract out the position by
In this case we use the initial condition such as
which corresponds to initial position at 1 without velocity. We also use , meaning that no additional force is applied. We can see that the position calculated via the discretized discrete state space follows the sine equation quite perfectly. See the Colab notebook here. Note that we adapted the code given in The Annotated S4 for state space models.
State space model for spring-mass with no external force
State space model for spring-mass with external force at time 0
Just for fun, let’s also consider another case where we start from initial position and velocity , but with being something like a Dirac Delta function which injects a fixed momentum at time . In this case, and for all . As shown below, the state space models are able to capture the dynamics quite nicely.
We also explore other variations such as exponentially decay amplitude due to friction and resonance where the external force has the same frequency as the natural frequency. We also show the cases where the numerical approximation can break down resulting in alising effect, once the frequency becomes too high compared to the granularity of the discretization.
Friction results in exponential decay of amplitude.
Impulse force at resonance frequency results in increasingly large oscillation.
High frequency compared to sampling frequency leads to aliasing (incorrect solution to the dynamical system).
Another aliasing example.
What Can We Model With Linear State Space Models?
In general, any linear differential equations can be framed as a linear state space. This excludes some dynamics; for instance, celestial mechanics, where the gravitational force is proportional to the inverse square of the distance between the two bodies, cannot be modeled with a linear state space.
Convolution View of State Space Models
Now let’s think about how we can process a vector input to obtain an output in batch. Due to the recurrence relations, we can write the output as
where and is the initial state, which is often taken to be zero. In this form, we see that is a convolution between the L-dimensional vector and the input vector . It means that we can think of the state space model as a convolution model, where the convolution kernel is . This unrolled view is useful to process the entire input sequence, either during inference or training.
That is, if we have a sequence of length as an input, once if we the kernel , we can compute the entire output in time due to the convolution theorem + FFT. The batched operation is important for training as well as initial processing of input sequence during inference.
SSM for Sequence Modeling
Now, you may wonder what would be the matrices , , that we should use? Remember from our earlier examples for the harmonic oscillator that the matrix controls the evolution of the dynamical system. How should we construct or interpret such a system to model something such as language modeling or time series forecasting? In a later section, we discuss the HiPPO matrix, which defines a type of matrix that can be used to model long range dependencies.
D. Einstein Summation for General Tensor Operations
We cover a brief introduction to Einstein summation, which is a concise notation for tensor operations. We will use this notation to describe attention or other tensor operations in a concise manner.
Let’s consider the attention weight between query tensor and key tensor . The query tensor is of shape where is the batch index, is the head index, is the query length index, and is the head dimension index. The key tensor is of shape where is the key length index. The attention weight tensor is of shape .
The attention weight tensor, before softmax, can be described in various ways such as
In all of the notations above, we are summing over the head dimension , which is the dimension that we are reducing. The sum over need not be explicitly mentioned if the output dimension is specified. That is, since does not contain , it implies that the output is the result of reduction over . Since this is akin to inner product over where all other axes are broadcasted, we can write it as for simplicity.
In this section, we will describe both attention (Vaswani et al., 2017) and linear attention (Katharopoulos et al., 2020). Let be a non linear operator where we may denote to emphasize that the operation is non linear over axis .
We use the same notation as in the last section where both are the head dimension index for key and value. We drop the batch dimension for simplicity.
The causality of attention is reflected in the summation over , which is the key length. That is, the query at position can only attend to the keys up to position . Note that we omit the multiplicative term for for simplicity.
More details on attention can be found in The Illustrated Attention via Einstein Summation.
Linear Attention
If is an identity function11, we can rearrange things where and equivalently interact with each other first. That is,
This is a fully linear system given . We can introduce some non linearity over the key length back by replacing and with and , where is a non linear operator over axis or .
Observe that reduces over dimension , the key/value length or context length. This means that no matter how long the sequence is, this term has the same dimensionality, and the evolution as context length increases is entirely via addition. Also, observe that
which means that the subsequent spatial step rolls into the state vector , which is a recurrent property. Hence, linear attention can be seen as a recurrent operation and helps connect the traditional attention with RNNs.
Note that in the linear attention paper (Katharopoulos et al., 2020), the author motivates the linearization from the perspective of kernel methods. That is, we can view as , which is equivalent to for some feature map . In the case of softmax, this corresponds to the exponential kernel where is an infinite dimensional feature map. Then, the author proposed an alternate feature map that is finite dimensional, for instance, the exponential linear unit .
Attention and Linear Attention
Notes and Observations
The difference between transformer attention and linear attention is the order of reduction and the non linearity operation along the length dimension.
Linear attention inspires the H3 architecture choice where the difference is the map where in H3 is based on a state space model.
In the traditional attention, is an outer product with respect to and . However, in linear attention, is an outer product with respect to the head dimension and .
Therefore, in linear attention, the quadratic complexity is shifted to the head dimension () in with linear complexity in sequence length. This is in constrast to the traditional attention, we have a quadratic complexity in length () in the expression but linear complexity in the head dimension. This is due to the artefact of the order of operations: whether we interact and first or and first.
For linear attention, the term depends implicitly on the length (even though the different length results in the same dimension of this state tensor). During training, needs to be computed for every for causality. That is, is computed as . This is in contrast to traditional attention where is computed as where the term has associated contextual representation for each explicitly.
The recurrence property in linear attention allows incremental decoding in constant time, if we already have the state up to length and want to process the next step . This is in contrast of the linear time complexity in traditional attention.
Long Convolution Models
HiPPO Framework for History Representation via State Space
The HiPPO matrices (Gu et al., 2020; Voelker et al., 2019) are the matrices associated with state space models, obtained for input memorization problem where we want the state to capture , the entire input12 up to time . This can be very useful for long range modeling where we can build a contualized representation where the output feature at a given time step represents the entire past. In this section, we will walk through how to approach this problem.
First, the state vector in state space models is dimensional. We want such a vector at time to represent the entire past of the input from time to . What would be a good way to represent the entire history in an dimensional vector, where the history can be arbitrarily long?
The key is to think about it in function space, where the input vector is interpreted as evenly spaced points from a continuous time function defined up to . Then, we choose the definition of the state space to represent the first coefficients according to some orthonormal basis in the function space. Such definition will yield particular matrices .
HiPPO Framework -- Illustration from HiPPO paper.
Below, we outline the derivation steps of the HiPPO framework.
Define how we want to weigh different points in the history. For long range modeling, a sensible choice would be to use uniform weighting so that we can take signals from points far away. In a more technical term, we choose the appropriate measure that defines the integral, which defines the inner product for the Hilbert space. For uniform weighting, we can use a measure scaled such that the total measure is 1 from time to when we want to memorize input up to time . Let’s call this measure 𝟙, which depends on .
Based on inner product (which incorporates the weight), we can then choose an orthonormal basis. For uniform weighting, the (scaled) Legendre polynomials form an orthonormal basis with respect to the inner product .
We choose the dimensional state vector to be exactly the coefficients where . The function representation via the Legendre polynomials is .
Since these coefficients correspond to the projection of onto the basis , the function representation minimizes the distance between and the true input . In other words, the function represented by the coefficients is the best N-dimensional approximation of in the Hilbert space with respect to the inner product . As becomes larger, the distance becomes arbitrarily small.
With this definition , we can show that or can be expressed a linear combination of and , which means that it is a linear state space model! The details can be found in (Gu et al., 2020), Appendix D, where there are derivations for different measures as well.
In the vectorized form, the derivative of the state vector where is the coefficient for Legendre orthonormal polynomials that best represent can be written as
where
and are the associated HiPPO matrices obtained from the derivation outlined above.
This is a time-varying state space model since the matrix depends on ! Due to this time dependence, it no longer be expressed as a convolution.
However, dropping actually works in practice and enables long range modeling in S4 model (Gu et al., 2022; Voelker et al., 2019). According to (Gu et al., 2023), we can show that this time-invariant version is a valid state space model that corresponds to using exponentially warped Legendre polynomials. Therefore, using directly is a valid time-indendendent state space model.
S4: Structured State Space
From the previous section, the HiPPO matrix provides a state space model definition that can map input signal to an output signal that captures the past history at each time step.
The next challenge to be addressed is how to compute the convolution kernel for state space models efficiently. Once we have , computing the output from the input is fast via the convolution theorem and Fast Fourier Transform.
Recall that the convolution kernel for state space models can be written as:
where is the associated matrix for discrete state space dynamics of .
That is, the convolution kernel requires obtaining for all . Since is an matrix, and we need to do it times, the computational complexity with naive matrix multiplication is . Here are the considerations outlined in the S4 paper.
At first glance, attempting to diagonalize seems reasonable since if for some diagonal matrix , then where can be done in time. However, the authors show that this is not numerically stable since entries of that diagonalize the HiPPO matrices are exponentially large in state size .
A more ideal scenario is if were diagonalizable by a unitary matrix . Note that a unitary matrix has properties and is very well-conditioned, hence will not suffer numerical instabilities. Such a matrix that is diagonalizable by a unitary matrix is called normal.
The HiPPO matrix is not normal. However, according to Theorem 1 of (Gu et al., 2022), it can be written as normal plus low rank (NPLR).
The paper S4 developed an algorithm to compute the convolution kernel specifically for state space models whose is NLPR. (See algorithm 1 in (Gu et al., 2022)) Specifically, Theorem 3 in (Gu et al., 2022) states that can be obtained with operations and .
Therefore, we now have a method to obtain the contextualized output from the input quite efficiently. First, by obtaining the convolution kernel which costs , then compute the output via the convolution theorem and FFT, which costs .
Overall, the model performs well on various long range modeling tasks, including the Long Range Arena benchmark (Tay et al., 2021).
More details on the proofs are covered extensively in the blog post The Annotated S4 as well as the original paper.
Diagonal State Spaces
Interestingly, (Gupta et al., 2022) also shows that the using the normal part of the Hippo matrix without the low rank correction also works well in practice, with performance matching the original S4.
Normal plus low rank (NPLR) in S4 can also be conjugated into a diagonal plus low rank matrix (DPLR). That is,
In general, the state space dynamics described by () or () are equivalent since it yields the same convolution kernel and can be seen as change of basis of the internal representation of the states .
That is, we can equivalently use as the Hippo matrix. In the diagonal case, we use only . This diagonal representation drastically simplifies the algorithm.
Practically, the paper emphasizes the importance of using the diagonal part of HiPPO matrix’s DLPR representation instead of random initialization, where the random initialization is shown to be less effective. In addition, there are considerations to constrain the real parts of the diagonals (the diagonals are the eigenvalues) to be non positive, which is reasoned to be essentially for long range, otherwise and values of can be arbitrarily large as the length increases.
Later, in (Gu et al., 2022), the authors show that the diagonal version of the HiPPO matrix is a noisy approximation and becomes closer as the dimension of internal state space approaches infinity. This does not hold for any NPLR matrix, but arises from the structural properties of the HiPPO matrix.
GSS: Gated State Space Models and Gated Attention Units
Gated State Space models (GSS) (Mehta et al., 2023) builds on two lines of work. First is the state space models, where the paper adopts the diagonal state space.
Second is the gating mechanism as an alternative activation function which has been shown to improve model performance and can be seen as a multiplicative residual connection, allowing the gradients to flow back freely. The gating mechanism also allows using weaker attention mechanism without quality degradation. We cover the literature of gating mechanism in Gating section.
In particular, GSS adopts Gated Attention Units (GAU) (Hua et al., 2022) with a diagonal state space model instead of the traditional attention, together with the input-controlled gating mechanism proposed in GAU.
Below is a simplified version of the GSS model (without dimensionality reduction before the state space model). Given an input , the model performs linear projection over hidden dimension into tensors.
Then, the output is computed as
where is the contextualized representation over spatial domain.
In contrast, an alternate contextualized representation is the attention mechanism where and .
Below are some additional observations from the paper:
The GSS paper conducted experiments aimed for language modeling and where the compute is of much larger scale that previous work on state-space models.
Contrary to the previous work on state space models, this paper found that for language modeling tasks, initialization does not matter significantly. This is in contrast to the sensitivity of initialization in (Gu et al., 2022; Gupta et al., 2022; Gu et al., 2022).
The paper observed consistent generalization to longer inputs; while the training uses up to 4k length, the model is evaluated on sequence lengths up to 65k where the performance becomes significantly better with longer context.
Aside: There are a few considerations in the paper that makes the model runs faster on accelerators by projecting the input into lower dimensionality before the state space model stage. We omit this step for simplicity. See the paper for extensive comparison with Block Recurrent Transformers (Hutchins et al., 2022) and great coverage of related work on long range modeling transformer architectures.
H3: Hungry Hungry HiPPOs
The goal for H3 is to address the gap between previous state space models (S4, S4d, etc.) and transformers. The paper to improve (1) the expressivity and (2) the computational efficiency of state space models.
Improved Expressivity
H3 draws inspiration from the attention mechanism in transformers and long range modeling with state space models.
Starting from the input to the current layer, we project it to the query , key and value tensors, similar to the input projections in transformers attention.
Perform shift state space model on . This can be seen as a non linear operation on the length axis of , similar to how the linear attention uses a non linear operation on individual tensors (instead of applying non linear function on entire like in the traditional attention).
The motivation for the shift SSM operation is due to the observation that SSMs struggle with recalling earlier tokens and comparing tokens across sequences.
Drawing inspiration from the linear attention model where and interact first, we perform a diagonal SSM on .
Multiply with to obtain the output.
Here, we offer two illustrations. The first one considers the case of single input single output (SISO), where the interaction between tensors are via elementwise multiplication.
Another illustration is in the batch case where we illustrate how we split the input into multiple heads, and how the interaction between tensors are via einsum. Then, the head outputs are grouped together from all heads to final output.
Improved Computational Efficiency
The H3 paper is the first to develop fused kernels specialized for convolution and FFT to increase the hardware utilization.
The paper proposed a state passing algorithm that allows SSM to scale to very large context length.
Hyena Hierarchy
Hyena model takes a departure from the state space literature where the dynamics are explicitly defined by the matrices . Instead, the Hyena model uses a convolutional approach where the convolution kernels are constructed based on a learned mapping from position embeddings. This is considered an implicit13 convolution as opposed to explicit convolution where the convolution kernel does not necessarily depend on the data.
A recurrent model can be seen as a global convolution model. However, the reverse is not true. The Hyena model is a global convolution model that is not recurrent.
Another key aspect of the Hyena model is the generalization of how we think of projected layer inputs. In models such as attention or H3, we think of the projected layer inputs as Q,K,V tensors. In the Hyena model, these projected layer inputs are generalized to be M + 1 copies without any specific interpretation.
The Hyena model can be seen as a generalization of the H3, GSS, S4/DSS models.
The Hyena paper emphasizes the operator perspective where view the attention or the convolution as a data-dependent operator.
In transformers, such data-dependent operator is controlled by query and key (and is O(L^2)). Such data-dependent operator acts on the value vector V, which then produces the output O.
Convolution models can also be seen as data-dependent operators, where the operator is controlled by the input function. The key difference here is that due to the convolution theorem and Fast Fourier Transform, the operator can be computed in O(L \log L) instead of O(L^2).
Illustrated Global Convolution Models
The overall goal for the illustration is to precisely describe the specification of various models via diagrams with a consistent notation. By adopting a consistent notation in a unified illustration, we hope to provide a clear picture of how different models are related and how they differ.
The diagrams are meant to be as precise as possible, in a sense that one can use it to translate to code without ambiguity (except some constant scaling which are omitted for simplicity). This means we have to portrays the case where the input feature is high-dimensional, in constrast to the single-input single-output description where the context operator deals with a single feature dimension (vector in and vector out).
While most of the techniques presented in this blog is about context operator which mixes information along the length axis, in the high-dimensional feature case, there are also crucial steps related to mixing different feature dimensions together. This happens at the stage of reading the input and writing the output.
For instance, given a layer input , we may project it with linear maps to obtain tensor (or ), where each feature in is a linear combinations of all features in the input. This is the reading the input step. (See Transformers Circuits (Elhage et al., 2021) for details regarding this view input reading from residual stream).
There can be many channels of these projected inputs. For instance, in transformers, we can see has separate channels where each channel operates independently until the writing stage. Each channel input is -dimensional where is the feature dimension of the layer input .
In Hyena, a d-dimensional input is projected to copies of dimensional tensors with channels where each channel has feature dimension.
After the different views of inputs (such as ) are obtain, a context operators mixes information along the spatial dimension (length dimension) and produces the output.
Then, the channel outputs are aggregated together via either simple concat and potentially another linear projection. This is the output writing step.
Gating can be seen as einsum in general. for instance, is an element-multiplication along and an outer product along simultaenously. Using this gating mechanism right after a complex operation, especially a non-linear one, may help gradients to flow freely, as hypothesized in the gating literature. (See Appendix Gating for more details)
Illustration of Global Convolution Models
FAQs
Coming soon! (please leave questions below in Gitcus)
References
Efficiently Modeling Long Sequences with Structured State Spaces
Albert Gu, Karan Goel, and Christopher Ré
In The Tenth International Conference on Learning Representations, ICLR
2022, Virtual Event, April 25-29, 2022, 2022
On the Parameterization and Initialization of Diagonal State Space
Models
Albert Gu, Karan Goel, Ankit Gupta, and
1 more author
In NeurIPS, 2022
Long Range Language Modeling via Gated State Spaces
Harsh Mehta, Ankit Gupta, Ashok Cutkosky, and
1 more author
In The Eleventh International Conference on Learning Representations,
ICLR 2023, Kigali, Rwanda, May 1-5, 2023, 2023
Hungry Hungry Hippos: Towards Language Modeling with State Space Models
Daniel Y. Fu, Tri Dao, Khaled Kamal Saab, and
3 more authors
In The Eleventh International Conference on Learning Representations,
ICLR 2023, Kigali, Rwanda, May 1-5, 2023, 2023
Hyena Hierarchy: Towards Larger Convolutional Language Models
Michael Poli, Stefano Massaroli, Eric Nguyen, and
6 more authors
In International Conference on Machine Learning, ICML 2023, 23-29 July
2023, Honolulu, Hawaii, USA, 2023
HyenaDNA: Long-Range Genomic Sequence Modeling at Single Nucleotide
Resolution
Eric Nguyen, Michael Poli, Marjan Faizi, and
10 more authors
CoRR, 2023
Attention is All you Need
Ashish Vaswani, Noam Shazeer, Niki Parmar, and
5 more authors
In Advances in Neural Information Processing Systems 30: Annual Conference
on Neural Information Processing Systems 2017, December 4-9, 2017,
Long Beach, CA, USA, 2017
Transformers are RNNs: Fast Autoregressive Transformers with Linear
Attention
Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and
1 more author
In Proceedings of the 37th International Conference on Machine Learning,
ICML 2020, 13-18 July 2020, Virtual Event, 2020
HiPPO: Recurrent Memory with Optimal Polynomial Projections
Albert Gu, Tri Dao, Stefano Ermon, and
2 more authors
In Advances in Neural Information Processing Systems 33: Annual Conference
on Neural Information Processing Systems 2020, NeurIPS 2020, December
6-12, 2020, virtual, 2020
Legendre Memory Units: Continuous-Time Representation in Recurrent
Neural Networks
Aaron Voelker, Ivana Kajic, and Chris Eliasmith
In Advances in Neural Information Processing Systems 32: Annual Conference
on Neural Information Processing Systems 2019, NeurIPS 2019, December
8-14, 2019, Vancouver, BC, Canada, 2019
How to Train your HIPPO: State Space Models with Generalized Orthogonal
Basis Projections
Albert Gu, Isys Johnson, Aman Timalsina, and
2 more authors
In The Eleventh International Conference on Learning Representations,
ICLR 2023, Kigali, Rwanda, May 1-5, 2023, 2023
Long Range Arena : A Benchmark for Efficient Transformers
Yi Tay, Mostafa Dehghani, Samira Abnar, and
7 more authors
In 9th International Conference on Learning Representations, ICLR 2021,
Virtual Event, Austria, May 3-7, 2021, 2021
Diagonal State Spaces are as Effective as Structured State Spaces
Ankit Gupta, Albert Gu, and Jonathan Berant
In NeurIPS, 2022
Transformer Quality in Linear Time
Weizhe Hua, Zihang Dai, Hanxiao Liu, and
1 more author
In International Conference on Machine Learning, ICML 2022, 17-23 July
2022, Baltimore, Maryland, USA, 2022
Block-Recurrent Transformers
DeLesley Hutchins, Imanol Schlag, Yuhuai Wu, and
2 more authors
In NeurIPS, 2022
Do Transformer Modifications Transfer Across Implementations and Applications?
Sharan Narang, Hyung Won Chung, Yi Tay, and
13 more authors
In Proceedings of the 2021 Conference on Empirical Methods in Natural
Language Processing, EMNLP 2021, Virtual Event / Punta Cana, Dominican
Republic, 7-11 November, 2021, 2021
GLaM: Efficient Scaling of Language Models with Mixture-of-Experts
Nan Du, Yanping Huang, Andrew M. Dai, and
24 more authors
In International Conference on Machine Learning, ICML 2022, 17-23 July
2022, Baltimore, Maryland, USA, 2022
LaMDA: Language Models for Dialog Applications
Romal Thoppilan, Daniel De Freitas, Jamie Hall, and
54 more authors
CoRR, 2022
A Mathematical Framework for Transformer Circuits
Nelson Elhage, Neel Nanda, Catherine Olsson, and
22 more authors
The role of gating has been used extensively in deep learning literature.
Gated Linear Units
(Dauphin et al., 2017) introduced gated linear units, which processes the layer input by projections into , modulated by element-wise multiplication with the input-dependent gates , where denotes the sigmoid function. That is,
Gated linear units can be seen as a multiplicative version of residual connection. That is, the gradient of the gated linear unit has a path without coupling with the non linearity , which can be arbitrarily low in some region of and could potentially suppress the gradient.
(Shazeer, 2020) shows that variants of GLU are quite effective for transformers such as SwiGLU when used in place of ReLU or GELU between the two linear projections in feedforward layers. SwiGLU is adopted is large scale models such as PaLM (Chowdhery et al., 2023).
gMLP
(Liu et al., 2021) proposes gMLP for vision tasks, which uses a gating mechanism similar to GLU, but with a different formulation. gMLP enables cross-token interactions via a linear projection coupled with gating as
where controls cross token interactions.The author finds that it is also effective to split into two parts along the channel dimension and instead uses
The comparisons shown in (Liu et al., 2021) indicate that attention is not critical for vision tasks, and the degradation in some NLP tasks can be compensated by making gMLP larger. This is an interesting experiment that attention may not be necessary and can be compensated via this form of gating and scale.
where , which describes token-token attention weights. This formulation allows contextualized gating via instead of gating by the same token as in MLP.
The paper uses the matrix as the query-key attention matrix
The paper uses two GAU layers as a replacement for MLP (or GLU) + multi-head attention, with where is the hidden dimension resulting in comparable number of parameters in both scenarios. The paper finds that, consistent with (Liu et al., 2021), gating allows a simpler or weaker attention mechanism without quality degradation and also incorporates linear attention, where in linear attention, uses first, rather than first.
More on Fourier
As illustrated in Fourier Basis, any periodic function can be written as a linear combination of sine and cosine, or more compactly, as a linear combination of complex exponentials.
However, a general function that is not peroidic can also be expressed in terms of the continuous-spectrum Fourier Transform. That is, the frequency component needs not be multiples of a base frequency (harmonics), but can be an entire continuous spectrum.
The Fourier transform of a function is defined as:
where is the frequency. Note that the complex notation allows us to extract components of both sine and cosine at once. If the Fourier transform is real, then frequency belongs to the cosine wave, and if it is pure imaginary, then the frequency belongs to the sine wave. A general complex number indicates the phase of the frequency component.
Example: Fourier Transform of a Sine Wave
Let’s look at a simple example where is a sine wave with frequency :
Then,
where is the Dirac delta function. We can see that the Fourier transform of a sine wave is a linear combination of two Dirac delta functions at and . The pure imaginary frequecy means that the frequency belongs to the sine wave (phase zero). We can see that in this case, the Fourier transform yields the same results as Fourier series representation – that is, we only need one frequency to represent a sine wave.
Example: Fourier Transform of a Truncated Sine Wave
Another example is where is a truncated sine wave. That is, is a sine wave for and zero otherwise. Then,
Here, due to truncatation, the function is no longer periodic and results in frequency components that spread across the spectrum. However, we can see that they are concentrated at and dissipates as gets larger.
Example: Fourier Transform of a Box Function
Let’s look at an example of the Fourier transform of a box function:
It can be shown that the Fourier transform of is:
A high level interpretation of this Fourier transform is that a box function has frequency components even at infinitely high frequencies. However, the contribution of such high frequencies get small as gets larger, since the range of is bounded and term gets smaller.
Fun Facts
The Fourier Transform is well-defined for any absolutely integrable function, which includes probability densities. Another name used for Fourier transform of a probability density is a characteristic function, defined as . In fact there is a one-to-one correspondence between a probability density and its Fourier transform. This characteristic function is often much easier to deal with than the density function itself; for example, one can easily prove the Central Limit Theorem using characteristic functions by showing that the characteristic functions of converges to the characteristic function of a Gaussian distribution, which implies that converges to a Gaussian in distribution.
Footnote
We will focus on the single input and single output case for simplicity. ↩
In general, the vectors can have different sizes, but for simplicity, we will assume that they have the same size. ↩
In general, we only need a Dirac delta function to be relatively zero everywhere except at . This is the case for which is not exactly zero outside of but is relatively zero compared to at . ↩
A Hilbert space is a vector space equipped with an inner product that is complete with respect to the norm induced by the inner product . Completeness means that any Cauchy sequence in converges to an element in within itself. The convergence is with respect to the norm, which provides a notion of distance. To elaborate on completeness further, a Cauchy sequence is a sequence such that elements are getting closer and closer together as . Intuitively, a complete space means that there are no unexpected “holes” in the space where any sequence that is supposed to converge (a Cauchy sequence) actually converges inside the space itself. ↩
A Schauder basis is a countable basis that is dense in a complete normed space (a Banach space). To say that a basis is dense means that any element of the space can be arbitrarily closely approximated by a finite linear combination of basis elements. In other words, for every element in the space and any given small positive distance (), there exists a finite sum of basis elements that is within that distance of the given element. This ensures that the basis “spans” the entire space in a limiting sense, even if any specific finite subset of the basis does not span the space.
A crucial distinction to note is that a Schauder basis doesn’t need to be mutually orthogonal. This is because it’s typically defined in the context of a Banach space, where angles or orthogonality might not be relevant concepts. Yet, in a Hilbert space, it’s entirely possible to orthogonalize a Schauder basis using the Gram-Schmidt process. ↩
The measure theoretic way is to define the inner product as where is the measure. In the case of , the measure is where is the weight function and is zero outside the interval . ↩
It turns out that such representation is also unique, which implies the an isomorphism between the space of continous functions and the space of sequences of numbers. More precisely, with a Schauder basis is isomorphic to , the space of square summable sequences. It is quite profound that we can uniquely represent a function that takes values on uncountably many points with a sequence of numbers, which is an countable set! The special structure of the orthonormal basis allows this to happen. ↩
The convergence here is in the norm defined by the inner product of the Hilbert space. The proof for convergence is out of scope for this post. ↩
is periodic in the interval if . In practice, any function can be made to be periodic via padding. For instance, if a function is defined over , then we can pad with zeros from to and also from to . This would result a function that is periodic over . ↩
In the original linear attention paper, the author keeps the softmax denominator. However, we can view it as an identity function with respect to the length dimension since the denominator involve , which is independent of and still allows the reodering of the summation and . ↩
In the case of state space models, we often think of as a function even though in the discretized version, it is a sequence. This is because we can think of as a function that is sampled at discrete time points. In other words, is a function that is defined for all , but we only observe it at discrete time points. ↩
In general, an implicit convolution means that given an input , the convolution filter depends on and also potentially some parameters . In the case of Hyena, the dependence on the data is not as strong – any input of the same length yields the same position embeddings, hence the same convolution kernels. ↩