[Refactor] Adjustments to moe_comm_method selection process (#3001)

### What this PR does / why we need it?
Fix issues mentioned in
https://github.com/vllm-project/vllm-ascend/pull/2791 and some minor
refactoring.
1. Use Enum instead of string.
2. Avoid setting a new property to forward_context in
AscendFusedMoE.forward().
3. Enabling TokenDispatcherWithMoge.
4. Remove redundant code.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

Qwen3-30B-A3B/Qwen3-30B-A3B-W8A8/DeepSeek-V3-W4A8-Pruing/deepseek-mtp/pangu-pro-moe-pruing:
1. Enable/Disable EP
2. Aclgraph & eager


- vLLM version: v0.10.2
- vLLM main:
9607d5eb44

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
weichen
2025-09-22 19:12:58 +08:00
committed by GitHub
parent bb1f0d5a62
commit 37a0715eda
14 changed files with 170 additions and 351 deletions

View File

@@ -13,14 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Any, Dict, Optional
import torch
from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
FusedMoEPrepareAndFinalizeWithAll2All,
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
@@ -28,13 +31,31 @@ from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather,
TokenDispatcherWithMC2)
TokenDispatcherWithMC2,
TokenDispatcherWithMoge)
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
def get_moe_comm_method(
moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]:
return _MoECommMethods.get(moe_comm_type)
def setup_moe_comm_method(moe_config):
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
_MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl(
moe_config)
class MoECommMethod(ABC):
"""Base class for MoE communication methods."""
def __init__(self, moe_config: FusedMoEConfig):
self.model_type = get_current_vllm_config(
).model_config.hf_config.model_type
self.moe_config = moe_config
self.mc2_mask = None
@@ -113,8 +134,8 @@ class MoECommMethod(ABC):
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=use_int8_w8a8 or use_int4_w4a8)
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = \
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"]
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales")
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
w1=w1,
@@ -126,6 +147,7 @@ class MoECommMethod(ABC):
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=topk_scales,
with_quant=use_int8_w8a8
or use_int4_w4a8,
fusion=use_int8_w8a8,
@@ -170,94 +192,21 @@ class AllGatherCommImpl(MoECommMethod):
"""
def _get_token_dispatcher(self):
return TokenDispatcherWithAllGather(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
if self.model_type == "PanguProMoE":
return TokenDispatcherWithMoge(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
else:
return TokenDispatcherWithAllGather(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
def _get_fused_moe_prepare_finalize(self):
return FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config)
class NativeAllGatherCommImpl(AllGatherCommImpl):
"""This implementation should be compatible with all scenarios.
Note that this implementation purely consists of native PyTorch ops
and does not use any NPU-specific ops. So the performance may not be optimal.
But it is a good fallback for scenarios where NPU-specific ops are not available.
"""
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
apply_a8_quantization: bool,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
num_tokens = hidden_states.shape[0]
# Generate token indices and flatten
token_indices = torch.arange(num_tokens,
device=hidden_states.device,
dtype=torch.int64)
token_indices = (token_indices.unsqueeze(1).expand(
-1, self.moe_config.experts_per_token).reshape(-1))
# Flatten token-to-expert mappings and map to local experts
weights_flat = topk_weights.view(-1)
experts_flat = topk_ids.view(-1)
local_experts_flat = (expert_map[experts_flat]
if expert_map is not None else experts_flat)
# Filter valid token-expert pairs
mask = local_experts_flat != -1
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
# So we need to filter out invalid tokens by zeroing their weights.
# This is a workaround and should be removed after the issue is fixed
filtered_weights = torch.where(mask, weights_flat,
torch.zeros_like(weights_flat)).to(
topk_weights.dtype)
filtered_experts = torch.where(
mask,
local_experts_flat,
torch.full_like(local_experts_flat, num_experts),
).to(topk_ids.dtype)
# Sort by local expert IDs
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
self.sorted_token_indices = token_indices[sort_indices]
self.sorted_weights = filtered_weights[sort_indices]
# Compute token counts with minlength of num_experts
# This is equivalent to but faster than:
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
token_counts = torch.zeros(num_experts + 1,
device=hidden_states.device,
dtype=torch.int64)
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
expert_tokens = token_counts[:num_experts]
# Rearrange hidden_states
permuted_hidden_states = hidden_states[self.sorted_token_indices]
group_list_type = 1 # `count` mode
return permuted_hidden_states, expert_tokens, None, group_list_type
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
final_hidden_states = torch.zeros_like(hidden_states)
final_hidden_states.index_add_(0, self.sorted_token_indices,
mlp_output)
hidden_states[:] = final_hidden_states
class MC2CommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.

View File

@@ -21,6 +21,7 @@ import torch_npu
from torch.nn.functional import pad
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.utils import dispose_tensor, is_310p
@@ -76,7 +77,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
bias1, bias2 = None, None
_output_dtype = w2_scale.dtype
is_mc2 = get_forward_context().moe_comm_method_name == "mc2commimpl"
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
if w1_scale_bias is None and is_mc2:
if w1_scale.dtype != torch.float32:
w1_scale = w1_scale.to(torch.float32)

View File

@@ -377,14 +377,13 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
# mypy: disable-error-code="override"
class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
class TokenDispatcherWithMoge(MoETokenDispatcher):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.apply_router_weight_on_input = False
self.local_ep = 1
self.local_num_experts = self.num_experts // self.local_ep
self.local_num_group = self.top_k // self.local_ep
self.local_num_experts = self.num_experts // self.ep_size
self.local_num_group = self.top_k // self.ep_size
self.bsz = None
def token_dispatch(self,
@@ -401,17 +400,6 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input:
assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * \
topk_weights.to(hidden_states.dtype)
self.bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1)
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
@@ -445,7 +433,7 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
unsorted_hidden_states = hidden_states.index_select(
0, unsorted_topk_ids)
final_hidden_states = unsorted_hidden_states.reshape(
self.bsz, self.top_k // self.local_ep, -1).sum(1)
self.bsz, self.top_k // self.ep_size, -1).sum(1)
return final_hidden_states