Memory IO Efficiency of Multi-Query Attention

Multi-query attention can be much more efficient under large batch and context length.

Multi-query attention was first introduced in and was later used in PaLM for inference efficiency. In this blog, we will analyze why multi-query can be much more efficient than the traditional multi-head attention.

Multi-Query Attention at a Glance

The key difference of multi-query attention is to collapse all the heads of the projection matrices PK and PV to have only 1 output head instead of full h heads. All other projection matrices (PQ and PO) still have sizes hdk. PK and PV have the size reduced from hdk to dk.

Note that given an input x with hidden dimension d, during incremental decoding, x is still projected to many heads during to produce the query tensor (since the query has h heads). Since the query has many heads, the fact that key and value tensors have 1 head still leads to multiple head-interactions during logits and output computation. The single head in key and value tensors is broadcasted to perform attention with all the heads with Q.

To see why such a simple change can lead to dramatically higher efficiency during incremental decoding, we provide background on counting the memory access and computation required for each tensor operation (einsum). Note: One can refer to The Illustrated Attention via Einstein Summation for the introduction to einsum.

Operation and Memory Access Counting (short version)

At a high level, the number operations and memory access for the tensor computation A,BC are:

Operation and Memory Access Counting (longer version, can be skipped)

Memory IO Cost

Now we can analyze the memory IO cost for multi-head and multi-query attention.

Incremental Decoding

Main Takeaway The calculations that incur the highest amount of memory access for normal multi-head attention are the logits and output calculations which involves the following tensor operation (for logits)

Multi Head q,K:bhk,bhmkbhm
Here, there are bhmk number of operations but it requires bhmk memory access, which is the memory-bound regime (rather than the compute bound) and is inefficient. In contrast, for multi-query,
Multi Query q,K:bhk,bmkbhm which requires only bhk + bmk memory access.

Figure 2: Multi-Query Attention vs Multi-Head Attention. Multi-query is almost identical to multi-head except for 1 head for the key and value projection matrices.

Aditional Details

The following table provides analysis for number of operations and memory access cost (in terms of tight complexity bounds) for both the traditional multi-head attention versus multi-query attention.


Table 1: Memory Access and Computation Complexities for Incremental Decoding with Multi-Head and Multi-Query Attention.

OperationEinsumMemory AccessComputationInput (x) : bdq=x,Pqbd,hdkbhkbd+hdk=bd+d2bdhk=bd2K=x,Pk (+Kprev)[MH] bd,hdkbhk (+bmhk)bd+d2bdhk=bd2[MQ] bd,dkbk (+bmk)bd+dkV=x,Pv (+Vprev)[MH] bd,hdvbhv (+bmhv)bd+d2bdhv=bd2[MQ] bd,dvbv (+bmv)bd+dvlogits=q,K[MH] bhk,bhmkbhmbhk+bhmk=bd+bmdbhmk=bmd[MQ] bhk,bmkbhmbd+bmk+bhmweights: softmaxbhmbhmout(O)=weights,V[MH] bhm,bhmvbhvbhm+bhmv=bhm+bmdbhmv=d[MQ] bhm,bmvbhvbhm+bmv+bhvy=O,PObhv,hdvbdbd+d2bdhv=bd2Total: Multi Headbd+bmd+d2bhm+bmd+bd2bd2Total: Multi Querybd+bmk+d2r:Multi Head1/d+m/d+1/br:Multi Query1/d+m/(dh)+1/b

Note: r is the ratio of memory access complexity versus computation complexity. A ratio close to 1 would indicate that there are 1-to-1 memory access per computation, which would be very inefficient. An unfused softmax or dropout is such examples of IO inefficienct operations.

Observations

Batch Computation Cost for Multi-Head Attention (can be skipped)

Batch computation in this case refers to when we compute attentions corresponding to n tokens. The analysis below shows that the number of memory access per operation is much less than 1-to-1 in which makes it quite efficient.

The table below shows the analysis per each operation. The memory access complexity are the same for both multi-head and multi-query. In practice, the multi-query setting is slightly faster due to lower constants. (In MQ, some d2 terms are reduced to dk, for example, but the total complexity is still bounded by d2)


Table 2: Memory Access and Computation Complexities for Batch Computation with Multi-Head and Multi-Query Attention. Note that we use n and m for final calculation of memory access and number of computations quite interchangeably since they are the same.

OperationEinsumMemory AccessComputationInput M, N : bmd, bndq=N,Pqbnd,dhkbhnkbnd+dhk=bnd+d2bndhk=bnd2K=M,Pk[MH] bmd,dhkbhmkbmd+d2bmdhk=bmd2[MQ] bmd,dkbmkbmd+dkV=M,Pv[MH] bmd,dhvbhmvbmd+d2bmdhv=bd2[MQ] bmd,dvbmvbmd+dvlogits=Q,K[MH] bhnk,bhmkbhnmbnd+bmd+bhn2bhmnk=bmnd=bn2d[MQ] bhnk,bmkbhnmbnd+bmk+bhn2weights: softmaxbhnmbhnmout(O)=weights,V[MH] bhnm,bhmvbhnvbhnm+bhmv=bhnm+bmdbhnmv=bmnd=bn2d[MQ] bhnm,bmvbhnvbhnm+bmv+bndy=O,PObhnv,hvdbndbnd+d2bndhv=bnd2Total: Multi Headbnd+bhn2+d2bnd2+bn2dbnd2Total: Multi Querybnd+bhn2+d2r:Multi Head1/d+1/k+1/(bn)<<1r:Multi Query1/d+1/k+1/(bn)<<1


Explanation

Implications

Footnotes

    References

    1. Fast Transformer Decoding: One Write-Head is All You Need[PDF]
      Shazeer, N., 2019. CoRR, Vol abs/1911.02150.
    2. PaLM: Scaling Language Modeling with Pathways[link]
      Chowdhery, A., Narang, S., Devlin, J., Bosma, M., Mishra, G., Roberts, A., Barham, P., Chung, H.W., Sutton, C., Gehrmann, S., Schuh, P., Shi, K., Tsvyashchenko, S., Maynez, J., Rao, A., Barnes, P., Tay, Y., Shazeer, N., Prabhakaran, V., Reif, E., Du, N., Hutchinson, B., Pope, R., Bradbury, J., Austin, J., Isard, M., Gur{-}Ari, G., Yin, P., Duke, T., Levskaya, A., Ghemawat, S., Dev, S., Michalewski, H., Garcia, X., Misra, V., Robinson, K., Fedus, L., Zhou, D., Ippolito, D., Luan, D., Lim, H., Zoph, B., Spiridonov, A., Sepassi, R., Dohan, D., Agrawal, S., Omernick, M., Dai, A.M., Pillai, T.S., Pellat, M., Lewkowycz, A., Moreira, E., Child, R., Polozov, O., Lee, K., Zhou, Z., Wang, X., Saeta, B., Diaz, M., Firat, O., Catasta, M., Wei, J., Meier{-}Hellstern, K., Eck, D., Dean, J., Petrov, S. and Fiedel, N., 2022. CoRR, Vol abs/2204.02311. DOI: 10.48550/arXiv.2204.02311