Adapt w8a8mxfp8 quantization for Qwen VL models (#7417)
### What this PR does / why we need it?
This PR adapts the `w8a8_mxfp8` quantization method to support Qwen
Vision-Language (VL) models. Key changes include:
- Reshaping multi-dimensional input tensors to 2D before the quantized
matrix multiplication.
- Reshaping the 2D output back to its original multi-dimensional format.
- Adding specific output reshaping for the visual components of Qwen VL
models.
- Casting the bias tensor to `float32` to comply with the
`npu_quant_matmul` kernel requirements.
These changes are necessary to enable `w8a8_mxfp8` quantization for
models with multi-modal inputs like Qwen VL.
### Does this PR introduce _any_ user-facing change?
No, this is a backend enhancement to extend quantization support to new
model architectures. There are no user-facing API or behavior changes.
### How was this patch tested?
CI is expected to pass. Manual testing should be performed with a Qwen
VL model using `w8a8_mxfp8` quantization to verify correctness and
performance.
- vLLM version: v0.17.0
- vLLM main:
4497431df6
---------
Signed-off-by: ksiyuan <ksiyuan@umich.edu>
This commit is contained in:
@@ -70,9 +70,15 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme):
|
||||
bias: torch.Tensor | None = None,
|
||||
tp_rank: int | None = 0,
|
||||
) -> torch.Tensor:
|
||||
# reshape x for Qwen VL models
|
||||
original_shape = x.shape
|
||||
if x.dim() > 2:
|
||||
x = x.view(-1, x.shape[-1])
|
||||
quantized_x, dynamic_scale = torch_npu.npu_dynamic_mx_quant(x, dst_type=torch.float8_e4m3fn)
|
||||
pertoken_scale = dynamic_scale
|
||||
output_dtype = x.dtype
|
||||
if bias is not None and bias.dtype != torch.float32:
|
||||
bias = bias.to(torch.float32)
|
||||
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
quantized_x,
|
||||
@@ -85,6 +91,9 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme):
|
||||
output_dtype=output_dtype,
|
||||
group_sizes=[1, 1, self.group_size],
|
||||
)
|
||||
# reshape output for Qwen VL models
|
||||
if len(original_shape) > 2:
|
||||
output = output.view(*original_shape[:-1], -1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user