[Feature] Support moe multi-stream for aclgraph. (#2946)

This PR puts the calculation of shared experts into a separate stream,
overlaping with routing experts.

- vLLM version: v0.10.2
- vLLM main:
fbd6523ac0

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-09-19 11:06:45 +08:00
committed by GitHub
parent 0c04bf1e36
commit 0a526768f5
14 changed files with 170 additions and 49 deletions

View File

@@ -195,8 +195,8 @@ msgid ""
msgstr "是否将MLA的向量操作放到另一个流中。此选项仅对使用MLA的模型例如DeepSeek有效。" msgstr "是否将MLA的向量操作放到另一个流中。此选项仅对使用MLA的模型例如DeepSeek有效。"
#: ../../user_guide/configuration/additional_config.md #: ../../user_guide/configuration/additional_config.md
msgid "`enable_multistream_moe`" msgid "`multistream_overlap_shared_expert`"
msgstr "`enable_multistream_moe`" msgstr "`multistream_overlap_shared_expert`"
#: ../../user_guide/configuration/additional_config.md #: ../../user_guide/configuration/additional_config.md
msgid "" msgid ""

View File

@@ -35,6 +35,7 @@ The following table lists the additional configuration options available in vLLM
| `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. | | `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. |
| `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. | | `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. |
| `oproj_tensor_parallel_size` | int | `None` | The custom tensor parallel size of oproj. | | `oproj_tensor_parallel_size` | int | `None` | The custom tensor parallel size of oproj. |
| `multistream_overlap_shared_expert`| bool | `False` | Whether to enable multistream shared expert. This option only takes effects on moe models with shared experts. |
The details of each config option are as follows: The details of each config option are as follows:
@@ -45,7 +46,6 @@ The details of each config option are as follows:
| `enabled` | bool | `False` | Whether to enable torchair graph mode. Currently only DeepSeek series models and PanguProMoE are supported to use torchair graph mode | | `enabled` | bool | `False` | Whether to enable torchair graph mode. Currently only DeepSeek series models and PanguProMoE are supported to use torchair graph mode |
| `mode` | str | `None` | When using reduce-overhead mode for torchair, mode needs to be set | | `mode` | str | `None` | When using reduce-overhead mode for torchair, mode needs to be set |
| `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream. This option only takes effects on models using MLA (e.g., DeepSeek). | | `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream. This option only takes effects on models using MLA (e.g., DeepSeek). |
| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert. This option only takes effects on DeepSeek moe models. |
| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization | | `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization |
| `enable_frozen_parameter` | bool | `True` | Whether to fix the memory address of weights during inference to reduce the input address refresh time during graph execution. | | `enable_frozen_parameter` | bool | `True` | Whether to fix the memory address of weights during inference to reduce the input address refresh time during graph execution. |
| `use_cached_graph` | bool | `False` | Whether to use cached graph | | `use_cached_graph` | bool | `False` | Whether to use cached graph |
@@ -74,13 +74,13 @@ An example of additional configuration is as follows:
"use_cached_graph": True, "use_cached_graph": True,
"graph_batch_sizes": [1, 2, 4, 8], "graph_batch_sizes": [1, 2, 4, 8],
"graph_batch_sizes_init": False, "graph_batch_sizes_init": False,
"enable_multistream_moe": False,
"enable_kv_nz": False "enable_kv_nz": False
}, },
"ascend_scheduler_config": { "ascend_scheduler_config": {
"enabled": True, "enabled": True,
"enable_chunked_prefill": True, "enable_chunked_prefill": True,
}, },
"multistream_overlap_shared_expert": True,
"refresh": False, "refresh": False,
} }
``` ```

View File

@@ -43,4 +43,4 @@ vllm serve model_path \
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector" "kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
}' \ }' \
--additional-config \ --additional-config \
'{"ascend_scheduler_config": {"enabled": true}, "torchair_graph_config":{"enabled":true,"enable_kv_nz":false, "enable_multistream_moe":false, "graph_batch_size":[28]}, "enable_weight_nz_layout":true}' '{"ascend_scheduler_config": {"enabled": true}, "torchair_graph_config":{"enabled":true,"enable_kv_nz":false, "graph_batch_size":[28]}, "enable_weight_nz_layout":true, "enable_multistream_moe":false}'

View File

@@ -29,4 +29,4 @@ vllm serve Qwen/Qwen1.5-MoE-A2.7B \
--gpu-memory-utilization 0.9 \ --gpu-memory-utilization 0.9 \
--trust-remote-code \ --trust-remote-code \
--enforce-eager \ --enforce-eager \
--additional-config '{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":false, "enable_multistream_moe":false, "use_cached_graph":false}}' --additional-config '{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":false, "use_cached_graph":false}}'

View File

@@ -66,8 +66,8 @@ def test_models_distributed_DeepSeek_multistream_moe():
additional_config={ additional_config={
"torchair_graph_config": { "torchair_graph_config": {
"enabled": True, "enabled": True,
"enable_multistream_moe": True,
}, },
"enable_multistream_moe": True,
"ascend_scheduler_config": { "ascend_scheduler_config": {
"enabled": True, "enabled": True,
}, },

View File

@@ -0,0 +1,103 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Compare the outputs of vLLM with multistream_overlap_shared_expert
enabled and disabled.
Run `pytest tests/e2e/singlecard/test_multistream_overlap_shared_expert.py`.
"""
import pytest
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal
MODELS = [
"Qwen/Qwen3-0.6B",
]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [32])
def test_models_with_multistream_overlap_shared_expert(
model: str,
max_tokens: int,
) -> None:
prompts = [
"Hello, my name is", "The president of the United States is",
"The capital of France is", "The future of AI is"
]
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)
with VllmRunner(
model,
max_model_len=1024,
enforce_eager=True,
additional_config={
"multistream_overlap_shared_expert": True,
},
) as runner:
vllm_moe_ms_eager_outputs = runner.model.generate(
prompts, sampling_params)
with VllmRunner(
model,
max_model_len=1024,
enforce_eager=False,
additional_config={
"multistream_overlap_shared_expert": True,
},
) as runner:
vllm_moe_ms_aclgraph_outputs = runner.model.generate(
prompts, sampling_params)
with VllmRunner(
model,
max_model_len=1024,
enforce_eager=True,
) as runner:
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
vllm_moe_ms_eager_outputs_list = []
for output in vllm_moe_ms_eager_outputs:
vllm_moe_ms_eager_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))
vllm_moe_ms_aclgraph_outputs_list = []
for output in vllm_moe_ms_aclgraph_outputs:
vllm_moe_ms_aclgraph_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))
vllm_eager_outputs_list = []
for output in vllm_eager_outputs:
vllm_eager_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))
check_outputs_equal(
outputs_0_lst=vllm_eager_outputs_list,
outputs_1_lst=vllm_moe_ms_eager_outputs_list,
name_0="vllm_eager_outputs",
name_1="vllm_moe_ms_eager_outputs",
)
check_outputs_equal(
outputs_0_lst=vllm_eager_outputs_list,
outputs_1_lst=vllm_moe_ms_aclgraph_outputs_list,
name_0="vllm_eager_outputs",
name_1="vllm_moe_ms_aclgraph_outputs",
)

View File

@@ -94,7 +94,8 @@ def mock_dist_env(mocker: MockerFixture):
return_value=mock_dp_and_tp_group(mocker)), \ return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_ascend_config', patch('vllm_ascend.ops.fused_moe.get_ascend_config',
return_value=MagicMock( return_value=MagicMock(
torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False), torchair_graph_config=MagicMock(enabled=False),
enable_multistream_moe=False,
expert_map_path=None expert_map_path=None
)), \ )), \
patch('vllm_ascend.ops.fused_moe.determine_expert_map', patch('vllm_ascend.ops.fused_moe.determine_expert_map',

View File

@@ -43,6 +43,7 @@ class TestAscendConfig(TestBase):
# No additional config given, check the default value here. # No additional config given, check the default value here.
ascend_config = init_ascend_config(test_vllm_config) ascend_config = init_ascend_config(test_vllm_config)
self.assertIsNone(ascend_config.expert_map_path) self.assertIsNone(ascend_config.expert_map_path)
self.assertFalse(ascend_config.multistream_overlap_shared_expert)
torchair_graph_config = ascend_config.torchair_graph_config torchair_graph_config = ascend_config.torchair_graph_config
self.assertFalse(torchair_graph_config.enabled) self.assertFalse(torchair_graph_config.enabled)
@@ -51,7 +52,6 @@ class TestAscendConfig(TestBase):
self.assertEqual(torchair_graph_config.graph_batch_sizes, []) self.assertEqual(torchair_graph_config.graph_batch_sizes, [])
self.assertFalse(torchair_graph_config.graph_batch_sizes_init) self.assertFalse(torchair_graph_config.graph_batch_sizes_init)
self.assertFalse(torchair_graph_config.enable_multistream_mla) self.assertFalse(torchair_graph_config.enable_multistream_mla)
self.assertFalse(torchair_graph_config.enable_multistream_moe)
self.assertTrue(torchair_graph_config.enable_view_optimize) self.assertTrue(torchair_graph_config.enable_view_optimize)
self.assertTrue(torchair_graph_config.enable_frozen_parameter) self.assertTrue(torchair_graph_config.enable_frozen_parameter)
self.assertFalse(torchair_graph_config.enable_kv_nz) self.assertFalse(torchair_graph_config.enable_kv_nz)
@@ -69,11 +69,11 @@ class TestAscendConfig(TestBase):
"graph_batch_sizes": [1, 2, 4], "graph_batch_sizes": [1, 2, 4],
"graph_batch_sizes_init": False, "graph_batch_sizes_init": False,
"enable_multistream_mla": True, "enable_multistream_mla": True,
"enable_multistream_moe": True,
"enable_view_optimize": True, "enable_view_optimize": True,
"enable_frozen_parameter": True, "enable_frozen_parameter": True,
"enable_kv_nz": True "enable_kv_nz": True
}, },
"multistream_overlap_shared_expert": True,
"ascend_scheduler_config": { "ascend_scheduler_config": {
"enabled": True "enabled": True
}, },
@@ -82,6 +82,7 @@ class TestAscendConfig(TestBase):
} }
ascend_config = init_ascend_config(test_vllm_config) ascend_config = init_ascend_config(test_vllm_config)
self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path") self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path")
self.assertTrue(ascend_config.multistream_overlap_shared_expert)
torchair_graph_config = ascend_config.torchair_graph_config torchair_graph_config = ascend_config.torchair_graph_config
self.assertTrue(torchair_graph_config.enabled) self.assertTrue(torchair_graph_config.enabled)
@@ -89,7 +90,6 @@ class TestAscendConfig(TestBase):
self.assertEqual(torchair_graph_config.graph_batch_sizes, [1, 2, 4]) self.assertEqual(torchair_graph_config.graph_batch_sizes, [1, 2, 4])
self.assertFalse(torchair_graph_config.graph_batch_sizes_init) self.assertFalse(torchair_graph_config.graph_batch_sizes_init)
self.assertTrue(torchair_graph_config.enable_multistream_mla) self.assertTrue(torchair_graph_config.enable_multistream_mla)
self.assertTrue(torchair_graph_config.enable_multistream_moe)
self.assertTrue(torchair_graph_config.enable_view_optimize) self.assertTrue(torchair_graph_config.enable_view_optimize)
self.assertTrue(torchair_graph_config.enable_frozen_parameter) self.assertTrue(torchair_graph_config.enable_frozen_parameter)
self.assertTrue(torchair_graph_config.enable_kv_nz) self.assertTrue(torchair_graph_config.enable_kv_nz)
@@ -306,17 +306,6 @@ class TestAscendConfig(TestBase):
} }
init_ascend_config(test_vllm_config) init_ascend_config(test_vllm_config)
# enable_multistream_moe should not be enabled without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"enable_multistream_moe": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# mode should not be configured without torchair graph mode # mode should not be configured without torchair graph mode
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = { test_vllm_config.additional_config = {

View File

@@ -70,7 +70,8 @@ def mock_dist_env(mocker: MockerFixture):
return_value=mock_dp_and_tp_group(mocker)), \ return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_config', patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_config',
return_value=MagicMock( return_value=MagicMock(
torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False), torchair_graph_config=MagicMock(enabled=False),
enable_multistream_moe=False,
expert_map_path=None expert_map_path=None
)), \ )), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.determine_expert_map', patch('vllm_ascend.torchair.ops.torchair_fused_moe.determine_expert_map',

View File

@@ -61,6 +61,8 @@ class AscendConfig:
self.enable_shared_expert_dp = additional_config.get( self.enable_shared_expert_dp = additional_config.get(
"enable_shared_expert_dp", False "enable_shared_expert_dp", False
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
self.multistream_overlap_shared_expert = additional_config.get(
"multistream_overlap_shared_expert", False)
self.enable_prefetch = additional_config.get("enable_prefetch", False) self.enable_prefetch = additional_config.get("enable_prefetch", False)
self.lmhead_tensor_parallel_size = additional_config.get( self.lmhead_tensor_parallel_size = additional_config.get(
"lmhead_tensor_parallel_size", None) "lmhead_tensor_parallel_size", None)
@@ -110,8 +112,6 @@ class TorchairGraphConfig:
"graph_batch_sizes_init", False) "graph_batch_sizes_init", False)
self.enable_multistream_mla = torchair_graph_config.get( self.enable_multistream_mla = torchair_graph_config.get(
"enable_multistream_mla", False) "enable_multistream_mla", False)
self.enable_multistream_moe = torchair_graph_config.get(
"enable_multistream_moe", False)
self.enable_view_optimize = torchair_graph_config.get( self.enable_view_optimize = torchair_graph_config.get(
"enable_view_optimize", True) "enable_view_optimize", True)
self.enable_frozen_parameter = torchair_graph_config.get( self.enable_frozen_parameter = torchair_graph_config.get(
@@ -148,10 +148,6 @@ class TorchairGraphConfig:
raise RuntimeError( raise RuntimeError(
"enable_multistream_mla is valid only when Torchair graph mode is enabled" "enable_multistream_mla is valid only when Torchair graph mode is enabled"
) )
if self.enable_multistream_moe:
raise RuntimeError(
"enable_multistream_moe is valid only when Torchair graph mode is enabled"
)
if self.enable_kv_nz: if self.enable_kv_nz:
raise RuntimeError( raise RuntimeError(
"enable_kv_nz is valid only when Torchair graph mode is enabled" "enable_kv_nz is valid only when Torchair graph mode is enabled"

View File

@@ -37,7 +37,7 @@ from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl, from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl, MC2CommImpl, AlltoAllCommImpl, MC2CommImpl,
NaiveMulticastCommImpl) NaiveMulticastCommImpl)
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__ original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
@@ -426,24 +426,39 @@ class AscendSharedFusedMoE(AscendFusedMoE):
super().__init__(**kwargs) super().__init__(**kwargs)
self._shared_experts = shared_experts self._shared_experts = shared_experts
self.use_overlapped = use_overlapped self.use_overlapped = use_overlapped
self.shared_expert_stream = None
ascend_config = get_ascend_config()
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
if self.multistream_overlap_shared_expert:
self.shared_expert_stream = torch.npu.Stream()
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = self._shared_experts(hidden_states) # Make sure the shared experts stream begins after hidden_states are ready.
if self.multistream_overlap_shared_expert:
self.shared_expert_stream.wait_stream( # type: ignore
torch.npu.current_stream())
with npu_stream_switch(self.shared_expert_stream,
enabled=self.multistream_overlap_shared_expert):
# Use a separate stream to run shared experts.
shared_out = self._shared_experts(hidden_states)
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context() forward_context = get_forward_context()
moe_comm_method_name = forward_context.moe_comm_method_name moe_comm_method_name = forward_context.moe_comm_method_name
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}: if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
shared_out = tensor_model_parallel_all_reduce(shared_out) shared_out = tensor_model_parallel_all_reduce(shared_out)
fused_out = super().forward( fused_out = super().forward(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
) )
# Make sure the default stream waits for the shared experts stream to finish.
if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
return shared_out, fused_out return shared_out, fused_out

View File

@@ -322,8 +322,8 @@ class TorchairDeepseekV2MoE(nn.Module):
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_moe = \ self.multistream_overlap_shared_expert = \
ascend_config.torchair_graph_config.enable_multistream_moe and \ ascend_config.multistream_overlap_shared_expert and \
self.torchair_graph_enabled self.torchair_graph_enabled
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
@@ -364,7 +364,7 @@ class TorchairDeepseekV2MoE(nn.Module):
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, reduce_results=reduce_results,
force_replicate=self.enable_multistream_moe force_replicate=self.multistream_overlap_shared_expert
or enable_shared_expert_dp, or enable_shared_expert_dp,
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
) )
@@ -406,7 +406,7 @@ class TorchairDeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = None router_logits = None
if not self.rm_router_logits and not self.enable_multistream_moe: if not self.rm_router_logits and not self.multistream_overlap_shared_expert:
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
experts_hidden_states = self.experts( experts_hidden_states = self.experts(
@@ -524,7 +524,7 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
elif (config.n_routed_experts is not None elif (config.n_routed_experts is not None
and self.debug_layer_idx >= config.first_k_dense_replace and self.debug_layer_idx >= config.first_k_dense_replace
and self.debug_layer_idx % config.moe_layer_freq == 0 and self.debug_layer_idx % config.moe_layer_freq == 0
and (ascend_config.torchair_graph_config.enable_multistream_moe and (ascend_config.multistream_overlap_shared_expert
or self.enable_shared_expert_dp)): or self.enable_shared_expert_dp)):
self.o_proj = TorchairDeepseekV2RowParallelLinearReplaceAllreduce( self.o_proj = TorchairDeepseekV2RowParallelLinearReplaceAllreduce(
self.num_heads * self.v_head_dim, self.num_heads * self.v_head_dim,
@@ -697,7 +697,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
) )
self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \ self.mla_moe_communication = ascend_config.multistream_overlap_shared_expert \
and model_config.use_mla and self.tp_size > 1 and model_config.use_mla and self.tp_size > 1
else: else:
self.mlp = TorchairDeepseekV2MLP( self.mlp = TorchairDeepseekV2MLP(

View File

@@ -1049,8 +1049,8 @@ class TorchairAscendFusedMoE(FusedMoE):
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64) self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_moe = \ self.multistream_overlap_shared_expert = \
ascend_config.torchair_graph_config.enable_multistream_moe and \ ascend_config.multistream_overlap_shared_expert and \
self.torchair_graph_enabled self.torchair_graph_enabled
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
@@ -1148,7 +1148,7 @@ class TorchairAscendFusedMoE(FusedMoE):
quantized_x_for_share, dynamic_scale_for_share = None, None quantized_x_for_share, dynamic_scale_for_share = None, None
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \ from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
TorchairAscendW8A8DynamicFusedMoEMethod TorchairAscendW8A8DynamicFusedMoEMethod
if self.enable_multistream_moe: if self.multistream_overlap_shared_expert:
if not self.rm_router_logits: if not self.rm_router_logits:
router_logits, _ = gate(hidden_states) router_logits, _ = gate(hidden_states)
if hasattr(self.quant_method, "quant_method") and \ if hasattr(self.quant_method, "quant_method") and \
@@ -1160,7 +1160,7 @@ class TorchairAscendFusedMoE(FusedMoE):
hidden_states) hidden_states)
if shared_experts: if shared_experts:
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: if not self.multistream_overlap_shared_expert or fused_moe_state != FusedMoEState.MC2:
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
shared_hidden_states = shared_experts(hidden_states) shared_hidden_states = shared_experts(hidden_states)
@@ -1256,7 +1256,8 @@ class TorchairAscendFusedMoE(FusedMoE):
log2phy=self.log2phy, log2phy=self.log2phy,
global_redundant_expert_num=self.global_redundant_expert_num, global_redundant_expert_num=self.global_redundant_expert_num,
shared_experts=shared_experts if self.torchair_graph_enabled shared_experts=shared_experts if self.torchair_graph_enabled
and self.enable_multistream_moe and not is_prefill else None, and self.multistream_overlap_shared_expert and not is_prefill else
None,
mc2_mask=mc2_mask, mc2_mask=mc2_mask,
quantized_x_for_share=quantized_x_for_share, quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share, dynamic_scale_for_share=dynamic_scale_for_share,

View File

@@ -21,7 +21,7 @@ import atexit
import functools import functools
import math import math
import os import os
from contextlib import contextmanager from contextlib import contextmanager, nullcontext
from enum import Enum from enum import Enum
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
@@ -321,7 +321,9 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
if os.getenv("HCCL_OP_EXPANSION_MODE") == 'AIV': if os.getenv("HCCL_OP_EXPANSION_MODE") == 'AIV':
# TODO: Find out whether we need to take into account the pp_size # TODO: Find out whether we need to take into account the pp_size
parallel_factor = 1 + num_comm_groups + int( parallel_factor = 1 + num_comm_groups + int(
parallel_config.enable_expert_parallel) parallel_config.enable_expert_parallel) + int(
vllm_config.additional_config.get(
"multistream_overlap_shared_expert", False))
if is_moe_model(vllm_config): if is_moe_model(vllm_config):
parallel_factor += (parallel_config.data_parallel_size > 1) parallel_factor += (parallel_config.data_parallel_size > 1)
# Calculate maximum supported batch sizes considering model architecture on the A2 Hardware Device # Calculate maximum supported batch sizes considering model architecture on the A2 Hardware Device
@@ -617,3 +619,16 @@ def weak_ref_tensors(
if isinstance(tensors, tuple): if isinstance(tensors, tuple):
return tuple(weak_ref_tensor(t) for t in tensors) return tuple(weak_ref_tensor(t) for t in tensors)
raise ValueError("Invalid type for tensors") raise ValueError("Invalid type for tensors")
def npu_stream_switch(target_stream: torch.npu.Stream,
*,
enabled: bool = True):
"""
Switch to the target stream if enabled is True.
Otherwise, do nothing.
"""
if not enabled:
return nullcontext()
assert target_stream is not None
return torch.npu.stream(target_stream)