Fix FlashInfer GPU <-> CPU sync (#9409)
This commit is contained in:
@@ -1372,7 +1372,14 @@ def fast_decode_plan(
|
||||
|
||||
if self.use_tensor_cores:
|
||||
# ALSO convert last_page_len to CPU
|
||||
last_page_len_host = last_page_len.cpu()
|
||||
if page_size == 1:
|
||||
# When page size is 1, last_page_len is always 1.
|
||||
# Directly construct the host tensor rather than executing a device-to-host copy.
|
||||
last_page_len_host = torch.ones(
|
||||
(batch_size,), dtype=torch.int32, device="cpu"
|
||||
)
|
||||
else:
|
||||
last_page_len_host = last_page_len.cpu()
|
||||
|
||||
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user