Files
xc-llm-ascend/vllm_ascend/distributed/moe_comm_method.py
yiz-liu a6bb502e70 [2/N][Feat] Add MC2 communication method for MoE layers (#2469)
### What this PR does / why we need it?
This method replaces the previous all-gather approach for small numbers
of tokens.

The key changes include:
- A new `AscendFusedMoE` layer that handles token splitting, local
computation, and final aggregation via all-gather.
- Logic in the model runner to dynamically select between the new MC2
method and the existing all-gather method based on the number of input
tokens.
- Sharding the MoE communication mask across tensor-parallel ranks.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Test case fixed.


- vLLM version: v0.10.1.1
- vLLM main:
b00e69f8ca

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-08-26 19:05:23 +08:00

497 lines
20 KiB
Python

from abc import ABC, abstractmethod
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_npu
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.distributed.communication_op import \
data_parallel_reduce_scatter
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
class MoECommMethod(ABC):
"""Base class for MoE communication methods."""
def __init__(self, moe_config: FusedMoEConfig):
self.moe_config = moe_config
@abstractmethod
def prepare(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Prepare the MoE communication method.
This method is called before quant_method.apply to prepare the
communication method. It can be used to initialize any necessary
resources or configurations.
"""
pass
@abstractmethod
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""Finalize the MoE communication method.
This method is called after quant_method.apply to finalize the
communication method. It can be used to clean up any resources or
configurations.
"""
pass
@abstractmethod
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""Pre-process before MLP.
Args:
hidden_states (torch.Tensor): Tensor of shape (num_tokens, hidden_size)
topk_ids (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
topk_weights (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
Mapping from global expert IDs to local expert IDs.
num_experts (int): Number of local experts (experts on this device).
Returns:
tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
- permuted_hidden_states (torch.Tensor): Tensor of shape
(num_tokens * top_k_num, hidden_size) after permuting
hidden_states based on topk_ids.
- expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
Number of tokens assigned to each expert.
- group_list_type (int): Type of group list, 0 for `cumsum`
and 1 for `count`. This is mainly for `npu_grouped_matmul`
to determine how to handle the output.
Raises:
NotImplementedError: If the method is not implemented in the subclass.
"""
pass
@abstractmethod
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
"""Post-process after MLP.
Args:
mlp_output (torch.Tensor): Tensor of shape
(num_tokens * top_k_num, hidden_size) after MLP.
hidden_states (torch.Tensor): Tensor of shape
(num_tokens, hidden_size) to be updated with the final output.
"""
pass
class DummyCommImpl(MoECommMethod):
def prepare(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Dummy prepare method that does nothing."""
return hidden_states, router_logits
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""Dummy finalize method that does nothing."""
return hidden_states
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""Dummy implementation, make sure the output shapes are correct."""
top_k_num = topk_ids.shape[1]
permuted_hidden_states = hidden_states.repeat_interleave(top_k_num,
dim=0)
expert_tokens = torch.zeros((num_experts, ),
dtype=torch.int64,
device=hidden_states.device)
group_list_type = 0
return permuted_hidden_states, expert_tokens, group_list_type
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
"""Dummy implementation that does nothing."""
pass
class AllGatherCommImpl(MoECommMethod):
"""This implementation is the same as NativeAllGatherCommImpl,
but uses NPU-specific ops for better performance.
This implementation should be compatible with all scenarios, and
thus it is the default implementation for MoE communication methods.
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
and `torch_npu.npu_moe_token_unpermute` for post-processing
to handle the token-to-expert mapping and communication efficiently.
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
for pre-processing and post-processing, respectively.
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
use `torch_npu.npu_moe_token_unpermute` instead.
This is a workaround and should be removed after the issue is fixed.
"""
def prepare(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""When DP size > 1, pad the hidden states and router logits for communication."""
if self.moe_config.dp_size > 1:
forward_context = get_forward_context()
max_tokens_across_dp = forward_context.max_tokens_across_dp
self.num_tokens = hidden_states.shape[0]
pad_size = max_tokens_across_dp - self.num_tokens
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
hidden_states = self.moe_config.dp_group.all_gather(
hidden_states, 0)
router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0)
return hidden_states, router_logits
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""When DP size > 1, reduce-scatter the hidden states to get the final output.
When TP size > 1, all-reduce the hidden states to get the final output.
"""
if self.moe_config.dp_size > 1:
hidden_states = data_parallel_reduce_scatter(hidden_states, dim=0)
hidden_states = hidden_states[:self.num_tokens]
if reduce_results and (self.moe_config.tp_size > 1
or self.moe_config.ep_size > 1):
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
return hidden_states
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor, # noqa: F841
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, int]:
num_tokens = hidden_states.shape[0]
self.topk_weights = topk_weights
self.topk_ids = topk_ids
first_expert_idx = 0
if expert_map is not None:
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
# So we need to filter out invalid tokens by zeroing their weights.
# This is a workaround and should be removed after the issue is fixed
mask = expert_map[topk_ids] != -1
# NOTE: This is equivalent to self.topk_weights[~mask] = 0.0,
# but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph
self.topk_weights = torch.where(mask, topk_weights, 0.0)
first_expert_idx = self.moe_config.ep_rank * num_experts
last_expert_idx = first_expert_idx + num_experts
permuted_hidden_states, expanded_row_idx, expert_tokens, _ = (
torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
active_num=num_tokens * self.moe_config.experts_per_token,
expert_num=self.moe_config.num_experts,
expert_tokens_num_type=1, # Only support `count` mode now
expert_tokens_num_flag=True, # Output `expert_tokens`
active_expert_range=[first_expert_idx, last_expert_idx],
quant_mode=-1,
))
self.expanded_row_idx = expanded_row_idx
permuted_hidden_states = permuted_hidden_states
group_list_type = 1 # `count` mode
return permuted_hidden_states, expert_tokens, group_list_type
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
hidden_states[:] = torch_npu.npu_moe_token_unpermute(
permuted_tokens=mlp_output,
sorted_indices=self.expanded_row_idx,
probs=self.topk_weights)
class NativeAllGatherCommImpl(AllGatherCommImpl):
"""This implementation should be compatible with all scenarios.
Note that this implementation purely consists of native PyTorch ops
and does not use any NPU-specific ops. So the performance may not be optimal.
But it is a good fallback for scenarios where NPU-specific ops are not available.
"""
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, int]:
num_tokens = hidden_states.shape[0]
# Generate token indices and flatten
token_indices = torch.arange(num_tokens,
device=hidden_states.device,
dtype=torch.int64)
token_indices = (token_indices.unsqueeze(1).expand(
-1, self.moe_config.experts_per_token).reshape(-1))
# Flatten token-to-expert mappings and map to local experts
weights_flat = topk_weights.view(-1)
experts_flat = topk_ids.view(-1)
local_experts_flat = (expert_map[experts_flat]
if expert_map is not None else experts_flat)
# Filter valid token-expert pairs
mask = local_experts_flat != -1
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
# So we need to filter out invalid tokens by zeroing their weights.
# This is a workaround and should be removed after the issue is fixed
filtered_weights = torch.where(mask, weights_flat,
torch.zeros_like(weights_flat)).to(
topk_weights.dtype)
filtered_experts = torch.where(
mask,
local_experts_flat,
torch.full_like(local_experts_flat, num_experts),
).to(topk_ids.dtype)
# Sort by local expert IDs
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
self.sorted_token_indices = token_indices[sort_indices]
self.sorted_weights = filtered_weights[sort_indices]
# Compute token counts with minlength of num_experts
# This is equivalent to but faster than:
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
token_counts = torch.zeros(num_experts + 1,
device=hidden_states.device,
dtype=torch.int64)
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
expert_tokens = token_counts[:num_experts]
# Rearrange hidden_states
permuted_hidden_states = hidden_states[self.sorted_token_indices]
group_list_type = 1 # `count` mode
return permuted_hidden_states, expert_tokens, group_list_type
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
final_hidden_states = torch.zeros_like(hidden_states)
final_hidden_states.index_add_(0, self.sorted_token_indices,
mlp_output)
hidden_states[:] = final_hidden_states
class MC2CommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
3. `enable_expert_parallel=False` is not supported.
This implementation uses the MC2 communication method, which is optimized for
Communication and Computation parallelism on Ascend devices.
"""
def __init__(self, moe_config: Optional[FusedMoEConfig]):
super().__init__(moe_config)
# NOTE: We do not need to use mc2_group's rank and world size
# because ep_group and mc2_group basically have the same init params.
# We only init another group because of the restriction of MC2:
# "No other groups can be used in the same process as the MC2 group."
self.mc2_comm_name = get_mc2_group().device_group._get_backend(
torch.device("npu")).get_hccl_comm_name(self.moe_config.ep_rank)
# Feature flags
self.enable_dispatch_v2 = hasattr(torch_npu,
"npu_moe_distribute_dispatch_v2")
self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3
self.need_extra_args = self.is_ascend_a3
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
# tp_size and tp_rank.
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""The target_pad_length is calculated in forward_context, here we pad the
hidden states and router logits. And if TP size > 1, we also need to split
the tensors accordingly.
"""
self.num_tokens, _ = hidden_states.shape
forward_context = get_forward_context()
self.mc2_mask = forward_context.mc2_mask
target_pad_length = forward_context.padded_num_tokens
pad_size = target_pad_length - self.num_tokens
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
split_mc2_mask = torch.tensor_split(self.mc2_mask,
self.tp_size,
dim=0)
self.split_hidden_states = split_hidden_states
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
self.mc2_mask = split_mc2_mask[self.tp_rank]
return hidden_states, router_logits
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""If TP size > 1, all-gather the hidden states to get the final output.
Also, unpad the hidden states if needed.
"""
if self.tp_size > 1:
dist.all_gather(list(self.split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(self.split_hidden_states, dim=0)
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, int]:
# Store tensors needed for post_process
self.topk_ids = topk_ids
self.topk_weights = topk_weights.to(torch.float32)
dispatch_kwargs = {
"x": hidden_states,
"expert_ids": self.topk_ids,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": self.moe_config.num_experts,
"global_bs": 0,
"scales": None,
"quant_mode": 0,
"group_ep": self.mc2_comm_name,
"ep_world_size": self.moe_config.ep_size,
"ep_rank_id": self.moe_config.ep_rank,
}
if self.need_extra_args:
dispatch_kwargs.update({
"group_tp": self.mc2_comm_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.is_ascend_a3 and self.enable_dispatch_v2:
dispatch_kwargs.update({
"x_active_mask": self.mc2_mask,
})
dispatch = torch_npu.npu_moe_distribute_dispatch_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch
(
permuted_hidden_states,
_, # dynamic_scale is not used
self.assist_info_for_combine,
expert_tokens,
self.ep_recv_counts,
self.tp_recv_counts,
) = dispatch(**dispatch_kwargs)[:6]
group_list_type = 1
return permuted_hidden_states, expert_tokens, group_list_type
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
combine_kwargs = {
"expand_x": mlp_output,
"expert_ids": self.topk_ids,
"expert_scales": self.topk_weights,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": self.moe_config.num_experts,
"global_bs": 0,
"ep_send_counts": self.ep_recv_counts,
"group_ep": self.mc2_comm_name,
"ep_world_size": self.moe_config.ep_size,
"ep_rank_id": self.moe_config.ep_rank,
}
if self.enable_dispatch_v2:
combine_kwargs[
"assist_info_for_combine"] = self.assist_info_for_combine
else:
combine_kwargs["expand_idx"] = self.assist_info_for_combine
if self.need_extra_args:
combine_kwargs.update({
"tp_send_counts": self.tp_recv_counts,
"group_tp": self.mc2_comm_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.is_ascend_a3 and self.enable_dispatch_v2:
combine_kwargs.update({
"x_active_mask": self.mc2_mask,
})
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
hidden_states[:] = combine(**combine_kwargs)