From b32ef53b3bfe67891ab107d731bf085c0db75f64 Mon Sep 17 00:00:00 2001 From: LookAround0301 Date: Fri, 5 Dec 2025 10:31:49 +0800 Subject: [PATCH] [long_seq] remove long_seq env (#4660) ### What this PR does / why we need it? remove env VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL - vLLM version: v0.12.0 --------- Signed-off-by: LookAround Signed-off-by: ZhangMingWei716 <2894054457@qq.com> Co-authored-by: ZhangMingWei716 <2894054457@qq.com> Co-authored-by: wangxiyuan --- tests/ut/attention/test_attention_v1.py | 26 +++- tests/ut/attention/test_mla_v1.py | 127 +++++++++++++++--- .../kv_connector/test_mooncake_connector.py | 9 +- tests/ut/test_platform.py | 18 +++ vllm_ascend/attention/attention_v1.py | 36 ++--- vllm_ascend/attention/mla_v1.py | 29 ++-- vllm_ascend/distributed/kvpool/pool_worker.py | 17 +-- .../llmdatadist_c_mgr_connector.py | 23 +--- vllm_ascend/distributed/mooncake_connector.py | 23 +--- vllm_ascend/distributed/parallel_state.py | 16 +-- vllm_ascend/ops/fused_moe/prepare_finalize.py | 5 +- vllm_ascend/platform.py | 12 +- vllm_ascend/spec_decode/mtp_proposer.py | 10 +- vllm_ascend/worker/block_table.py | 13 +- vllm_ascend/worker/model_runner_v1.py | 25 ++-- vllm_ascend/worker/worker_v1.py | 17 +-- 16 files changed, 230 insertions(+), 176 deletions(-) diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index 3a94e9e8..af179b37 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -54,12 +54,16 @@ class TestAscendAttentionBackend(TestBase): class TestAscendAttentionMetadataBuilder(TestBase): + @patch('vllm.distributed.parallel_state.get_pcp_group') + @patch('vllm.distributed.parallel_state._PCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch('vllm.distributed.parallel_state.get_dcp_group') @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) - def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group): + def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp, + mock_get_pcp_group): mock_dcp.world_size = 1 dcp_group = MagicMock(spec=GroupCoordinator) dcp_group.rank_in_group = 0 @@ -67,6 +71,13 @@ class TestAscendAttentionMetadataBuilder(TestBase): dcp_group.device_group = MagicMock() mock_get_dcp_group.return_value = dcp_group + mock_pcp.world_size = 1 + pcp_group = MagicMock(spec=GroupCoordinator) + pcp_group.rank_in_group = 0 + pcp_group.world_size = 1 + pcp_group.device_group = MagicMock() + mock_get_pcp_group.return_value = pcp_group + self.mock_vllm_config = MagicMock() self.mock_vllm_config.speculative_config = None self.mock_vllm_config.model_config.max_model_len = 640 @@ -117,12 +128,16 @@ class TestAscendAttentionMetadataBuilder(TestBase): class TestAscendAttentionBackendImpl(TestBase): + @patch('vllm.distributed.parallel_state.get_pcp_group') + @patch('vllm.distributed.parallel_state._PCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch('vllm.distributed.parallel_state.get_dcp_group') @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) - def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group): + def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp, + mock_get_pcp_group): mock_dcp.world_size = 1 dcp_group = MagicMock(spec=GroupCoordinator) dcp_group.rank_in_group = 0 @@ -130,6 +145,13 @@ class TestAscendAttentionBackendImpl(TestBase): dcp_group.device_group = MagicMock() mock_get_dcp_group.return_value = dcp_group + mock_pcp.world_size = 1 + pcp_group = MagicMock(spec=GroupCoordinator) + pcp_group.rank_in_group = 0 + pcp_group.world_size = 1 + pcp_group.device_group = MagicMock() + mock_get_pcp_group.return_value = pcp_group + self.layer = MagicMock() self.layer.layer_name = "test_layer" self.layer._k_scale_float = 1.0 diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 1babb728..fbb90aa3 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -177,13 +177,17 @@ class TestAscendMLAMetadata(TestBase): class TestAscendMLAMetadataBuilder(TestBase): + @patch('vllm.distributed.parallel_state.get_pcp_group') + @patch('vllm.distributed.parallel_state._PCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch('vllm.distributed.parallel_state.get_dcp_group') @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) def test_ascend_mla_metadata_builder_default(self, mock_get_dcp_size, - mock_dcp, mock_get_dcp_group): + mock_dcp, mock_get_dcp_group, + mock_pcp, mock_get_pcp_group): mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 @@ -201,6 +205,13 @@ class TestAscendMLAMetadataBuilder(TestBase): dcp_group.device_group = MagicMock() mock_get_dcp_group.return_value = dcp_group + mock_pcp.world_size = 1 + pcp_group = MagicMock(spec=GroupCoordinator) + pcp_group.rank_in_group = 0 + pcp_group.world_size = 1 + pcp_group.device_group = MagicMock() + mock_get_pcp_group.return_value = pcp_group + mock_vllm_config.speculative_config = None ascend_config = MagicMock() @@ -215,6 +226,9 @@ class TestAscendMLAMetadataBuilder(TestBase): builder.chunked_prefill_enabled, mock_vllm_config.scheduler_config.enable_chunked_prefill) + @patch('vllm.distributed.parallel_state.get_pcp_group') + @patch('vllm.distributed.parallel_state._PCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch('vllm.distributed.parallel_state.get_dcp_group') @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) @@ -222,7 +236,9 @@ class TestAscendMLAMetadataBuilder(TestBase): return_value=1) def test_ascend_mla_metadata_builder_spec_decode(self, mock_get_dcp_size, mock_dcp, - mock_get_dcp_group): + mock_get_dcp_group, + mock_pcp, + mock_get_pcp_group): mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 @@ -240,6 +256,13 @@ class TestAscendMLAMetadataBuilder(TestBase): dcp_group.device_group = MagicMock() mock_get_dcp_group.return_value = dcp_group + mock_pcp.world_size = 1 + pcp_group = MagicMock(spec=GroupCoordinator) + pcp_group.rank_in_group = 0 + pcp_group.world_size = 1 + pcp_group.device_group = MagicMock() + mock_get_pcp_group.return_value = pcp_group + mock_spec_config = MagicMock() mock_spec_config.num_speculative_tokens = 3 mock_vllm_config.speculative_config = mock_spec_config @@ -256,13 +279,17 @@ class TestAscendMLAMetadataBuilder(TestBase): builder.chunked_prefill_enabled, mock_vllm_config.scheduler_config.enable_chunked_prefill) + @patch('vllm.distributed.parallel_state.get_pcp_group') + @patch('vllm.distributed.parallel_state._PCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch('vllm.distributed.parallel_state.get_dcp_group') @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) def test_ascend_mla_metadata_builder_build_full_graph( - self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group): + self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp, + mock_get_pcp_group): mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 @@ -280,6 +307,13 @@ class TestAscendMLAMetadataBuilder(TestBase): dcp_group.device_group = MagicMock() mock_get_dcp_group.return_value = dcp_group + mock_pcp.world_size = 1 + pcp_group = MagicMock(spec=GroupCoordinator) + pcp_group.rank_in_group = 0 + pcp_group.world_size = 1 + pcp_group.device_group = MagicMock() + mock_get_pcp_group.return_value = pcp_group + mock_spec_config = MagicMock() mock_spec_config.num_speculative_tokens = 1 mock_spec_config.disable_padded_drafter_batch = True @@ -307,13 +341,16 @@ class TestAscendMLAMetadataBuilder(TestBase): [1, 2, 4, 5, 6, 6, 7, 8]) self.assertEqual(metadata.decode.block_table.shape[0], 8) + @patch('vllm.distributed.parallel_state.get_pcp_group') + @patch('vllm.distributed.parallel_state._PCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch('vllm.distributed.parallel_state.get_dcp_group') @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) def test_reorder_batch(self, mock_get_dcp_size, mock_dcp, - mock_get_dcp_group): + mock_get_dcp_group, mock_pcp, mock_get_pcp_group): ascend_config = MagicMock() mock_vllm_config = MagicMock() @@ -331,6 +368,13 @@ class TestAscendMLAMetadataBuilder(TestBase): dcp_group.device_group = MagicMock() mock_get_dcp_group.return_value = dcp_group + mock_pcp.world_size = 1 + pcp_group = MagicMock(spec=GroupCoordinator) + pcp_group.rank_in_group = 0 + pcp_group.world_size = 1 + pcp_group.device_group = MagicMock() + mock_get_pcp_group.return_value = pcp_group + mock_vllm_config.speculative_config = None with patch("vllm_ascend.attention.mla_v1.get_ascend_config", @@ -358,6 +402,9 @@ class TestAscendMLAMetadataBuilder(TestBase): self.assertTrue(modified) input_batch.swap_states.assert_called_once_with(1, 2) + @patch('vllm.distributed.parallel_state.get_pcp_group') + @patch('vllm.distributed.parallel_state._PCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch('vllm.distributed.parallel_state.get_dcp_group') @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) @@ -365,7 +412,9 @@ class TestAscendMLAMetadataBuilder(TestBase): return_value=1) def test_pad_actual_seq_lens_q_mtp_disable_pad(self, mock_get_dcp_size, mock_dcp, - mock_get_dcp_group): + mock_get_dcp_group, + mock_pcp, + mock_get_pcp_group): mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 @@ -384,6 +433,13 @@ class TestAscendMLAMetadataBuilder(TestBase): dcp_group.device_group = MagicMock() mock_get_dcp_group.return_value = dcp_group + mock_pcp.world_size = 1 + pcp_group = MagicMock(spec=GroupCoordinator) + pcp_group.rank_in_group = 0 + pcp_group.world_size = 1 + pcp_group.device_group = MagicMock() + mock_get_pcp_group.return_value = pcp_group + builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config, mock_device) input_seq_lens = [1, 2, 4, 5] @@ -394,14 +450,18 @@ class TestAscendMLAMetadataBuilder(TestBase): num_reqs_pad_size, num_reqs, input_seq_lens) self.assertEqual(output_seq_lens, expect_output) + @patch('vllm.distributed.parallel_state.get_pcp_group') + @patch('vllm.distributed.parallel_state._PCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch('vllm.distributed.parallel_state.get_dcp_group') @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1) def test_pad_actual_seq_lens_q_mtp_enable_pad(self, mock_get_dcp_size, - mock_dcp, - mock_get_dcp_group): + mock_dcp, mock_get_dcp_group, + mock_pcp, + mock_get_pcp_group): mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 @@ -419,6 +479,14 @@ class TestAscendMLAMetadataBuilder(TestBase): dcp_group.world_size = 1 dcp_group.device_group = MagicMock() mock_get_dcp_group.return_value = dcp_group + + mock_pcp.world_size = 1 + pcp_group = MagicMock(spec=GroupCoordinator) + pcp_group.rank_in_group = 0 + pcp_group.world_size = 1 + pcp_group.device_group = MagicMock() + mock_get_pcp_group.return_value = pcp_group + common_metadata = MagicMock() common_metadata.actual_seq_lengths_q = [2, 4, 6, 8] @@ -452,6 +520,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): self.kv_cache_spec.head_size = 128 self.kv_cache_spec.num_heads = 32 + @patch("vllm_ascend.attention.mla_v1.get_pcp_group") @patch( "vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size" ) @@ -461,9 +530,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): @patch("torch.npu.is_available") def test_build_prefix_no_cache_metadata(self, mock_npu_available, mock_zeros, mock_get_ascend_config, - mock_dcp_world_size): + mock_dcp_world_size, + mock_get_pcp_group): mock_npu_available.return_value = False mock_dcp_world_size.return_value = 1 + pcp_group = MagicMock(spec=GroupCoordinator) + pcp_group.world_size = 1 + mock_get_pcp_group.return_value = pcp_group def zeros_override(*args, **kwargs): kwargs.pop('pin_memory', None) @@ -512,6 +585,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) + @patch("vllm_ascend.attention.mla_v1.get_pcp_group") @patch( "vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size" ) @@ -521,9 +595,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): @patch("torch.npu.is_available") def test_build_chunked_prefix_metadata(self, mock_npu_available, mock_zeros, mock_get_ascend_config, - mock_dcp_world_size): + mock_dcp_world_size, + mock_get_pcp_group): mock_npu_available.return_value = False mock_dcp_world_size.return_value = 1 + pcp_group = MagicMock(spec=GroupCoordinator) + pcp_group.world_size = 1 + mock_get_pcp_group.return_value = pcp_group def zeros_override(*args, **kwargs): kwargs.pop('pin_memory', None) @@ -573,14 +651,18 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) + @patch("vllm_ascend.attention.mla_v1.get_pcp_group") @patch( "vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size" ) @patch("vllm_ascend.attention.mla_v1.get_ascend_config") def test_build_decode_only_metadata(self, mock_get_ascend_config, - mock_dcp_world_size): + mock_dcp_world_size, + mock_get_pcp_group): mock_dcp_world_size.return_value = 1 - + pcp_group = MagicMock(spec=GroupCoordinator) + pcp_group.world_size = 1 + mock_get_pcp_group.return_value = pcp_group common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=torch.tensor([0, 1, 2, 3]), query_start_loc_cpu=torch.tensor([0, 1, 2, 3]), @@ -622,14 +704,18 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) + @patch("vllm_ascend.attention.mla_v1.get_pcp_group") @patch( "vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size" ) @patch("vllm_ascend.attention.mla_v1.get_ascend_config") def test_build_for_graph_capture_decode_only(self, mock_get_ascend_config, - mock_dcp_world_size): + mock_dcp_world_size, + mock_get_pcp_group): mock_dcp_world_size.return_value = 1 - + pcp_group = MagicMock(spec=GroupCoordinator) + pcp_group.world_size = 1 + mock_get_pcp_group.return_value = pcp_group common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=torch.tensor([0, 1, 2, 3]), query_start_loc_cpu=torch.tensor([0, 1, 2, 3]), @@ -672,14 +758,18 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) + @patch("vllm_ascend.attention.mla_v1.get_pcp_group") @patch( "vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size" ) @patch("vllm_ascend.attention.mla_v1.get_ascend_config") def test_build_for_graph_capture_prefill(self, mock_get_ascend_config, - mock_dcp_world_size): + mock_dcp_world_size, + mock_get_pcp_group): mock_dcp_world_size.return_value = 1 - + pcp_group = MagicMock(spec=GroupCoordinator) + pcp_group.world_size = 1 + mock_get_pcp_group.return_value = pcp_group common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=torch.tensor([0, 3, 7]), query_start_loc_cpu=torch.tensor([0, 3, 7]), @@ -716,6 +806,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): class TestAscendMLAImpl(TestBase): + @patch('vllm.distributed.parallel_state._PCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch("vllm.distributed.get_decode_context_model_parallel_world_size", @@ -727,13 +819,16 @@ class TestAscendMLAImpl(TestBase): @patch("vllm_ascend.attention.mla_v1.get_current_vllm_config") @patch("vllm_ascend.attention.mla_v1.get_ascend_config") def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size, - mock_tp, mock_get_dcp_size, mock_dcp): + mock_tp, mock_get_dcp_size, mock_dcp, mock_pcp): mock_tp.world_size = 2 mock_tp.rank_in_group = MagicMock() mock_tp.device_group = MagicMock() mock_dcp.world_size = 1 mock_dcp.rank_in_group = MagicMock() mock_dcp.device_group = MagicMock() + mock_pcp.world_size = 1 + mock_pcp.rank_in_group = MagicMock() + mock_pcp.device_group = MagicMock() vllm_config = MagicMock() speculative_config = MagicMock() model_config = MagicMock() diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index a0edff8e..1179e328 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -1031,6 +1031,11 @@ class TestMooncakeConnectorWorker(unittest.TestCase): self.mock_dcp = MagicMock() self.mock_dcp.world_size = 1 + self.mock_pcp_group = MagicMock(spec=GroupCoordinator) + self.mock_pcp_group.rank_in_group = 0 + self.mock_pcp_group.world_size = 1 + self.mock_pcp_group.device_group = MagicMock() + self.patches = [ patch( 'vllm_ascend.distributed.mooncake_layerwise_connector.envs_ascend.PHYSICAL_DEVICES', @@ -1069,7 +1074,9 @@ class TestMooncakeConnectorWorker(unittest.TestCase): return_value=self.mock_dcp), patch( 'vllm.distributed.get_decode_context_model_parallel_world_size', - return_value=1) + return_value=1), + patch('vllm_ascend.distributed.mooncake_connector.get_pcp_group', + return_value=self.mock_pcp_group), ] for p in self.patches: diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index d51981dc..35374dd8 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -245,6 +245,10 @@ class TestNPUPlatform(TestBase): ) vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.parallel_config.enable_expert_parallel = False + vllm_config.parallel_config.decode_context_parallel_size = 1 + vllm_config.parallel_config.prefill_context_parallel_size = 1 + vllm_config.parallel_config.decode_context_parallel_size = 1 + vllm_config.parallel_config.prefill_context_parallel_size = 1 vllm_config.parallel_config.tensor_parallel_size = 1 mock_init_recompute.return_value = MagicMock() vllm_config.scheduler_config = MagicMock() @@ -275,6 +279,8 @@ class TestNPUPlatform(TestBase): ) vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.model_config = None + vllm_config.parallel_config.decode_context_parallel_size = 1 + vllm_config.parallel_config.prefill_context_parallel_size = 1 vllm_config.parallel_config.tensor_parallel_size = 1 mock_init_recompute.return_value = MagicMock() vllm_config.scheduler_config = MagicMock() @@ -300,6 +306,8 @@ class TestNPUPlatform(TestBase): ) vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.model_config.enforce_eager = True + vllm_config.parallel_config.decode_context_parallel_size = 1 + vllm_config.parallel_config.prefill_context_parallel_size = 1 vllm_config.parallel_config.tensor_parallel_size = 1 mock_init_recompute.return_value = MagicMock() vllm_config.scheduler_config = MagicMock() @@ -338,6 +346,8 @@ class TestNPUPlatform(TestBase): ) vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.model_config.enforce_eager = False + vllm_config.parallel_config.decode_context_parallel_size = 1 + vllm_config.parallel_config.prefill_context_parallel_size = 1 vllm_config.parallel_config.tensor_parallel_size = 1 mock_init_recompute.return_value = MagicMock() vllm_config.scheduler_config = MagicMock() @@ -409,6 +419,8 @@ class TestNPUPlatform(TestBase): mock_init_ascend.return_value = mock_ascend_config vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.model_config.enforce_eager = False + vllm_config.parallel_config.decode_context_parallel_size = 1 + vllm_config.parallel_config.prefill_context_parallel_size = 1 vllm_config.parallel_config.tensor_parallel_size = 1 mock_init_recompute.return_value = MagicMock() vllm_config.scheduler_config = MagicMock() @@ -446,6 +458,8 @@ class TestNPUPlatform(TestBase): vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.cache_config.block_size = None vllm_config.cache_config.enable_prefix_caching = True + vllm_config.parallel_config.decode_context_parallel_size = 1 + vllm_config.parallel_config.prefill_context_parallel_size = 1 vllm_config.parallel_config.tensor_parallel_size = 1 mock_init_recompute.return_value = MagicMock() vllm_config.scheduler_config = MagicMock() @@ -472,6 +486,8 @@ class TestNPUPlatform(TestBase): ) vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.parallel_config.worker_cls = "auto" + vllm_config.parallel_config.decode_context_parallel_size = 1 + vllm_config.parallel_config.prefill_context_parallel_size = 1 vllm_config.parallel_config.tensor_parallel_size = 1 mock_init_recompute.return_value = MagicMock() vllm_config.scheduler_config = MagicMock() @@ -510,6 +526,8 @@ class TestNPUPlatform(TestBase): ) vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.compilation_config.custom_ops = [] + vllm_config.parallel_config.decode_context_parallel_size = 1 + vllm_config.parallel_config.prefill_context_parallel_size = 1 vllm_config.parallel_config.tensor_parallel_size = 1 mock_init_recompute.return_value = MagicMock() diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index b524e648..28a8e78b 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -26,10 +26,13 @@ import torch.nn as nn import torch_npu from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) +from vllm.attention.backends.registry import (AttentionBackendEnum, + register_backend) from vllm.config import VllmConfig from vllm.distributed import (get_dcp_group, get_decode_context_model_parallel_rank, - get_decode_context_model_parallel_world_size) + get_decode_context_model_parallel_world_size, + get_pcp_group) from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.utils import AttentionCGSupport @@ -41,19 +44,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, split_decodes_and_prefills) from vllm_ascend.compilation.acl_graph import (get_graph_params, update_graph_params_workspaces) -from vllm_ascend.utils import prefill_context_parallel_enable, weak_ref_tensors - -# isort: off -if prefill_context_parallel_enable(): - from vllm.distributed import (get_pcp_group, - get_prefill_context_model_parallel_rank, - get_prefill_context_model_parallel_world_size - ) - -# isort: on - -from vllm.attention.backends.registry import (AttentionBackendEnum, - register_backend) +from vllm_ascend.utils import weak_ref_tensors @register_backend(AttentionBackendEnum.CUSTOM, "ASCEND") @@ -255,10 +246,9 @@ class AscendAttentionMetadataBuilder: vllm_config.scheduler_config.max_num_batched_tokens, dtype=torch.uint8, device=device) - self.pcp_size = get_prefill_context_model_parallel_world_size( - ) if prefill_context_parallel_enable() else 1 - self.pcp_rank = get_prefill_context_model_parallel_rank( - ) if self.pcp_size > 1 else 0 + self.pcp_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group( + ).rank_in_group if self.pcp_size > 1 else 0 self.dcp_size = get_decode_context_model_parallel_world_size() self.dcp_rank = get_decode_context_model_parallel_rank( ) if self.dcp_size > 1 else 0 @@ -350,8 +340,7 @@ class AscendAttentionMetadataBuilder: context_lens_cpu = num_computed_tokens_cpu[ num_decodes:num_reqs] max_context_len_cpu = context_lens_cpu.max().item() - pcp_size = get_prefill_context_model_parallel_world_size( - ) if prefill_context_parallel_enable() else 1 + pcp_size = get_pcp_group().world_size if self.chunked_prefill_enabled and max_context_len_cpu > 0: local_context_lens_allranks = torch.tensor( num_computed_tokens_of_pcp_dcp @@ -539,10 +528,9 @@ class AscendAttentionBackendImpl(AttentionImpl): self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.key_cache = None self.value_cache = None - self.pcp_size = get_prefill_context_model_parallel_world_size( - ) if prefill_context_parallel_enable() else 1 - self.pcp_rank = get_prefill_context_model_parallel_rank( - ) if self.pcp_size > 1 else 0 + self.pcp_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group( + ).rank_in_group if self.pcp_size > 1 else 0 self.pcp_group = get_pcp_group( ).device_group if self.pcp_size > 1 else None diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 623b2712..610a6c2a 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -13,7 +13,7 @@ from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import (get_dcp_group, get_decode_context_model_parallel_rank, get_decode_context_model_parallel_world_size, - get_tensor_model_parallel_rank, + get_pcp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group) from vllm.forward_context import ForwardContext, get_forward_context @@ -37,17 +37,9 @@ from vllm_ascend.compilation.acl_graph import (get_graph_params, from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - is_enable_nz, prefill_context_parallel_enable, - weak_ref_tensors) + is_enable_nz, weak_ref_tensors) from vllm_ascend.worker.npu_input_batch import InputBatch -# isort: off -if prefill_context_parallel_enable(): - from vllm.distributed import (get_pcp_group, - get_prefill_context_model_parallel_rank, - get_prefill_context_model_parallel_world_size - ) -# isort: on if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -265,15 +257,13 @@ class AscendMLAMetadataBuilder: self.cos_cache = None self.sin_cache = None - self.pcp_size = get_prefill_context_model_parallel_world_size( - ) if prefill_context_parallel_enable() else 1 - self.pcp_rank = get_prefill_context_model_parallel_rank( - ) if self.pcp_size > 1 else 0 + self.pcp_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group( + ).rank_in_group if self.pcp_size > 1 else 0 self.dcp_size = get_decode_context_model_parallel_world_size() self.dcp_rank = get_decode_context_model_parallel_rank( ) if self.dcp_size > 1 else 0 - self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size if prefill_context_parallel_enable( - ) else 1 + self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size self.cp_virtual_block_size = self.cp_local_block_size * self.dcp_size * self.pcp_size decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs', 0) @@ -868,10 +858,9 @@ class AscendMLAImpl(MLAAttentionImpl): self.speculative_config = vllm_config.speculative_config self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO - self.pcp_size = get_prefill_context_model_parallel_world_size( - ) if prefill_context_parallel_enable() else 1 - self.pcp_rank = get_prefill_context_model_parallel_rank( - ) if self.pcp_size > 1 else 0 + self.pcp_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group( + ).rank_in_group if self.pcp_size > 1 else 0 self.pcp_group = get_pcp_group( ).device_group if self.pcp_size > 1 else None diff --git a/vllm_ascend/distributed/kvpool/pool_worker.py b/vllm_ascend/distributed/kvpool/pool_worker.py index b1dc53c3..09cf94be 100644 --- a/vllm_ascend/distributed/kvpool/pool_worker.py +++ b/vllm_ascend/distributed/kvpool/pool_worker.py @@ -6,7 +6,7 @@ import torch from vllm.config import VllmConfig from vllm.distributed import (get_decode_context_model_parallel_rank, get_decode_context_model_parallel_world_size, - get_tensor_model_parallel_rank, + get_pcp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.logger import logger from vllm.v1.core.kv_cache_utils import BlockHash @@ -22,14 +22,6 @@ from vllm_ascend.distributed.kvpool.config_data import ( from vllm_ascend.distributed.kvpool.kv_transfer import ( KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) -from vllm_ascend.utils import prefill_context_parallel_enable - -if prefill_context_parallel_enable(): - # isort: off - from vllm.distributed import (get_prefill_context_model_parallel_rank, - get_prefill_context_model_parallel_world_size - ) - # isort: on backend_map: Dict[str, Type[Backend]] = { "mooncake": MooncakeBackend, @@ -57,10 +49,9 @@ class KVPoolWorker: self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() - self.pcp_size = get_prefill_context_model_parallel_world_size( - ) if prefill_context_parallel_enable() else 1 - self.pcp_rank = get_prefill_context_model_parallel_rank( - ) if self.pcp_size > 1 else 0 + self.pcp_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group( + ).rank_in_group if self.pcp_size > 1 else 0 self.dcp_size = get_decode_context_model_parallel_world_size() self.dcp_rank = get_decode_context_model_parallel_rank( ) if self.dcp_size > 1 else 0 diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index e5e253c9..af54d9b2 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -22,10 +22,11 @@ from vllm import envs from vllm.config import KVTransferConfig, VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) -from vllm.distributed.parallel_state import (get_dcp_group, get_tp_group, - get_world_group) +from vllm.distributed.parallel_state import (get_dcp_group, get_pcp_group, + get_tp_group, get_world_group) from vllm.forward_context import ForwardContext from vllm.logger import logger +from vllm.utils.network_utils import get_ip from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig @@ -33,14 +34,7 @@ from vllm.v1.request import Request, RequestStatus import vllm_ascend.envs as envs_ascend from vllm_ascend.distributed.utils import get_transfer_timeout_value -from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, - prefill_context_parallel_enable) - -if prefill_context_parallel_enable(): - from vllm.distributed.parallel_state import \ - get_prefill_context_model_parallel_rank - -from vllm.utils.network_utils import get_ip +from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type TORCH_DTYPE_TO_NPU_DTYPE = { torch.half: llm_datadist.DataType.DT_FLOAT16, @@ -203,8 +197,7 @@ class LLMDataDistCMgrConnectorScheduler(): else: dp_rank_local = vllm_config.parallel_config.data_parallel_rank_local tp_size = self.vllm_config.parallel_config.tensor_parallel_size - self.pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size if prefill_context_parallel_enable( - ) else 1 + self.pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size self.port = dp_rank_local * self.pcp_size * tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT @@ -345,10 +338,8 @@ class LLMDataDistCMgrConnectorWorker(): self.tp_size = vllm_config.parallel_config.tensor_parallel_size self.tp_rank = get_tp_group().rank_in_group self.rank = get_world_group().rank - self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size if prefill_context_parallel_enable( - ) else 1 - self.pcp_rank = get_prefill_context_model_parallel_rank( - ) if prefill_context_parallel_enable() else 0 + self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size + self.pcp_rank = get_pcp_group().rank_in_group self.dcp_size = get_dcp_group().world_size self.local_ip = get_ip() self.kv_transfer_config: KVTransferConfig = vllm_config.kv_transfer_config diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index d978533b..2d376058 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -27,9 +27,10 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.distributed.parallel_state import ( get_decode_context_model_parallel_rank, - get_decode_context_model_parallel_world_size, + get_decode_context_model_parallel_world_size, get_pcp_group, get_tensor_model_parallel_rank, get_tp_group) from vllm.logger import logger +from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import RequestStatus @@ -38,16 +39,6 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config from vllm_ascend.distributed.mooncake_transfer_engine import global_te from vllm_ascend.distributed.utils import get_transfer_timeout_value -from vllm_ascend.utils import prefill_context_parallel_enable - -# isort: off -if prefill_context_parallel_enable(): - from vllm.distributed import (get_prefill_context_model_parallel_rank, - get_prefill_context_model_parallel_world_size - ) -# isort: on - -from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -730,8 +721,7 @@ class MooncakeConnectorScheduler: logger.info("Initializing Mooncake Scheduler %s", engine_id) self.side_channel_host = get_ip() - self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size \ - if prefill_context_parallel_enable() else 1 + self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \ vllm_config.parallel_config.data_parallel_size * \ @@ -898,10 +888,9 @@ class MooncakeConnectorWorker: self.dp_size = vllm_config.parallel_config.data_parallel_size_local self.kv_caches: dict[str, torch.Tensor] = {} self.side_channel_host = get_ip() - self.pcp_size = get_prefill_context_model_parallel_world_size( - ) if prefill_context_parallel_enable() else 1 - self.pcp_rank = get_prefill_context_model_parallel_rank( - ) if self.pcp_size > 1 else 0 + self.pcp_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group( + ).rank_in_group if self.pcp_size > 1 else 0 self.dcp_size = get_decode_context_model_parallel_world_size() self.dcp_rank = get_decode_context_model_parallel_rank( ) if self.dcp_size > 1 else 0 diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 00de0627..96d403f9 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -9,8 +9,7 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group, import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.utils import (flashcomm2_enable, - prefill_context_parallel_enable) +from vllm_ascend.utils import flashcomm2_enable # Currently, mc2 op need their own group coordinator. _MC2: Optional[GroupCoordinator] = None @@ -74,15 +73,10 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): # The layout of all ranks: ExternalDP * EP # ExternalDP is the data parallel group that is not part of the model, # every dp rank can generate independently (in verl integration). - if prefill_context_parallel_enable(): - all_ranks = torch.arange(world_size).reshape( - -1, parallel_config.data_parallel_size * - parallel_config.prefill_context_parallel_size * - parallel_config.tensor_parallel_size) - else: - all_ranks = torch.arange(world_size).reshape( - -1, parallel_config.data_parallel_size * - parallel_config.tensor_parallel_size) + all_ranks = torch.arange(world_size).reshape( + -1, parallel_config.data_parallel_size * + parallel_config.prefill_context_parallel_size * + parallel_config.tensor_parallel_size) pd_tp_ratio = get_ascend_config().pd_tp_ratio pd_head_ratio = get_ascend_config().pd_head_ratio diff --git a/vllm_ascend/ops/fused_moe/prepare_finalize.py b/vllm_ascend/ops/fused_moe/prepare_finalize.py index 48350ea8..b3b907b0 100644 --- a/vllm_ascend/ops/fused_moe/prepare_finalize.py +++ b/vllm_ascend/ops/fused_moe/prepare_finalize.py @@ -24,16 +24,13 @@ import torch.nn as nn import torch_npu from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( - get_dp_group, get_tensor_model_parallel_rank, + get_dp_group, get_pcp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe import FusedMoEConfig from vllm_ascend.utils import enable_sp, prefill_context_parallel_enable -if prefill_context_parallel_enable(): - from vllm.distributed import get_pcp_group - class QuantType(Enum): NONE = 0 diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 9958d06f..2eafd61b 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -33,11 +33,12 @@ from vllm_ascend.torchair.utils import (check_torchair_cache_exist, from vllm_ascend.utils import refresh_block_size # isort: off -from vllm_ascend.utils import ( - ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD, AscendDeviceType, - enable_sp, get_ascend_device_type, is_vl_model, - prefill_context_parallel_enable, update_aclgraph_sizes, - update_cudagraph_capture_sizes, update_default_aclgraph_sizes) +from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, + COMPRESSED_TENSORS_METHOD, AscendDeviceType, + enable_sp, get_ascend_device_type, is_vl_model, + update_aclgraph_sizes, + update_cudagraph_capture_sizes, + update_default_aclgraph_sizes) if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -329,7 +330,6 @@ class NPUPlatform(Platform): vllm_config.scheduler_config.SLO_limits_for_dynamic_batch = ascend_config.SLO_limits_for_dynamic_batch if vllm_config.kv_transfer_config is not None and \ - prefill_context_parallel_enable() and \ cache_config.block_size != parallel_config.cp_kv_cache_interleave_size and \ parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size > 1: raise AssertionError( diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index d6bf784f..89425c6a 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -7,6 +7,7 @@ import torch.nn as nn import torch.nn.functional as F from vllm.config import (CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, set_current_vllm_config) +from vllm.distributed import get_pcp_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -16,6 +17,8 @@ from vllm.model_executor.model_loader.utils import \ from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.utils.math_utils import cdiv +from vllm.utils.platform_utils import is_pin_memory_available +from vllm.utils.torch_utils import set_default_torch_dtype from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.core.sched.output import SchedulerOutput @@ -32,15 +35,8 @@ from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, update_mla_attn_params) from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, - prefill_context_parallel_enable, shared_expert_dp_enabled) -if prefill_context_parallel_enable(): - from vllm.distributed import get_pcp_group - -from vllm.utils.platform_utils import is_pin_memory_available -from vllm.utils.torch_utils import set_default_torch_dtype - logger = init_logger(__name__) PADDING_SLOT_ID = -1 diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index 3317a237..2d9e9569 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -2,14 +2,9 @@ from typing import Optional, Union import numpy as np import torch -from vllm.distributed import get_dcp_group +from vllm.distributed import get_dcp_group, get_pcp_group from vllm.utils.math_utils import cdiv -from vllm_ascend.utils import prefill_context_parallel_enable - -if prefill_context_parallel_enable(): - from vllm.distributed import get_pcp_group - class BlockTable: @@ -31,8 +26,7 @@ class BlockTable: self.physical_block_size = block_size try: - self.pcp_world_size = get_pcp_group( - ).world_size if prefill_context_parallel_enable() else 1 + self.pcp_world_size = get_pcp_group().world_size self.pcp_rank = get_pcp_group( ).rank_in_group if self.pcp_world_size > 1 else 0 self.dcp_world_size = get_dcp_group().world_size @@ -279,8 +273,7 @@ class MultiGroupBlockTable: # must be multiplied by dcp_world_size. try: dcp_world_size = get_dcp_group().world_size - pcp_world_size = get_pcp_group( - ).world_size if prefill_context_parallel_enable() else 1 + pcp_world_size = get_pcp_group().world_size except AssertionError: # DCP might not be initialized in testing dcp_world_size = 1 diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 6ef7bab1..c66914c7 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -52,7 +52,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import (get_dcp_group, get_dp_group, - get_pp_group, get_tp_group, + get_pcp_group, get_pp_group, + get_tp_group, is_global_first_rank) from vllm.forward_context import get_forward_context from vllm.logger import logger @@ -145,16 +146,9 @@ from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, AscendDeviceType, ProfileExecuteDuration, enable_sp, get_ascend_device_type, is_enable_nz, - is_moe_model, lmhead_tp_enable, - prefill_context_parallel_enable) + is_moe_model, lmhead_tp_enable) from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch -if prefill_context_parallel_enable(): - from vllm.distributed import get_pcp_group - from vllm.distributed.parallel_state import ( - get_prefill_context_model_parallel_rank, - get_prefill_context_model_parallel_world_size) - if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput @@ -290,10 +284,9 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.dcp_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group - self.pcp_size = get_prefill_context_model_parallel_world_size( - ) if prefill_context_parallel_enable() else 1 - self.pcp_rank = get_prefill_context_model_parallel_rank( - ) if self.pcp_size > 1 else 0 + self.pcp_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group( + ).rank_in_group if self.pcp_size > 1 else 0 decode_max_num_seqs = getattr(self.scheduler_config, 'decode_max_num_seqs', 0) self.max_num_reqs = max(self.scheduler_config.max_num_seqs, @@ -602,8 +595,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): if self.vllm_config.speculative_config else 0), kernel_block_sizes=[[self.vllm_config.cache_config.block_size]], cp_kv_cache_interleave_size=self.parallel_config. - cp_kv_cache_interleave_size - if prefill_context_parallel_enable() else 1, + cp_kv_cache_interleave_size, ) self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, dtype=torch.int64) @@ -2742,8 +2734,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): device=self.device) long_seq_metadata = self._generate_pcp_metadata(num_tokens) if long_seq_metadata is not None: - pcp_world_size = get_pcp_group( - ).world_size if prefill_context_parallel_enable() else 1 + pcp_world_size = get_pcp_group().world_size dcp_world_size = get_dcp_group().world_size num_computed_tokens_of_pcp_dcp = [[ [0] * dcp_world_size for _ in range(pcp_world_size) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 41b6abb9..f8b39fcf 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -53,7 +53,6 @@ from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import (check_ascend_device_type, is_enable_nz, - prefill_context_parallel_enable, register_ascend_customop, sleep_mode_enabled, try_register_lib) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner @@ -405,17 +404,11 @@ class NPUWorker(WorkerBase): init_distributed_environment(self.parallel_config.world_size, self.rank, self.distributed_init_method, self.local_rank, "hccl") - if prefill_context_parallel_enable(): - ensure_model_parallel_initialized( - self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size, - self.parallel_config.prefill_context_parallel_size, - self.parallel_config.decode_context_parallel_size) - else: - ensure_model_parallel_initialized( - self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size, - self.parallel_config.decode_context_parallel_size) + ensure_model_parallel_initialized( + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size, + self.parallel_config.prefill_context_parallel_size, + self.parallel_config.decode_context_parallel_size) init_ascend_model_parallel(self.parallel_config) ensure_kv_transfer_initialized(self.vllm_config) ensure_ec_transfer_initialized(self.vllm_config)