Fix FlashInfer GPU <-> CPU sync (#9409)

This commit is contained in:
Nathan Wang
2025-08-20 18:26:12 -04:00
committed by GitHub
parent a91e90d9a3
commit 24eaebeb4b

View File

@@ -1372,7 +1372,14 @@ def fast_decode_plan(
if self.use_tensor_cores:
# ALSO convert last_page_len to CPU
last_page_len_host = last_page_len.cpu()
if page_size == 1:
# When page size is 1, last_page_len is always 1.
# Directly construct the host tensor rather than executing a device-to-host copy.
last_page_len_host = torch.ones(
(batch_size,), dtype=torch.int32, device="cpu"
)
else:
last_page_len_host = last_page_len.cpu()
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)