[PD] Support logprob & Add failure test (#6558)
This commit is contained in:
@@ -32,6 +32,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
DisaggregationMode,
|
||||
FakeBootstrapHost,
|
||||
KVClassType,
|
||||
MetadataBuffers,
|
||||
ReqToMetadataIdxAllocator,
|
||||
TransferBackend,
|
||||
get_kv_class,
|
||||
@@ -63,8 +64,7 @@ class PrefillBootstrapQueue:
|
||||
token_to_kv_pool: KVCache,
|
||||
draft_token_to_kv_pool: Optional[KVCache],
|
||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||
metadata_buffers: List[torch.Tensor],
|
||||
aux_dtype: torch.dtype,
|
||||
metadata_buffers: MetadataBuffers,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
bootstrap_port: int,
|
||||
@@ -76,7 +76,6 @@ class PrefillBootstrapQueue:
|
||||
self.draft_token_to_kv_pool = draft_token_to_kv_pool
|
||||
|
||||
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
|
||||
self.aux_dtype = aux_dtype
|
||||
|
||||
self.metadata_buffers = metadata_buffers
|
||||
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||
@@ -116,15 +115,9 @@ class PrefillBootstrapQueue:
|
||||
kv_args.kv_item_lens = kv_item_lens
|
||||
|
||||
# Define req -> input ids buffer
|
||||
kv_args.aux_data_ptrs = [
|
||||
metadata_buffer.data_ptr() for metadata_buffer in self.metadata_buffers
|
||||
]
|
||||
kv_args.aux_data_lens = [
|
||||
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
|
||||
]
|
||||
kv_args.aux_item_lens = [
|
||||
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
||||
]
|
||||
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
||||
self.metadata_buffers.get_buf_infos()
|
||||
)
|
||||
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||
kv_args.gpu_id = self.scheduler.gpu_id
|
||||
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
||||
@@ -299,10 +292,9 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
||||
Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
||||
Adapted from process_batch_result_prefill
|
||||
"""
|
||||
|
||||
(
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
@@ -315,27 +307,78 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
result.extend_logprob_start_len_per_req,
|
||||
)
|
||||
|
||||
logprob_pt = 0
|
||||
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
||||
if self.enable_overlap:
|
||||
# wait
|
||||
_, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(launch_done)
|
||||
logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
|
||||
launch_done
|
||||
)
|
||||
else:
|
||||
next_token_ids = result.next_token_ids.tolist()
|
||||
|
||||
for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
|
||||
if batch.return_logprob:
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
logits_output.next_token_logprobs = (
|
||||
logits_output.next_token_logprobs.tolist()
|
||||
)
|
||||
if logits_output.input_token_logprobs is not None:
|
||||
logits_output.input_token_logprobs = tuple(
|
||||
logits_output.input_token_logprobs.tolist()
|
||||
)
|
||||
for i, (req, next_token_id) in enumerate(
|
||||
zip(batch.reqs, next_token_ids, strict=True)
|
||||
):
|
||||
req: Req
|
||||
if req.is_chunked <= 0:
|
||||
# There is no output_ids for prefill
|
||||
req.output_ids.append(next_token_id)
|
||||
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
||||
self.send_kv_chunk(req, token_id=next_token_id)
|
||||
self.disagg_prefill_inflight_queue.append(req)
|
||||
if req.return_logprob:
|
||||
assert extend_logprob_start_len_per_req is not None
|
||||
assert extend_input_len_per_req is not None
|
||||
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
||||
extend_input_len = extend_input_len_per_req[i]
|
||||
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
||||
self.add_logprob_return_values(
|
||||
i,
|
||||
req,
|
||||
logprob_pt,
|
||||
next_token_ids,
|
||||
num_input_logprobs,
|
||||
logits_output,
|
||||
)
|
||||
logprob_pt += num_input_logprobs
|
||||
self.send_kv_chunk(req, last_chunk=True)
|
||||
|
||||
if req.grammar is not None:
|
||||
req.grammar.accept_token(next_token_id)
|
||||
req.grammar.finished = req.finished()
|
||||
else:
|
||||
# being chunked reqs' prefill is not finished
|
||||
req.is_chunked -= 1
|
||||
|
||||
if req.return_logprob:
|
||||
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
||||
extend_input_len = extend_input_len_per_req[i]
|
||||
if extend_logprob_start_len < extend_input_len:
|
||||
# Update input logprobs.
|
||||
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
||||
self.add_input_logprob_return_values(
|
||||
i,
|
||||
req,
|
||||
logits_output,
|
||||
logprob_pt,
|
||||
num_input_logprobs,
|
||||
last_prefill_chunk=False,
|
||||
)
|
||||
logprob_pt += num_input_logprobs
|
||||
|
||||
if self.enable_overlap:
|
||||
self.send_kv_chunk(req, end_idx=req.tmp_end_idx)
|
||||
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
|
||||
|
||||
# We need to remove the sync in the following function for overlap schedule.
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
|
||||
def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
|
||||
"""
|
||||
@@ -379,7 +422,11 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
)
|
||||
|
||||
# Stream requests which have finished transfer
|
||||
self.stream_output(done_reqs, False, None)
|
||||
self.stream_output(
|
||||
done_reqs,
|
||||
any(req.return_logprob for req in done_reqs),
|
||||
None,
|
||||
)
|
||||
|
||||
self.disagg_prefill_inflight_queue = undone_reqs
|
||||
|
||||
@@ -405,7 +452,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
def send_kv_chunk(
|
||||
self: Scheduler,
|
||||
req: Req,
|
||||
token_id: Optional[int] = None,
|
||||
last_chunk: bool = False,
|
||||
end_idx: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -413,37 +460,28 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
"""
|
||||
page_size = self.token_to_kv_pool_allocator.page_size
|
||||
start_idx = req.start_send_idx
|
||||
# if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
|
||||
# the resolved length is not the same as fill_ids's length
|
||||
end_idx = (
|
||||
end_idx
|
||||
if end_idx is not None
|
||||
else min(len(req.fill_ids), len(req.origin_input_ids))
|
||||
)
|
||||
|
||||
last_chunk = token_id is not None
|
||||
|
||||
if not last_chunk:
|
||||
# if not the last chunk and the last page is partial, delay the last partial page to the next send
|
||||
end_idx = end_idx - end_idx % page_size
|
||||
|
||||
# Update next start_send_idx
|
||||
req.start_send_idx = end_idx
|
||||
|
||||
kv_indices = (
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
if last_chunk is True:
|
||||
self.disagg_prefill_bootstrap_queue.store_prefill_results(
|
||||
req.metadata_buffer_index, token_id
|
||||
)
|
||||
req.start_send_idx = end_idx
|
||||
if last_chunk:
|
||||
self.disagg_metadata_buffers.set_buf(req)
|
||||
page_indices = kv_to_page_indices(kv_indices, page_size)
|
||||
if len(page_indices) == 0:
|
||||
logger.info(
|
||||
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
|
||||
)
|
||||
return
|
||||
|
||||
req.disagg_kv_sender.send(page_indices)
|
||||
|
||||
Reference in New Issue
Block a user