compatible with flashinfer v0.2 (#3235)
This commit is contained in:
@@ -800,7 +800,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
kv_indices = torch.empty(
|
kv_indices = torch.empty(
|
||||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
paged_kernel_lens_sum + 256,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=req_pool_indices.device,
|
||||||
)
|
)
|
||||||
create_flashinfer_kv_indices_triton[(bs,)](
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
|
|||||||
Reference in New Issue
Block a user