[2/N][Feat] Add MC2 communication method for MoE layers (#2469)

### What this PR does / why we need it?
This method replaces the previous all-gather approach for small numbers
of tokens.

The key changes include:
- A new `AscendFusedMoE` layer that handles token splitting, local
computation, and final aggregation via all-gather.
- Logic in the model runner to dynamically select between the new MC2
method and the existing all-gather method based on the number of input
tokens.
- Sharding the MoE communication mask across tensor-parallel ranks.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Test case fixed.


- vLLM version: v0.10.1.1
- vLLM main:
b00e69f8ca

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-08-26 19:05:23 +08:00
committed by GitHub
parent 5d8ec28009
commit a6bb502e70
11 changed files with 506 additions and 410 deletions

View File

@@ -24,7 +24,7 @@ import os
import time
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast
import numpy as np
import numpy.typing as npt
@@ -85,9 +85,6 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
DummyCommImpl,
MoECommMethod)
from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
@@ -368,13 +365,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
self.mc2_tokens_capacity = 512 * self.parallel_config.tensor_parallel_size
self.reserved_mc2_mask = torch.zeros(
512,
self.mc2_tokens_capacity,
dtype=torch.bool,
device=self.device,
)
self.moe_comm_method = AllGatherCommImpl
self.moe_comm_method = "mc2"
self.fallback_moe_comm_method = "allgather"
self.dummy_moe_comm_method = "dummy"
def _use_aclgraph(self) -> bool:
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
@@ -1622,6 +1622,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors) = (self._prepare_inputs(
scheduler_output, intermediate_tensors))
moe_comm_method = (self.moe_comm_method
if num_input_tokens <= self.mc2_tokens_capacity else
self.fallback_moe_comm_method)
# Run forward pass
with ProfileExecuteDuration().capture_async("forward"):
with set_ascend_forward_context(
@@ -1631,8 +1635,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=self.with_prefill,
reserved_mc2_mask=self.reserved_mc2_mask,
moe_comm_method=self.moe_comm_method(
self.device, self.dtype, self.model_config.hf_config),
moe_comm_method=moe_comm_method,
num_actual_tokens=scheduler_output.
total_num_scheduled_tokens):
self.maybe_setup_kv_connector(scheduler_output)
@@ -1938,7 +1941,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_tokens: int,
with_prefill: bool = False,
is_torchair_compile: bool = False,
moe_comm_method: Type[MoECommMethod] = DummyCommImpl,
moe_comm_method: str = "dummy",
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
force_attention: bool = False,
uniform_decode: bool = False,
@@ -2061,8 +2064,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
with_prefill=with_prefill,
in_profile_run=self.in_profile_run,
reserved_mc2_mask=self.reserved_mc2_mask,
moe_comm_method=moe_comm_method(
self.device, self.dtype, self.model_config.hf_config),
moe_comm_method=moe_comm_method,
num_actual_tokens=0,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor):