[2/N][Feat] Add MC2 communication method for MoE layers (#2469)
### What this PR does / why we need it?
This method replaces the previous all-gather approach for small numbers
of tokens.
The key changes include:
- A new `AscendFusedMoE` layer that handles token splitting, local
computation, and final aggregation via all-gather.
- Logic in the model runner to dynamically select between the new MC2
method and the existing all-gather method based on the number of input
tokens.
- Sharding the MoE communication mask across tensor-parallel ranks.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
Test case fixed.
- vLLM version: v0.10.1.1
- vLLM main:
b00e69f8ca
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -24,7 +24,7 @@ import os
|
||||
import time
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@@ -85,9 +85,6 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
|
||||
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
||||
DummyCommImpl,
|
||||
MoECommMethod)
|
||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||
@@ -368,13 +365,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
|
||||
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
|
||||
|
||||
self.mc2_tokens_capacity = 512 * self.parallel_config.tensor_parallel_size
|
||||
self.reserved_mc2_mask = torch.zeros(
|
||||
512,
|
||||
self.mc2_tokens_capacity,
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.moe_comm_method = AllGatherCommImpl
|
||||
self.moe_comm_method = "mc2"
|
||||
self.fallback_moe_comm_method = "allgather"
|
||||
self.dummy_moe_comm_method = "dummy"
|
||||
|
||||
def _use_aclgraph(self) -> bool:
|
||||
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
|
||||
@@ -1622,6 +1622,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
intermediate_tensors) = (self._prepare_inputs(
|
||||
scheduler_output, intermediate_tensors))
|
||||
|
||||
moe_comm_method = (self.moe_comm_method
|
||||
if num_input_tokens <= self.mc2_tokens_capacity else
|
||||
self.fallback_moe_comm_method)
|
||||
|
||||
# Run forward pass
|
||||
with ProfileExecuteDuration().capture_async("forward"):
|
||||
with set_ascend_forward_context(
|
||||
@@ -1631,8 +1635,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
with_prefill=self.with_prefill,
|
||||
reserved_mc2_mask=self.reserved_mc2_mask,
|
||||
moe_comm_method=self.moe_comm_method(
|
||||
self.device, self.dtype, self.model_config.hf_config),
|
||||
moe_comm_method=moe_comm_method,
|
||||
num_actual_tokens=scheduler_output.
|
||||
total_num_scheduled_tokens):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
@@ -1938,7 +1941,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens: int,
|
||||
with_prefill: bool = False,
|
||||
is_torchair_compile: bool = False,
|
||||
moe_comm_method: Type[MoECommMethod] = DummyCommImpl,
|
||||
moe_comm_method: str = "dummy",
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
force_attention: bool = False,
|
||||
uniform_decode: bool = False,
|
||||
@@ -2061,8 +2064,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
with_prefill=with_prefill,
|
||||
in_profile_run=self.in_profile_run,
|
||||
reserved_mc2_mask=self.reserved_mc2_mask,
|
||||
moe_comm_method=moe_comm_method(
|
||||
self.device, self.dtype, self.model_config.hf_config),
|
||||
moe_comm_method=moe_comm_method,
|
||||
num_actual_tokens=0,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor):
|
||||
|
||||
Reference in New Issue
Block a user