From 5b05b3a09070f6f4bd86f25fa623206f87f6bda4 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Tue, 3 Mar 2026 19:03:50 +0800 Subject: [PATCH] [feat]ds3.2 pcp support mtp and chunkprefill (#6917) ### What this PR does / why we need it? ds3.2 pcp supports the combination of MTP and chunkprefill features. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/15d76f74e2fdb12a95ea00f0ca283acf6219a2b7 --------- Signed-off-by: weiguihua2 --- .../attention/context_parallel/sfa_cp.py | 67 +++++++++++++------ vllm_ascend/worker/model_runner_v1.py | 63 ++++++++--------- vllm_ascend/worker/pcp_utils.py | 25 ++++--- 3 files changed, 95 insertions(+), 60 deletions(-) diff --git a/vllm_ascend/attention/context_parallel/sfa_cp.py b/vllm_ascend/attention/context_parallel/sfa_cp.py index 7fd3e739..73366682 100644 --- a/vllm_ascend/attention/context_parallel/sfa_cp.py +++ b/vllm_ascend/attention/context_parallel/sfa_cp.py @@ -45,6 +45,14 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder): self.block_size = (self.block_size * self.cp_virtual_block_size) // np.gcd( self.block_size, self.cp_virtual_block_size ) + self.slot_mapping_buf = torch.empty( + ( + vllm_config.scheduler_config.max_num_batched_tokens + + 2 * self.pcp_size * vllm_config.scheduler_config.max_num_seqs, + ), + dtype=torch.int32, + device=device, + ) def build( self, @@ -82,15 +90,31 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder): long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata assert long_seq_metadata is not None num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded - slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens_pcp_padded] + self.slot_mapping_buf[:num_actual_tokens_pcp_padded].copy_( + common_attn_metadata.slot_mapping[:num_actual_tokens_pcp_padded], non_blocking=True + ) if self.enable_mlapo: - slot_mapping[:num_decode_tokens] = slot_mapping[: num_decode_tokens * self.pcp_size : self.pcp_size] - slot_mapping[num_decode_tokens : num_decode_tokens * self.pcp_size].fill_(-1) - metadata_cls.slot_mapping = slot_mapping + self.slot_mapping_buf[:num_decode_tokens] = self.slot_mapping_buf[ + : num_decode_tokens * self.pcp_size : self.pcp_size + ] + self.slot_mapping_buf[num_decode_tokens : num_decode_tokens * self.pcp_size].fill_(-1) + elif self.speculative_config is not None and num_decodes > 0: + # when mtp, pcp_allgather_restore_idx=[696,-1,697,-1,560,-1,561,-1,100,101,102], + # slot_mapping should be [696,697,-1,-1,560,561,-1,-1,100,101,102] + num_tokens_per_request = num_decode_tokens // num_decodes + decode_slot_mapping = self.slot_mapping_buf[: num_decode_tokens * self.pcp_size].reshape( + num_decodes, -1 + ) + decode_slot_mapping[:, :num_tokens_per_request] = decode_slot_mapping[ + :, : num_tokens_per_request * self.pcp_size : self.pcp_size + ] + decode_slot_mapping[:, num_tokens_per_request : num_tokens_per_request * self.pcp_size].fill_(-1) + self.slot_mapping_buf[: num_decode_tokens * self.pcp_size] = decode_slot_mapping.flatten() + metadata_cls.slot_mapping = self.slot_mapping_buf[:num_actual_tokens_pcp_padded] actual_seq_lengths_query = metadata_cls.cum_query_lens if num_prefills > 0 and num_decode_tokens > 0: prefill_q_cum_seqlens = ( - actual_seq_lengths_query[num_decode_tokens:] - actual_seq_lengths_query[num_decode_tokens - 1] + actual_seq_lengths_query[num_decodes:] - actual_seq_lengths_query[num_decodes - 1] ) else: prefill_q_cum_seqlens = actual_seq_lengths_query @@ -108,8 +132,9 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder): ) -> AscendPCPMetadata | None: common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata assert common_long_seq_metadata is not None - q_head_kv_lens = (seq_lens // 2) * (self.pcp_rank + 1) - q_tail_kv_lens = seq_lens * self.pcp_size - (seq_lens // 2) * self.pcp_rank + num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(seq_lens.device) + q_head_kv_lens = (seq_lens // 2) * (self.pcp_rank + 1) + num_computed_tokens + q_tail_kv_lens = seq_lens * self.pcp_size - (seq_lens // 2) * self.pcp_rank + num_computed_tokens return AscendPCPMetadata( q_head_idx=common_long_seq_metadata.q_head_idx_tensor, q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor, @@ -181,6 +206,7 @@ class AscendSFACPImpl(AscendSFAImpl): return self._execute_sparse_flash_attention( ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key ) + num_decodes = attn_metadata.num_decodes num_decode_tokens = attn_metadata.num_decode_tokens num_prefills = attn_metadata.num_prefills decode_attn_out = None @@ -190,10 +216,10 @@ class AscendSFACPImpl(AscendSFAImpl): q_pe[:num_decode_tokens], kv, key_rope, - block_table[:num_decode_tokens], + block_table[:num_decodes], topk_indices[:num_decode_tokens], - actual_seq_lengths_query[:num_decode_tokens], - actual_seq_lengths_key[:num_decode_tokens], + actual_seq_lengths_query[:num_decodes], + actual_seq_lengths_key[:num_decodes], ) if num_prefills < 1: @@ -205,10 +231,10 @@ class AscendSFACPImpl(AscendSFAImpl): ql_nope = ql_nope[num_decode_tokens:] q_pe = q_pe[num_decode_tokens:] topk_indices = topk_indices[num_decode_tokens:] - block_table = block_table[num_decode_tokens:] + block_table = block_table[num_decodes:] # q head compute - q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decode_tokens:] + q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decodes:] q_head_output = self._execute_sparse_flash_attention( torch.index_select(ql_nope, 0, q_head_idx), torch.index_select(q_pe, 0, q_head_idx), @@ -221,7 +247,7 @@ class AscendSFACPImpl(AscendSFAImpl): ) # q tail compute - q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decode_tokens:] + q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decodes:] q_tail_output = self._execute_sparse_flash_attention( torch.index_select(ql_nope, 0, q_tail_idx), torch.index_select(q_pe, 0, q_tail_idx), @@ -321,6 +347,7 @@ class AscendSFACPImpl(AscendSFAImpl): ) # decode compute + num_decodes = attn_metadata.num_decodes num_decode_tokens = attn_metadata.num_decode_tokens num_prefills = attn_metadata.num_prefills decode_topk_indices = None @@ -329,9 +356,9 @@ class AscendSFACPImpl(AscendSFAImpl): q[:num_decode_tokens], key, weights[:num_decode_tokens], - actual_seq_lengths_query[:num_decode_tokens], - actual_seq_lengths_key[:num_decode_tokens], - block_table[:num_decode_tokens], + actual_seq_lengths_query[:num_decodes], + actual_seq_lengths_key[:num_decodes], + block_table[:num_decodes], ) # prefill compute @@ -339,14 +366,14 @@ class AscendSFACPImpl(AscendSFAImpl): return decode_topk_indices q = q[num_decode_tokens:] weights = weights[num_decode_tokens:] - actual_seq_lengths_key = actual_seq_lengths_key[num_decode_tokens:] - block_table = block_table[num_decode_tokens:] + actual_seq_lengths_key = actual_seq_lengths_key[num_decodes:] + block_table = block_table[num_decodes:] # pcp split for head and tail q_head_idx = attn_metadata.sfa_cp_metadata.q_head_idx q_tail_idx = attn_metadata.sfa_cp_metadata.q_tail_idx # q head compute - q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decode_tokens:] + q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decodes:] q_head_topk_indices = self._execute_indexer_select( q=torch.index_select(q, 0, q_head_idx), key=key, @@ -357,7 +384,7 @@ class AscendSFACPImpl(AscendSFAImpl): ) # q tail compute - q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decode_tokens:] + q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decodes:] q_tail_topk_indices = self._execute_indexer_select( q=torch.index_select(q, 0, q_tail_idx), key=key, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 4ed3f587..0b90183e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -246,36 +246,7 @@ class NPUModelRunner(GPUModelRunner): self.max_num_reqs = self.scheduler_config.max_num_seqs self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_rank = vllm_config.parallel_config.data_parallel_rank - try: - self.dcp_size = get_dcp_group().world_size - self.dcp_rank = get_dcp_group().rank_in_group - self.pcp_size = get_pcp_group().world_size - self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 - except Exception: - self.dcp_size = 1 - self.dcp_rank = 0 - self.pcp_size = 1 - self.pcp_rank = 0 - if self.pcp_size > 1: - self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs - max_buffer_num_tokens = self.max_num_tokens - if self.pcp_size * self.dcp_size > 1: - max_buffer_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size - self.pcp_manager = PCPManager( - self.pcp_size, - self.pcp_rank, - self.dcp_size, - self.dcp_rank, - max_buffer_num_tokens, - self.max_num_reqs, - self.device, - self.vllm_config, - self.use_async_scheduling, - self.pin_memory, - ) - # TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this - self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32) - self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64) + self.sampler = AscendSampler() self.attn_state: AscendAttentionState | None = None @@ -310,6 +281,38 @@ class NPUModelRunner(GPUModelRunner): use_mm_prefix=self.model_config is not None and self.model_config.is_mm_prefix_lm, ) + try: + self.dcp_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + self.pcp_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 + except Exception: + self.dcp_size = 1 + self.dcp_rank = 0 + self.pcp_size = 1 + self.pcp_rank = 0 + if self.pcp_size > 1: + self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs + max_buffer_num_tokens = self.max_num_tokens + if self.pcp_size * self.dcp_size > 1: + max_buffer_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size + self.pcp_manager = PCPManager( + self.pcp_size, + self.pcp_rank, + self.dcp_size, + self.dcp_rank, + max_buffer_num_tokens, + self.max_num_reqs, + self.device, + self.vllm_config, + self.use_async_scheduling, + self.pin_memory, + self.use_sparse, + ) + # TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this + self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32) + self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64) + self._set_up_drafter() # kv role diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index 61fdf0c3..b46fc52a 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -56,6 +56,7 @@ class PCPManager: vllm_config: VllmConfig, use_async_scheduling: bool, pin_memory: bool = False, + use_sparse: bool = False, ) -> None: self.pcp_world_size = pcp_world_size self.pcp_world_rank = pcp_rank @@ -97,6 +98,7 @@ class PCPManager: + self.pcp_world_size * self.dcp_world_size * self.max_num_reqs ) ) + self.use_sparse = use_sparse if self.speculative_config and self.pcp_world_size * self.dcp_world_size > 1: self.input_ids_pcp_full = CpuGpuBuffer( self.max_num_tokens, dtype=torch.int32, device=device, pin_memory=pin_memory @@ -784,16 +786,19 @@ class PCPManager: num_prefill_reqs = self.num_prefill_reqs num_decode_reqs = self.num_decode_reqs num_decode_reqs_flatten = ori_query_lens_cpu[:num_decode_reqs].sum().item() - block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_( - block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone() - ) - block_table_tensor[:num_decode_reqs_flatten].copy_( - block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0) - ) - block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs] - if num_reqs_padded > num_reqs: - pad_size = num_reqs_padded - num_reqs - ori_query_lens_cpu[-pad_size:] = torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item()) + if not self.use_sparse: + block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_( + block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone() + ) + block_table_tensor[:num_decode_reqs_flatten].copy_( + block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0) + ) + block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs] + if num_reqs_padded > num_reqs: + pad_size = num_reqs_padded - num_reqs + ori_query_lens_cpu[-pad_size:] = torch.full( + [pad_size], ori_query_lens_cpu[-pad_size - 1].item() + ) pcp_unpad_mask = self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length] long_seq_metadata = AscendPrefillContextParallelMetadata( pcp_use_hybrid_attn=self.pcp_use_hybrid_attn,