From 749e24f81eeda8b6750fe98728af2487cf9a564d Mon Sep 17 00:00:00 2001 From: Qiu Date: Fri, 23 Jan 2026 14:19:49 +0800 Subject: [PATCH] [bugfix] align max_num_batched_tokens with tp*pcp when using FLASHCOMM1 (#6000) ### What this PR does / why we need it? Align max_num_batched_tokens with tp*pcp when using FLASHCOMM1 to avoid assert error in `NPUModelRunner._dummy_run`. - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: QiuChunshuo --- tests/ut/attention/test_mla_cp.py | 4 ++++ tests/ut/attention/test_mla_v1.py | 3 +++ tests/ut/attention/test_sfa_v1.py | 9 ++++++++- vllm_ascend/ascend_config.py | 20 ++++++++++++++++++-- 4 files changed, 33 insertions(+), 3 deletions(-) diff --git a/tests/ut/attention/test_mla_cp.py b/tests/ut/attention/test_mla_cp.py index 801e42c9..e8b94c57 100755 --- a/tests/ut/attention/test_mla_cp.py +++ b/tests/ut/attention/test_mla_cp.py @@ -176,12 +176,16 @@ class TestAscendMLAImpl(TestBase): vllm_config = MagicMock() speculative_config = MagicMock() model_config = MagicMock() + parallel_config = MagicMock() + parallel_config.prefill_context_parallel_size = 1 + parallel_config.tensor_parallel_size = 2 speculative_config.num_speculative_tokens = 4 vllm_config.speculative_config = speculative_config model_config.dtype = torch.float16 vllm_config.model_config = model_config get_current_vllm_config.return_value = vllm_config vllm_config.additional_config = {"refresh": True} + vllm_config.parallel_config = parallel_config init_ascend_config(vllm_config) num_heads = 256 diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 6059ecab..eb5234e6 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -757,12 +757,15 @@ class TestAscendMLAImpl(TestBase): vllm_config = MagicMock() speculative_config = MagicMock() model_config = MagicMock() + parallel_config = MagicMock() + parallel_config.prefill_context_parallel_size = 1 speculative_config.num_speculative_tokens = 4 vllm_config.speculative_config = speculative_config model_config.dtype = torch.float16 vllm_config.model_config = model_config get_current_vllm_config.return_value = vllm_config vllm_config.additional_config = {"refresh": True} + vllm_config.parallel_config = parallel_config init_ascend_config(vllm_config) num_heads = 256 diff --git a/tests/ut/attention/test_sfa_v1.py b/tests/ut/attention/test_sfa_v1.py index aca95244..015779a1 100644 --- a/tests/ut/attention/test_sfa_v1.py +++ b/tests/ut/attention/test_sfa_v1.py @@ -5,6 +5,7 @@ import torch from tests.ut.base import TestBase from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm.distributed.parallel_state import GroupCoordinator if 'torch_npu._inductor' not in sys.modules: sys.modules['torch_npu._inductor'] = MagicMock() @@ -81,7 +82,13 @@ class TestAscendSFAMetadata(TestBase): class TestAscendSFAMetadataBuilder(TestBase): - def setUp(self): + @patch('vllm.distributed.parallel_state._TP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + def setUp(self, mock_tp): + mock_tp.world_size = 2 + mock_tp.rank_in_group = MagicMock() + mock_tp.device_group = MagicMock() + self.mock_cfg = MagicMock() self.mock_cfg.parallel_config = MagicMock() diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 0bb39fe9..06299fd5 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING from vllm.logger import logger from vllm.triton_utils import HAS_TRITON +from vllm.utils.math_utils import cdiv if TYPE_CHECKING: from vllm.config import VllmConfig @@ -62,10 +63,25 @@ class AscendConfig: additional_config.get("enable_shared_expert_dp", False) and vllm_config.parallel_config.enable_expert_parallel ) - if self.enable_shared_expert_dp: - from vllm_ascend.utils import enable_sp + from vllm_ascend.utils import enable_sp + if self.enable_shared_expert_dp: assert enable_sp(vllm_config=vllm_config, enable_shared_expert_dp=True) + + if vllm_config.parallel_config.prefill_context_parallel_size > 1 and enable_sp(vllm_config=vllm_config): + tp_pcp_size = ( + vllm_config.parallel_config.tensor_parallel_size + * vllm_config.parallel_config.prefill_context_parallel_size + ) + if vllm_config.scheduler_config.max_num_batched_tokens % tp_pcp_size != 0: + vllm_config.scheduler_config.max_num_batched_tokens = ( + cdiv(vllm_config.scheduler_config.max_num_batched_tokens, tp_pcp_size) * tp_pcp_size + ) + logger.warning_once( + f"When using FLASHCOMM1, the max_num_batched_tokens should be divisible" + f"by tp_size * pcp_size ({tp_pcp_size}). It has been adjusted to" + f"{vllm_config.scheduler_config.max_num_batched_tokens}." + ) self.multistream_overlap_shared_expert = additional_config.get("multistream_overlap_shared_expert", False) self.multistream_overlap_gate = additional_config.get("multistream_overlap_gate", False) self.recompute_scheduler_enable = additional_config.get("recompute_scheduler_enable", False)