diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py index c91b4b813..d4331c234 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -1,23 +1,32 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Optional +from typing import TYPE_CHECKING, List, Optional import numpy as np import numpy.typing as npt -from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.server_args import ServerArgs +if TYPE_CHECKING: + from sglang.srt.disaggregation.utils import DisaggregationMode + class KVArgs: engine_rank: int - kv_data_ptrs: list[int] - kv_data_lens: list[int] - kv_item_lens: list[int] - aux_data_ptrs: list[int] - aux_data_lens: list[int] - aux_item_lens: list[int] + kv_data_ptrs: List[int] + kv_data_lens: List[int] + kv_item_lens: List[int] + aux_data_ptrs: List[int] + aux_data_lens: List[int] + aux_item_lens: List[int] ib_device: str + ib_traffic_class: str gpu_id: int + # for different tp + decode_tp_size: int + # for pp prefill + prefill_pp_size: int class KVPoll: @@ -45,7 +54,12 @@ class BaseKVSender(ABC): @abstractmethod def __init__( - self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int + self, + mgr: BaseKVManager, + bootstrap_addr: str, + bootstrap_room: int, + dest_tp_ranks: List[int], + pp_rank: int, ): ... @abstractmethod diff --git a/python/sglang/srt/disaggregation/common/utils.py b/python/sglang/srt/disaggregation/common/utils.py new file mode 100644 index 000000000..ba0cfd6af --- /dev/null +++ b/python/sglang/srt/disaggregation/common/utils.py @@ -0,0 +1,42 @@ +import threading +from collections import deque +from typing import List, Tuple + +import numpy as np +import numpy.typing as npt + + +class FastQueue: + def __init__(self): + self._buf = deque() + self._cond = threading.Condition() + + def put(self, item): + with self._cond: + self._buf.append(item) + # wake up a thread of wait() + self._cond.notify() + + def get(self): + with self._cond: + # if queue is empty ,block until is notified() + while not self._buf: + self._cond.wait() + return self._buf.popleft() + + +def group_concurrent_contiguous( + src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] +) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: + """Vectorised NumPy implementation.""" + if src_indices.size == 0: + return [], [] + + brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1 + src_groups = np.split(src_indices, brk) + dst_groups = np.split(dst_indices, brk) + + src_groups = [g.tolist() for g in src_groups] + dst_groups = [g.tolist() for g in dst_groups] + + return src_groups, dst_groups diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 99b2bf330..e2cc25eeb 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -33,8 +33,8 @@ from torch.distributed import ProcessGroup from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll from sglang.srt.disaggregation.utils import ( + FAKE_BOOTSTRAP_HOST, DisaggregationMode, - FakeBootstrapHost, KVClassType, MetadataBuffers, ReqToMetadataIdxAllocator, @@ -207,7 +207,7 @@ class DecodePreallocQueue: def add(self, req: Req) -> None: """Add a request to the pending queue.""" - if req.bootstrap_host == FakeBootstrapHost: + if req.bootstrap_host == FAKE_BOOTSTRAP_HOST: # Fake transfer for warmup reqs kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER) else: diff --git a/python/sglang/srt/disaggregation/fake/conn.py b/python/sglang/srt/disaggregation/fake/conn.py index d080c8e2e..25335dd68 100644 --- a/python/sglang/srt/disaggregation/fake/conn.py +++ b/python/sglang/srt/disaggregation/fake/conn.py @@ -17,7 +17,14 @@ logger = logging.getLogger(__name__) # For warmup reqs, we don't kv transfer, we use the fake sender and receiver class FakeKVSender(BaseKVSender): - def __init__(self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int): + def __init__( + self, + mgr: BaseKVManager, + bootstrap_addr: str, + bootstrap_room: int, + dest_tp_ranks: List[int], + pp_rank: int, + ): self.has_sent = False def poll(self) -> KVPoll: diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 5d55eb468..b3d83db69 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -28,12 +28,12 @@ from sglang.srt.disaggregation.base.conn import ( KVArgs, KVPoll, ) -from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine -from sglang.srt.disaggregation.utils import ( - DisaggregationMode, +from sglang.srt.disaggregation.common.utils import ( FastQueue, group_concurrent_contiguous, ) +from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine +from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( get_free_port, @@ -677,7 +677,12 @@ class MooncakeKVManager(BaseKVManager): class MooncakeKVSender(BaseKVSender): def __init__( - self, mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: int + self, + mgr: MooncakeKVManager, + bootstrap_addr: str, + bootstrap_room: int, + dest_tp_ranks: List[int], + pp_rank: int, ): self.kv_mgr = mgr self.bootstrap_room = bootstrap_room diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index f9a0e931c..18378bbf4 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -24,10 +24,8 @@ from sglang.srt.disaggregation.common.conn import ( CommonKVManager, CommonKVReceiver, ) -from sglang.srt.disaggregation.utils import ( - DisaggregationMode, - group_concurrent_contiguous, -) +from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous +from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_local_ip_by_remote @@ -350,7 +348,14 @@ class NixlKVManager(CommonKVManager): class NixlKVSender(BaseKVSender): - def __init__(self, mgr: NixlKVManager, bootstrap_addr: str, bootstrap_room: int): + def __init__( + self, + mgr: NixlKVManager, + bootstrap_addr: str, + bootstrap_room: int, + dest_tp_ranks: List[int], + pp_rank: int, + ): self.kv_mgr = mgr self.bootstrap_room = bootstrap_room self.aux_index = None diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index a008b404f..3382c9473 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -27,10 +27,10 @@ from typing import TYPE_CHECKING, List, Optional import torch -from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll +from sglang.srt.disaggregation.base import BaseKVManager, KVPoll from sglang.srt.disaggregation.utils import ( + FAKE_BOOTSTRAP_HOST, DisaggregationMode, - FakeBootstrapHost, KVClassType, MetadataBuffers, ReqToMetadataIdxAllocator, @@ -51,7 +51,6 @@ if TYPE_CHECKING: from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler from sglang.srt.mem_cache.memory_pool import KVCache - logger = logging.getLogger(__name__) @@ -68,35 +67,45 @@ class PrefillBootstrapQueue: metadata_buffers: MetadataBuffers, tp_rank: int, tp_size: int, + gpu_id: int, bootstrap_port: int, gloo_group: ProcessGroup, - transfer_backend: TransferBackend, + max_total_num_tokens: int, + decode_tp_size: int, + decode_dp_size: int, scheduler: Scheduler, + pp_rank: int, + pp_size: int, + transfer_backend: TransferBackend, ): self.token_to_kv_pool = token_to_kv_pool self.draft_token_to_kv_pool = draft_token_to_kv_pool - self.is_mla_backend = is_mla_backend(token_to_kv_pool) - self.metadata_buffers = metadata_buffers self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator self.tp_rank = tp_rank self.tp_size = tp_size - self.transfer_backend = transfer_backend - self.scheduler = scheduler - self.kv_manager = self._init_kv_manager() - self.queue: List[Req] = [] - self.gloo_group = gloo_group + self.decode_tp_size = decode_tp_size + self.decode_dp_size = decode_dp_size + self.pp_rank = pp_rank + self.pp_size = pp_size + self.gpu_id = gpu_id self.bootstrap_port = bootstrap_port - - def store_prefill_results(self, idx: int, token_id: int): - assert token_id >= 0, f"token_id: {token_id} is negative" - output_id_buffer = self.metadata_buffers[0] - output_id_buffer[idx] = token_id + self.queue: List[Req] = [] + self.pp_rank = pp_rank + self.pp_size = pp_size + self.gloo_group = gloo_group + self.max_total_num_tokens = max_total_num_tokens + self.scheduler = scheduler + self.transfer_backend = transfer_backend + self.kv_manager = self._init_kv_manager() def _init_kv_manager(self) -> BaseKVManager: - kv_args = KVArgs() + kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS) + kv_args = kv_args_class() kv_args.engine_rank = self.tp_rank + kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size + kv_args.prefill_pp_size = self.pp_size kv_data_ptrs, kv_data_lens, kv_item_lens = ( self.token_to_kv_pool.get_contiguous_buf_infos() ) @@ -115,12 +124,12 @@ class PrefillBootstrapQueue: kv_args.kv_data_lens = kv_data_lens kv_args.kv_item_lens = kv_item_lens - # Define req -> input ids buffer kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = ( self.metadata_buffers.get_buf_infos() ) kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.gpu_id = self.scheduler.gpu_id + kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) kv_manager = kv_manager_class( kv_args, @@ -130,23 +139,39 @@ class PrefillBootstrapQueue: ) return kv_manager - def add(self, req: Req) -> None: - if req.bootstrap_host == FakeBootstrapHost: - # Fake transfer for warmup reqs + def add(self, req: Req, num_kv_heads: int) -> None: + if self._check_if_req_exceed_kv_capacity(req): + return + + if req.bootstrap_host == FAKE_BOOTSTRAP_HOST: kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER) else: kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER) + + dest_tp_ranks = [self.tp_rank] + req.disagg_kv_sender = kv_sender_class( mgr=self.kv_manager, bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}", bootstrap_room=req.bootstrap_room, + dest_tp_ranks=dest_tp_ranks, + pp_rank=self.pp_rank, ) self._process_req(req) self.queue.append(req) - def extend(self, reqs: List[Req]) -> None: + def extend(self, reqs: List[Req], num_kv_heads: int) -> None: for req in reqs: - self.add(req) + self.add(req, num_kv_heads) + + def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool: + if len(req.origin_input_ids) > self.max_total_num_tokens: + message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}" + logger.error(message) + prepare_abort(req, message) + self.scheduler.stream_output([req], req.return_logprob) + return True + return False def _process_req(self, req: Req) -> None: """ @@ -154,19 +179,40 @@ class PrefillBootstrapQueue: """ req.sampling_params.max_new_tokens = 1 - def pop_bootstrapped(self) -> List[Req]: - """pop the reqs which has finished bootstrapping""" + def pop_bootstrapped( + self, + return_failed_reqs: bool = False, + rids_to_check: Optional[List[str]] = None, + ) -> List[Req]: + """ + pop the reqs which has finished bootstrapping + + return_failed_reqs: For PP, on rank 0, also return the failed reqs to notify the next rank + rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank. + """ + bootstrapped_reqs = [] + failed_reqs = [] indices_to_remove = set() if len(self.queue) == 0: - return [] + if return_failed_reqs is False: + return [] + else: + return [], [] polls = poll_and_all_reduce( [req.disagg_kv_sender for req in self.queue], self.gloo_group ) - for i, (req, poll) in enumerate(zip(self.queue, polls)): + + if rids_to_check is not None: + # if req not in reqs_info_to_check, skip + if req.rid not in rids_to_check: + continue + # Either waiting for input or failed + assert poll == KVPoll.WaitingForInput or poll == KVPoll.Failed + if poll == KVPoll.Bootstrapping: continue elif poll == KVPoll.Failed: @@ -181,9 +227,10 @@ class PrefillBootstrapQueue: ) self.scheduler.stream_output([req], req.return_logprob) indices_to_remove.add(i) + failed_reqs.append(req) continue - # KV.WaitingForInput + # KV.WaitingForInput - init here num_kv_indices = len(req.origin_input_ids) if self.req_to_metadata_buffer_idx_allocator.available_size() == 0: break @@ -192,9 +239,9 @@ class PrefillBootstrapQueue: self.req_to_metadata_buffer_idx_allocator.alloc() ) assert req.metadata_buffer_index is not None + num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size) req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index) - bootstrapped_reqs.append(req) indices_to_remove.add(i) @@ -202,7 +249,10 @@ class PrefillBootstrapQueue: entry for i, entry in enumerate(self.queue) if i not in indices_to_remove ] - return bootstrapped_reqs + if return_failed_reqs is False: + return bootstrapped_reqs + else: + return bootstrapped_reqs, failed_reqs class SchedulerDisaggregationPrefillMixin: @@ -211,7 +261,7 @@ class SchedulerDisaggregationPrefillMixin: """ @torch.no_grad() - def event_loop_normal_disagg_prefill(self: Scheduler): + def event_loop_normal_disagg_prefill(self: Scheduler) -> None: """A normal scheduler loop for prefill worker in disaggregation mode.""" while True: @@ -229,7 +279,6 @@ class SchedulerDisaggregationPrefillMixin: or self.server_args.enable_sp_layernorm ): batch, _ = self.prepare_dp_attn_batch(batch) - self.cur_batch = batch if batch: @@ -250,7 +299,7 @@ class SchedulerDisaggregationPrefillMixin: self.running_batch.batch_is_full = False @torch.no_grad() - def event_loop_overlap_disagg_prefill(self: Scheduler): + def event_loop_overlap_disagg_prefill(self: Scheduler) -> None: self.result_queue = deque() while True: @@ -268,9 +317,7 @@ class SchedulerDisaggregationPrefillMixin: or self.server_args.enable_sp_layernorm ): batch, _ = self.prepare_dp_attn_batch(batch) - self.cur_batch = batch - if batch: result = self.run_batch(batch) self.result_queue.append((batch.copy(), result)) @@ -287,6 +334,9 @@ class SchedulerDisaggregationPrefillMixin: if self.last_batch: tmp_batch, tmp_result = self.result_queue.popleft() + tmp_batch.next_batch_sampling_info = ( + self.tp_worker.cur_sampling_info if batch else None + ) self.process_batch_result_disagg_prefill(tmp_batch, tmp_result) if len(self.disagg_prefill_inflight_queue) > 0: @@ -309,7 +359,7 @@ class SchedulerDisaggregationPrefillMixin: launch_done: Optional[threading.Event] = None, ) -> None: """ - Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue + Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue Adapted from process_batch_result_prefill """ ( @@ -325,7 +375,7 @@ class SchedulerDisaggregationPrefillMixin: ) logprob_pt = 0 - # Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue + # Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue if self.enable_overlap: # wait logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result( @@ -397,11 +447,15 @@ class SchedulerDisaggregationPrefillMixin: # We need to remove the sync in the following function for overlap schedule. self.set_next_batch_sampling_info_done(batch) - def process_disagg_prefill_inflight_queue(self: Scheduler) -> None: + def process_disagg_prefill_inflight_queue( + self: Scheduler, rids_to_check: Optional[List[str]] = None + ) -> List[Req]: """ Poll the requests in the middle of transfer. If done, return the request. + rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank. """ - assert len(self.disagg_prefill_inflight_queue) > 0 + if len(self.disagg_prefill_inflight_queue) == 0: + return [] done_reqs = [] @@ -413,6 +467,14 @@ class SchedulerDisaggregationPrefillMixin: undone_reqs: List[Req] = [] # Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue for req, poll in zip(self.disagg_prefill_inflight_queue, polls): + + if rids_to_check is not None: + if req.rid not in rids_to_check: + undone_reqs.append(req) + continue + + assert poll == KVPoll.Success or poll == KVPoll.Failed + if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]: undone_reqs.append(req) elif poll == KVPoll.Success: # transfer done @@ -434,11 +496,8 @@ class SchedulerDisaggregationPrefillMixin: req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR ) done_reqs.append(req) - - for req in done_reqs: - self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free( - req.metadata_buffer_index - ) + else: + assert False, f"Unexpected polling state {poll=}" # Stream requests which have finished transfer self.stream_output( @@ -446,9 +505,32 @@ class SchedulerDisaggregationPrefillMixin: any(req.return_logprob for req in done_reqs), None, ) + for req in done_reqs: + req: Req + self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index) + req.metadata_buffer_index = -1 self.disagg_prefill_inflight_queue = undone_reqs + return done_reqs + + def get_transferred_rids(self: Scheduler) -> List[str]: + """ + Used by PP, get the transferred rids but **do not pop** + """ + polls = poll_and_all_reduce( + [req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue], + self.tp_worker.get_tp_group().cpu_group, + ) + + transferred_rids: List[str] = [] + + for req, poll in zip(self.disagg_prefill_inflight_queue, polls): + if poll == KVPoll.Success or poll == KVPoll.Failed: + transferred_rids.append(req.rid) + + return transferred_rids + def process_prefill_chunk(self: Scheduler) -> None: if self.last_batch and self.last_batch.forward_mode.is_extend(): if self.chunked_req: diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index db7dd3239..6b52342dd 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -14,15 +14,15 @@ import requests import torch import torch.distributed as dist -from sglang.srt.utils import get_ip, get_local_ip_by_remote +from sglang.srt.utils import get_ip if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req -FakeBootstrapHost = "2.2.2.2" - -# env var for testing failure, convert to float explicitly -FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0)) +######################### +# Constants & Enums +######################### +FAKE_BOOTSTRAP_HOST = "2.2.2.2" class DisaggregationMode(Enum): @@ -31,6 +31,14 @@ class DisaggregationMode(Enum): DECODE = "decode" +######################### +# Synchronization +######################### + +# env var for testing failure, convert to float explicitly +FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0)) + + def poll_and_all_reduce(pollers, gloo_group): # at a certain prob, the poll is failed to simulate failure if FAILURE_PROB > 0: @@ -47,6 +55,11 @@ def poll_and_all_reduce(pollers, gloo_group): return tensor_to_reduce.tolist() +######################### +# Metadata Buffers +######################### + + class ReqToMetadataIdxAllocator: """A memory pool that maps a request to its first output token location.""" @@ -70,138 +83,6 @@ class ReqToMetadataIdxAllocator: self.free_slots.append(free_index) -class TransferBackend(Enum): - MOONCAKE = "mooncake" - NIXL = "nixl" - FAKE = "fake" - - -class KVClassType(Enum): - MANAGER = "manager" - SENDER = "sender" - RECEIVER = "receiver" - BOOTSTRAP_SERVER = "bootstrap_server" - - -def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): - from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender - - if transfer_backend == TransferBackend.MOONCAKE: - from sglang.srt.disaggregation.mooncake import ( - MooncakeKVBootstrapServer, - MooncakeKVManager, - MooncakeKVReceiver, - MooncakeKVSender, - ) - - class_mapping = { - KVClassType.MANAGER: MooncakeKVManager, - KVClassType.SENDER: MooncakeKVSender, - KVClassType.RECEIVER: (MooncakeKVReceiver), - KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer, - } - return class_mapping.get(class_type) - if transfer_backend == TransferBackend.NIXL: - from sglang.srt.disaggregation.nixl import ( - NixlKVBootstrapServer, - NixlKVManager, - NixlKVReceiver, - NixlKVSender, - ) - - class_mapping = { - KVClassType.MANAGER: NixlKVManager, - KVClassType.SENDER: NixlKVSender, - KVClassType.RECEIVER: (NixlKVReceiver), - KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer, - } - return class_mapping.get(class_type) - if transfer_backend == TransferBackend.FAKE: - from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender - - class_mapping = { - KVClassType.SENDER: FakeKVSender, - KVClassType.RECEIVER: (FakeKVReceiver), - } - return class_mapping.get(class_type) - - raise ValueError(f"Unsupported transfer backend: {transfer_backend}") - - -def kv_to_page_indices(kv_indices: np.ndarray, page_size: int): - # 1. The page is guaranteed to be full except the last page. - # 2. page index = kv_index // page_size - # The return vector is kv_indices[::page_size] // page_size - if page_size == 1: # shortcut - return kv_indices - - return kv_indices[::page_size] // page_size - - -def kv_to_page_num(num_kv_indices: int, page_size: int): - # ceil(num_kv_indices / page_size) - return (num_kv_indices + page_size - 1) // page_size - - -@dataclasses.dataclass -class PDRegistryRequest: - """A request to register a machine itself to the LB.""" - - mode: str - registry_url: str - bootstrap_port: Optional[int] = None - - def __post_init__(self): - if self.mode == "prefill" and self.bootstrap_port is None: - raise ValueError("Bootstrap port must be set in PREFILL mode.") - elif self.mode == "decode" and self.bootstrap_port is not None: - raise ValueError("Bootstrap port must not be set in DECODE mode.") - elif self.mode not in ["prefill", "decode"]: - raise ValueError( - f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'." - ) - - -def register_disaggregation_server( - mode: str, server_port: int, bootstrap_port: int, pdlb_url: str -): - boostrap_port = bootstrap_port if mode == "prefill" else None - registry_request = PDRegistryRequest( - mode=mode, - registry_url=f"http://{get_ip()}:{server_port}", - bootstrap_port=boostrap_port, - ) - res = requests.post( - f"{pdlb_url}/register", - json=dataclasses.asdict(registry_request), - ) - if res.status_code != 200: - warnings.warn( - f"Failed to register disaggregation server: {res.status_code} {res.text}" - ) - - -def is_mla_backend(target_kv_pool) -> bool: - from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool - - return isinstance(target_kv_pool, MLATokenToKVPool) - - -def prepare_abort(req: Req, error_message: str, status_code=None): - from sglang.srt.managers.schedule_batch import FINISH_ABORT - - # populate finish metadata and stream output - req.finished_reason = FINISH_ABORT(error_message, status_code) - - if req.return_logprob: - req.input_token_logprobs_val = [] - req.input_token_logprobs_idx = [] - req.input_top_logprobs_val = [] - req.input_top_logprobs_idx = [] - req.input_token_ids_logprobs_val = [] - req.input_token_ids_logprobs_idx = [] - - class MetadataBuffers: def __init__(self, size: int, max_top_logprobs_num: int = 128): # TODO: abort top_logprobs_num > 128 in PD @@ -282,37 +163,160 @@ class MetadataBuffers: ) -class FastQueue: - def __init__(self): - self._buf = deque() - self._cond = threading.Condition() - - def put(self, item): - with self._cond: - self._buf.append(item) - # wake up a thread of wait() - self._cond.notify() - - def get(self): - with self._cond: - # if queue is empty ,block until is notified() - while not self._buf: - self._cond.wait() - return self._buf.popleft() +######################### +# Transfer Backend +######################### -def group_concurrent_contiguous( - src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] -) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: - """Vectorised NumPy implementation.""" - if src_indices.size == 0: - return [], [] +class TransferBackend(Enum): + MOONCAKE = "mooncake" + NIXL = "nixl" + FAKE = "fake" - brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1 - src_groups = np.split(src_indices, brk) - dst_groups = np.split(dst_indices, brk) - src_groups = [g.tolist() for g in src_groups] - dst_groups = [g.tolist() for g in dst_groups] +class KVClassType(Enum): + KVARGS = "kvargs" + MANAGER = "manager" + SENDER = "sender" + RECEIVER = "receiver" + BOOTSTRAP_SERVER = "bootstrap_server" - return src_groups, dst_groups + +def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): + from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender + + if transfer_backend == TransferBackend.MOONCAKE: + from sglang.srt.disaggregation.base import KVArgs + from sglang.srt.disaggregation.mooncake import ( + MooncakeKVBootstrapServer, + MooncakeKVManager, + MooncakeKVReceiver, + MooncakeKVSender, + ) + + class_mapping = { + KVClassType.KVARGS: KVArgs, + KVClassType.MANAGER: MooncakeKVManager, + KVClassType.SENDER: MooncakeKVSender, + KVClassType.RECEIVER: (MooncakeKVReceiver), + KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer, + } + return class_mapping.get(class_type) + if transfer_backend == TransferBackend.NIXL: + from sglang.srt.disaggregation.base import KVArgs + from sglang.srt.disaggregation.nixl import ( + NixlKVBootstrapServer, + NixlKVManager, + NixlKVReceiver, + NixlKVSender, + ) + + class_mapping = { + KVClassType.KVARGS: KVArgs, + KVClassType.MANAGER: NixlKVManager, + KVClassType.SENDER: NixlKVSender, + KVClassType.RECEIVER: (NixlKVReceiver), + KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer, + } + return class_mapping.get(class_type) + if transfer_backend == TransferBackend.FAKE: + from sglang.srt.disaggregation.base import KVArgs + from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender + + class_mapping = { + KVClassType.KVARGS: KVArgs, + KVClassType.SENDER: FakeKVSender, + KVClassType.RECEIVER: (FakeKVReceiver), + } + return class_mapping.get(class_type) + + raise ValueError(f"Unsupported transfer backend: {transfer_backend}") + + +######################### +# KV Pages +######################### + + +def kv_to_page_indices(kv_indices: np.ndarray, page_size: int): + # 1. The page is guaranteed to be full except the last page. + # 2. page index = kv_index // page_size + # The return vector is kv_indices[::page_size] // page_size + if page_size == 1: # shortcut + return kv_indices + + return kv_indices[::page_size] // page_size + + +def kv_to_page_num(num_kv_indices: int, page_size: int): + # ceil(num_kv_indices / page_size) + return (num_kv_indices + page_size - 1) // page_size + + +######################### +# PDLB Registry +######################### + + +@dataclasses.dataclass +class PDRegistryRequest: + """A request to register a machine itself to the LB.""" + + mode: str + registry_url: str + bootstrap_port: Optional[int] = None + + def __post_init__(self): + if self.mode == "prefill" and self.bootstrap_port is None: + raise ValueError("Bootstrap port must be set in PREFILL mode.") + elif self.mode == "decode" and self.bootstrap_port is not None: + raise ValueError("Bootstrap port must not be set in DECODE mode.") + elif self.mode not in ["prefill", "decode"]: + raise ValueError( + f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'." + ) + + +def register_disaggregation_server( + mode: str, server_port: int, bootstrap_port: int, pdlb_url: str +): + boostrap_port = bootstrap_port if mode == "prefill" else None + registry_request = PDRegistryRequest( + mode=mode, + registry_url=f"http://{get_ip()}:{server_port}", + bootstrap_port=boostrap_port, + ) + res = requests.post( + f"{pdlb_url}/register", + json=dataclasses.asdict(registry_request), + ) + if res.status_code != 200: + warnings.warn( + f"Failed to register disaggregation server: {res.status_code} {res.text}" + ) + + +######################### +# Misc +######################### + + +def is_mla_backend(target_kv_pool) -> bool: + from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool + + return isinstance(target_kv_pool, MLATokenToKVPool) + + +def prepare_abort(req: Req, error_message: str, status_code=None): + from sglang.srt.managers.schedule_batch import FINISH_ABORT + + # populate finish metadata and stream output + req.finished_reason = FINISH_ABORT(error_message, status_code) + + if req.return_logprob: + req.input_token_logprobs_val = [] + req.input_token_logprobs_idx = [] + req.input_top_logprobs_val = [] + req.input_top_logprobs_idx = [] + req.input_token_ids_logprobs_val = [] + req.input_token_ids_logprobs_idx = [] diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 06c673438..89417fd86 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -43,7 +43,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse, Response, StreamingResponse from sglang.srt.disaggregation.utils import ( - FakeBootstrapHost, + FAKE_BOOTSTRAP_HOST, register_disaggregation_server, ) from sglang.srt.entrypoints.engine import _launch_subprocesses @@ -878,7 +878,7 @@ def _wait_and_warmup( "max_new_tokens": 8, "ignore_eos": True, }, - "bootstrap_host": [FakeBootstrapHost] * server_args.dp_size, + "bootstrap_host": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size, # This is a hack to ensure fake transfer is enabled during prefill warmup # ensure each dp rank has a unique bootstrap_room during prefill warmup "bootstrap_room": [ diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 37f39096c..0b3a76667 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -619,7 +619,7 @@ class Scheduler( self.disaggregation_mode == DisaggregationMode.DECODE ): # *2 for the headroom. buffer_size = (self.req_to_token_pool.size) * 2 - req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator( + self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator( buffer_size ) self.disagg_metadata_buffers = MetadataBuffers(buffer_size) @@ -627,7 +627,7 @@ class Scheduler( # The decode requests polling kv cache self.disagg_decode_transfer_queue = DecodeTransferQueue( gloo_group=self.attn_tp_cpu_group, - req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator, + req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator, metadata_buffers=self.disagg_metadata_buffers, scheduler=self, tree_cache=self.tree_cache, @@ -642,7 +642,7 @@ class Scheduler( if self.draft_worker is None else self.draft_worker.model_runner.token_to_kv_pool ), - req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator, + req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator, metadata_buffers=self.disagg_metadata_buffers, scheduler=self, transfer_queue=self.disagg_decode_transfer_queue, @@ -660,7 +660,7 @@ class Scheduler( elif self.disaggregation_mode == DisaggregationMode.PREFILL: # *2 for the headroom. buffer_size = self.max_running_requests * 2 - req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator( + self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator( buffer_size ) self.disagg_metadata_buffers = MetadataBuffers(buffer_size) @@ -672,14 +672,20 @@ class Scheduler( if self.draft_worker is None else self.draft_worker.model_runner.token_to_kv_pool ), - req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator, + req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator, metadata_buffers=self.disagg_metadata_buffers, tp_rank=self.tp_rank, tp_size=self.tp_size, + gpu_id=self.gpu_id, bootstrap_port=self.server_args.disaggregation_bootstrap_port, gloo_group=self.attn_tp_cpu_group, - transfer_backend=self.transfer_backend, + max_total_num_tokens=self.max_total_num_tokens, + decode_tp_size=self.server_args.disaggregation_decode_tp, + decode_dp_size=self.server_args.disaggregation_decode_dp, scheduler=self, + pp_rank=self.pp_rank, + pp_size=self.pp_size, + transfer_backend=self.transfer_backend, ) # The prefill requests that are in the middle of kv sending self.disagg_prefill_inflight_queue: List[Req] = [] @@ -1110,7 +1116,9 @@ class Scheduler( def _add_request_to_queue(self, req: Req): req.queue_time_start = time.perf_counter() if self.disaggregation_mode == DisaggregationMode.PREFILL: - self.disagg_prefill_bootstrap_queue.add(req) + self.disagg_prefill_bootstrap_queue.add( + req, self.model_config.num_key_value_heads + ) elif self.disaggregation_mode == DisaggregationMode.DECODE: self.disagg_decode_prealloc_queue.add(req) else: @@ -1118,7 +1126,9 @@ class Scheduler( def _extend_requests_to_queue(self, reqs: List[Req]): if self.disaggregation_mode == DisaggregationMode.PREFILL: - self.disagg_prefill_bootstrap_queue.extend(reqs) + self.disagg_prefill_bootstrap_queue.extend( + reqs, self.model_config.num_key_value_heads + ) elif self.disaggregation_mode == DisaggregationMode.DECODE: # If this is a decode server, we put the request to the decode pending prealloc queue self.disagg_decode_prealloc_queue.extend(reqs) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d0a97eb6a..04b6f96cb 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -227,6 +227,9 @@ class ServerArgs: disaggregation_mode: str = "null" disaggregation_transfer_backend: str = "mooncake" disaggregation_bootstrap_port: int = 8998 + disaggregation_decode_tp: Optional[int] = None + disaggregation_decode_dp: Optional[int] = None + disaggregation_prefill_pp: Optional[int] = 1 disaggregation_ib_device: Optional[str] = None num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD pdlb_url: Optional[str] = None @@ -505,12 +508,27 @@ class ServerArgs: self.triton_attention_num_kv_splits = 16 # PD disaggregation - if self.disaggregation_mode == "prefill": - self.disable_cuda_graph = True - logger.warning("Cuda graph is disabled for prefill server") - elif self.disaggregation_mode == "decode": + if self.disaggregation_mode == "decode": + assert ( + self.disaggregation_decode_tp is None + ), "Cannot set --disaggregation-decode-tp for the decode engine." + assert ( + self.disaggregation_decode_dp is None + ), "Cannot set --disaggregation-decode-dp for the decode engine." + self.disable_radix_cache = True logger.warning("KV cache is forced as chunk cache for decode server") + elif self.disaggregation_mode == "prefill": + if self.disaggregation_decode_tp is None: + self.disaggregation_decode_tp = self.tp_size + if self.disaggregation_decode_dp is None: + self.disaggregation_decode_dp = self.dp_size + + self.disaggregation_prefill_pp = self.pp_size + self.validate_disagg_tp_size(self.tp_size, self.disaggregation_decode_tp) + + self.disable_cuda_graph = True + logger.warning("Cuda graph is disabled for prefill server") os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = ( "1" if self.enable_torch_compile else "0" @@ -520,6 +538,14 @@ class ServerArgs: "1" if self.disable_outlines_disk_cache else "0" ) + def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int): + larger_tp = max(decode_tp, prefill_tp) + smaller_tp = min(decode_tp, prefill_tp) + assert larger_tp % smaller_tp == 0, ( + "Different tp size is supported only when one tp is multiple of the other. " + f"decode_tp={decode_tp}, prefill_tp={prefill_tp}" + ) + @staticmethod def add_cli_args(parser: argparse.ArgumentParser): # Model and port args @@ -1512,6 +1538,24 @@ class ServerArgs: default=ServerArgs.disaggregation_bootstrap_port, help="Bootstrap server port on the prefill server. Default is 8998.", ) + parser.add_argument( + "--disaggregation-decode-tp", + type=int, + default=ServerArgs.disaggregation_decode_tp, + help="Decode tp size. If not set, it matches the tp size of the current engine. This is only set on the prefill server.", + ) + parser.add_argument( + "--disaggregation-decode-dp", + type=int, + default=ServerArgs.disaggregation_decode_dp, + help="Decode dp size. If not set, it matches the dp size of the current engine. This is only set on the prefill server.", + ) + parser.add_argument( + "--disaggregation-prefill-pp", + type=int, + default=ServerArgs.disaggregation_prefill_pp, + help="Prefill pp size. If not set, it is default to 1. This is only set on the decode server.", + ) parser.add_argument( "--disaggregation-ib-device", type=str,