[Refactor] Adjustments to moe_comm_method selection process (#3001)

### What this PR does / why we need it?
Fix issues mentioned in
https://github.com/vllm-project/vllm-ascend/pull/2791 and some minor
refactoring.
1. Use Enum instead of string.
2. Avoid setting a new property to forward_context in
AscendFusedMoE.forward().
3. Enabling TokenDispatcherWithMoge.
4. Remove redundant code.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

Qwen3-30B-A3B/Qwen3-30B-A3B-W8A8/DeepSeek-V3-W4A8-Pruing/deepseek-mtp/pangu-pro-moe-pruing:
1. Enable/Disable EP
2. Aclgraph & eager


- vLLM version: v0.10.2
- vLLM main:
9607d5eb44

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
weichen
2025-09-22 19:12:58 +08:00
committed by GitHub
parent bb1f0d5a62
commit 37a0715eda
14 changed files with 170 additions and 351 deletions

View File

@@ -94,7 +94,8 @@ from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
scatter_mm_placeholders)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.ascend_forward_context import (MoECommType,
set_ascend_forward_context)
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
@@ -1860,7 +1861,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)
def _select_moe_comm_method(self, num_tokens: int,
with_prefill: bool) -> str:
with_prefill: bool) -> MoECommType:
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
are designed for expert parallelism.
2. If expert parallel is enabled, we need to consider the soc version and the
@@ -1881,36 +1882,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
ValueError: If the soc version is unsupported.
Returns:
str: The selected MoE communication method, either "allgather", "mc2", or "alltoall".
MoECommType: The selected MoE communication method.
"""
soc_version = get_ascend_soc_version()
quant_type = getattr(self.vllm_config.model_config.hf_config,
'moe_quantize', None)
model_type = self.vllm_config.model_config.hf_config.model_type
if not self.parallel_config.enable_expert_parallel:
moe_comm_method = "allgather"
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendSocVersion.A2}:
if num_tokens <= self.mc2_tokens_capacity and self.parallel_config.world_size_across_dp >= 16:
moe_comm_method = "mc2"
if (num_tokens <= self.mc2_tokens_capacity
and self.parallel_config.world_size_across_dp >= 16):
moe_comm_type = MoECommType.MC2
else:
# Currently, w4a8_dynamic does not support allgatherep
if quant_type == "w4a8_dynamic":
moe_comm_method = "alltoall"
moe_comm_type = MoECommType.ALLTOALL
else:
moe_comm_method = "allgather"
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendSocVersion.A3}:
moe_comm_method = "mc2" if num_tokens <= self.mc2_tokens_capacity else "alltoall"
moe_comm_type = (MoECommType.MC2
if num_tokens <= self.mc2_tokens_capacity else
MoECommType.ALLTOALL)
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")
if moe_comm_method == "allgather" and with_prefill:
moe_comm_method = "naivemulticast"
if moe_comm_type == MoECommType.ALLGATHER and with_prefill:
moe_comm_type = MoECommType.NAIVE_MULTICAST
# PanguProMoE only supports allgather
if model_type == "PanguProMoE":
moe_comm_type = MoECommType.ALLGATHER
if is_global_first_rank():
logger.debug(f"num_tokens: {num_tokens}, "
f"moe_comm_method: {moe_comm_method}")
return moe_comm_method
f"moe_comm_type: {moe_comm_type}")
return moe_comm_type
@torch.inference_mode()
def execute_model(
@@ -1942,8 +1951,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.dynamic_eplb:
self.eplb_updator.take_update_info_from_eplb_process()
moe_comm_method = self._select_moe_comm_method(num_input_tokens,
self.with_prefill)
moe_comm_type = self._select_moe_comm_method(num_input_tokens,
self.with_prefill)
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
scheduler_output.total_num_scheduled_tokens
@@ -1962,7 +1971,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=self.with_prefill,
reserved_mc2_mask=self.reserved_mc2_mask,
moe_comm_method=moe_comm_method,
moe_comm_type=moe_comm_type,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor,
num_actual_tokens=scheduler_output.
@@ -2351,8 +2360,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
(num_tokens, num_tokens_across_dp, with_prefill,
_) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)
moe_comm_method = self._select_moe_comm_method(num_tokens,
with_prefill)
moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill)
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.seperate_routine(). This means that we are using
@@ -2472,7 +2480,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
with_prefill=with_prefill,
in_profile_run=self.in_profile_run,
reserved_mc2_mask=self.reserved_mc2_mask,
moe_comm_method=moe_comm_method,
moe_comm_type=moe_comm_type,
num_actual_tokens=0,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor,