[PD] Better logs (#5715)

This commit is contained in:
Liangsheng Yin
2025-04-25 17:25:45 +08:00
committed by GitHub
parent 43fb95c2fa
commit c55550cbf0
3 changed files with 50 additions and 34 deletions

View File

@@ -307,7 +307,7 @@ class DecodeTransferQueue:
def extend(self, req_conns) -> None:
self.queue.extend(req_conns)
def pop_transferred(self) -> List[Req]:
def pop_transferred(self) -> List[DecodeRequest]:
if not self.queue:
return []
@@ -330,7 +330,7 @@ class DecodeTransferQueue:
assert len(decode_req.req.output_ids) == 0
assert decode_req.req.transferred_output_id is None
decode_req.req.transferred_output_id = output_id
transferred_reqs.append(decode_req.req)
transferred_reqs.append(decode_req)
indices_to_remove.add(i)
elif poll in [
KVPoll.Bootstrapping,
@@ -454,7 +454,7 @@ class SchedulerDisaggregationDecodeMixin:
return batch, result
@torch.no_grad()
def event_loop_normal_disagg_decode(self):
def event_loop_normal_disagg_decode(self: Scheduler):
"""A normal scheduler loop for decode worker in disaggregation mode."""
while True:
@@ -497,7 +497,7 @@ class SchedulerDisaggregationDecodeMixin:
self.last_batch = batch
@torch.no_grad()
def event_loop_overlap_disagg_decode(self):
def event_loop_overlap_disagg_decode(self: Scheduler):
result_queue = deque()
self.last_batch: Optional[ScheduleBatch] = None
self.last_batch_in_queue = False # last batch is modifed in-place, so we need another variable to track if it's extend
@@ -641,8 +641,15 @@ class SchedulerDisaggregationDecodeMixin:
def process_decode_queue(self: Scheduler):
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
def _num_pre_alloc(req):
return len(req.req.origin_input_ids) + max(len(req.req.output_ids) - 1, 0)
self.num_tokens_pre_allocated += sum(_num_pre_alloc(req) for req in req_conns)
self.disagg_decode_transfer_queue.extend(req_conns)
alloc_reqs = (
self.disagg_decode_transfer_queue.pop_transferred()
) # the requests which kv has arrived
self.waiting_queue.extend(alloc_reqs)
self.num_tokens_pre_allocated -= sum(_num_pre_alloc(req) for req in alloc_reqs)
self.waiting_queue.extend([req.req for req in alloc_reqs])

View File

@@ -176,14 +176,14 @@ class SchedulerDisaggregationPrefillMixin:
"""
@torch.no_grad()
def event_loop_normal_disagg_prefill(self):
def event_loop_normal_disagg_prefill(self: Scheduler):
"""A normal scheduler loop for prefill worker in disaggregation mode."""
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
self.waiting_queue.extend(
self.disagg_prefill_pending_queue.pop_bootstrapped()
self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
)
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()
@@ -214,14 +214,14 @@ class SchedulerDisaggregationPrefillMixin:
self.running_batch.batch_is_full = False
@torch.no_grad()
def event_loop_overlap_disagg_prefill(self):
def event_loop_overlap_disagg_prefill(self: Scheduler):
self.result_queue = deque()
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
self.waiting_queue.extend(
self.disagg_prefill_pending_queue.pop_bootstrapped()
self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
)
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()
@@ -326,7 +326,7 @@ class SchedulerDisaggregationPrefillMixin:
raise Exception("Transferring failed")
for req in done_reqs:
self.disagg_prefill_pending_queue.req_to_metadata_buffer_idx_allocator.free(
self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
req.metadata_buffer_index
)
@@ -342,9 +342,8 @@ class SchedulerDisaggregationPrefillMixin:
# only finished requests to running_batch.
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req)
if (
self.enable_overlap
): # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
if self.enable_overlap:
# Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
self.chunked_req.tmp_end_idx = min(
len(self.chunked_req.fill_ids),
len(self.chunked_req.origin_input_ids),
@@ -390,7 +389,7 @@ class SchedulerDisaggregationPrefillMixin:
.numpy()
)
if last_chunk is True:
self.disagg_prefill_pending_queue.store_prefill_results(
self.disagg_prefill_bootstrap_queue.store_prefill_results(
req.metadata_buffer_index, token_id
)
page_indices = kv_to_page_indices(kv_indices, page_size)