学习 #42885 [Perf][MLA] Enable FULL cudagraph capture for TRITON_MLA decode
cudagraph_support 功能和作用详解
| 级别 | 值 | 含义 | 适用场景 |
|---|---|---|---|
| ALWAYS | 3 | 始终支持 cudagraph,包括混合 prefill/decode 的 batch | 最灵活,如 FlashInfer v3 |
| UNIFORM_BATCH | 2 | 支持 cudagraph,但要求 batch 内所有 query 长度相同 | 适合 speculative decode(每个 request 生成相同数量的 token) |
| UNIFORM_SINGLE_TOKEN_DECODE | 1 | 仅支持 query_len==1 的纯 decode 场景 | 最严格的限制 |
| NEVER | 0 | 不支持 cudagraph | 回退到 PIECEWISE 模式 |
每个 Attention Backend 的 MetadataBuilder 通过 _cudagraph_support 类变量声明自己的能力:
# Triton MLA (PR修改前) - 不支持
# 继承自 MLACommonMetadataBuilder,默认 _cudagraph_support = NEVER
# Triton MLA (PR修改后) - 支持 UNIFORM_BATCH
class TritonMLAMetadataBuilder(MLACommonMetadataBuilder):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
在 vllm/v1/worker/gpu/attn_utils.py 中,系统会遍历所有 attention layers,找出最小的 cudagraph 支持级别:
# Find minimum cudagraph support across all attention backends
min_cg_support = AttentionCGSupport.ALWAYS
min_cg_attn_backend = None
for layer in attention_layers:
builder_cls = layer.backend.get_builder_cls()
cg_support = builder_cls.get_cudagraph_support()
if cg_support.value < min_cg_support.value:
min_cg_support = cg_support
min_cg_attn_backend = layer.backend.get_name()
然后在 vllm/config/compilation.py 中根据这个值决定是否启用 FULL cudagraph 模式:
if cudagraph_mode.has_full_cudagraphs() and min_cg_support == AttentionCGSupport.NEVER:
# 如果不支持,降级到 PIECEWISE 模式
raise ValueError("...")
MLACommonMetadataBuilder默认_cudagraph_support = NEVER- 导致 Triton MLA backend 无法使用 FULL cudagraph 模式
- 每次 decode 都要经过 Python dispatch,61层 × 50 tokens = 3050 次分发,耗费 2.3s CPU 时间
修复:
- 新增
TritonMLAMetadataBuilder子类,声明_cudagraph_support = UNIFORM_BATCH - 让系统知道 Triton MLA 可以支持 cudagraph 捕获
- decode 阶段可以被完整捕获到 CUDA Graph 中,避免 per-layer dispatch
为什么 UNIFORM_BATCH 是安全的?
- Batch 内所有 request 的计算模式一致(query_len 相同)
- 内存分配按 worst-case(按 max_seq_len 而非实际 seq_len)
- kernel 内部没有数据依赖的动态分支会导致 replay 失败,不支持prefill
TritonMLABackend 算法核心分析
- 获取 batch 信息
B = q.shape[0] q_num_heads = q.shape[1] - 预分配输出 tensor
o = torch.zeros(B, q_num_heads, self.kv_lora_rank, ...) - 计算
num_kv_splits - 分配
attn_logitsbuffer - 调用 Triton kernel
decode_attention_fwd( q, # (16, 128, kv_lora_rank) query kv_c_and_k_pe_cache, # KV cache(paged) kv_c_cache, # KV compressed cache o, # (16, 128, 64) 输出 lse, # (16, 128) LSE attn_metadata.decode.block_table, # 页表 attn_metadata.decode.seq_lens, # [100, 500, 1200, ..., 3000] attn_logits, # (16, 128, 64, 65) 中间 buffer num_kv_splits, # 64 scale, # attention scale PAGE_SIZE, # e.g. 16 ... )
num_kv_splits 到底是什么含义?和 SM count 是什么关系?
-
max_splits = SM_count × 2(或 ×4)是物理上限 -
为什么?因为每个 split 对应一个 CUDA thread block
-
同时运行的 block 数不能超过 SM 能承载的数量
-
如果
num_kv_splits > SM_count,多余的 block 只能排队等 -
所以
num_kv_splits设得再大也没意义,反而浪费内存
为什么其他mlabackend支持的级别就是never
错误!都是batch,只有这个是搞错了
实际上固定形状的attention都容易支持capture,对应就是decode。prefill还是不能支持。
所以PD分离可以产生这方面的价值
Piecewise带来的开销具体分析
PIECEWISE 的本质问题:
- Attention 是模型的热点路径(占 50-70% 的执行时间)
- PIECEWISE 让这部分每次都在 eager mode 执行
- 每次 decode step 都要重新走一遍完整的 CPU → GPU 调用链
Full Capture 的本质优势:
- Attention 和其他层一样,被录制成固定的 kernel 序列
- decode step 只需要 1 次 launch,GPU 自动执行所有操作
- CPU 只需要”点火”,不需要”指挥每一步”
如果不是DP分离,系统是怎么处理的呢
检查 batch 类型:
- 如果是 pure decode: 使用 FULL cudagraph capture。一次 replay 执行所有层
- 如果是 mixed batch: fallback 到 PIECEWISE 模式 ,prefill 部分 eager 执行
启发
- 如果取消full graph的限制会不会产生cpu的开销?