[PD] Fix large page size + chunk prefill (#5588)
This commit is contained in:
@@ -231,7 +231,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
|
chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
|
||||||
assert len(chunked_dst_kv_indice) == len(
|
assert len(chunked_dst_kv_indice) == len(
|
||||||
kv_chunk.prefill_kv_indices
|
kv_chunk.prefill_kv_indices
|
||||||
)
|
), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
|
||||||
|
|
||||||
ret = self.send_kvcache(
|
ret = self.send_kvcache(
|
||||||
req.mooncake_session_id,
|
req.mooncake_session_id,
|
||||||
|
|||||||
@@ -306,4 +306,10 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
page_indices = kv_to_page_indices(
|
page_indices = kv_to_page_indices(
|
||||||
kv_indices, self.token_to_kv_pool_allocator.page_size
|
kv_indices, self.token_to_kv_pool_allocator.page_size
|
||||||
)
|
)
|
||||||
req.disagg_kv_sender.send(page_indices, slice(start_idx, end_idx), is_last)
|
|
||||||
|
page_start_idx = start_idx // self.token_to_kv_pool_allocator.page_size
|
||||||
|
page_end_idx = page_start_idx + len(page_indices)
|
||||||
|
|
||||||
|
req.disagg_kv_sender.send(
|
||||||
|
page_indices, slice(page_start_idx, page_end_idx), is_last
|
||||||
|
)
|
||||||
|
|||||||
@@ -76,13 +76,22 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|||||||
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
||||||
|
|
||||||
|
|
||||||
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int, is_last: bool = True):
|
||||||
# 1. The page is guaruanteed to be full except the last page.
|
# 1. The page is guaruanteed to be full except the last page.
|
||||||
# 2. page index = kv_index // page_size
|
# 2. page index = kv_index // page_size
|
||||||
# The return vector is kv_indices[::page_size] // page_size
|
|
||||||
if page_size == 1: # shortcut
|
if page_size == 1: # shortcut
|
||||||
return kv_indices
|
return kv_indices
|
||||||
return kv_indices[::page_size] // page_size
|
|
||||||
|
# if last chunk, send the last partial page
|
||||||
|
# if not last chunk, delay the last partial page to the next send
|
||||||
|
if is_last:
|
||||||
|
return kv_indices[::page_size] // page_size
|
||||||
|
else:
|
||||||
|
if len(kv_indices) % page_size == 0: # no partial page
|
||||||
|
return kv_indices[::page_size] // page_size
|
||||||
|
else: # partial page
|
||||||
|
return kv_indices[::page_size][:-1] // page_size
|
||||||
|
|
||||||
|
|
||||||
def kv_to_page_num(num_kv_indices: int, page_size: int):
|
def kv_to_page_num(num_kv_indices: int, page_size: int):
|
||||||
|
|||||||
@@ -446,13 +446,16 @@ class MLATokenToKVPool(KVCache):
|
|||||||
]
|
]
|
||||||
|
|
||||||
self.layer_transfer_counter = None
|
self.layer_transfer_counter = None
|
||||||
|
self.page_size = page_size
|
||||||
|
|
||||||
# for disagg
|
# for disagg
|
||||||
def get_contiguous_buf_infos(self):
|
def get_contiguous_buf_infos(self):
|
||||||
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
||||||
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
|
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
|
||||||
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
|
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
|
||||||
kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
|
kv_item_lens = [
|
||||||
|
self.kv_buffer[i][0].nbytes * self.page_size for i in range(self.layer_num)
|
||||||
|
]
|
||||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||||
|
|
||||||
def get_key_buffer(self, layer_id: int):
|
def get_key_buffer(self, layer_id: int):
|
||||||
|
|||||||
13
scripts/playground/disaggregation/cli.py
Normal file
13
scripts/playground/disaggregation/cli.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
prompt = "Hello " * 16000
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
"http://0.0.0.0:8000/generate",
|
||||||
|
json={"text": prompt, "sampling_params": {"temperature": 0}},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
print("Response content (raw):", response.content)
|
||||||
Reference in New Issue
Block a user