[PD] Support decode retract and update decode.py (#7196)

This commit is contained in:
Byron Hsu
2025-06-14 19:48:05 -07:00
committed by GitHub
parent 349bb2c92a
commit db0cc57e75
6 changed files with 378 additions and 43 deletions

View File

@@ -628,6 +628,7 @@ class Scheduler(
self.disagg_decode_transfer_queue = DecodeTransferQueue(
gloo_group=self.attn_tp_cpu_group,
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
tp_rank=self.tp_rank,
metadata_buffers=self.disagg_metadata_buffers,
scheduler=self,
tree_cache=self.tree_cache,
@@ -650,7 +651,11 @@ class Scheduler(
gloo_group=self.attn_tp_cpu_group,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
dp_size=self.server_args.dp_size,
gpu_id=self.gpu_id,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
max_total_num_tokens=self.max_total_num_tokens,
prefill_pp_size=self.server_args.disaggregation_prefill_pp,
transfer_backend=self.transfer_backend,
)
@@ -1124,14 +1129,14 @@ class Scheduler(
else:
self.waiting_queue.append(req)
def _extend_requests_to_queue(self, reqs: List[Req]):
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.disagg_prefill_bootstrap_queue.extend(
reqs, self.model_config.num_key_value_heads
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
# If this is a decode server, we put the request to the decode pending prealloc queue
self.disagg_decode_prealloc_queue.extend(reqs)
self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
else:
self.waiting_queue.extend(reqs)
@@ -1274,6 +1279,7 @@ class Scheduler(
if self.disaggregation_mode == DisaggregationMode.DECODE:
msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
msg += (
f"cuda graph: {can_run_cuda_graph}, "
@@ -1575,7 +1581,7 @@ class Scheduler(
f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
)
self._extend_requests_to_queue(retracted_reqs)
self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
else:
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_decay,