[PD] Support decode retract and update decode.py (#7196)

This commit is contained in:
Byron Hsu
2025-06-14 19:48:05 -07:00
committed by GitHub
parent 349bb2c92a
commit db0cc57e75
6 changed files with 378 additions and 43 deletions

View File

@@ -31,7 +31,7 @@ import numpy as np
import torch
from torch.distributed import ProcessGroup
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
from sglang.srt.disaggregation.utils import (
FAKE_BOOTSTRAP_HOST,
DisaggregationMode,
@@ -45,9 +45,17 @@ from sglang.srt.disaggregation.utils import (
poll_and_all_reduce,
prepare_abort,
)
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
ScheduleBatch,
global_server_args_dict,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import (
KVCache,
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -145,7 +153,11 @@ class DecodePreallocQueue:
gloo_group: ProcessGroup,
tp_rank: int,
tp_size: int,
dp_size: int,
gpu_id: int,
bootstrap_port: int,
max_total_num_tokens: int,
prefill_pp_size: int,
transfer_backend: TransferBackend,
):
self.req_to_token_pool = req_to_token_pool
@@ -161,25 +173,35 @@ class DecodePreallocQueue:
self.gloo_group = gloo_group
self.tp_rank = tp_rank
self.tp_size = tp_size
self.dp_size = dp_size
self.gpu_id = gpu_id
self.bootstrap_port = bootstrap_port
self.max_total_num_tokens = max_total_num_tokens
self.prefill_pp_size = prefill_pp_size
self.num_reserved_decode_tokens = int(
os.environ.get("SGLANG_NUM_RESERVED_DECODE_TOKENS", "512")
)
self.transfer_backend = transfer_backend
# Queue for requests pending pre-allocation
self.queue: List[DecodeRequest] = []
self.transfer_backend = transfer_backend
self.retracted_queue: List[Req] = []
self.prefill_pp_size = prefill_pp_size
self.kv_manager = self._init_kv_manager()
def _init_kv_manager(self) -> BaseKVManager:
kv_args = KVArgs()
kv_args.engine_rank = self.tp_rank
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
kv_args = kv_args_class()
attn_tp_size = self.tp_size // self.dp_size
kv_args.engine_rank = self.tp_rank % (attn_tp_size)
kv_args.decode_tp_size = attn_tp_size
kv_args.prefill_pp_size = self.prefill_pp_size
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
)
if self.draft_token_to_kv_pool is not None:
# We should also transfer draft model kv cache. The indices are
# always shared with a target model.
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
)
@@ -194,6 +216,7 @@ class DecodePreallocQueue:
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
self.metadata_buffers.get_buf_infos()
)
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
@@ -205,27 +228,83 @@ class DecodePreallocQueue:
)
return kv_manager
def add(self, req: Req) -> None:
def add(self, req: Req, is_retracted: bool = False) -> None:
"""Add a request to the pending queue."""
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
# Fake transfer for warmup reqs
kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
else:
kv_receiver_class = get_kv_class(
self.transfer_backend, KVClassType.RECEIVER
)
kv_receiver = kv_receiver_class(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
data_parallel_rank=req.data_parallel_rank,
)
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
if self._check_if_req_exceed_kv_capacity(req):
return
def extend(self, reqs: List[Req]) -> None:
if is_retracted:
self.retracted_queue.append(req)
else:
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
kv_receiver_class = get_kv_class(
TransferBackend.FAKE, KVClassType.RECEIVER
)
else:
kv_receiver_class = get_kv_class(
self.transfer_backend, KVClassType.RECEIVER
)
kv_receiver = kv_receiver_class(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
)
self.queue.append(
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
)
def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
if len(req.origin_input_ids) > self.max_total_num_tokens:
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
logger.error(message)
prepare_abort(req, message)
self.scheduler.stream_output([req], req.return_logprob)
return True
return False
def extend(self, reqs: List[Req], is_retracted: bool = False) -> None:
"""Add a request to the pending queue."""
for req in reqs:
self.add(req)
self.add(req, is_retracted=is_retracted)
def resume_retracted_reqs(self) -> List[Req]:
# TODO refactor the scheduling part, reuse with the unified engine logic as much as possible
# allocate memory
resumed_reqs = []
indices_to_remove = set()
allocatable_tokens = self._allocatable_tokens(count_retracted=False)
for i, req in enumerate(self.retracted_queue):
if self.req_to_token_pool.available_size() <= 0:
break
required_tokens_for_request = (
len(req.origin_input_ids)
+ len(req.output_ids)
+ self.num_reserved_decode_tokens
)
if required_tokens_for_request > allocatable_tokens:
break
resumed_reqs.append(req)
indices_to_remove.add(i)
req.is_retracted = False
self._pre_alloc(req)
allocatable_tokens -= required_tokens_for_request
# load from cpu, release the cpu copy
req.load_kv_cache(self.req_to_token_pool, self.token_to_kv_pool_allocator)
self.retracted_queue = [
entry
for i, entry in enumerate(self.retracted_queue)
if i not in indices_to_remove
]
return resumed_reqs
def _update_handshake_waiters(self) -> None:
if not self.queue:
@@ -255,6 +334,8 @@ class DecodePreallocQueue:
error_message,
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
else:
raise ValueError(f"Unexpected poll case: {poll}")
def pop_preallocated(self) -> List[DecodeRequest]:
"""Pop the preallocated requests from the pending queue (FIFO)."""
@@ -262,8 +343,16 @@ class DecodePreallocQueue:
preallocated_reqs = []
indices_to_remove = set()
allocatable_tokens = self._allocatable_tokens()
# We need to make sure that the sum of inflight tokens and allocatable tokens is greater than maximum input+output length of each inflight request
# Otherwise it is possible for one request running decode out of memory, while all other requests are in the transfer queue that cannot be retracted.
retractable_tokens = sum(
len(r.origin_input_ids) + len(r.output_ids)
for r in self.scheduler.running_batch.reqs
)
allocatable_tokens = self._allocatable_tokens(
retractable_tokens=retractable_tokens, count_retracted=True
)
# First, remove all failed requests from the queue
for i, decode_req in enumerate(self.queue):
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
@@ -272,6 +361,7 @@ class DecodePreallocQueue:
)
indices_to_remove.add(i)
# Then, preallocate the remaining requests if possible
for i, decode_req in enumerate(self.queue):
if i in indices_to_remove:
continue
@@ -285,10 +375,23 @@ class DecodePreallocQueue:
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
break
# Memory estimation: don't add if the projected memory cannot be met
# TODO: add new_token ratio
origin_input_len = len(decode_req.req.origin_input_ids)
required_tokens_for_request = (
len(decode_req.req.origin_input_ids) + self.num_reserved_decode_tokens
origin_input_len + self.num_reserved_decode_tokens
)
if (
max(
required_tokens_for_request,
origin_input_len
+ decode_req.req.sampling_params.max_new_tokens
- retractable_tokens,
)
> allocatable_tokens
):
break
if required_tokens_for_request > allocatable_tokens:
break
@@ -321,15 +424,35 @@ class DecodePreallocQueue:
return preallocated_reqs
def _allocatable_tokens(self) -> int:
allocatable_tokens = (
self.token_to_kv_pool_allocator.available_size()
- self.num_reserved_decode_tokens
def _allocatable_tokens(
self, retractable_tokens: Optional[int] = None, count_retracted: bool = True
) -> int:
need_space_for_single_req = (
max(
[
x.sampling_params.max_new_tokens
+ len(x.origin_input_ids)
- retractable_tokens
for x in self.scheduler.running_batch.reqs
]
)
if retractable_tokens is not None
and len(self.scheduler.running_batch.reqs) > 0
else 0
)
available_size = self.token_to_kv_pool_allocator.available_size()
allocatable_tokens = available_size - max(
# preserve some space for future decode
self.num_reserved_decode_tokens
* (
len(self.scheduler.running_batch.reqs)
+ len(self.transfer_queue.queue)
+ len(self.scheduler.waiting_queue)
)
),
# make sure each request can finish if reach max_tokens with all other requests retracted
need_space_for_single_req,
)
# Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
@@ -342,15 +465,27 @@ class DecodePreallocQueue:
self.scheduler.last_batch.reqs
)
if count_retracted:
allocatable_tokens -= sum(
[
len(req.origin_input_ids)
+ len(req.output_ids)
+ self.num_reserved_decode_tokens
for req in self.retracted_queue
]
)
return allocatable_tokens
def _pre_alloc(self, req: Req) -> torch.Tensor:
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
req_pool_indices = self.req_to_token_pool.alloc(1)
assert req_pool_indices is not None
assert (
req_pool_indices is not None
), "req_pool_indices is full! There is a bug in memory estimation."
req.req_pool_idx = req_pool_indices[0]
if self.token_to_kv_pool_allocator.page_size == 1:
kv_loc = self.token_to_kv_pool_allocator.alloc(
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
@@ -375,7 +510,10 @@ class DecodePreallocQueue:
),
extend_num_tokens=num_tokens,
)
assert kv_loc is not None
assert (
kv_loc is not None
), "KV cache is full! There is a bug in memory estimation."
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
@@ -395,6 +533,7 @@ class DecodeTransferQueue:
self,
gloo_group: ProcessGroup,
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
tp_rank: int,
metadata_buffers: MetadataBuffers,
scheduler: Scheduler,
tree_cache: BasePrefixCache,
@@ -402,6 +541,7 @@ class DecodeTransferQueue:
self.queue: List[DecodeRequest] = []
self.gloo_group = gloo_group
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
self.tp_rank = tp_rank
self.metadata_buffers = metadata_buffers
self.scheduler = scheduler
self.tree_cache = tree_cache
@@ -412,10 +552,9 @@ class DecodeTransferQueue:
def extend(self, decode_reqs: List[DecodeRequest]) -> None:
self.queue.extend(decode_reqs)
def pop_transferred(self) -> List[DecodeRequest]:
def pop_transferred(self) -> List[Req]:
if not self.queue:
return []
polls = poll_and_all_reduce(
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
)
@@ -424,7 +563,7 @@ class DecodeTransferQueue:
indices_to_remove = set()
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
if poll == KVPoll.Failed:
error_message = f"Decode transfer failed for request rank={self.scheduler.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
try:
decode_req.kv_receiver.failure_exception()
except Exception as e:
@@ -543,7 +682,8 @@ class SchedulerDisaggregationDecodeMixin:
batch, _ = self._prepare_idle_batch_and_run(None)
if batch is None and (
len(self.disagg_decode_transfer_queue.queue)
len(self.waiting_queue)
+ len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue)
== 0
):
@@ -622,7 +762,8 @@ class SchedulerDisaggregationDecodeMixin:
self.process_batch_result(tmp_batch, tmp_result)
if batch is None and (
len(self.disagg_decode_transfer_queue.queue)
len(self.waiting_queue)
+ len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue)
== 0
):
@@ -716,6 +857,13 @@ class SchedulerDisaggregationDecodeMixin:
return new_batch
def process_decode_queue(self: Scheduler):
# try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps
resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()
self.waiting_queue.extend(resumed_reqs)
if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0:
# if there are still retracted requests, we do not allocate new requests
return
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
self.disagg_decode_transfer_queue.extend(req_conns)
alloc_reqs = (

View File

@@ -25,6 +25,7 @@ from collections import deque
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional
import numpy as np
import torch
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
@@ -575,6 +576,7 @@ class SchedulerDisaggregationPrefillMixin:
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
.cpu()
.numpy()
.astype(np.int64)
)
req.start_send_idx = end_idx
if last_chunk: