学习 #42885 [Perf][MLA] Enable FULL cudagraph capture for TRITON_MLA decode

cudagraph_support 功能和作用详解

级别 含义 适用场景
ALWAYS 3 始终支持 cudagraph,包括混合 prefill/decode 的 batch 最灵活,如 FlashInfer v3
UNIFORM_BATCH 2 支持 cudagraph,但要求 batch 内所有 query 长度相同 适合 speculative decode(每个 request 生成相同数量的 token)
UNIFORM_SINGLE_TOKEN_DECODE 1 仅支持 query_len==1 的纯 decode 场景 最严格的限制
NEVER 0 不支持 cudagraph 回退到 PIECEWISE 模式

每个 Attention Backend 的 MetadataBuilder 通过 _cudagraph_support 类变量声明自己的能力:

# Triton MLA (PR修改前) - 不支持
# 继承自 MLACommonMetadataBuilder,默认 _cudagraph_support = NEVER

# Triton MLA (PR修改后) - 支持 UNIFORM_BATCH
class TritonMLAMetadataBuilder(MLACommonMetadataBuilder):
    _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH

vllm/v1/worker/gpu/attn_utils.py 中,系统会遍历所有 attention layers,找出最小的 cudagraph 支持级别:

# Find minimum cudagraph support across all attention backends
min_cg_support = AttentionCGSupport.ALWAYS
min_cg_attn_backend = None

for layer in attention_layers:
    builder_cls = layer.backend.get_builder_cls()
    cg_support = builder_cls.get_cudagraph_support()
    if cg_support.value < min_cg_support.value:
        min_cg_support = cg_support
        min_cg_attn_backend = layer.backend.get_name()

然后在 vllm/config/compilation.py 中根据这个值决定是否启用 FULL cudagraph 模式:

if cudagraph_mode.has_full_cudagraphs() and min_cg_support == AttentionCGSupport.NEVER:
    # 如果不支持,降级到 PIECEWISE 模式
    raise ValueError("...")

修复:

为什么 UNIFORM_BATCH 是安全的?

  1. Batch 内所有 request 的计算模式一致(query_len 相同)
  2. 内存分配按 worst-case(按 max_seq_len 而非实际 seq_len)
  3. kernel 内部没有数据依赖的动态分支会导致 replay 失败,不支持prefill

TritonMLABackend 算法核心分析

  1. 获取 batch 信息
    B = q.shape[0] 
    q_num_heads = q.shape[1]
    
  2. 预分配输出 tensor
    o = torch.zeros(B, q_num_heads, self.kv_lora_rank, ...)
    
  3. 计算 num_kv_splits
  4. 分配 attn_logits buffer
  5. 调用 Triton kernel
    decode_attention_fwd(
     q,                        # (16, 128, kv_lora_rank) query
     kv_c_and_k_pe_cache,      # KV cache(paged)
     kv_c_cache,               # KV compressed cache
     o,                        # (16, 128, 64) 输出
     lse,                      # (16, 128) LSE
     attn_metadata.decode.block_table,  # 页表
     attn_metadata.decode.seq_lens,     # [100, 500, 1200, ..., 3000]
     attn_logits,              # (16, 128, 64, 65) 中间 buffer
     num_kv_splits,            # 64
     scale,                    # attention scale
     PAGE_SIZE,                # e.g. 16
     ...
    )
    

num_kv_splits 到底是什么含义?和 SM count 是什么关系?

为什么其他mlabackend支持的级别就是never

错误!都是batch,只有这个是搞错了

实际上固定形状的attention都容易支持capture,对应就是decode。prefill还是不能支持。

所以PD分离可以产生这方面的价值

Piecewise带来的开销具体分析

PIECEWISE 的本质问题:

Full Capture 的本质优势:

如果不是DP分离,系统是怎么处理的呢

检查 batch 类型:

  1. 如果是 pure decode: 使用 FULL cudagraph capture。一次 replay 执行所有层
  2. 如果是 mixed batch: fallback 到 PIECEWISE 模式 ,prefill 部分 eager 执行

启发

  1. 如果取消full graph的限制会不会产生cpu的开销?