[PD] Update prefill.py (#7190)

This commit is contained in:
Byron Hsu
2025-06-14 15:59:54 -07:00
committed by GitHub
parent ab1a4fa5cb
commit 7d316991b2
11 changed files with 458 additions and 245 deletions

View File

@@ -619,7 +619,7 @@ class Scheduler(
self.disaggregation_mode == DisaggregationMode.DECODE
): # *2 for the headroom.
buffer_size = (self.req_to_token_pool.size) * 2
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
@@ -627,7 +627,7 @@ class Scheduler(
# The decode requests polling kv cache
self.disagg_decode_transfer_queue = DecodeTransferQueue(
gloo_group=self.attn_tp_cpu_group,
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
metadata_buffers=self.disagg_metadata_buffers,
scheduler=self,
tree_cache=self.tree_cache,
@@ -642,7 +642,7 @@ class Scheduler(
if self.draft_worker is None
else self.draft_worker.model_runner.token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
metadata_buffers=self.disagg_metadata_buffers,
scheduler=self,
transfer_queue=self.disagg_decode_transfer_queue,
@@ -660,7 +660,7 @@ class Scheduler(
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
# *2 for the headroom.
buffer_size = self.max_running_requests * 2
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
@@ -672,14 +672,20 @@ class Scheduler(
if self.draft_worker is None
else self.draft_worker.model_runner.token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
metadata_buffers=self.disagg_metadata_buffers,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
gpu_id=self.gpu_id,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
gloo_group=self.attn_tp_cpu_group,
transfer_backend=self.transfer_backend,
max_total_num_tokens=self.max_total_num_tokens,
decode_tp_size=self.server_args.disaggregation_decode_tp,
decode_dp_size=self.server_args.disaggregation_decode_dp,
scheduler=self,
pp_rank=self.pp_rank,
pp_size=self.pp_size,
transfer_backend=self.transfer_backend,
)
# The prefill requests that are in the middle of kv sending
self.disagg_prefill_inflight_queue: List[Req] = []
@@ -1110,7 +1116,9 @@ class Scheduler(
def _add_request_to_queue(self, req: Req):
req.queue_time_start = time.perf_counter()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.disagg_prefill_bootstrap_queue.add(req)
self.disagg_prefill_bootstrap_queue.add(
req, self.model_config.num_key_value_heads
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.add(req)
else:
@@ -1118,7 +1126,9 @@ class Scheduler(
def _extend_requests_to_queue(self, reqs: List[Req]):
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.disagg_prefill_bootstrap_queue.extend(reqs)
self.disagg_prefill_bootstrap_queue.extend(
reqs, self.model_config.num_key_value_heads
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
# If this is a decode server, we put the request to the decode pending prealloc queue
self.disagg_decode_prealloc_queue.extend(reqs)