[feat]ds3.2 pcp support mtp and chunkprefill (#6917)

### What this PR does / why we need it?
ds3.2 pcp supports the combination of MTP and chunkprefill features.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2026-03-03 19:03:50 +08:00
committed by GitHub
parent b771ca9a47
commit 5b05b3a090
3 changed files with 95 additions and 60 deletions

View File

@@ -246,36 +246,7 @@ class NPUModelRunner(GPUModelRunner):
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
try:
self.dcp_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
except Exception:
self.dcp_size = 1
self.dcp_rank = 0
self.pcp_size = 1
self.pcp_rank = 0
if self.pcp_size > 1:
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
max_buffer_num_tokens = self.max_num_tokens
if self.pcp_size * self.dcp_size > 1:
max_buffer_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size
self.pcp_manager = PCPManager(
self.pcp_size,
self.pcp_rank,
self.dcp_size,
self.dcp_rank,
max_buffer_num_tokens,
self.max_num_reqs,
self.device,
self.vllm_config,
self.use_async_scheduling,
self.pin_memory,
)
# TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this
self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64)
self.sampler = AscendSampler()
self.attn_state: AscendAttentionState | None = None
@@ -310,6 +281,38 @@ class NPUModelRunner(GPUModelRunner):
use_mm_prefix=self.model_config is not None and self.model_config.is_mm_prefix_lm,
)
try:
self.dcp_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
except Exception:
self.dcp_size = 1
self.dcp_rank = 0
self.pcp_size = 1
self.pcp_rank = 0
if self.pcp_size > 1:
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
max_buffer_num_tokens = self.max_num_tokens
if self.pcp_size * self.dcp_size > 1:
max_buffer_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size
self.pcp_manager = PCPManager(
self.pcp_size,
self.pcp_rank,
self.dcp_size,
self.dcp_rank,
max_buffer_num_tokens,
self.max_num_reqs,
self.device,
self.vllm_config,
self.use_async_scheduling,
self.pin_memory,
self.use_sparse,
)
# TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this
self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64)
self._set_up_drafter()
# kv role

View File

@@ -56,6 +56,7 @@ class PCPManager:
vllm_config: VllmConfig,
use_async_scheduling: bool,
pin_memory: bool = False,
use_sparse: bool = False,
) -> None:
self.pcp_world_size = pcp_world_size
self.pcp_world_rank = pcp_rank
@@ -97,6 +98,7 @@ class PCPManager:
+ self.pcp_world_size * self.dcp_world_size * self.max_num_reqs
)
)
self.use_sparse = use_sparse
if self.speculative_config and self.pcp_world_size * self.dcp_world_size > 1:
self.input_ids_pcp_full = CpuGpuBuffer(
self.max_num_tokens, dtype=torch.int32, device=device, pin_memory=pin_memory
@@ -784,16 +786,19 @@ class PCPManager:
num_prefill_reqs = self.num_prefill_reqs
num_decode_reqs = self.num_decode_reqs
num_decode_reqs_flatten = ori_query_lens_cpu[:num_decode_reqs].sum().item()
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
)
block_table_tensor[:num_decode_reqs_flatten].copy_(
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
)
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
if num_reqs_padded > num_reqs:
pad_size = num_reqs_padded - num_reqs
ori_query_lens_cpu[-pad_size:] = torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item())
if not self.use_sparse:
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
)
block_table_tensor[:num_decode_reqs_flatten].copy_(
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
)
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
if num_reqs_padded > num_reqs:
pad_size = num_reqs_padded - num_reqs
ori_query_lens_cpu[-pad_size:] = torch.full(
[pad_size], ori_query_lens_cpu[-pad_size - 1].item()
)
pcp_unpad_mask = self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length]
long_seq_metadata = AscendPrefillContextParallelMetadata(
pcp_use_hybrid_attn=self.pcp_use_hybrid_attn,