diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 57179c9f..98409cb9 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -426,8 +426,8 @@ class AscendMLATorchairMetadataBuilder: if num_prefills > 0: reqs_start = num_decodes # prefill_start tokens_start = num_decode_tokens - max_query_len = query_lens[tokens_start:].max().item() - max_seq_lens = seq_lens[tokens_start:].max().item() + max_query_len = query_lens[reqs_start:].max().item() + max_seq_lens = seq_lens[reqs_start:].max().item() prefill_query_start_loc = query_start_loc[ reqs_start:] - query_start_loc[reqs_start] @@ -473,9 +473,9 @@ class AscendMLATorchairMetadataBuilder: 1).unsqueeze(2) prefill_metadata = AscendMLATorchairPrefillMetadata( attn_mask=common_attn_metadata.attn_mask, - query_lens=query_lens[tokens_start:].to(torch.int32), + query_lens=query_lens[reqs_start:].to(torch.int32), seq_lens=seq_lens, - context_lens=seq_lens[tokens_start:], + context_lens=seq_lens[reqs_start:], input_positions=prefill_input_positions, block_table=block_table[reqs_start:, ...], max_query_len=max_query_len,