[feat]ds3.2 pcp support mtp and chunkprefill (#6917)
### What this PR does / why we need it?
ds3.2 pcp supports the combination of MTP and chunkprefill features.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -56,6 +56,7 @@ class PCPManager:
|
||||
vllm_config: VllmConfig,
|
||||
use_async_scheduling: bool,
|
||||
pin_memory: bool = False,
|
||||
use_sparse: bool = False,
|
||||
) -> None:
|
||||
self.pcp_world_size = pcp_world_size
|
||||
self.pcp_world_rank = pcp_rank
|
||||
@@ -97,6 +98,7 @@ class PCPManager:
|
||||
+ self.pcp_world_size * self.dcp_world_size * self.max_num_reqs
|
||||
)
|
||||
)
|
||||
self.use_sparse = use_sparse
|
||||
if self.speculative_config and self.pcp_world_size * self.dcp_world_size > 1:
|
||||
self.input_ids_pcp_full = CpuGpuBuffer(
|
||||
self.max_num_tokens, dtype=torch.int32, device=device, pin_memory=pin_memory
|
||||
@@ -784,16 +786,19 @@ class PCPManager:
|
||||
num_prefill_reqs = self.num_prefill_reqs
|
||||
num_decode_reqs = self.num_decode_reqs
|
||||
num_decode_reqs_flatten = ori_query_lens_cpu[:num_decode_reqs].sum().item()
|
||||
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
|
||||
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
|
||||
)
|
||||
block_table_tensor[:num_decode_reqs_flatten].copy_(
|
||||
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
|
||||
)
|
||||
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
|
||||
if num_reqs_padded > num_reqs:
|
||||
pad_size = num_reqs_padded - num_reqs
|
||||
ori_query_lens_cpu[-pad_size:] = torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item())
|
||||
if not self.use_sparse:
|
||||
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
|
||||
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
|
||||
)
|
||||
block_table_tensor[:num_decode_reqs_flatten].copy_(
|
||||
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
|
||||
)
|
||||
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
|
||||
if num_reqs_padded > num_reqs:
|
||||
pad_size = num_reqs_padded - num_reqs
|
||||
ori_query_lens_cpu[-pad_size:] = torch.full(
|
||||
[pad_size], ori_query_lens_cpu[-pad_size - 1].item()
|
||||
)
|
||||
pcp_unpad_mask = self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length]
|
||||
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
||||
pcp_use_hybrid_attn=self.pcp_use_hybrid_attn,
|
||||
|
||||
Reference in New Issue
Block a user