Adapt w8a8mxfp8 quantization for Qwen VL models (#7417)

### What this PR does / why we need it?

This PR adapts the `w8a8_mxfp8` quantization method to support Qwen
Vision-Language (VL) models. Key changes include:
- Reshaping multi-dimensional input tensors to 2D before the quantized
matrix multiplication.
- Reshaping the 2D output back to its original multi-dimensional format.
- Adding specific output reshaping for the visual components of Qwen VL
models.
- Casting the bias tensor to `float32` to comply with the
`npu_quant_matmul` kernel requirements.

These changes are necessary to enable `w8a8_mxfp8` quantization for
models with multi-modal inputs like Qwen VL.

### Does this PR introduce _any_ user-facing change?

No, this is a backend enhancement to extend quantization support to new
model architectures. There are no user-facing API or behavior changes.

### How was this patch tested?

CI is expected to pass. Manual testing should be performed with a Qwen
VL model using `w8a8_mxfp8` quantization to verify correctness and
performance.

- vLLM version: v0.17.0
- vLLM main:
4497431df6

---------

Signed-off-by: ksiyuan <ksiyuan@umich.edu>
This commit is contained in:
Siyuan Kong
2026-03-20 16:18:58 +08:00
committed by GitHub
parent 4e6dbe0956
commit a16c99141b

View File

@@ -70,9 +70,15 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme):
bias: torch.Tensor | None = None,
tp_rank: int | None = 0,
) -> torch.Tensor:
# reshape x for Qwen VL models
original_shape = x.shape
if x.dim() > 2:
x = x.view(-1, x.shape[-1])
quantized_x, dynamic_scale = torch_npu.npu_dynamic_mx_quant(x, dst_type=torch.float8_e4m3fn)
pertoken_scale = dynamic_scale
output_dtype = x.dtype
if bias is not None and bias.dtype != torch.float32:
bias = bias.to(torch.float32)
output = torch_npu.npu_quant_matmul(
quantized_x,
@@ -85,6 +91,9 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme):
output_dtype=output_dtype,
group_sizes=[1, 1, self.group_size],
)
# reshape output for Qwen VL models
if len(original_shape) > 2:
output = output.view(*original_shape[:-1], -1)
return output