support cp&dcp (#3260)

### What this PR does / why we need it?
This PR adds the Prefill Context Parallelism (PCP) feature, which
corresponds to DCP. For specific implementation details, please refer to
the RFC https://github.com/vllm-project/vllm/issues/25749.
TL;DR: PCP enhances long-sequence inference capabilities by partitioning
the sequence dimension during the prefill stage.
### Does this PR introduce _any_ user-facing change?
The current implementation primarily includes the following changes:

Modified ModelRunner.py for CP partitioning logic for tokens;
Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to
PCP.
Modified block_tables.py to extend the KV cache storage based on
DCP&PCP;
Added necessary command-line arguments to control parallelism for PCP;
### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: LookAround <lixushi@huawei.com>
Signed-off-by: chenjie <chenjie137@huawei.com>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
Signed-off-by: Feng Liu <liufeng248@huawei.com>
Signed-off-by: gaojc <1055866782@qq.com>
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Signed-off-by: z50049692 <zhangmingwei11@huawei.com>
Co-authored-by: chenjie <chenjie137@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com>
Co-authored-by: Feng Liu <liufeng248@huawei.com>
Co-authored-by: gaojc <1055866782@qq.com>
Co-authored-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: z50049692 <zhangmingwei11@huawei.com>
Co-authored-by: w00896881 <wangzixuan40@huawei.com>
This commit is contained in:
LookAround0301
2025-10-24 10:32:01 +08:00
committed by GitHub
parent 2bcadcb9d5
commit b54d44e664
18 changed files with 1729 additions and 211 deletions

View File

@@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
@@ -175,7 +176,19 @@ class TestAscendAttentionMetadataBuilder(TestBase):
class TestAscendAttentionBackendImpl(TestBase):
def setUp(self):
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group):
mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
self.layer = MagicMock()
self.layer.layer_name = "test_layer"
self.layer._k_scale_float = 1.0
@@ -328,6 +341,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.seq_lens = torch.tensor([10])
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
# layer.quant_method.apply.return_value = metadata
print(self.layer_no_quant._v_scale_float)
@@ -360,6 +375,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
output = self.impl.forward(layer,
@@ -390,6 +407,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 10
metadata.num_prefills = 0
layer = self.layer_no_quant
mock_get_forward_context.return_value = MagicMock(capturing=False)
@@ -496,6 +515,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
mock_get_forward_context.return_value = MagicMock(capturing=True)
@@ -527,6 +548,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 100
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 10
metadata.num_prefills = 0
layer = self.layer_no_quant
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
64), 1)
@@ -560,6 +583,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 10
metadata.num_prefills = 0
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
64), 1)
@@ -579,11 +604,13 @@ class TestAscendAttentionBackendImpl(TestBase):
assert output.shape == (10, 8 * 64)
@patch('torch.version')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
@patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
def test_forward_head_size_192(self, mock_vanilla_prefill,
mock_npu_reshape_and_cache, mock_is_310p):
mock_npu_reshape_and_cache, mock_is_310p,
mock_version):
"""Test forward pass when head_size is 192"""
self.impl.head_size = 192
@@ -598,7 +625,10 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 10
metadata.num_prefills = 0
layer = self.layer_no_quant
mock_version.cann = "8.4.RC1"
mock_vanilla_prefill.return_value = MagicMock()
output = self.impl_192.forward(layer,
@@ -612,10 +642,12 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_vanilla_prefill.assert_called_once()
assert output.shape == (10, 8 * 192)
@patch('torch.version')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention_splitfuse')
def test_forward_normal_v1_situation(self, mock_paged_attention,
mock_npu_reshape_and_cache):
mock_npu_reshape_and_cache,
mock_version):
"""Test forward pass in normal V1 situation"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
@@ -628,8 +660,12 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
mock_version.cann = "8.4.RC1"
output = self.impl.forward(layer,
query,
key,
@@ -641,13 +677,14 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch.version')
@patch('torch_npu.npu_format_cast')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention_splitfuse')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
def test_forward_310p_device(self, mock_is_310p, mock_paged_attention,
mock_npu_reshape_and_cache,
mock_npu_format_cast):
mock_npu_format_cast, mock_version):
"""Test forward pass on 310P device"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
@@ -660,9 +697,12 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
mock_npu_format_cast.return_value = metadata.attn_mask
mock_version.cann = "8.4.RC1"
output = self.impl.forward(layer,
query,
key,
@@ -687,6 +727,8 @@ class TestAscendAttentionBackendImpl(TestBase):
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
metadata.num_decodes = 0
metadata.num_prefills = 10
layer = self.layer_no_quant
with self.assertRaises(NotImplementedError):