Fix flashinfer version in sgl-kernel (#10135)
This commit is contained in:
@@ -1187,7 +1187,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(self.speculative_num_steps, max_bs * self.max_context_len),
|
||||
(self.speculative_num_steps, max_bs * self.topk * self.max_context_len),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
@@ -1349,6 +1349,10 @@ def fast_decode_plan(
|
||||
self.device, non_blocking=non_blocking
|
||||
)
|
||||
|
||||
# TODO:
|
||||
# We want to cache `empty_q_data`, `empty_kv_cache`, `last_page_len_host` (if it is ones) in the wrapper
|
||||
# so that we do not need to create them every time.
|
||||
|
||||
# Create empty tensors for dtype info if needed
|
||||
empty_q_data = torch.empty(
|
||||
0,
|
||||
|
||||
Reference in New Issue
Block a user