[1/N][Feat] Support MoE models with ACL Graph and refactor MoE communication logic (#2125)
### What this PR does / why we need it?
This PR refactors the MoE (Mixture of Experts) communication logic by
introducing a strategy pattern. It defines an abstract base class,
`MoECommMethod`, which encapsulates different communication strategies
for MoE layers. By decoupling the MoE implementation from any single
communication method, this change makes it simpler to add, replace, or
optimize communication strategies in the future.
Plan / Roadmap
1. Introduce `MoECommMethod`, implement `AllGatherImpl`, and adapt ACL
Graph handling to cover all scenarios (this PR).
2. Implement `MC2CommImpl` and `AllToAllCommImpl` to optimize
performance in specific scenarios.
3. Enable W8A8 / Int8 models to use `unified_fused_experts`.
Other notes
* Data-parallel (DP) communication currently does not work with vLLM's
dispatch/combine mechanisms; an alternative approach is required to
resolve this incompatibility.
- vLLM version: v0.10.0
- vLLM main:
f7ad6a1eb3
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -26,7 +26,7 @@ import types
|
||||
import weakref
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@@ -43,7 +43,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
|
||||
get_tp_group)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.forward_context import DPMetadata, get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
@@ -79,6 +79,9 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
|
||||
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
||||
DummyCommImpl,
|
||||
MoECommMethod)
|
||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||
@@ -335,7 +338,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
self.use_aclgraph = (self.vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE
|
||||
and not self.model_config.enforce_eager)
|
||||
and not self.model_config.enforce_eager and
|
||||
not ascend_config.torchair_graph_config.enabled)
|
||||
self.aclgraph_batch_sizes = list(
|
||||
reversed(
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||
@@ -375,6 +379,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
|
||||
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
|
||||
|
||||
self.reserved_mc2_mask = torch.zeros(
|
||||
512,
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.moe_comm_method = AllGatherCommImpl
|
||||
|
||||
def check_batch_sizes_consistency(self) -> None:
|
||||
if not dist.is_initialized():
|
||||
return
|
||||
@@ -1003,6 +1015,32 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
mm_embeds.append(mm_embeds_item)
|
||||
return mm_embeds
|
||||
|
||||
def get_dp_padding(self,
|
||||
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
|
||||
"""This implementation is derived from vLLM's `GPUModelRunner.get_dp_padding`.
|
||||
Please note that vLLM may refactor or modify this function over time,
|
||||
at present, we are using the version introduced in PR #18935.
|
||||
"""
|
||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
# For DP: Don't pad when setting enforce_eager.
|
||||
# This lets us set enforce_eager on the prefiller in a P/D setup and
|
||||
# still use ACL graphs (enabled by this padding) on the decoder.
|
||||
|
||||
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
|
||||
# Early exit.
|
||||
return 0, None
|
||||
|
||||
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
|
||||
num_tokens, dp_size, dp_rank)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
|
||||
num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
|
||||
dp_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
|
||||
|
||||
def _process_reqs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@@ -1025,6 +1063,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Eager mode.
|
||||
num_input_tokens = total_num_scheduled_tokens
|
||||
|
||||
# Padding for DP
|
||||
num_pad, num_tokens_across_dp_native = self.get_dp_padding(
|
||||
num_input_tokens)
|
||||
num_input_tokens += num_pad
|
||||
|
||||
modified_batch = self.attn_metadata_builder.reorder_batch(
|
||||
self.input_batch, scheduler_output)
|
||||
if modified_batch:
|
||||
@@ -1250,13 +1293,26 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
moe_comm_method = self.moe_comm_method
|
||||
|
||||
# NOTE: Currently this padding logic is really messy,
|
||||
# MC2 may not be available in eager mode
|
||||
# TODO: Unify the padding logic between TorchAir and ACL Graph ASAP
|
||||
if self.use_aclgraph:
|
||||
num_tokens_across_dp = num_tokens_across_dp_native
|
||||
else:
|
||||
num_input_tokens = padded_num_tokens_across_dp
|
||||
|
||||
# Run forward pass
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=padded_num_tokens_across_dp,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
with_prefill=with_prefill,
|
||||
reserved_mc2_mask=self.reserved_mc2_mask,
|
||||
moe_comm_method=moe_comm_method(self.device, self.dtype,
|
||||
self.model_config.hf_config),
|
||||
num_actual_tokens=total_num_scheduled_tokens):
|
||||
with ProfileExecuteDuration().capture_async("forward"):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
@@ -1865,6 +1921,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
skip_attn: bool = True,
|
||||
with_prefill: bool = False,
|
||||
is_torchair_compile: bool = False,
|
||||
moe_comm_method: Type[MoECommMethod] = DummyCommImpl,
|
||||
) -> torch.Tensor:
|
||||
# Padding for DP
|
||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||
@@ -1932,6 +1989,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
with_prefill=with_prefill,
|
||||
in_profile_run=self.in_profile_run,
|
||||
reserved_mc2_mask=self.reserved_mc2_mask,
|
||||
moe_comm_method=moe_comm_method(
|
||||
self.device, self.dtype, self.model_config.hf_config),
|
||||
num_actual_tokens=0,
|
||||
):
|
||||
hidden_states = self._generate_dummy_run_hidden_states(
|
||||
@@ -2328,13 +2388,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Trigger ACL graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
# TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode
|
||||
with graph_capture(device=self.device):
|
||||
skip_attn = not self.vllm_config.compilation_config.full_cuda_graph
|
||||
for num_tokens in reversed(self.aclgraph_batch_sizes):
|
||||
for _ in range(self.vllm_config.compilation_config.
|
||||
cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens)
|
||||
self._dummy_run(num_tokens)
|
||||
self._dummy_run(
|
||||
num_tokens,
|
||||
skip_attn=skip_attn,
|
||||
moe_comm_method=self.moe_comm_method,
|
||||
)
|
||||
self._dummy_run(
|
||||
num_tokens,
|
||||
skip_attn=skip_attn,
|
||||
moe_comm_method=self.moe_comm_method,
|
||||
)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
Reference in New Issue
Block a user