[2/N][Feat] Add MC2 communication method for MoE layers (#2469)
### What this PR does / why we need it?
This method replaces the previous all-gather approach for small numbers
of tokens.
The key changes include:
- A new `AscendFusedMoE` layer that handles token splitting, local
computation, and final aggregation via all-gather.
- Logic in the model runner to dynamically select between the new MC2
method and the existing all-gather method based on the number of input
tokens.
- Sharding the MoE communication mask across tensor-parallel ranks.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
Test case fixed.
- vLLM version: v0.10.1.1
- vLLM main:
b00e69f8ca
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -11,7 +11,6 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
|
||||
set_forward_context)
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
|
||||
|
||||
|
||||
class FusedMoEState(Enum):
|
||||
@@ -57,7 +56,7 @@ def set_ascend_forward_context(
|
||||
with_prefill: bool = True,
|
||||
in_profile_run: bool = False,
|
||||
reserved_mc2_mask: Optional[torch.Tensor] = None,
|
||||
moe_comm_method: Optional[MoECommMethod] = None,
|
||||
moe_comm_method: str = "",
|
||||
num_actual_tokens: Optional[int] = None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: Optional[BatchDescriptor] = None):
|
||||
@@ -75,7 +74,7 @@ def set_ascend_forward_context(
|
||||
batch_descriptor=batch_descriptor,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
forward_context.moe_comm_method = moe_comm_method
|
||||
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
|
||||
forward_context.with_prefill = with_prefill
|
||||
ep_size = (get_ep_group().world_size if
|
||||
vllm_config.parallel_config.enable_expert_parallel else 1)
|
||||
|
||||
Reference in New Issue
Block a user