support cp&dcp (#3260)
### What this PR does / why we need it? This PR adds the Prefill Context Parallelism (PCP) feature, which corresponds to DCP. For specific implementation details, please refer to the RFC https://github.com/vllm-project/vllm/issues/25749. TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage. ### Does this PR introduce _any_ user-facing change? The current implementation primarily includes the following changes: Modified ModelRunner.py for CP partitioning logic for tokens; Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to PCP. Modified block_tables.py to extend the KV cache storage based on DCP&PCP; Added necessary command-line arguments to control parallelism for PCP; ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: chenjie <chenjie137@huawei.com> Signed-off-by: Delphine-Nic <tanwenqin@huawei.com> Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: Feng Liu <liufeng248@huawei.com> Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: weiguihua2 <weiguihua2@huawei.com> Signed-off-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: chenjie <chenjie137@huawei.com> Co-authored-by: Delphine-Nic <tanwenqin@huawei.com> Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com> Co-authored-by: Feng Liu <liufeng248@huawei.com> Co-authored-by: gaojc <1055866782@qq.com> Co-authored-by: weiguihua2 <weiguihua2@huawei.com> Co-authored-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: w00896881 <wangzixuan40@huawei.com>
This commit is contained in:
@@ -5,6 +5,11 @@ import torch
|
||||
from vllm.distributed import get_dcp_group
|
||||
from vllm.utils import cdiv
|
||||
|
||||
from vllm_ascend.utils import prefill_context_parallel_enable
|
||||
|
||||
if prefill_context_parallel_enable():
|
||||
from vllm.distributed import get_pcp_group
|
||||
|
||||
|
||||
class BlockTable:
|
||||
|
||||
@@ -15,7 +20,8 @@ class BlockTable:
|
||||
max_num_batched_tokens: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
kernel_sizes: Union[list[int], None] = None):
|
||||
kernel_sizes: Union[list[int], None] = None,
|
||||
cp_kv_cache_interleave_size: int = 1):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
@@ -80,13 +86,20 @@ class BlockTable:
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
try:
|
||||
self.pcp_world_size = get_pcp_group(
|
||||
).world_size if prefill_context_parallel_enable() else 1
|
||||
self.pcp_rank = get_pcp_group(
|
||||
).rank_in_group if self.pcp_world_size > 1 else 0
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.pcp_world_size = 1
|
||||
self.pcp_rank = 0
|
||||
self.kernel_sizes = kernel_sizes
|
||||
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
@@ -132,14 +145,14 @@ class BlockTable:
|
||||
# here because M (max_model_len) is not necessarily divisible by
|
||||
# block_size.
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
if self.dcp_world_size * self.pcp_world_size > 1:
|
||||
# Note(hc): The DCP implement store kvcache with an interleave
|
||||
# style, the kvcache for the token whose token_idx is i is
|
||||
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
|
||||
|
||||
# Use a "virtual block" which equals to world_size * block_size
|
||||
# for block_table_indices calculation.
|
||||
virtual_block_size = self.block_size * self.dcp_world_size
|
||||
virtual_block_size = self.block_size * self.dcp_world_size * self.pcp_world_size
|
||||
|
||||
# IMPORTANT: In hybrid mode, positions are in logical block space,
|
||||
# but we need to map them to the correct logical block table indices
|
||||
@@ -157,9 +170,14 @@ class BlockTable:
|
||||
# Use virtual_block_size for mask calculation, which marks local
|
||||
# tokens.
|
||||
virtual_block_offsets = positions % virtual_block_size
|
||||
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
|
||||
self.current_rank = self.dcp_world_size * self.pcp_rank + self.dcp_rank
|
||||
mask = (virtual_block_offsets // self.cp_kv_cache_interleave_size %
|
||||
(self.dcp_world_size *
|
||||
self.pcp_world_size) == self.current_rank)
|
||||
# Calculate local block_offsets
|
||||
block_offsets = virtual_block_offsets // self.dcp_world_size
|
||||
block_offsets = virtual_block_offsets \
|
||||
// (self.dcp_world_size * self.pcp_world_size * self.cp_kv_cache_interleave_size) \
|
||||
* self.cp_kv_cache_interleave_size + virtual_block_offsets % self.cp_kv_cache_interleave_size
|
||||
# Calculate slot_mapping
|
||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||
# Write final slots, use -1 for not-local
|
||||
@@ -242,16 +260,20 @@ class MultiGroupBlockTable:
|
||||
device: torch.device,
|
||||
block_sizes: list[int],
|
||||
num_speculative_tokens: int = 0,
|
||||
kernel_sizes: Optional[list[list[int]]] = None) -> None:
|
||||
kernel_sizes: Optional[list[list[int]]] = None,
|
||||
cp_kv_cache_interleave_size: int = 1) -> None:
|
||||
# Note(hc): each dcp rank only store
|
||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||
# so the block_size which used for calc max_num_blocks_per_req
|
||||
# must be multiplied by dcp_world_size.
|
||||
try:
|
||||
dcp_world_size = get_dcp_group().world_size
|
||||
cp_world_size = get_pcp_group(
|
||||
).world_size if prefill_context_parallel_enable() else 1
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
dcp_world_size = 1
|
||||
cp_world_size = 1
|
||||
|
||||
if kernel_sizes is None:
|
||||
kernel_sizes = [[0]] * len(block_sizes)
|
||||
@@ -267,9 +289,12 @@ class MultiGroupBlockTable:
|
||||
self.block_tables = [
|
||||
BlockTable(
|
||||
block_size, max_num_reqs,
|
||||
max(cdiv(max_model_len, block_size * dcp_world_size),
|
||||
max(
|
||||
cdiv(max_model_len,
|
||||
block_size * dcp_world_size * cp_world_size),
|
||||
1 + num_speculative_tokens), max_num_batched_tokens,
|
||||
pin_memory, device, kernel_size_list)
|
||||
pin_memory, device, kernel_size_list,
|
||||
cp_kv_cache_interleave_size)
|
||||
for block_size, kernel_size_list in zip(block_sizes, kernel_sizes)
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user