GPU编程层次
Python DSL层 (Triton / CuTile / CuTe DSL)
直接把高层代码编译到 PTX,不经过 CUDA C++。2019 年以来最重要的范式变化。
| 框架 | 来源 | 核心理念 |
|---|---|---|
| Triton | OpenAI, 2019 | block/tile 级编程,线程映射、shared memory、warp 分配交给编译器 |
| CuTe DSL | NVIDIA, 2025 | 给 CuTe layout 代数套 Python 前端,开发体验接近 CUTLASS C++ 性能 |
| CuTile | NVIDIA, CUDA 13.1 | 官方 tile 编程模型,NVIDIA「自己下场做 Triton」的回应 |
均有自己的 MLIR 编译器,直接 lower 到 PTX。
C++ 模板层 (CUTLASS / CuTe / ThunderKittens)
动机:手写 CUDA C++ 写矩阵乘 / attention 太累、太容易出错,用 C++ 模板把高性能模式封装成可复用抽象。
| 框架 | 说明 |
|---|---|
| CUTLASS | CUDA Templates for Linear Algebra Subroutines,开源可定制的 cuBLAS |
| CuTe | CUTLASS 3.0 起,描述「数据怎么摆 + 线程怎么映射」的 layout 代数,是 CUTLASS 的地基 |
| ThunderKittens | 斯坦福 Hazy Research, 2024,用精简抽象在 H100 上打平/超越 CUTLASS 版 FlashAttention-3 |
CUDA C++ (.cu)
地基。2012 年以来唯一的选择,SIMT 模型 — grid → block(CTA) → warp → thread,自己管理寄存器、shared memory、tiling、同步。所有上层框架最终 fallback 到它。
编译链:.cu → nvcc → PTX → ptxas → SASS
日常说的「CUDA」指这门语言,也泛指 NVIDIA 整个 GPU 计算平台(语言 + 编译器 + 驱动 + 库)。大多数开发者通过 cuBLAS/cuDNN 间接使用,而非自己写。
PTX
Parallel Thread eXecution — NVIDIA 的虚拟 ISA(中间表示)。所有上层入口最后都要变成 PTX,再由驱动里的 ptxas 编成特定架构的 SASS。
实际工程里几乎没人从头用 PTX 写 kernel。真实用途:在 .cu 里通过 asm volatile 内联指令,碰 C++ 层暴露不出来的硬件特性(异步拷贝、cache hint、新硬件刚出编译器还没支持的指令)。PTX 是补丁工具,不是开发语言。
SASS
Streaming ASSembly — 特定架构的真实机器码。通常只在做极致性能分析或逆向时才看。
NVIDIA 闭源库
| 库 | 用途 |
|---|---|
| cuBLAS | 通用线性代数,GEMM / BLAS |
| cuDNN | 深度学习算子:卷积、池化、归一化、attention |
PyTorch 矩阵乘默认走 cuBLAS,卷积走 cuDNN。黑盒,不能改、看不见内部,但参数对就能拿到 NVIDIA 多年调优的性能。
CUTLASS 是 NVIDIA 给这条闭源生态做的「开源积木版本」,名字致敬 cuBLAS。
代码对比
cuBLAS — 黑盒调用
cublasHandle_t handle;
cublasCreate(&handle);
float alpha = 1.0f, beta = 0.0f;
cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_N, M, N, K,
&alpha, dA, M, dB, K, &beta, dC, M);
cublasDestroy(handle);
只关心:尺寸、转置、指针。内部怎么 tile、用不用 Tensor Core、线程怎么分 — NVIDIA 全替你定了。
纯 CUDA C++ — 朴素手写
__global__ void gemm_naive(float* A, float* B, float* C, int M, int N, int K) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < M && col < N) {
float sum = 0.0f;
for (int k = 0; k < K; ++k)
sum += A[row * K + k] * B[k * N + col];
C[row * N + col] = sum;
}
}
能跑,但慢 — 没用 shared memory、Tensor Core、内存访问未优化。要写快得手动加 tiling、SMEM、bank conflict 处理,换代硬件要重写。
CuTe — layout 代数描述数据编排
using namespace cute;
// 裸内存包装成带 layout 的 Tensor
Tensor mA = make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(_1{}, M));
Tensor mB = make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(_1{}, N));
Tensor mC = make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(_1{}, M));
// 声明切块:每个 block 处理 128×128×8
auto block_tile = make_shape(Int<128>{}, Int<128>{}, Int<8>{});
// 用 layout 把全局矩阵切给当前 block
Tensor gA = local_tile(mA, block_tile, ...);
Tensor gB = local_tile(mB, block_tile, ...);
// 声明 tiled MMA:用哪条 Tensor Core 指令
TiledMMA mma = make_tiled_mma(SM80_16x8x8_F32F16F16F32_TN{}, ...);
cute::gemm(mma, gA, gB, gC);
没有手写 A[row*K+k],而是用 make_shape / make_stride / local_tile 声明「数据长什么样、怎么切」。layout 代数管数据编排,不是调现成函数。
CUTLASS — 模板组装生产级 kernel
using Gemm = device::GemmUniversal<
cutlass::half_t, cutlass::layout::RowMajor, // A
cutlass::half_t, cutlass::layout::ColumnMajor, // B
float, cutlass::layout::RowMajor, // C
float, // 累加
cutlass::arch::OpClassTensorOp, // Tensor Core
cutlass::arch::Sm90, // Hopper
Shape<_128,_128,_64>, // block tile
Shape<_64, _64, _64> // warp tile
>;
Gemm gemm_op;
gemm_op({M, N, K}, {dA, lda}, {dB, ldb}, {dC, ldc}, {alpha, beta});
你定类型、布局、架构、tile 大小、warp 分配。CUTLASS 编译成高度优化的 kernel。
Triton — Python tile 级
@triton.jit
def gemm_kernel(A, B, C, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(A + offs_m[:, None] * stride_am + (k + offs_k[None, :]) * stride_ak)
b = tl.load(B + (k + offs_k[:, None]) * stride_bk + offs_n[None, :] * stride_bn)
acc += tl.dot(a, b)
tl.store(C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn, acc)
关键区别:没有 threadIdx,没有 __shared__,没有同步原语。以 tile 为单位思考 — Triton 编译器自动决定线程映射、shared memory staging、Tensor Core 指令、内存合并。
FlashAttention 演进
| 版本 | 时间 | 硬件 | 实现方式 | 关键变化 |
|---|---|---|---|---|
| FA1 | 2022 | A100 | 手写 CUDA C++ | 首次提出 tiling + online softmax, SRAM-aware |
| FA2 | 2023 | A100 | CUTLASS 3.x / CuTe | C++ 模板重写,比 FA1 快 ~2× |
| FA3 | 2024 | H100 | CUTLASS 深度适配 | WGMMA + TMA + setmaxnreg,达 ~740 TFLOPS |
| FA4 | 2025-26 | H100/B200 | CuTe DSL | Python 写,性能逼近 C++,比 Triton 版快 ~50% |
FlashAttention 4 本质是用 Python DSL 写,编译器生成 PTX。
vLLM / SGLang 的定位
-
调度/抽象层 — attention/GEMM backend 的可插拔抽象 + 运行时自动选最优后端。真正的工程壁垒,不是 kernel 本身。
-
自研 Triton kernel — 两处:(a) 适配独特数据结构的算子(PagedAttention、KV cache 抓取);(b) 跨硬件兜底后端。
-
主力高性能算子 — 全部外包给 FlashAttention、FlashInfer、CUTLASS、CuTe DSL、TRT-LLM,不自己重造。
参考