Fix max_seq_len_k in trtllm_mha attention backend (#9416)
This commit is contained in:
@@ -127,7 +127,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
|
||||
|
||||
# Precompute maximum sequence length
|
||||
metadata.max_seq_len_k = self.max_context_len
|
||||
metadata.max_seq_len_k = seq_lens[:bs].max().item()
|
||||
|
||||
# Precompute page table
|
||||
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :]
|
||||
@@ -156,7 +156,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
max_len = seq_lens_cpu.max().item()
|
||||
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
|
||||
metadata.max_seq_len_k = self.max_context_len
|
||||
metadata.max_seq_len_k = max_len
|
||||
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||
page_indices = self.req_to_token[
|
||||
@@ -265,7 +265,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
workspace_buffer=self.workspace_buffer,
|
||||
block_tables=self.forward_metadata.page_table,
|
||||
seq_lens=self.forward_metadata.cache_seqlens_int32,
|
||||
max_seq_len=self.forward_metadata.max_seq_len_k,
|
||||
max_seq_len=self.max_context_len,
|
||||
bmm1_scale=bmm1_scale,
|
||||
bmm2_scale=bmm2_scale,
|
||||
window_left=layer.sliding_window_size,
|
||||
@@ -320,7 +320,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
block_tables=self.forward_metadata.page_table,
|
||||
seq_lens=self.forward_metadata.cache_seqlens_int32,
|
||||
max_q_len=self.forward_metadata.max_seq_len_q,
|
||||
max_kv_len=self.forward_metadata.max_seq_len_k,
|
||||
max_kv_len=self.max_context_len,
|
||||
bmm1_scale=bmm1_scale,
|
||||
bmm2_scale=bmm2_scale,
|
||||
batch_size=forward_batch.batch_size,
|
||||
|
||||
Reference in New Issue
Block a user