[PD] Update prefill.py (#7190)
This commit is contained in:
@@ -619,7 +619,7 @@ class Scheduler(
|
||||
self.disaggregation_mode == DisaggregationMode.DECODE
|
||||
): # *2 for the headroom.
|
||||
buffer_size = (self.req_to_token_pool.size) * 2
|
||||
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||
buffer_size
|
||||
)
|
||||
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
||||
@@ -627,7 +627,7 @@ class Scheduler(
|
||||
# The decode requests polling kv cache
|
||||
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
||||
gloo_group=self.attn_tp_cpu_group,
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=self.disagg_metadata_buffers,
|
||||
scheduler=self,
|
||||
tree_cache=self.tree_cache,
|
||||
@@ -642,7 +642,7 @@ class Scheduler(
|
||||
if self.draft_worker is None
|
||||
else self.draft_worker.model_runner.token_to_kv_pool
|
||||
),
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=self.disagg_metadata_buffers,
|
||||
scheduler=self,
|
||||
transfer_queue=self.disagg_decode_transfer_queue,
|
||||
@@ -660,7 +660,7 @@ class Scheduler(
|
||||
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# *2 for the headroom.
|
||||
buffer_size = self.max_running_requests * 2
|
||||
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||
buffer_size
|
||||
)
|
||||
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
||||
@@ -672,14 +672,20 @@ class Scheduler(
|
||||
if self.draft_worker is None
|
||||
else self.draft_worker.model_runner.token_to_kv_pool
|
||||
),
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=self.disagg_metadata_buffers,
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
gpu_id=self.gpu_id,
|
||||
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
||||
gloo_group=self.attn_tp_cpu_group,
|
||||
transfer_backend=self.transfer_backend,
|
||||
max_total_num_tokens=self.max_total_num_tokens,
|
||||
decode_tp_size=self.server_args.disaggregation_decode_tp,
|
||||
decode_dp_size=self.server_args.disaggregation_decode_dp,
|
||||
scheduler=self,
|
||||
pp_rank=self.pp_rank,
|
||||
pp_size=self.pp_size,
|
||||
transfer_backend=self.transfer_backend,
|
||||
)
|
||||
# The prefill requests that are in the middle of kv sending
|
||||
self.disagg_prefill_inflight_queue: List[Req] = []
|
||||
@@ -1110,7 +1116,9 @@ class Scheduler(
|
||||
def _add_request_to_queue(self, req: Req):
|
||||
req.queue_time_start = time.perf_counter()
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.disagg_prefill_bootstrap_queue.add(req)
|
||||
self.disagg_prefill_bootstrap_queue.add(
|
||||
req, self.model_config.num_key_value_heads
|
||||
)
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.disagg_decode_prealloc_queue.add(req)
|
||||
else:
|
||||
@@ -1118,7 +1126,9 @@ class Scheduler(
|
||||
|
||||
def _extend_requests_to_queue(self, reqs: List[Req]):
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.disagg_prefill_bootstrap_queue.extend(reqs)
|
||||
self.disagg_prefill_bootstrap_queue.extend(
|
||||
reqs, self.model_config.num_key_value_heads
|
||||
)
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
# If this is a decode server, we put the request to the decode pending prealloc queue
|
||||
self.disagg_decode_prealloc_queue.extend(reqs)
|
||||
|
||||
Reference in New Issue
Block a user