diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index a662265..c8379b7 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -79,7 +79,7 @@ class AscendMLAPrefillMetadata: chunk_seq_lens: torch.Tensor attn_mask: torch.Tensor - query_lens: list[int] + query_lens: torch.Tensor seq_lens: list[int] context_lens: torch.Tensor input_positions: torch.Tensor @@ -380,7 +380,7 @@ class AscendMLAMetadataBuilder: 1).unsqueeze(2) prefill_metadata = AscendMLAPrefillMetadata( attn_mask=common_attn_metadata.attn_mask, - query_lens=query_lens[reqs_start:], + query_lens=query_lens[reqs_start:].to(torch.int32), seq_lens=seq_lens, context_lens=seq_lens[reqs_start:], input_positions=prefill_input_positions, @@ -837,9 +837,7 @@ class AscendMLAImpl(MLAAttentionImpl): k_rope=k_pe, value=value, mask=self.prefill_mask, - seqlen=torch.tensor( - attn_metadata.prefill.query_lens, - dtype=torch.int32), + seqlen=attn_metadata.prefill.query_lens, head_num=self.num_heads, kv_head_num=self.num_heads, pre_out=None, diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index ed14fed..4269727 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -74,7 +74,7 @@ class AscendMLATorchairPrefillMetadata: chunk_seq_lens: torch.Tensor attn_mask: torch.Tensor - query_lens: list[int] + query_lens: torch.Tensor seq_lens: list[int] context_lens: torch.Tensor input_positions: torch.Tensor @@ -473,7 +473,7 @@ class AscendMLATorchairMetadataBuilder: 1).unsqueeze(2) prefill_metadata = AscendMLATorchairPrefillMetadata( attn_mask=common_attn_metadata.attn_mask, - query_lens=query_lens[tokens_start:], + query_lens=query_lens[tokens_start:].to(torch.int32), seq_lens=seq_lens, context_lens=seq_lens[tokens_start:], input_positions=prefill_input_positions, @@ -880,9 +880,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl): k_rope=k_pe, value=value, mask=self.prefill_mask, - seqlen=torch.tensor( - attn_metadata.prefill.query_lens, - dtype=torch.int32), + seqlen=attn_metadata.prefill.query_lens, head_num=self.num_heads, kv_head_num=self.num_heads, pre_out=None,