[Cherry-pick] Port MoE multi-stream fix to v0.11.0-dev (#3753)

This PR moves the communication operation of shared experts out of extra
stream because I found that this might cause rtMemcpy related errors
when running shared experts multistream with aclgraph.

Furthermore, I utilize a global variable as extra stream object to avoid
allocating streams for each layer in full-graph mode.

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-10-25 15:51:43 +08:00
committed by GitHub
parent 1bc61031e5
commit a58ff9e92f
3 changed files with 25 additions and 13 deletions

View File

@@ -28,7 +28,7 @@ from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal
MODELS = [
"Qwen/Qwen3-0.6B",
"vllm-ascend/DeepSeek-V2-Lite-W8A8",
]

View File

@@ -40,7 +40,8 @@ from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p,
is_enable_nz, npu_stream_switch,
shared_expert_dp_enabled)
shared_expert_dp_enabled,
shared_experts_compute_stream)
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
@@ -419,8 +420,6 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
self.shared_expert_stream = None
ascend_config = get_ascend_config()
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
if self.multistream_overlap_shared_expert:
self.shared_expert_stream = torch.npu.Stream()
if enable_sp():
logger.info_once(
"Sequence parallelism is enabled, shared experts are replicated for best performance."
@@ -442,19 +441,15 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
router_logits: torch.Tensor):
# Make sure the shared experts stream begins after hidden_states are ready.
if self.multistream_overlap_shared_expert:
self.shared_expert_stream.wait_stream( # type: ignore
shared_experts_compute_stream().wait_stream( # type: ignore
torch.npu.current_stream())
with npu_stream_switch(self.shared_expert_stream,
with npu_stream_switch(shared_experts_compute_stream(),
enabled=self.multistream_overlap_shared_expert):
# Use a separate stream to run shared experts.
# Note that currently we only support calculations in separate streams with aclgraph.
# Communication operations in another stream might cause unknown errors.
shared_out = self._shared_experts(hidden_states)
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
fused_output = AscendFusedMoE.forward_impl(
self,
hidden_states=hidden_states,
@@ -462,5 +457,12 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
)
# Make sure the default stream waits for the shared experts stream to finish.
if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
torch.npu.current_stream().wait_stream(
shared_experts_compute_stream())
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out, fused_output

View File

@@ -52,6 +52,7 @@ _IS_310P = None
_SLEEP_MODE_ENABLED = None
_CURRENT_STREAM = None
_PREFETCH_STREAM = None
_SHARED_EXPERTS_COMPUTE_STREAM = None
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
_DEFAULT_BUFFER_SIZE = 200
_MIN_DP_BUFFER_SIZE = 50
@@ -259,6 +260,15 @@ def prefetch_stream() -> torch.npu.Stream:
return _PREFETCH_STREAM
def shared_experts_compute_stream() -> torch.npu.Stream:
global _SHARED_EXPERTS_COMPUTE_STREAM
if _SHARED_EXPERTS_COMPUTE_STREAM is None:
# when this function is called before any stream is set,
# we return the default stream.
_SHARED_EXPERTS_COMPUTE_STREAM = torch_npu.npu.Stream()
return _SHARED_EXPERTS_COMPUTE_STREAM
def adapt_patch(is_global_patch: bool = False):
if is_global_patch:
from vllm_ascend.patch import platform # noqa: F401