学习 [torch.compile] Add patch for fullgraph compilation
#42686
torch.compile 的两阶段架构:Dynamo vs Inductor
- TorchDynamo 负责追踪模型的tensor生成一个抽象的graph(FX Graph)
- TorchInductor 负责生成 triton kernel/c++ 编译并缓存
两种 partition 模式的对比
| 模式 | Dynamo 行为 | Inductor 行为 | 问题 |
|---|---|---|---|
| Dynamo Partition | 切成多个小子图 | 每个子图单独编译 | 子图小,内联影响有限 |
| Inductor Partition (Fullgraph) | 尽量不切分,大子图 | 整个大图一起编译 | 子图大,内联影响被放大 |
和piecewise的split的关系
- Dynamo Partition 和 Inductor Partition 是”二选一”,不是”先后关系”
use_inductor_graph_partition=False(默认)
# 整个流程:
Python forward 代码
↓
Dynamo 追踪,生成 FX Graph
↓
Dynamo 检查 splitting_ops,在这个位置切图: ← 这里 split!
↓
多个子图(Subgraph 1, Subgraph 2, ...)
↓
每个子图分别交给 Inductor 编译
↓
生成多个 cudagraph + attention eager 执行
Dynamo 是”切图的人”,Inductor 只是”编译工人”。
use_inductor_graph_partition=True
# 整个流程:
Python forward 代码
↓
Dynamo 追踪,生成 FX Graph
↓
Dynamo 不切图!整个交给 Inductor ← Dynamo 不 split!
↓
Inductor 接收完整的大图
↓
Inductor 做优化、fusion
↓
Inductor 在代码生成时,根据需要 partition ← 这里 split!
↓
生成优化后的代码
这个PR引出的问题
当使用 use_inductor_graph_partition=True(Inductor Partition 模式)进行 fullgraph 编译时,性能可能变差,甚至比默认的 Dynamo Partition 还慢。
PyTorch 2.11 的 Inductor 编译器有一个 bug —— 它会过度内联被多次引用的中间张量。
在 vllm/env_override.py 中打一个 monkey patch,修复 Inductor 的 should_realize_on_reuse 启发式算法:
实现原理
通过monkey patch替代了torch的该方法
should_realize_on_reuse 新计算逻辑详解
核心问题
Inductor 编译时,遇到一个被多次引用的中间 tensor,需要决定:
- Inline(内联): 每次使用时重新计算
- Materialize(物化): 计算一次并存储到内存,后续直接读取
成本模型公式
total_read_bytes * (users - 1) >= output_bytes * (1 + users)
# ← 左式 ≥ 右式 →
# 就返回 True(物化)
公式推导
策略一:Inline(内联)
总内存流量 = 每次引用都要重新读取输入数据
= total_read_bytes × users
解释:
- 有 users 次引用
- 每次都要从内存读取 total_read_bytes 的输入
- 重新计算中间结果
策略二:Materialize(物化)
总内存流量 = 第一次计算的读取 + 写入输出 + 后续引用的读取
= total_read_bytes + output_bytes + output_bytes × users
= total_read_bytes + output_bytes × (1 + users)
解释:
- total_read_bytes: 计算一次时读取的输入
- output_bytes: 把结果写入内存
- output_bytes × users: 每次引用从内存读取结果
决策条件
当 Inline 成本 ≥ Materialize 成本时,选择物化:
total_read_bytes × users ≥ total_read_bytes + output_bytes × (1 + users)
两边同时减去 total_read_bytes:
total_read_bytes × (users - 1) ≥ output_bytes × (1 + users)
启发
- torch的compile很难绕过,因为确实产生了很多编译级别的优化
- 这个成本公式的思路可以学习,在极致的优化性能的时候通过这种方式
- 假如完全去掉pytorch,通过手写kernel+cuda graph,除非极致的优化每一个compile提供的优化细节,如何保证不出现性能倒退?尤其是一些尾部模型?