GPU编程层次

Python DSL层 (Triton / CuTile / CuTe DSL)

直接把高层代码编译到 PTX,不经过 CUDA C++。2019 年以来最重要的范式变化。

框架 来源 核心理念
Triton OpenAI, 2019 block/tile 级编程,线程映射、shared memory、warp 分配交给编译器
CuTe DSL NVIDIA, 2025 给 CuTe layout 代数套 Python 前端,开发体验接近 CUTLASS C++ 性能
CuTile NVIDIA, CUDA 13.1 官方 tile 编程模型,NVIDIA「自己下场做 Triton」的回应

均有自己的 MLIR 编译器,直接 lower 到 PTX。

C++ 模板层 (CUTLASS / CuTe / ThunderKittens)

动机:手写 CUDA C++ 写矩阵乘 / attention 太累、太容易出错,用 C++ 模板把高性能模式封装成可复用抽象。

框架 说明
CUTLASS CUDA Templates for Linear Algebra Subroutines,开源可定制的 cuBLAS
CuTe CUTLASS 3.0 起,描述「数据怎么摆 + 线程怎么映射」的 layout 代数,是 CUTLASS 的地基
ThunderKittens 斯坦福 Hazy Research, 2024,用精简抽象在 H100 上打平/超越 CUTLASS 版 FlashAttention-3

CUDA C++ (.cu)

地基。2012 年以来唯一的选择,SIMT 模型 — grid → block(CTA) → warp → thread,自己管理寄存器、shared memory、tiling、同步。所有上层框架最终 fallback 到它。

编译链:.cu → nvcc → PTX → ptxas → SASS

日常说的「CUDA」指这门语言,也泛指 NVIDIA 整个 GPU 计算平台(语言 + 编译器 + 驱动 + 库)。大多数开发者通过 cuBLAS/cuDNN 间接使用,而非自己写。

PTX

Parallel Thread eXecution — NVIDIA 的虚拟 ISA(中间表示)。所有上层入口最后都要变成 PTX,再由驱动里的 ptxas 编成特定架构的 SASS。

实际工程里几乎没人从头用 PTX 写 kernel。真实用途:在 .cu 里通过 asm volatile 内联指令,碰 C++ 层暴露不出来的硬件特性(异步拷贝、cache hint、新硬件刚出编译器还没支持的指令)。PTX 是补丁工具,不是开发语言。

SASS

Streaming ASSembly — 特定架构的真实机器码。通常只在做极致性能分析或逆向时才看。

NVIDIA 闭源库

用途
cuBLAS 通用线性代数,GEMM / BLAS
cuDNN 深度学习算子:卷积、池化、归一化、attention

PyTorch 矩阵乘默认走 cuBLAS,卷积走 cuDNN。黑盒,不能改、看不见内部,但参数对就能拿到 NVIDIA 多年调优的性能。

CUTLASS 是 NVIDIA 给这条闭源生态做的「开源积木版本」,名字致敬 cuBLAS。

代码对比

cuBLAS — 黑盒调用

cublasHandle_t handle;
cublasCreate(&handle);

float alpha = 1.0f, beta = 0.0f;
cublasSgemm(handle,
            CUBLAS_OP_N, CUBLAS_OP_N, M, N, K,
            &alpha, dA, M, dB, K, &beta, dC, M);

cublasDestroy(handle);

只关心:尺寸、转置、指针。内部怎么 tile、用不用 Tensor Core、线程怎么分 — NVIDIA 全替你定了。

纯 CUDA C++ — 朴素手写

__global__ void gemm_naive(float* A, float* B, float* C, int M, int N, int K) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    if (row < M && col < N) {
        float sum = 0.0f;
        for (int k = 0; k < K; ++k)
            sum += A[row * K + k] * B[k * N + col];
        C[row * N + col] = sum;
    }
}

能跑,但慢 — 没用 shared memory、Tensor Core、内存访问未优化。要写快得手动加 tiling、SMEM、bank conflict 处理,换代硬件要重写。

CuTe — layout 代数描述数据编排

using namespace cute;

// 裸内存包装成带 layout 的 Tensor
Tensor mA = make_tensor(make_gmem_ptr(A), make_shape(M, K), make_stride(_1{}, M));
Tensor mB = make_tensor(make_gmem_ptr(B), make_shape(N, K), make_stride(_1{}, N));
Tensor mC = make_tensor(make_gmem_ptr(C), make_shape(M, N), make_stride(_1{}, M));

// 声明切块:每个 block 处理 128×128×8
auto block_tile = make_shape(Int<128>{}, Int<128>{}, Int<8>{});

// 用 layout 把全局矩阵切给当前 block
Tensor gA = local_tile(mA, block_tile, ...);
Tensor gB = local_tile(mB, block_tile, ...);

// 声明 tiled MMA:用哪条 Tensor Core 指令
TiledMMA mma = make_tiled_mma(SM80_16x8x8_F32F16F16F32_TN{}, ...);

cute::gemm(mma, gA, gB, gC);

没有手写 A[row*K+k],而是用 make_shape / make_stride / local_tile 声明「数据长什么样、怎么切」。layout 代数管数据编排,不是调现成函数。

CUTLASS — 模板组装生产级 kernel

using Gemm = device::GemmUniversal<
    cutlass::half_t, cutlass::layout::RowMajor,    // A
    cutlass::half_t, cutlass::layout::ColumnMajor, // B
    float,           cutlass::layout::RowMajor,    // C
    float,                                         // 累加
    cutlass::arch::OpClassTensorOp,                // Tensor Core
    cutlass::arch::Sm90,                           // Hopper
    Shape<_128,_128,_64>,                          // block tile
    Shape<_64, _64, _64>                           // warp tile
>;

Gemm gemm_op;
gemm_op({M, N, K}, {dA, lda}, {dB, ldb}, {dC, ldc}, {alpha, beta});

你定类型、布局、架构、tile 大小、warp 分配。CUTLASS 编译成高度优化的 kernel。

Triton — Python tile 级

@triton.jit
def gemm_kernel(A, B, C, M, N, K,
                stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
                BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    for k in range(0, K, BLOCK_K):
        a = tl.load(A + offs_m[:, None] * stride_am + (k + offs_k[None, :]) * stride_ak)
        b = tl.load(B + (k + offs_k[:, None]) * stride_bk + offs_n[None, :] * stride_bn)
        acc += tl.dot(a, b)

    tl.store(C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn, acc)

关键区别:没有 threadIdx,没有 __shared__,没有同步原语。以 tile 为单位思考 — Triton 编译器自动决定线程映射、shared memory staging、Tensor Core 指令、内存合并。

FlashAttention 演进

版本 时间 硬件 实现方式 关键变化
FA1 2022 A100 手写 CUDA C++ 首次提出 tiling + online softmax, SRAM-aware
FA2 2023 A100 CUTLASS 3.x / CuTe C++ 模板重写,比 FA1 快 ~2×
FA3 2024 H100 CUTLASS 深度适配 WGMMA + TMA + setmaxnreg,达 ~740 TFLOPS
FA4 2025-26 H100/B200 CuTe DSL Python 写,性能逼近 C++,比 Triton 版快 ~50%

FlashAttention 4 本质是用 Python DSL 写,编译器生成 PTX。

vLLM / SGLang 的定位

  1. 调度/抽象层 — attention/GEMM backend 的可插拔抽象 + 运行时自动选最优后端。真正的工程壁垒,不是 kernel 本身。

  2. 自研 Triton kernel — 两处:(a) 适配独特数据结构的算子(PagedAttention、KV cache 抓取);(b) 跨硬件兜底后端。

  3. 主力高性能算子 — 全部外包给 FlashAttention、FlashInfer、CUTLASS、CuTe DSL、TRT-LLM,不自己重造。


参考