Support multistream of shared experts in FusedMoE (#997)

Contains on #1111 for completeness.

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
Implement multi-stream parallelism for MoE layers with shared experts,
where computation of shared experts will be overlapped with expert token
dispatch and combine. Also, when multi-stream is enabled, weights of
shared experts will be force to replicate across all cards, regardless
of any tensor parallelism configurations, to avoid AllReduce operations.

With the expected overlaping being:
```
| shared gate_up | shared act |              | shared down |
|    dispatch    | routed gate_up, act, down |   combine   |
```

<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->

### Does this PR introduce _any_ user-facing change?
No.

<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->

### How was this patch tested?
Tested on 1x16 910 node, with tailored 2 layer DSKv2.
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

---------

Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
This commit is contained in:
sdmyzlp
2025-06-11 09:18:38 +08:00
committed by GitHub
parent 04abfd8721
commit 7bdc606677
11 changed files with 296 additions and 308 deletions

View File

@@ -16,7 +16,7 @@
# Adapted from vllm/tests/kernels/test_moe.py
import os
from typing import Callable, List, Optional
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
@@ -36,6 +36,7 @@ import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
@@ -106,15 +107,17 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
return topk_ids_pad, unpad_indices
def fused_experts_with_mc2(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: Optional[str] = None,
**kwargs) -> torch.Tensor:
def fused_experts_with_mc2(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: Optional[str] = None,
shared_experts: Optional[Any] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
global_bs = 0
moe_expert_num = len(expert_map)
kwargs_mc2 = {
@@ -154,6 +157,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
0:5]
if shared_experts is not None:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(hidden_states, topk_weights)
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
npu_wait_tensor(shared_gate_up, expand_x)
shared_act = shared_experts.act_fn(shared_gate_up)
w1 = w1.transpose(1, 2)
group_list = expert_token_nums.to(torch.int64)
@@ -210,7 +220,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
return hidden_states
if shared_experts is None:
return hidden_states
else:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(shared_act, down_out_list)
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
return hidden_states, shared_hidden_states
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
@@ -875,6 +891,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = False,
enable_force_load_balance: bool = False,
shared_experts: Optional[Any] = None,
**kwargs,
) -> torch.Tensor:
@@ -924,7 +941,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
**kwargs)
shared_experts=shared_experts)
elif self.torchair_graph_enabled or get_ep_group().world_size == 1:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
@@ -1053,9 +1070,6 @@ class AscendFusedMoE(FusedMoE):
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
self.enable_multistream_shared_expert = \
ascend_config.torchair_graph_config.enable_multistream_shared_expert and VLLM_ENABLE_MC2
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
@@ -1102,8 +1116,8 @@ class AscendFusedMoE(FusedMoE):
router_logits: torch.Tensor,
is_prefill: bool,
enable_force_load_balance: bool = False,
top_k=None,
**kwargs):
top_k: Optional[int] = None,
shared_experts: Optional[Any] = None):
assert self.quant_method is not None
if top_k:
@@ -1132,7 +1146,7 @@ class AscendFusedMoE(FusedMoE):
hidden_states, router_logits)
# Matrix multiply.
hidden_states = self.quant_method.apply(
e_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
@@ -1150,36 +1164,39 @@ class AscendFusedMoE(FusedMoE):
enable_force_load_balance=enable_force_load_balance,
log2phy=self.log2phy,
global_redundant_expert_num=self.global_redundant_expert_num,
**kwargs)
shared_experts=shared_experts,
)
if self.enable_multistream_shared_expert and not is_prefill:
hidden_states, shared_output = hidden_states
if shared_experts is not None:
# Provide dummy implementation of "non-separated" shared experts.
if not isinstance(e_hidden_states, tuple):
return e_hidden_states, shared_experts(hidden_states)
else:
return e_hidden_states
if self.dp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
...
elif self.torchair_graph_enabled:
if USING_LCCL_COM: # type: ignore
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
hidden_states,
e_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
e_hidden_states,
"sum",
scatter_dim=0,
group=get_dp_group().device_group)
elif self.torchair_graph_enabled and not is_prefill:
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
hidden_states,
e_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
e_hidden_states,
"sum",
scatter_dim=0,
group=get_dp_group().device_group)
else:
hidden_states = get_ep_group().combine(hidden_states)
e_hidden_states = get_ep_group().combine(e_hidden_states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
e_hidden_states = tensor_model_parallel_all_reduce(e_hidden_states)
if self.enable_multistream_shared_expert and not is_prefill:
return hidden_states, shared_output
return hidden_states
return e_hidden_states
# ----------------------------------------- TBO-related --------------------------------------------