From af1973b871e9e81825a243bfc7de99ca469c1df1 Mon Sep 17 00:00:00 2001 From: Qiaolin Yu Date: Wed, 20 Aug 2025 19:17:13 -0700 Subject: [PATCH] Fix max_seq_len_k in trtllm_mha attention backend (#9416) --- python/sglang/srt/layers/attention/trtllm_mha_backend.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index d8cb8aa0b..b737d96e7 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -127,7 +127,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32) # Precompute maximum sequence length - metadata.max_seq_len_k = self.max_context_len + metadata.max_seq_len_k = seq_lens[:bs].max().item() # Precompute page table metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :] @@ -156,7 +156,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): metadata = self.decode_cuda_graph_metadata[bs] max_len = seq_lens_cpu.max().item() max_seq_pages = (max_len + self.page_size - 1) // self.page_size - metadata.max_seq_len_k = self.max_context_len + metadata.max_seq_len_k = max_len metadata.cache_seqlens_int32.copy_(seq_lens) page_indices = self.req_to_token[ @@ -265,7 +265,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): workspace_buffer=self.workspace_buffer, block_tables=self.forward_metadata.page_table, seq_lens=self.forward_metadata.cache_seqlens_int32, - max_seq_len=self.forward_metadata.max_seq_len_k, + max_seq_len=self.max_context_len, bmm1_scale=bmm1_scale, bmm2_scale=bmm2_scale, window_left=layer.sliding_window_size, @@ -320,7 +320,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): block_tables=self.forward_metadata.page_table, seq_lens=self.forward_metadata.cache_seqlens_int32, max_q_len=self.forward_metadata.max_seq_len_q, - max_kv_len=self.forward_metadata.max_seq_len_k, + max_kv_len=self.max_context_len, bmm1_scale=bmm1_scale, bmm2_scale=bmm2_scale, batch_size=forward_batch.batch_size,