[Perf][MoE] Improve MoE multistream parallel performace. (#1891)

This PR designs the shared expert multi-stream parallelism of
w8a8-dynamic-quantized MoE stage in more detail to achieve better
performance.

- vLLM version: v0.10.0
- vLLM main:
2cc571199b

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-07-29 23:53:19 +08:00
committed by GitHub
parent 4df8e0027c
commit b6a7f07c70
3 changed files with 124 additions and 14 deletions

View File

@@ -393,7 +393,7 @@ class CustomDeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits = None
if not self.rm_router_logits:
if not self.rm_router_logits and not self.enable_multistream_moe:
router_logits, _ = self.gate(hidden_states)
experts_hidden_states = self.experts(

View File

@@ -1334,6 +1334,21 @@ class AscendFusedMoE(FusedMoE):
forward_context = get_forward_context()
fused_moe_state = forward_context.fused_moe_state
mc2_mask = forward_context.mc2_mask
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
quantized_x_for_share, dynamic_scale_for_share = None, None
from vllm_ascend.quantization.w8a8_dynamic import \
AscendW8A8DynamicFusedMoEMethod
if self.enable_multistream_moe:
if not self.rm_router_logits:
router_logits, _ = gate(hidden_states)
if hasattr(self.quant_method, "quant_method") and \
isinstance(self.quant_method.quant_method,
AscendW8A8DynamicFusedMoEMethod
) and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
hidden_states)
if shared_experts:
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
@@ -1419,6 +1434,8 @@ class AscendFusedMoE(FusedMoE):
shared_experts=shared_experts if self.torchair_graph_enabled
and self.enable_multistream_moe and not is_prefill else None,
mc2_mask=mc2_mask,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
)
if shared_experts:

View File

@@ -33,6 +33,82 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
dispose_tensor, get_ascend_soc_version)
def apply_mlp_decode(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
group_list: torch.Tensor,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1) -> torch.Tensor:
"""
apply MLP: gate_up_proj -> swiglu -> down_proj
Args:
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
w1: expert weights1 with shape
(num_experts, hidden_size, intermediate_size * 2)
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
w2: expert weights2 with shape
(num_experts, intermediate_size, hidden_size)
w2_scale: weights2 scale with shape (num_experts, hidden_size)
group_list: number of tokens for each expert, follow cumsum mode, and
with shape (num_experts).
transpose_weight:
w1: (num_experts, intermediate_size * 2, hidden_size) ->
(num_experts, hidden_size, intermediate_size * 2)
w2: (num_experts, hidden_size, intermediate_size) ->
(num_experts, intermediate_size, hidden_size)
Returns:
hidden_states: output hidden states after MLP.
"""
if dynamic_scale is None:
unquantized_hidden_states = hidden_states
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
hidden_states)
# Dispose the original unquantized hidden states
# to save npu memory because they're no longer used.
dispose_tensor(unquantized_hidden_states)
else:
pertoken_scale = dynamic_scale
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32)[0]
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=w1_scale,
activation_scale=pertoken_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
activate_left=True,
quant_mode=1,
)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
scale=[w2_scale],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=w2_scale.dtype)[0]
return hidden_states
def apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
@@ -124,6 +200,8 @@ def fused_experts_with_mc2(
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
shared_gate_up: Optional[Any] = None,
shared_dequant_scale: Optional[Any] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
assert mc2_mask is not None
if log2phy is not None:
@@ -186,18 +264,19 @@ def fused_experts_with_mc2(
if shared_experts is not None:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(quantized_x_for_share, expand_x)
npu_wait_tensor(shared_gate_up, expand_x)
shared_act_out = shared_experts.act_fn(
(quantized_x_for_share, dynamic_scale_for_share))
(shared_gate_up, shared_dequant_scale))
shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1]
down_out_list = apply_mlp(expand_x,
w1,
w1_scale,
w2,
w2_scale,
expert_token_nums,
dynamic_scale=dynamic_scale)
# `expand_x` will be disposed in the `apply_mlp` function
down_out_list = apply_mlp_decode(expand_x,
w1,
w1_scale,
w2,
w2_scale,
expert_token_nums,
dynamic_scale=dynamic_scale)
# moeCombine
kwargs_mc2 = {
@@ -745,6 +824,8 @@ class AscendW8A8DynamicFusedMoEMethod:
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
@@ -781,6 +862,16 @@ class AscendW8A8DynamicFusedMoEMethod:
e_score_correction_bias=e_score_correction_bias,
)
fused_moe_state = get_forward_context().fused_moe_state
shared_gate_up, shared_dequant_scale = None, None
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(quantized_x_for_share, router_logits)
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
@@ -788,8 +879,6 @@ class AscendW8A8DynamicFusedMoEMethod:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
topk_weights = topk_weights.to(x.dtype)
fused_moe_state = get_forward_context().fused_moe_state
if fused_moe_state == FusedMoEState.AllGatherEP:
return fused_experts_with_allgather(
hidden_states=x,
@@ -806,7 +895,7 @@ class AscendW8A8DynamicFusedMoEMethod:
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale,
w1_scale=layer.w13_weight_scale_fp32,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
@@ -817,7 +906,9 @@ class AscendW8A8DynamicFusedMoEMethod:
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
is_torchair=self.torchair_graph_enabled,
mc2_mask=kwargs.get("mc2_mask", None))
mc2_mask=kwargs.get("mc2_mask", None),
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale)
elif fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
]:
@@ -860,6 +951,8 @@ class AscendW8A8DynamicFusedMoEMethod:
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
layer.w13_weight_scale.data.shape[0], -1)
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
torch.float32)
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
layer.w13_weight_offset.data.shape[0], -1)
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(