Fix max_seq_len_k in trtllm_mha attention backend (#9416)

This commit is contained in:
Qiaolin Yu
2025-08-20 19:17:13 -07:00
committed by GitHub
parent 5cfbb4c136
commit af1973b871

View File

@@ -127,7 +127,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32) metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
# Precompute maximum sequence length # Precompute maximum sequence length
metadata.max_seq_len_k = self.max_context_len metadata.max_seq_len_k = seq_lens[:bs].max().item()
# Precompute page table # Precompute page table
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :] metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :]
@@ -156,7 +156,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
metadata = self.decode_cuda_graph_metadata[bs] metadata = self.decode_cuda_graph_metadata[bs]
max_len = seq_lens_cpu.max().item() max_len = seq_lens_cpu.max().item()
max_seq_pages = (max_len + self.page_size - 1) // self.page_size max_seq_pages = (max_len + self.page_size - 1) // self.page_size
metadata.max_seq_len_k = self.max_context_len metadata.max_seq_len_k = max_len
metadata.cache_seqlens_int32.copy_(seq_lens) metadata.cache_seqlens_int32.copy_(seq_lens)
page_indices = self.req_to_token[ page_indices = self.req_to_token[
@@ -265,7 +265,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
workspace_buffer=self.workspace_buffer, workspace_buffer=self.workspace_buffer,
block_tables=self.forward_metadata.page_table, block_tables=self.forward_metadata.page_table,
seq_lens=self.forward_metadata.cache_seqlens_int32, seq_lens=self.forward_metadata.cache_seqlens_int32,
max_seq_len=self.forward_metadata.max_seq_len_k, max_seq_len=self.max_context_len,
bmm1_scale=bmm1_scale, bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale, bmm2_scale=bmm2_scale,
window_left=layer.sliding_window_size, window_left=layer.sliding_window_size,
@@ -320,7 +320,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
block_tables=self.forward_metadata.page_table, block_tables=self.forward_metadata.page_table,
seq_lens=self.forward_metadata.cache_seqlens_int32, seq_lens=self.forward_metadata.cache_seqlens_int32,
max_q_len=self.forward_metadata.max_seq_len_q, max_q_len=self.forward_metadata.max_seq_len_q,
max_kv_len=self.forward_metadata.max_seq_len_k, max_kv_len=self.max_context_len,
bmm1_scale=bmm1_scale, bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale, bmm2_scale=bmm2_scale,
batch_size=forward_batch.batch_size, batch_size=forward_batch.batch_size,