Fix flashinfer version in sgl-kernel (#10135)

This commit is contained in:
Lianmin Zheng
2025-09-07 12:54:07 -07:00
committed by GitHub
parent e719bb0e84
commit 76a2c86b88
2 changed files with 6 additions and 2 deletions

View File

@@ -1187,7 +1187,7 @@ class FlashInferMultiStepDraftBackend:
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
(self.speculative_num_steps, max_bs * self.topk * self.max_context_len),
dtype=torch.int32,
device="cuda",
)
@@ -1349,6 +1349,10 @@ def fast_decode_plan(
self.device, non_blocking=non_blocking
)
# TODO:
# We want to cache `empty_q_data`, `empty_kv_cache`, `last_page_len_host` (if it is ones) in the wrapper
# so that we do not need to create them every time.
# Create empty tensors for dtype info if needed
empty_q_data = torch.empty(
0,

View File

@@ -81,7 +81,7 @@ FetchContent_Populate(repo-triton)
FetchContent_Declare(
repo-flashinfer
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
GIT_TAG 018b551825c8e5579206e6eb9d3229fa679202b3
GIT_TAG 1a85c439a064c1609568675aa580a409a53fb183
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-flashinfer)