[long_seq] remove long_seq env (#4660)
### What this PR does / why we need it? remove env VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL - vLLM version: v0.12.0 --------- Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: ZhangMingWei716 <2894054457@qq.com> Co-authored-by: ZhangMingWei716 <2894054457@qq.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -54,12 +54,16 @@ class TestAscendAttentionBackend(TestBase):
|
||||
|
||||
class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group):
|
||||
def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
@@ -67,6 +71,13 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
self.mock_vllm_config = MagicMock()
|
||||
self.mock_vllm_config.speculative_config = None
|
||||
self.mock_vllm_config.model_config.max_model_len = 640
|
||||
@@ -117,12 +128,16 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
|
||||
class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group):
|
||||
def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
@@ -130,6 +145,13 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
self.layer = MagicMock()
|
||||
self.layer.layer_name = "test_layer"
|
||||
self.layer._k_scale_float = 1.0
|
||||
|
||||
@@ -177,13 +177,17 @@ class TestAscendMLAMetadata(TestBase):
|
||||
|
||||
class TestAscendMLAMetadataBuilder(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_ascend_mla_metadata_builder_default(self, mock_get_dcp_size,
|
||||
mock_dcp, mock_get_dcp_group):
|
||||
mock_dcp, mock_get_dcp_group,
|
||||
mock_pcp, mock_get_pcp_group):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -201,6 +205,13 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
ascend_config = MagicMock()
|
||||
@@ -215,6 +226,9 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
builder.chunked_prefill_enabled,
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@@ -222,7 +236,9 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
return_value=1)
|
||||
def test_ascend_mla_metadata_builder_spec_decode(self, mock_get_dcp_size,
|
||||
mock_dcp,
|
||||
mock_get_dcp_group):
|
||||
mock_get_dcp_group,
|
||||
mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -240,6 +256,13 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
mock_spec_config = MagicMock()
|
||||
mock_spec_config.num_speculative_tokens = 3
|
||||
mock_vllm_config.speculative_config = mock_spec_config
|
||||
@@ -256,13 +279,17 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
builder.chunked_prefill_enabled,
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_ascend_mla_metadata_builder_build_full_graph(
|
||||
self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group):
|
||||
self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -280,6 +307,13 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
mock_spec_config = MagicMock()
|
||||
mock_spec_config.num_speculative_tokens = 1
|
||||
mock_spec_config.disable_padded_drafter_batch = True
|
||||
@@ -307,13 +341,16 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
[1, 2, 4, 5, 6, 6, 7, 8])
|
||||
self.assertEqual(metadata.decode.block_table.shape[0], 8)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_reorder_batch(self, mock_get_dcp_size, mock_dcp,
|
||||
mock_get_dcp_group):
|
||||
mock_get_dcp_group, mock_pcp, mock_get_pcp_group):
|
||||
ascend_config = MagicMock()
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
@@ -331,6 +368,13 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
||||
@@ -358,6 +402,9 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
self.assertTrue(modified)
|
||||
input_batch.swap_states.assert_called_once_with(1, 2)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@@ -365,7 +412,9 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
return_value=1)
|
||||
def test_pad_actual_seq_lens_q_mtp_disable_pad(self, mock_get_dcp_size,
|
||||
mock_dcp,
|
||||
mock_get_dcp_group):
|
||||
mock_get_dcp_group,
|
||||
mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -384,6 +433,13 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
||||
mock_device)
|
||||
input_seq_lens = [1, 2, 4, 5]
|
||||
@@ -394,14 +450,18 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
num_reqs_pad_size, num_reqs, input_seq_lens)
|
||||
self.assertEqual(output_seq_lens, expect_output)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_pad_actual_seq_lens_q_mtp_enable_pad(self, mock_get_dcp_size,
|
||||
mock_dcp,
|
||||
mock_get_dcp_group):
|
||||
mock_dcp, mock_get_dcp_group,
|
||||
mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -419,6 +479,14 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
common_metadata = MagicMock()
|
||||
common_metadata.actual_seq_lengths_q = [2, 4, 6, 8]
|
||||
|
||||
@@ -452,6 +520,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
self.kv_cache_spec.head_size = 128
|
||||
self.kv_cache_spec.num_heads = 32
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
|
||||
@patch(
|
||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
||||
)
|
||||
@@ -461,9 +530,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
@patch("torch.npu.is_available")
|
||||
def test_build_prefix_no_cache_metadata(self, mock_npu_available,
|
||||
mock_zeros, mock_get_ascend_config,
|
||||
mock_dcp_world_size):
|
||||
mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_npu_available.return_value = False
|
||||
mock_dcp_world_size.return_value = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
def zeros_override(*args, **kwargs):
|
||||
kwargs.pop('pin_memory', None)
|
||||
@@ -512,6 +585,7 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
|
||||
@patch(
|
||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
||||
)
|
||||
@@ -521,9 +595,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
@patch("torch.npu.is_available")
|
||||
def test_build_chunked_prefix_metadata(self, mock_npu_available,
|
||||
mock_zeros, mock_get_ascend_config,
|
||||
mock_dcp_world_size):
|
||||
mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_npu_available.return_value = False
|
||||
mock_dcp_world_size.return_value = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
def zeros_override(*args, **kwargs):
|
||||
kwargs.pop('pin_memory', None)
|
||||
@@ -573,14 +651,18 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
|
||||
@patch(
|
||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
||||
)
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
def test_build_decode_only_metadata(self, mock_get_ascend_config,
|
||||
mock_dcp_world_size):
|
||||
mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 1, 2, 3]),
|
||||
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
|
||||
@@ -622,14 +704,18 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
|
||||
@patch(
|
||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
||||
)
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
def test_build_for_graph_capture_decode_only(self, mock_get_ascend_config,
|
||||
mock_dcp_world_size):
|
||||
mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 1, 2, 3]),
|
||||
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
|
||||
@@ -672,14 +758,18 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_pcp_group")
|
||||
@patch(
|
||||
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
|
||||
)
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
def test_build_for_graph_capture_prefill(self, mock_get_ascend_config,
|
||||
mock_dcp_world_size):
|
||||
mock_dcp_world_size,
|
||||
mock_get_pcp_group):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 3, 7]),
|
||||
query_start_loc_cpu=torch.tensor([0, 3, 7]),
|
||||
@@ -716,6 +806,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
|
||||
class TestAscendMLAImpl(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
@@ -727,13 +819,16 @@ class TestAscendMLAImpl(TestBase):
|
||||
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size,
|
||||
mock_tp, mock_get_dcp_size, mock_dcp):
|
||||
mock_tp, mock_get_dcp_size, mock_dcp, mock_pcp):
|
||||
mock_tp.world_size = 2
|
||||
mock_tp.rank_in_group = MagicMock()
|
||||
mock_tp.device_group = MagicMock()
|
||||
mock_dcp.world_size = 1
|
||||
mock_dcp.rank_in_group = MagicMock()
|
||||
mock_dcp.device_group = MagicMock()
|
||||
mock_pcp.world_size = 1
|
||||
mock_pcp.rank_in_group = MagicMock()
|
||||
mock_pcp.device_group = MagicMock()
|
||||
vllm_config = MagicMock()
|
||||
speculative_config = MagicMock()
|
||||
model_config = MagicMock()
|
||||
|
||||
@@ -1031,6 +1031,11 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
self.mock_dcp = MagicMock()
|
||||
self.mock_dcp.world_size = 1
|
||||
|
||||
self.mock_pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
self.mock_pcp_group.rank_in_group = 0
|
||||
self.mock_pcp_group.world_size = 1
|
||||
self.mock_pcp_group.device_group = MagicMock()
|
||||
|
||||
self.patches = [
|
||||
patch(
|
||||
'vllm_ascend.distributed.mooncake_layerwise_connector.envs_ascend.PHYSICAL_DEVICES',
|
||||
@@ -1069,7 +1074,9 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
return_value=self.mock_dcp),
|
||||
patch(
|
||||
'vllm.distributed.get_decode_context_model_parallel_world_size',
|
||||
return_value=1)
|
||||
return_value=1),
|
||||
patch('vllm_ascend.distributed.mooncake_connector.get_pcp_group',
|
||||
return_value=self.mock_pcp_group),
|
||||
]
|
||||
|
||||
for p in self.patches:
|
||||
|
||||
@@ -245,6 +245,10 @@ class TestNPUPlatform(TestBase):
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.parallel_config.enable_expert_parallel = False
|
||||
vllm_config.parallel_config.decode_context_parallel_size = 1
|
||||
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
||||
vllm_config.parallel_config.decode_context_parallel_size = 1
|
||||
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
mock_init_recompute.return_value = MagicMock()
|
||||
vllm_config.scheduler_config = MagicMock()
|
||||
@@ -275,6 +279,8 @@ class TestNPUPlatform(TestBase):
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.model_config = None
|
||||
vllm_config.parallel_config.decode_context_parallel_size = 1
|
||||
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
mock_init_recompute.return_value = MagicMock()
|
||||
vllm_config.scheduler_config = MagicMock()
|
||||
@@ -300,6 +306,8 @@ class TestNPUPlatform(TestBase):
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.model_config.enforce_eager = True
|
||||
vllm_config.parallel_config.decode_context_parallel_size = 1
|
||||
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
mock_init_recompute.return_value = MagicMock()
|
||||
vllm_config.scheduler_config = MagicMock()
|
||||
@@ -338,6 +346,8 @@ class TestNPUPlatform(TestBase):
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.model_config.enforce_eager = False
|
||||
vllm_config.parallel_config.decode_context_parallel_size = 1
|
||||
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
mock_init_recompute.return_value = MagicMock()
|
||||
vllm_config.scheduler_config = MagicMock()
|
||||
@@ -409,6 +419,8 @@ class TestNPUPlatform(TestBase):
|
||||
mock_init_ascend.return_value = mock_ascend_config
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.model_config.enforce_eager = False
|
||||
vllm_config.parallel_config.decode_context_parallel_size = 1
|
||||
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
mock_init_recompute.return_value = MagicMock()
|
||||
vllm_config.scheduler_config = MagicMock()
|
||||
@@ -446,6 +458,8 @@ class TestNPUPlatform(TestBase):
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.cache_config.block_size = None
|
||||
vllm_config.cache_config.enable_prefix_caching = True
|
||||
vllm_config.parallel_config.decode_context_parallel_size = 1
|
||||
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
mock_init_recompute.return_value = MagicMock()
|
||||
vllm_config.scheduler_config = MagicMock()
|
||||
@@ -472,6 +486,8 @@ class TestNPUPlatform(TestBase):
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.parallel_config.worker_cls = "auto"
|
||||
vllm_config.parallel_config.decode_context_parallel_size = 1
|
||||
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
mock_init_recompute.return_value = MagicMock()
|
||||
vllm_config.scheduler_config = MagicMock()
|
||||
@@ -510,6 +526,8 @@ class TestNPUPlatform(TestBase):
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.compilation_config.custom_ops = []
|
||||
vllm_config.parallel_config.decode_context_parallel_size = 1
|
||||
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
mock_init_recompute.return_value = MagicMock()
|
||||
|
||||
|
||||
@@ -26,10 +26,13 @@ import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
from vllm.attention.backends.registry import (AttentionBackendEnum,
|
||||
register_backend)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_dcp_group,
|
||||
get_decode_context_model_parallel_rank,
|
||||
get_decode_context_model_parallel_world_size)
|
||||
get_decode_context_model_parallel_world_size,
|
||||
get_pcp_group)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
@@ -41,19 +44,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
from vllm_ascend.utils import prefill_context_parallel_enable, weak_ref_tensors
|
||||
|
||||
# isort: off
|
||||
if prefill_context_parallel_enable():
|
||||
from vllm.distributed import (get_pcp_group,
|
||||
get_prefill_context_model_parallel_rank,
|
||||
get_prefill_context_model_parallel_world_size
|
||||
)
|
||||
|
||||
# isort: on
|
||||
|
||||
from vllm.attention.backends.registry import (AttentionBackendEnum,
|
||||
register_backend)
|
||||
from vllm_ascend.utils import weak_ref_tensors
|
||||
|
||||
|
||||
@register_backend(AttentionBackendEnum.CUSTOM, "ASCEND")
|
||||
@@ -255,10 +246,9 @@ class AscendAttentionMetadataBuilder:
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||
) if self.pcp_size > 1 else 0
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group(
|
||||
).rank_in_group if self.pcp_size > 1 else 0
|
||||
self.dcp_size = get_decode_context_model_parallel_world_size()
|
||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||
) if self.dcp_size > 1 else 0
|
||||
@@ -350,8 +340,7 @@ class AscendAttentionMetadataBuilder:
|
||||
context_lens_cpu = num_computed_tokens_cpu[
|
||||
num_decodes:num_reqs]
|
||||
max_context_len_cpu = context_lens_cpu.max().item()
|
||||
pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
pcp_size = get_pcp_group().world_size
|
||||
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
||||
local_context_lens_allranks = torch.tensor(
|
||||
num_computed_tokens_of_pcp_dcp
|
||||
@@ -539,10 +528,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.key_cache = None
|
||||
self.value_cache = None
|
||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||
) if self.pcp_size > 1 else 0
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group(
|
||||
).rank_in_group if self.pcp_size > 1 else 0
|
||||
self.pcp_group = get_pcp_group(
|
||||
).device_group if self.pcp_size > 1 else None
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import (get_dcp_group,
|
||||
get_decode_context_model_parallel_rank,
|
||||
get_decode_context_model_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_pcp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
@@ -37,17 +37,9 @@ from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
is_enable_nz, prefill_context_parallel_enable,
|
||||
weak_ref_tensors)
|
||||
is_enable_nz, weak_ref_tensors)
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
# isort: off
|
||||
if prefill_context_parallel_enable():
|
||||
from vllm.distributed import (get_pcp_group,
|
||||
get_prefill_context_model_parallel_rank,
|
||||
get_prefill_context_model_parallel_world_size
|
||||
)
|
||||
# isort: on
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
@@ -265,15 +257,13 @@ class AscendMLAMetadataBuilder:
|
||||
self.cos_cache = None
|
||||
self.sin_cache = None
|
||||
|
||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||
) if self.pcp_size > 1 else 0
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group(
|
||||
).rank_in_group if self.pcp_size > 1 else 0
|
||||
self.dcp_size = get_decode_context_model_parallel_world_size()
|
||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||
) if self.dcp_size > 1 else 0
|
||||
self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size if prefill_context_parallel_enable(
|
||||
) else 1
|
||||
self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size
|
||||
self.cp_virtual_block_size = self.cp_local_block_size * self.dcp_size * self.pcp_size
|
||||
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
|
||||
0)
|
||||
@@ -868,10 +858,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||
|
||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||
) if self.pcp_size > 1 else 0
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group(
|
||||
).rank_in_group if self.pcp_size > 1 else 0
|
||||
self.pcp_group = get_pcp_group(
|
||||
).device_group if self.pcp_size > 1 else None
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_decode_context_model_parallel_rank,
|
||||
get_decode_context_model_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_pcp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import logger
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
@@ -22,14 +22,6 @@ from vllm_ascend.distributed.kvpool.config_data import (
|
||||
from vllm_ascend.distributed.kvpool.kv_transfer import (
|
||||
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
|
||||
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
|
||||
from vllm_ascend.utils import prefill_context_parallel_enable
|
||||
|
||||
if prefill_context_parallel_enable():
|
||||
# isort: off
|
||||
from vllm.distributed import (get_prefill_context_model_parallel_rank,
|
||||
get_prefill_context_model_parallel_world_size
|
||||
)
|
||||
# isort: on
|
||||
|
||||
backend_map: Dict[str, Type[Backend]] = {
|
||||
"mooncake": MooncakeBackend,
|
||||
@@ -57,10 +49,9 @@ class KVPoolWorker:
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||
) if self.pcp_size > 1 else 0
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group(
|
||||
).rank_in_group if self.pcp_size > 1 else 0
|
||||
self.dcp_size = get_decode_context_model_parallel_world_size()
|
||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||
) if self.dcp_size > 1 else 0
|
||||
|
||||
@@ -22,10 +22,11 @@ from vllm import envs
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import (get_dcp_group, get_tp_group,
|
||||
get_world_group)
|
||||
from vllm.distributed.parallel_state import (get_dcp_group, get_pcp_group,
|
||||
get_tp_group, get_world_group)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import logger
|
||||
from vllm.utils.network_utils import get_ip
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
@@ -33,14 +34,7 @@ from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.distributed.utils import get_transfer_timeout_value
|
||||
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
|
||||
prefill_context_parallel_enable)
|
||||
|
||||
if prefill_context_parallel_enable():
|
||||
from vllm.distributed.parallel_state import \
|
||||
get_prefill_context_model_parallel_rank
|
||||
|
||||
from vllm.utils.network_utils import get_ip
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
TORCH_DTYPE_TO_NPU_DTYPE = {
|
||||
torch.half: llm_datadist.DataType.DT_FLOAT16,
|
||||
@@ -203,8 +197,7 @@ class LLMDataDistCMgrConnectorScheduler():
|
||||
else:
|
||||
dp_rank_local = vllm_config.parallel_config.data_parallel_rank_local
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
self.pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size if prefill_context_parallel_enable(
|
||||
) else 1
|
||||
self.pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size
|
||||
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||
|
||||
self.port = dp_rank_local * self.pcp_size * tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT
|
||||
@@ -345,10 +338,8 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
self.rank = get_world_group().rank
|
||||
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size if prefill_context_parallel_enable(
|
||||
) else 1
|
||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||
) if prefill_context_parallel_enable() else 0
|
||||
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.local_ip = get_ip()
|
||||
self.kv_transfer_config: KVTransferConfig = vllm_config.kv_transfer_config
|
||||
|
||||
@@ -27,9 +27,10 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_decode_context_model_parallel_rank,
|
||||
get_decode_context_model_parallel_world_size,
|
||||
get_decode_context_model_parallel_world_size, get_pcp_group,
|
||||
get_tensor_model_parallel_rank, get_tp_group)
|
||||
from vllm.logger import logger
|
||||
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import RequestStatus
|
||||
@@ -38,16 +39,6 @@ import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
||||
from vllm_ascend.distributed.mooncake_transfer_engine import global_te
|
||||
from vllm_ascend.distributed.utils import get_transfer_timeout_value
|
||||
from vllm_ascend.utils import prefill_context_parallel_enable
|
||||
|
||||
# isort: off
|
||||
if prefill_context_parallel_enable():
|
||||
from vllm.distributed import (get_prefill_context_model_parallel_rank,
|
||||
get_prefill_context_model_parallel_world_size
|
||||
)
|
||||
# isort: on
|
||||
|
||||
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
@@ -730,8 +721,7 @@ class MooncakeConnectorScheduler:
|
||||
logger.info("Initializing Mooncake Scheduler %s", engine_id)
|
||||
|
||||
self.side_channel_host = get_ip()
|
||||
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size \
|
||||
if prefill_context_parallel_enable() else 1
|
||||
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||
self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \
|
||||
vllm_config.parallel_config.data_parallel_size * \
|
||||
@@ -898,10 +888,9 @@ class MooncakeConnectorWorker:
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size_local
|
||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||
self.side_channel_host = get_ip()
|
||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||
) if self.pcp_size > 1 else 0
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group(
|
||||
).rank_in_group if self.pcp_size > 1 else 0
|
||||
self.dcp_size = get_decode_context_model_parallel_world_size()
|
||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||
) if self.dcp_size > 1 else 0
|
||||
|
||||
@@ -9,8 +9,7 @@ from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import (flashcomm2_enable,
|
||||
prefill_context_parallel_enable)
|
||||
from vllm_ascend.utils import flashcomm2_enable
|
||||
|
||||
# Currently, mc2 op need their own group coordinator.
|
||||
_MC2: Optional[GroupCoordinator] = None
|
||||
@@ -74,15 +73,10 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
# The layout of all ranks: ExternalDP * EP
|
||||
# ExternalDP is the data parallel group that is not part of the model,
|
||||
# every dp rank can generate independently (in verl integration).
|
||||
if prefill_context_parallel_enable():
|
||||
all_ranks = torch.arange(world_size).reshape(
|
||||
-1, parallel_config.data_parallel_size *
|
||||
parallel_config.prefill_context_parallel_size *
|
||||
parallel_config.tensor_parallel_size)
|
||||
else:
|
||||
all_ranks = torch.arange(world_size).reshape(
|
||||
-1, parallel_config.data_parallel_size *
|
||||
parallel_config.tensor_parallel_size)
|
||||
all_ranks = torch.arange(world_size).reshape(
|
||||
-1, parallel_config.data_parallel_size *
|
||||
parallel_config.prefill_context_parallel_size *
|
||||
parallel_config.tensor_parallel_size)
|
||||
|
||||
pd_tp_ratio = get_ascend_config().pd_tp_ratio
|
||||
pd_head_ratio = get_ascend_config().pd_head_ratio
|
||||
|
||||
@@ -24,16 +24,13 @@ import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dp_group, get_tensor_model_parallel_rank,
|
||||
get_dp_group, get_pcp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.utils import enable_sp, prefill_context_parallel_enable
|
||||
|
||||
if prefill_context_parallel_enable():
|
||||
from vllm.distributed import get_pcp_group
|
||||
|
||||
|
||||
class QuantType(Enum):
|
||||
NONE = 0
|
||||
|
||||
@@ -33,11 +33,12 @@ from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
||||
from vllm_ascend.utils import refresh_block_size
|
||||
|
||||
# isort: off
|
||||
from vllm_ascend.utils import (
|
||||
ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD, AscendDeviceType,
|
||||
enable_sp, get_ascend_device_type, is_vl_model,
|
||||
prefill_context_parallel_enable, update_aclgraph_sizes,
|
||||
update_cudagraph_capture_sizes, update_default_aclgraph_sizes)
|
||||
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD,
|
||||
COMPRESSED_TENSORS_METHOD, AscendDeviceType,
|
||||
enable_sp, get_ascend_device_type, is_vl_model,
|
||||
update_aclgraph_sizes,
|
||||
update_cudagraph_capture_sizes,
|
||||
update_default_aclgraph_sizes)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
@@ -329,7 +330,6 @@ class NPUPlatform(Platform):
|
||||
vllm_config.scheduler_config.SLO_limits_for_dynamic_batch = ascend_config.SLO_limits_for_dynamic_batch
|
||||
|
||||
if vllm_config.kv_transfer_config is not None and \
|
||||
prefill_context_parallel_enable() and \
|
||||
cache_config.block_size != parallel_config.cp_kv_cache_interleave_size and \
|
||||
parallel_config.decode_context_parallel_size * parallel_config.prefill_context_parallel_size > 1:
|
||||
raise AssertionError(
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from vllm.config import (CUDAGraphMode, VllmConfig,
|
||||
get_layers_from_vllm_config, set_current_vllm_config)
|
||||
from vllm.distributed import get_pcp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
@@ -16,6 +17,8 @@ from vllm.model_executor.model_loader.utils import \
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@@ -32,15 +35,8 @@ from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
update_mla_attn_params)
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
|
||||
prefill_context_parallel_enable,
|
||||
shared_expert_dp_enabled)
|
||||
|
||||
if prefill_context_parallel_enable():
|
||||
from vllm.distributed import get_pcp_group
|
||||
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
@@ -2,14 +2,9 @@ from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from vllm.distributed import get_dcp_group
|
||||
from vllm.distributed import get_dcp_group, get_pcp_group
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
from vllm_ascend.utils import prefill_context_parallel_enable
|
||||
|
||||
if prefill_context_parallel_enable():
|
||||
from vllm.distributed import get_pcp_group
|
||||
|
||||
|
||||
class BlockTable:
|
||||
|
||||
@@ -31,8 +26,7 @@ class BlockTable:
|
||||
self.physical_block_size = block_size
|
||||
|
||||
try:
|
||||
self.pcp_world_size = get_pcp_group(
|
||||
).world_size if prefill_context_parallel_enable() else 1
|
||||
self.pcp_world_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group(
|
||||
).rank_in_group if self.pcp_world_size > 1 else 0
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
@@ -279,8 +273,7 @@ class MultiGroupBlockTable:
|
||||
# must be multiplied by dcp_world_size.
|
||||
try:
|
||||
dcp_world_size = get_dcp_group().world_size
|
||||
pcp_world_size = get_pcp_group(
|
||||
).world_size if prefill_context_parallel_enable() else 1
|
||||
pcp_world_size = get_pcp_group().world_size
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
dcp_world_size = 1
|
||||
|
||||
@@ -52,7 +52,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.distributed.parallel_state import (get_dcp_group, get_dp_group,
|
||||
get_pp_group, get_tp_group,
|
||||
get_pcp_group, get_pp_group,
|
||||
get_tp_group,
|
||||
is_global_first_rank)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
@@ -145,16 +146,9 @@ from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
AscendDeviceType, ProfileExecuteDuration,
|
||||
enable_sp, get_ascend_device_type, is_enable_nz,
|
||||
is_moe_model, lmhead_tp_enable,
|
||||
prefill_context_parallel_enable)
|
||||
is_moe_model, lmhead_tp_enable)
|
||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
if prefill_context_parallel_enable():
|
||||
from vllm.distributed import get_pcp_group
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_prefill_context_model_parallel_rank,
|
||||
get_prefill_context_model_parallel_world_size)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr # type: ignore[import-untyped]
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
@@ -290,10 +284,9 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||
) if self.pcp_size > 1 else 0
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group(
|
||||
).rank_in_group if self.pcp_size > 1 else 0
|
||||
decode_max_num_seqs = getattr(self.scheduler_config,
|
||||
'decode_max_num_seqs', 0)
|
||||
self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
|
||||
@@ -602,8 +595,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
if self.vllm_config.speculative_config else 0),
|
||||
kernel_block_sizes=[[self.vllm_config.cache_config.block_size]],
|
||||
cp_kv_cache_interleave_size=self.parallel_config.
|
||||
cp_kv_cache_interleave_size
|
||||
if prefill_context_parallel_enable() else 1,
|
||||
cp_kv_cache_interleave_size,
|
||||
)
|
||||
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int64)
|
||||
@@ -2742,8 +2734,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
device=self.device)
|
||||
long_seq_metadata = self._generate_pcp_metadata(num_tokens)
|
||||
if long_seq_metadata is not None:
|
||||
pcp_world_size = get_pcp_group(
|
||||
).world_size if prefill_context_parallel_enable() else 1
|
||||
pcp_world_size = get_pcp_group().world_size
|
||||
dcp_world_size = get_dcp_group().world_size
|
||||
num_computed_tokens_of_pcp_dcp = [[
|
||||
[0] * dcp_world_size for _ in range(pcp_world_size)
|
||||
|
||||
@@ -53,7 +53,6 @@ from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import (check_ascend_device_type, is_enable_nz,
|
||||
prefill_context_parallel_enable,
|
||||
register_ascend_customop, sleep_mode_enabled,
|
||||
try_register_lib)
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
@@ -405,17 +404,11 @@ class NPUWorker(WorkerBase):
|
||||
init_distributed_environment(self.parallel_config.world_size,
|
||||
self.rank, self.distributed_init_method,
|
||||
self.local_rank, "hccl")
|
||||
if prefill_context_parallel_enable():
|
||||
ensure_model_parallel_initialized(
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
self.parallel_config.pipeline_parallel_size,
|
||||
self.parallel_config.prefill_context_parallel_size,
|
||||
self.parallel_config.decode_context_parallel_size)
|
||||
else:
|
||||
ensure_model_parallel_initialized(
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
self.parallel_config.pipeline_parallel_size,
|
||||
self.parallel_config.decode_context_parallel_size)
|
||||
ensure_model_parallel_initialized(
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
self.parallel_config.pipeline_parallel_size,
|
||||
self.parallel_config.prefill_context_parallel_size,
|
||||
self.parallel_config.decode_context_parallel_size)
|
||||
init_ascend_model_parallel(self.parallel_config)
|
||||
ensure_kv_transfer_initialized(self.vllm_config)
|
||||
ensure_ec_transfer_initialized(self.vllm_config)
|
||||
|
||||
Reference in New Issue
Block a user