Fix FlashInfer GPU <-> CPU sync (#9409)
This commit is contained in:
@@ -1372,7 +1372,14 @@ def fast_decode_plan(
|
|||||||
|
|
||||||
if self.use_tensor_cores:
|
if self.use_tensor_cores:
|
||||||
# ALSO convert last_page_len to CPU
|
# ALSO convert last_page_len to CPU
|
||||||
last_page_len_host = last_page_len.cpu()
|
if page_size == 1:
|
||||||
|
# When page size is 1, last_page_len is always 1.
|
||||||
|
# Directly construct the host tensor rather than executing a device-to-host copy.
|
||||||
|
last_page_len_host = torch.ones(
|
||||||
|
(batch_size,), dtype=torch.int32, device="cpu"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
last_page_len_host = last_page_len.cpu()
|
||||||
|
|
||||||
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
|
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user