add mxfp8 moe quantization (#6670)

### What this PR does / why we need it?
support mxfp8 quantization (Qwen MOE )
Using adaptor to make the hardware-specific behavior clearer and more
maintainable
### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
13397841ab

---------

Signed-off-by: fangrongcan <17343701736@163.com>
Signed-off-by: wangyao-i <iwangyao@outlook.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: Eric-dot <60131170+Eric-dot@users.noreply.github.com>
Co-authored-by: fangrongcan <f00876277@china.huawei.com>
Co-authored-by: wangyao-i <iwangyao@outlook.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
Eric-dot
2026-03-02 11:04:06 +08:00
committed by GitHub
parent c324053b44
commit 3c66a970f2
10 changed files with 802 additions and 100 deletions

View File

@@ -28,6 +28,7 @@ import torch_npu
from vllm.config import get_current_vllm_config
from vllm.distributed.parallel_state import get_ep_group
from vllm_ascend.device.device_op import DeviceOperator
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled
@@ -103,8 +104,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
self.ep_rank_id = get_mc2_group().rank_in_group
self.ep_world_size = get_mc2_group().world_size
self.enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
self.need_extra_args = get_ascend_device_type() == AscendDeviceType.A3
self.need_extra_args = get_ascend_device_type() in [AscendDeviceType.A3, AscendDeviceType.A5]
self.a5_need_extra_args = get_ascend_device_type() == AscendDeviceType.A5
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
# improve communication performance.
@@ -136,8 +137,21 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
expert_map: torch.Tensor,
mc2_mask: torch.Tensor,
global_redundant_expert_num: int = 0,
**kwargs,
):
quant_mode = 2 if self.with_quant else 0
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
comm_quant_mode = kwargs.get("comm_quant_mode")
# NOTE: quant_mode differs by quant feature:
# - Legacy int communication quantization uses quant_mode=2.
# - A5 MXFP8 communication uses quant_mode=4.
# TODO(linfeng): The quantization-related parameters need to be consolidated into a single
# dataclass, and the FP8 MoE code path should be integrated into it going forward.
if comm_quant_mode is not None:
quant_mode = comm_quant_mode
elif self.with_quant:
quant_mode = 4 if self.a5_need_extra_args and use_mxfp_quant else 2
else:
quant_mode = 0
self.moe_expert_num = len(expert_map) + global_redundant_expert_num
kwargs_mc2 = {
"x": hidden_states,
@@ -164,7 +178,12 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
"tp_rank_id": 0,
}
)
if self.need_expert_scale:
if self.a5_need_extra_args and use_mxfp_quant:
y_dtype = kwargs.get("y_dtype")
if self.with_quant:
y_dtype = torch.float8_e4m3fn if y_dtype is None else y_dtype
stage1_kwargs.update({"tp_world_size": 1, "tp_rank_id": 0, "y_dtype": y_dtype})
if self.need_expert_scale or self.a5_need_extra_args:
stage1_kwargs.update(
{
"expert_scales": topk_weights.to(torch.float32),
@@ -186,11 +205,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: torch.Tensor | None = None,
**kwargs,
):
self.with_quant = with_quant
kwargs_mc2 = self.get_dispatch_mc2_kwargs(
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask, global_redundant_expert_num
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask, global_redundant_expert_num, **kwargs
)
output = (
torch_npu.npu_moe_distribute_dispatch_v2(**kwargs_mc2)
@@ -337,19 +356,16 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
first_expert_idx = 0
last_expert_idx = self.num_experts_local
global_num_experts = self.num_experts_local
sorted_hidden_states, expanded_row_idx, expert_tokens, pertoken_scale = (
torch.ops._C_ascend.npu_moe_init_routing_custom(
hidden_states,
topk_ids,
scale=pertoken_scale,
active_num=num_tokens * self.top_k,
expert_num=global_num_experts,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
active_expert_range=[first_expert_idx, last_expert_idx],
quant_mode=1 if self.with_quant and pertoken_scale is None else -1,
)
sorted_hidden_states, expanded_row_idx, expert_tokens, pertoken_scale = DeviceOperator.npu_moe_init_routing(
hidden_states,
topk_ids,
scale=pertoken_scale,
active_num=num_tokens * self.top_k,
expert_num=global_num_experts,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
active_expert_range=[first_expert_idx, last_expert_idx],
quant_mode=1 if self.with_quant and pertoken_scale is None else -1,
)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1 # `count` mode