修复bug #41469 xpu平台缺少awq_dequantize的cuda算子
分析
- xpu平台缺少某个cuda的算子实现
- 但是业务代码里面直接调用到了
- 其实amd平台也是有这个问题,但是amd在platform的环节加入了一个环境变量,强制降级
证据
AMD平台的处理逻辑
# vllm/platforms/rocm.py
# 声明平台特殊的处理方式
@classmethod
def verify_quantization(cls, quant: str) -> None:
super().verify_quantization(quant)
if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
logger.warning(
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ "
"is not set, enabling VLLM_USE_TRITON_AWQ."
)
os.environ["VLLM_USE_TRITON_AWQ"] = "1"
verify_quantization 的初始化:
- 用户启动 vLLM
- vllm/config/model.py: ModelConfig.post_init()
- self._verify_quantization()
- current_platform.verify_quantization(self.quantization)
- self._verify_quantization()
- vllm/config/model.py: ModelConfig.post_init()
XPU缺少verify_quantization的特殊处理
Bind逻辑
// csrc/libtorch_stable/torch_bindings.cpp
#ifndef USE_ROCM
// ...
// Dequantization for AWQ.
ops.def(
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor");
// ...
// AWQ ops
ops.impl("awq_dequantize", TORCH_BOX(&awq_dequantize));
torch_bindings.cpp 在编译后是一个so文件,可以包含初始化代码,在python中加载的时候执行,进行了如此的注册
STABLE_TORCH_LIBRARY_FRAGMENT TODO
#ifndef USE_ROCM
含义:如果没有声明AMD是编译
近似:只有NV平台才执行
Fallback Mechanism
def awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy):
if envs.VLLM_USE_TRITON_AWQ:
from vllm.model_executor.layers.quantization.awq_triton import awq_dequantize_triton
return awq_dequantize_triton(qweight, scales, zeros)
return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy)
When VLLM_USE_TRITON_AWQ=1, the code uses the pure Triton kernel implementation instead of the CUDA C++ kernel, which works on any platform that supports Triton (including XPU with Triton-XPU).
启发
- 这样类似的小bug应该比较多,主要原因是没有一个集中的处理的方法,主要代码在dispatch里面缺少硬件的识别和检查,对维护不友好
- 但是这类问题的价值也有限,都是少数corner case