compatible with flashinfer v0.2 (#3235)

This commit is contained in:
Yineng Zhang
2025-02-01 01:32:18 +08:00
committed by GitHub
parent 656f7fc1bc
commit 7811bfdaa7

View File

@@ -800,7 +800,9 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
paged_kernel_lens_sum + 256,
dtype=torch.int32,
device=req_pool_indices.device,
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,