[PD] Better logs (#5715)
This commit is contained in:
@@ -307,7 +307,7 @@ class DecodeTransferQueue:
|
||||
def extend(self, req_conns) -> None:
|
||||
self.queue.extend(req_conns)
|
||||
|
||||
def pop_transferred(self) -> List[Req]:
|
||||
def pop_transferred(self) -> List[DecodeRequest]:
|
||||
if not self.queue:
|
||||
return []
|
||||
|
||||
@@ -330,7 +330,7 @@ class DecodeTransferQueue:
|
||||
assert len(decode_req.req.output_ids) == 0
|
||||
assert decode_req.req.transferred_output_id is None
|
||||
decode_req.req.transferred_output_id = output_id
|
||||
transferred_reqs.append(decode_req.req)
|
||||
transferred_reqs.append(decode_req)
|
||||
indices_to_remove.add(i)
|
||||
elif poll in [
|
||||
KVPoll.Bootstrapping,
|
||||
@@ -454,7 +454,7 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
return batch, result
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_normal_disagg_decode(self):
|
||||
def event_loop_normal_disagg_decode(self: Scheduler):
|
||||
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
||||
|
||||
while True:
|
||||
@@ -497,7 +497,7 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.last_batch = batch
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_overlap_disagg_decode(self):
|
||||
def event_loop_overlap_disagg_decode(self: Scheduler):
|
||||
result_queue = deque()
|
||||
self.last_batch: Optional[ScheduleBatch] = None
|
||||
self.last_batch_in_queue = False # last batch is modifed in-place, so we need another variable to track if it's extend
|
||||
@@ -641,8 +641,15 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
|
||||
def process_decode_queue(self: Scheduler):
|
||||
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
||||
|
||||
def _num_pre_alloc(req):
|
||||
return len(req.req.origin_input_ids) + max(len(req.req.output_ids) - 1, 0)
|
||||
|
||||
self.num_tokens_pre_allocated += sum(_num_pre_alloc(req) for req in req_conns)
|
||||
self.disagg_decode_transfer_queue.extend(req_conns)
|
||||
alloc_reqs = (
|
||||
self.disagg_decode_transfer_queue.pop_transferred()
|
||||
) # the requests which kv has arrived
|
||||
self.waiting_queue.extend(alloc_reqs)
|
||||
self.num_tokens_pre_allocated -= sum(_num_pre_alloc(req) for req in alloc_reqs)
|
||||
|
||||
self.waiting_queue.extend([req.req for req in alloc_reqs])
|
||||
|
||||
@@ -176,14 +176,14 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_normal_disagg_prefill(self):
|
||||
def event_loop_normal_disagg_prefill(self: Scheduler):
|
||||
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
self.waiting_queue.extend(
|
||||
self.disagg_prefill_pending_queue.pop_bootstrapped()
|
||||
self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
|
||||
)
|
||||
self.process_prefill_chunk()
|
||||
batch = self.get_new_batch_prefill()
|
||||
@@ -214,14 +214,14 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_overlap_disagg_prefill(self):
|
||||
def event_loop_overlap_disagg_prefill(self: Scheduler):
|
||||
self.result_queue = deque()
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
self.waiting_queue.extend(
|
||||
self.disagg_prefill_pending_queue.pop_bootstrapped()
|
||||
self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
|
||||
)
|
||||
self.process_prefill_chunk()
|
||||
batch = self.get_new_batch_prefill()
|
||||
@@ -326,7 +326,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
raise Exception("Transferring failed")
|
||||
|
||||
for req in done_reqs:
|
||||
self.disagg_prefill_pending_queue.req_to_metadata_buffer_idx_allocator.free(
|
||||
self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
|
||||
req.metadata_buffer_index
|
||||
)
|
||||
|
||||
@@ -342,9 +342,8 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
# only finished requests to running_batch.
|
||||
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
||||
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
||||
if (
|
||||
self.enable_overlap
|
||||
): # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
|
||||
if self.enable_overlap:
|
||||
# Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
|
||||
self.chunked_req.tmp_end_idx = min(
|
||||
len(self.chunked_req.fill_ids),
|
||||
len(self.chunked_req.origin_input_ids),
|
||||
@@ -390,7 +389,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
.numpy()
|
||||
)
|
||||
if last_chunk is True:
|
||||
self.disagg_prefill_pending_queue.store_prefill_results(
|
||||
self.disagg_prefill_bootstrap_queue.store_prefill_results(
|
||||
req.metadata_buffer_index, token_id
|
||||
)
|
||||
page_indices = kv_to_page_indices(kv_indices, page_size)
|
||||
|
||||
Reference in New Issue
Block a user