add mxfp8 moe quantization (#6670)

### What this PR does / why we need it?
support mxfp8 quantization (Qwen MOE )
Using adaptor to make the hardware-specific behavior clearer and more
maintainable
### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
13397841ab

---------

Signed-off-by: fangrongcan <17343701736@163.com>
Signed-off-by: wangyao-i <iwangyao@outlook.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: Eric-dot <60131170+Eric-dot@users.noreply.github.com>
Co-authored-by: fangrongcan <f00876277@china.huawei.com>
Co-authored-by: wangyao-i <iwangyao@outlook.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
Eric-dot
2026-03-02 11:04:06 +08:00
committed by GitHub
parent c324053b44
commit 3c66a970f2
10 changed files with 802 additions and 100 deletions

View File

@@ -22,7 +22,11 @@ from vllm.forward_context import get_forward_context
from vllm.triton_utils import HAS_TRITON
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.device.device_op import DeviceOperator
from vllm_ascend.ops.activation import AscendSwigluOAIAndMul
from vllm_ascend.quantization.mxfp_compat import (
ensure_mxfp8_moe_available,
)
from vllm_ascend.utils import (
dispose_tensor,
enable_custom_op,
@@ -66,12 +70,22 @@ def cumsum_group_list(
)
def _require_single_tensor_for_swiglu_quant(
tensor_or_list: list[torch.Tensor] | torch.Tensor, *, name: str
) -> torch.Tensor:
if isinstance(tensor_or_list, list):
if len(tensor_or_list) != 1:
raise ValueError(f"{name} must be a tensor or a single-element list, but got {len(tensor_or_list)}.")
return tensor_or_list[0]
return tensor_or_list
def quant_apply_mlp(
hidden_states: torch.Tensor,
w1: list[torch.Tensor],
w1_scale: list[torch.Tensor],
w2: list[torch.Tensor],
w2_scale: list[torch.Tensor],
w1: list[torch.Tensor] | torch.Tensor,
w1_scale: list[torch.Tensor] | torch.Tensor,
w2: list[torch.Tensor] | torch.Tensor,
w2_scale: list[torch.Tensor] | torch.Tensor,
group_list: torch.Tensor,
group_list_type: int = 1,
dynamic_scale: torch.Tensor = None,
@@ -81,15 +95,45 @@ def quant_apply_mlp(
w2_offset: torch.Tensor | None = None,
fusion: bool = False,
dynamic_eplb: bool = False,
**kwargs,
) -> torch.Tensor:
# TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced by different
# quantization modes will be consolidated into a dataclass in a follow-up.
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
act_quant_type = torch.float8_e4m3fn
weight_quant_type = None
scale_type = None
per_token_scale_type = None
use_bf16 = True
input_hidden_dtype = hidden_states.dtype
use_gmm_swiglu_quant_fusion = use_mxfp_quant or (fusion and not dynamic_eplb)
if use_mxfp_quant:
act_quant_type = kwargs.get("act_quant_type", torch.float8_e4m3fn)
weight_quant_type = kwargs.get("weight_quant_type", torch.float8_e4m3fn)
scale_type = kwargs.get("scale_type")
per_token_scale_type = kwargs.get("per_token_scale_type")
use_bf16 = kwargs.get("use_bf16", True)
ensure_mxfp8_moe_available("MXFP MoE MLP path")
if w1_scale_bias is not None or w2_scale_bias is not None:
raise NotImplementedError("MXFP path does not support scale_bias yet.")
if w1_offset is not None or w2_offset is not None:
raise NotImplementedError("MXFP path does not support antiquant offset yet.")
if w1_offset is not None:
unquantized_hidden_states = hidden_states
quantized_hidden_states = None
elif dynamic_scale is None:
unquantized_hidden_states = hidden_states
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
# Dispose the original unquantized hidden states
# to save npu memory because they're no longer used.
hidden_states, pertoken_scale = DeviceOperator.npu_dynamic_quant(
hidden_states=hidden_states,
dynamic_scale=None,
act_quant_type=act_quant_type,
use_mxfp_quant=use_mxfp_quant,
)
dispose_tensor(unquantized_hidden_states)
quantized_hidden_states = None
else:
@@ -98,13 +142,14 @@ def quant_apply_mlp(
quantized_hidden_states = hidden_states
bias1, bias2 = None, None
_output_dtype = w2_scale[0].dtype
_output_dtype = w2_scale[0].dtype if isinstance(w2_scale, list) else w2_scale.dtype
weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states)
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states)
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
if w1_scale_bias is None and w1_offset is None and is_mc2:
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb) and not use_mxfp_quant:
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list(
x=hidden_states,
@@ -113,14 +158,16 @@ def quant_apply_mlp(
x_scale=pertoken_scale,
group_list=cumsum_group_list(group_list, group_list_type, 0),
)
elif fusion and not dynamic_eplb:
elif use_gmm_swiglu_quant_fusion:
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
hidden_states, swiglu_out_scale, _ = DeviceOperator.npu_grouped_matmul_swiglu_quant(
x=hidden_states,
weight=w1[0],
weight=_require_single_tensor_for_swiglu_quant(w1, name="w1"),
group_list=cumsum_group_list(group_list, group_list_type, 0),
weight_scale=w1_scale[0],
weight_scale=_require_single_tensor_for_swiglu_quant(w1_scale, name="w1_scale"),
x_scale=pertoken_scale,
bias=None,
use_mxfp_quant=use_mxfp_quant,
)
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
@@ -152,17 +199,23 @@ def quant_apply_mlp(
quant_mode=1,
)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
hidden_states = DeviceOperator.npu_grouped_matmul_gmm2(
hidden_states=hidden_states,
weight=w2,
scale=w2_scale,
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
weight_scale=w2_scale,
per_token_scale=swiglu_out_scale,
group_list=group_list,
output_dtype=w2_scale[0].dtype,
)[0]
group_list_type=group_list_type,
input_dtype=input_hidden_dtype,
act_quant_type=act_quant_type,
weight_quant_type=weight_quant_type,
scale_type=scale_type,
per_token_scale_type=per_token_scale_type,
use_bf16=use_bf16,
use_mxfp_quant=use_mxfp_quant,
bias=None,
fallback_output_dtype=w2_scale[0].dtype if isinstance(w2_scale, list) else w2_scale.dtype,
)
elif w1_offset is not None:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
@@ -201,7 +254,7 @@ def quant_apply_mlp(
# TODO w4a8 scene: dynamic acquisition of dtype in the future
_output_dtype = torch.bfloat16
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb) and not use_mxfp_quant:
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list(
x=hidden_states,
@@ -211,15 +264,15 @@ def quant_apply_mlp(
group_list=cumsum_group_list(group_list, group_list_type, 0),
bias=bias1,
)
elif fusion and not dynamic_eplb:
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
elif use_gmm_swiglu_quant_fusion:
hidden_states, swiglu_out_scale, _ = DeviceOperator.npu_grouped_matmul_swiglu_quant(
x=hidden_states,
weight=w1[0],
bias=bias1,
weight=_require_single_tensor_for_swiglu_quant(w1, name="w1"),
group_list=cumsum_group_list(group_list, group_list_type, 0),
weight_scale=w1_scale[0],
weight_scale=_require_single_tensor_for_swiglu_quant(w1_scale, name="w1_scale"),
x_scale=pertoken_scale,
bias=bias1,
use_mxfp_quant=use_mxfp_quant,
)
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
@@ -251,18 +304,23 @@ def quant_apply_mlp(
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
hidden_states = DeviceOperator.npu_grouped_matmul_gmm2(
hidden_states=hidden_states,
weight=w2,
scale=w2_scale,
bias=bias2,
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
weight_scale=w2_scale,
per_token_scale=swiglu_out_scale,
group_list=group_list,
output_dtype=_output_dtype,
)[0]
group_list_type=group_list_type,
input_dtype=input_hidden_dtype,
act_quant_type=act_quant_type,
weight_quant_type=weight_quant_type,
scale_type=scale_type,
per_token_scale_type=per_token_scale_type,
use_bf16=use_bf16,
use_mxfp_quant=use_mxfp_quant,
bias=bias2,
fallback_output_dtype=_output_dtype,
)
return hidden_states
@@ -334,26 +392,13 @@ def unified_apply_mlp(
fusion: bool = False,
need_trans: bool = True,
dynamic_eplb: bool = False,
**kwargs,
) -> torch.Tensor:
if with_quant:
assert w1_scale is not None and w2_scale is not None
return quant_apply_mlp(
hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=dynamic_scale,
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
w1_offset=w1_offset,
w2_offset=w2_offset,
fusion=fusion,
dynamic_eplb=dynamic_eplb,
)
else:
"""
Unified MoE MLP entry.
Quant path is dispatched by DeviceOperator with explicit quant-type flags.
"""
if not with_quant:
return unquant_apply_mlp(
hidden_states=hidden_states,
w1=w1,
@@ -366,3 +411,34 @@ def unified_apply_mlp(
topk_scales=topk_scales,
need_trans=need_trans,
)
assert w1_scale is not None and w2_scale is not None
# TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced by different
# quantization modes will be consolidated into a dataclass in a follow-up.
act_quant_type = kwargs.get("act_quant_type", torch.float8_e4m3fn)
weight_quant_type = kwargs.get("weight_quant_type", torch.float8_e4m3fn)
scale_type = kwargs.get("scale_type")
per_token_scale_type = kwargs.get("per_token_scale_type")
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
return quant_apply_mlp(
hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=dynamic_scale,
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
w1_offset=w1_offset,
w2_offset=w2_offset,
fusion=fusion,
dynamic_eplb=dynamic_eplb,
act_quant_type=act_quant_type,
weight_quant_type=weight_quant_type,
scale_type=scale_type,
per_token_scale_type=per_token_scale_type,
use_mxfp_quant=use_mxfp_quant,
use_bf16=kwargs.get("use_bf16", True),
)