[feat]ds3.2 pcp support mtp and chunkprefill (#6917)

### What this PR does / why we need it?
ds3.2 pcp supports the combination of MTP and chunkprefill features.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2026-03-03 19:03:50 +08:00
committed by GitHub
parent b771ca9a47
commit 5b05b3a090
3 changed files with 95 additions and 60 deletions

View File

@@ -56,6 +56,7 @@ class PCPManager:
vllm_config: VllmConfig,
use_async_scheduling: bool,
pin_memory: bool = False,
use_sparse: bool = False,
) -> None:
self.pcp_world_size = pcp_world_size
self.pcp_world_rank = pcp_rank
@@ -97,6 +98,7 @@ class PCPManager:
+ self.pcp_world_size * self.dcp_world_size * self.max_num_reqs
)
)
self.use_sparse = use_sparse
if self.speculative_config and self.pcp_world_size * self.dcp_world_size > 1:
self.input_ids_pcp_full = CpuGpuBuffer(
self.max_num_tokens, dtype=torch.int32, device=device, pin_memory=pin_memory
@@ -784,16 +786,19 @@ class PCPManager:
num_prefill_reqs = self.num_prefill_reqs
num_decode_reqs = self.num_decode_reqs
num_decode_reqs_flatten = ori_query_lens_cpu[:num_decode_reqs].sum().item()
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
)
block_table_tensor[:num_decode_reqs_flatten].copy_(
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
)
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
if num_reqs_padded > num_reqs:
pad_size = num_reqs_padded - num_reqs
ori_query_lens_cpu[-pad_size:] = torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item())
if not self.use_sparse:
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
)
block_table_tensor[:num_decode_reqs_flatten].copy_(
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
)
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
if num_reqs_padded > num_reqs:
pad_size = num_reqs_padded - num_reqs
ori_query_lens_cpu[-pad_size:] = torch.full(
[pad_size], ori_query_lens_cpu[-pad_size - 1].item()
)
pcp_unpad_mask = self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length]
long_seq_metadata = AscendPrefillContextParallelMetadata(
pcp_use_hybrid_attn=self.pcp_use_hybrid_attn,