跳转到主要内容
AI Systems

KV Cache:推理性能的命根子

约 14 分钟阅读

1. 为什么需要 KV Cache?

1.1 回顾 Autoregressive Decoding

你已经知道 decode 阶段每次只生成 1 个 token。关键问题是:生成第 N 个 token 时,attention 需要看到前面所有 N-1 个 token 的 Key 和 Value。

如果不缓存,decode 阶段每生成一个 token 都要重新计算所有历史 token 的 K、V 向量:

decode 阶段(无 KV Cache 的假设):
  生成 token 1: 计算 1 个 token 的 K, V                    → 计算量 ∝ 1
  生成 token 2: 重新计算前 2 个 token 的 K, V               → 计算量 ∝ 2
  生成 token 3: 重新计算前 3 个 token 的 K, V               → 计算量 ∝ 3
  ...
  生成 token N: 重新计算前 N 个 token 的 K, V               → 计算量 ∝ N

  总计算量 = 1 + 2 + 3 + ... + N = N(N+1)/2 ≈ O(N²)

KV Cache 的核心思想:把已经算过的 K、V 存起来,下次直接用

decode 阶段(有 KV Cache):
  生成 token 1: 计算 K₁, V₁ → 存入 cache                  → 计算量 ∝ 1
  生成 token 2: 从 cache 读 K₁,V₁, 只算 K₂,V₂ → 存入 cache → 计算量 ∝ 1
  生成 token 3: 从 cache 读 K₁₋₂,V₁₋₂, 只算 K₃,V₃ → 存入  → 计算量 ∝ 1
  ...
  每步只需计算 1 个新 token 的 K, V

  总计算量 = 1 × N = O(N)

代价:用显存换计算量。计算从 O(N²) 降到 O(N),但显存占用线性增长。

1.2 前置知识:Self-Attention 与 Multi-Head Attention

Self-Attention(自注意力)

Self-Attention 让序列中的每个 token 去看同一个序列中所有其他 token,决定”我应该关注谁”:

句子: "The cat sat on the mat because it was tired"

处理 "it" 时 → 给每个 token 打相关性分数 → "cat" 得分最高
→ "it" 的输出表示会融入 "cat" 的信息

计算过程:

输入序列 X = [x₁, x₂, ..., x_n]

1. 每个 token 通过三个权重矩阵生成三个向量:
   Q = X · W_Q   (Query:  "我在找什么?")
   K = X · W_K   (Key:    "我能提供什么?")
   V = X · W_V   (Value:  "我的实际内容")

2. 计算相关性: Score = Q · K^T → [n × n] 矩阵
3. 归一化:     Weights = softmax(Score / √d_k)
4. 加权求和:   Output = Weights · V

叫 “self” 是因为 Q、K、V 都来自同一个序列。

Q·K^T 的数学含义:向量点积 = 相似度

Q·K^T 的本质是点积(dot product),衡量两个向量的相似度:

点积: a · b = |a| × |b| × cos(θ)

cos(0°)=1  → 方向相同 → 点积大 → 最相似
cos(90°)=0 → 正交     → 点积零 → 不相关

Q·K^T 得到 [n × n] 矩阵,每个元素 qᵢ·kⱼ = token i 和 token j 的匹配度

除以 √d_k 是为了防止高维点积值过大导致 softmax 变成 one-hot(梯度消失)。

为什么需要独立的 K,不能直接用 V 匹配?

K 和 V 的职责不同:K 负责”被搜索到”,V 负责”提供内容”

类比图书馆:
  Q = 你的搜索词
  K = 每本书的索引标签(用来匹配)
  V = 每本书的实际内容(匹配后取出)

如果没有 K,直接用 V 匹配:
  V 既要当好索引(简洁、区分度高),又要当好内容(丰富、完整)
  两个目标矛盾 → 效果差

有了 K:
  W_K 学习: 让相关 token 的 key 在投影空间里方向接近(点积大)
  W_V 学习: 让 token 携带丰富的语义内容
  各司其职 → 效果好

前沿研究:2025 年论文 Key and Value Weights Are Probably All You Need 证明 Q/K/V 三个权重矩阵中有一个可以用单位矩阵替代(参数减少 25%), 但完全干掉 K 在大模型上仍有质量损失。另一个方向是 Linear Attention(Mamba/RWKV), 改变 Q·K 的交互方式,将 KV Cache 从线性增长变为固定大小。

Multi-Head Attention(多头注意力)

单头 attention 只能学一种关注模式。Multi-Head 把 attention 拆成多个”头”,每个头独立学不同的模式:

hidden_dim = 4096 (Llama-3-8B)
num_heads  = 32
head_dim   = hidden_dim / num_heads = 4096 / 32 = 128

把 Q, K, V 各拆成 32 份,每份维度 128:
  Head 0:  Attention(Q₀[seq,128], K₀[seq,128], V₀[seq,128]) → Out₀[seq,128]
  Head 1:  Attention(Q₁[seq,128], K₁[seq,128], V₁[seq,128]) → Out₁[seq,128]
  ...
  Head 31: Attention(Q₃₁[seq,128], K₃₁[seq,128], V₃₁[seq,128]) → Out₃₁[seq,128]

最后拼回去: Output = Concat(Out₀, ..., Out₃₁) · W_O → [seq, 4096]

Layers(层数)

Transformer 是多层堆叠,每层做同样的事(attention + FFN),但权重独立:

输入 embedding

  Layer 0:  Multi-Head Attention → FFN
  Layer 1:  Multi-Head Attention → FFN
  ...
  Layer 31: Multi-Head Attention → FFN    ← Llama-3-8B 共 32 层

输出 logits

每一层都有独立的 KV Cache(因为每层的 K、V 不同)

1.3 KV Cache 在 Attention 中的位置

有了上面的基础,Self-Attention 的公式就很清晰了:

Q = X · W_Q    (Query)
K = X · W_K    (Key)
V = X · W_V    (Value)

Attention(Q, K, V) = softmax(Q · K^T / √d_k) · V

在 decode 阶段:

  • Q:只有当前新 token 的 query(1 个向量)
  • K:cache 中所有历史 token 的 key + 当前新 token 的 key
  • V:cache 中所有历史 token 的 value + 当前新 token 的 value

所以每一步 decode,KV Cache 都在增长:

Step 1: cache = [K₁, V₁]
Step 2: cache = [K₁, K₂], [V₁, V₂]
Step 3: cache = [K₁, K₂, K₃], [V₁, V₂, V₃]
...
Step N: cache = [K₁...K_N], [V₁...V_N]

2. KV Cache 内存计算公式

2.1 核心公式(MHA - Multi-Head Attention)

KV Cache 内存 = 2 × layers × heads × head_dim × seq_len × batch_size × bytes_per_element

各参数含义(回顾 1.2 节的概念):

参数含义为什么影响 KV Cache示例 (Llama-3-8B)
2K 和 V 两个张量每个 token 要存 K 和 V 两份固定值
layersTransformer 层数每层有独立的 KV,层越多份数越多32
headsKV 注意力头数每个头有独立的 K、V 向量32 (MHA)
head_dim每个头的维度每个 K/V 向量的长度,= hidden_dim / num_heads128 (= 4096/32)
seq_len已缓存的 token 数每个 token 都要存,序列越长 cache 越大2048
batch_size同时处理的请求数每个请求有独立的 KV Cache1
bytes_per_element数据精度FP16=2字节, INT8=1字节, INT4=0.5字节2 (FP16/BF16)

2.2 实际计算示例

Llama-3-8B (MHA), FP16, seq_len=2048, batch=1:

2 × 32 × 32 × 128 × 2048 × 1 × 2 bytes
= 2 × 32 × 32 × 128 × 2048 × 2
= 1,073,741,824 bytes
= 1 GB

一条请求的 KV Cache 就要 1GB 显存!

batch=32 时:

1 GB × 32 = 32 GB  ← 光 KV Cache 就 32GB

加上模型权重(8B × 2 bytes = 16GB),总共需要 48GB,A100-40GB 直接 OOM。

2.3 不同模型的 KV Cache 对比

模型LayersHeadshead_dimKV Cache/token/batch (FP16)
Llama-3-8B3232 (MHA)1280.5 MB/token
Llama-3-70B8064 (MHA)1282.5 MB/token
Mistral-7B328 (GQA)1280.125 MB/token

注意 Mistral 用了 GQA(Grouped-Query Attention),KV heads 只有 8 个而不是 32 个, KV Cache 直接缩小 4 倍!

2.4 MHA vs MQA vs GQA 对 KV Cache 的影响

这是现代模型架构减少 KV Cache 的核心手段:

MHA (Multi-Head Attention):
  每个 attention head 都有独立的 K, V
  KV heads = Q heads = 32
  KV Cache 最大

MQA (Multi-Query Attention):
  所有 Q heads 共享 1 组 K, V
  KV heads = 1
  KV Cache 最小,但质量可能下降

GQA (Grouped-Query Attention):
  Q heads 分组,每组共享 1 组 K, V
  KV heads = Q heads / group_size (如 32/4 = 8)
  KV Cache 适中,质量接近 MHA
MHA:  Q₁K₁V₁  Q₂K₂V₂  Q₃K₃V₃  Q₄K₄V₄   ← 4组 KV
GQA:  Q₁Q₂→K₁V₁  Q₃Q₄→K₂V₂              ← 2组 KV (省一半)
MQA:  Q₁Q₂Q₃Q₄→K₁V₁                      ← 1组 KV (省最多)

GQA 公式修正:

KV Cache = 2 × layers × kv_heads × head_dim × seq_len × batch × bytes

注意这里用 kv_heads 而不是 q_heads

2.5 深入理解 GQA (Grouped-Query Attention)

GQA 是 Google 在 2023 年论文 GQA: Training Generalized Multi-Query Attention from Multi-Head Checkpoints 中提出的,目前已成为主流大模型的标配(Llama 2/3、Mistral、Gemma、Qwen 等)。

为什么需要 GQA?

MHA 的问题很直接:KV Cache 太大。MQA 把 KV heads 压到 1 个,Cache 最小,但质量下降明显——所有 Q head 被迫共享同一组 K、V,表达能力受限。

GQA 的思路是找一个中间点:把 Q heads 分成若干组,每组共享一组 K、V

以 Llama-3-8B 为例 (q_heads=32, kv_heads=8, group_size=4):

Group 0:  Q₀  Q₁  Q₂  Q₃   → 共享 K₀, V₀
Group 1:  Q₄  Q₅  Q₆  Q₇   → 共享 K₁, V₁
Group 2:  Q₈  Q₉  Q₁₀ Q₁₁  → 共享 K₂, V₂
...
Group 7:  Q₂₈ Q₂₉ Q₃₀ Q₃₁ → 共享 K₇, V₇

group_size = q_heads / kv_heads = 32 / 8 = 4

理解 kv_heads 和 head_dim

类比一下:

  • MHA = 32 个学生,每人配一本独立的参考书(K)和笔记本(V)
  • GQA = 32 个学生分成 8 组,每组共用一本参考书和笔记本
  • MQA = 32 个学生全班共用一本

kv_heads 就是”有多少组独立的 K、V”。head_dim 是每组 KV 的向量维度——每个 head 用多少维来表示一个 token。两者相乘就是一个 token 在一层中 KV Cache 的维度总量。

kv_heads 直接决定 KV Cache 的大小:

每个 token 在一层中的 KV Cache = 2 × kv_heads × head_dim × bytes

MHA: 2 × 32 × 128 × 2 = 16,384 bytes/token/layer
GQA: 2 × 8  × 128 × 2 = 4,096  bytes/token/layer
                                  ↑ 少了 4 倍

GQA 的计算过程

在实际计算中,GQA 需要把 KV 广播(broadcast)到对应的 Q heads:

1. 投影阶段:
   Q = X · W_Q  → shape: [batch, seq, 32, 128]   ← 32 个 Q heads
   K = X · W_K  → shape: [batch, seq, 8, 128]    ← 只有 8 个 KV heads
   V = X · W_V  → shape: [batch, seq, 8, 128]

2. 广播 KV 到每个 Q head:
   K_expanded = K.repeat_interleave(4, dim=2)  → [batch, seq, 32, 128]
   V_expanded = V.repeat_interleave(4, dim=2)  → [batch, seq, 32, 128]
   
   等价于:
   K₀ → Q₀,Q₁,Q₂,Q₃ 共用
   K₁ → Q₄,Q₅,Q₆,Q₇ 共用
   ...

3. 之后和标准 MHA 一样计算 attention:
   Score = Q · K_expanded^T / √d_k
   Output = softmax(Score) · V_expanded

注意:广播只是逻辑上的展开,实际实现中(如 FlashAttention)会直接用索引映射, 不会真的复制 KV 数据,所以不增加显存开销。

参数量对比

GQA 减少的不只是 KV Cache,还有模型参数本身:

W_K 和 W_V 的参数量:

MHA:  W_K = [hidden_dim, q_heads × head_dim]  = [4096, 4096]  → 16M params × 2
GQA:  W_K = [hidden_dim, kv_heads × head_dim] = [4096, 1024]  → 4M params × 2
MQA:  W_K = [hidden_dim, 1 × head_dim]        = [4096, 128]   → 0.5M params × 2

GQA 的 KV 投影参数量 = MHA 的 1/4(当 group_size=4 时)

KV Cache 节省量

以 Llama-3-8B 为例,seq_len=2048, batch=1, FP16:

MHA (kv_heads=32):
  2 × 32 × 32 × 128 × 2048 × 2 = 1,073,741,824 bytes = 1 GB

GQA (kv_heads=8):
  2 × 32 × 8 × 128 × 2048 × 2  = 268,435,456 bytes   = 256 MB

节省: 1 GB → 256 MB,缩小 4 倍

batch=32 时差距更明显:

MHA: 32 GB KV Cache  → A100-80GB 勉强能跑
GQA: 8 GB KV Cache   → 同样的显存能服务 4 倍并发

为什么 GQA 质量接近 MHA?

直觉上,同组内的 Q heads 学到的 attention pattern 往往是相似的。Google 的论文通过实验验证了这一点:

实验结果 (论文 Table 1, T5-XXL 在多任务上的表现):

MHA (kv_heads=64):  基准
GQA (kv_heads=8):   质量损失 < 0.5%,推理速度接近 MQA
MQA (kv_heads=1):   质量损失 ~1-3%,推理速度最快

关键发现:

  • 从 MHA checkpoint 出发,通过 mean pooling 合并相邻 KV heads 的权重,再做少量 fine-tuning(原始训练量的 5%),就能得到高质量的 GQA 模型
  • 这种 “uptraining” 方式比从头训练 GQA 更高效

主流模型的 GQA 配置

模型Q HeadsKV HeadsGroup SizeKV Cache 相对 MHA
Llama-2-70B64881/8
Llama-3-8B32841/4
Llama-3-70B64881/8
Mistral-7B32841/4
Gemma-7B161611/1 (实际是 MHA)
Qwen-2-72B64881/8
DeepSeek-V21281 (MLA)特殊架构

2.6 MLA:比 GQA 更激进的 KV Cache 压缩

GQA 通过减少 KV heads 来压缩 KV Cache,但每个 KV head 仍然存完整的 head_dim 维向量。DeepSeek-V2(2024)提出了 MLA(Multi-head Latent Attention),思路完全不同:不减少 head 数量,而是把所有 heads 的 K、V 联合压缩到一个低维潜向量里

核心思想:低秩 KV 联合压缩

MHA/GQA 的 KV Cache 存的是每个 head 的完整 K、V 向量。MLA 的做法是:先把 hidden state 投影到一个很小的潜空间,推理时只缓存这个小向量,需要时再解压回完整的 K、V。

MHA/GQA 的做法:
  h_t → W_K → K_t (完整的 key)     ← 缓存这个
  h_t → W_V → V_t (完整的 value)   ← 缓存这个

MLA 的做法:
  h_t → W_DKV → c_t^KV (压缩的潜向量)  ← 只缓存这个!
  c_t^KV → W_UK → K_t (需要时解压)      ← 不需要缓存
  c_t^KV → W_UV → V_t (需要时解压)      ← 不需要缓存

关键在于 c_t^KV 的维度 d_c 远小于原始 KV 的维度 n_h × d_h

DeepSeek-V2 的参数:
  n_h = 128 (attention heads)
  d_h = 128 (head dim)
  d_c = 512 (压缩维度) = 4 × d_h

原始 KV 维度: n_h × d_h = 128 × 128 = 16,384
压缩后维度:   d_c = 512

压缩比: 16,384 / 512 = 32 倍

为什么能压缩这么多?

本质上是利用了 KV 的低秩特性。128 个 head 的 K、V 向量之间存在大量冗余——它们都是从同一个 hidden state 线性投影出来的。MLA 用一个低秩瓶颈层(类似 autoencoder)来捕捉这些共享信息:

类比:
  GQA = 把 32 个学生分成 8 组,每组共用一本完整的参考书
  MLA = 把所有参考书的内容压缩成一本摘要,需要时再展开成完整版

  GQA 减少的是"书的数量"
  MLA 减少的是"每本书的厚度"(而且只保留一本摘要)

推理时的矩阵吸收技巧

MLA 还有一个精妙的工程优化:解压矩阵 W_UKW_UV 可以分别被吸收进 W_QW_O 中,推理时根本不需要显式计算 K 和 V:

标准计算:
  score = q_t · (W_UK · c_t^KV)^T = q_t · c_t^KV^T · W_UK^T

吸收后:
  score = (q_t · W_UK^T) · c_t^KV^T
          ↑ 这部分可以预计算合并到 W_Q 里

效果: 推理时直接用 c_t^KV 做 attention,不需要解压

RoPE 的兼容问题

有一个细节:RoPE(旋转位置编码)是作用在 K 上的,但 MLA 的 K 是从潜向量解压出来的,如果对解压后的 K 加 RoPE,就无法做矩阵吸收了(因为 RoPE 破坏了线性关系)。

DeepSeek-V2 的解决方案是解耦 RoPE:额外维护一个小的 k_t^R 向量专门携带位置信息,和压缩部分的 k_t^C 拼接起来:

k_t = [k_t^C ; k_t^R]
       ↑ 从潜向量解压    ↑ 单独计算,带 RoPE
       不需要缓存         需要缓存,但维度很小 (d_h^R = d_h/2 = 64)

所以 MLA 实际缓存的是 c_t^KV(512 维)+ k_t^R(64 维)= 576 维/token/layer。

KV Cache 对比

机制每 token 每层缓存量DeepSeek-V2 实际值相对 MHA
MHA2 × n_h × d_h2 × 128 × 128 = 32,768
GQA (8 groups)2 × n_g × d_h2 × 8 × 128 = 2,0481/16
MQA2 × d_h2 × 128 = 2561/128
MLAd_c + d_h^R512 + 64 = 5761/57

DeepSeek 论文报告 KV Cache 减少 93.3%,同时性能超过 MHA。

MLA 的核心洞察:KV Cache 的瓶颈不在于 head 数量,而在于信息冗余。 与其减少 head(GQA),不如直接压缩信息本身。 代价是推理时多了一步矩阵乘法(解压),但这个计算量相比节省的显存带宽是值得的。

参考:DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model (2024)


3. PagedAttention:用虚拟内存思想管理 KV Cache

3.1 传统 KV Cache 的问题

传统方式为每个请求预分配一块连续显存来存 KV Cache:

请求 A: [████████░░░░░░░░]  ← 预分配 max_seq_len,实际只用了一半
请求 B: [██████░░░░░░░░░░]  ← 更浪费
请求 C: [无法分配]           ← 虽然总空闲够,但找不到连续空间

三大问题:

  1. 内部碎片:预分配 max_seq_len 但实际用不完,浪费 60-80%
  2. 外部碎片:请求结束释放后,空闲块不连续,新请求放不进去
  3. 过度预留:不知道请求会生成多长,只能按最大值预留

3.2 PagedAttention 的核心思想

vLLM 论文(SOSP 2023)借鉴了操作系统的虚拟内存分页机制:

操作系统:  虚拟页 → 页表 → 物理页帧(不需要连续)
vLLM:     逻辑 KV block → block table → 物理 GPU 显存块(不需要连续)

工作方式:

1. 把 KV Cache 切成固定大小的 block(如 16 tokens 一块)
2. 每个请求维护一个 block table(逻辑块 → 物理块的映射)
3. 按需分配:生成新 token 时才分配新 block
4. 物理块可以在显存中任意位置,不需要连续

请求 A 的 block table:
  逻辑块 0 → 物理块 7
  逻辑块 1 → 物理块 2
  逻辑块 2 → 物理块 15

请求 B 的 block table:
  逻辑块 0 → 物理块 3
  逻辑块 1 → 物理块 11

效果:

传统方式:  [AAAA____BBBB____CCCC____]  ← 大量浪费
PagedAttn: [AB CA BC AB CC BA]          ← 几乎零浪费

3.3 PagedAttention 的关键优势

特性传统方式PagedAttention
内存分配预分配连续大块按需分配小块
内部碎片严重(60-80%浪费)极小(最后一个 block 可能有少量)
外部碎片严重无(不需要连续)
内存利用率~20-40%~96-98%
并发请求数多 2-4 倍
内存共享困难天然支持(Copy-on-Write)

3.4 Copy-on-Write 共享

当多个请求共享相同的 prompt(如 system prompt),PagedAttention 可以让它们共享同一份 KV Cache 物理块:

请求 A (system prompt + 用户问题A):
  block table: [共享块0, 共享块1, 共享块2, A专属块3, A专属块4]

请求 B (system prompt + 用户问题B):
  block table: [共享块0, 共享块1, 共享块2, B专属块3, B专属块4]
                ↑ 同一份物理内存,不重复存储

这在 chat 场景下(大量请求共享 system prompt)节省巨大。


4. 长文本推理中的 KV Cache 管理策略

当 seq_len 达到 128K 甚至 1M 时,即使用了 GQA + PagedAttention,KV Cache 仍然是巨大的挑战。

4.1 滑动窗口注意力 (Sliding Window Attention)

Mistral 采用的策略:只保留最近 W 个 token 的 KV Cache,更早的丢弃。

窗口大小 W = 4096

token 位置:  1  2  3  ...  4096  4097  4098  ...
KV Cache:                  [████████████████]
                            ↑ 只保留最近 4096 个
                            token 1-xxx 的 KV 被丢弃

优点:KV Cache 大小固定,不随序列增长 缺点:丢失远距离依赖信息

4.2 KV Cache 量化

对 KV Cache 本身做量化(FP16 → INT8 / INT4):

原始 KV Cache (FP16): 1 GB
INT8 量化后:          0.5 GB  (省 50%)
INT4 量化后:          0.25 GB (省 75%)

vLLM 和 TensorRT-LLM 都支持 KV Cache 量化,精度损失通常可接受。

4.3 KV Cache 压缩与驱逐

更高级的策略:

  • H2O (Heavy-Hitter Oracle):保留 attention score 最高的 token 的 KV,驱逐不重要的
  • StreamingLLM:只保留开头几个 token(attention sink)+ 最近的窗口
  • Scissorhands:基于历史 attention pattern 预测哪些 token 未来不会被关注
StreamingLLM 策略:
[sink tokens] + ... (丢弃) ... + [recent window]
[████]                           [████████████]
 ↑ 前4个token                    ↑ 最近N个token
 (attention sink)

5. 关键要点总结

KV Cache 核心认知
  • KV Cache = 用显存换计算量,是推理加速的基础
  • 内存公式: 2 × L × kv_heads × d × seq × batch × dtype
  • GQA 比 MHA 省数倍 KV Cache(Llama3/Mistral 都用 GQA)
  • PagedAttention 用分页消除碎片,利用率从 ~30% → ~97%
  • 长文本场景需要滑动窗口/量化/驱逐等额外策略
  • KV Cache 是推理服务 OOM 的头号原因

6. 延伸阅读