Files
xc-llm-ascend/vllm_ascend/distributed/moe_comm_method.py
yiz-liu dfc7eb39ad [Fix] Fix DP-related padding logic (#2582)
### What this PR does / why we need it?
The determination of attention state, padding, and other forward
metadata has been moved to an earlier stage within the input preparation
process. This change enables us to utilize a single all-reduce
operation, maximizing synchronization efficiency as early as possible.

The logic for synchronizing metadata—such as the number of tokens,
prefill status, and DBO status—across data parallel (DP) ranks has now
been unified and simplified.

For performance improvements, the all-reduce operation has been switched
from the `gloo` backend to the `npu` backend, which results in an
reduction of several milliseconds per step (**approximately 10%
performance gain for TPOT!**).

Additionally, the multi-DP server hang issue has been resolved, ensuring
no more hangs occur when `num_requests < dp_size`. Alas, a relief.

Finally, the miscalculated memory usage issue has been addressed by
removing the unnecessary `DummyCommImpl`, allowing the system to use the
real communication method when determining available memory.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Maybe we should add an test case for multi-DP online server?
@MengqingCao


- vLLM version: v0.10.1.1
- vLLM main:
c5d004aaaf

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-08-28 19:39:58 +08:00

460 lines
19 KiB
Python

from abc import ABC, abstractmethod
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_npu
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.distributed.communication_op import \
data_parallel_reduce_scatter
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
class MoECommMethod(ABC):
"""Base class for MoE communication methods."""
def __init__(self, moe_config: FusedMoEConfig):
self.moe_config = moe_config
@abstractmethod
def prepare(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Prepare the MoE communication method.
This method is called before quant_method.apply to prepare the
communication method. It can be used to initialize any necessary
resources or configurations.
"""
pass
@abstractmethod
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""Finalize the MoE communication method.
This method is called after quant_method.apply to finalize the
communication method. It can be used to clean up any resources or
configurations.
"""
pass
@abstractmethod
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""Pre-process before MLP.
Args:
hidden_states (torch.Tensor): Tensor of shape (num_tokens, hidden_size)
topk_ids (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
topk_weights (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
Mapping from global expert IDs to local expert IDs.
num_experts (int): Number of local experts (experts on this device).
Returns:
tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
- permuted_hidden_states (torch.Tensor): Tensor of shape
(num_tokens * top_k_num, hidden_size) after permuting
hidden_states based on topk_ids.
- expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
Number of tokens assigned to each expert.
- group_list_type (int): Type of group list, 0 for `cumsum`
and 1 for `count`. This is mainly for `npu_grouped_matmul`
to determine how to handle the output.
Raises:
NotImplementedError: If the method is not implemented in the subclass.
"""
pass
@abstractmethod
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
"""Post-process after MLP.
Args:
mlp_output (torch.Tensor): Tensor of shape
(num_tokens * top_k_num, hidden_size) after MLP.
hidden_states (torch.Tensor): Tensor of shape
(num_tokens, hidden_size) to be updated with the final output.
"""
pass
class AllGatherCommImpl(MoECommMethod):
"""This implementation is the same as NativeAllGatherCommImpl,
but uses NPU-specific ops for better performance.
This implementation should be compatible with all scenarios, and
thus it is the default implementation for MoE communication methods.
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
and `torch_npu.npu_moe_token_unpermute` for post-processing
to handle the token-to-expert mapping and communication efficiently.
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
for pre-processing and post-processing, respectively.
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
use `torch_npu.npu_moe_token_unpermute` instead.
This is a workaround and should be removed after the issue is fixed.
"""
def prepare(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""When DP size > 1, pad the hidden states and router logits for communication."""
if self.moe_config.dp_size > 1:
forward_context = get_forward_context()
max_tokens_across_dp = forward_context.max_tokens_across_dp
self.num_tokens = hidden_states.shape[0]
pad_size = max_tokens_across_dp - self.num_tokens
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
hidden_states = self.moe_config.dp_group.all_gather(
hidden_states, 0)
router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0)
return hidden_states, router_logits
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""When DP size > 1, reduce-scatter the hidden states to get the final output.
When TP size > 1, all-reduce the hidden states to get the final output.
"""
if self.moe_config.dp_size > 1:
hidden_states = data_parallel_reduce_scatter(hidden_states, dim=0)
hidden_states = hidden_states[:self.num_tokens]
if reduce_results and (self.moe_config.tp_size > 1
or self.moe_config.ep_size > 1):
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
return hidden_states
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor, # noqa: F841
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, int]:
num_tokens = hidden_states.shape[0]
self.topk_weights = topk_weights
self.topk_ids = topk_ids
first_expert_idx = 0
if expert_map is not None:
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
# So we need to filter out invalid tokens by zeroing their weights.
# This is a workaround and should be removed after the issue is fixed
mask = expert_map[topk_ids] != -1
# NOTE: This is equivalent to self.topk_weights[~mask] = 0.0,
# but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph
self.topk_weights = torch.where(mask, topk_weights, 0.0)
first_expert_idx = self.moe_config.ep_rank * num_experts
last_expert_idx = first_expert_idx + num_experts
permuted_hidden_states, expanded_row_idx, expert_tokens, _ = (
torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
active_num=num_tokens * self.moe_config.experts_per_token,
expert_num=self.moe_config.num_experts,
expert_tokens_num_type=1, # Only support `count` mode now
expert_tokens_num_flag=True, # Output `expert_tokens`
active_expert_range=[first_expert_idx, last_expert_idx],
quant_mode=-1,
))
self.expanded_row_idx = expanded_row_idx
permuted_hidden_states = permuted_hidden_states
group_list_type = 1 # `count` mode
return permuted_hidden_states, expert_tokens, group_list_type
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
hidden_states[:] = torch_npu.npu_moe_token_unpermute(
permuted_tokens=mlp_output,
sorted_indices=self.expanded_row_idx,
probs=self.topk_weights)
class NativeAllGatherCommImpl(AllGatherCommImpl):
"""This implementation should be compatible with all scenarios.
Note that this implementation purely consists of native PyTorch ops
and does not use any NPU-specific ops. So the performance may not be optimal.
But it is a good fallback for scenarios where NPU-specific ops are not available.
"""
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, int]:
num_tokens = hidden_states.shape[0]
# Generate token indices and flatten
token_indices = torch.arange(num_tokens,
device=hidden_states.device,
dtype=torch.int64)
token_indices = (token_indices.unsqueeze(1).expand(
-1, self.moe_config.experts_per_token).reshape(-1))
# Flatten token-to-expert mappings and map to local experts
weights_flat = topk_weights.view(-1)
experts_flat = topk_ids.view(-1)
local_experts_flat = (expert_map[experts_flat]
if expert_map is not None else experts_flat)
# Filter valid token-expert pairs
mask = local_experts_flat != -1
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
# So we need to filter out invalid tokens by zeroing their weights.
# This is a workaround and should be removed after the issue is fixed
filtered_weights = torch.where(mask, weights_flat,
torch.zeros_like(weights_flat)).to(
topk_weights.dtype)
filtered_experts = torch.where(
mask,
local_experts_flat,
torch.full_like(local_experts_flat, num_experts),
).to(topk_ids.dtype)
# Sort by local expert IDs
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
self.sorted_token_indices = token_indices[sort_indices]
self.sorted_weights = filtered_weights[sort_indices]
# Compute token counts with minlength of num_experts
# This is equivalent to but faster than:
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
token_counts = torch.zeros(num_experts + 1,
device=hidden_states.device,
dtype=torch.int64)
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
expert_tokens = token_counts[:num_experts]
# Rearrange hidden_states
permuted_hidden_states = hidden_states[self.sorted_token_indices]
group_list_type = 1 # `count` mode
return permuted_hidden_states, expert_tokens, group_list_type
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
final_hidden_states = torch.zeros_like(hidden_states)
final_hidden_states.index_add_(0, self.sorted_token_indices,
mlp_output)
hidden_states[:] = final_hidden_states
class MC2CommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
3. `enable_expert_parallel=False` is not supported.
This implementation uses the MC2 communication method, which is optimized for
Communication and Computation parallelism on Ascend devices.
"""
def __init__(self, moe_config: Optional[FusedMoEConfig]):
super().__init__(moe_config)
# NOTE: We do not need to use mc2_group's rank and world size
# because ep_group and mc2_group basically have the same init params.
# We only init another group because of the restriction of MC2:
# "No other groups can be used in the same process as the MC2 group."
self.mc2_comm_name = get_mc2_group().device_group._get_backend(
torch.device("npu")).get_hccl_comm_name(self.moe_config.ep_rank)
# Feature flags
self.enable_dispatch_v2 = hasattr(torch_npu,
"npu_moe_distribute_dispatch_v2")
self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3
self.need_extra_args = self.is_ascend_a3
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
# tp_size and tp_rank.
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""The target_pad_length is calculated in forward_context, here we pad the
hidden states and router logits. And if TP size > 1, we also need to split
the tensors accordingly.
"""
self.num_tokens, _ = hidden_states.shape
forward_context = get_forward_context()
self.mc2_mask = forward_context.mc2_mask
target_pad_length = forward_context.padded_num_tokens
pad_size = target_pad_length - self.num_tokens
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
split_mc2_mask = torch.tensor_split(self.mc2_mask,
self.tp_size,
dim=0)
self.split_hidden_states = split_hidden_states
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
self.mc2_mask = split_mc2_mask[self.tp_rank]
return hidden_states, router_logits
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""If TP size > 1, all-gather the hidden states to get the final output.
Also, unpad the hidden states if needed.
"""
if self.tp_size > 1:
dist.all_gather(list(self.split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(self.split_hidden_states, dim=0)
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, int]:
# Store tensors needed for post_process
self.topk_ids = topk_ids
self.topk_weights = topk_weights.to(torch.float32)
dispatch_kwargs = {
"x": hidden_states,
"expert_ids": self.topk_ids,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": self.moe_config.num_experts,
"global_bs": 0,
"scales": None,
"quant_mode": 0,
"group_ep": self.mc2_comm_name,
"ep_world_size": self.moe_config.ep_size,
"ep_rank_id": self.moe_config.ep_rank,
}
if self.need_extra_args:
dispatch_kwargs.update({
"group_tp": self.mc2_comm_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.is_ascend_a3 and self.enable_dispatch_v2:
dispatch_kwargs.update({
"x_active_mask": self.mc2_mask,
})
dispatch = torch_npu.npu_moe_distribute_dispatch_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch
(
permuted_hidden_states,
_, # dynamic_scale is not used
self.assist_info_for_combine,
expert_tokens,
self.ep_recv_counts,
self.tp_recv_counts,
) = dispatch(**dispatch_kwargs)[:6]
group_list_type = 1
return permuted_hidden_states, expert_tokens, group_list_type
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
combine_kwargs = {
"expand_x": mlp_output,
"expert_ids": self.topk_ids,
"expert_scales": self.topk_weights,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": self.moe_config.num_experts,
"global_bs": 0,
"ep_send_counts": self.ep_recv_counts,
"group_ep": self.mc2_comm_name,
"ep_world_size": self.moe_config.ep_size,
"ep_rank_id": self.moe_config.ep_rank,
}
if self.enable_dispatch_v2:
combine_kwargs[
"assist_info_for_combine"] = self.assist_info_for_combine
else:
combine_kwargs["expand_idx"] = self.assist_info_for_combine
if self.need_extra_args:
combine_kwargs.update({
"tp_send_counts": self.tp_recv_counts,
"group_tp": self.mc2_comm_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.is_ascend_a3 and self.enable_dispatch_v2:
combine_kwargs.update({
"x_active_mask": self.mc2_mask,
})
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
hidden_states[:] = combine(**combine_kwargs)