diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index a5b207c77..6b66e12d6 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -1187,7 +1187,7 @@ class FlashInferMultiStepDraftBackend: def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): self.cuda_graph_kv_indices = torch.zeros( - (self.speculative_num_steps, max_bs * self.max_context_len), + (self.speculative_num_steps, max_bs * self.topk * self.max_context_len), dtype=torch.int32, device="cuda", ) @@ -1349,6 +1349,10 @@ def fast_decode_plan( self.device, non_blocking=non_blocking ) + # TODO: + # We want to cache `empty_q_data`, `empty_kv_cache`, `last_page_len_host` (if it is ones) in the wrapper + # so that we do not need to create them every time. + # Create empty tensors for dtype info if needed empty_q_data = torch.empty( 0, diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 13ef9ce49..7fa1c723c 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -81,7 +81,7 @@ FetchContent_Populate(repo-triton) FetchContent_Declare( repo-flashinfer GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git - GIT_TAG 018b551825c8e5579206e6eb9d3229fa679202b3 + GIT_TAG 1a85c439a064c1609568675aa580a409a53fb183 GIT_SHALLOW OFF ) FetchContent_Populate(repo-flashinfer)