From df777e9faa4338a5dac6c20f1088fed646f05d19 Mon Sep 17 00:00:00 2001 From: zhangsicheng5 Date: Wed, 19 Nov 2025 11:21:46 +0800 Subject: [PATCH] [bugfix] pcp + mtp acl graph bugfix (#4221) Fix pcp + mtp bug while using acl graph. While using pcp + mtp, we need to flatten block_table to avoid irregular attn mask shape, this was done in mla attn_metadata builder, but we found out that this influences block_table address and leads to incorrect results while enable acl graph. To fix this, we enlarge block_table buffer size and flatten block_table in model_runner prepare_inputs, so this will not influence block_table address. - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 Signed-off-by: zhangsicheng5 --- vllm_ascend/attention/mla_v1.py | 21 ++++++++++---- vllm_ascend/worker/block_table.py | 41 +++++++++++++++------------ vllm_ascend/worker/model_runner_v1.py | 31 ++++++++++++++++++++ 3 files changed, 69 insertions(+), 24 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 62ed95c5..47bcc494 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -369,6 +369,12 @@ class AscendMLAMetadataBuilder: device = self.device block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) + if self.pcp_size > 1: + num_decodes_flatten = num_decodes * self.decode_threshold + block_table = common_attn_metadata.block_table_tensor[: + num_decodes_flatten + + + num_prefills] if num_actual_tokens_pcp_padded is None: num_actual_tokens_pcp_padded = num_actual_tokens @@ -546,6 +552,9 @@ class AscendMLAMetadataBuilder: cos=cos, pcp_metadata=pcp_metadata, ) + if self.pcp_size > 1: + prefill_metadata.block_table = block_table[ + num_decodes_flatten:, ...] decode_metadata = None if num_decodes > 0: @@ -556,12 +565,12 @@ class AscendMLAMetadataBuilder: max_seq_lens = seq_lens[:num_decodes].max().item() seq_lens = seq_lens[:num_decodes] input_positions = input_positions[:num_decode_tokens] - block_table = block_table[:num_decodes, ...] - # For pcp + spec decode, we flatten seq_lens and block_table - # to avoid irregular spec_attn_mask shape - if self.pcp_size > 1 and self.decode_threshold > 1: - block_table = block_table.repeat_interleave( - self.decode_threshold, dim=0) + if self.pcp_size > 1: + # For pcp + spec decode, we flatten seq_lens and block_table + # to avoid irregular spec_attn_mask shape + block_table = block_table[:num_decodes_flatten, ...] + else: + block_table = block_table[:num_decodes, ...] seq_lens_list = seq_lens.tolist() if num_computed_tokens_of_pcp_dcp is not None: diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index da0cb543..579a051a 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -27,13 +27,29 @@ class BlockTable: pin_memory: bool, device: torch.device, kernel_sizes: Union[list[int], None] = None, - cp_kv_cache_interleave_size: int = 1): + cp_kv_cache_interleave_size: int = 1, + num_speculative_tokens: int = 0): self.max_num_reqs = max_num_reqs self.max_num_blocks_per_req = max_num_blocks_per_req self.max_num_batched_tokens = max_num_batched_tokens self.pin_memory = pin_memory self.device = device self.physical_block_size = block_size + + try: + self.pcp_world_size = get_pcp_group( + ).world_size if prefill_context_parallel_enable() else 1 + self.pcp_rank = get_pcp_group( + ).rank_in_group if self.pcp_world_size > 1 else 0 + self.dcp_world_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.dcp_world_size = 1 + self.dcp_rank = 0 + self.pcp_world_size = 1 + self.pcp_rank = 0 + # If kernel_sizes is None or [0], use physical block size (no splitting) if kernel_sizes is None or kernel_sizes == [0]: self.block_size = block_size @@ -69,13 +85,16 @@ class BlockTable: else: logical_table_size = max_num_blocks_per_req + duplicate_size = 1 + if self.pcp_world_size > 1: + duplicate_size += num_speculative_tokens self.block_table = torch.zeros( - (max_num_reqs, logical_table_size), + (max_num_reqs * duplicate_size, logical_table_size), device=self.device, dtype=torch.int32, ) self.block_table_cpu = torch.zeros( - (max_num_reqs, logical_table_size), + (max_num_reqs * duplicate_size, logical_table_size), device="cpu", dtype=torch.int32, pin_memory=pin_memory, @@ -83,20 +102,6 @@ class BlockTable: self.block_table_np = self.block_table_cpu.numpy() self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) - try: - self.pcp_world_size = get_pcp_group( - ).world_size if prefill_context_parallel_enable() else 1 - self.pcp_rank = get_pcp_group( - ).rank_in_group if self.pcp_world_size > 1 else 0 - self.dcp_world_size = get_dcp_group().world_size - self.dcp_rank = get_dcp_group().rank_in_group - except AssertionError: - # DCP might not be initialized in testing - self.dcp_world_size = 1 - self.dcp_rank = 0 - self.pcp_world_size = 1 - self.pcp_rank = 0 - self.slot_mapping_cpu = torch.zeros( self.max_num_batched_tokens + 2 * self.pcp_world_size * self.max_num_reqs, @@ -306,7 +311,7 @@ class MultiGroupBlockTable: block_size * dcp_world_size * pcp_world_size), 1 + num_speculative_tokens), max_num_batched_tokens, pin_memory, device, kernel_size_list, - cp_kv_cache_interleave_size) + cp_kv_cache_interleave_size, num_speculative_tokens) for block_size, kernel_size_list in zip(block_sizes, kernel_sizes) ] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 20891e50..2b0481c0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -596,6 +596,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.is_pooling_model, self.vllm_config.model_config.logits_processors), is_pooling_model=self.is_pooling_model, + num_speculative_tokens=( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config else 0), kernel_block_sizes=[[self.vllm_config.cache_config.block_size]], cp_kv_cache_interleave_size=self.parallel_config. cp_kv_cache_interleave_size @@ -1922,6 +1925,31 @@ class NPUModelRunner(LoRAModelRunnerMixin): prefill_context_parallel_metadata=long_seq_metadata, ) + if self.speculative_config and self.pcp_size > 1: + # For pcp + spec decode, we flatten block_table + # to avoid irregular spec_attn_mask shape, e.g., + # num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1, + # ori block_table: # [d0, d1, p0, p1, p2] + # (num_reqs_d + num_reqs_p, max_num_blocks), + # flattened block_table: [d0, d0, d1, d1, p0, p1, p2] + # (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks), + ori_query_lens = self.query_start_loc_pcp_full_cpu[1:num_reqs+1] - \ + self.query_start_loc_pcp_full_cpu[:num_reqs] + num_prefill_reqs = (ori_query_lens + > self.decode_threshold).sum().item() + num_decode_reqs = num_reqs - num_prefill_reqs + num_decode_reqs_flatten = num_decode_reqs * self.decode_threshold + blk_table_tensor[ + num_decode_reqs_flatten:num_decode_reqs_flatten + + num_prefill_reqs].copy_( + blk_table_tensor[num_decode_reqs:num_decode_reqs + + num_prefill_reqs].clone()) + blk_table_tensor[:num_decode_reqs_flatten].copy_( + blk_table_tensor[:num_decode_reqs].repeat_interleave( + self.decode_threshold, dim=0)) + common_attn_metadata.block_table_tensor = \ + blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs] + if self.speculative_config and \ self.spec_decode_common_attn_metadata is None: self.spec_decode_common_attn_metadata = common_attn_metadata @@ -2831,6 +2859,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): sin=self.sin, prefill_context_parallel_metadata=long_seq_metadata, ) + if self.pcp_size > 1: + common_attn_metadata.block_table_tensor = \ + block_table_tensor[:num_reqs * self.decode_threshold] attn_state = AscendAttentionState.DecodeOnly if self.speculative_config and \ self.speculative_config.method == "deepseek_mtp":