[Feature] Support moe multi-stream for aclgraph. (#2946)
This PR puts the calculation of shared experts into a separate stream,
overlaping with routing experts.
- vLLM version: v0.10.2
- vLLM main:
fbd6523ac0
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -61,6 +61,8 @@ class AscendConfig:
|
||||
self.enable_shared_expert_dp = additional_config.get(
|
||||
"enable_shared_expert_dp", False
|
||||
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
|
||||
self.multistream_overlap_shared_expert = additional_config.get(
|
||||
"multistream_overlap_shared_expert", False)
|
||||
self.enable_prefetch = additional_config.get("enable_prefetch", False)
|
||||
self.lmhead_tensor_parallel_size = additional_config.get(
|
||||
"lmhead_tensor_parallel_size", None)
|
||||
@@ -110,8 +112,6 @@ class TorchairGraphConfig:
|
||||
"graph_batch_sizes_init", False)
|
||||
self.enable_multistream_mla = torchair_graph_config.get(
|
||||
"enable_multistream_mla", False)
|
||||
self.enable_multistream_moe = torchair_graph_config.get(
|
||||
"enable_multistream_moe", False)
|
||||
self.enable_view_optimize = torchair_graph_config.get(
|
||||
"enable_view_optimize", True)
|
||||
self.enable_frozen_parameter = torchair_graph_config.get(
|
||||
@@ -148,10 +148,6 @@ class TorchairGraphConfig:
|
||||
raise RuntimeError(
|
||||
"enable_multistream_mla is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_multistream_moe:
|
||||
raise RuntimeError(
|
||||
"enable_multistream_moe is valid only when Torchair graph mode is enabled"
|
||||
)
|
||||
if self.enable_kv_nz:
|
||||
raise RuntimeError(
|
||||
"enable_kv_nz is valid only when Torchair graph mode is enabled"
|
||||
|
||||
@@ -37,7 +37,7 @@ from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl, MC2CommImpl,
|
||||
NaiveMulticastCommImpl)
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch
|
||||
|
||||
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
||||
|
||||
@@ -426,24 +426,39 @@ class AscendSharedFusedMoE(AscendFusedMoE):
|
||||
super().__init__(**kwargs)
|
||||
self._shared_experts = shared_experts
|
||||
self.use_overlapped = use_overlapped
|
||||
self.shared_expert_stream = None
|
||||
ascend_config = get_ascend_config()
|
||||
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
|
||||
if self.multistream_overlap_shared_expert:
|
||||
self.shared_expert_stream = torch.npu.Stream()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
shared_out = self._shared_experts(hidden_states)
|
||||
# Make sure the shared experts stream begins after hidden_states are ready.
|
||||
if self.multistream_overlap_shared_expert:
|
||||
self.shared_expert_stream.wait_stream( # type: ignore
|
||||
torch.npu.current_stream())
|
||||
with npu_stream_switch(self.shared_expert_stream,
|
||||
enabled=self.multistream_overlap_shared_expert):
|
||||
# Use a separate stream to run shared experts.
|
||||
shared_out = self._shared_experts(hidden_states)
|
||||
|
||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_method_name = forward_context.moe_comm_method_name
|
||||
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_method_name = forward_context.moe_comm_method_name
|
||||
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
|
||||
fused_out = super().forward(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
# Make sure the default stream waits for the shared experts stream to finish.
|
||||
if self.multistream_overlap_shared_expert:
|
||||
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
|
||||
return shared_out, fused_out
|
||||
|
||||
|
||||
|
||||
@@ -322,8 +322,8 @@ class TorchairDeepseekV2MoE(nn.Module):
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_multistream_moe = \
|
||||
ascend_config.torchair_graph_config.enable_multistream_moe and \
|
||||
self.multistream_overlap_shared_expert = \
|
||||
ascend_config.multistream_overlap_shared_expert and \
|
||||
self.torchair_graph_enabled
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
@@ -364,7 +364,7 @@ class TorchairDeepseekV2MoE(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
force_replicate=self.enable_multistream_moe
|
||||
force_replicate=self.multistream_overlap_shared_expert
|
||||
or enable_shared_expert_dp,
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
@@ -406,7 +406,7 @@ class TorchairDeepseekV2MoE(nn.Module):
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = None
|
||||
if not self.rm_router_logits and not self.enable_multistream_moe:
|
||||
if not self.rm_router_logits and not self.multistream_overlap_shared_expert:
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
experts_hidden_states = self.experts(
|
||||
@@ -524,7 +524,7 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
elif (config.n_routed_experts is not None
|
||||
and self.debug_layer_idx >= config.first_k_dense_replace
|
||||
and self.debug_layer_idx % config.moe_layer_freq == 0
|
||||
and (ascend_config.torchair_graph_config.enable_multistream_moe
|
||||
and (ascend_config.multistream_overlap_shared_expert
|
||||
or self.enable_shared_expert_dp)):
|
||||
self.o_proj = TorchairDeepseekV2RowParallelLinearReplaceAllreduce(
|
||||
self.num_heads * self.v_head_dim,
|
||||
@@ -697,7 +697,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \
|
||||
self.mla_moe_communication = ascend_config.multistream_overlap_shared_expert \
|
||||
and model_config.use_mla and self.tp_size > 1
|
||||
else:
|
||||
self.mlp = TorchairDeepseekV2MLP(
|
||||
|
||||
@@ -1049,8 +1049,8 @@ class TorchairAscendFusedMoE(FusedMoE):
|
||||
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
|
||||
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_multistream_moe = \
|
||||
ascend_config.torchair_graph_config.enable_multistream_moe and \
|
||||
self.multistream_overlap_shared_expert = \
|
||||
ascend_config.multistream_overlap_shared_expert and \
|
||||
self.torchair_graph_enabled
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
|
||||
@@ -1148,7 +1148,7 @@ class TorchairAscendFusedMoE(FusedMoE):
|
||||
quantized_x_for_share, dynamic_scale_for_share = None, None
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
|
||||
TorchairAscendW8A8DynamicFusedMoEMethod
|
||||
if self.enable_multistream_moe:
|
||||
if self.multistream_overlap_shared_expert:
|
||||
if not self.rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states)
|
||||
if hasattr(self.quant_method, "quant_method") and \
|
||||
@@ -1160,7 +1160,7 @@ class TorchairAscendFusedMoE(FusedMoE):
|
||||
hidden_states)
|
||||
|
||||
if shared_experts:
|
||||
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
|
||||
if not self.multistream_overlap_shared_expert or fused_moe_state != FusedMoEState.MC2:
|
||||
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
||||
shared_hidden_states = shared_experts(hidden_states)
|
||||
|
||||
@@ -1256,7 +1256,8 @@ class TorchairAscendFusedMoE(FusedMoE):
|
||||
log2phy=self.log2phy,
|
||||
global_redundant_expert_num=self.global_redundant_expert_num,
|
||||
shared_experts=shared_experts if self.torchair_graph_enabled
|
||||
and self.enable_multistream_moe and not is_prefill else None,
|
||||
and self.multistream_overlap_shared_expert and not is_prefill else
|
||||
None,
|
||||
mc2_mask=mc2_mask,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
|
||||
@@ -21,7 +21,7 @@ import atexit
|
||||
import functools
|
||||
import math
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from enum import Enum
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||
@@ -321,7 +321,9 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
if os.getenv("HCCL_OP_EXPANSION_MODE") == 'AIV':
|
||||
# TODO: Find out whether we need to take into account the pp_size
|
||||
parallel_factor = 1 + num_comm_groups + int(
|
||||
parallel_config.enable_expert_parallel)
|
||||
parallel_config.enable_expert_parallel) + int(
|
||||
vllm_config.additional_config.get(
|
||||
"multistream_overlap_shared_expert", False))
|
||||
if is_moe_model(vllm_config):
|
||||
parallel_factor += (parallel_config.data_parallel_size > 1)
|
||||
# Calculate maximum supported batch sizes considering model architecture on the A2 Hardware Device
|
||||
@@ -617,3 +619,16 @@ def weak_ref_tensors(
|
||||
if isinstance(tensors, tuple):
|
||||
return tuple(weak_ref_tensor(t) for t in tensors)
|
||||
raise ValueError("Invalid type for tensors")
|
||||
|
||||
|
||||
def npu_stream_switch(target_stream: torch.npu.Stream,
|
||||
*,
|
||||
enabled: bool = True):
|
||||
"""
|
||||
Switch to the target stream if enabled is True.
|
||||
Otherwise, do nothing.
|
||||
"""
|
||||
if not enabled:
|
||||
return nullcontext()
|
||||
assert target_stream is not None
|
||||
return torch.npu.stream(target_stream)
|
||||
|
||||
Reference in New Issue
Block a user