跳转到主要内容

KV Cache:推理性能的命根子

1. 为什么需要 KV Cache?

1.1 回顾 Autoregressive Decoding

你已经知道 decode 阶段每次只生成 1 个 token。关键问题是:生成第 N 个 token 时,attention 需要看到前面所有 N-1 个 token 的 Key 和 Value。

如果不缓存,每生成一个 token 都要重新计算所有历史 token 的 K、V 向量:

生成 token 1: 计算 K₁, V₁
生成 token 2: 重新计算 K₁, V₁, 再计算 K₂, V₂
生成 token 3: 重新计算 K₁, V₁, K₂, V₂, 再计算 K₃, V₃
...
生成 token N: 重新计算 K₁..K_{N-1}, V₁..V_{N-1}  ← O(N²) 次计算!

KV Cache 的核心思想:把已经算过的 K、V 存起来,下次直接用

生成 token 1: 计算 K₁, V₁ → 存入 cache
生成 token 2: 从 cache 读 K₁, V₁, 计算 K₂, V₂ → 存入 cache
生成 token 3: 从 cache 读 K₁, V₁, K₂, V₂, 计算 K₃, V₃ → 存入 cache
...
每步只需计算 1 个新 token 的 K, V  ← O(N) 次计算

代价:用显存换计算量。计算从 O(N²) 降到 O(N),但显存占用线性增长。

1.2 KV Cache 在 Attention 中的位置

回忆 Self-Attention 的计算:

Q = X · W_Q    (Query)
K = X · W_K    (Key)
V = X · W_V    (Value)

Attention(Q, K, V) = softmax(Q · K^T / √d_k) · V

在 decode 阶段:

  • Q:只有当前新 token 的 query(1 个向量)
  • K:cache 中所有历史 token 的 key + 当前新 token 的 key
  • V:cache 中所有历史 token 的 value + 当前新 token 的 value

所以每一步 decode,KV Cache 都在增长:

Step 1: cache = [K₁, V₁]
Step 2: cache = [K₁, K₂], [V₁, V₂]
Step 3: cache = [K₁, K₂, K₃], [V₁, V₂, V₃]
...
Step N: cache = [K₁...K_N], [V₁...V_N]

2. KV Cache 内存计算公式

2.1 核心公式(MHA - Multi-Head Attention)

KV Cache 内存 = 2 × layers × heads × head_dim × seq_len × batch_size × bytes_per_element

各参数含义:

参数含义示例 (Llama-3-8B)
2K 和 V 两个张量固定值
layersTransformer 层数32
heads注意力头数32 (MHA)
head_dim每个头的维度128
seq_len序列长度2048
batch_size批大小1
bytes_per_element数据精度2 (FP16/BF16)

2.2 实际计算示例

Llama-3-8B (MHA), FP16, seq_len=2048, batch=1:

2 × 32 × 32 × 128 × 2048 × 1 × 2 bytes
= 2 × 32 × 32 × 128 × 2048 × 2
= 1,073,741,824 bytes
= 1 GB

一条请求的 KV Cache 就要 1GB 显存!

batch=32 时:

1 GB × 32 = 32 GB  ← 光 KV Cache 就 32GB

加上模型权重(8B × 2 bytes = 16GB),总共需要 48GB,A100-40GB 直接 OOM。

2.3 不同模型的 KV Cache 对比

模型LayersHeadshead_dimKV Cache/token/batch (FP16)
Llama-3-8B3232 (MHA)1280.5 MB/token
Llama-3-70B8064 (MHA)1282.5 MB/token
Mistral-7B328 (GQA)1280.125 MB/token

注意 Mistral 用了 GQA(Grouped-Query Attention),KV heads 只有 8 个而不是 32 个, KV Cache 直接缩小 4 倍!

2.4 MHA vs MQA vs GQA 对 KV Cache 的影响

这是现代模型架构减少 KV Cache 的核心手段:

MHA (Multi-Head Attention):
  每个 attention head 都有独立的 K, V
  KV heads = Q heads = 32
  KV Cache 最大

MQA (Multi-Query Attention):
  所有 Q heads 共享 1 组 K, V
  KV heads = 1
  KV Cache 最小,但质量可能下降

GQA (Grouped-Query Attention):
  Q heads 分组,每组共享 1 组 K, V
  KV heads = Q heads / group_size (如 32/4 = 8)
  KV Cache 适中,质量接近 MHA
MHA:  Q₁K₁V₁  Q₂K₂V₂  Q₃K₃V₃  Q₄K₄V₄   ← 4组 KV
GQA:  Q₁Q₂→K₁V₁  Q₃Q₄→K₂V₂              ← 2组 KV (省一半)
MQA:  Q₁Q₂Q₃Q₄→K₁V₁                      ← 1组 KV (省最多)

GQA 公式修正:

KV Cache = 2 × layers × kv_heads × head_dim × seq_len × batch × bytes

注意这里用 kv_heads 而不是 q_heads


3. PagedAttention:用虚拟内存思想管理 KV Cache

3.1 传统 KV Cache 的问题

传统方式为每个请求预分配一块连续显存来存 KV Cache:

请求 A: [████████░░░░░░░░]  ← 预分配 max_seq_len,实际只用了一半
请求 B: [██████░░░░░░░░░░]  ← 更浪费
请求 C: [无法分配]           ← 虽然总空闲够,但找不到连续空间

三大问题:

  1. 内部碎片:预分配 max_seq_len 但实际用不完,浪费 60-80%
  2. 外部碎片:请求结束释放后,空闲块不连续,新请求放不进去
  3. 过度预留:不知道请求会生成多长,只能按最大值预留

3.2 PagedAttention 的核心思想

vLLM 论文(SOSP 2023)借鉴了操作系统的虚拟内存分页机制:

操作系统:  虚拟页 → 页表 → 物理页帧(不需要连续)
vLLM:     逻辑 KV block → block table → 物理 GPU 显存块(不需要连续)

工作方式:

1. 把 KV Cache 切成固定大小的 block(如 16 tokens 一块)
2. 每个请求维护一个 block table(逻辑块 → 物理块的映射)
3. 按需分配:生成新 token 时才分配新 block
4. 物理块可以在显存中任意位置,不需要连续

请求 A 的 block table:
  逻辑块 0 → 物理块 7
  逻辑块 1 → 物理块 2
  逻辑块 2 → 物理块 15

请求 B 的 block table:
  逻辑块 0 → 物理块 3
  逻辑块 1 → 物理块 11

效果:

传统方式:  [AAAA____BBBB____CCCC____]  ← 大量浪费
PagedAttn: [AB CA BC AB CC BA]          ← 几乎零浪费

3.3 PagedAttention 的关键优势

特性传统方式PagedAttention
内存分配预分配连续大块按需分配小块
内部碎片严重(60-80%浪费)极小(最后一个 block 可能有少量)
外部碎片严重无(不需要连续)
内存利用率~20-40%~96-98%
并发请求数多 2-4 倍
内存共享困难天然支持(Copy-on-Write)

3.4 Copy-on-Write 共享

当多个请求共享相同的 prompt(如 system prompt),PagedAttention 可以让它们共享同一份 KV Cache 物理块:

请求 A (system prompt + 用户问题A):
  block table: [共享块0, 共享块1, 共享块2, A专属块3, A专属块4]

请求 B (system prompt + 用户问题B):
  block table: [共享块0, 共享块1, 共享块2, B专属块3, B专属块4]
                ↑ 同一份物理内存,不重复存储

这在 chat 场景下(大量请求共享 system prompt)节省巨大。


4. 长文本推理中的 KV Cache 管理策略

当 seq_len 达到 128K 甚至 1M 时,即使用了 GQA + PagedAttention,KV Cache 仍然是巨大的挑战。

4.1 滑动窗口注意力 (Sliding Window Attention)

Mistral 采用的策略:只保留最近 W 个 token 的 KV Cache,更早的丢弃。

窗口大小 W = 4096

token 位置:  1  2  3  ...  4096  4097  4098  ...
KV Cache:                  [████████████████]
                            ↑ 只保留最近 4096 个
                            token 1-xxx 的 KV 被丢弃

优点:KV Cache 大小固定,不随序列增长 缺点:丢失远距离依赖信息

4.2 KV Cache 量化

对 KV Cache 本身做量化(FP16 → INT8 / INT4):

原始 KV Cache (FP16): 1 GB
INT8 量化后:          0.5 GB  (省 50%)
INT4 量化后:          0.25 GB (省 75%)

vLLM 和 TensorRT-LLM 都支持 KV Cache 量化,精度损失通常可接受。

4.3 KV Cache 压缩与驱逐

更高级的策略:

  • H2O (Heavy-Hitter Oracle):保留 attention score 最高的 token 的 KV,驱逐不重要的
  • StreamingLLM:只保留开头几个 token(attention sink)+ 最近的窗口
  • Scissorhands:基于历史 attention pattern 预测哪些 token 未来不会被关注
StreamingLLM 策略:
[sink tokens] + ... (丢弃) ... + [recent window]
[████]                           [████████████]
 ↑ 前4个token                    ↑ 最近N个token
 (attention sink)

5. 关键要点总结

┌─────────────────────────────────────────────────────────┐
│  KV Cache 核心认知                                        │
├─────────────────────────────────────────────────────────┤
│  1. KV Cache = 用显存换计算量,是推理加速的基础              │
│  2. 内存公式: 2 × L × kv_heads × d × seq × batch × dtype  │
│  3. GQA 比 MHA 省数倍 KV Cache(Llama3/Mistral 都用 GQA)  │
│  4. PagedAttention 用分页消除碎片,利用率从 ~30% → ~97%     │
│  5. 长文本场景需要滑动窗口/量化/驱逐等额外策略              │
│  6. KV Cache 是推理服务 OOM 的头号原因                      │
└─────────────────────────────────────────────────────────┘

6. 延伸阅读

修改历史1 次提交