Fix flashinfer version in sgl-kernel (#10135)
This commit is contained in:
@@ -1187,7 +1187,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(self.speculative_num_steps, max_bs * self.max_context_len),
|
||||
(self.speculative_num_steps, max_bs * self.topk * self.max_context_len),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
@@ -1349,6 +1349,10 @@ def fast_decode_plan(
|
||||
self.device, non_blocking=non_blocking
|
||||
)
|
||||
|
||||
# TODO:
|
||||
# We want to cache `empty_q_data`, `empty_kv_cache`, `last_page_len_host` (if it is ones) in the wrapper
|
||||
# so that we do not need to create them every time.
|
||||
|
||||
# Create empty tensors for dtype info if needed
|
||||
empty_q_data = torch.empty(
|
||||
0,
|
||||
|
||||
@@ -81,7 +81,7 @@ FetchContent_Populate(repo-triton)
|
||||
FetchContent_Declare(
|
||||
repo-flashinfer
|
||||
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
|
||||
GIT_TAG 018b551825c8e5579206e6eb9d3229fa679202b3
|
||||
GIT_TAG 1a85c439a064c1609568675aa580a409a53fb183
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-flashinfer)
|
||||
|
||||
Reference in New Issue
Block a user