[Fix] Fix FIA query and query_start_loc shape mismatch error (#4518)
### What this PR does / why we need it? Due to the requirement of the FIA operator that the **query.shape[0]** must match **actual_seq_len[-1]**, in graph mode and multi-DP scenarios, the query is padded to the size of **num_input_token**. This leads to validation errors during tiling in the operator. However, since the padding is applied at the end of the query, it does not affect the actual execution result of the operator, and the precision remains unaffected. <img width="2434" height="49" alt="image" src="https://github.com/user-attachments/assets/63520816-fbc3-4382-82b9-89dbb1492f6c" /> Our modification padding both **actual_seq_len** and **actual_seq_len_kv** to resolve the validation issue in the operator. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com> Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
@@ -183,7 +183,6 @@ class AscendMetadataForDecode:
|
||||
class AscendMetadata:
|
||||
# **************************** Basic Properties ************************** #
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
fia_attn_mask: Optional[torch.Tensor] = None
|
||||
# Current state of this attention run.
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
@@ -312,21 +311,18 @@ class AscendAttentionMetadataBuilder:
|
||||
num_actual_tokens_pcp_padded]
|
||||
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
fia_attn_mask = common_attn_metadata.fia_attn_mask
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_reqs
|
||||
+ 1]
|
||||
num_computed_tokens_cpu = (seq_lens - query_lens)
|
||||
|
||||
if attn_state == AscendAttentionState.DecodeOnly and \
|
||||
common_attn_metadata.num_input_tokens > num_actual_tokens:
|
||||
if common_attn_metadata.num_input_tokens > num_actual_tokens:
|
||||
padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens
|
||||
seq_lens = torch.cat([
|
||||
seq_lens,
|
||||
torch.ones(padded_num_tokens,
|
||||
dtype=seq_lens.dtype,
|
||||
device=seq_lens.device)
|
||||
torch.tensor([padded_num_tokens
|
||||
]).to(seq_lens.device).to(seq_lens.dtype)
|
||||
])
|
||||
block_table_padding = torch.zeros(
|
||||
(padded_num_tokens, ) + block_table.shape[1:],
|
||||
@@ -335,10 +331,8 @@ class AscendAttentionMetadataBuilder:
|
||||
block_table = torch.cat([block_table, block_table_padding], dim=0)
|
||||
query_start_loc_cpu = torch.cat([
|
||||
query_start_loc_cpu,
|
||||
torch.arange(query_start_loc_cpu[-1] + 1,
|
||||
query_start_loc_cpu[-1] + padded_num_tokens,
|
||||
dtype=query_start_loc_cpu.dtype,
|
||||
device=query_start_loc_cpu.device)
|
||||
torch.tensor([query_start_loc_cpu[-1] + padded_num_tokens]).to(
|
||||
query_start_loc_cpu.device).to(query_start_loc_cpu.dtype)
|
||||
])
|
||||
|
||||
query_start_loc = query_start_loc_cpu.to(self.device,
|
||||
@@ -471,7 +465,6 @@ class AscendAttentionMetadataBuilder:
|
||||
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
fia_attn_mask=fia_attn_mask,
|
||||
attn_state=attn_state,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
@@ -604,7 +597,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
|
||||
num_tokens = attn_metadata.query_start_loc_list[-1]
|
||||
query = query[:num_tokens]
|
||||
graph_params = get_graph_params()
|
||||
query_start_loc = attn_metadata.query_start_loc_list
|
||||
# Prepare tensors for attention output
|
||||
@@ -618,7 +610,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
atten_mask=attn_metadata.fia_attn_mask,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
block_table=block_table,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
@@ -641,7 +633,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
graph_params.attn_params[num_tokens].append(
|
||||
(weak_ref_tensors(query), weak_ref_tensors(key),
|
||||
weak_ref_tensors(value), weak_ref_tensors(block_table),
|
||||
weak_ref_tensors(attn_metadata.fia_attn_mask), block_size,
|
||||
weak_ref_tensors(attn_metadata.attn_mask), block_size,
|
||||
actual_seq_lengths_kv, query_start_loc, self.num_kv_heads,
|
||||
self.num_heads, self.scale, weak_ref_tensors(output),
|
||||
weak_ref_tensors(softmax_lse)))
|
||||
@@ -651,7 +643,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
atten_mask=attn_metadata.fia_attn_mask,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
block_table=block_table,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
|
||||
Reference in New Issue
Block a user