DeepSeek V4 TP Size Exceeds N Groups Error
n_groups 是什么?和 TP 的关系
传统模型的痛点:庞大的 W_o 矩阵
标准 MHA 中,每个 Head 输出维度 d_k(如 128),128 个 Head 拼起来总维度 16384。紧接着需要通过巨大的 W_o 矩阵将万维向量投影回 Hidden Size。
- W_o 参数量:16384 × 16384 ≈ 2.68 亿(仅单层 Attention 输出矩阵)
- 问题:Decoding 阶段每个 Token 都要和这个矩阵做 GEMM,计算密集 + 高带宽消耗
V4 的解法:Grouped Low-Rank 分解
DeepSeek-V4 认为 128 个注意力头的信息直接用稠密大矩阵投影冗余度极高,于是引入 Grouped Output Projection:
- 切分(Grouping):128 个 Head 均匀分成 16 个组,每组 8 个 Head
- Down-Projection:每组通过低秩矩阵将 8 × 128 = 1024 维投影到组内隐空间(o_lora_rank = 1024)
- 拼接与 Up-Projection:16 个组压缩后的向量拼接,再映射回 Hidden Size
实际配置:
| 模型 | o_groups | 说明 |
|---|---|---|
| DeepSeek-V4-Flash (284B) | 8 | 激活 13B |
| DeepSeek-V4-Pro (1.6T) | 16 | 激活 49B |
组大小是训练时定好的。TP 会把对应的投影矩阵拆分到各卡。
Bug 原因
TP 数量过大 → 拆分后一个 GPU 上分不到完整的组 → 崩溃
修复方案
做了 dirty fix,理想方案应该支持更大的 TP,但受 vLLM 架构限制会比较复杂。社区暂未看到更深入的解决方案。
deepgemm_post_process_fp8_weight_block 的作用
# 假设 TP=2, wo_a 形状 [1024, 512](每张卡4组)
# 1. 权重是 2D 的
wq.shape = [1024, 512] # [输入维度, 输出维度]
ws.shape = [8, 4] # [1024/128, 512/128] 每128x128块一个scale
# 2. 转成 3D 用于 grouped BMM
g = bmm_batch_size = n_local_groups = 4 # 每张卡4组
r = 512 // 4 = 128 # 每组128列
d = 1024 # 输入维度
wq: [1024, 512] → [g, r, d] = [4, 128, 1024] # 转置后
ws: [8, 4] → [g, r/128, d/128] = [4, 1, 8]
# 3. 转换 scale 格式给 DeepGEMM
dg_ws = transform_sf_into_required_layout(ws, ...)
# 把 scale 排成 DeepGEMM 要求的内存对齐格式
# 4. 返回用于后续 fp8_einsum 计算
return wq, dg_ws
核心流程:2D 权重 → 3D 分组 → scale 格式转换 → 交给 DeepGEMM 计算。