学习 [torch.compile] Add patch for fullgraph compilation

#42686

torch.compile 的两阶段架构:Dynamo vs Inductor

  1. TorchDynamo 负责追踪模型的tensor生成一个抽象的graph(FX Graph)
  2. TorchInductor 负责生成 triton kernel/c++ 编译并缓存

两种 partition 模式的对比

模式 Dynamo 行为 Inductor 行为 问题
Dynamo Partition 切成多个小子图 每个子图单独编译 子图小,内联影响有限
Inductor Partition (Fullgraph) 尽量不切分,大子图 整个大图一起编译 子图大,内联影响被放大

和piecewise的split的关系

  1. Dynamo Partition 和 Inductor Partition 是”二选一”,不是”先后关系”

use_inductor_graph_partition=False(默认)

# 整个流程:
Python forward 代码
    
Dynamo 追踪生成 FX Graph
    
Dynamo 检查 splitting_ops在这个位置切图   这里 split
    
多个子图Subgraph 1, Subgraph 2, ...
    
每个子图分别交给 Inductor 编译
    
生成多个 cudagraph + attention eager 执行

Dynamo 是”切图的人”,Inductor 只是”编译工人”。

use_inductor_graph_partition=True

# 整个流程:
Python forward 代码
    
Dynamo 追踪生成 FX Graph
    
Dynamo 不切图整个交给 Inductor         Dynamo  split
    
Inductor 接收完整的大图
    
Inductor 做优化fusion
    
Inductor 在代码生成时根据需要 partition  这里 split
    
生成优化后的代码

这个PR引出的问题

当使用 use_inductor_graph_partition=True(Inductor Partition 模式)进行 fullgraph 编译时,性能可能变差,甚至比默认的 Dynamo Partition 还慢。

PyTorch 2.11 的 Inductor 编译器有一个 bug —— 它会过度内联被多次引用的中间张量。

vllm/env_override.py 中打一个 monkey patch,修复 Inductor 的 should_realize_on_reuse 启发式算法:

实现原理

通过monkey patch替代了torch的该方法

should_realize_on_reuse 新计算逻辑详解

核心问题

Inductor 编译时,遇到一个被多次引用的中间 tensor,需要决定:


成本模型公式

total_read_bytes * (users - 1) >= output_bytes * (1 + users)
#         ← 左式 ≥ 右式 →
#     就返回 True(物化)

公式推导

策略一:Inline(内联)

总内存流量 = 每次引用都要重新读取输入数据
           = total_read_bytes × users

解释:
- 有 users 次引用
- 每次都要从内存读取 total_read_bytes 的输入
- 重新计算中间结果

策略二:Materialize(物化)

总内存流量 = 第一次计算的读取 + 写入输出 + 后续引用的读取
           = total_read_bytes + output_bytes + output_bytes × users
           = total_read_bytes + output_bytes × (1 + users)

解释:
- total_read_bytes: 计算一次时读取的输入
- output_bytes: 把结果写入内存
- output_bytes × users: 每次引用从内存读取结果

决策条件

当 Inline 成本 ≥ Materialize 成本时,选择物化:

total_read_bytes × users ≥ total_read_bytes + output_bytes × (1 + users)

两边同时减去 total_read_bytes:

total_read_bytes × (users - 1) ≥ output_bytes × (1 + users)

启发

  1. torch的compile很难绕过,因为确实产生了很多编译级别的优化
  2. 这个成本公式的思路可以学习,在极致的优化性能的时候通过这种方式
  3. 假如完全去掉pytorch,通过手写kernel+cuda graph,除非极致的优化每一个compile提供的优化细节,如何保证不出现性能倒退?尤其是一些尾部模型?