diff --git a/vllm_ascend/quantization/methods/w8a8_mxfp8.py b/vllm_ascend/quantization/methods/w8a8_mxfp8.py index 1961e168..574c4d75 100644 --- a/vllm_ascend/quantization/methods/w8a8_mxfp8.py +++ b/vllm_ascend/quantization/methods/w8a8_mxfp8.py @@ -70,9 +70,15 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme): bias: torch.Tensor | None = None, tp_rank: int | None = 0, ) -> torch.Tensor: + # reshape x for Qwen VL models + original_shape = x.shape + if x.dim() > 2: + x = x.view(-1, x.shape[-1]) quantized_x, dynamic_scale = torch_npu.npu_dynamic_mx_quant(x, dst_type=torch.float8_e4m3fn) pertoken_scale = dynamic_scale output_dtype = x.dtype + if bias is not None and bias.dtype != torch.float32: + bias = bias.to(torch.float32) output = torch_npu.npu_quant_matmul( quantized_x, @@ -85,6 +91,9 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme): output_dtype=output_dtype, group_sizes=[1, 1, self.group_size], ) + # reshape output for Qwen VL models + if len(original_shape) > 2: + output = output.view(*original_shape[:-1], -1) return output