[Refactor] Adjustments to moe_comm_method selection process (#3001)
### What this PR does / why we need it?
Fix issues mentioned in
https://github.com/vllm-project/vllm-ascend/pull/2791 and some minor
refactoring.
1. Use Enum instead of string.
2. Avoid setting a new property to forward_context in
AscendFusedMoE.forward().
3. Enabling TokenDispatcherWithMoge.
4. Remove redundant code.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Qwen3-30B-A3B/Qwen3-30B-A3B-W8A8/DeepSeek-V3-W4A8-Pruing/deepseek-mtp/pangu-pro-moe-pruing:
1. Enable/Disable EP
2. Aclgraph & eager
- vLLM version: v0.10.2
- vLLM main:
9607d5eb44
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
@@ -94,7 +94,8 @@ from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
|
||||
scatter_mm_placeholders)
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.ascend_forward_context import (MoECommType,
|
||||
set_ascend_forward_context)
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
@@ -1860,7 +1861,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
|
||||
def _select_moe_comm_method(self, num_tokens: int,
|
||||
with_prefill: bool) -> str:
|
||||
with_prefill: bool) -> MoECommType:
|
||||
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
|
||||
are designed for expert parallelism.
|
||||
2. If expert parallel is enabled, we need to consider the soc version and the
|
||||
@@ -1881,36 +1882,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
ValueError: If the soc version is unsupported.
|
||||
|
||||
Returns:
|
||||
str: The selected MoE communication method, either "allgather", "mc2", or "alltoall".
|
||||
MoECommType: The selected MoE communication method.
|
||||
"""
|
||||
soc_version = get_ascend_soc_version()
|
||||
quant_type = getattr(self.vllm_config.model_config.hf_config,
|
||||
'moe_quantize', None)
|
||||
model_type = self.vllm_config.model_config.hf_config.model_type
|
||||
|
||||
if not self.parallel_config.enable_expert_parallel:
|
||||
moe_comm_method = "allgather"
|
||||
moe_comm_type = MoECommType.ALLGATHER
|
||||
elif soc_version in {AscendSocVersion.A2}:
|
||||
if num_tokens <= self.mc2_tokens_capacity and self.parallel_config.world_size_across_dp >= 16:
|
||||
moe_comm_method = "mc2"
|
||||
if (num_tokens <= self.mc2_tokens_capacity
|
||||
and self.parallel_config.world_size_across_dp >= 16):
|
||||
moe_comm_type = MoECommType.MC2
|
||||
else:
|
||||
# Currently, w4a8_dynamic does not support allgatherep
|
||||
if quant_type == "w4a8_dynamic":
|
||||
moe_comm_method = "alltoall"
|
||||
moe_comm_type = MoECommType.ALLTOALL
|
||||
else:
|
||||
moe_comm_method = "allgather"
|
||||
moe_comm_type = MoECommType.ALLGATHER
|
||||
|
||||
elif soc_version in {AscendSocVersion.A3}:
|
||||
moe_comm_method = "mc2" if num_tokens <= self.mc2_tokens_capacity else "alltoall"
|
||||
moe_comm_type = (MoECommType.MC2
|
||||
if num_tokens <= self.mc2_tokens_capacity else
|
||||
MoECommType.ALLTOALL)
|
||||
else:
|
||||
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
||||
|
||||
if moe_comm_method == "allgather" and with_prefill:
|
||||
moe_comm_method = "naivemulticast"
|
||||
if moe_comm_type == MoECommType.ALLGATHER and with_prefill:
|
||||
moe_comm_type = MoECommType.NAIVE_MULTICAST
|
||||
|
||||
# PanguProMoE only supports allgather
|
||||
if model_type == "PanguProMoE":
|
||||
moe_comm_type = MoECommType.ALLGATHER
|
||||
|
||||
if is_global_first_rank():
|
||||
logger.debug(f"num_tokens: {num_tokens}, "
|
||||
f"moe_comm_method: {moe_comm_method}")
|
||||
|
||||
return moe_comm_method
|
||||
f"moe_comm_type: {moe_comm_type}")
|
||||
return moe_comm_type
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
@@ -1942,8 +1951,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.dynamic_eplb:
|
||||
self.eplb_updator.take_update_info_from_eplb_process()
|
||||
|
||||
moe_comm_method = self._select_moe_comm_method(num_input_tokens,
|
||||
self.with_prefill)
|
||||
moe_comm_type = self._select_moe_comm_method(num_input_tokens,
|
||||
self.with_prefill)
|
||||
|
||||
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
|
||||
scheduler_output.total_num_scheduled_tokens
|
||||
@@ -1962,7 +1971,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
with_prefill=self.with_prefill,
|
||||
reserved_mc2_mask=self.reserved_mc2_mask,
|
||||
moe_comm_method=moe_comm_method,
|
||||
moe_comm_type=moe_comm_type,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
num_actual_tokens=scheduler_output.
|
||||
@@ -2351,8 +2360,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||
_) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)
|
||||
|
||||
moe_comm_method = self._select_moe_comm_method(num_tokens,
|
||||
with_prefill)
|
||||
moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill)
|
||||
|
||||
# If cudagraph_mode.decode_mode() == FULL and
|
||||
# cudagraph_mode.seperate_routine(). This means that we are using
|
||||
@@ -2472,7 +2480,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
with_prefill=with_prefill,
|
||||
in_profile_run=self.in_profile_run,
|
||||
reserved_mc2_mask=self.reserved_mc2_mask,
|
||||
moe_comm_method=moe_comm_method,
|
||||
moe_comm_type=moe_comm_type,
|
||||
num_actual_tokens=0,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
|
||||
Reference in New Issue
Block a user