diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 27c289b..7dc309b 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -140,7 +140,13 @@ class AscendMetadata: # The sequence length per sequence. Sequence length means the computed # tokens + new tokens (is None if it is a decoding). # (batch_size,) + # TODO(Angazenn): The following parameters are quite redundant and + # contains similar information (such as seq_lens seq_lens_list). We + # should simplified these parameters once attention schema in vLLM-Ascend + # is unified. seq_lens: torch.Tensor = None + seq_lens_list: List[int] = None # type: ignore + actual_seq_lengths_q: List[int] = None # type: ignore query_start_loc: torch.Tensor = None query_lens: torch.Tensor = None @@ -229,7 +235,9 @@ class AscendAttentionMetadataBuilder: query_start_loc=query_start_loc, query_lens=query_lens, seq_lens=seq_lens, + seq_lens_list=seq_lens.tolist(), max_query_len=common_attn_metadata.max_query_len, + actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(), slot_mapping=slot_mapping, attn_mask=attn_mask, attn_state=attn_state, @@ -528,8 +536,8 @@ class AscendAttentionBackendImpl(AttentionImpl): block_table=attn_metadata.block_tables, input_layout="TND", block_size=block_size, - actual_seq_lengths=attn_metadata.query_start_loc[1:], - actual_seq_lengths_kv=attn_metadata.seq_lens, + actual_seq_lengths=attn_metadata.actual_seq_lengths_q, + actual_seq_lengths_kv=attn_metadata.seq_lens_list, num_key_value_heads=self.num_kv_heads, num_heads=self.num_heads, scale=self.scale,