[Feat] 310p support MoE W8A8 quantizaition (#6641)
### What this PR does / why we need it?
This PR introduces support for W8A8 dynamic quantization for
Mixture-of-Experts (MoE) models on Ascend 310P devices. This is achieved
by:
- Implementing a new quantization scheme
`AscendW8A8DynamicFusedMoEMethod310`.
- Adding a unified MLP implementation (`unified_apply_mlp`) for 310P
that handles both quantized and unquantized paths.
- Refactoring the MoE and quantization configuration logic to correctly
route to the new 310P-specific implementations.
- Adding new e2e and unit tests to verify the functionality of MoE W8A8
quantization.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- Added a new e2e test `test_qwen3_moe_tp2_w8a8` to test MoE W8A8
quantization in a multi-card setup.
- Added several new unit tests for the 310P-specific MoE components,
including `experts_selector`, `fused_moe`, `moe_comm_method`, `moe_mlp`,
and the new `w8a8_dynamic` quantization method.
- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd
---------
Signed-off-by: pu-zhe <zpuaa@outlook.com>
This commit is contained in:
@@ -19,7 +19,6 @@ from collections.abc import Callable
|
||||
import torch
|
||||
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import _native_select_experts
|
||||
from vllm_ascend.utils import get_weight_prefetch_method
|
||||
|
||||
|
||||
def select_experts(
|
||||
@@ -55,9 +54,6 @@ def select_experts(
|
||||
topk_weights: router weights of shape (num_tokens, top_k).
|
||||
topk_ids: selected expert IDs of shape (num_tokens, top_k).
|
||||
"""
|
||||
# prefetch w1_w3_proj.weight preprocess
|
||||
weight_prefetch_method = get_weight_prefetch_method()
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up")
|
||||
topk_weights, topk_ids = _native_select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
|
||||
@@ -58,7 +58,6 @@ class AscendUnquantizedFusedMoEMethod310(UnquantizedFusedMoEMethod):
|
||||
num_expert_group: int | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
@@ -67,7 +66,6 @@ class AscendUnquantizedFusedMoEMethod310(UnquantizedFusedMoEMethod):
|
||||
) -> torch.Tensor:
|
||||
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||
assert routed_scaling_factor == 1.0
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
@@ -195,44 +193,36 @@ class AscendFusedMoE310(FusedMoE):
|
||||
|
||||
method = quant_method.quant_method
|
||||
quant_type = getattr(method, "quant_type", QuantType.NONE)
|
||||
if quant_type != QuantType.NONE:
|
||||
# TODO: w8a8 quantization will be supported soon, and only reject w4a8 here.
|
||||
raise RuntimeError("W8A8 is not supported currently.")
|
||||
return QuantType.NONE
|
||||
if quant_type not in [QuantType.NONE, QuantType.W8A8]:
|
||||
raise RuntimeError("Only Unquant and W8A8 is supported.")
|
||||
return quant_type
|
||||
|
||||
def forward_impl( # type: ignore[override]
|
||||
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
assert self.quant_method is not None
|
||||
assert self.routed_scaling_factor == 1.0, "routed_scaling_factor != 1.0 is not supported."
|
||||
forward_context = get_forward_context()
|
||||
|
||||
hidden_states, router_logits, _, context_metadata = forward_context.moe_comm_method.prepare(
|
||||
hidden_states=hidden_states, router_logits=router_logits, quant_type=self.quant_type
|
||||
)
|
||||
|
||||
if isinstance(hidden_states, tuple):
|
||||
hidden_states, pertoken_scale = hidden_states
|
||||
else:
|
||||
pertoken_scale = None
|
||||
|
||||
# Matrix multiply.
|
||||
fused_experts_results: FusedExpertsResult = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
pertoken_scale=pertoken_scale,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.local_expert_map,
|
||||
top_k=self.top_k,
|
||||
router_logits=router_logits,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.local_expert_map,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,39 +1,90 @@
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
from __future__ import annotations
|
||||
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl
|
||||
|
||||
from .token_dispatcher import TokenDispatcherWithAllGather310
|
||||
|
||||
|
||||
class AllGatherCommImpl310(AllGatherCommImpl):
|
||||
"""This implementation is the same as NativeAllGatherCommImpl,
|
||||
but uses NPU-specific ops for better performance.
|
||||
|
||||
This implementation should be compatible with all scenarios, and
|
||||
thus it is the default implementation for MoE communication methods.
|
||||
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
|
||||
and `torch_npu.npu_moe_token_unpermute` for post-processing
|
||||
to handle the token-to-expert mapping and communication efficiently.
|
||||
"""
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
return TokenDispatcherWithAllGather310(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts,
|
||||
)
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult
|
||||
|
||||
from .moe_mlp import unified_apply_mlp
|
||||
from .token_dispatcher import TokenDispatcherWithAllGather310
|
||||
|
||||
|
||||
class AllGatherCommImpl310(AllGatherCommImpl):
|
||||
"""This implementation is the same as NativeAllGatherCommImpl,
|
||||
but uses NPU-specific ops for better performance.
|
||||
|
||||
This implementation should be compatible with all scenarios, and
|
||||
thus it is the default implementation for MoE communication methods.
|
||||
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
|
||||
and `torch_npu.npu_moe_token_unpermute` for post-processing
|
||||
to handle the token-to-expert mapping and communication efficiently.
|
||||
"""
|
||||
|
||||
def fused_experts( # type: ignore[override]
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
use_int8_w8a8: bool = False,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> FusedExpertsResult:
|
||||
# This method is overridden to use the 310p-specific unified_apply_mlp
|
||||
# which provides optimized MLP computation for the 310p platform
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
assert moe_comm_method is not None, "Missing communication context"
|
||||
|
||||
dispatch_results = self.token_dispatcher.token_dispatch(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
mlp_output = unified_apply_mlp(
|
||||
hidden_states=dispatch_results.hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
group_list=dispatch_results.group_list,
|
||||
group_list_type=dispatch_results.group_list_type,
|
||||
with_quant=use_int8_w8a8,
|
||||
)
|
||||
|
||||
combine_results = self.token_dispatcher.token_combine(
|
||||
hidden_states=mlp_output, context_metadata=dispatch_results.context_metadata
|
||||
)
|
||||
|
||||
return FusedExpertsResult(
|
||||
routed_out=combine_results.routed_out,
|
||||
group_list_type=dispatch_results.group_list_type,
|
||||
expert_tokens=dispatch_results.group_list,
|
||||
)
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
return TokenDispatcherWithAllGather310(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts,
|
||||
)
|
||||
|
||||
93
vllm_ascend/_310p/fused_moe/moe_mlp.py
Normal file
93
vllm_ascend/_310p/fused_moe/moe_mlp.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
|
||||
def quant_apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
) -> torch.Tensor:
|
||||
if group_list_type == 1:
|
||||
# Convert group_list to cumulative sum format if group_list is count format
|
||||
group_list = torch.cumsum(group_list, dim=0)
|
||||
|
||||
hidden_states = torch_npu.npu_quant_grouped_matmul_dequant(
|
||||
x=hidden_states, quantized_weight=w1, weight_scale=w1_scale, group_list=group_list, quant_mode="pertoken"
|
||||
)
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states = torch_npu.npu_quant_grouped_matmul_dequant(
|
||||
x=hidden_states, quantized_weight=w2, weight_scale=w2_scale, group_list=group_list, quant_mode="pertoken"
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unquant_apply_mlp(
|
||||
hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, group_list: torch.Tensor, group_list_type: int = 1
|
||||
) -> torch.Tensor:
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
act_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[act_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unified_apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
group_list_type: int = 1,
|
||||
with_quant: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if with_quant:
|
||||
assert w1_scale is not None and w2_scale is not None
|
||||
return quant_apply_mlp(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
)
|
||||
else:
|
||||
return unquant_apply_mlp(
|
||||
hidden_states=hidden_states, w1=w1, w2=w2, group_list=group_list, group_list_type=group_list_type
|
||||
)
|
||||
@@ -32,21 +32,14 @@ class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def token_dispatch(
|
||||
def token_dispatch( # type: ignore[override]
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
mc2_mask: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
):
|
||||
if with_quant:
|
||||
raise RuntimeError("Quant is not supported for 310P currently.")
|
||||
self.original_shape = hidden_states.shape
|
||||
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
@@ -77,7 +70,6 @@ class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather):
|
||||
|
||||
return TokenDispatchResult(
|
||||
hidden_states=sorted_hidden_states,
|
||||
dynamic_scale=None,
|
||||
group_list=expert_tokens,
|
||||
group_list_type=group_list_type,
|
||||
context_metadata=context_metadata,
|
||||
|
||||
Reference in New Issue
Block a user