MLA
这一章主要参考苏剑林老师的科学空间博文:缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA。
MHA
Multi-Head Attention 指的就是多头注意力,是开山之作 Attention is all you need 中所提出的一种注意力机制。下面考虑增量模式下的 MHA 计算公式(方便起见,已忽略 qkT 除以 dk 的归一化步骤) 。不妨设输入的行向量序列为 (x1,x2,…,xt),xi∈Rd,现在要计算 Attention 层的结果 ot。用 1≤s≤h 表示某个注意力头。
ot(s)=Attention(qt(s),k≤t(s),v≤t(s))≜∑i≤texp(qt(s)ki(s)⊤)∑i≤texp(qt(s)ki(s)⊤)vi(s)qi(s)=xiWq(s)∈Rdk,Wq(s)∈Rd×dkki(s)=xiWk(s)∈Rdk,Wk(s)∈Rd×dkvi(s)=xiWv(s)∈Rdv,Wv(s)∈Rd×dv
实践上,常见的设置是 dk=dv=d/h。
模型 |
d |
h |
dk |
dv |
LLAMA2-7B |
4096 |
32 |
128 |
128 |
LLAMA2-70B |
8192 |
64 |
128 |
128 |
预测后续所有 token 时,k≤t(s),v≤t(s) 的值一直保持不变,这部分结果就可以缓存下来供后续生成调用,以避免不必要的重复计算,这就是所谓的 KV Cache。当模型参数量变大后,每个 Attention 结构要耗费 d×(dk+dv)×h 大小的存储空间,因此就推出了 Multi-Query Attention 和 Grouped-Query Attention 来优化显存。MQA 将 KV Cache 全部重复使用(即 k1(s)=⋯=kt(s),v1(s)=…,=vt(s) ),而 GQA 则将其分组重复使用。
MLA 缘起
MLA(Multi-head Latent Attention)应用于 DeepSeek V2 和 V3,利用低秩分解和矩阵吸收来减少 MHA 的 KV Cache 占用,同时让效果不逊于 MQA 和 GQA。其核心思想是引入了 Wc 和 ci:
ci=xiWc∈Rdc,Wc∈Rd×dcqi(s)=xiWq(s)∈Rdk,Wq(s)∈Rd×dkki(s)=ciWk(s)∈Rdk,Wk(s)∈Rdc×dkvi(s)=ciWv(s)∈Rdv,Wv(s)∈Rdc×dv
上述形式在推理时仍然要缓存彼此不同的 ki(s) 和 vi(s)。注意到以下恒等变换:
qt(s)ki(s)⊤=(xtWq(s))(ciWk(s))⊤=xt(Wq(s)Wk(s)⊤)ci⊤
如果用 Wqk(s)=Wq(s)Wk(s)⊤ 替换 Wq(s) ,MHA 里的 ki(s) 就能视为此处的 ci,这样后者可以在不同注意力头之间共享——我们称这种变换为 吸收变换。MHA 的 vi(s) 同样能做吸收变换视为 ci,因为 ot(s) 最后会拼起来过一遍投影矩阵 [ot(1),ot(2),…,ot(h)]WO,WO∈Rdvh×d,所以 Wv(s) 可以和 WO(s) 合并。
注意:吸收变换虽然在数学上等价,在低精度运算下误差会增大。
兼容 RoPE
DeepSeek 使用了 RoPE 这种“用绝对坐标包含相对关系”的位置编码。RoPE 是 dk×dk 的分块对角矩阵 Rm,满足 RmRn⊤=Rm−n。加上 RoPE 后无法进行吸收减缓,因为两个 W 之间会出现和位置相关的 Rt−i 项:
qi(s)=xiWq(s)Ri,ki(s)=ciWk(s)Riqt(s)ki(s)⊤=(xtWq(s)Rt)(ciWk(s)Ri)⊤=xt(Wq(s)Rt−iWk(s)⊤)ci⊤
MLA 的解决方法是将 qi 和 ki 的维度从 dk 扩展到 dk+dr,其中 dr 是专为兼容 RoPE 而设计的,即这部分无法进行吸收变换。MLA 对于 qi 和 ki 的处理略有不同,前者正常训练 h 个 Wqr(s),后者却借用了 MQA 的思想全局共享一个 Wkr(即 ki=xiWkrRi),这样能尽可能降低新引入的 dr 个维度的 K Cache。
标准 MLA
标准的 MLA 结构(吸收变换前)如下。有两个需要注意的点:
- 计算 qi(s) 时模仿了 ki(s) 的方式引入了 ci′ 降维,这一步和减少 KV Cache 无关。猜测是想前向时缓存 ci′。
- 计算 ki(s) 时两个部分左乘的两个向量不一致,右边仍然保持 xi,可能 RoPE 时想尽可能保持原向量。
ot=[ot(1),ot(2),⋯,ot(h)]ot(s)=Attention(qt(s),k≤t(s),v≤t(s))≜∑i≤texp(qt(s)ki(s)⊤)∑i≤texp(qt(s)ki(s)⊤)vi(s)qi(s)=[ci′Wqc(s),ci′Wqr(s)Ri]∈Rdk+dr,Wqc(s)∈Rdc′×dk,Wqr(s)∈Rdc′×drki(s)=[ciWkc(s),xiWkrRi]∈Rdk+dr,Wkc(s)∈Rdc×dk,Wkr∈Rd×drvi(s)=ciWv(s)∈Rdv,Wv(s)∈Rdc×dvci′=xiWc′∈Rdc′,Wc′∈Rd×dc′ci=xiWc∈Rdc,Wc∈Rd×dc
经过吸收变换后的 MLA 结构如下:
ot=[ot(1)Wv(1),ot(2)Wv(2),⋯,ot(h)Wv(h)]ot(s)=Attention(qt(s),k≤t,c≤t)≜∑i≤texp(qt(s)ki⊤)∑i≤texp(qt(s)ki⊤)ciqi(s)=[ci′Wqc(s)Wkc(s)⊤,ci′Wqr(s)Ri]∈Rdc+drki=[ci,xiWkrRi]∈Rdc+drWqc(s)∈Rdc′×dk,Wkc(s)∈Rdc×dk,Wqr(s)∈Rdc′×dr,Wkr∈Rd×drci′=xiWc′∈Rdc′,Wc′∈Rd×dc′ci=xiWc∈Rdc,Wc∈Rd×dc
模型 |
d |
dc′ |
dc |
dk |
dv |
dr |
DeepSeek V3-16B |
2048 |
— |
512 |
128 |
128 |
64 |
DeepSeek V3-671B |
7168 |
1536 |
512 |
128 |
128 |
64 |
DeepSeek V3 推理代码
DeepSeek V3 推理代码详见其开源的 Github。
MLA 层
【待施工】

图源:河小涌 DeepSeek-v3代码demo解读:架构与推理!