Multi-query attention can be much more efficient under large batch and context length.
Multi-query attention was first introduced in and was later used in PaLM for inference efficiency. In this blog, we will analyze why multi-query can be much more efficient than the traditional multi-head attention.
Multi-Query Attention at a Glance
The key difference of multi-query attention is to collapse all the heads of the projection matrices \(P_K\) and \(P_V\) to have only 1 output head instead of full \(h\) heads. All other projection matrices (\(P_Q\) and \(P_O\)) still have sizes hdk. \(P_K\) and \(P_V\) have the size reduced from hdk to dk.
Note that given an input \(x\) with hidden dimension \(d\), during incremental decoding, \(x\) is still projected to many heads during to produce the query tensor (since the query has h heads). Since the query has many heads, the fact that key and value tensors have 1 head still leads to multiple head-interactions during logits and output computation. The single head in key and value tensors is broadcasted to perform attention with all the heads with \(Q\).
To see why such a simple change can lead to dramatically higher efficiency during incremental decoding, we provide background on counting the memory access and computation required for each tensor operation (einsum). Note: One can refer to The Illustrated Attention via Einstein Summation for the introduction to einsum.
Operation and Memory Access Counting (short version)
At a high level, the number operations and memory access for the tensor computation \(\langle A,B \rangle \to C\) are:
Number of memory access: \(\small \mathcal{O}(\vert A \vert + \vert B \vert + \vert C \vert )\) where \(\small \vert A \vert\) is the size of the tensor A (product of all dimensions). This is because to access each input or output, we need to either read from it or write to it at least once.*
Number of computations: \(\small \mathcal{O}( \text{product}(\text{distinct dimensions in A and B})))\).
For example, \(\small \langle bhnv, hdv \rangle \to bhnd\) requires
\(\small \mathcal{O}(bhndv) = \mathcal{O}(bnd^2)\) number of operations
and \(\small \mathcal{O}(bhnv + hdv + bhnd)\) memory access for both of the inputs as well as the output.
Operation and Memory Access Counting (longer version, can be skipped)
The number of operations for \(A,B \to C\) is the number of duplicates * the number of base operations.
Example 1: \(bhnk, bhmk \to bhnm\) has \(bh\) number of duplicates where the base operation is \(nk,mk→ nm\) since \(bh\) are the dimensions that are shared across all inputs and output. This matrix multiplication \(nk,mk \to nm\) requires \(nmk\) operations. Therefore, total number of operations is \(\mathcal{O}(bh * nmk )\).
Note. for \(nk,mk \to nm\), \(n\) and \(m\) are the non-interacting dimensions and \(k\) is the interacting dimension (getting summed over). The number of operations in general equals product(set(non-interacting dimensions)) * interacting dimension = nm * k.
Example 2: \(bhnv, hdv \to bnd\). In this case, there’s no duplicate dimensions across inputs and output. Since this can be framed as \(bn * hv, d * hv \to bnd\), we see that bn and d are the non-interacting dimensions and hv are the interacting one. Therefore, the number of operations is \(\mathcal{O}(bnd * hv )\)
In general, this is equivalent to product(set(A, B)) where A and B here represent the dimensions.
Memory IO Cost
Now we can analyze the memory IO cost for multi-head and multi-query attention.
Incremental Decoding
Main Takeaway
The calculations that incur the highest amount of memory access for normal multi-head attention are the logits and output calculations which involves the following tensor operation (for logits)
Multi Head \(\langle q,K \rangle : bhk, bhmk \to bhm\)
Here, there are bhmk number of operations but it requires bhmk memory access, which is the memory-bound regime (rather than the compute bound) and is inefficient. In contrast, for multi-query,
Multi Query \(\langle q,K \rangle : bhk, bmk \to bhm\)
which requires only bhk + bmk memory access.
Aditional Details
The following table provides analysis for number of operations and memory access cost (in terms of tight complexity bounds) for both the traditional multi-head attention versus multi-query attention.
The color red denote the change due to multi-query attention. Other operations are the same across multi-attention and multi-head if the difference is not stated explicitly.
Note: The number of operations are the same for multi-query and multi-attention
Table 1: Memory Access and Computation Complexities for Incremental Decoding with Multi-Head and Multi-Query Attention.
Note: \(r\) is the ratio of memory access complexity versus computation complexity. A ratio close to 1 would indicate that there are 1-to-1 memory access per computation, which would be very inefficient. An unfused softmax or dropout is such examples of IO inefficienct operations.
Observations
for \(b \sim 1\) or \(m \sim d\), the number of memory access is high compared to the number of operations
For multi-query, the offending term \(m/d\) is reduced by \(h\) to \(m/(dh)\).
Batch Computation Cost for Multi-Head Attention (can be skipped)
Batch computation in this case refers to when we compute attentions corresponding to n tokens. The analysis below shows that the number of memory access per operation is much less than 1-to-1 in which makes it quite efficient.
The table below shows the analysis per each operation. The memory access complexity are the same for both multi-head and multi-query. In practice, the multi-query setting is slightly faster due to lower constants. (In MQ, some \(d^2\) terms are reduced to \(dk\), for example, but the total complexity is still bounded by \(d^2\))
Table 2: Memory Access and Computation Complexities for Batch Computation with Multi-Head and Multi-Query Attention. Note that we use n and m for final calculation of memory access and number of computations quite interchangeably since they are the same.
At the end of the calculations, we use \(n=m\) for the usual context encoding case (where the query and key inputs are the same).
Note: We perform some approximations such as (1) \(dk < d^2\) and (2) \(bnk < bnd\) to arrive at the total memory access.
To approximate the total computation, we assume that \(d >> n\) which means that \(bnd^2 >> bn^2d\), so the latter can be ignored.
Both MQ and MH have the same memory access complexity in the batch case, leading to the same efficiency for context encoding.
Implications
The context encoding is the compute-bound regime where all query and key interact over all positions at once. Typically, for a ~10B model, this context encoding latency on a single GPU can be around 400 ms for 2000 input length. This equates to roughly 0.1 ms per token on average. In contrast, the per token latency of such a model would typically be around ~10 ms at best. We can see that the incremental decoding is roughly 100 times (10 ms / 0.1 ms) less efficient.
One can typically perform incremental decoding with similar latency while increasing batch size from 1 up to a certain batch size where GPU memory would hit the limit. Increasing batch size increases inference efficiency since the model parameters are used to compute over many samples rather than just 1.
Multi-query can help reduce the memory consumption during incremental decoding quite significantly, and also help flatten the inference latency to increase much slower than in the MH case when batch size b or context length m increase.
Note - The dimensionality reduction of \(P_K\) and \(P_V\) leads to lower number of parameters (for example, 13B multi-head attention model becomes 10.5B multi-query model, fixing all other configurations constant). In order to scale up the multi-query attention model to be of similar size, one can increase other configurations.
Plot on latency and memory consumption – coming soon!