Fix flashinfer version in sgl-kernel (#10135)
This commit is contained in:
@@ -1187,7 +1187,7 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||||
self.cuda_graph_kv_indices = torch.zeros(
|
self.cuda_graph_kv_indices = torch.zeros(
|
||||||
(self.speculative_num_steps, max_bs * self.max_context_len),
|
(self.speculative_num_steps, max_bs * self.topk * self.max_context_len),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
@@ -1349,6 +1349,10 @@ def fast_decode_plan(
|
|||||||
self.device, non_blocking=non_blocking
|
self.device, non_blocking=non_blocking
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO:
|
||||||
|
# We want to cache `empty_q_data`, `empty_kv_cache`, `last_page_len_host` (if it is ones) in the wrapper
|
||||||
|
# so that we do not need to create them every time.
|
||||||
|
|
||||||
# Create empty tensors for dtype info if needed
|
# Create empty tensors for dtype info if needed
|
||||||
empty_q_data = torch.empty(
|
empty_q_data = torch.empty(
|
||||||
0,
|
0,
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ FetchContent_Populate(repo-triton)
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
repo-flashinfer
|
repo-flashinfer
|
||||||
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
|
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
|
||||||
GIT_TAG 018b551825c8e5579206e6eb9d3229fa679202b3
|
GIT_TAG 1a85c439a064c1609568675aa580a409a53fb183
|
||||||
GIT_SHALLOW OFF
|
GIT_SHALLOW OFF
|
||||||
)
|
)
|
||||||
FetchContent_Populate(repo-flashinfer)
|
FetchContent_Populate(repo-flashinfer)
|
||||||
|
|||||||
Reference in New Issue
Block a user