From a16c99141b0830240eeff0cbe01bfc3c833c62fb Mon Sep 17 00:00:00 2001 From: Siyuan Kong <101997501+ksiyuan@users.noreply.github.com> Date: Fri, 20 Mar 2026 16:18:58 +0800 Subject: [PATCH] Adapt w8a8mxfp8 quantization for Qwen VL models (#7417) ### What this PR does / why we need it? This PR adapts the `w8a8_mxfp8` quantization method to support Qwen Vision-Language (VL) models. Key changes include: - Reshaping multi-dimensional input tensors to 2D before the quantized matrix multiplication. - Reshaping the 2D output back to its original multi-dimensional format. - Adding specific output reshaping for the visual components of Qwen VL models. - Casting the bias tensor to `float32` to comply with the `npu_quant_matmul` kernel requirements. These changes are necessary to enable `w8a8_mxfp8` quantization for models with multi-modal inputs like Qwen VL. ### Does this PR introduce _any_ user-facing change? No, this is a backend enhancement to extend quantization support to new model architectures. There are no user-facing API or behavior changes. ### How was this patch tested? CI is expected to pass. Manual testing should be performed with a Qwen VL model using `w8a8_mxfp8` quantization to verify correctness and performance. - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4497431df654e46fb1fb5e64bf8611e762ae5d87 --------- Signed-off-by: ksiyuan --- vllm_ascend/quantization/methods/w8a8_mxfp8.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm_ascend/quantization/methods/w8a8_mxfp8.py b/vllm_ascend/quantization/methods/w8a8_mxfp8.py index 1961e168..574c4d75 100644 --- a/vllm_ascend/quantization/methods/w8a8_mxfp8.py +++ b/vllm_ascend/quantization/methods/w8a8_mxfp8.py @@ -70,9 +70,15 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme): bias: torch.Tensor | None = None, tp_rank: int | None = 0, ) -> torch.Tensor: + # reshape x for Qwen VL models + original_shape = x.shape + if x.dim() > 2: + x = x.view(-1, x.shape[-1]) quantized_x, dynamic_scale = torch_npu.npu_dynamic_mx_quant(x, dst_type=torch.float8_e4m3fn) pertoken_scale = dynamic_scale output_dtype = x.dtype + if bias is not None and bias.dtype != torch.float32: + bias = bias.to(torch.float32) output = torch_npu.npu_quant_matmul( quantized_x, @@ -85,6 +91,9 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme): output_dtype=output_dtype, group_sizes=[1, 1, self.group_size], ) + # reshape output for Qwen VL models + if len(original_shape) > 2: + output = output.view(*original_shape[:-1], -1) return output