KV Cache:推理性能的命根子
1. 为什么需要 KV Cache?
1.1 回顾 Autoregressive Decoding
你已经知道 decode 阶段每次只生成 1 个 token。关键问题是:生成第 N 个 token 时,attention 需要看到前面所有 N-1 个 token 的 Key 和 Value。
如果不缓存,每生成一个 token 都要重新计算所有历史 token 的 K、V 向量:
生成 token 1: 计算 K₁, V₁
生成 token 2: 重新计算 K₁, V₁, 再计算 K₂, V₂
生成 token 3: 重新计算 K₁, V₁, K₂, V₂, 再计算 K₃, V₃
...
生成 token N: 重新计算 K₁..K_{N-1}, V₁..V_{N-1} ← O(N²) 次计算!
KV Cache 的核心思想:把已经算过的 K、V 存起来,下次直接用。
生成 token 1: 计算 K₁, V₁ → 存入 cache
生成 token 2: 从 cache 读 K₁, V₁, 计算 K₂, V₂ → 存入 cache
生成 token 3: 从 cache 读 K₁, V₁, K₂, V₂, 计算 K₃, V₃ → 存入 cache
...
每步只需计算 1 个新 token 的 K, V ← O(N) 次计算
代价:用显存换计算量。计算从 O(N²) 降到 O(N),但显存占用线性增长。
1.2 KV Cache 在 Attention 中的位置
回忆 Self-Attention 的计算:
Q = X · W_Q (Query)
K = X · W_K (Key)
V = X · W_V (Value)
Attention(Q, K, V) = softmax(Q · K^T / √d_k) · V
在 decode 阶段:
- Q:只有当前新 token 的 query(1 个向量)
- K:cache 中所有历史 token 的 key + 当前新 token 的 key
- V:cache 中所有历史 token 的 value + 当前新 token 的 value
所以每一步 decode,KV Cache 都在增长:
Step 1: cache = [K₁, V₁]
Step 2: cache = [K₁, K₂], [V₁, V₂]
Step 3: cache = [K₁, K₂, K₃], [V₁, V₂, V₃]
...
Step N: cache = [K₁...K_N], [V₁...V_N]
2. KV Cache 内存计算公式
2.1 核心公式(MHA - Multi-Head Attention)
KV Cache 内存 = 2 × layers × heads × head_dim × seq_len × batch_size × bytes_per_element
各参数含义:
| 参数 | 含义 | 示例 (Llama-3-8B) |
|---|---|---|
2 | K 和 V 两个张量 | 固定值 |
layers | Transformer 层数 | 32 |
heads | 注意力头数 | 32 (MHA) |
head_dim | 每个头的维度 | 128 |
seq_len | 序列长度 | 2048 |
batch_size | 批大小 | 1 |
bytes_per_element | 数据精度 | 2 (FP16/BF16) |
2.2 实际计算示例
Llama-3-8B (MHA), FP16, seq_len=2048, batch=1:
2 × 32 × 32 × 128 × 2048 × 1 × 2 bytes
= 2 × 32 × 32 × 128 × 2048 × 2
= 1,073,741,824 bytes
= 1 GB
一条请求的 KV Cache 就要 1GB 显存!
batch=32 时:
1 GB × 32 = 32 GB ← 光 KV Cache 就 32GB
加上模型权重(8B × 2 bytes = 16GB),总共需要 48GB,A100-40GB 直接 OOM。
2.3 不同模型的 KV Cache 对比
| 模型 | Layers | Heads | head_dim | KV Cache/token/batch (FP16) |
|---|---|---|---|---|
| Llama-3-8B | 32 | 32 (MHA) | 128 | 0.5 MB/token |
| Llama-3-70B | 80 | 64 (MHA) | 128 | 2.5 MB/token |
| Mistral-7B | 32 | 8 (GQA) | 128 | 0.125 MB/token |
注意 Mistral 用了 GQA(Grouped-Query Attention),KV heads 只有 8 个而不是 32 个, KV Cache 直接缩小 4 倍!
2.4 MHA vs MQA vs GQA 对 KV Cache 的影响
这是现代模型架构减少 KV Cache 的核心手段:
MHA (Multi-Head Attention):
每个 attention head 都有独立的 K, V
KV heads = Q heads = 32
KV Cache 最大
MQA (Multi-Query Attention):
所有 Q heads 共享 1 组 K, V
KV heads = 1
KV Cache 最小,但质量可能下降
GQA (Grouped-Query Attention):
Q heads 分组,每组共享 1 组 K, V
KV heads = Q heads / group_size (如 32/4 = 8)
KV Cache 适中,质量接近 MHA
MHA: Q₁K₁V₁ Q₂K₂V₂ Q₃K₃V₃ Q₄K₄V₄ ← 4组 KV
GQA: Q₁Q₂→K₁V₁ Q₃Q₄→K₂V₂ ← 2组 KV (省一半)
MQA: Q₁Q₂Q₃Q₄→K₁V₁ ← 1组 KV (省最多)
GQA 公式修正:
KV Cache = 2 × layers × kv_heads × head_dim × seq_len × batch × bytes
注意这里用 kv_heads 而不是 q_heads。
3. PagedAttention:用虚拟内存思想管理 KV Cache
3.1 传统 KV Cache 的问题
传统方式为每个请求预分配一块连续显存来存 KV Cache:
请求 A: [████████░░░░░░░░] ← 预分配 max_seq_len,实际只用了一半
请求 B: [██████░░░░░░░░░░] ← 更浪费
请求 C: [无法分配] ← 虽然总空闲够,但找不到连续空间
三大问题:
- 内部碎片:预分配 max_seq_len 但实际用不完,浪费 60-80%
- 外部碎片:请求结束释放后,空闲块不连续,新请求放不进去
- 过度预留:不知道请求会生成多长,只能按最大值预留
3.2 PagedAttention 的核心思想
vLLM 论文(SOSP 2023)借鉴了操作系统的虚拟内存分页机制:
操作系统: 虚拟页 → 页表 → 物理页帧(不需要连续)
vLLM: 逻辑 KV block → block table → 物理 GPU 显存块(不需要连续)
工作方式:
1. 把 KV Cache 切成固定大小的 block(如 16 tokens 一块)
2. 每个请求维护一个 block table(逻辑块 → 物理块的映射)
3. 按需分配:生成新 token 时才分配新 block
4. 物理块可以在显存中任意位置,不需要连续
请求 A 的 block table:
逻辑块 0 → 物理块 7
逻辑块 1 → 物理块 2
逻辑块 2 → 物理块 15
请求 B 的 block table:
逻辑块 0 → 物理块 3
逻辑块 1 → 物理块 11
效果:
传统方式: [AAAA____BBBB____CCCC____] ← 大量浪费
PagedAttn: [AB CA BC AB CC BA] ← 几乎零浪费
3.3 PagedAttention 的关键优势
| 特性 | 传统方式 | PagedAttention |
|---|---|---|
| 内存分配 | 预分配连续大块 | 按需分配小块 |
| 内部碎片 | 严重(60-80%浪费) | 极小(最后一个 block 可能有少量) |
| 外部碎片 | 严重 | 无(不需要连续) |
| 内存利用率 | ~20-40% | ~96-98% |
| 并发请求数 | 少 | 多 2-4 倍 |
| 内存共享 | 困难 | 天然支持(Copy-on-Write) |
3.4 Copy-on-Write 共享
当多个请求共享相同的 prompt(如 system prompt),PagedAttention 可以让它们共享同一份 KV Cache 物理块:
请求 A (system prompt + 用户问题A):
block table: [共享块0, 共享块1, 共享块2, A专属块3, A专属块4]
请求 B (system prompt + 用户问题B):
block table: [共享块0, 共享块1, 共享块2, B专属块3, B专属块4]
↑ 同一份物理内存,不重复存储
这在 chat 场景下(大量请求共享 system prompt)节省巨大。
4. 长文本推理中的 KV Cache 管理策略
当 seq_len 达到 128K 甚至 1M 时,即使用了 GQA + PagedAttention,KV Cache 仍然是巨大的挑战。
4.1 滑动窗口注意力 (Sliding Window Attention)
Mistral 采用的策略:只保留最近 W 个 token 的 KV Cache,更早的丢弃。
窗口大小 W = 4096
token 位置: 1 2 3 ... 4096 4097 4098 ...
KV Cache: [████████████████]
↑ 只保留最近 4096 个
token 1-xxx 的 KV 被丢弃
优点:KV Cache 大小固定,不随序列增长 缺点:丢失远距离依赖信息
4.2 KV Cache 量化
对 KV Cache 本身做量化(FP16 → INT8 / INT4):
原始 KV Cache (FP16): 1 GB
INT8 量化后: 0.5 GB (省 50%)
INT4 量化后: 0.25 GB (省 75%)
vLLM 和 TensorRT-LLM 都支持 KV Cache 量化,精度损失通常可接受。
4.3 KV Cache 压缩与驱逐
更高级的策略:
- H2O (Heavy-Hitter Oracle):保留 attention score 最高的 token 的 KV,驱逐不重要的
- StreamingLLM:只保留开头几个 token(attention sink)+ 最近的窗口
- Scissorhands:基于历史 attention pattern 预测哪些 token 未来不会被关注
StreamingLLM 策略:
[sink tokens] + ... (丢弃) ... + [recent window]
[████] [████████████]
↑ 前4个token ↑ 最近N个token
(attention sink)
5. 关键要点总结
┌─────────────────────────────────────────────────────────┐
│ KV Cache 核心认知 │
├─────────────────────────────────────────────────────────┤
│ 1. KV Cache = 用显存换计算量,是推理加速的基础 │
│ 2. 内存公式: 2 × L × kv_heads × d × seq × batch × dtype │
│ 3. GQA 比 MHA 省数倍 KV Cache(Llama3/Mistral 都用 GQA) │
│ 4. PagedAttention 用分页消除碎片,利用率从 ~30% → ~97% │
│ 5. 长文本场景需要滑动窗口/量化/驱逐等额外策略 │
│ 6. KV Cache 是推理服务 OOM 的头号原因 │
└─────────────────────────────────────────────────────────┘
6. 延伸阅读
- vLLM 论文: Efficient Memory Management for LLMs with PagedAttention (SOSP 2023)
- KV Cache Memory Calculation for LLMs — 详细的内存计算指南
- PagedAttention Memory-Level Analysis — 深入分析 PagedAttention 的内存行为
- How PagedAttention Resolves Memory Waste — Red Hat 的工程视角解读
修改历史1 次提交
- docs(ai-systems): add comprehensive LLM inference documentationxiaocheng··
7c98505