From deded17f38a2d48de6731883fe7d0231082795eb Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Mon, 21 Apr 2025 10:27:02 -0700 Subject: [PATCH] [PD] Fix edge case and simplify large page size + chunked prefill (#5589) --- python/sglang/srt/disaggregation/prefill.py | 25 +++++++++++++++------ python/sglang/srt/disaggregation/utils.py | 14 +++--------- scripts/playground/disaggregation/cli.py | 6 ++--- 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 5295cb232..568c9973e 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -287,8 +287,16 @@ class SchedulerDisaggregationPrefillMixin: """ Send a prefilled chunk to the decode server """ + page_size = self.token_to_kv_pool_allocator.page_size start_idx = req.start_send_idx end_idx = min(len(req.fill_ids), len(req.origin_input_ids)) + last_chunk = token_id is not None + + if (not last_chunk) and ( + end_idx % page_size != 0 + ): # todo: remove the second condition + # if not the last chunk and the last page is partial, delay the last partial page to the next send + end_idx = end_idx - end_idx % page_size # Update next start_send_idx req.start_send_idx = end_idx @@ -298,18 +306,21 @@ class SchedulerDisaggregationPrefillMixin: .cpu() .numpy() ) - if token_id is not None: + if last_chunk is True: self.disagg_prefill_pending_queue.store_prefill_results( req.metadata_buffer_index, token_id ) - is_last = token_id is not None - page_indices = kv_to_page_indices( - kv_indices, self.token_to_kv_pool_allocator.page_size - ) + page_indices = kv_to_page_indices(kv_indices, page_size) - page_start_idx = start_idx // self.token_to_kv_pool_allocator.page_size + page_start_idx = start_idx // page_size page_end_idx = page_start_idx + len(page_indices) + if len(page_indices) == 0: + logger.info( + f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty" + ) + return + req.disagg_kv_sender.send( - page_indices, slice(page_start_idx, page_end_idx), is_last + page_indices, slice(page_start_idx, page_end_idx), last_chunk ) diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index e51677b27..4836c9a77 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -76,22 +76,14 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): raise ValueError(f"Unsupported transfer backend: {transfer_backend}") -def kv_to_page_indices(kv_indices: np.ndarray, page_size: int, is_last: bool = True): +def kv_to_page_indices(kv_indices: np.ndarray, page_size: int): # 1. The page is guaruanteed to be full except the last page. # 2. page index = kv_index // page_size - + # The return vector is kv_indices[::page_size] // page_size if page_size == 1: # shortcut return kv_indices - # if last chunk, send the last partial page - # if not last chunk, delay the last partial page to the next send - if is_last: - return kv_indices[::page_size] // page_size - else: - if len(kv_indices) % page_size == 0: # no partial page - return kv_indices[::page_size] // page_size - else: # partial page - return kv_indices[::page_size][:-1] // page_size + return kv_indices[::page_size] // page_size def kv_to_page_num(num_kv_indices: int, page_size: int): diff --git a/scripts/playground/disaggregation/cli.py b/scripts/playground/disaggregation/cli.py index 4a05e2534..5bcc5629e 100644 --- a/scripts/playground/disaggregation/cli.py +++ b/scripts/playground/disaggregation/cli.py @@ -1,4 +1,4 @@ -prompt = "Hello " * 16000 +prompt = [0] * 431 import json @@ -6,8 +6,8 @@ import requests response = requests.post( "http://0.0.0.0:8000/generate", - json={"text": prompt, "sampling_params": {"temperature": 0}}, + json={"input_ids": [prompt] * 32, "sampling_params": {"temperature": 0}}, ) -print("Response content (raw):", response.content) +# print("Response content (raw):", response.content)