[feat]ds3.2 pcp support mtp and chunkprefill (#6917)

### What this PR does / why we need it?
ds3.2 pcp supports the combination of MTP and chunkprefill features.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2026-03-03 19:03:50 +08:00
committed by GitHub
parent b771ca9a47
commit 5b05b3a090
3 changed files with 95 additions and 60 deletions

View File

@@ -45,6 +45,14 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
self.block_size = (self.block_size * self.cp_virtual_block_size) // np.gcd(
self.block_size, self.cp_virtual_block_size
)
self.slot_mapping_buf = torch.empty(
(
vllm_config.scheduler_config.max_num_batched_tokens
+ 2 * self.pcp_size * vllm_config.scheduler_config.max_num_seqs,
),
dtype=torch.int32,
device=device,
)
def build(
self,
@@ -82,15 +90,31 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert long_seq_metadata is not None
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens_pcp_padded]
self.slot_mapping_buf[:num_actual_tokens_pcp_padded].copy_(
common_attn_metadata.slot_mapping[:num_actual_tokens_pcp_padded], non_blocking=True
)
if self.enable_mlapo:
slot_mapping[:num_decode_tokens] = slot_mapping[: num_decode_tokens * self.pcp_size : self.pcp_size]
slot_mapping[num_decode_tokens : num_decode_tokens * self.pcp_size].fill_(-1)
metadata_cls.slot_mapping = slot_mapping
self.slot_mapping_buf[:num_decode_tokens] = self.slot_mapping_buf[
: num_decode_tokens * self.pcp_size : self.pcp_size
]
self.slot_mapping_buf[num_decode_tokens : num_decode_tokens * self.pcp_size].fill_(-1)
elif self.speculative_config is not None and num_decodes > 0:
# when mtp, pcp_allgather_restore_idx=[696,-1,697,-1,560,-1,561,-1,100,101,102],
# slot_mapping should be [696,697,-1,-1,560,561,-1,-1,100,101,102]
num_tokens_per_request = num_decode_tokens // num_decodes
decode_slot_mapping = self.slot_mapping_buf[: num_decode_tokens * self.pcp_size].reshape(
num_decodes, -1
)
decode_slot_mapping[:, :num_tokens_per_request] = decode_slot_mapping[
:, : num_tokens_per_request * self.pcp_size : self.pcp_size
]
decode_slot_mapping[:, num_tokens_per_request : num_tokens_per_request * self.pcp_size].fill_(-1)
self.slot_mapping_buf[: num_decode_tokens * self.pcp_size] = decode_slot_mapping.flatten()
metadata_cls.slot_mapping = self.slot_mapping_buf[:num_actual_tokens_pcp_padded]
actual_seq_lengths_query = metadata_cls.cum_query_lens
if num_prefills > 0 and num_decode_tokens > 0:
prefill_q_cum_seqlens = (
actual_seq_lengths_query[num_decode_tokens:] - actual_seq_lengths_query[num_decode_tokens - 1]
actual_seq_lengths_query[num_decodes:] - actual_seq_lengths_query[num_decodes - 1]
)
else:
prefill_q_cum_seqlens = actual_seq_lengths_query
@@ -108,8 +132,9 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
) -> AscendPCPMetadata | None:
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert common_long_seq_metadata is not None
q_head_kv_lens = (seq_lens // 2) * (self.pcp_rank + 1)
q_tail_kv_lens = seq_lens * self.pcp_size - (seq_lens // 2) * self.pcp_rank
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(seq_lens.device)
q_head_kv_lens = (seq_lens // 2) * (self.pcp_rank + 1) + num_computed_tokens
q_tail_kv_lens = seq_lens * self.pcp_size - (seq_lens // 2) * self.pcp_rank + num_computed_tokens
return AscendPCPMetadata(
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
@@ -181,6 +206,7 @@ class AscendSFACPImpl(AscendSFAImpl):
return self._execute_sparse_flash_attention(
ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key
)
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens
num_prefills = attn_metadata.num_prefills
decode_attn_out = None
@@ -190,10 +216,10 @@ class AscendSFACPImpl(AscendSFAImpl):
q_pe[:num_decode_tokens],
kv,
key_rope,
block_table[:num_decode_tokens],
block_table[:num_decodes],
topk_indices[:num_decode_tokens],
actual_seq_lengths_query[:num_decode_tokens],
actual_seq_lengths_key[:num_decode_tokens],
actual_seq_lengths_query[:num_decodes],
actual_seq_lengths_key[:num_decodes],
)
if num_prefills < 1:
@@ -205,10 +231,10 @@ class AscendSFACPImpl(AscendSFAImpl):
ql_nope = ql_nope[num_decode_tokens:]
q_pe = q_pe[num_decode_tokens:]
topk_indices = topk_indices[num_decode_tokens:]
block_table = block_table[num_decode_tokens:]
block_table = block_table[num_decodes:]
# q head compute
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decode_tokens:]
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decodes:]
q_head_output = self._execute_sparse_flash_attention(
torch.index_select(ql_nope, 0, q_head_idx),
torch.index_select(q_pe, 0, q_head_idx),
@@ -221,7 +247,7 @@ class AscendSFACPImpl(AscendSFAImpl):
)
# q tail compute
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decode_tokens:]
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decodes:]
q_tail_output = self._execute_sparse_flash_attention(
torch.index_select(ql_nope, 0, q_tail_idx),
torch.index_select(q_pe, 0, q_tail_idx),
@@ -321,6 +347,7 @@ class AscendSFACPImpl(AscendSFAImpl):
)
# decode compute
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens
num_prefills = attn_metadata.num_prefills
decode_topk_indices = None
@@ -329,9 +356,9 @@ class AscendSFACPImpl(AscendSFAImpl):
q[:num_decode_tokens],
key,
weights[:num_decode_tokens],
actual_seq_lengths_query[:num_decode_tokens],
actual_seq_lengths_key[:num_decode_tokens],
block_table[:num_decode_tokens],
actual_seq_lengths_query[:num_decodes],
actual_seq_lengths_key[:num_decodes],
block_table[:num_decodes],
)
# prefill compute
@@ -339,14 +366,14 @@ class AscendSFACPImpl(AscendSFAImpl):
return decode_topk_indices
q = q[num_decode_tokens:]
weights = weights[num_decode_tokens:]
actual_seq_lengths_key = actual_seq_lengths_key[num_decode_tokens:]
block_table = block_table[num_decode_tokens:]
actual_seq_lengths_key = actual_seq_lengths_key[num_decodes:]
block_table = block_table[num_decodes:]
# pcp split for head and tail
q_head_idx = attn_metadata.sfa_cp_metadata.q_head_idx
q_tail_idx = attn_metadata.sfa_cp_metadata.q_tail_idx
# q head compute
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decode_tokens:]
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decodes:]
q_head_topk_indices = self._execute_indexer_select(
q=torch.index_select(q, 0, q_head_idx),
key=key,
@@ -357,7 +384,7 @@ class AscendSFACPImpl(AscendSFAImpl):
)
# q tail compute
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decode_tokens:]
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decodes:]
q_tail_topk_indices = self._execute_indexer_select(
q=torch.index_select(q, 0, q_tail_idx),
key=key,

View File

@@ -246,36 +246,7 @@ class NPUModelRunner(GPUModelRunner):
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
try:
self.dcp_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
except Exception:
self.dcp_size = 1
self.dcp_rank = 0
self.pcp_size = 1
self.pcp_rank = 0
if self.pcp_size > 1:
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
max_buffer_num_tokens = self.max_num_tokens
if self.pcp_size * self.dcp_size > 1:
max_buffer_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size
self.pcp_manager = PCPManager(
self.pcp_size,
self.pcp_rank,
self.dcp_size,
self.dcp_rank,
max_buffer_num_tokens,
self.max_num_reqs,
self.device,
self.vllm_config,
self.use_async_scheduling,
self.pin_memory,
)
# TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this
self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64)
self.sampler = AscendSampler()
self.attn_state: AscendAttentionState | None = None
@@ -310,6 +281,38 @@ class NPUModelRunner(GPUModelRunner):
use_mm_prefix=self.model_config is not None and self.model_config.is_mm_prefix_lm,
)
try:
self.dcp_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
except Exception:
self.dcp_size = 1
self.dcp_rank = 0
self.pcp_size = 1
self.pcp_rank = 0
if self.pcp_size > 1:
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
max_buffer_num_tokens = self.max_num_tokens
if self.pcp_size * self.dcp_size > 1:
max_buffer_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size
self.pcp_manager = PCPManager(
self.pcp_size,
self.pcp_rank,
self.dcp_size,
self.dcp_rank,
max_buffer_num_tokens,
self.max_num_reqs,
self.device,
self.vllm_config,
self.use_async_scheduling,
self.pin_memory,
self.use_sparse,
)
# TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this
self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64)
self._set_up_drafter()
# kv role

View File

@@ -56,6 +56,7 @@ class PCPManager:
vllm_config: VllmConfig,
use_async_scheduling: bool,
pin_memory: bool = False,
use_sparse: bool = False,
) -> None:
self.pcp_world_size = pcp_world_size
self.pcp_world_rank = pcp_rank
@@ -97,6 +98,7 @@ class PCPManager:
+ self.pcp_world_size * self.dcp_world_size * self.max_num_reqs
)
)
self.use_sparse = use_sparse
if self.speculative_config and self.pcp_world_size * self.dcp_world_size > 1:
self.input_ids_pcp_full = CpuGpuBuffer(
self.max_num_tokens, dtype=torch.int32, device=device, pin_memory=pin_memory
@@ -784,16 +786,19 @@ class PCPManager:
num_prefill_reqs = self.num_prefill_reqs
num_decode_reqs = self.num_decode_reqs
num_decode_reqs_flatten = ori_query_lens_cpu[:num_decode_reqs].sum().item()
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
)
block_table_tensor[:num_decode_reqs_flatten].copy_(
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
)
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
if num_reqs_padded > num_reqs:
pad_size = num_reqs_padded - num_reqs
ori_query_lens_cpu[-pad_size:] = torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item())
if not self.use_sparse:
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
)
block_table_tensor[:num_decode_reqs_flatten].copy_(
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
)
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
if num_reqs_padded > num_reqs:
pad_size = num_reqs_padded - num_reqs
ori_query_lens_cpu[-pad_size:] = torch.full(
[pad_size], ori_query_lens_cpu[-pad_size - 1].item()
)
pcp_unpad_mask = self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length]
long_seq_metadata = AscendPrefillContextParallelMetadata(
pcp_use_hybrid_attn=self.pcp_use_hybrid_attn,