[PD] Fix edge case and simplify large page size + chunked prefill (#5589)
This commit is contained in:
@@ -287,8 +287,16 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
"""
|
||||
Send a prefilled chunk to the decode server
|
||||
"""
|
||||
page_size = self.token_to_kv_pool_allocator.page_size
|
||||
start_idx = req.start_send_idx
|
||||
end_idx = min(len(req.fill_ids), len(req.origin_input_ids))
|
||||
last_chunk = token_id is not None
|
||||
|
||||
if (not last_chunk) and (
|
||||
end_idx % page_size != 0
|
||||
): # todo: remove the second condition
|
||||
# if not the last chunk and the last page is partial, delay the last partial page to the next send
|
||||
end_idx = end_idx - end_idx % page_size
|
||||
|
||||
# Update next start_send_idx
|
||||
req.start_send_idx = end_idx
|
||||
@@ -298,18 +306,21 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
if token_id is not None:
|
||||
if last_chunk is True:
|
||||
self.disagg_prefill_pending_queue.store_prefill_results(
|
||||
req.metadata_buffer_index, token_id
|
||||
)
|
||||
is_last = token_id is not None
|
||||
page_indices = kv_to_page_indices(
|
||||
kv_indices, self.token_to_kv_pool_allocator.page_size
|
||||
)
|
||||
page_indices = kv_to_page_indices(kv_indices, page_size)
|
||||
|
||||
page_start_idx = start_idx // self.token_to_kv_pool_allocator.page_size
|
||||
page_start_idx = start_idx // page_size
|
||||
page_end_idx = page_start_idx + len(page_indices)
|
||||
|
||||
if len(page_indices) == 0:
|
||||
logger.info(
|
||||
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
|
||||
)
|
||||
return
|
||||
|
||||
req.disagg_kv_sender.send(
|
||||
page_indices, slice(page_start_idx, page_end_idx), is_last
|
||||
page_indices, slice(page_start_idx, page_end_idx), last_chunk
|
||||
)
|
||||
|
||||
@@ -76,22 +76,14 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
||||
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
||||
|
||||
|
||||
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int, is_last: bool = True):
|
||||
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
||||
# 1. The page is guaruanteed to be full except the last page.
|
||||
# 2. page index = kv_index // page_size
|
||||
|
||||
# The return vector is kv_indices[::page_size] // page_size
|
||||
if page_size == 1: # shortcut
|
||||
return kv_indices
|
||||
|
||||
# if last chunk, send the last partial page
|
||||
# if not last chunk, delay the last partial page to the next send
|
||||
if is_last:
|
||||
return kv_indices[::page_size] // page_size
|
||||
else:
|
||||
if len(kv_indices) % page_size == 0: # no partial page
|
||||
return kv_indices[::page_size] // page_size
|
||||
else: # partial page
|
||||
return kv_indices[::page_size][:-1] // page_size
|
||||
return kv_indices[::page_size] // page_size
|
||||
|
||||
|
||||
def kv_to_page_num(num_kv_indices: int, page_size: int):
|
||||
|
||||
Reference in New Issue
Block a user