修复bug #41469 xpu平台缺少awq_dequantize的cuda算子

分析

  1. xpu平台缺少某个cuda的算子实现
  2. 但是业务代码里面直接调用到了
  3. 其实amd平台也是有这个问题,但是amd在platform的环节加入了一个环境变量,强制降级

证据

AMD平台的处理逻辑

# vllm/platforms/rocm.py
# 声明平台特殊的处理方式

@classmethod
def verify_quantization(cls, quant: str) -> None:
    super().verify_quantization(quant)
    if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ:
        logger.warning(
            "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ "
            "is not set, enabling VLLM_USE_TRITON_AWQ."
        )
    os.environ["VLLM_USE_TRITON_AWQ"] = "1"

verify_quantization 的初始化:

XPU缺少verify_quantization的特殊处理

Bind逻辑

// csrc/libtorch_stable/torch_bindings.cpp 

#ifndef USE_ROCM
// ...
// Dequantization for AWQ.
  ops.def(
      "awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
      "Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor");
// ...
// AWQ ops
  ops.impl("awq_dequantize", TORCH_BOX(&awq_dequantize));

torch_bindings.cpp 在编译后是一个so文件,可以包含初始化代码,在python中加载的时候执行,进行了如此的注册

STABLE_TORCH_LIBRARY_FRAGMENT TODO

#ifndef USE_ROCM 含义:如果没有声明AMD是编译 近似:只有NV平台才执行

Fallback Mechanism

def awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy):
    if envs.VLLM_USE_TRITON_AWQ:
        from vllm.model_executor.layers.quantization.awq_triton import awq_dequantize_triton
        return awq_dequantize_triton(qweight, scales, zeros)
    return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy)

When VLLM_USE_TRITON_AWQ=1, the code uses the pure Triton kernel implementation instead of the CUDA C++ kernel, which works on any platform that supports Triton (including XPU with Triton-XPU).

启发

  1. 这样类似的小bug应该比较多,主要原因是没有一个集中的处理的方法,主要代码在dispatch里面缺少硬件的识别和检查,对维护不友好
  2. 但是这类问题的价值也有限,都是少数corner case