[feat]ds3.2 pcp support mtp and chunkprefill (#6917)

### What this PR does / why we need it?
ds3.2 pcp supports the combination of MTP and chunkprefill features.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
weiguihua2
2026-03-03 19:03:50 +08:00
committed by GitHub
parent b771ca9a47
commit 5b05b3a090
3 changed files with 95 additions and 60 deletions

View File

@@ -45,6 +45,14 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
self.block_size = (self.block_size * self.cp_virtual_block_size) // np.gcd( self.block_size = (self.block_size * self.cp_virtual_block_size) // np.gcd(
self.block_size, self.cp_virtual_block_size self.block_size, self.cp_virtual_block_size
) )
self.slot_mapping_buf = torch.empty(
(
vllm_config.scheduler_config.max_num_batched_tokens
+ 2 * self.pcp_size * vllm_config.scheduler_config.max_num_seqs,
),
dtype=torch.int32,
device=device,
)
def build( def build(
self, self,
@@ -82,15 +90,31 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert long_seq_metadata is not None assert long_seq_metadata is not None
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens_pcp_padded] self.slot_mapping_buf[:num_actual_tokens_pcp_padded].copy_(
common_attn_metadata.slot_mapping[:num_actual_tokens_pcp_padded], non_blocking=True
)
if self.enable_mlapo: if self.enable_mlapo:
slot_mapping[:num_decode_tokens] = slot_mapping[: num_decode_tokens * self.pcp_size : self.pcp_size] self.slot_mapping_buf[:num_decode_tokens] = self.slot_mapping_buf[
slot_mapping[num_decode_tokens : num_decode_tokens * self.pcp_size].fill_(-1) : num_decode_tokens * self.pcp_size : self.pcp_size
metadata_cls.slot_mapping = slot_mapping ]
self.slot_mapping_buf[num_decode_tokens : num_decode_tokens * self.pcp_size].fill_(-1)
elif self.speculative_config is not None and num_decodes > 0:
# when mtp, pcp_allgather_restore_idx=[696,-1,697,-1,560,-1,561,-1,100,101,102],
# slot_mapping should be [696,697,-1,-1,560,561,-1,-1,100,101,102]
num_tokens_per_request = num_decode_tokens // num_decodes
decode_slot_mapping = self.slot_mapping_buf[: num_decode_tokens * self.pcp_size].reshape(
num_decodes, -1
)
decode_slot_mapping[:, :num_tokens_per_request] = decode_slot_mapping[
:, : num_tokens_per_request * self.pcp_size : self.pcp_size
]
decode_slot_mapping[:, num_tokens_per_request : num_tokens_per_request * self.pcp_size].fill_(-1)
self.slot_mapping_buf[: num_decode_tokens * self.pcp_size] = decode_slot_mapping.flatten()
metadata_cls.slot_mapping = self.slot_mapping_buf[:num_actual_tokens_pcp_padded]
actual_seq_lengths_query = metadata_cls.cum_query_lens actual_seq_lengths_query = metadata_cls.cum_query_lens
if num_prefills > 0 and num_decode_tokens > 0: if num_prefills > 0 and num_decode_tokens > 0:
prefill_q_cum_seqlens = ( prefill_q_cum_seqlens = (
actual_seq_lengths_query[num_decode_tokens:] - actual_seq_lengths_query[num_decode_tokens - 1] actual_seq_lengths_query[num_decodes:] - actual_seq_lengths_query[num_decodes - 1]
) )
else: else:
prefill_q_cum_seqlens = actual_seq_lengths_query prefill_q_cum_seqlens = actual_seq_lengths_query
@@ -108,8 +132,9 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
) -> AscendPCPMetadata | None: ) -> AscendPCPMetadata | None:
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
assert common_long_seq_metadata is not None assert common_long_seq_metadata is not None
q_head_kv_lens = (seq_lens // 2) * (self.pcp_rank + 1) num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(seq_lens.device)
q_tail_kv_lens = seq_lens * self.pcp_size - (seq_lens // 2) * self.pcp_rank q_head_kv_lens = (seq_lens // 2) * (self.pcp_rank + 1) + num_computed_tokens
q_tail_kv_lens = seq_lens * self.pcp_size - (seq_lens // 2) * self.pcp_rank + num_computed_tokens
return AscendPCPMetadata( return AscendPCPMetadata(
q_head_idx=common_long_seq_metadata.q_head_idx_tensor, q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor, q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
@@ -181,6 +206,7 @@ class AscendSFACPImpl(AscendSFAImpl):
return self._execute_sparse_flash_attention( return self._execute_sparse_flash_attention(
ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key
) )
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
num_prefills = attn_metadata.num_prefills num_prefills = attn_metadata.num_prefills
decode_attn_out = None decode_attn_out = None
@@ -190,10 +216,10 @@ class AscendSFACPImpl(AscendSFAImpl):
q_pe[:num_decode_tokens], q_pe[:num_decode_tokens],
kv, kv,
key_rope, key_rope,
block_table[:num_decode_tokens], block_table[:num_decodes],
topk_indices[:num_decode_tokens], topk_indices[:num_decode_tokens],
actual_seq_lengths_query[:num_decode_tokens], actual_seq_lengths_query[:num_decodes],
actual_seq_lengths_key[:num_decode_tokens], actual_seq_lengths_key[:num_decodes],
) )
if num_prefills < 1: if num_prefills < 1:
@@ -205,10 +231,10 @@ class AscendSFACPImpl(AscendSFAImpl):
ql_nope = ql_nope[num_decode_tokens:] ql_nope = ql_nope[num_decode_tokens:]
q_pe = q_pe[num_decode_tokens:] q_pe = q_pe[num_decode_tokens:]
topk_indices = topk_indices[num_decode_tokens:] topk_indices = topk_indices[num_decode_tokens:]
block_table = block_table[num_decode_tokens:] block_table = block_table[num_decodes:]
# q head compute # q head compute
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decode_tokens:] q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decodes:]
q_head_output = self._execute_sparse_flash_attention( q_head_output = self._execute_sparse_flash_attention(
torch.index_select(ql_nope, 0, q_head_idx), torch.index_select(ql_nope, 0, q_head_idx),
torch.index_select(q_pe, 0, q_head_idx), torch.index_select(q_pe, 0, q_head_idx),
@@ -221,7 +247,7 @@ class AscendSFACPImpl(AscendSFAImpl):
) )
# q tail compute # q tail compute
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decode_tokens:] q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decodes:]
q_tail_output = self._execute_sparse_flash_attention( q_tail_output = self._execute_sparse_flash_attention(
torch.index_select(ql_nope, 0, q_tail_idx), torch.index_select(ql_nope, 0, q_tail_idx),
torch.index_select(q_pe, 0, q_tail_idx), torch.index_select(q_pe, 0, q_tail_idx),
@@ -321,6 +347,7 @@ class AscendSFACPImpl(AscendSFAImpl):
) )
# decode compute # decode compute
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
num_prefills = attn_metadata.num_prefills num_prefills = attn_metadata.num_prefills
decode_topk_indices = None decode_topk_indices = None
@@ -329,9 +356,9 @@ class AscendSFACPImpl(AscendSFAImpl):
q[:num_decode_tokens], q[:num_decode_tokens],
key, key,
weights[:num_decode_tokens], weights[:num_decode_tokens],
actual_seq_lengths_query[:num_decode_tokens], actual_seq_lengths_query[:num_decodes],
actual_seq_lengths_key[:num_decode_tokens], actual_seq_lengths_key[:num_decodes],
block_table[:num_decode_tokens], block_table[:num_decodes],
) )
# prefill compute # prefill compute
@@ -339,14 +366,14 @@ class AscendSFACPImpl(AscendSFAImpl):
return decode_topk_indices return decode_topk_indices
q = q[num_decode_tokens:] q = q[num_decode_tokens:]
weights = weights[num_decode_tokens:] weights = weights[num_decode_tokens:]
actual_seq_lengths_key = actual_seq_lengths_key[num_decode_tokens:] actual_seq_lengths_key = actual_seq_lengths_key[num_decodes:]
block_table = block_table[num_decode_tokens:] block_table = block_table[num_decodes:]
# pcp split for head and tail # pcp split for head and tail
q_head_idx = attn_metadata.sfa_cp_metadata.q_head_idx q_head_idx = attn_metadata.sfa_cp_metadata.q_head_idx
q_tail_idx = attn_metadata.sfa_cp_metadata.q_tail_idx q_tail_idx = attn_metadata.sfa_cp_metadata.q_tail_idx
# q head compute # q head compute
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decode_tokens:] q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decodes:]
q_head_topk_indices = self._execute_indexer_select( q_head_topk_indices = self._execute_indexer_select(
q=torch.index_select(q, 0, q_head_idx), q=torch.index_select(q, 0, q_head_idx),
key=key, key=key,
@@ -357,7 +384,7 @@ class AscendSFACPImpl(AscendSFAImpl):
) )
# q tail compute # q tail compute
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decode_tokens:] q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decodes:]
q_tail_topk_indices = self._execute_indexer_select( q_tail_topk_indices = self._execute_indexer_select(
q=torch.index_select(q, 0, q_tail_idx), q=torch.index_select(q, 0, q_tail_idx),
key=key, key=key,

View File

@@ -246,36 +246,7 @@ class NPUModelRunner(GPUModelRunner):
self.max_num_reqs = self.scheduler_config.max_num_seqs self.max_num_reqs = self.scheduler_config.max_num_seqs
self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_size = vllm_config.parallel_config.data_parallel_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.dp_rank = vllm_config.parallel_config.data_parallel_rank
try:
self.dcp_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
except Exception:
self.dcp_size = 1
self.dcp_rank = 0
self.pcp_size = 1
self.pcp_rank = 0
if self.pcp_size > 1:
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
max_buffer_num_tokens = self.max_num_tokens
if self.pcp_size * self.dcp_size > 1:
max_buffer_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size
self.pcp_manager = PCPManager(
self.pcp_size,
self.pcp_rank,
self.dcp_size,
self.dcp_rank,
max_buffer_num_tokens,
self.max_num_reqs,
self.device,
self.vllm_config,
self.use_async_scheduling,
self.pin_memory,
)
# TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this
self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64)
self.sampler = AscendSampler() self.sampler = AscendSampler()
self.attn_state: AscendAttentionState | None = None self.attn_state: AscendAttentionState | None = None
@@ -310,6 +281,38 @@ class NPUModelRunner(GPUModelRunner):
use_mm_prefix=self.model_config is not None and self.model_config.is_mm_prefix_lm, use_mm_prefix=self.model_config is not None and self.model_config.is_mm_prefix_lm,
) )
try:
self.dcp_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
self.pcp_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
except Exception:
self.dcp_size = 1
self.dcp_rank = 0
self.pcp_size = 1
self.pcp_rank = 0
if self.pcp_size > 1:
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
max_buffer_num_tokens = self.max_num_tokens
if self.pcp_size * self.dcp_size > 1:
max_buffer_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size
self.pcp_manager = PCPManager(
self.pcp_size,
self.pcp_rank,
self.dcp_size,
self.dcp_rank,
max_buffer_num_tokens,
self.max_num_reqs,
self.device,
self.vllm_config,
self.use_async_scheduling,
self.pin_memory,
self.use_sparse,
)
# TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this
self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64)
self._set_up_drafter() self._set_up_drafter()
# kv role # kv role

View File

@@ -56,6 +56,7 @@ class PCPManager:
vllm_config: VllmConfig, vllm_config: VllmConfig,
use_async_scheduling: bool, use_async_scheduling: bool,
pin_memory: bool = False, pin_memory: bool = False,
use_sparse: bool = False,
) -> None: ) -> None:
self.pcp_world_size = pcp_world_size self.pcp_world_size = pcp_world_size
self.pcp_world_rank = pcp_rank self.pcp_world_rank = pcp_rank
@@ -97,6 +98,7 @@ class PCPManager:
+ self.pcp_world_size * self.dcp_world_size * self.max_num_reqs + self.pcp_world_size * self.dcp_world_size * self.max_num_reqs
) )
) )
self.use_sparse = use_sparse
if self.speculative_config and self.pcp_world_size * self.dcp_world_size > 1: if self.speculative_config and self.pcp_world_size * self.dcp_world_size > 1:
self.input_ids_pcp_full = CpuGpuBuffer( self.input_ids_pcp_full = CpuGpuBuffer(
self.max_num_tokens, dtype=torch.int32, device=device, pin_memory=pin_memory self.max_num_tokens, dtype=torch.int32, device=device, pin_memory=pin_memory
@@ -784,16 +786,19 @@ class PCPManager:
num_prefill_reqs = self.num_prefill_reqs num_prefill_reqs = self.num_prefill_reqs
num_decode_reqs = self.num_decode_reqs num_decode_reqs = self.num_decode_reqs
num_decode_reqs_flatten = ori_query_lens_cpu[:num_decode_reqs].sum().item() num_decode_reqs_flatten = ori_query_lens_cpu[:num_decode_reqs].sum().item()
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_( if not self.use_sparse:
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone() block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
) block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
block_table_tensor[:num_decode_reqs_flatten].copy_( )
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0) block_table_tensor[:num_decode_reqs_flatten].copy_(
) block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs] )
if num_reqs_padded > num_reqs: block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
pad_size = num_reqs_padded - num_reqs if num_reqs_padded > num_reqs:
ori_query_lens_cpu[-pad_size:] = torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item()) pad_size = num_reqs_padded - num_reqs
ori_query_lens_cpu[-pad_size:] = torch.full(
[pad_size], ori_query_lens_cpu[-pad_size - 1].item()
)
pcp_unpad_mask = self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length] pcp_unpad_mask = self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length]
long_seq_metadata = AscendPrefillContextParallelMetadata( long_seq_metadata = AscendPrefillContextParallelMetadata(
pcp_use_hybrid_attn=self.pcp_use_hybrid_attn, pcp_use_hybrid_attn=self.pcp_use_hybrid_attn,