[2/N][Feat] Add MC2 communication method for MoE layers (#2469)

### What this PR does / why we need it?
This method replaces the previous all-gather approach for small numbers
of tokens.

The key changes include:
- A new `AscendFusedMoE` layer that handles token splitting, local
computation, and final aggregation via all-gather.
- Logic in the model runner to dynamically select between the new MC2
method and the existing all-gather method based on the number of input
tokens.
- Sharding the MoE communication mask across tensor-parallel ranks.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Test case fixed.


- vLLM version: v0.10.1.1
- vLLM main:
b00e69f8ca

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-08-26 19:05:23 +08:00
committed by GitHub
parent 5d8ec28009
commit a6bb502e70
11 changed files with 506 additions and 410 deletions

View File

@@ -20,7 +20,6 @@ import torch
import torch.distributed as dist
from vllm.distributed.device_communicators.base_device_communicator import \
DeviceCommunicatorBase
from vllm.utils import logger
class NPUCommunicator(DeviceCommunicatorBase):
@@ -35,12 +34,6 @@ class NPUCommunicator(DeviceCommunicatorBase):
# init device according to rank
self.device = torch.npu.current_device()
if self.use_all2all:
from vllm.distributed.device_communicators.all2all import \
NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
def all_to_all(self,
input_: torch.Tensor,
scatter_dim: int = 0,
@@ -80,17 +73,3 @@ class NPUCommunicator(DeviceCommunicatorBase):
dist.all_to_all(output_list, input_list, group=self.device_group)
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
return output_tensor
# TODO: Add ut for dispatch and combine
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states)
return hidden_states