[feature] chunkprefill support pcp&dcp (#3801)
### What this PR does / why we need it?
ChunkPrefill now can support Long Sequence Feature Pcp&Dcp
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI tests passed with self-test
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: Apocalypse990923-qshi <qiushixu@usc.edu>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <3834144971@qq.com>
This commit is contained in:
@@ -77,14 +77,6 @@ class BlockTable:
|
||||
self.block_table_np = self.block_table_cpu.numpy()
|
||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
try:
|
||||
self.pcp_world_size = get_pcp_group(
|
||||
).world_size if prefill_context_parallel_enable() else 1
|
||||
@@ -98,6 +90,20 @@ class BlockTable:
|
||||
self.dcp_rank = 0
|
||||
self.pcp_world_size = 1
|
||||
self.pcp_rank = 0
|
||||
|
||||
self.slot_mapping_cpu = torch.zeros(
|
||||
self.max_num_batched_tokens +
|
||||
2 * self.pcp_world_size * self.max_num_reqs,
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
self.slot_mapping = torch.zeros(
|
||||
self.max_num_batched_tokens +
|
||||
2 * self.pcp_world_size * self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
self.kernel_sizes = kernel_sizes
|
||||
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
|
||||
|
||||
@@ -148,7 +154,7 @@ class BlockTable:
|
||||
if self.dcp_world_size * self.pcp_world_size > 1:
|
||||
# Note(hc): The DCP implement store kvcache with an interleave
|
||||
# style, the kvcache for the token whose token_idx is i is
|
||||
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
|
||||
# always stored on the GPU whose dcp_rank equals i % pcp_world_size:
|
||||
|
||||
# Use a "virtual block" which equals to world_size * block_size
|
||||
# for block_table_indices calculation.
|
||||
@@ -268,12 +274,12 @@ class MultiGroupBlockTable:
|
||||
# must be multiplied by dcp_world_size.
|
||||
try:
|
||||
dcp_world_size = get_dcp_group().world_size
|
||||
cp_world_size = get_pcp_group(
|
||||
pcp_world_size = get_pcp_group(
|
||||
).world_size if prefill_context_parallel_enable() else 1
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
dcp_world_size = 1
|
||||
cp_world_size = 1
|
||||
pcp_world_size = 1
|
||||
|
||||
if kernel_sizes is None:
|
||||
kernel_sizes = [[0]] * len(block_sizes)
|
||||
@@ -291,7 +297,7 @@ class MultiGroupBlockTable:
|
||||
block_size, max_num_reqs,
|
||||
max(
|
||||
cdiv(max_model_len,
|
||||
block_size * dcp_world_size * cp_world_size),
|
||||
block_size * dcp_world_size * pcp_world_size),
|
||||
1 + num_speculative_tokens), max_num_batched_tokens,
|
||||
pin_memory, device, kernel_size_list,
|
||||
cp_kv_cache_interleave_size)
|
||||
|
||||
Reference in New Issue
Block a user