Memory IO Efficiency of Multi-Query Attention

Multi-query attention can be much more efficient under large batch and context length.

Multi-query attention was first introduced in and was later used in PaLM for inference efficiency. In this blog, we will analyze why multi-query can be much more efficient than the traditional multi-head attention.

Multi-Query Attention at a Glance

The key difference of multi-query attention is to collapse all the heads of the projection matrices \(P_K\) and \(P_V\) to have only 1 output head instead of full \(h\) heads. All other projection matrices (\(P_Q\) and \(P_O\)) still have sizes hdk. \(P_K\) and \(P_V\) have the size reduced from hdk to dk.

Note that given an input \(x\) with hidden dimension \(d\), during incremental decoding, \(x\) is still projected to many heads during to produce the query tensor (since the query has h heads). Since the query has many heads, the fact that key and value tensors have 1 head still leads to multiple head-interactions during logits and output computation. The single head in key and value tensors is broadcasted to perform attention with all the heads with \(Q\).

To see why such a simple change can lead to dramatically higher efficiency during incremental decoding, we provide background on counting the memory access and computation required for each tensor operation (einsum). Note: One can refer to The Illustrated Attention via Einstein Summation for the introduction to einsum.

Operation and Memory Access Counting (short version)

At a high level, the number operations and memory access for the tensor computation \(\langle A,B \rangle \to C\) are:

Operation and Memory Access Counting (longer version, can be skipped)

Memory IO Cost

Now we can analyze the memory IO cost for multi-head and multi-query attention.

Incremental Decoding

Main Takeaway The calculations that incur the highest amount of memory access for normal multi-head attention are the logits and output calculations which involves the following tensor operation (for logits)

Multi Head \(\langle q,K \rangle : bhk, bhmk \to bhm\)
Here, there are bhmk number of operations but it requires bhmk memory access, which is the memory-bound regime (rather than the compute bound) and is inefficient. In contrast, for multi-query,
Multi Query \(\langle q,K \rangle : bhk, bmk \to bhm\) which requires only bhk + bmk memory access.

Figure 2: Multi-Query Attention vs Multi-Head Attention. Multi-query is almost identical to multi-head except for 1 head for the key and value projection matrices.

Aditional Details

The following table provides analysis for number of operations and memory access cost (in terms of tight complexity bounds) for both the traditional multi-head attention versus multi-query attention.


Table 1: Memory Access and Computation Complexities for Incremental Decoding with Multi-Head and Multi-Query Attention.

\[\scriptsize{ \begin{array}{l|l|c|c} \textbf{Operation} & \textbf{Einsum} & \textbf{Memory Access} & \textbf{Computation} \\\hline \text{Input (x) : bd} & & \\ \rule{0pt}{2em} q = \langle x, P_q \rangle & bd,hdk \rightarrow bhk & bd + hdk = bd + d^2 & bdhk = bd^2 \\ \rule{0pt}{1.5em} K = \langle x, P_k \rangle \ (+ K_{prev}) & [MH] \ bd,{\color{red}{h}} dk \rightarrow b{\color{red}{h}}k \ (+ bm{\color{red}{h}}k) & bd + {\color{red}{d^2}} & bdhk = bd^2 \\ & [MQ] \ bd,dk \rightarrow bk \ (+ bmk) & bd + {\color{red}{dk}} & \\ \rule{0pt}{2em} V = \langle x, P_v \rangle \ (+ V_{prev}) & [MH] \ bd,{\color{red}{h}}dv \rightarrow bhv \ (+ bm{\color{red}{h}}v) & bd + {\color{red}{d^2}} & bdhv = bd^2 \\ & [MQ] \ bd,dv \rightarrow bv \ (+ bmv) & bd + {\color{red}{dv}} & \\ \rule{0pt}{2em} \text{logits} = \langle q, K \rangle & [MH] \ bhk,b{\color{red}{h}}mk \rightarrow bhm & bhk + bhmk = bd + bm{\color{red}{d}} & bhmk = bmd \\ & [MQ] \ bhk,bmk \rightarrow bhm & bd + bm{\color{red}{k}} + {\color{red}{bhm}} & \\ \rule{0pt}{2em} \text{weights: softmax} & & bhm & bhm \\ \rule{0pt}{2em} \text{out(O)} = \langle \text{weights}, V \rangle & [MH] \ bhm,b{\color{red}{h}}mv \rightarrow bhv & bhm + bhmv = bhm + bm{\color{red}{d}} & bhmv = d \\ & [MQ] \ bhm,bmv \rightarrow bhv & bhm + bm{\color{red}{v}} + {\color{red}{bhv}} & \\ \rule{0pt}{2em} y=\langle O, P_O \rangle & bhv,hdv \rightarrow bd & bd + d^2 & bdhv = bd^2 \\ \rule{0pt}{2em} \text{Total}\text{: Multi Head} & & bd + bmd + d^2 & bhm + bm{\color{red}{d}} + bd^2 \approx bd^2 \\ \text{Total}\text{: Multi Query} & & bd + bm{\color{red}{k}} + d^2 & \\ \hline \rule{0pt}{1em} r: \text{Multi Head} & & 1/d + m/{\color{red}{d}} + 1/b & \\ r: \text{Multi Query} & & 1/d + m/({\color{red}{dh}}) + 1/b & \\ \end{array} }\]

Note: \(r\) is the ratio of memory access complexity versus computation complexity. A ratio close to 1 would indicate that there are 1-to-1 memory access per computation, which would be very inefficient. An unfused softmax or dropout is such examples of IO inefficienct operations.

Observations

Batch Computation Cost for Multi-Head Attention (can be skipped)

Batch computation in this case refers to when we compute attentions corresponding to n tokens. The analysis below shows that the number of memory access per operation is much less than 1-to-1 in which makes it quite efficient.

The table below shows the analysis per each operation. The memory access complexity are the same for both multi-head and multi-query. In practice, the multi-query setting is slightly faster due to lower constants. (In MQ, some \(d^2\) terms are reduced to \(dk\), for example, but the total complexity is still bounded by \(d^2\))


Table 2: Memory Access and Computation Complexities for Batch Computation with Multi-Head and Multi-Query Attention. Note that we use n and m for final calculation of memory access and number of computations quite interchangeably since they are the same.

\[\scriptsize{ \begin{array}{l|l|c|c} \textbf{Operation} & \textbf{Einsum} & \textbf{Memory Access} & \textbf{Computation} \\\hline \text{Input M, N : bmd, bnd} & & \\ \rule{0pt}{2em} q = \langle N, P_q \rangle & bnd,dhk \rightarrow bhnk & bnd + dhk = bnd + d^2 & bndhk = bnd^2 \\ \rule{0pt}{1.5em} K = \langle M, P_k \rangle & [MH] \ bmd,d{\color{red}{h}}k \rightarrow b{\color{red}{h}}mk & bmd + {\color{red}{d^2}} & bmdhk = bmd^2 \\ & [MQ] \ bmd,dk \rightarrow bmk & bmd + {\color{red}{dk}} & \\ \rule{0pt}{2em} V = \langle M, P_v \rangle & [MH] \ bmd,d{\color{red}{h}}v \rightarrow b{\color{red}{h}}mv & bmd + {\color{red}{d^2}} & bmdhv = bd^2 \\ & [MQ] \ bmd,dv \rightarrow bmv & bmd + {\color{red}{dv}} & \\ \rule{0pt}{2em} \text{logits} = \langle Q, K \rangle & [MH] \ bhnk,b{\color{red}{h}}mk \rightarrow bhnm & bnd + bm{\color{red}{d}} + bhn^2 & bhmnk = bmnd = bn^2d \\ & [MQ] \ bhnk,bmk \rightarrow bhnm & bnd + bm{\color{red}{k}} + bhn^2 & \\ \rule{0pt}{2em} \text{weights: softmax} & & bhnm & bhnm \\ \rule{0pt}{2em} \text{out(O)} = \langle \text{weights}, V \rangle & [MH] \ bhnm,b{\color{red}{h}}mv \rightarrow bhnv & bhnm + bhmv = bhnm + bm{\color{red}{d}} & bhnmv = bmnd = bn^2d \\ & [MQ] \ bhnm,bmv \rightarrow bhnv & bhnm + bm{\color{red}{v}} + {\color{red}{bnd}} & \\ \rule{0pt}{2em} y=\langle O, P_O \rangle & bhnv,hvd \rightarrow bnd & bnd + d^2 & bndhv = bnd^2 \\ \rule{0pt}{2em} \text{Total}\text{: Multi Head} & & \approx bnd + bhn^2 + d^2 & bnd^2 + bn^2d \approx bnd^2 \\ \text{Total}\text{: Multi Query} & & \approx bnd + bhn^2 + d^2 & \\ \hline \rule{0pt}{1em} r: \text{Multi Head} & & 1/d + 1/k + 1/(bn) << 1 & \\ r: \text{Multi Query} & & 1/d + 1/k + 1/(bn) << 1 & \\ \end{array} }\]


Explanation

Implications