[1/N][Feat] Support MoE models with ACL Graph and refactor MoE communication logic (#2125)

### What this PR does / why we need it?
This PR refactors the MoE (Mixture of Experts) communication logic by
introducing a strategy pattern. It defines an abstract base class,
`MoECommMethod`, which encapsulates different communication strategies
for MoE layers. By decoupling the MoE implementation from any single
communication method, this change makes it simpler to add, replace, or
optimize communication strategies in the future.

Plan / Roadmap

1. Introduce `MoECommMethod`, implement `AllGatherImpl`, and adapt ACL
Graph handling to cover all scenarios (this PR).
2. Implement `MC2CommImpl` and `AllToAllCommImpl` to optimize
performance in specific scenarios.
3. Enable W8A8 / Int8 models to use `unified_fused_experts`.

Other notes

* Data-parallel (DP) communication currently does not work with vLLM's
dispatch/combine mechanisms; an alternative approach is required to
resolve this incompatibility.

- vLLM version: v0.10.0
- vLLM main:
f7ad6a1eb3

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-08-12 21:10:20 +08:00
committed by GitHub
parent 1a70564e7c
commit 992271b027
7 changed files with 764 additions and 26 deletions

View File

@@ -5,11 +5,12 @@ from typing import Any, Optional
import torch
from vllm.config import VllmConfig
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
from vllm.distributed import (get_dp_group, get_ep_group,
get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context, set_forward_context
import vllm_ascend.envs as envs
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
class FusedMoEState(Enum):
@@ -54,6 +55,8 @@ def set_ascend_forward_context(
num_tokens_across_dp: Optional[torch.Tensor] = None,
with_prefill: bool = True,
in_profile_run: bool = False,
reserved_mc2_mask: Optional[torch.Tensor] = None,
moe_comm_method: Optional[MoECommMethod] = None,
num_actual_tokens: Optional[int] = None,
):
"""A context manager that stores the current forward context,
@@ -66,6 +69,7 @@ def set_ascend_forward_context(
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp):
forward_context = get_forward_context()
forward_context.moe_comm_method = moe_comm_method
forward_context.with_prefill = with_prefill
ep_size = (get_ep_group().world_size if
vllm_config.parallel_config.enable_expert_parallel else 1)
@@ -97,16 +101,17 @@ def set_ascend_forward_context(
if num_tokens is not None:
if num_actual_tokens is None:
num_actual_tokens = num_tokens
tp_world_size = get_tp_group().world_size
tp_world_size = get_tensor_model_parallel_world_size()
# NOTE: token num which need to pad to when mc2
forward_context.padded_num_tokens = math.ceil(
max_tokens_across_dp / tp_world_size) * tp_world_size
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
dtype=torch.bool,
device=NPUPlatform.device_type)
mc2_mask[:num_actual_tokens] = True
forward_context.mc2_mask = mc2_mask
if reserved_mc2_mask is not None:
mc2_mask = reserved_mc2_mask[:forward_context.
padded_num_tokens]
mc2_mask[:num_actual_tokens] = True
mc2_mask[num_actual_tokens:] = False
forward_context.mc2_mask = mc2_mask
try:
yield