diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 9e6365cbf..2ab7f6cb5 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -977,10 +977,12 @@ class FlashAttentionBackend(AttentionBackend): metadata.max_seq_len_k + self.page_size - 1 ) // self.page_size page_indices = self.req_to_token[ - :, - self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages], + req_pool_indices[:, None], + self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][ + None, : + ], ] - page_indices = page_indices[req_pool_indices] // self.page_size + page_indices //= self.page_size metadata.page_table[:, :max_seq_pages].copy_(page_indices) metadata.page_table[:, max_seq_pages:].fill_(0)