[1/N][Feat] Support MoE models with ACL Graph and refactor MoE communication logic (#2125)

### What this PR does / why we need it?
This PR refactors the MoE (Mixture of Experts) communication logic by
introducing a strategy pattern. It defines an abstract base class,
`MoECommMethod`, which encapsulates different communication strategies
for MoE layers. By decoupling the MoE implementation from any single
communication method, this change makes it simpler to add, replace, or
optimize communication strategies in the future.

Plan / Roadmap

1. Introduce `MoECommMethod`, implement `AllGatherImpl`, and adapt ACL
Graph handling to cover all scenarios (this PR).
2. Implement `MC2CommImpl` and `AllToAllCommImpl` to optimize
performance in specific scenarios.
3. Enable W8A8 / Int8 models to use `unified_fused_experts`.

Other notes

* Data-parallel (DP) communication currently does not work with vLLM's
dispatch/combine mechanisms; an alternative approach is required to
resolve this incompatibility.

- vLLM version: v0.10.0
- vLLM main:
f7ad6a1eb3

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-08-12 21:10:20 +08:00
committed by GitHub
parent 1a70564e7c
commit 992271b027
7 changed files with 764 additions and 26 deletions

View File

@@ -26,7 +26,7 @@ import types
import weakref
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast
import numpy as np
import numpy.typing as npt
@@ -43,7 +43,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
get_tp_group)
from vllm.forward_context import get_forward_context
from vllm.forward_context import DPMetadata, get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
@@ -79,6 +79,9 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
DummyCommImpl,
MoECommMethod)
from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
@@ -335,7 +338,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.use_aclgraph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
and not self.model_config.enforce_eager and
not ascend_config.torchair_graph_config.enabled)
self.aclgraph_batch_sizes = list(
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
@@ -375,6 +379,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
self.reserved_mc2_mask = torch.zeros(
512,
dtype=torch.bool,
device=self.device,
)
self.moe_comm_method = AllGatherCommImpl
def check_batch_sizes_consistency(self) -> None:
if not dist.is_initialized():
return
@@ -1003,6 +1015,32 @@ class NPUModelRunner(LoRAModelRunnerMixin):
mm_embeds.append(mm_embeds_item)
return mm_embeds
def get_dp_padding(self,
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
"""This implementation is derived from vLLM's `GPUModelRunner.get_dp_padding`.
Please note that vLLM may refactor or modify this function over time,
at present, we are using the version introduced in PR #18935.
"""
dp_size = self.vllm_config.parallel_config.data_parallel_size
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
# For DP: Don't pad when setting enforce_eager.
# This lets us set enforce_eager on the prefiller in a P/D setup and
# still use ACL graphs (enabled by this padding) on the decoder.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
# Early exit.
return 0, None
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
num_tokens, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
dp_size,
device="cpu",
dtype=torch.int32)
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
def _process_reqs(
self,
scheduler_output: "SchedulerOutput",
@@ -1025,6 +1063,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Eager mode.
num_input_tokens = total_num_scheduled_tokens
# Padding for DP
num_pad, num_tokens_across_dp_native = self.get_dp_padding(
num_input_tokens)
num_input_tokens += num_pad
modified_batch = self.attn_metadata_builder.reorder_batch(
self.input_batch, scheduler_output)
if modified_batch:
@@ -1250,13 +1293,26 @@ class NPUModelRunner(LoRAModelRunnerMixin):
for k, v in self.intermediate_tensors.items()
})
moe_comm_method = self.moe_comm_method
# NOTE: Currently this padding logic is really messy,
# MC2 may not be available in eager mode
# TODO: Unify the padding logic between TorchAir and ACL Graph ASAP
if self.use_aclgraph:
num_tokens_across_dp = num_tokens_across_dp_native
else:
num_input_tokens = padded_num_tokens_across_dp
# Run forward pass
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=padded_num_tokens_across_dp,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=with_prefill,
reserved_mc2_mask=self.reserved_mc2_mask,
moe_comm_method=moe_comm_method(self.device, self.dtype,
self.model_config.hf_config),
num_actual_tokens=total_num_scheduled_tokens):
with ProfileExecuteDuration().capture_async("forward"):
self.maybe_setup_kv_connector(scheduler_output)
@@ -1865,6 +1921,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
skip_attn: bool = True,
with_prefill: bool = False,
is_torchair_compile: bool = False,
moe_comm_method: Type[MoECommMethod] = DummyCommImpl,
) -> torch.Tensor:
# Padding for DP
(num_tokens, num_tokens_across_dp, with_prefill,
@@ -1932,6 +1989,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=with_prefill,
in_profile_run=self.in_profile_run,
reserved_mc2_mask=self.reserved_mc2_mask,
moe_comm_method=moe_comm_method(
self.device, self.dtype, self.model_config.hf_config),
num_actual_tokens=0,
):
hidden_states = self._generate_dummy_run_hidden_states(
@@ -2328,13 +2388,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Trigger ACL graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
# TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode
with graph_capture(device=self.device):
skip_attn = not self.vllm_config.compilation_config.full_cuda_graph
for num_tokens in reversed(self.aclgraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens)
self._dummy_run(num_tokens)
self._dummy_run(
num_tokens,
skip_attn=skip_attn,
moe_comm_method=self.moe_comm_method,
)
self._dummy_run(
num_tokens,
skip_attn=skip_attn,
moe_comm_method=self.moe_comm_method,
)
def capture_model(self) -> None:
start_time = time.perf_counter()