[Fix] Fix FIA query and query_start_loc shape mismatch error (#4518)

### What this PR does / why we need it?
Due to the requirement of the FIA operator that the **query.shape[0]**
must match **actual_seq_len[-1]**, in graph mode and multi-DP scenarios,
the query is padded to the size of **num_input_token**. This leads to
validation errors during tiling in the operator. However, since the
padding is applied at the end of the query, it does not affect the
actual execution result of the operator, and the precision remains
unaffected.
<img width="2434" height="49" alt="image"
src="https://github.com/user-attachments/assets/63520816-fbc3-4382-82b9-89dbb1492f6c"
/>
Our modification padding both **actual_seq_len** and
**actual_seq_len_kv** to resolve the validation issue in the operator.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
XiaoxinWang
2025-12-03 17:33:31 +08:00
committed by GitHub
parent 7271f0d536
commit 15dc01f050
2 changed files with 12 additions and 41 deletions

View File

@@ -183,7 +183,6 @@ class AscendMetadataForDecode:
class AscendMetadata:
# **************************** Basic Properties ************************** #
attn_mask: Optional[torch.Tensor] = None
fia_attn_mask: Optional[torch.Tensor] = None
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
@@ -312,21 +311,18 @@ class AscendAttentionMetadataBuilder:
num_actual_tokens_pcp_padded]
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
attn_mask = common_attn_metadata.attn_mask
fia_attn_mask = common_attn_metadata.fia_attn_mask
attn_state = common_attn_metadata.attn_state
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_reqs
+ 1]
num_computed_tokens_cpu = (seq_lens - query_lens)
if attn_state == AscendAttentionState.DecodeOnly and \
common_attn_metadata.num_input_tokens > num_actual_tokens:
if common_attn_metadata.num_input_tokens > num_actual_tokens:
padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens
seq_lens = torch.cat([
seq_lens,
torch.ones(padded_num_tokens,
dtype=seq_lens.dtype,
device=seq_lens.device)
torch.tensor([padded_num_tokens
]).to(seq_lens.device).to(seq_lens.dtype)
])
block_table_padding = torch.zeros(
(padded_num_tokens, ) + block_table.shape[1:],
@@ -335,10 +331,8 @@ class AscendAttentionMetadataBuilder:
block_table = torch.cat([block_table, block_table_padding], dim=0)
query_start_loc_cpu = torch.cat([
query_start_loc_cpu,
torch.arange(query_start_loc_cpu[-1] + 1,
query_start_loc_cpu[-1] + padded_num_tokens,
dtype=query_start_loc_cpu.dtype,
device=query_start_loc_cpu.device)
torch.tensor([query_start_loc_cpu[-1] + padded_num_tokens]).to(
query_start_loc_cpu.device).to(query_start_loc_cpu.dtype)
])
query_start_loc = query_start_loc_cpu.to(self.device,
@@ -471,7 +465,6 @@ class AscendAttentionMetadataBuilder:
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
slot_mapping=slot_mapping,
attn_mask=attn_mask,
fia_attn_mask=fia_attn_mask,
attn_state=attn_state,
num_prefills=num_prefills,
num_decodes=num_decodes,
@@ -604,7 +597,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
actual_seq_lengths_kv = attn_metadata.seq_lens_list
num_tokens = attn_metadata.query_start_loc_list[-1]
query = query[:num_tokens]
graph_params = get_graph_params()
query_start_loc = attn_metadata.query_start_loc_list
# Prepare tensors for attention output
@@ -618,7 +610,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
query=query,
key=key,
value=value,
atten_mask=attn_metadata.fia_attn_mask,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
@@ -641,7 +633,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(query), weak_ref_tensors(key),
weak_ref_tensors(value), weak_ref_tensors(block_table),
weak_ref_tensors(attn_metadata.fia_attn_mask), block_size,
weak_ref_tensors(attn_metadata.attn_mask), block_size,
actual_seq_lengths_kv, query_start_loc, self.num_kv_heads,
self.num_heads, self.scale, weak_ref_tensors(output),
weak_ref_tensors(softmax_lse)))
@@ -651,7 +643,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
query=query,
key=key,
value=value,
atten_mask=attn_metadata.fia_attn_mask,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,