diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 11b712acc..e6436952d 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -231,7 +231,7 @@ class MooncakeKVManager(BaseKVManager): chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice] assert len(chunked_dst_kv_indice) == len( kv_chunk.prefill_kv_indices - ) + ), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}" ret = self.send_kvcache( req.mooncake_session_id, diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 9e080926a..5295cb232 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -306,4 +306,10 @@ class SchedulerDisaggregationPrefillMixin: page_indices = kv_to_page_indices( kv_indices, self.token_to_kv_pool_allocator.page_size ) - req.disagg_kv_sender.send(page_indices, slice(start_idx, end_idx), is_last) + + page_start_idx = start_idx // self.token_to_kv_pool_allocator.page_size + page_end_idx = page_start_idx + len(page_indices) + + req.disagg_kv_sender.send( + page_indices, slice(page_start_idx, page_end_idx), is_last + ) diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 7b43566f1..e51677b27 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -76,13 +76,22 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): raise ValueError(f"Unsupported transfer backend: {transfer_backend}") -def kv_to_page_indices(kv_indices: np.ndarray, page_size: int): +def kv_to_page_indices(kv_indices: np.ndarray, page_size: int, is_last: bool = True): # 1. The page is guaruanteed to be full except the last page. # 2. page index = kv_index // page_size - # The return vector is kv_indices[::page_size] // page_size + if page_size == 1: # shortcut return kv_indices - return kv_indices[::page_size] // page_size + + # if last chunk, send the last partial page + # if not last chunk, delay the last partial page to the next send + if is_last: + return kv_indices[::page_size] // page_size + else: + if len(kv_indices) % page_size == 0: # no partial page + return kv_indices[::page_size] // page_size + else: # partial page + return kv_indices[::page_size][:-1] // page_size def kv_to_page_num(num_kv_indices: int, page_size: int): diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 8b38b1cbe..079dfd9b1 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -446,13 +446,16 @@ class MLATokenToKVPool(KVCache): ] self.layer_transfer_counter = None + self.page_size = page_size # for disagg def get_contiguous_buf_infos(self): # MLA has only one kv_buffer, so only the information of this buffer needs to be returned. kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)] kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)] - kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)] + kv_item_lens = [ + self.kv_buffer[i][0].nbytes * self.page_size for i in range(self.layer_num) + ] return kv_data_ptrs, kv_data_lens, kv_item_lens def get_key_buffer(self, layer_id: int): diff --git a/scripts/playground/disaggregation/cli.py b/scripts/playground/disaggregation/cli.py new file mode 100644 index 000000000..4a05e2534 --- /dev/null +++ b/scripts/playground/disaggregation/cli.py @@ -0,0 +1,13 @@ +prompt = "Hello " * 16000 + +import json + +import requests + +response = requests.post( + "http://0.0.0.0:8000/generate", + json={"text": prompt, "sampling_params": {"temperature": 0}}, +) + + +print("Response content (raw):", response.content)