support cp&dcp (#3260)

### What this PR does / why we need it?
This PR adds the Prefill Context Parallelism (PCP) feature, which
corresponds to DCP. For specific implementation details, please refer to
the RFC https://github.com/vllm-project/vllm/issues/25749.
TL;DR: PCP enhances long-sequence inference capabilities by partitioning
the sequence dimension during the prefill stage.
### Does this PR introduce _any_ user-facing change?
The current implementation primarily includes the following changes:

Modified ModelRunner.py for CP partitioning logic for tokens;
Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to
PCP.
Modified block_tables.py to extend the KV cache storage based on
DCP&PCP;
Added necessary command-line arguments to control parallelism for PCP;
### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: LookAround <lixushi@huawei.com>
Signed-off-by: chenjie <chenjie137@huawei.com>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
Signed-off-by: Feng Liu <liufeng248@huawei.com>
Signed-off-by: gaojc <1055866782@qq.com>
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Signed-off-by: z50049692 <zhangmingwei11@huawei.com>
Co-authored-by: chenjie <chenjie137@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com>
Co-authored-by: Feng Liu <liufeng248@huawei.com>
Co-authored-by: gaojc <1055866782@qq.com>
Co-authored-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: z50049692 <zhangmingwei11@huawei.com>
Co-authored-by: w00896881 <wangzixuan40@huawei.com>
This commit is contained in:
LookAround0301
2025-10-24 10:32:01 +08:00
committed by GitHub
parent 2bcadcb9d5
commit b54d44e664
18 changed files with 1729 additions and 211 deletions

View File

@@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
@@ -175,7 +176,19 @@ class TestAscendAttentionMetadataBuilder(TestBase):
class TestAscendAttentionBackendImpl(TestBase):
def setUp(self):
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group):
mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
self.layer = MagicMock()
self.layer.layer_name = "test_layer"
self.layer._k_scale_float = 1.0
@@ -328,6 +341,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.seq_lens = torch.tensor([10])
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
# layer.quant_method.apply.return_value = metadata
print(self.layer_no_quant._v_scale_float)
@@ -360,6 +375,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
output = self.impl.forward(layer,
@@ -390,6 +407,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 10
metadata.num_prefills = 0
layer = self.layer_no_quant
mock_get_forward_context.return_value = MagicMock(capturing=False)
@@ -496,6 +515,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
mock_get_forward_context.return_value = MagicMock(capturing=True)
@@ -527,6 +548,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 100
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 10
metadata.num_prefills = 0
layer = self.layer_no_quant
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
64), 1)
@@ -560,6 +583,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 10
metadata.num_prefills = 0
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
64), 1)
@@ -579,11 +604,13 @@ class TestAscendAttentionBackendImpl(TestBase):
assert output.shape == (10, 8 * 64)
@patch('torch.version')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
@patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
def test_forward_head_size_192(self, mock_vanilla_prefill,
mock_npu_reshape_and_cache, mock_is_310p):
mock_npu_reshape_and_cache, mock_is_310p,
mock_version):
"""Test forward pass when head_size is 192"""
self.impl.head_size = 192
@@ -598,7 +625,10 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 10
metadata.num_prefills = 0
layer = self.layer_no_quant
mock_version.cann = "8.4.RC1"
mock_vanilla_prefill.return_value = MagicMock()
output = self.impl_192.forward(layer,
@@ -612,10 +642,12 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_vanilla_prefill.assert_called_once()
assert output.shape == (10, 8 * 192)
@patch('torch.version')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention_splitfuse')
def test_forward_normal_v1_situation(self, mock_paged_attention,
mock_npu_reshape_and_cache):
mock_npu_reshape_and_cache,
mock_version):
"""Test forward pass in normal V1 situation"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
@@ -628,8 +660,12 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
mock_version.cann = "8.4.RC1"
output = self.impl.forward(layer,
query,
key,
@@ -641,13 +677,14 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch.version')
@patch('torch_npu.npu_format_cast')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention_splitfuse')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
def test_forward_310p_device(self, mock_is_310p, mock_paged_attention,
mock_npu_reshape_and_cache,
mock_npu_format_cast):
mock_npu_format_cast, mock_version):
"""Test forward pass on 310P device"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
@@ -660,9 +697,12 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
mock_npu_format_cast.return_value = metadata.attn_mask
mock_version.cann = "8.4.RC1"
output = self.impl.forward(layer,
query,
key,
@@ -687,6 +727,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
with self.assertRaises(NotImplementedError):

View File

@@ -130,6 +130,7 @@ class TestAscendMLADecodeMetadata(TestBase):
class TestAscendMLAMetadata(TestBase):
def test_ascend_mla_metadata_default(self):
num_actual_tokens_pcp_padded = 100
num_actual_tokens = 100
slot_mapping = torch.randn(100, 4, 1024)
query_start_loc = torch.tensor([1, 2, 3, 4])
@@ -150,12 +151,11 @@ class TestAscendMLAMetadata(TestBase):
decode = None
prefill = None
metadata = AscendMLAMetadata(num_actual_tokens, slot_mapping,
query_start_loc, seq_lens, block_tables,
num_decodes, num_decode_tokens,
num_prefills, num_input_tokens,
query_lens, head_dim, attn_mask,
attn_state, decode, prefill)
metadata = AscendMLAMetadata(
num_actual_tokens_pcp_padded, num_actual_tokens, slot_mapping,
query_start_loc, seq_lens, block_tables, num_decodes,
num_decode_tokens, num_prefills, num_input_tokens, query_lens,
head_dim, attn_mask, attn_state, decode, prefill)
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
self.assertIs(metadata.slot_mapping, slot_mapping)
@@ -266,6 +266,10 @@ class TestAscendMLAMetadataBuilder(TestBase):
class TestAscendMLAImpl(TestBase):
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
@patch('vllm.distributed.parallel_state._TP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
@@ -273,8 +277,13 @@ class TestAscendMLAImpl(TestBase):
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size,
mock_tp):
mock_tp, mock_get_dcp_size, mock_dcp):
mock_tp.world_size = 2
mock_tp.rank_in_group = MagicMock()
mock_tp.device_group = MagicMock()
mock_dcp.world_size = 1
mock_dcp.rank_in_group = MagicMock()
mock_dcp.device_group = MagicMock()
vllm_config = MagicMock()
speculative_config = MagicMock()
model_config = MagicMock()

View File

@@ -80,6 +80,8 @@ def test_read_agent_metadata():
worker.local_ip = worker_local_ip
worker.tp_rank = worker_tp_rank
worker.llm_datadist_role = LLMRole.PROMPT
worker.pcp_rank = 0
worker.tp_size = worker_tp_rank + 1
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = worker_visible_devices
agent_metadata = LLMDataDistCMgrConnectorWorker.read_agent_metadata(
worker, rank_table)

View File

@@ -149,7 +149,9 @@ def create_request(
range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234,
remote_tp_size=1)
remote_tp_size=1,
remote_cp_size=1,
remote_dcp_size=1)
max_tokens = 1 if do_remote_decode else max_tokens
sampling_params = SamplingParams(max_tokens=max_tokens)

View File

@@ -13,7 +13,7 @@
# This file is a part of the vllm-ascend project.
#
from types import SimpleNamespace
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
import pytest
import torch
@@ -100,6 +100,11 @@ def mock_distributed():
pp_group.rank_in_group = 0
pp_group.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mlp_tp_group = Mock(spec=GroupCoordinator)
mlp_tp_group.rank_in_group = 0
mlp_tp_group.world_size = 1
@@ -117,6 +122,9 @@ def mock_distributed():
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group", return_value=pp_group), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group",
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
patch('vllm.distributed.parallel_state.get_dcp_group', return_value=dcp_group), \
patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)), \
patch("vllm.distributed.get_decode_context_model_parallel_world_size", return_value=1),\
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
_PP=pp_group), \