[Fix] Fix FIA query and query_start_loc shape mismatch error (#4518)
### What this PR does / why we need it? Due to the requirement of the FIA operator that the **query.shape[0]** must match **actual_seq_len[-1]**, in graph mode and multi-DP scenarios, the query is padded to the size of **num_input_token**. This leads to validation errors during tiling in the operator. However, since the padding is applied at the end of the query, it does not affect the actual execution result of the operator, and the precision remains unaffected. <img width="2434" height="49" alt="image" src="https://github.com/user-attachments/assets/63520816-fbc3-4382-82b9-89dbb1492f6c" /> Our modification padding both **actual_seq_len** and **actual_seq_len_kv** to resolve the validation issue in the operator. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com> Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
@@ -183,7 +183,6 @@ class AscendMetadataForDecode:
|
|||||||
class AscendMetadata:
|
class AscendMetadata:
|
||||||
# **************************** Basic Properties ************************** #
|
# **************************** Basic Properties ************************** #
|
||||||
attn_mask: Optional[torch.Tensor] = None
|
attn_mask: Optional[torch.Tensor] = None
|
||||||
fia_attn_mask: Optional[torch.Tensor] = None
|
|
||||||
# Current state of this attention run.
|
# Current state of this attention run.
|
||||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||||
|
|
||||||
@@ -312,21 +311,18 @@ class AscendAttentionMetadataBuilder:
|
|||||||
num_actual_tokens_pcp_padded]
|
num_actual_tokens_pcp_padded]
|
||||||
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||||
attn_mask = common_attn_metadata.attn_mask
|
attn_mask = common_attn_metadata.attn_mask
|
||||||
fia_attn_mask = common_attn_metadata.fia_attn_mask
|
|
||||||
attn_state = common_attn_metadata.attn_state
|
attn_state = common_attn_metadata.attn_state
|
||||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||||
num_reqs
|
num_reqs
|
||||||
+ 1]
|
+ 1]
|
||||||
num_computed_tokens_cpu = (seq_lens - query_lens)
|
num_computed_tokens_cpu = (seq_lens - query_lens)
|
||||||
|
|
||||||
if attn_state == AscendAttentionState.DecodeOnly and \
|
if common_attn_metadata.num_input_tokens > num_actual_tokens:
|
||||||
common_attn_metadata.num_input_tokens > num_actual_tokens:
|
|
||||||
padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens
|
padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens
|
||||||
seq_lens = torch.cat([
|
seq_lens = torch.cat([
|
||||||
seq_lens,
|
seq_lens,
|
||||||
torch.ones(padded_num_tokens,
|
torch.tensor([padded_num_tokens
|
||||||
dtype=seq_lens.dtype,
|
]).to(seq_lens.device).to(seq_lens.dtype)
|
||||||
device=seq_lens.device)
|
|
||||||
])
|
])
|
||||||
block_table_padding = torch.zeros(
|
block_table_padding = torch.zeros(
|
||||||
(padded_num_tokens, ) + block_table.shape[1:],
|
(padded_num_tokens, ) + block_table.shape[1:],
|
||||||
@@ -335,10 +331,8 @@ class AscendAttentionMetadataBuilder:
|
|||||||
block_table = torch.cat([block_table, block_table_padding], dim=0)
|
block_table = torch.cat([block_table, block_table_padding], dim=0)
|
||||||
query_start_loc_cpu = torch.cat([
|
query_start_loc_cpu = torch.cat([
|
||||||
query_start_loc_cpu,
|
query_start_loc_cpu,
|
||||||
torch.arange(query_start_loc_cpu[-1] + 1,
|
torch.tensor([query_start_loc_cpu[-1] + padded_num_tokens]).to(
|
||||||
query_start_loc_cpu[-1] + padded_num_tokens,
|
query_start_loc_cpu.device).to(query_start_loc_cpu.dtype)
|
||||||
dtype=query_start_loc_cpu.dtype,
|
|
||||||
device=query_start_loc_cpu.device)
|
|
||||||
])
|
])
|
||||||
|
|
||||||
query_start_loc = query_start_loc_cpu.to(self.device,
|
query_start_loc = query_start_loc_cpu.to(self.device,
|
||||||
@@ -471,7 +465,6 @@ class AscendAttentionMetadataBuilder:
|
|||||||
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
|
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
fia_attn_mask=fia_attn_mask,
|
|
||||||
attn_state=attn_state,
|
attn_state=attn_state,
|
||||||
num_prefills=num_prefills,
|
num_prefills=num_prefills,
|
||||||
num_decodes=num_decodes,
|
num_decodes=num_decodes,
|
||||||
@@ -604,7 +597,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||||
|
|
||||||
num_tokens = attn_metadata.query_start_loc_list[-1]
|
num_tokens = attn_metadata.query_start_loc_list[-1]
|
||||||
query = query[:num_tokens]
|
|
||||||
graph_params = get_graph_params()
|
graph_params = get_graph_params()
|
||||||
query_start_loc = attn_metadata.query_start_loc_list
|
query_start_loc = attn_metadata.query_start_loc_list
|
||||||
# Prepare tensors for attention output
|
# Prepare tensors for attention output
|
||||||
@@ -618,7 +610,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
query=query,
|
query=query,
|
||||||
key=key,
|
key=key,
|
||||||
value=value,
|
value=value,
|
||||||
atten_mask=attn_metadata.fia_attn_mask,
|
atten_mask=attn_metadata.attn_mask,
|
||||||
block_table=block_table,
|
block_table=block_table,
|
||||||
input_layout="TND",
|
input_layout="TND",
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
@@ -641,7 +633,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
graph_params.attn_params[num_tokens].append(
|
graph_params.attn_params[num_tokens].append(
|
||||||
(weak_ref_tensors(query), weak_ref_tensors(key),
|
(weak_ref_tensors(query), weak_ref_tensors(key),
|
||||||
weak_ref_tensors(value), weak_ref_tensors(block_table),
|
weak_ref_tensors(value), weak_ref_tensors(block_table),
|
||||||
weak_ref_tensors(attn_metadata.fia_attn_mask), block_size,
|
weak_ref_tensors(attn_metadata.attn_mask), block_size,
|
||||||
actual_seq_lengths_kv, query_start_loc, self.num_kv_heads,
|
actual_seq_lengths_kv, query_start_loc, self.num_kv_heads,
|
||||||
self.num_heads, self.scale, weak_ref_tensors(output),
|
self.num_heads, self.scale, weak_ref_tensors(output),
|
||||||
weak_ref_tensors(softmax_lse)))
|
weak_ref_tensors(softmax_lse)))
|
||||||
@@ -651,7 +643,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
query=query,
|
query=query,
|
||||||
key=key,
|
key=key,
|
||||||
value=value,
|
value=value,
|
||||||
atten_mask=attn_metadata.fia_attn_mask,
|
atten_mask=attn_metadata.attn_mask,
|
||||||
block_table=block_table,
|
block_table=block_table,
|
||||||
input_layout="TND",
|
input_layout="TND",
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
|
|||||||
@@ -321,7 +321,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.attn_groups: list[list[AttentionGroup]] = []
|
self.attn_groups: list[list[AttentionGroup]] = []
|
||||||
self.encoder_cache: Dict[str, torch.Tensor] = {}
|
self.encoder_cache: Dict[str, torch.Tensor] = {}
|
||||||
self.attn_mask = None
|
self.attn_mask = None
|
||||||
self.fia_attn_mask = None
|
|
||||||
self.attn_state = None
|
self.attn_state = None
|
||||||
self.requests: Dict[str, CachedRequestState] = {}
|
self.requests: Dict[str, CachedRequestState] = {}
|
||||||
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
||||||
@@ -982,23 +981,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Pooling situation.
|
# Pooling situation.
|
||||||
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
|
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
|
||||||
return self.attn_mask_builder.get_pooling_mask(self.device)
|
return self.attn_mask_builder.get_pooling_mask(self.device)
|
||||||
# fia prefill situation.
|
|
||||||
if attn_state in [
|
|
||||||
AscendAttentionState.PrefillNoCache,
|
|
||||||
AscendAttentionState.PrefillCacheHit,
|
|
||||||
AscendAttentionState.ChunkedPrefill
|
|
||||||
]:
|
|
||||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
|
||||||
|
|
||||||
# Decode-only situation.
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _make_fia_attention_mask(self) -> torch.Tensor:
|
|
||||||
# pcp situation.
|
|
||||||
if self.pcp_size > 1:
|
|
||||||
return None
|
|
||||||
if self.attn_mask_builder is None:
|
|
||||||
raise ValueError("Attn mask builder is None")
|
|
||||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||||
|
|
||||||
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
|
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
|
||||||
@@ -1579,7 +1561,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
|
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
|
||||||
position=positions_cpu,
|
position=positions_cpu,
|
||||||
attn_state=attn_state)
|
attn_state=attn_state)
|
||||||
self.fia_attn_mask = self._make_fia_attention_mask()
|
|
||||||
self.attn_state = attn_state # type: ignore
|
self.attn_state = attn_state # type: ignore
|
||||||
|
|
||||||
self.with_prefill = with_prefill
|
self.with_prefill = with_prefill
|
||||||
@@ -1804,7 +1785,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||||
positions=self.positions,
|
positions=self.positions,
|
||||||
attn_mask=self.attn_mask,
|
attn_mask=self.attn_mask,
|
||||||
fia_attn_mask=self.fia_attn_mask,
|
|
||||||
spec_attn_mask=self.spec_attn_mask,
|
spec_attn_mask=self.spec_attn_mask,
|
||||||
attn_state=self.attn_state,
|
attn_state=self.attn_state,
|
||||||
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
|
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
|
||||||
@@ -2723,7 +2703,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||||
|
|
||||||
assigned_mask_dim = 2048
|
assigned_mask_dim = 2048
|
||||||
self.fia_attn_mask = torch.triu(torch.ones(assigned_mask_dim,
|
self.attn_mask = torch.triu(torch.ones(assigned_mask_dim,
|
||||||
assigned_mask_dim),
|
assigned_mask_dim),
|
||||||
diagonal=1).to(torch.int8).to(
|
diagonal=1).to(torch.int8).to(
|
||||||
self.device)
|
self.device)
|
||||||
@@ -2770,7 +2750,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||||
positions=self.positions,
|
positions=self.positions,
|
||||||
attn_mask=self.attn_mask,
|
attn_mask=self.attn_mask,
|
||||||
fia_attn_mask=self.fia_attn_mask,
|
|
||||||
spec_attn_mask=self.spec_attn_mask,
|
spec_attn_mask=self.spec_attn_mask,
|
||||||
attn_state=self.attn_state,
|
attn_state=self.attn_state,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
|
|||||||
Reference in New Issue
Block a user