[BugFix]Move to_list in foward_v1 with FIA earlier to build (#3185)
### What this PR does / why we need it? The current implementation of FIA will introduce an `to_list` operation for actual_seq_lengths_q and seq_lens,which comsumes extra time. These operation can be moved earlier into `build` operation of attention metadata. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
@@ -140,7 +140,13 @@ class AscendMetadata:
|
||||
# The sequence length per sequence. Sequence length means the computed
|
||||
# tokens + new tokens (is None if it is a decoding).
|
||||
# (batch_size,)
|
||||
# TODO(Angazenn): The following parameters are quite redundant and
|
||||
# contains similar information (such as seq_lens seq_lens_list). We
|
||||
# should simplified these parameters once attention schema in vLLM-Ascend
|
||||
# is unified.
|
||||
seq_lens: torch.Tensor = None
|
||||
seq_lens_list: List[int] = None # type: ignore
|
||||
actual_seq_lengths_q: List[int] = None # type: ignore
|
||||
|
||||
query_start_loc: torch.Tensor = None
|
||||
query_lens: torch.Tensor = None
|
||||
@@ -229,7 +235,9 @@ class AscendAttentionMetadataBuilder:
|
||||
query_start_loc=query_start_loc,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
@@ -528,8 +536,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
block_table=attn_metadata.block_tables,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
actual_seq_lengths=attn_metadata.query_start_loc[1:],
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens,
|
||||
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
|
||||
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
|
||||
Reference in New Issue
Block a user