MLA

这一章主要参考苏剑林老师的科学空间博文:缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA

MHA

Multi-Head Attention 指的就是多头注意力,是开山之作 Attention is all you need 中所提出的一种注意力机制。下面考虑增量模式下的 MHA 计算公式(方便起见,已忽略 qkT\boldsymbol{q}\boldsymbol{k}^T 除以 dk\sqrt{d_k} 的归一化步骤) 。不妨设输入的行向量序列为 (x1,x2,,xt),xiRd(\boldsymbol{x}_1,\boldsymbol{x}_2,\dots,\boldsymbol{x}_t),x_i \in \mathbb{R}^{d},现在要计算 Attention 层的结果 ot\boldsymbol{o}_t。用 1sh1 \le s \le h 表示某个注意力头。

ot(s)=Attention(qt(s),kt(s),vt(s))itexp(qt(s)ki(s))vi(s)itexp(qt(s)ki(s))qi(s)=xiWq(s)Rdk,Wq(s)Rd×dkki(s)=xiWk(s)Rdk,Wk(s)Rd×dkvi(s)=xiWv(s)Rdv,Wv(s)Rd×dv\begin{equation} \begin{gathered} \boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} ,\boldsymbol{v}_{\leq t}^{(s)}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)} \\[15pt] \boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k}\\ \boldsymbol{k}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d\times d_k} \\ \boldsymbol{v}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d\times d_v} \end{gathered} \end{equation}

实践上,常见的设置是 dk=dv=d/hd_k = d_v = d / h

模型 dd hh dkd_k dvd_v
LLAMA2-7B 40964096 3232 128128 128128
LLAMA2-70B 81928192 6464 128128 128128

预测后续所有 token 时,kt(s),vt(s)\boldsymbol{k}_{\leq t}^{(s)} ,\boldsymbol{v}_{\leq t}^{(s)} 的值一直保持不变,这部分结果就可以缓存下来供后续生成调用,以避免不必要的重复计算,这就是所谓的 KV Cache。当模型参数量变大后,每个 Attention 结构要耗费 d×(dk+dv)×hd \times (d_k+d_v) \times h 大小的存储空间,因此就推出了 Multi-Query Attention 和 Grouped-Query Attention 来优化显存。MQA 将 KV Cache 全部重复使用(即 k1(s)==kt(s),v1(s)=,=vt(s)\boldsymbol{k}_1^{(s)}=\dots=\boldsymbol{k}_t^{(s)},\boldsymbol{v}_1^{(s)}=\dots,=\boldsymbol{v}_t^{(s)} ),而 GQA 则将其分组重复使用。

MLA 缘起

MLA(Multi-head Latent Attention)应用于 DeepSeek V2 和 V3,利用低秩分解和矩阵吸收来减少 MHA 的 KV Cache 占用,同时让效果不逊于 MQA 和 GQA。其核心思想是引入了 Wc\boldsymbol{W}_cci\boldsymbol{c}_i

ci=xiWcRdc,WcRd×dcqi(s)=xiWq(s)Rdk,Wq(s)Rd×dkki(s)=ciWk(s)Rdk,Wk(s)Rdc×dkvi(s)=ciWv(s)Rdv,Wv(s)Rdc×dv\begin{equation} \begin{gathered} \boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c} \\[10pt] \boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k}\\ \boldsymbol{k}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d_c\times d_k} \\ \boldsymbol{v}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_c\times d_v} \end{gathered} \end{equation}

上述形式在推理时仍然要缓存彼此不同的 ki(s)\boldsymbol{k}_i^{(s)}vi(s)\boldsymbol{v}_i^{(s)}。注意到以下恒等变换:

qt(s)ki(s)=(xtWq(s))(ciWk(s))=xt(Wq(s)Wk(s))ci\begin{equation}\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top} = \left(\boldsymbol{x}_t\boldsymbol{W}_q^{(s)}\right) \left(\boldsymbol{c}_i\boldsymbol{W}_k^{(s)}\right){}^{\top} = \boldsymbol{x}_t\left(\boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)}{}^{\top}\right)\boldsymbol{c}_i^{\top} \end{equation} \\

如果用 Wqk(s)=Wq(s)Wk(s)\boldsymbol{W}_{qk}^{(s)}=\boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)}{}^{\top} 替换 Wq(s)\boldsymbol{W}_q^{(s)} ,MHA 里的 ki(s)\boldsymbol{k}_i^{(s)} 就能视为此处的 ci\boldsymbol{c}_i,这样后者可以在不同注意力头之间共享——我们称这种变换为 吸收变换。MHA 的 vi(s)\boldsymbol{v}_i^{(s)} 同样能做吸收变换视为 ci\boldsymbol{c}_i,因为 ot(s)\boldsymbol{o}_t^{(s)} 最后会拼起来过一遍投影矩阵 [ot(1),ot(2),,ot(h)]WO,WORdvh×d[\boldsymbol{o}_t^{(1)},\boldsymbol{o}_t^{(2)},\dots,\boldsymbol{o}_t^{(h)}]\boldsymbol{W}_O,\quad \boldsymbol{W}_O \in\mathbb{R}^{d_vh\times d},所以 Wv(s)\boldsymbol{W}_v^{(s)} 可以和 WO(s)\boldsymbol{W}_O^{(s)} 合并。

注意:吸收变换虽然在数学上等价,在低精度运算下误差会增大。

兼容 RoPE

DeepSeek 使用了 RoPE 这种“用绝对坐标包含相对关系”的位置编码。RoPE 是 dk×dkd_k \times d_k 的分块对角矩阵 Rm\boldsymbol{\mathcal{R}}_m,满足 RmRn=Rmn\boldsymbol{\mathcal{R}}_m\boldsymbol{\mathcal{R}}_n^{\top}=\boldsymbol{\mathcal{R}}_{m-n}。加上 RoPE 后无法进行吸收减缓,因为两个 W\boldsymbol{W} 之间会出现和位置相关的 Rti\boldsymbol{\mathcal{R}}_{t-i} 项:

qi(s)=xiWq(s)Ri,ki(s)=ciWk(s)Riqt(s)ki(s)=(xtWq(s)Rt)(ciWk(s)Ri)=xt(Wq(s)RtiWk(s))ci\begin{equation} \boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}{\boldsymbol{\mathcal{R}}_i}\quad,\quad\boldsymbol{k}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_k^{(s)}{\boldsymbol{\mathcal{R}}_i} \\ \boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top} = \left(\boldsymbol{x}_t\boldsymbol{W}_q^{(s)}{\boldsymbol{\mathcal{R}}_t}\right) \left(\boldsymbol{c}_i\boldsymbol{W}_k^{(s)}{\boldsymbol{\mathcal{R}}_i}\right){}^{\top} = \boldsymbol{x}_t\left(\boldsymbol{W}_q^{(s)}{\boldsymbol{\mathcal{R}}_{t-i}} \boldsymbol{W}_k^{(s)}{}^{\top}\right)\boldsymbol{c}_i^{\top} \end{equation}

MLA 的解决方法是将 qi\boldsymbol{q}_iki\boldsymbol{k}_i 的维度从 dkd_k 扩展到 dk+drd_k+d_r,其中 drd_r 是专为兼容 RoPE 而设计的,即这部分无法进行吸收变换。MLA 对于 qi\boldsymbol{q}_iki\boldsymbol{k}_i 的处理略有不同,前者正常训练 hhWqr(s)\boldsymbol{W}_{qr}^{(s)},后者却借用了 MQA 的思想全局共享一个 Wkr\boldsymbol{W}_{kr}(即 ki=xiWkrRi\boldsymbol{k}_i=\boldsymbol{x}_i\boldsymbol{W}_{kr}{\boldsymbol{\mathcal{R}}_i}),这样能尽可能降低新引入的 drd_r 个维度的 K Cache。

标准 MLA

标准的 MLA 结构(吸收变换前)如下。有两个需要注意的点:

  • 计算 qi(s)\boldsymbol{q}_i^{(s)} 时模仿了 ki(s)\boldsymbol{k}_i^{(s)} 的方式引入了 ci\boldsymbol{c}_i' 降维,这一步和减少 KV Cache 无关。猜测是想前向时缓存 ci\boldsymbol{c}_i'
  • 计算 ki(s)\boldsymbol{k}_i^{(s)} 时两个部分左乘的两个向量不一致,右边仍然保持 xi\boldsymbol{x}_i,可能 RoPE 时想尽可能保持原向量。

ot=[ot(1),ot(2),,ot(h)]ot(s)=Attention(qt(s),kt(s),vt(s))itexp(qt(s)ki(s))vi(s)itexp(qt(s)ki(s))qi(s)=[ciWqc(s),ciWqr(s)Ri]Rdk+dr,Wqc(s)Rdc×dk,Wqr(s)Rdc×drki(s)=[ciWkc(s),xiWkrRi]Rdk+dr,Wkc(s)Rdc×dk,WkrRd×drvi(s)=ciWv(s)Rdv,Wv(s)Rdc×dvci=xiWcRdc,WcRd×dcci=xiWcRdc,WcRd×dc\begin{equation} \begin{gathered} \boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt] \boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} ,\boldsymbol{v}_{\leq t}^{(s)}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)} \\[15pt] \boldsymbol{q}_i^{(s)} = \left[\boldsymbol{c}_i'\boldsymbol{W}_{qc}^{(s)}, \boldsymbol{c}_i'\boldsymbol{W}_{qr}^{(s)}{\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_k + d_r},\quad \boldsymbol{W}_{qc}^{(s)}\in\mathbb{R}^{d_c'\times d_k},\boldsymbol{W}_{qr}^{(s)}\in\mathbb{R}^{d_c'\times d_r}\\ \boldsymbol{k}_i^{(s)} = \left[\boldsymbol{c}_i\boldsymbol{W}_{kc}^{(s)}, \boldsymbol{x}_i\boldsymbol{W}_{kr}{\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_k+d_r},\quad \boldsymbol{W}_{kc}^{(s)}\in\mathbb{R}^{d_c\times d_k}, \boldsymbol{W}_{kr}\in\mathbb{R}^{d\times d_r} \\ \boldsymbol{v}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_c\times d_v} \\[10pt] \boldsymbol{c}_i' = \boldsymbol{x}_i \boldsymbol{W}_c'\in\mathbb{R}^{d_c'},\quad \boldsymbol{W}_c'\in\mathbb{R}^{d\times d_c'} \\ \boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c} \\ \end{gathered} \end{equation}

经过吸收变换后的 MLA 结构如下:

ot=[ot(1)Wv(1),ot(2)Wv(2),,ot(h)Wv(h)]ot(s)=Attention(qt(s),kt,ct)itexp(qt(s)ki)ciitexp(qt(s)ki)qi(s)=[ciWqc(s)Wkc(s),ciWqr(s)Ri]Rdc+drki=[ci,xiWkrRi]Rdc+drWqc(s)Rdc×dk,Wkc(s)Rdc×dk,Wqr(s)Rdc×dr,WkrRd×drci=xiWcRdc,WcRd×dcci=xiWcRdc,WcRd×dc\begin{equation} \begin{gathered} \boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}\boldsymbol{W}_v^{(1)}, \boldsymbol{o}_t^{(2)}\boldsymbol{W}_v^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\boldsymbol{W}_v^{(h)}\right] \\[10pt] \boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t} ,\boldsymbol{c}_{\leq t}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i{}^{\top}\right)\boldsymbol{c}_i}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i{}^{\top}\right)} \\[15pt] \boldsymbol{q}_i^{(s)} = \left[\boldsymbol{c}_i'\boldsymbol{W}_{qc}^{(s)}\boldsymbol{W}_{kc}^{(s)}{}^{\top}, \boldsymbol{c}_i'\boldsymbol{W}_{qr}^{(s)}{\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_c + d_r}\\ \boldsymbol{k}_i = \left[\boldsymbol{c}_i, \boldsymbol{x}_i\boldsymbol{W}_{kr}{\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_c+d_r}\\ \boldsymbol{W}_{qc}^{(s)}\in\mathbb{R}^{d_c'\times d_k},\boldsymbol{W}_{kc}^{(s)}\in\mathbb{R}^{d_c\times d_k},\boldsymbol{W}_{qr}^{(s)}\in\mathbb{R}^{d_c'\times d_r},\boldsymbol{W}_{kr}\in\mathbb{R}^{d\times d_r} \\[10pt] \boldsymbol{c}_i' = \boldsymbol{x}_i \boldsymbol{W}_c'\in\mathbb{R}^{d_c'},\quad \boldsymbol{W}_c'\in\mathbb{R}^{d\times d_c'} \\ \boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c} \\ \end{gathered} \end{equation}

模型 dd dcd_c' dcd_c dkd_k dvd_v drd_r
DeepSeek V3-16B 20482048 512512 128128 128128 6464
DeepSeek V3-671B 71687168 15361536 512512 128128 128128 6464

DeepSeek V3 推理代码

DeepSeek V3 推理代码详见其开源的 Github

MLA 层

【待施工】

图源:河小涌 DeepSeek-v3代码demo解读:架构与推理!