From 15dc01f0502f7ee2f3efe38c279927988f9001d1 Mon Sep 17 00:00:00 2001
From: XiaoxinWang <963372609@qq.com>
Date: Wed, 3 Dec 2025 17:33:31 +0800
Subject: [PATCH] [Fix] Fix FIA `query` and `query_start_loc` shape mismatch
error (#4518)
### What this PR does / why we need it?
Due to the requirement of the FIA operator that the **query.shape[0]**
must match **actual_seq_len[-1]**, in graph mode and multi-DP scenarios,
the query is padded to the size of **num_input_token**. This leads to
validation errors during tiling in the operator. However, since the
padding is applied at the end of the query, it does not affect the
actual execution result of the operator, and the precision remains
unaffected.
Our modification padding both **actual_seq_len** and
**actual_seq_len_kv** to resolve the validation issue in the operator.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2
Signed-off-by: wangxiaoxin-sherie
Co-authored-by: wangxiaoxin-sherie
---
vllm_ascend/attention/attention_v1.py | 24 ++++++++--------------
vllm_ascend/worker/model_runner_v1.py | 29 ++++-----------------------
2 files changed, 12 insertions(+), 41 deletions(-)
diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py
index ff0240bb..b524e648 100644
--- a/vllm_ascend/attention/attention_v1.py
+++ b/vllm_ascend/attention/attention_v1.py
@@ -183,7 +183,6 @@ class AscendMetadataForDecode:
class AscendMetadata:
# **************************** Basic Properties ************************** #
attn_mask: Optional[torch.Tensor] = None
- fia_attn_mask: Optional[torch.Tensor] = None
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
@@ -312,21 +311,18 @@ class AscendAttentionMetadataBuilder:
num_actual_tokens_pcp_padded]
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
attn_mask = common_attn_metadata.attn_mask
- fia_attn_mask = common_attn_metadata.fia_attn_mask
attn_state = common_attn_metadata.attn_state
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_reqs
+ 1]
num_computed_tokens_cpu = (seq_lens - query_lens)
- if attn_state == AscendAttentionState.DecodeOnly and \
- common_attn_metadata.num_input_tokens > num_actual_tokens:
+ if common_attn_metadata.num_input_tokens > num_actual_tokens:
padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens
seq_lens = torch.cat([
seq_lens,
- torch.ones(padded_num_tokens,
- dtype=seq_lens.dtype,
- device=seq_lens.device)
+ torch.tensor([padded_num_tokens
+ ]).to(seq_lens.device).to(seq_lens.dtype)
])
block_table_padding = torch.zeros(
(padded_num_tokens, ) + block_table.shape[1:],
@@ -335,10 +331,8 @@ class AscendAttentionMetadataBuilder:
block_table = torch.cat([block_table, block_table_padding], dim=0)
query_start_loc_cpu = torch.cat([
query_start_loc_cpu,
- torch.arange(query_start_loc_cpu[-1] + 1,
- query_start_loc_cpu[-1] + padded_num_tokens,
- dtype=query_start_loc_cpu.dtype,
- device=query_start_loc_cpu.device)
+ torch.tensor([query_start_loc_cpu[-1] + padded_num_tokens]).to(
+ query_start_loc_cpu.device).to(query_start_loc_cpu.dtype)
])
query_start_loc = query_start_loc_cpu.to(self.device,
@@ -471,7 +465,6 @@ class AscendAttentionMetadataBuilder:
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
slot_mapping=slot_mapping,
attn_mask=attn_mask,
- fia_attn_mask=fia_attn_mask,
attn_state=attn_state,
num_prefills=num_prefills,
num_decodes=num_decodes,
@@ -604,7 +597,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
actual_seq_lengths_kv = attn_metadata.seq_lens_list
num_tokens = attn_metadata.query_start_loc_list[-1]
- query = query[:num_tokens]
graph_params = get_graph_params()
query_start_loc = attn_metadata.query_start_loc_list
# Prepare tensors for attention output
@@ -618,7 +610,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
query=query,
key=key,
value=value,
- atten_mask=attn_metadata.fia_attn_mask,
+ atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
@@ -641,7 +633,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(query), weak_ref_tensors(key),
weak_ref_tensors(value), weak_ref_tensors(block_table),
- weak_ref_tensors(attn_metadata.fia_attn_mask), block_size,
+ weak_ref_tensors(attn_metadata.attn_mask), block_size,
actual_seq_lengths_kv, query_start_loc, self.num_kv_heads,
self.num_heads, self.scale, weak_ref_tensors(output),
weak_ref_tensors(softmax_lse)))
@@ -651,7 +643,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
query=query,
key=key,
value=value,
- atten_mask=attn_metadata.fia_attn_mask,
+ atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py
index 1f46b9d4..9e28e117 100644
--- a/vllm_ascend/worker/model_runner_v1.py
+++ b/vllm_ascend/worker/model_runner_v1.py
@@ -321,7 +321,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.attn_groups: list[list[AttentionGroup]] = []
self.encoder_cache: Dict[str, torch.Tensor] = {}
self.attn_mask = None
- self.fia_attn_mask = None
self.attn_state = None
self.requests: Dict[str, CachedRequestState] = {}
self.intermediate_tensors: Optional[IntermediateTensors] = None
@@ -982,23 +981,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Pooling situation.
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
return self.attn_mask_builder.get_pooling_mask(self.device)
- # fia prefill situation.
- if attn_state in [
- AscendAttentionState.PrefillNoCache,
- AscendAttentionState.PrefillCacheHit,
- AscendAttentionState.ChunkedPrefill
- ]:
- return self.attn_mask_builder.get_splitfuse_attn_mask()
-
- # Decode-only situation.
- return None
-
- def _make_fia_attention_mask(self) -> torch.Tensor:
- # pcp situation.
- if self.pcp_size > 1:
- return None
- if self.attn_mask_builder is None:
- raise ValueError("Attn mask builder is None")
return self.attn_mask_builder.get_splitfuse_attn_mask()
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
@@ -1579,7 +1561,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
position=positions_cpu,
attn_state=attn_state)
- self.fia_attn_mask = self._make_fia_attention_mask()
self.attn_state = attn_state # type: ignore
self.with_prefill = with_prefill
@@ -1804,7 +1785,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens_cpu=num_computed_tokens_cpu,
positions=self.positions,
attn_mask=self.attn_mask,
- fia_attn_mask=self.fia_attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
@@ -2723,10 +2703,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.query_lens = torch.from_numpy(num_scheduled_tokens)
assigned_mask_dim = 2048
- self.fia_attn_mask = torch.triu(torch.ones(assigned_mask_dim,
- assigned_mask_dim),
- diagonal=1).to(torch.int8).to(
- self.device)
+ self.attn_mask = torch.triu(torch.ones(assigned_mask_dim,
+ assigned_mask_dim),
+ diagonal=1).to(torch.int8).to(
+ self.device)
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
@@ -2770,7 +2750,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens_cpu=num_computed_tokens_cpu,
positions=self.positions,
attn_mask=self.attn_mask,
- fia_attn_mask=self.fia_attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
max_query_len=max_query_len,