From 15dc01f0502f7ee2f3efe38c279927988f9001d1 Mon Sep 17 00:00:00 2001 From: XiaoxinWang <963372609@qq.com> Date: Wed, 3 Dec 2025 17:33:31 +0800 Subject: [PATCH] [Fix] Fix FIA `query` and `query_start_loc` shape mismatch error (#4518) ### What this PR does / why we need it? Due to the requirement of the FIA operator that the **query.shape[0]** must match **actual_seq_len[-1]**, in graph mode and multi-DP scenarios, the query is padded to the size of **num_input_token**. This leads to validation errors during tiling in the operator. However, since the padding is applied at the end of the query, it does not affect the actual execution result of the operator, and the precision remains unaffected. image Our modification padding both **actual_seq_len** and **actual_seq_len_kv** to resolve the validation issue in the operator. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: wangxiaoxin-sherie Co-authored-by: wangxiaoxin-sherie --- vllm_ascend/attention/attention_v1.py | 24 ++++++++-------------- vllm_ascend/worker/model_runner_v1.py | 29 ++++----------------------- 2 files changed, 12 insertions(+), 41 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index ff0240bb..b524e648 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -183,7 +183,6 @@ class AscendMetadataForDecode: class AscendMetadata: # **************************** Basic Properties ************************** # attn_mask: Optional[torch.Tensor] = None - fia_attn_mask: Optional[torch.Tensor] = None # Current state of this attention run. attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill @@ -312,21 +311,18 @@ class AscendAttentionMetadataBuilder: num_actual_tokens_pcp_padded] # slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] attn_mask = common_attn_metadata.attn_mask - fia_attn_mask = common_attn_metadata.fia_attn_mask attn_state = common_attn_metadata.attn_state query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] num_computed_tokens_cpu = (seq_lens - query_lens) - if attn_state == AscendAttentionState.DecodeOnly and \ - common_attn_metadata.num_input_tokens > num_actual_tokens: + if common_attn_metadata.num_input_tokens > num_actual_tokens: padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens seq_lens = torch.cat([ seq_lens, - torch.ones(padded_num_tokens, - dtype=seq_lens.dtype, - device=seq_lens.device) + torch.tensor([padded_num_tokens + ]).to(seq_lens.device).to(seq_lens.dtype) ]) block_table_padding = torch.zeros( (padded_num_tokens, ) + block_table.shape[1:], @@ -335,10 +331,8 @@ class AscendAttentionMetadataBuilder: block_table = torch.cat([block_table, block_table_padding], dim=0) query_start_loc_cpu = torch.cat([ query_start_loc_cpu, - torch.arange(query_start_loc_cpu[-1] + 1, - query_start_loc_cpu[-1] + padded_num_tokens, - dtype=query_start_loc_cpu.dtype, - device=query_start_loc_cpu.device) + torch.tensor([query_start_loc_cpu[-1] + padded_num_tokens]).to( + query_start_loc_cpu.device).to(query_start_loc_cpu.dtype) ]) query_start_loc = query_start_loc_cpu.to(self.device, @@ -471,7 +465,6 @@ class AscendAttentionMetadataBuilder: actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(), slot_mapping=slot_mapping, attn_mask=attn_mask, - fia_attn_mask=fia_attn_mask, attn_state=attn_state, num_prefills=num_prefills, num_decodes=num_decodes, @@ -604,7 +597,6 @@ class AscendAttentionBackendImpl(AttentionImpl): actual_seq_lengths_kv = attn_metadata.seq_lens_list num_tokens = attn_metadata.query_start_loc_list[-1] - query = query[:num_tokens] graph_params = get_graph_params() query_start_loc = attn_metadata.query_start_loc_list # Prepare tensors for attention output @@ -618,7 +610,7 @@ class AscendAttentionBackendImpl(AttentionImpl): query=query, key=key, value=value, - atten_mask=attn_metadata.fia_attn_mask, + atten_mask=attn_metadata.attn_mask, block_table=block_table, input_layout="TND", block_size=block_size, @@ -641,7 +633,7 @@ class AscendAttentionBackendImpl(AttentionImpl): graph_params.attn_params[num_tokens].append( (weak_ref_tensors(query), weak_ref_tensors(key), weak_ref_tensors(value), weak_ref_tensors(block_table), - weak_ref_tensors(attn_metadata.fia_attn_mask), block_size, + weak_ref_tensors(attn_metadata.attn_mask), block_size, actual_seq_lengths_kv, query_start_loc, self.num_kv_heads, self.num_heads, self.scale, weak_ref_tensors(output), weak_ref_tensors(softmax_lse))) @@ -651,7 +643,7 @@ class AscendAttentionBackendImpl(AttentionImpl): query=query, key=key, value=value, - atten_mask=attn_metadata.fia_attn_mask, + atten_mask=attn_metadata.attn_mask, block_table=block_table, input_layout="TND", block_size=block_size, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 1f46b9d4..9e28e117 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -321,7 +321,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.attn_groups: list[list[AttentionGroup]] = [] self.encoder_cache: Dict[str, torch.Tensor] = {} self.attn_mask = None - self.fia_attn_mask = None self.attn_state = None self.requests: Dict[str, CachedRequestState] = {} self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -982,23 +981,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Pooling situation. if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS": return self.attn_mask_builder.get_pooling_mask(self.device) - # fia prefill situation. - if attn_state in [ - AscendAttentionState.PrefillNoCache, - AscendAttentionState.PrefillCacheHit, - AscendAttentionState.ChunkedPrefill - ]: - return self.attn_mask_builder.get_splitfuse_attn_mask() - - # Decode-only situation. - return None - - def _make_fia_attention_mask(self) -> torch.Tensor: - # pcp situation. - if self.pcp_size > 1: - return None - if self.attn_mask_builder is None: - raise ValueError("Attn mask builder is None") return self.attn_mask_builder.get_splitfuse_attn_mask() def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): @@ -1579,7 +1561,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, position=positions_cpu, attn_state=attn_state) - self.fia_attn_mask = self._make_fia_attention_mask() self.attn_state = attn_state # type: ignore self.with_prefill = with_prefill @@ -1804,7 +1785,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_computed_tokens_cpu=num_computed_tokens_cpu, positions=self.positions, attn_mask=self.attn_mask, - fia_attn_mask=self.fia_attn_mask, spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, is_only_prefill=bool(np.all(num_valid_tokens != 1)), @@ -2723,10 +2703,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.query_lens = torch.from_numpy(num_scheduled_tokens) assigned_mask_dim = 2048 - self.fia_attn_mask = torch.triu(torch.ones(assigned_mask_dim, - assigned_mask_dim), - diagonal=1).to(torch.int8).to( - self.device) + self.attn_mask = torch.triu(torch.ones(assigned_mask_dim, + assigned_mask_dim), + diagonal=1).to(torch.int8).to( + self.device) num_computed_tokens_cpu = ( self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) @@ -2770,7 +2750,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_computed_tokens_cpu=num_computed_tokens_cpu, positions=self.positions, attn_mask=self.attn_mask, - fia_attn_mask=self.fia_attn_mask, spec_attn_mask=self.spec_attn_mask, attn_state=self.attn_state, max_query_len=max_query_len,