support cp&dcp (#3260)
### What this PR does / why we need it? This PR adds the Prefill Context Parallelism (PCP) feature, which corresponds to DCP. For specific implementation details, please refer to the RFC https://github.com/vllm-project/vllm/issues/25749. TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage. ### Does this PR introduce _any_ user-facing change? The current implementation primarily includes the following changes: Modified ModelRunner.py for CP partitioning logic for tokens; Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to PCP. Modified block_tables.py to extend the KV cache storage based on DCP&PCP; Added necessary command-line arguments to control parallelism for PCP; ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: chenjie <chenjie137@huawei.com> Signed-off-by: Delphine-Nic <tanwenqin@huawei.com> Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: Feng Liu <liufeng248@huawei.com> Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: weiguihua2 <weiguihua2@huawei.com> Signed-off-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: chenjie <chenjie137@huawei.com> Co-authored-by: Delphine-Nic <tanwenqin@huawei.com> Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com> Co-authored-by: Feng Liu <liufeng248@huawei.com> Co-authored-by: gaojc <1055866782@qq.com> Co-authored-by: weiguihua2 <weiguihua2@huawei.com> Co-authored-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: w00896881 <wangzixuan40@huawei.com>
This commit is contained in:
@@ -130,6 +130,7 @@ class TestAscendMLADecodeMetadata(TestBase):
|
||||
class TestAscendMLAMetadata(TestBase):
|
||||
|
||||
def test_ascend_mla_metadata_default(self):
|
||||
num_actual_tokens_pcp_padded = 100
|
||||
num_actual_tokens = 100
|
||||
slot_mapping = torch.randn(100, 4, 1024)
|
||||
query_start_loc = torch.tensor([1, 2, 3, 4])
|
||||
@@ -150,12 +151,11 @@ class TestAscendMLAMetadata(TestBase):
|
||||
decode = None
|
||||
prefill = None
|
||||
|
||||
metadata = AscendMLAMetadata(num_actual_tokens, slot_mapping,
|
||||
query_start_loc, seq_lens, block_tables,
|
||||
num_decodes, num_decode_tokens,
|
||||
num_prefills, num_input_tokens,
|
||||
query_lens, head_dim, attn_mask,
|
||||
attn_state, decode, prefill)
|
||||
metadata = AscendMLAMetadata(
|
||||
num_actual_tokens_pcp_padded, num_actual_tokens, slot_mapping,
|
||||
query_start_loc, seq_lens, block_tables, num_decodes,
|
||||
num_decode_tokens, num_prefills, num_input_tokens, query_lens,
|
||||
head_dim, attn_mask, attn_state, decode, prefill)
|
||||
|
||||
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
|
||||
self.assertIs(metadata.slot_mapping, slot_mapping)
|
||||
@@ -266,6 +266,10 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
|
||||
class TestAscendMLAImpl(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch('vllm.distributed.parallel_state._TP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
|
||||
@@ -273,8 +277,13 @@ class TestAscendMLAImpl(TestBase):
|
||||
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size,
|
||||
mock_tp):
|
||||
mock_tp, mock_get_dcp_size, mock_dcp):
|
||||
mock_tp.world_size = 2
|
||||
mock_tp.rank_in_group = MagicMock()
|
||||
mock_tp.device_group = MagicMock()
|
||||
mock_dcp.world_size = 1
|
||||
mock_dcp.rank_in_group = MagicMock()
|
||||
mock_dcp.device_group = MagicMock()
|
||||
vllm_config = MagicMock()
|
||||
speculative_config = MagicMock()
|
||||
model_config = MagicMock()
|
||||
|
||||
Reference in New Issue
Block a user