[CI]Moe alltoall communication optimization (#1067)
[CI]Moe alltoall communication optimization The DeepSeek V3/R1 model has 256 routing experts. During parallel inference, if the load of an EP rank is high, the overall communication and computing time is slowed down, which becomes a weakness of parallel inference because the load is unevenly distributed. However, the data volume in the prefill phase is large, and the inter-card communication time consumption/calculation time consumption and the data volume are closely related to each other. Therefore, less non-linear precision loss can be used to obtain a near-linear performance improvement. During parallel inference, global synchronization occurs during communication. As a result, the card with low load completes the calculation first and waits for the card with the highest load to complete the calculation. Therefore, if the load is unbalanced, the card with high load slows down the overall time consumption. Significant performance gains can be achieved by discarding a small number of tokens, which is unacceptable in some precision-sensitive scenarios. However, similar to quantification, it is a solution that uses an acceptable precision loss in some scenarios for performance. In addition, a trade-off between performance and precision can be achieved by configuring a proportion of discarded tokens. Perform the test on A3. The batch size is 8 (B), the prompt length is 3.5K tokens (S), and the parallel configuration is as follows: AttnDP=2, AttnTP=8, MoeTP=1, and MoeEP=16. In this sence, we got a 10%-15% performance gain. Plus, the next version, we'll have an alltoallv moe. --------- Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -112,6 +112,11 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0'))
|
||||
),
|
||||
# MOE_ALL2ALL_BUFFER:
|
||||
# 0: default, normal init.
|
||||
# 1: enable moe_all2all_buffer.
|
||||
"MOE_ALL2ALL_BUFFER":
|
||||
lambda: bool(int(os.getenv("MOE_ALL2ALL_BUFFER", '0'))),
|
||||
# VLLM_ASCEND_ACL_OP_INIT_MODE:
|
||||
# 0: default, normal init.
|
||||
# 1: delay init until launch aclops.
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/kernels/test_moe.py
|
||||
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -37,6 +37,71 @@ from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
|
||||
|
||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
|
||||
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
||||
|
||||
|
||||
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
|
||||
max_row_per_ep_rank: int, num_tokens: int,
|
||||
top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
original_total_elements = num_tokens * top_k
|
||||
device = topk_ids.device
|
||||
original_dtype = topk_ids.dtype
|
||||
|
||||
if original_total_elements == 0:
|
||||
output_len = ep_size * max_row_per_ep_rank
|
||||
topk_ids_pad = torch.full((output_len, ),
|
||||
expert_num,
|
||||
dtype=original_dtype,
|
||||
device=device)
|
||||
unpad_indices = torch.full((original_total_elements, ),
|
||||
-1,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
return topk_ids_pad, unpad_indices
|
||||
|
||||
experts_per_ep_rank_val = expert_num // ep_size
|
||||
if experts_per_ep_rank_val == 0:
|
||||
raise ValueError(
|
||||
"expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. "
|
||||
"Ensure expert_num >= ep_size.")
|
||||
|
||||
assigned_ep_rank = (topk_ids.float() /
|
||||
experts_per_ep_rank_val).to(original_dtype)
|
||||
indices_arange = torch.arange(topk_ids.shape[0], device=device)
|
||||
|
||||
is_new_segment = torch.cat((torch.tensor([True], device=device),
|
||||
assigned_ep_rank[1:] != assigned_ep_rank[:-1]))
|
||||
temp_start_markers = torch.full_like(indices_arange,
|
||||
-1,
|
||||
dtype=indices_arange.dtype)
|
||||
temp_start_markers[is_new_segment] = indices_arange[is_new_segment]
|
||||
start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0]
|
||||
token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token
|
||||
is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank
|
||||
cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long)
|
||||
indices_in_rec_cond_list_for_all = cumsum_kept - 1
|
||||
unpad_indices = torch.where(
|
||||
is_kept_mask, indices_in_rec_cond_list_for_all,
|
||||
torch.tensor(-1, device=device, dtype=torch.long))
|
||||
output_len = ep_size * max_row_per_ep_rank
|
||||
topk_ids_pad = torch.full((output_len, ),
|
||||
expert_num,
|
||||
dtype=original_dtype,
|
||||
device=device)
|
||||
if topk_ids.shape[0] > 0:
|
||||
all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx
|
||||
temp_pad_buffer = torch.full((output_len + 1, ),
|
||||
expert_num,
|
||||
dtype=original_dtype,
|
||||
device=device)
|
||||
output_len_tensor = torch.tensor(output_len,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
scatter_indices = torch.where(is_kept_mask, all_destination_indices,
|
||||
output_len_tensor)
|
||||
temp_pad_buffer.scatter_(0, scatter_indices, topk_ids)
|
||||
topk_ids_pad = temp_pad_buffer[:output_len]
|
||||
return topk_ids_pad, unpad_indices
|
||||
|
||||
|
||||
def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||
@@ -146,8 +211,62 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||
return hidden_states
|
||||
|
||||
|
||||
# currently expert parallelism implemented with all2all
|
||||
# is under-optimized.
|
||||
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1) -> torch.Tensor:
|
||||
"""
|
||||
apply MLP: gate_up_proj -> swiglu -> down_proj
|
||||
|
||||
Args:
|
||||
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
|
||||
w1: expert weights1 with shape
|
||||
(num_experts, hidden_size, intermediate_size * 2)
|
||||
w2: expert weights2 with shape
|
||||
(num_experts, intermediate_size, hidden_size)
|
||||
group_list: number of tokens for each expert, follow cumsum mode, and
|
||||
with shape (num_experts).
|
||||
transpose_weight:
|
||||
w1: (num_experts, intermediate_size * 2, hidden_size) ->
|
||||
(num_experts, hidden_size, intermediate_size * 2)
|
||||
w2: (num_experts, hidden_size, intermediate_size) ->
|
||||
(num_experts, intermediate_size, hidden_size)
|
||||
|
||||
Returns:
|
||||
hidden_states: output hidden states after MLP.
|
||||
"""
|
||||
|
||||
assert len(hidden_states_wrapper) == 1
|
||||
hidden_states = hidden_states_wrapper.pop()
|
||||
|
||||
w1 = w1.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat(hidden_states, dim=0)
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat(hidden_states, dim=0)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def fused_experts_with_all2all(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -283,6 +402,133 @@ def fused_experts_with_all2all(
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
# currently expert parallelism implemented with all2all
|
||||
# is under-optimized.
|
||||
def fused_experts_with_all2all_buffer(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
max_model_len: int,
|
||||
global_batch_size: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
ep_group: GroupCoordinator = None,
|
||||
):
|
||||
original_shape = hidden_states.shape
|
||||
if len(original_shape) == 3:
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
device = hidden_states.device
|
||||
|
||||
global_num_experts = len(expert_map)
|
||||
local_num_experts = global_num_experts // ep_group.world_size
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32,
|
||||
device=device).view(top_k,
|
||||
-1).permute(1, 0).contiguous())
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
expert_idx=topk_ids,
|
||||
active_num=num_tokens)
|
||||
|
||||
max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) *
|
||||
max_model_len // ep_group.world_size +
|
||||
1) * top_k * 2
|
||||
expert_idx_buffer_scatter, unpad_indices = process_topk_ids(
|
||||
expanded_expert_idx, global_num_experts, ep_group.world_size,
|
||||
max_row_per_ep_rank, num_tokens, top_k)
|
||||
hidden_states_pad_idx = torch.zeros(
|
||||
expert_idx_buffer_scatter.shape,
|
||||
dtype=expert_idx_buffer_scatter.dtype,
|
||||
device=expert_idx_buffer_scatter.device)
|
||||
non_pad_len = torch.sum(
|
||||
(expert_idx_buffer_scatter != global_num_experts).to(torch.int32))
|
||||
hidden_states_pad_idx[
|
||||
expert_idx_buffer_scatter != global_num_experts] = torch.arange(
|
||||
non_pad_len,
|
||||
dtype=expert_idx_buffer_scatter.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx]
|
||||
expert_idx_buffer_gather = torch.empty_like(
|
||||
expert_idx_buffer_scatter,
|
||||
dtype=expert_idx_buffer_scatter.dtype,
|
||||
device=expert_idx_buffer_scatter.device)
|
||||
hidden_states_buffer_gather = torch.empty_like(
|
||||
hidden_states_buffer_scatter,
|
||||
dtype=hidden_states_buffer_scatter.dtype,
|
||||
device=hidden_states_buffer_scatter.device)
|
||||
dist.all_to_all_single(expert_idx_buffer_gather,
|
||||
expert_idx_buffer_scatter,
|
||||
group=ep_group.device_group)
|
||||
dist.all_to_all_single(hidden_states_buffer_gather,
|
||||
hidden_states_buffer_scatter,
|
||||
group=ep_group.device_group)
|
||||
mask = expert_idx_buffer_gather != global_num_experts
|
||||
local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * (
|
||||
global_num_experts // ep_group.world_size)
|
||||
hidden_states = hidden_states_buffer_gather[mask]
|
||||
idx_type = local_expert_idx.dtype
|
||||
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float())
|
||||
sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type)
|
||||
|
||||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||
sorted_local_expert_idx, local_num_experts).to(torch.int64)
|
||||
hidden_states = hidden_states[sorted_idx]
|
||||
group_list_type = 0
|
||||
|
||||
hidden_states_wrapper = [hidden_states]
|
||||
del hidden_states
|
||||
|
||||
hidden_states = apply_mlp(hidden_states_wrapper,
|
||||
w1,
|
||||
w2,
|
||||
expert_tokens,
|
||||
group_list_type=group_list_type)
|
||||
|
||||
resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype)
|
||||
hidden_states = hidden_states[resorted_idx]
|
||||
hidden_states_scatter = torch.zeros(
|
||||
(mask.shape[0], hidden_states.shape[1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
hidden_states_scatter[mask] = hidden_states
|
||||
hidden_states_gatter = torch.empty_like(
|
||||
hidden_states_scatter,
|
||||
dtype=hidden_states_scatter.dtype,
|
||||
device=hidden_states_scatter.device)
|
||||
dist.all_to_all_single(hidden_states_gatter,
|
||||
hidden_states_scatter,
|
||||
group=ep_group.device_group)
|
||||
hidden_states_gatter = hidden_states_gatter[
|
||||
expert_idx_buffer_scatter != global_num_experts]
|
||||
if hidden_states_gatter.shape[0] != row_idx_len:
|
||||
hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
hidden_states[unpad_indices != -1] = hidden_states_gatter
|
||||
else:
|
||||
# TODO: Reorder device memory 2 times here, replace the current
|
||||
hidden_states = hidden_states_gatter
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=topk_weights,
|
||||
expanded_src_to_dst_row=expanded_row_idx,
|
||||
export_for_source_row=topk_ids,
|
||||
)
|
||||
|
||||
if len(original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(original_shape)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -585,6 +831,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
self.ep_size = ep_group.world_size
|
||||
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.local_batch_size = self.global_batch_size // self.ep_size
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
@@ -613,21 +860,22 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = False,
|
||||
enable_force_load_balance: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if global_num_experts == 256:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
@@ -683,11 +931,19 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map)
|
||||
elif MOE_ALL2ALL_BUFFER:
|
||||
return fused_experts_with_all2all_buffer(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
max_model_len=self.max_model_len,
|
||||
global_batch_size=self.global_batch_size,
|
||||
expert_map=expert_map,
|
||||
ep_group=get_ep_group())
|
||||
else:
|
||||
# The current implementation of deepseek moe splits hidden_states
|
||||
# according to tp_size before they are feed into fused_moe module.
|
||||
# Therefore, all2all is needed no matter how dp/tp is set so as to
|
||||
# dispatch/combine tokens.
|
||||
return fused_experts_with_all2all(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
|
||||
Reference in New Issue
Block a user