diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 6b66e12d6..a5b207c77 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -1187,7 +1187,7 @@ class FlashInferMultiStepDraftBackend: def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): self.cuda_graph_kv_indices = torch.zeros( - (self.speculative_num_steps, max_bs * self.topk * self.max_context_len), + (self.speculative_num_steps, max_bs * self.max_context_len), dtype=torch.int32, device="cuda", ) @@ -1349,10 +1349,6 @@ def fast_decode_plan( self.device, non_blocking=non_blocking ) - # TODO: - # We want to cache `empty_q_data`, `empty_kv_cache`, `last_page_len_host` (if it is ones) in the wrapper - # so that we do not need to create them every time. - # Create empty tensors for dtype info if needed empty_q_data = torch.empty( 0, diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 5867f95f5..80f29921f 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -81,7 +81,7 @@ FetchContent_Populate(repo-triton) FetchContent_Declare( repo-flashinfer GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git - GIT_TAG 1a85c439a064c1609568675aa580a409a53fb183 + GIT_TAG 018b551825c8e5579206e6eb9d3229fa679202b3 GIT_SHALLOW OFF ) FetchContent_Populate(repo-flashinfer)