[Fix] Fix FIA query and query_start_loc shape mismatch error (#4518)

### What this PR does / why we need it?
Due to the requirement of the FIA operator that the **query.shape[0]**
must match **actual_seq_len[-1]**, in graph mode and multi-DP scenarios,
the query is padded to the size of **num_input_token**. This leads to
validation errors during tiling in the operator. However, since the
padding is applied at the end of the query, it does not affect the
actual execution result of the operator, and the precision remains
unaffected.
<img width="2434" height="49" alt="image"
src="https://github.com/user-attachments/assets/63520816-fbc3-4382-82b9-89dbb1492f6c"
/>
Our modification padding both **actual_seq_len** and
**actual_seq_len_kv** to resolve the validation issue in the operator.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
XiaoxinWang
2025-12-03 17:33:31 +08:00
committed by GitHub
parent 7271f0d536
commit 15dc01f050
2 changed files with 12 additions and 41 deletions

View File

@@ -183,7 +183,6 @@ class AscendMetadataForDecode:
class AscendMetadata: class AscendMetadata:
# **************************** Basic Properties ************************** # # **************************** Basic Properties ************************** #
attn_mask: Optional[torch.Tensor] = None attn_mask: Optional[torch.Tensor] = None
fia_attn_mask: Optional[torch.Tensor] = None
# Current state of this attention run. # Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
@@ -312,21 +311,18 @@ class AscendAttentionMetadataBuilder:
num_actual_tokens_pcp_padded] num_actual_tokens_pcp_padded]
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] # slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
attn_mask = common_attn_metadata.attn_mask attn_mask = common_attn_metadata.attn_mask
fia_attn_mask = common_attn_metadata.fia_attn_mask
attn_state = common_attn_metadata.attn_state attn_state = common_attn_metadata.attn_state
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_reqs num_reqs
+ 1] + 1]
num_computed_tokens_cpu = (seq_lens - query_lens) num_computed_tokens_cpu = (seq_lens - query_lens)
if attn_state == AscendAttentionState.DecodeOnly and \ if common_attn_metadata.num_input_tokens > num_actual_tokens:
common_attn_metadata.num_input_tokens > num_actual_tokens:
padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens
seq_lens = torch.cat([ seq_lens = torch.cat([
seq_lens, seq_lens,
torch.ones(padded_num_tokens, torch.tensor([padded_num_tokens
dtype=seq_lens.dtype, ]).to(seq_lens.device).to(seq_lens.dtype)
device=seq_lens.device)
]) ])
block_table_padding = torch.zeros( block_table_padding = torch.zeros(
(padded_num_tokens, ) + block_table.shape[1:], (padded_num_tokens, ) + block_table.shape[1:],
@@ -335,10 +331,8 @@ class AscendAttentionMetadataBuilder:
block_table = torch.cat([block_table, block_table_padding], dim=0) block_table = torch.cat([block_table, block_table_padding], dim=0)
query_start_loc_cpu = torch.cat([ query_start_loc_cpu = torch.cat([
query_start_loc_cpu, query_start_loc_cpu,
torch.arange(query_start_loc_cpu[-1] + 1, torch.tensor([query_start_loc_cpu[-1] + padded_num_tokens]).to(
query_start_loc_cpu[-1] + padded_num_tokens, query_start_loc_cpu.device).to(query_start_loc_cpu.dtype)
dtype=query_start_loc_cpu.dtype,
device=query_start_loc_cpu.device)
]) ])
query_start_loc = query_start_loc_cpu.to(self.device, query_start_loc = query_start_loc_cpu.to(self.device,
@@ -471,7 +465,6 @@ class AscendAttentionMetadataBuilder:
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(), actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
attn_mask=attn_mask, attn_mask=attn_mask,
fia_attn_mask=fia_attn_mask,
attn_state=attn_state, attn_state=attn_state,
num_prefills=num_prefills, num_prefills=num_prefills,
num_decodes=num_decodes, num_decodes=num_decodes,
@@ -604,7 +597,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
actual_seq_lengths_kv = attn_metadata.seq_lens_list actual_seq_lengths_kv = attn_metadata.seq_lens_list
num_tokens = attn_metadata.query_start_loc_list[-1] num_tokens = attn_metadata.query_start_loc_list[-1]
query = query[:num_tokens]
graph_params = get_graph_params() graph_params = get_graph_params()
query_start_loc = attn_metadata.query_start_loc_list query_start_loc = attn_metadata.query_start_loc_list
# Prepare tensors for attention output # Prepare tensors for attention output
@@ -618,7 +610,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
query=query, query=query,
key=key, key=key,
value=value, value=value,
atten_mask=attn_metadata.fia_attn_mask, atten_mask=attn_metadata.attn_mask,
block_table=block_table, block_table=block_table,
input_layout="TND", input_layout="TND",
block_size=block_size, block_size=block_size,
@@ -641,7 +633,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
graph_params.attn_params[num_tokens].append( graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(query), weak_ref_tensors(key), (weak_ref_tensors(query), weak_ref_tensors(key),
weak_ref_tensors(value), weak_ref_tensors(block_table), weak_ref_tensors(value), weak_ref_tensors(block_table),
weak_ref_tensors(attn_metadata.fia_attn_mask), block_size, weak_ref_tensors(attn_metadata.attn_mask), block_size,
actual_seq_lengths_kv, query_start_loc, self.num_kv_heads, actual_seq_lengths_kv, query_start_loc, self.num_kv_heads,
self.num_heads, self.scale, weak_ref_tensors(output), self.num_heads, self.scale, weak_ref_tensors(output),
weak_ref_tensors(softmax_lse))) weak_ref_tensors(softmax_lse)))
@@ -651,7 +643,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
query=query, query=query,
key=key, key=key,
value=value, value=value,
atten_mask=attn_metadata.fia_attn_mask, atten_mask=attn_metadata.attn_mask,
block_table=block_table, block_table=block_table,
input_layout="TND", input_layout="TND",
block_size=block_size, block_size=block_size,

View File

@@ -321,7 +321,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.attn_groups: list[list[AttentionGroup]] = [] self.attn_groups: list[list[AttentionGroup]] = []
self.encoder_cache: Dict[str, torch.Tensor] = {} self.encoder_cache: Dict[str, torch.Tensor] = {}
self.attn_mask = None self.attn_mask = None
self.fia_attn_mask = None
self.attn_state = None self.attn_state = None
self.requests: Dict[str, CachedRequestState] = {} self.requests: Dict[str, CachedRequestState] = {}
self.intermediate_tensors: Optional[IntermediateTensors] = None self.intermediate_tensors: Optional[IntermediateTensors] = None
@@ -982,23 +981,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Pooling situation. # Pooling situation.
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS": if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
return self.attn_mask_builder.get_pooling_mask(self.device) return self.attn_mask_builder.get_pooling_mask(self.device)
# fia prefill situation.
if attn_state in [
AscendAttentionState.PrefillNoCache,
AscendAttentionState.PrefillCacheHit,
AscendAttentionState.ChunkedPrefill
]:
return self.attn_mask_builder.get_splitfuse_attn_mask()
# Decode-only situation.
return None
def _make_fia_attention_mask(self) -> torch.Tensor:
# pcp situation.
if self.pcp_size > 1:
return None
if self.attn_mask_builder is None:
raise ValueError("Attn mask builder is None")
return self.attn_mask_builder.get_splitfuse_attn_mask() return self.attn_mask_builder.get_splitfuse_attn_mask()
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
@@ -1579,7 +1561,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
position=positions_cpu, position=positions_cpu,
attn_state=attn_state) attn_state=attn_state)
self.fia_attn_mask = self._make_fia_attention_mask()
self.attn_state = attn_state # type: ignore self.attn_state = attn_state # type: ignore
self.with_prefill = with_prefill self.with_prefill = with_prefill
@@ -1804,7 +1785,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens_cpu=num_computed_tokens_cpu, num_computed_tokens_cpu=num_computed_tokens_cpu,
positions=self.positions, positions=self.positions,
attn_mask=self.attn_mask, attn_mask=self.attn_mask,
fia_attn_mask=self.fia_attn_mask,
spec_attn_mask=self.spec_attn_mask, spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state, attn_state=self.attn_state,
is_only_prefill=bool(np.all(num_valid_tokens != 1)), is_only_prefill=bool(np.all(num_valid_tokens != 1)),
@@ -2723,10 +2703,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.query_lens = torch.from_numpy(num_scheduled_tokens) self.query_lens = torch.from_numpy(num_scheduled_tokens)
assigned_mask_dim = 2048 assigned_mask_dim = 2048
self.fia_attn_mask = torch.triu(torch.ones(assigned_mask_dim, self.attn_mask = torch.triu(torch.ones(assigned_mask_dim,
assigned_mask_dim), assigned_mask_dim),
diagonal=1).to(torch.int8).to( diagonal=1).to(torch.int8).to(
self.device) self.device)
num_computed_tokens_cpu = ( num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
@@ -2770,7 +2750,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens_cpu=num_computed_tokens_cpu, num_computed_tokens_cpu=num_computed_tokens_cpu,
positions=self.positions, positions=self.positions,
attn_mask=self.attn_mask, attn_mask=self.attn_mask,
fia_attn_mask=self.fia_attn_mask,
spec_attn_mask=self.spec_attn_mask, spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state, attn_state=self.attn_state,
max_query_len=max_query_len, max_query_len=max_query_len,