### What this PR does / why we need it?
The determination of attention state, padding, and other forward
metadata has been moved to an earlier stage within the input preparation
process. This change enables us to utilize a single all-reduce
operation, maximizing synchronization efficiency as early as possible.
The logic for synchronizing metadata—such as the number of tokens,
prefill status, and DBO status—across data parallel (DP) ranks has now
been unified and simplified.
For performance improvements, the all-reduce operation has been switched
from the `gloo` backend to the `npu` backend, which results in an
reduction of several milliseconds per step (**approximately 10%
performance gain for TPOT!**).
Additionally, the multi-DP server hang issue has been resolved, ensuring
no more hangs occur when `num_requests < dp_size`. Alas, a relief.
Finally, the miscalculated memory usage issue has been addressed by
removing the unnecessary `DummyCommImpl`, allowing the system to use the
real communication method when determining available memory.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
Maybe we should add an test case for multi-DP online server?
@MengqingCao
- vLLM version: v0.10.1.1
- vLLM main:
c5d004aaaf
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
460 lines
19 KiB
Python
460 lines
19 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
import torch_npu
|
|
from vllm.distributed import tensor_model_parallel_all_reduce
|
|
from vllm.distributed.parallel_state import (
|
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
|
|
|
from vllm_ascend.distributed.communication_op import \
|
|
data_parallel_reduce_scatter
|
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
|
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
|
|
|
|
|
class MoECommMethod(ABC):
|
|
"""Base class for MoE communication methods."""
|
|
|
|
def __init__(self, moe_config: FusedMoEConfig):
|
|
self.moe_config = moe_config
|
|
|
|
@abstractmethod
|
|
def prepare(
|
|
self, hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Prepare the MoE communication method.
|
|
|
|
This method is called before quant_method.apply to prepare the
|
|
communication method. It can be used to initialize any necessary
|
|
resources or configurations.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def finalize(self, hidden_states: torch.Tensor,
|
|
reduce_results: bool) -> torch.Tensor:
|
|
"""Finalize the MoE communication method.
|
|
|
|
This method is called after quant_method.apply to finalize the
|
|
communication method. It can be used to clean up any resources or
|
|
configurations.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def permute(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
expert_map: torch.Tensor,
|
|
num_experts: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
|
"""Pre-process before MLP.
|
|
|
|
Args:
|
|
hidden_states (torch.Tensor): Tensor of shape (num_tokens, hidden_size)
|
|
topk_ids (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
|
|
topk_weights (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
|
|
expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
|
|
Mapping from global expert IDs to local expert IDs.
|
|
num_experts (int): Number of local experts (experts on this device).
|
|
|
|
Returns:
|
|
tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
|
|
- permuted_hidden_states (torch.Tensor): Tensor of shape
|
|
(num_tokens * top_k_num, hidden_size) after permuting
|
|
hidden_states based on topk_ids.
|
|
- expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
|
|
Number of tokens assigned to each expert.
|
|
- group_list_type (int): Type of group list, 0 for `cumsum`
|
|
and 1 for `count`. This is mainly for `npu_grouped_matmul`
|
|
to determine how to handle the output.
|
|
Raises:
|
|
NotImplementedError: If the method is not implemented in the subclass.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def unpermute(self, mlp_output: torch.Tensor,
|
|
hidden_states: torch.Tensor) -> None:
|
|
"""Post-process after MLP.
|
|
|
|
Args:
|
|
mlp_output (torch.Tensor): Tensor of shape
|
|
(num_tokens * top_k_num, hidden_size) after MLP.
|
|
hidden_states (torch.Tensor): Tensor of shape
|
|
(num_tokens, hidden_size) to be updated with the final output.
|
|
"""
|
|
pass
|
|
|
|
|
|
class AllGatherCommImpl(MoECommMethod):
|
|
"""This implementation is the same as NativeAllGatherCommImpl,
|
|
but uses NPU-specific ops for better performance.
|
|
|
|
This implementation should be compatible with all scenarios, and
|
|
thus it is the default implementation for MoE communication methods.
|
|
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
|
|
and `torch_npu.npu_moe_token_unpermute` for post-processing
|
|
to handle the token-to-expert mapping and communication efficiently.
|
|
|
|
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
|
|
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
|
|
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
|
|
for pre-processing and post-processing, respectively.
|
|
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
|
|
use `torch_npu.npu_moe_token_unpermute` instead.
|
|
This is a workaround and should be removed after the issue is fixed.
|
|
"""
|
|
|
|
def prepare(
|
|
self, hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""When DP size > 1, pad the hidden states and router logits for communication."""
|
|
if self.moe_config.dp_size > 1:
|
|
forward_context = get_forward_context()
|
|
max_tokens_across_dp = forward_context.max_tokens_across_dp
|
|
|
|
self.num_tokens = hidden_states.shape[0]
|
|
pad_size = max_tokens_across_dp - self.num_tokens
|
|
if pad_size > 0:
|
|
hidden_states = nn.functional.pad(hidden_states,
|
|
(0, 0, 0, pad_size))
|
|
router_logits = nn.functional.pad(router_logits,
|
|
(0, 0, 0, pad_size))
|
|
|
|
hidden_states = self.moe_config.dp_group.all_gather(
|
|
hidden_states, 0)
|
|
router_logits = self.moe_config.dp_group.all_gather(
|
|
router_logits, 0)
|
|
|
|
return hidden_states, router_logits
|
|
|
|
def finalize(self, hidden_states: torch.Tensor,
|
|
reduce_results: bool) -> torch.Tensor:
|
|
"""When DP size > 1, reduce-scatter the hidden states to get the final output.
|
|
|
|
When TP size > 1, all-reduce the hidden states to get the final output.
|
|
"""
|
|
if self.moe_config.dp_size > 1:
|
|
hidden_states = data_parallel_reduce_scatter(hidden_states, dim=0)
|
|
hidden_states = hidden_states[:self.num_tokens]
|
|
|
|
if reduce_results and (self.moe_config.tp_size > 1
|
|
or self.moe_config.ep_size > 1):
|
|
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
def permute(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
expert_map: torch.Tensor, # noqa: F841
|
|
num_experts: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
|
num_tokens = hidden_states.shape[0]
|
|
|
|
self.topk_weights = topk_weights
|
|
self.topk_ids = topk_ids
|
|
|
|
first_expert_idx = 0
|
|
if expert_map is not None:
|
|
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
|
# So we need to filter out invalid tokens by zeroing their weights.
|
|
# This is a workaround and should be removed after the issue is fixed
|
|
mask = expert_map[topk_ids] != -1
|
|
# NOTE: This is equivalent to self.topk_weights[~mask] = 0.0,
|
|
# but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph
|
|
self.topk_weights = torch.where(mask, topk_weights, 0.0)
|
|
|
|
first_expert_idx = self.moe_config.ep_rank * num_experts
|
|
last_expert_idx = first_expert_idx + num_experts
|
|
|
|
permuted_hidden_states, expanded_row_idx, expert_tokens, _ = (
|
|
torch_npu.npu_moe_init_routing_v2(
|
|
hidden_states,
|
|
topk_ids,
|
|
active_num=num_tokens * self.moe_config.experts_per_token,
|
|
expert_num=self.moe_config.num_experts,
|
|
expert_tokens_num_type=1, # Only support `count` mode now
|
|
expert_tokens_num_flag=True, # Output `expert_tokens`
|
|
active_expert_range=[first_expert_idx, last_expert_idx],
|
|
quant_mode=-1,
|
|
))
|
|
self.expanded_row_idx = expanded_row_idx
|
|
permuted_hidden_states = permuted_hidden_states
|
|
|
|
group_list_type = 1 # `count` mode
|
|
|
|
return permuted_hidden_states, expert_tokens, group_list_type
|
|
|
|
def unpermute(self, mlp_output: torch.Tensor,
|
|
hidden_states: torch.Tensor) -> None:
|
|
hidden_states[:] = torch_npu.npu_moe_token_unpermute(
|
|
permuted_tokens=mlp_output,
|
|
sorted_indices=self.expanded_row_idx,
|
|
probs=self.topk_weights)
|
|
|
|
|
|
class NativeAllGatherCommImpl(AllGatherCommImpl):
|
|
"""This implementation should be compatible with all scenarios.
|
|
|
|
Note that this implementation purely consists of native PyTorch ops
|
|
and does not use any NPU-specific ops. So the performance may not be optimal.
|
|
But it is a good fallback for scenarios where NPU-specific ops are not available.
|
|
"""
|
|
|
|
def permute(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
expert_map: torch.Tensor,
|
|
num_experts: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
|
num_tokens = hidden_states.shape[0]
|
|
|
|
# Generate token indices and flatten
|
|
token_indices = torch.arange(num_tokens,
|
|
device=hidden_states.device,
|
|
dtype=torch.int64)
|
|
token_indices = (token_indices.unsqueeze(1).expand(
|
|
-1, self.moe_config.experts_per_token).reshape(-1))
|
|
|
|
# Flatten token-to-expert mappings and map to local experts
|
|
weights_flat = topk_weights.view(-1)
|
|
experts_flat = topk_ids.view(-1)
|
|
local_experts_flat = (expert_map[experts_flat]
|
|
if expert_map is not None else experts_flat)
|
|
|
|
# Filter valid token-expert pairs
|
|
mask = local_experts_flat != -1
|
|
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
|
# So we need to filter out invalid tokens by zeroing their weights.
|
|
# This is a workaround and should be removed after the issue is fixed
|
|
filtered_weights = torch.where(mask, weights_flat,
|
|
torch.zeros_like(weights_flat)).to(
|
|
topk_weights.dtype)
|
|
filtered_experts = torch.where(
|
|
mask,
|
|
local_experts_flat,
|
|
torch.full_like(local_experts_flat, num_experts),
|
|
).to(topk_ids.dtype)
|
|
|
|
# Sort by local expert IDs
|
|
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
|
self.sorted_token_indices = token_indices[sort_indices]
|
|
self.sorted_weights = filtered_weights[sort_indices]
|
|
|
|
# Compute token counts with minlength of num_experts
|
|
# This is equivalent to but faster than:
|
|
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
|
token_counts = torch.zeros(num_experts + 1,
|
|
device=hidden_states.device,
|
|
dtype=torch.int64)
|
|
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
|
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
|
expert_tokens = token_counts[:num_experts]
|
|
|
|
# Rearrange hidden_states
|
|
permuted_hidden_states = hidden_states[self.sorted_token_indices]
|
|
|
|
group_list_type = 1 # `count` mode
|
|
|
|
return permuted_hidden_states, expert_tokens, group_list_type
|
|
|
|
def unpermute(self, mlp_output: torch.Tensor,
|
|
hidden_states: torch.Tensor) -> None:
|
|
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
|
|
|
|
final_hidden_states = torch.zeros_like(hidden_states)
|
|
final_hidden_states.index_add_(0, self.sorted_token_indices,
|
|
mlp_output)
|
|
|
|
hidden_states[:] = final_hidden_states
|
|
|
|
|
|
class MC2CommImpl(MoECommMethod):
|
|
"""This implementation is for the scenarios listed below:
|
|
1. `enable_expert_parallel=True`.
|
|
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
|
|
3. `enable_expert_parallel=False` is not supported.
|
|
|
|
This implementation uses the MC2 communication method, which is optimized for
|
|
Communication and Computation parallelism on Ascend devices.
|
|
"""
|
|
|
|
def __init__(self, moe_config: Optional[FusedMoEConfig]):
|
|
super().__init__(moe_config)
|
|
|
|
# NOTE: We do not need to use mc2_group's rank and world size
|
|
# because ep_group and mc2_group basically have the same init params.
|
|
# We only init another group because of the restriction of MC2:
|
|
# "No other groups can be used in the same process as the MC2 group."
|
|
self.mc2_comm_name = get_mc2_group().device_group._get_backend(
|
|
torch.device("npu")).get_hccl_comm_name(self.moe_config.ep_rank)
|
|
|
|
# Feature flags
|
|
self.enable_dispatch_v2 = hasattr(torch_npu,
|
|
"npu_moe_distribute_dispatch_v2")
|
|
self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3
|
|
self.need_extra_args = self.is_ascend_a3
|
|
self._restore_tp_across_dp()
|
|
|
|
def _restore_tp_across_dp(self):
|
|
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
|
|
# tp_size and tp_rank.
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.tp_rank = get_tensor_model_parallel_rank()
|
|
|
|
def prepare(
|
|
self, hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""The target_pad_length is calculated in forward_context, here we pad the
|
|
hidden states and router logits. And if TP size > 1, we also need to split
|
|
the tensors accordingly.
|
|
"""
|
|
self.num_tokens, _ = hidden_states.shape
|
|
forward_context = get_forward_context()
|
|
self.mc2_mask = forward_context.mc2_mask
|
|
target_pad_length = forward_context.padded_num_tokens
|
|
pad_size = target_pad_length - self.num_tokens
|
|
|
|
if pad_size > 0:
|
|
hidden_states = nn.functional.pad(hidden_states,
|
|
(0, 0, 0, pad_size))
|
|
router_logits = nn.functional.pad(router_logits,
|
|
(0, 0, 0, pad_size))
|
|
|
|
if self.tp_size > 1:
|
|
split_hidden_states = torch.tensor_split(hidden_states,
|
|
self.tp_size,
|
|
dim=0)
|
|
split_router_logits = torch.tensor_split(router_logits,
|
|
self.tp_size,
|
|
dim=0)
|
|
split_mc2_mask = torch.tensor_split(self.mc2_mask,
|
|
self.tp_size,
|
|
dim=0)
|
|
self.split_hidden_states = split_hidden_states
|
|
|
|
hidden_states = split_hidden_states[self.tp_rank]
|
|
router_logits = split_router_logits[self.tp_rank]
|
|
self.mc2_mask = split_mc2_mask[self.tp_rank]
|
|
|
|
return hidden_states, router_logits
|
|
|
|
def finalize(self, hidden_states: torch.Tensor,
|
|
reduce_results: bool) -> torch.Tensor:
|
|
"""If TP size > 1, all-gather the hidden states to get the final output.
|
|
|
|
Also, unpad the hidden states if needed.
|
|
"""
|
|
if self.tp_size > 1:
|
|
dist.all_gather(list(self.split_hidden_states), hidden_states,
|
|
self.moe_config.tp_group.device_group)
|
|
hidden_states = torch.cat(self.split_hidden_states, dim=0)
|
|
|
|
if self.num_tokens < hidden_states.shape[0]:
|
|
hidden_states = hidden_states[:self.num_tokens]
|
|
|
|
return hidden_states
|
|
|
|
def permute(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
expert_map: torch.Tensor,
|
|
num_experts: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
|
# Store tensors needed for post_process
|
|
self.topk_ids = topk_ids
|
|
self.topk_weights = topk_weights.to(torch.float32)
|
|
|
|
dispatch_kwargs = {
|
|
"x": hidden_states,
|
|
"expert_ids": self.topk_ids,
|
|
"expert_shard_type": 0,
|
|
"shared_expert_rank_num": 0,
|
|
"moe_expert_num": self.moe_config.num_experts,
|
|
"global_bs": 0,
|
|
"scales": None,
|
|
"quant_mode": 0,
|
|
"group_ep": self.mc2_comm_name,
|
|
"ep_world_size": self.moe_config.ep_size,
|
|
"ep_rank_id": self.moe_config.ep_rank,
|
|
}
|
|
|
|
if self.need_extra_args:
|
|
dispatch_kwargs.update({
|
|
"group_tp": self.mc2_comm_name,
|
|
"tp_world_size": 1,
|
|
"tp_rank_id": 0,
|
|
})
|
|
if self.is_ascend_a3 and self.enable_dispatch_v2:
|
|
dispatch_kwargs.update({
|
|
"x_active_mask": self.mc2_mask,
|
|
})
|
|
|
|
dispatch = torch_npu.npu_moe_distribute_dispatch_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch
|
|
|
|
(
|
|
permuted_hidden_states,
|
|
_, # dynamic_scale is not used
|
|
self.assist_info_for_combine,
|
|
expert_tokens,
|
|
self.ep_recv_counts,
|
|
self.tp_recv_counts,
|
|
) = dispatch(**dispatch_kwargs)[:6]
|
|
|
|
group_list_type = 1
|
|
|
|
return permuted_hidden_states, expert_tokens, group_list_type
|
|
|
|
def unpermute(self, mlp_output: torch.Tensor,
|
|
hidden_states: torch.Tensor) -> None:
|
|
combine_kwargs = {
|
|
"expand_x": mlp_output,
|
|
"expert_ids": self.topk_ids,
|
|
"expert_scales": self.topk_weights,
|
|
"expert_shard_type": 0,
|
|
"shared_expert_rank_num": 0,
|
|
"moe_expert_num": self.moe_config.num_experts,
|
|
"global_bs": 0,
|
|
"ep_send_counts": self.ep_recv_counts,
|
|
"group_ep": self.mc2_comm_name,
|
|
"ep_world_size": self.moe_config.ep_size,
|
|
"ep_rank_id": self.moe_config.ep_rank,
|
|
}
|
|
|
|
if self.enable_dispatch_v2:
|
|
combine_kwargs[
|
|
"assist_info_for_combine"] = self.assist_info_for_combine
|
|
else:
|
|
combine_kwargs["expand_idx"] = self.assist_info_for_combine
|
|
|
|
if self.need_extra_args:
|
|
combine_kwargs.update({
|
|
"tp_send_counts": self.tp_recv_counts,
|
|
"group_tp": self.mc2_comm_name,
|
|
"tp_world_size": 1,
|
|
"tp_rank_id": 0,
|
|
})
|
|
if self.is_ascend_a3 and self.enable_dispatch_v2:
|
|
combine_kwargs.update({
|
|
"x_active_mask": self.mc2_mask,
|
|
})
|
|
|
|
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
|
|
|
|
hidden_states[:] = combine(**combine_kwargs)
|