From 30c5d947c3e5183b1d46bb7dd34732498cadf1e9 Mon Sep 17 00:00:00 2001 From: Wang Yixuan <88923622+hust17yixuan@users.noreply.github.com> Date: Thu, 9 Oct 2025 19:00:32 +0800 Subject: [PATCH] [bugfix]fix multistream moe in torchair (#3164) ### What this PR does / why we need it? the multistream moe in tochari only validate in decode, but can't be applied to chunked prefill, So add some judgments to isolate the scenario ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: hust17yixuan <303660421@qq.com> --- tests/ut/torchair/models/test_torchair_deepseek_v2.py | 4 ++-- vllm_ascend/torchair/models/torchair_deepseek_v2.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/ut/torchair/models/test_torchair_deepseek_v2.py b/tests/ut/torchair/models/test_torchair_deepseek_v2.py index 5a7c2a2..0231ab9 100644 --- a/tests/ut/torchair/models/test_torchair_deepseek_v2.py +++ b/tests/ut/torchair/models/test_torchair_deepseek_v2.py @@ -176,7 +176,7 @@ def test_torchair_deepseek_v2_merged_replicated_linear(mock_distributed): TorchairDeepseekV2RowParallelLinearReplaceAllreduce, TorchairDeepseekV2RowParallelLinear ]) -def test_row_parallel_linear(cls, mock_distributed): +def test_row_parallel_linear(cls, mock_distributed, mock_forward_context): linear = cls(input_size=128, output_size=64, bias=False, quant_config=None) linear.quant_method = Mock() linear.quant_method.apply.return_value = torch.randn(2, 4, 64) @@ -282,7 +282,7 @@ def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual, mock_maybe_wait_prefetch_done, mock_rms_norm, mock_add_norm, mock_distributed, base_config, - vllm_config): + vllm_config, mock_forward_context): mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128)) mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128), torch.randn(2, 128)) diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v2.py b/vllm_ascend/torchair/models/torchair_deepseek_v2.py index 8cf6e24..371b1c9 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_v2.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_v2.py @@ -161,12 +161,13 @@ class TorchairDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear): output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) + forward_context = get_forward_context() if self.reduce_results and self.tp_size > 1: num_tokens = output_parallel.shape[0] if is_force_scatter and num_tokens % self.tp_size: output_parallel = nn.functional.pad( output_parallel, (0, 0, 0, -num_tokens % self.tp_size)) - if is_force_scatter or (not is_prefill + if is_force_scatter or (not forward_context.with_prefill and output_parallel.shape[0] % self.tp_size == 0): output = tensor_model_parallel_reduce_scatter(output_parallel, @@ -945,15 +946,15 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): replace_allreduce: bool = False, ) -> torch.Tensor: # Self Attention + forward_context = get_forward_context() if attn_metadata is not None: decoding_condition_met = ( not attn_metadata.is_prefill if self.use_sfa else - attn_metadata.num_decodes > 0 if self.use_mla else False) + not forward_context.with_prefill if self.use_mla else False) mla_moe_communication = decoding_condition_met and self.mla_moe_communication and replace_allreduce else: mla_moe_communication = False - forward_context = get_forward_context() if (envs.VLLM_ASCEND_ENABLE_MLAPO and isinstance(self.self_attn, TorchairDeepseekV2SFAAttention) and attn_metadata is not None