support cp&dcp (#3260)

### What this PR does / why we need it?
This PR adds the Prefill Context Parallelism (PCP) feature, which
corresponds to DCP. For specific implementation details, please refer to
the RFC https://github.com/vllm-project/vllm/issues/25749.
TL;DR: PCP enhances long-sequence inference capabilities by partitioning
the sequence dimension during the prefill stage.
### Does this PR introduce _any_ user-facing change?
The current implementation primarily includes the following changes:

Modified ModelRunner.py for CP partitioning logic for tokens;
Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to
PCP.
Modified block_tables.py to extend the KV cache storage based on
DCP&PCP;
Added necessary command-line arguments to control parallelism for PCP;
### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: LookAround <lixushi@huawei.com>
Signed-off-by: chenjie <chenjie137@huawei.com>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
Signed-off-by: Feng Liu <liufeng248@huawei.com>
Signed-off-by: gaojc <1055866782@qq.com>
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Signed-off-by: z50049692 <zhangmingwei11@huawei.com>
Co-authored-by: chenjie <chenjie137@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com>
Co-authored-by: Feng Liu <liufeng248@huawei.com>
Co-authored-by: gaojc <1055866782@qq.com>
Co-authored-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: z50049692 <zhangmingwei11@huawei.com>
Co-authored-by: w00896881 <wangzixuan40@huawei.com>
This commit is contained in:
LookAround0301
2025-10-24 10:32:01 +08:00
committed by GitHub
parent 2bcadcb9d5
commit b54d44e664
18 changed files with 1729 additions and 211 deletions

View File

@@ -5,6 +5,11 @@ import torch
from vllm.distributed import get_dcp_group
from vllm.utils import cdiv
from vllm_ascend.utils import prefill_context_parallel_enable
if prefill_context_parallel_enable():
from vllm.distributed import get_pcp_group
class BlockTable:
@@ -15,7 +20,8 @@ class BlockTable:
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
kernel_sizes: Union[list[int], None] = None):
kernel_sizes: Union[list[int], None] = None,
cp_kv_cache_interleave_size: int = 1):
self.max_num_reqs = max_num_reqs
self.max_num_blocks_per_req = max_num_blocks_per_req
self.max_num_batched_tokens = max_num_batched_tokens
@@ -80,13 +86,20 @@ class BlockTable:
dtype=torch.int64,
device=self.device)
try:
self.pcp_world_size = get_pcp_group(
).world_size if prefill_context_parallel_enable() else 1
self.pcp_rank = get_pcp_group(
).rank_in_group if self.pcp_world_size > 1 else 0
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.pcp_world_size = 1
self.pcp_rank = 0
self.kernel_sizes = kernel_sizes
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
def append_row(
self,
@@ -132,14 +145,14 @@ class BlockTable:
# here because M (max_model_len) is not necessarily divisible by
# block_size.
if self.dcp_world_size > 1:
if self.dcp_world_size * self.pcp_world_size > 1:
# Note(hc): The DCP implement store kvcache with an interleave
# style, the kvcache for the token whose token_idx is i is
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
# Use a "virtual block" which equals to world_size * block_size
# for block_table_indices calculation.
virtual_block_size = self.block_size * self.dcp_world_size
virtual_block_size = self.block_size * self.dcp_world_size * self.pcp_world_size
# IMPORTANT: In hybrid mode, positions are in logical block space,
# but we need to map them to the correct logical block table indices
@@ -157,9 +170,14 @@ class BlockTable:
# Use virtual_block_size for mask calculation, which marks local
# tokens.
virtual_block_offsets = positions % virtual_block_size
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
self.current_rank = self.dcp_world_size * self.pcp_rank + self.dcp_rank
mask = (virtual_block_offsets // self.cp_kv_cache_interleave_size %
(self.dcp_world_size *
self.pcp_world_size) == self.current_rank)
# Calculate local block_offsets
block_offsets = virtual_block_offsets // self.dcp_world_size
block_offsets = virtual_block_offsets \
// (self.dcp_world_size * self.pcp_world_size * self.cp_kv_cache_interleave_size) \
* self.cp_kv_cache_interleave_size + virtual_block_offsets % self.cp_kv_cache_interleave_size
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
@@ -242,16 +260,20 @@ class MultiGroupBlockTable:
device: torch.device,
block_sizes: list[int],
num_speculative_tokens: int = 0,
kernel_sizes: Optional[list[list[int]]] = None) -> None:
kernel_sizes: Optional[list[list[int]]] = None,
cp_kv_cache_interleave_size: int = 1) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
try:
dcp_world_size = get_dcp_group().world_size
cp_world_size = get_pcp_group(
).world_size if prefill_context_parallel_enable() else 1
except AssertionError:
# DCP might not be initialized in testing
dcp_world_size = 1
cp_world_size = 1
if kernel_sizes is None:
kernel_sizes = [[0]] * len(block_sizes)
@@ -267,9 +289,12 @@ class MultiGroupBlockTable:
self.block_tables = [
BlockTable(
block_size, max_num_reqs,
max(cdiv(max_model_len, block_size * dcp_world_size),
max(
cdiv(max_model_len,
block_size * dcp_world_size * cp_world_size),
1 + num_speculative_tokens), max_num_batched_tokens,
pin_memory, device, kernel_size_list)
pin_memory, device, kernel_size_list,
cp_kv_cache_interleave_size)
for block_size, kernel_size_list in zip(block_sizes, kernel_sizes)
]