diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 365a0a54f..e1f7ea76f 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -322,10 +322,13 @@ class FlashAttentionBackend(AttentionBackend): torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0) ) - metadata.page_table = self.req_to_token[ - :, self.decode_cuda_graph_metadata["strided_indices"] + max_seq_pages = (metadata.max_seq_len_k + self.page_size - 1) // self.page_size + page_indices = self.req_to_token[ + :, self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages] ] - metadata.page_table = metadata.page_table[req_pool_indices[:bs]] + page_indices = page_indices[req_pool_indices[:bs]] // self.page_size + metadata.page_table[:, :max_seq_pages].copy_(page_indices) + metadata.page_table[:, max_seq_pages:].fill_(0) self.forward_metadata = metadata def get_cuda_graph_seq_len_fill_value(self):