From a9499885e97e77b8b0a57642423c6e2c1a6fcaa8 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Sun, 13 Apr 2025 10:39:39 -0700 Subject: [PATCH] [PD] Add transfer backend abstraction (#5328) --- .../srt/disaggregation/base/__init__.py | 8 ++ python/sglang/srt/disaggregation/base/conn.py | 106 ++++++++++++++++++ python/sglang/srt/disaggregation/decode.py | 23 +++- .../srt/disaggregation/mooncake/__init__.py | 6 + .../srt/disaggregation/{ => mooncake}/conn.py | 47 ++++---- .../transfer_engine.py} | 0 python/sglang/srt/disaggregation/prefill.py | 22 +++- python/sglang/srt/disaggregation/utils.py | 31 +++++ python/sglang/srt/managers/schedule_batch.py | 4 +- python/sglang/srt/managers/scheduler.py | 7 ++ .../sglang/srt/managers/tokenizer_manager.py | 16 ++- python/sglang/srt/server_args.py | 7 ++ 12 files changed, 236 insertions(+), 41 deletions(-) create mode 100644 python/sglang/srt/disaggregation/base/__init__.py create mode 100644 python/sglang/srt/disaggregation/base/conn.py create mode 100644 python/sglang/srt/disaggregation/mooncake/__init__.py rename python/sglang/srt/disaggregation/{ => mooncake}/conn.py (95%) rename python/sglang/srt/disaggregation/{transfer_engine/mooncake.py => mooncake/transfer_engine.py} (100%) diff --git a/python/sglang/srt/disaggregation/base/__init__.py b/python/sglang/srt/disaggregation/base/__init__.py new file mode 100644 index 000000000..bfeecfe1c --- /dev/null +++ b/python/sglang/srt/disaggregation/base/__init__.py @@ -0,0 +1,8 @@ +from .conn import ( + BaseKVBootstrapServer, + BaseKVManager, + BaseKVReceiver, + BaseKVSender, + KVArgs, + KVPoll, +) diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py new file mode 100644 index 000000000..bdf5f5027 --- /dev/null +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -0,0 +1,106 @@ +from abc import ABC, abstractmethod +from typing import Optional + +import numpy as np +import numpy.typing as npt + +from sglang.srt.disaggregation.utils import DisaggregationMode + + +class KVArgs: + engine_rank: int + kv_data_ptrs: list[int] + kv_data_lens: list[int] + kv_item_lens: list[int] + aux_data_ptrs: list[int] + aux_data_lens: list[int] + aux_item_lens: list[int] + ib_device: str + + +class KVPoll: + Failed = 0 + Bootstrapping = 1 + WaitingForInput = 2 + Transferring = 3 + Success = 4 + + +class BaseKVManager(ABC): + """Base class for managing transfers states""" + + @abstractmethod + def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode): ... + + +class BaseKVSender(ABC): + + @abstractmethod + def __init__( + self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int + ): ... + + @abstractmethod + def init(self, num_kv_indices: int, aux_index: Optional[int] = None): + """ + Notify the decoder server about the kv indices length and aux index + """ + ... + + @abstractmethod + def send(self, kv_indices: npt.NDArray[np.int64]): + """ + Send the kv cache at the given kv indices to the decoder server + """ + ... + + @abstractmethod + def poll(self) -> KVPoll: + """ + Check the status of the kv cache transfer + """ + ... + + @abstractmethod + def failure_exception(self): + """ + Raise an exception if the kv cache transfer fails + """ + ... + + +class BaseKVReceiver(ABC): + + @abstractmethod + def __init__( + self, + mgr: BaseKVManager, + bootstrap_addr: str, + bootstrap_room: Optional[int] = None, + ): ... + + @abstractmethod + def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): + """ + Notify the prefill server about the kv indices and aux index + """ + ... + + @abstractmethod + def poll(self) -> KVPoll: + """ + Check the status of the kv cache transfer + """ + ... + + @abstractmethod + def failure_exception(self): + """ + Raise an exception if the kv cache transfer fails + """ + ... + + +class BaseKVBootstrapServer(ABC): + @abstractmethod + def __init__(self, port: int): ... diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index e0918a083..1f4f9cfa7 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -28,10 +28,19 @@ import numpy as np import torch from torch.distributed import ProcessGroup -from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVReceiver +from sglang.srt.disaggregation.base import ( + BaseKVManager, + BaseKVReceiver, + BaseKVSender, + KVArgs, + KVPoll, +) from sglang.srt.disaggregation.utils import ( DisaggregationMode, + KVClassType, ReqToMetadataIdxAllocator, + TransferBackend, + get_kv_class, poll_and_all_reduce, ) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache @@ -51,7 +60,7 @@ if TYPE_CHECKING: @dataclass class DecodeRequest: req: Req - kv_receiver: KVReceiver + kv_receiver: BaseKVReceiver waiting_for_input: bool = False metadata_buffer_index: int = -1 @@ -75,6 +84,7 @@ class DecodePreallocQueue: tp_rank: int, tp_size: int, bootstrap_port: int, + transfer_backend: TransferBackend, ): self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator @@ -94,9 +104,10 @@ class DecodePreallocQueue: # Queue for requests pending pre-allocation self.queue: List[DecodeRequest] = [] + self.transfer_backend = transfer_backend self.kv_manager = self._init_kv_manager() - def _init_kv_manager(self) -> KVManager: + def _init_kv_manager(self) -> BaseKVManager: kv_args = KVArgs() kv_args.engine_rank = self.tp_rank kv_data_ptrs, kv_data_lens, kv_item_lens = ( @@ -117,13 +128,15 @@ class DecodePreallocQueue: metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers ] kv_args.ib_device = "mock-ib-device" - kv_manager = KVManager(kv_args, DisaggregationMode("decode")) + kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) + kv_manager = kv_manager_class(kv_args, DisaggregationMode.DECODE) return kv_manager def add(self, req: Req) -> None: """Add a request to the pending queue.""" - kv_receiver = KVReceiver( + kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER) + kv_receiver = kv_receiver_class( mgr=self.kv_manager, bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}", bootstrap_room=req.bootstrap_room, diff --git a/python/sglang/srt/disaggregation/mooncake/__init__.py b/python/sglang/srt/disaggregation/mooncake/__init__.py new file mode 100644 index 000000000..035097f68 --- /dev/null +++ b/python/sglang/srt/disaggregation/mooncake/__init__.py @@ -0,0 +1,6 @@ +from .conn import ( + MooncakeKVBootstrapServer, + MooncakeKVManager, + MooncakeKVReceiver, + MooncakeKVSender, +) diff --git a/python/sglang/srt/disaggregation/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py similarity index 95% rename from python/sglang/srt/disaggregation/conn.py rename to python/sglang/srt/disaggregation/mooncake/conn.py index f61add327..c1bdb7c74 100644 --- a/python/sglang/srt/disaggregation/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -12,7 +12,15 @@ import numpy.typing as npt import zmq from aiohttp import web -from sglang.srt.disaggregation.transfer_engine.mooncake import MooncakeTransferEngine +from sglang.srt.disaggregation.base.conn import ( + BaseKVBootstrapServer, + BaseKVManager, + BaseKVReceiver, + BaseKVSender, + KVArgs, + KVPoll, +) +from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.utils import DisaggregationMode logger = logging.getLogger(__name__) @@ -44,25 +52,6 @@ def group_concurrent_contiguous( return src_groups, dst_groups -class KVArgs: - engine_rank: int - kv_data_ptrs: list[int] - kv_data_lens: list[int] - kv_item_lens: list[int] - aux_data_ptrs: list[int] - aux_data_lens: list[int] - aux_item_lens: list[int] - ib_device: str - - -class KVPoll: - Failed = 0 - Bootstrapping = 1 - WaitingForInput = 2 - Transferring = 3 - Success = 4 - - RequestPoolType = Dict[int, Tuple[npt.NDArray[np.int64], Optional[int]]] WaitingPoolType = Dict[ int, Tuple[str, list[int], npt.NDArray[np.int64], list[int], int] @@ -71,8 +60,7 @@ KVSENDER_POLLING_PORT = 17788 KVRECEIVER_POLLING_PORT = 27788 -class KVManager: - # TODO: make it general and support multiple transfer backend before merging +class MooncakeKVManager(BaseKVManager): def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode): self.engine = MooncakeTransferEngine() self.kv_args = args @@ -331,9 +319,11 @@ class KVManager: return self.engine.get_session_id() -class KVSender: +class MooncakeKVSender(BaseKVSender): - def __init__(self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int): + def __init__( + self, mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: int + ): self.kv_mgr = mgr self.bootstrap_room = bootstrap_room self.kv_mgr.set_status(bootstrap_room, KVPoll.WaitingForInput) @@ -353,10 +343,13 @@ class KVSender: raise Exception("Fake KVSender Exception") -class KVReceiver: +class MooncakeKVReceiver(BaseKVReceiver): def __init__( - self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None + self, + mgr: MooncakeKVManager, + bootstrap_addr: str, + bootstrap_room: Optional[int] = None, ): self.bootstrap_room = bootstrap_room self.bootstrap_addr = bootstrap_addr @@ -403,7 +396,7 @@ class KVReceiver: raise Exception("Fake KVReceiver Exception") -class KVBootstrapServer: +class MooncakeKVBootstrapServer(BaseKVBootstrapServer): def __init__(self, port: int): self.port = port self.app = web.Application() diff --git a/python/sglang/srt/disaggregation/transfer_engine/mooncake.py b/python/sglang/srt/disaggregation/mooncake/transfer_engine.py similarity index 100% rename from python/sglang/srt/disaggregation/transfer_engine/mooncake.py rename to python/sglang/srt/disaggregation/mooncake/transfer_engine.py diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 865955be5..c6020767b 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -24,10 +24,19 @@ from typing import TYPE_CHECKING, List, Optional import torch -from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVSender +from sglang.srt.disaggregation.base import ( + BaseKVManager, + BaseKVReceiver, + BaseKVSender, + KVArgs, + KVPoll, +) from sglang.srt.disaggregation.utils import ( DisaggregationMode, + KVClassType, ReqToMetadataIdxAllocator, + TransferBackend, + get_kv_class, poll_and_all_reduce, ) from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch @@ -38,6 +47,7 @@ if TYPE_CHECKING: from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler from sglang.srt.mem_cache.memory_pool import KVCache + logger = logging.getLogger(__name__) @@ -56,6 +66,7 @@ class PrefillBootstrapQueue: tp_size: int, bootstrap_port: int, gloo_group: ProcessGroup, + transfer_backend: TransferBackend, ): self.token_to_kv_pool = token_to_kv_pool self.aux_dtype = aux_dtype @@ -64,6 +75,7 @@ class PrefillBootstrapQueue: self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator self.tp_rank = tp_rank self.tp_size = tp_size + self.transfer_backend = transfer_backend self.kv_manager = self._init_kv_manager() self.queue: List[Req] = [] self.gloo_group = gloo_group @@ -74,7 +86,7 @@ class PrefillBootstrapQueue: output_id_buffer = self.metadata_buffers[0] output_id_buffer[idx] = token_id - def _init_kv_manager(self) -> KVManager: + def _init_kv_manager(self) -> BaseKVManager: kv_args = KVArgs() kv_args.engine_rank = self.tp_rank kv_data_ptrs, kv_data_lens, kv_item_lens = ( @@ -96,11 +108,13 @@ class PrefillBootstrapQueue: metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers ] kv_args.ib_device = "mock-ib-device" - kv_manager = KVManager(kv_args, DisaggregationMode("prefill")) + kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) + kv_manager = kv_manager_class(kv_args, DisaggregationMode.PREFILL) return kv_manager def add(self, req: Req) -> None: - req.disagg_kv_sender = KVSender( + kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER) + req.disagg_kv_sender = kv_sender_class( mgr=self.kv_manager, bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}", bootstrap_room=req.bootstrap_room, diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 76da71a00..54d344416 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -42,3 +42,34 @@ class ReqToMetadataIdxAllocator: def free(self, free_index: int): self.free_slots.append(free_index) + + +class TransferBackend(Enum): + MOONCAKE = "mooncake" + FAKE = "fake" + + +class KVClassType(Enum): + MANAGER = "manager" + SENDER = "sender" + RECEIVER = "receiver" + BOOTSTRAP_SERVER = "bootstrap_server" + + +def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): + if transfer_backend == TransferBackend.MOONCAKE: + from sglang.srt.disaggregation.mooncake import ( + MooncakeKVBootstrapServer, + MooncakeKVManager, + MooncakeKVReceiver, + MooncakeKVSender, + ) + + class_mapping = { + KVClassType.MANAGER: MooncakeKVManager, + KVClassType.SENDER: MooncakeKVSender, + KVClassType.RECEIVER: MooncakeKVReceiver, + KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer, + } + return class_mapping.get(class_type) + raise ValueError(f"Unsupported transfer backend: {transfer_backend}") diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 020dbcf4b..fa9f40112 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -45,7 +45,7 @@ import triton.language as tl from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject -from sglang.srt.disaggregation.conn import KVSender +from sglang.srt.disaggregation.base import BaseKVSender from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache @@ -525,7 +525,7 @@ class Req: # For disaggregation self.bootstrap_host: str = bootstrap_host self.bootstrap_room: Optional[int] = bootstrap_room - self.disagg_kv_sender: Optional[KVSender] = None + self.disagg_kv_sender: Optional[BaseKVSender] = None # used for warmup because we don't have a pair yet when init self.skip_kv_transfer: bool = False diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index ae469ae2a..156146e83 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -49,6 +49,7 @@ from sglang.srt.disaggregation.prefill import ( from sglang.srt.disaggregation.utils import ( DisaggregationMode, ReqToMetadataIdxAllocator, + TransferBackend, ) from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.dp_attention import compute_dp_attention_world_info @@ -530,6 +531,10 @@ class Scheduler( ) def init_disaggregation(self): + self.transfer_backend = TransferBackend( + self.server_args.disaggregation_transfer_backend + ) + if ( self.disaggregation_mode == DisaggregationMode.DECODE ): # *2 for the headroom. @@ -567,6 +572,7 @@ class Scheduler( tp_rank=self.tp_rank, tp_size=self.tp_size, bootstrap_port=self.server_args.disaggregation_bootstrap_port, + transfer_backend=self.transfer_backend, ) elif self.disaggregation_mode == DisaggregationMode.PREFILL: # *2 for the headroom. @@ -592,6 +598,7 @@ class Scheduler( tp_size=self.tp_size, bootstrap_port=self.server_args.disaggregation_bootstrap_port, gloo_group=self.tp_worker.get_attention_tp_cpu_group(), + transfer_backend=self.transfer_backend, ) # The prefill requests that are in the middle of kv sending self.disagg_prefill_inflight_queue: List[Req] = [] diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 69df67058..1acd97f5b 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -48,8 +48,12 @@ from fastapi import BackgroundTasks from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.disaggregation.conn import KVBootstrapServer -from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.disaggregation.utils import ( + DisaggregationMode, + KVClassType, + TransferBackend, + get_kv_class, +) from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.io_struct import ( AbortReq, @@ -329,10 +333,16 @@ class TokenizerManager: self.disaggregation_mode = DisaggregationMode( self.server_args.disaggregation_mode ) + self.transfer_backend = TransferBackend( + self.server_args.disaggregation_transfer_backend + ) # for disaggregtion, start kv boostrap server on prefill if self.disaggregation_mode == DisaggregationMode.PREFILL: # only start bootstrap server on prefill tm - self.bootstrap_server = KVBootstrapServer( + kv_bootstrap_server_class = get_kv_class( + self.transfer_backend, KVClassType.BOOTSTRAP_SERVER + ) + self.bootstrap_server = kv_bootstrap_server_class( self.server_args.disaggregation_bootstrap_port ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 59b744c14..6d78b654a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -195,6 +195,7 @@ class ServerArgs: # For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only) disaggregation_mode: str = "null" disaggregation_bootstrap_port: int = 8998 + disaggregation_transfer_backend: str = "mooncake" # multimodal disable_fast_image_processor: bool = False @@ -1173,6 +1174,12 @@ class ServerArgs: default=ServerArgs.disaggregation_bootstrap_port, help="Bootstrap server port on the prefill server. Default is 8998.", ) + parser.add_argument( + "--disaggregation-transfer-backend", + type=str, + default=ServerArgs.disaggregation_transfer_backend, + help="The backend for disaggregation transfer. Default is mooncake.", + ) # Multimodal parser.add_argument(