[PD] Add doc and simplify sender.send (#6019)

This commit is contained in:
Byron Hsu
2025-05-21 21:22:21 -07:00
committed by GitHub
parent 4d643f6c7a
commit 7513558074
6 changed files with 63 additions and 25 deletions

View File

@@ -33,28 +33,18 @@ class FakeKVSender(BaseKVSender):
self,
kv_indices: list[int],
aux_index: Optional[int] = None,
dest_ranks: Optional[list[int]] = None,
):
logger.info(
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}, dest_ranks: {dest_ranks}"
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}"
)
pass
def send(
self,
kv_indices: npt.NDArray[np.int64],
index_slice: slice,
is_last: bool,
):
logger.info(
f"FakeKVSender send with kv_indices: {kv_indices}, index_slice: {index_slice}, is_last: {is_last}"
)
if is_last:
self.has_sent = True
logger.info(f"FakeKVSender send success")
else:
self.has_sent = False
logger.info(f"FakeKVSender send fake transferring")
self.has_sent = True
logger.info(f"FakeKVSender send with kv_indices: {kv_indices}")
def failure_exception(self):
raise Exception("Fake KVSender Exception")

View File

@@ -464,6 +464,8 @@ class MooncakeKVSender(BaseKVSender):
self.aux_index = None
self.bootstrap_server_url = bootstrap_addr
self.session_id = self.kv_mgr.get_session_id()
# inner state
self.curr_idx = 0
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.num_kv_indices = num_kv_indices
@@ -472,9 +474,11 @@ class MooncakeKVSender(BaseKVSender):
def send(
self,
kv_indices: npt.NDArray[np.int64],
index_slice: slice,
is_last: bool,
):
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
self.curr_idx += len(kv_indices)
is_last = self.curr_idx == self.num_kv_indices
if not is_last:
self.kv_mgr.add_transfer_request(
self.bootstrap_room, kv_indices, index_slice, False

View File

@@ -384,11 +384,10 @@ class SchedulerDisaggregationPrefillMixin:
if end_idx is not None
else min(len(req.fill_ids), len(req.origin_input_ids))
)
last_chunk = token_id is not None
if (not last_chunk) and (
end_idx % page_size != 0
): # todo: remove the second condition
if not last_chunk:
# if not the last chunk and the last page is partial, delay the last partial page to the next send
end_idx = end_idx - end_idx % page_size
@@ -405,16 +404,10 @@ class SchedulerDisaggregationPrefillMixin:
req.metadata_buffer_index, token_id
)
page_indices = kv_to_page_indices(kv_indices, page_size)
page_start_idx = start_idx // page_size
page_end_idx = page_start_idx + len(page_indices)
if len(page_indices) == 0:
logger.info(
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
)
return
req.disagg_kv_sender.send(
page_indices, slice(page_start_idx, page_end_idx), last_chunk
)
req.disagg_kv_sender.send(page_indices)

View File

@@ -407,6 +407,7 @@ class GenerateReqInput:
else None
),
return_hidden_states=self.return_hidden_states,
# if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
bootstrap_host=(
self.bootstrap_host[i] if self.bootstrap_host is not None else None
),