[PD] Support logprob & Add failure test (#6558)

This commit is contained in:
Byron Hsu
2025-05-23 14:29:20 -07:00
committed by GitHub
parent 1b2e8f76d9
commit 8233cc10fd
10 changed files with 595 additions and 241 deletions

View File

@@ -32,6 +32,7 @@ from sglang.srt.disaggregation.utils import (
DisaggregationMode,
FakeBootstrapHost,
KVClassType,
MetadataBuffers,
ReqToMetadataIdxAllocator,
TransferBackend,
get_kv_class,
@@ -63,8 +64,7 @@ class PrefillBootstrapQueue:
token_to_kv_pool: KVCache,
draft_token_to_kv_pool: Optional[KVCache],
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: List[torch.Tensor],
aux_dtype: torch.dtype,
metadata_buffers: MetadataBuffers,
tp_rank: int,
tp_size: int,
bootstrap_port: int,
@@ -76,7 +76,6 @@ class PrefillBootstrapQueue:
self.draft_token_to_kv_pool = draft_token_to_kv_pool
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
self.aux_dtype = aux_dtype
self.metadata_buffers = metadata_buffers
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
@@ -116,15 +115,9 @@ class PrefillBootstrapQueue:
kv_args.kv_item_lens = kv_item_lens
# Define req -> input ids buffer
kv_args.aux_data_ptrs = [
metadata_buffer.data_ptr() for metadata_buffer in self.metadata_buffers
]
kv_args.aux_data_lens = [
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.aux_item_lens = [
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
self.metadata_buffers.get_buf_infos()
)
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
@@ -299,10 +292,9 @@ class SchedulerDisaggregationPrefillMixin:
launch_done: Optional[threading.Event] = None,
) -> None:
"""
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
Adapted from process_batch_result_prefill
"""
(
logits_output,
next_token_ids,
@@ -315,27 +307,78 @@ class SchedulerDisaggregationPrefillMixin:
result.extend_logprob_start_len_per_req,
)
logprob_pt = 0
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
if self.enable_overlap:
# wait
_, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(launch_done)
logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
launch_done
)
else:
next_token_ids = result.next_token_ids.tolist()
for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
if batch.return_logprob:
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
for i, (req, next_token_id) in enumerate(
zip(batch.reqs, next_token_ids, strict=True)
):
req: Req
if req.is_chunked <= 0:
# There is no output_ids for prefill
req.output_ids.append(next_token_id)
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
self.send_kv_chunk(req, token_id=next_token_id)
self.disagg_prefill_inflight_queue.append(req)
if req.return_logprob:
assert extend_logprob_start_len_per_req is not None
assert extend_input_len_per_req is not None
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
extend_input_len = extend_input_len_per_req[i]
num_input_logprobs = extend_input_len - extend_logprob_start_len
self.add_logprob_return_values(
i,
req,
logprob_pt,
next_token_ids,
num_input_logprobs,
logits_output,
)
logprob_pt += num_input_logprobs
self.send_kv_chunk(req, last_chunk=True)
if req.grammar is not None:
req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished()
else:
# being chunked reqs' prefill is not finished
req.is_chunked -= 1
if req.return_logprob:
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
extend_input_len = extend_input_len_per_req[i]
if extend_logprob_start_len < extend_input_len:
# Update input logprobs.
num_input_logprobs = extend_input_len - extend_logprob_start_len
self.add_input_logprob_return_values(
i,
req,
logits_output,
logprob_pt,
num_input_logprobs,
last_prefill_chunk=False,
)
logprob_pt += num_input_logprobs
if self.enable_overlap:
self.send_kv_chunk(req, end_idx=req.tmp_end_idx)
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
# We need to remove the sync in the following function for overlap schedule.
self.set_next_batch_sampling_info_done(batch)
def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
"""
@@ -379,7 +422,11 @@ class SchedulerDisaggregationPrefillMixin:
)
# Stream requests which have finished transfer
self.stream_output(done_reqs, False, None)
self.stream_output(
done_reqs,
any(req.return_logprob for req in done_reqs),
None,
)
self.disagg_prefill_inflight_queue = undone_reqs
@@ -405,7 +452,7 @@ class SchedulerDisaggregationPrefillMixin:
def send_kv_chunk(
self: Scheduler,
req: Req,
token_id: Optional[int] = None,
last_chunk: bool = False,
end_idx: Optional[int] = None,
) -> None:
"""
@@ -413,37 +460,28 @@ class SchedulerDisaggregationPrefillMixin:
"""
page_size = self.token_to_kv_pool_allocator.page_size
start_idx = req.start_send_idx
# if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
# the resolved length is not the same as fill_ids's length
end_idx = (
end_idx
if end_idx is not None
else min(len(req.fill_ids), len(req.origin_input_ids))
)
last_chunk = token_id is not None
if not last_chunk:
# if not the last chunk and the last page is partial, delay the last partial page to the next send
end_idx = end_idx - end_idx % page_size
# Update next start_send_idx
req.start_send_idx = end_idx
kv_indices = (
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
.cpu()
.numpy()
)
if last_chunk is True:
self.disagg_prefill_bootstrap_queue.store_prefill_results(
req.metadata_buffer_index, token_id
)
req.start_send_idx = end_idx
if last_chunk:
self.disagg_metadata_buffers.set_buf(req)
page_indices = kv_to_page_indices(kv_indices, page_size)
if len(page_indices) == 0:
logger.info(
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
)
return
req.disagg_kv_sender.send(page_indices)