DeepSeek V4 TP Size Exceeds N Groups Error

n_groups 是什么?和 TP 的关系

传统模型的痛点:庞大的 W_o 矩阵

标准 MHA 中,每个 Head 输出维度 d_k(如 128),128 个 Head 拼起来总维度 16384。紧接着需要通过巨大的 W_o 矩阵将万维向量投影回 Hidden Size。

V4 的解法:Grouped Low-Rank 分解

DeepSeek-V4 认为 128 个注意力头的信息直接用稠密大矩阵投影冗余度极高,于是引入 Grouped Output Projection:

  1. 切分(Grouping):128 个 Head 均匀分成 16 个组,每组 8 个 Head
  2. Down-Projection:每组通过低秩矩阵将 8 × 128 = 1024 维投影到组内隐空间(o_lora_rank = 1024)
  3. 拼接与 Up-Projection:16 个组压缩后的向量拼接,再映射回 Hidden Size

实际配置:

模型 o_groups 说明
DeepSeek-V4-Flash (284B) 8 激活 13B
DeepSeek-V4-Pro (1.6T) 16 激活 49B

组大小是训练时定好的。TP 会把对应的投影矩阵拆分到各卡。

Bug 原因

TP 数量过大 → 拆分后一个 GPU 上分不到完整的组 → 崩溃

修复方案

做了 dirty fix,理想方案应该支持更大的 TP,但受 vLLM 架构限制会比较复杂。社区暂未看到更深入的解决方案。

deepgemm_post_process_fp8_weight_block 的作用

# 假设 TP=2, wo_a 形状 [1024, 512](每张卡4组)

# 1. 权重是 2D 的
wq.shape = [1024, 512]      # [输入维度, 输出维度]
ws.shape = [8, 4]           # [1024/128, 512/128] 每128x128块一个scale

# 2. 转成 3D 用于 grouped BMM
g = bmm_batch_size = n_local_groups = 4  # 每张卡4组
r = 512 // 4 = 128          # 每组128列
d = 1024                    # 输入维度

wq: [1024, 512]  [g, r, d] = [4, 128, 1024]  # 转置后
ws: [8, 4]  [g, r/128, d/128] = [4, 1, 8]

# 3. 转换 scale 格式给 DeepGEMM
dg_ws = transform_sf_into_required_layout(ws, ...)
# 把 scale 排成 DeepGEMM 要求的内存对齐格式

# 4. 返回用于后续 fp8_einsum 计算
return wq, dg_ws

核心流程:2D 权重 → 3D 分组 → scale 格式转换 → 交给 DeepGEMM 计算。