diff --git a/tests/ut/attention/test_mla_cp.py b/tests/ut/attention/test_mla_cp.py index 801e42c9..e8b94c57 100755 --- a/tests/ut/attention/test_mla_cp.py +++ b/tests/ut/attention/test_mla_cp.py @@ -176,12 +176,16 @@ class TestAscendMLAImpl(TestBase): vllm_config = MagicMock() speculative_config = MagicMock() model_config = MagicMock() + parallel_config = MagicMock() + parallel_config.prefill_context_parallel_size = 1 + parallel_config.tensor_parallel_size = 2 speculative_config.num_speculative_tokens = 4 vllm_config.speculative_config = speculative_config model_config.dtype = torch.float16 vllm_config.model_config = model_config get_current_vllm_config.return_value = vllm_config vllm_config.additional_config = {"refresh": True} + vllm_config.parallel_config = parallel_config init_ascend_config(vllm_config) num_heads = 256 diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 6059ecab..eb5234e6 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -757,12 +757,15 @@ class TestAscendMLAImpl(TestBase): vllm_config = MagicMock() speculative_config = MagicMock() model_config = MagicMock() + parallel_config = MagicMock() + parallel_config.prefill_context_parallel_size = 1 speculative_config.num_speculative_tokens = 4 vllm_config.speculative_config = speculative_config model_config.dtype = torch.float16 vllm_config.model_config = model_config get_current_vllm_config.return_value = vllm_config vllm_config.additional_config = {"refresh": True} + vllm_config.parallel_config = parallel_config init_ascend_config(vllm_config) num_heads = 256 diff --git a/tests/ut/attention/test_sfa_v1.py b/tests/ut/attention/test_sfa_v1.py index aca95244..015779a1 100644 --- a/tests/ut/attention/test_sfa_v1.py +++ b/tests/ut/attention/test_sfa_v1.py @@ -5,6 +5,7 @@ import torch from tests.ut.base import TestBase from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm.distributed.parallel_state import GroupCoordinator if 'torch_npu._inductor' not in sys.modules: sys.modules['torch_npu._inductor'] = MagicMock() @@ -81,7 +82,13 @@ class TestAscendSFAMetadata(TestBase): class TestAscendSFAMetadataBuilder(TestBase): - def setUp(self): + @patch('vllm.distributed.parallel_state._TP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + def setUp(self, mock_tp): + mock_tp.world_size = 2 + mock_tp.rank_in_group = MagicMock() + mock_tp.device_group = MagicMock() + self.mock_cfg = MagicMock() self.mock_cfg.parallel_config = MagicMock() diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 0bb39fe9..06299fd5 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING from vllm.logger import logger from vllm.triton_utils import HAS_TRITON +from vllm.utils.math_utils import cdiv if TYPE_CHECKING: from vllm.config import VllmConfig @@ -62,10 +63,25 @@ class AscendConfig: additional_config.get("enable_shared_expert_dp", False) and vllm_config.parallel_config.enable_expert_parallel ) - if self.enable_shared_expert_dp: - from vllm_ascend.utils import enable_sp + from vllm_ascend.utils import enable_sp + if self.enable_shared_expert_dp: assert enable_sp(vllm_config=vllm_config, enable_shared_expert_dp=True) + + if vllm_config.parallel_config.prefill_context_parallel_size > 1 and enable_sp(vllm_config=vllm_config): + tp_pcp_size = ( + vllm_config.parallel_config.tensor_parallel_size + * vllm_config.parallel_config.prefill_context_parallel_size + ) + if vllm_config.scheduler_config.max_num_batched_tokens % tp_pcp_size != 0: + vllm_config.scheduler_config.max_num_batched_tokens = ( + cdiv(vllm_config.scheduler_config.max_num_batched_tokens, tp_pcp_size) * tp_pcp_size + ) + logger.warning_once( + f"When using FLASHCOMM1, the max_num_batched_tokens should be divisible" + f"by tp_size * pcp_size ({tp_pcp_size}). It has been adjusted to" + f"{vllm_config.scheduler_config.max_num_batched_tokens}." + ) self.multistream_overlap_shared_expert = additional_config.get("multistream_overlap_shared_expert", False) self.multistream_overlap_gate = additional_config.get("multistream_overlap_gate", False) self.recompute_scheduler_enable = additional_config.get("recompute_scheduler_enable", False)