add mxfp8 moe quantization (#6670)

### What this PR does / why we need it?
support mxfp8 quantization (Qwen MOE )
Using adaptor to make the hardware-specific behavior clearer and more
maintainable
### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
13397841ab

---------

Signed-off-by: fangrongcan <17343701736@163.com>
Signed-off-by: wangyao-i <iwangyao@outlook.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: Eric-dot <60131170+Eric-dot@users.noreply.github.com>
Co-authored-by: fangrongcan <f00876277@china.huawei.com>
Co-authored-by: wangyao-i <iwangyao@outlook.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
Eric-dot
2026-03-02 11:04:06 +08:00
committed by GitHub
parent c324053b44
commit 3c66a970f2
10 changed files with 802 additions and 100 deletions

View File

@@ -30,6 +30,7 @@ class QuantType(Enum):
NONE = 0
W8A8 = 1
W4A8 = 2
MXFP8 = 3
class AscendLinearScheme(ABC):

View File

@@ -15,13 +15,24 @@
# limitations under the License.
#
from collections.abc import Callable
from typing import Any
import torch
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.config import CompilationMode, get_current_vllm_config
from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context
from .base import AscendLinearScheme
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.quantization.mxfp_compat import (
FLOAT8_E8M0FNU_DTYPE,
ensure_mxfp8_linear_available,
ensure_mxfp8_moe_available,
)
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
from .registry import register_scheme
@@ -37,6 +48,7 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme):
model_dtype = None
def __init__(self):
ensure_mxfp8_linear_available("W8A8_MXFP8 linear quantization")
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get("group_size", 32)
@@ -66,9 +78,9 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme):
quantized_x,
layer.weight,
layer.weight_scale,
scale_dtype=torch_npu.float8_e8m0fnu,
scale_dtype=FLOAT8_E8M0FNU_DTYPE,
pertoken_scale=pertoken_scale,
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
pertoken_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
bias=bias,
output_dtype=output_dtype,
group_sizes=[1, 1, self.group_size],
@@ -81,3 +93,127 @@ class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme):
layer.weight_scale.data = layer.weight_scale.data.reshape(n_dim, k_dim // 2, 2)
layer.weight.data = layer.weight.data.transpose(0, 1)
layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1)
@register_scheme("W8A8_MXFP8", "moe")
class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme):
"""FusedMoe method for Ascend W8A8_DYNAMIC."""
model_dtype = None
quant_type: QuantType = QuantType.MXFP8
def __init__(self):
ensure_mxfp8_moe_available("W8A8_MXFP8 MoE quantization")
self.ep_group = get_ep_group()
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get("group_size", 32)
ascend_config = get_ascend_config()
self.use_aclgraph = (
vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE
and not vllm_config.model_config.enforce_eager
)
self.dynamic_eplb = ascend_config.eplb_config.dynamic_eplb
@staticmethod
def get_weight(
num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
param_dict = {}
param_dict["w13_weight"] = torch.empty(
num_experts, 2 * intermediate_size_per_partition, hidden_sizes, dtype=torch.float8_e4m3fn
)
param_dict["w2_weight"] = torch.empty(
num_experts, hidden_sizes, intermediate_size_per_partition, dtype=torch.float8_e4m3fn
)
return param_dict
def get_dynamic_quant_param(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
param_dict = {}
param_dict["w13_weight_scale"] = torch.empty(
num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.group_size, dtype=torch.uint8
)
param_dict["w2_weight_scale"] = torch.empty(
num_experts, hidden_sizes, intermediate_size_per_partition // self.group_size, dtype=torch.uint8
)
return param_dict
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
is_prefill: bool = True,
enable_force_load_balance: bool = True,
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
**kwargs,
) -> torch.Tensor:
expected = global_num_experts - global_redundant_expert_num
assert router_logits.shape[1] == expected, "Number of global experts mismatch (excluding redundancy)"
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
)
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
topk_weights = topk_weights.to(x.dtype)
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int8_w8a8=False,
expert_map=expert_map,
log2phy=log2phy,
dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask"),
use_mxfp_quant=True,
act_quant_type=torch.float8_e4m3fn,
weight_quant_type=torch.float8_e4m3fn,
scale_type=FLOAT8_E8M0FNU_DTYPE,
per_token_scale_type=FLOAT8_E8M0FNU_DTYPE,
)
def process_weights_after_loading(self, layer):
g_num, n_size, k_size = layer.w13_weight_scale.shape
layer.w13_weight_scale.data = layer.w13_weight_scale.data.reshape(g_num, n_size, k_size // 2, 2)
g_num, n_size, k_size = layer.w2_weight_scale.shape
layer.w2_weight_scale.data = layer.w2_weight_scale.data.reshape(g_num, n_size, k_size // 2, 2)
layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2)
layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(1, 2)
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(1, 2)