[feat]ds3.2 pcp support mtp and chunkprefill (#6917)

### What this PR does / why we need it?
ds3.2 pcp supports the combination of MTP and chunkprefill features.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2026-03-03 19:03:50 +08:00
committed by GitHub
parent b771ca9a47
commit 5b05b3a090
3 changed files with 95 additions and 60 deletions

View File

@@ -45,6 +45,14 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
self.block_size = (self.block_size * self.cp_virtual_block_size) // np.gcd(
self.block_size, self.cp_virtual_block_size
)
self.slot_mapping_buf = torch.empty(
(
vllm_config.scheduler_config.max_num_batched_tokens
+ 2 * self.pcp_size * vllm_config.scheduler_config.max_num_seqs,
),
dtype=torch.int32,
device=device,
)
def build(
self,
@@ -82,15 +90,31 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert long_seq_metadata is not None
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens_pcp_padded]
self.slot_mapping_buf[:num_actual_tokens_pcp_padded].copy_(
common_attn_metadata.slot_mapping[:num_actual_tokens_pcp_padded], non_blocking=True
)
if self.enable_mlapo:
slot_mapping[:num_decode_tokens] = slot_mapping[: num_decode_tokens * self.pcp_size : self.pcp_size]
slot_mapping[num_decode_tokens : num_decode_tokens * self.pcp_size].fill_(-1)
metadata_cls.slot_mapping = slot_mapping
self.slot_mapping_buf[:num_decode_tokens] = self.slot_mapping_buf[
: num_decode_tokens * self.pcp_size : self.pcp_size
]
self.slot_mapping_buf[num_decode_tokens : num_decode_tokens * self.pcp_size].fill_(-1)
elif self.speculative_config is not None and num_decodes > 0:
# when mtp, pcp_allgather_restore_idx=[696,-1,697,-1,560,-1,561,-1,100,101,102],
# slot_mapping should be [696,697,-1,-1,560,561,-1,-1,100,101,102]
num_tokens_per_request = num_decode_tokens // num_decodes
decode_slot_mapping = self.slot_mapping_buf[: num_decode_tokens * self.pcp_size].reshape(
num_decodes, -1
)
decode_slot_mapping[:, :num_tokens_per_request] = decode_slot_mapping[
:, : num_tokens_per_request * self.pcp_size : self.pcp_size
]
decode_slot_mapping[:, num_tokens_per_request : num_tokens_per_request * self.pcp_size].fill_(-1)
self.slot_mapping_buf[: num_decode_tokens * self.pcp_size] = decode_slot_mapping.flatten()
metadata_cls.slot_mapping = self.slot_mapping_buf[:num_actual_tokens_pcp_padded]
actual_seq_lengths_query = metadata_cls.cum_query_lens
if num_prefills > 0 and num_decode_tokens > 0:
prefill_q_cum_seqlens = (
actual_seq_lengths_query[num_decode_tokens:] - actual_seq_lengths_query[num_decode_tokens - 1]
actual_seq_lengths_query[num_decodes:] - actual_seq_lengths_query[num_decodes - 1]
)
else:
prefill_q_cum_seqlens = actual_seq_lengths_query
@@ -108,8 +132,9 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
) -> AscendPCPMetadata | None:
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert common_long_seq_metadata is not None
q_head_kv_lens = (seq_lens // 2) * (self.pcp_rank + 1)
q_tail_kv_lens = seq_lens * self.pcp_size - (seq_lens // 2) * self.pcp_rank
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(seq_lens.device)
q_head_kv_lens = (seq_lens // 2) * (self.pcp_rank + 1) + num_computed_tokens
q_tail_kv_lens = seq_lens * self.pcp_size - (seq_lens // 2) * self.pcp_rank + num_computed_tokens
return AscendPCPMetadata(
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
@@ -181,6 +206,7 @@ class AscendSFACPImpl(AscendSFAImpl):
return self._execute_sparse_flash_attention(
ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key
)
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens
num_prefills = attn_metadata.num_prefills
decode_attn_out = None
@@ -190,10 +216,10 @@ class AscendSFACPImpl(AscendSFAImpl):
q_pe[:num_decode_tokens],
kv,
key_rope,
block_table[:num_decode_tokens],
block_table[:num_decodes],
topk_indices[:num_decode_tokens],
actual_seq_lengths_query[:num_decode_tokens],
actual_seq_lengths_key[:num_decode_tokens],
actual_seq_lengths_query[:num_decodes],
actual_seq_lengths_key[:num_decodes],
)
if num_prefills < 1:
@@ -205,10 +231,10 @@ class AscendSFACPImpl(AscendSFAImpl):
ql_nope = ql_nope[num_decode_tokens:]
q_pe = q_pe[num_decode_tokens:]
topk_indices = topk_indices[num_decode_tokens:]
block_table = block_table[num_decode_tokens:]
block_table = block_table[num_decodes:]
# q head compute
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decode_tokens:]
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decodes:]
q_head_output = self._execute_sparse_flash_attention(
torch.index_select(ql_nope, 0, q_head_idx),
torch.index_select(q_pe, 0, q_head_idx),
@@ -221,7 +247,7 @@ class AscendSFACPImpl(AscendSFAImpl):
)
# q tail compute
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decode_tokens:]
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decodes:]
q_tail_output = self._execute_sparse_flash_attention(
torch.index_select(ql_nope, 0, q_tail_idx),
torch.index_select(q_pe, 0, q_tail_idx),
@@ -321,6 +347,7 @@ class AscendSFACPImpl(AscendSFAImpl):
)
# decode compute
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens
num_prefills = attn_metadata.num_prefills
decode_topk_indices = None
@@ -329,9 +356,9 @@ class AscendSFACPImpl(AscendSFAImpl):
q[:num_decode_tokens],
key,
weights[:num_decode_tokens],
actual_seq_lengths_query[:num_decode_tokens],
actual_seq_lengths_key[:num_decode_tokens],
block_table[:num_decode_tokens],
actual_seq_lengths_query[:num_decodes],
actual_seq_lengths_key[:num_decodes],
block_table[:num_decodes],
)
# prefill compute
@@ -339,14 +366,14 @@ class AscendSFACPImpl(AscendSFAImpl):
return decode_topk_indices
q = q[num_decode_tokens:]
weights = weights[num_decode_tokens:]
actual_seq_lengths_key = actual_seq_lengths_key[num_decode_tokens:]
block_table = block_table[num_decode_tokens:]
actual_seq_lengths_key = actual_seq_lengths_key[num_decodes:]
block_table = block_table[num_decodes:]
# pcp split for head and tail
q_head_idx = attn_metadata.sfa_cp_metadata.q_head_idx
q_tail_idx = attn_metadata.sfa_cp_metadata.q_tail_idx
# q head compute
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decode_tokens:]
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decodes:]
q_head_topk_indices = self._execute_indexer_select(
q=torch.index_select(q, 0, q_head_idx),
key=key,
@@ -357,7 +384,7 @@ class AscendSFACPImpl(AscendSFAImpl):
)
# q tail compute
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decode_tokens:]
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decodes:]
q_tail_topk_indices = self._execute_indexer_select(
q=torch.index_select(q, 0, q_tail_idx),
key=key,