[PD] Add transfer backend abstraction (#5328)
This commit is contained in:
8
python/sglang/srt/disaggregation/base/__init__.py
Normal file
8
python/sglang/srt/disaggregation/base/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from .conn import (
|
||||||
|
BaseKVBootstrapServer,
|
||||||
|
BaseKVManager,
|
||||||
|
BaseKVReceiver,
|
||||||
|
BaseKVSender,
|
||||||
|
KVArgs,
|
||||||
|
KVPoll,
|
||||||
|
)
|
||||||
106
python/sglang/srt/disaggregation/base/conn.py
Normal file
106
python/sglang/srt/disaggregation/base/conn.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
|
|
||||||
|
|
||||||
|
class KVArgs:
|
||||||
|
engine_rank: int
|
||||||
|
kv_data_ptrs: list[int]
|
||||||
|
kv_data_lens: list[int]
|
||||||
|
kv_item_lens: list[int]
|
||||||
|
aux_data_ptrs: list[int]
|
||||||
|
aux_data_lens: list[int]
|
||||||
|
aux_item_lens: list[int]
|
||||||
|
ib_device: str
|
||||||
|
|
||||||
|
|
||||||
|
class KVPoll:
|
||||||
|
Failed = 0
|
||||||
|
Bootstrapping = 1
|
||||||
|
WaitingForInput = 2
|
||||||
|
Transferring = 3
|
||||||
|
Success = 4
|
||||||
|
|
||||||
|
|
||||||
|
class BaseKVManager(ABC):
|
||||||
|
"""Base class for managing transfers states"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode): ...
|
||||||
|
|
||||||
|
|
||||||
|
class BaseKVSender(ABC):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int
|
||||||
|
): ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
||||||
|
"""
|
||||||
|
Notify the decoder server about the kv indices length and aux index
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def send(self, kv_indices: npt.NDArray[np.int64]):
|
||||||
|
"""
|
||||||
|
Send the kv cache at the given kv indices to the decoder server
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def poll(self) -> KVPoll:
|
||||||
|
"""
|
||||||
|
Check the status of the kv cache transfer
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def failure_exception(self):
|
||||||
|
"""
|
||||||
|
Raise an exception if the kv cache transfer fails
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class BaseKVReceiver(ABC):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mgr: BaseKVManager,
|
||||||
|
bootstrap_addr: str,
|
||||||
|
bootstrap_room: Optional[int] = None,
|
||||||
|
): ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
||||||
|
"""
|
||||||
|
Notify the prefill server about the kv indices and aux index
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def poll(self) -> KVPoll:
|
||||||
|
"""
|
||||||
|
Check the status of the kv cache transfer
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def failure_exception(self):
|
||||||
|
"""
|
||||||
|
Raise an exception if the kv cache transfer fails
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class BaseKVBootstrapServer(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, port: int): ...
|
||||||
@@ -28,10 +28,19 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVReceiver
|
from sglang.srt.disaggregation.base import (
|
||||||
|
BaseKVManager,
|
||||||
|
BaseKVReceiver,
|
||||||
|
BaseKVSender,
|
||||||
|
KVArgs,
|
||||||
|
KVPoll,
|
||||||
|
)
|
||||||
from sglang.srt.disaggregation.utils import (
|
from sglang.srt.disaggregation.utils import (
|
||||||
DisaggregationMode,
|
DisaggregationMode,
|
||||||
|
KVClassType,
|
||||||
ReqToMetadataIdxAllocator,
|
ReqToMetadataIdxAllocator,
|
||||||
|
TransferBackend,
|
||||||
|
get_kv_class,
|
||||||
poll_and_all_reduce,
|
poll_and_all_reduce,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
@@ -51,7 +60,7 @@ if TYPE_CHECKING:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class DecodeRequest:
|
class DecodeRequest:
|
||||||
req: Req
|
req: Req
|
||||||
kv_receiver: KVReceiver
|
kv_receiver: BaseKVReceiver
|
||||||
waiting_for_input: bool = False
|
waiting_for_input: bool = False
|
||||||
metadata_buffer_index: int = -1
|
metadata_buffer_index: int = -1
|
||||||
|
|
||||||
@@ -75,6 +84,7 @@ class DecodePreallocQueue:
|
|||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
tp_size: int,
|
tp_size: int,
|
||||||
bootstrap_port: int,
|
bootstrap_port: int,
|
||||||
|
transfer_backend: TransferBackend,
|
||||||
):
|
):
|
||||||
self.req_to_token_pool = req_to_token_pool
|
self.req_to_token_pool = req_to_token_pool
|
||||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||||
@@ -94,9 +104,10 @@ class DecodePreallocQueue:
|
|||||||
|
|
||||||
# Queue for requests pending pre-allocation
|
# Queue for requests pending pre-allocation
|
||||||
self.queue: List[DecodeRequest] = []
|
self.queue: List[DecodeRequest] = []
|
||||||
|
self.transfer_backend = transfer_backend
|
||||||
self.kv_manager = self._init_kv_manager()
|
self.kv_manager = self._init_kv_manager()
|
||||||
|
|
||||||
def _init_kv_manager(self) -> KVManager:
|
def _init_kv_manager(self) -> BaseKVManager:
|
||||||
kv_args = KVArgs()
|
kv_args = KVArgs()
|
||||||
kv_args.engine_rank = self.tp_rank
|
kv_args.engine_rank = self.tp_rank
|
||||||
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
||||||
@@ -117,13 +128,15 @@ class DecodePreallocQueue:
|
|||||||
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
||||||
]
|
]
|
||||||
kv_args.ib_device = "mock-ib-device"
|
kv_args.ib_device = "mock-ib-device"
|
||||||
kv_manager = KVManager(kv_args, DisaggregationMode("decode"))
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
||||||
|
kv_manager = kv_manager_class(kv_args, DisaggregationMode.DECODE)
|
||||||
return kv_manager
|
return kv_manager
|
||||||
|
|
||||||
def add(self, req: Req) -> None:
|
def add(self, req: Req) -> None:
|
||||||
"""Add a request to the pending queue."""
|
"""Add a request to the pending queue."""
|
||||||
|
|
||||||
kv_receiver = KVReceiver(
|
kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
|
||||||
|
kv_receiver = kv_receiver_class(
|
||||||
mgr=self.kv_manager,
|
mgr=self.kv_manager,
|
||||||
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
||||||
bootstrap_room=req.bootstrap_room,
|
bootstrap_room=req.bootstrap_room,
|
||||||
|
|||||||
6
python/sglang/srt/disaggregation/mooncake/__init__.py
Normal file
6
python/sglang/srt/disaggregation/mooncake/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from .conn import (
|
||||||
|
MooncakeKVBootstrapServer,
|
||||||
|
MooncakeKVManager,
|
||||||
|
MooncakeKVReceiver,
|
||||||
|
MooncakeKVSender,
|
||||||
|
)
|
||||||
@@ -12,7 +12,15 @@ import numpy.typing as npt
|
|||||||
import zmq
|
import zmq
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from sglang.srt.disaggregation.transfer_engine.mooncake import MooncakeTransferEngine
|
from sglang.srt.disaggregation.base.conn import (
|
||||||
|
BaseKVBootstrapServer,
|
||||||
|
BaseKVManager,
|
||||||
|
BaseKVReceiver,
|
||||||
|
BaseKVSender,
|
||||||
|
KVArgs,
|
||||||
|
KVPoll,
|
||||||
|
)
|
||||||
|
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
||||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -44,25 +52,6 @@ def group_concurrent_contiguous(
|
|||||||
return src_groups, dst_groups
|
return src_groups, dst_groups
|
||||||
|
|
||||||
|
|
||||||
class KVArgs:
|
|
||||||
engine_rank: int
|
|
||||||
kv_data_ptrs: list[int]
|
|
||||||
kv_data_lens: list[int]
|
|
||||||
kv_item_lens: list[int]
|
|
||||||
aux_data_ptrs: list[int]
|
|
||||||
aux_data_lens: list[int]
|
|
||||||
aux_item_lens: list[int]
|
|
||||||
ib_device: str
|
|
||||||
|
|
||||||
|
|
||||||
class KVPoll:
|
|
||||||
Failed = 0
|
|
||||||
Bootstrapping = 1
|
|
||||||
WaitingForInput = 2
|
|
||||||
Transferring = 3
|
|
||||||
Success = 4
|
|
||||||
|
|
||||||
|
|
||||||
RequestPoolType = Dict[int, Tuple[npt.NDArray[np.int64], Optional[int]]]
|
RequestPoolType = Dict[int, Tuple[npt.NDArray[np.int64], Optional[int]]]
|
||||||
WaitingPoolType = Dict[
|
WaitingPoolType = Dict[
|
||||||
int, Tuple[str, list[int], npt.NDArray[np.int64], list[int], int]
|
int, Tuple[str, list[int], npt.NDArray[np.int64], list[int], int]
|
||||||
@@ -71,8 +60,7 @@ KVSENDER_POLLING_PORT = 17788
|
|||||||
KVRECEIVER_POLLING_PORT = 27788
|
KVRECEIVER_POLLING_PORT = 27788
|
||||||
|
|
||||||
|
|
||||||
class KVManager:
|
class MooncakeKVManager(BaseKVManager):
|
||||||
# TODO: make it general and support multiple transfer backend before merging
|
|
||||||
def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode):
|
def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode):
|
||||||
self.engine = MooncakeTransferEngine()
|
self.engine = MooncakeTransferEngine()
|
||||||
self.kv_args = args
|
self.kv_args = args
|
||||||
@@ -331,9 +319,11 @@ class KVManager:
|
|||||||
return self.engine.get_session_id()
|
return self.engine.get_session_id()
|
||||||
|
|
||||||
|
|
||||||
class KVSender:
|
class MooncakeKVSender(BaseKVSender):
|
||||||
|
|
||||||
def __init__(self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int):
|
def __init__(
|
||||||
|
self, mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: int
|
||||||
|
):
|
||||||
self.kv_mgr = mgr
|
self.kv_mgr = mgr
|
||||||
self.bootstrap_room = bootstrap_room
|
self.bootstrap_room = bootstrap_room
|
||||||
self.kv_mgr.set_status(bootstrap_room, KVPoll.WaitingForInput)
|
self.kv_mgr.set_status(bootstrap_room, KVPoll.WaitingForInput)
|
||||||
@@ -353,10 +343,13 @@ class KVSender:
|
|||||||
raise Exception("Fake KVSender Exception")
|
raise Exception("Fake KVSender Exception")
|
||||||
|
|
||||||
|
|
||||||
class KVReceiver:
|
class MooncakeKVReceiver(BaseKVReceiver):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None
|
self,
|
||||||
|
mgr: MooncakeKVManager,
|
||||||
|
bootstrap_addr: str,
|
||||||
|
bootstrap_room: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.bootstrap_room = bootstrap_room
|
self.bootstrap_room = bootstrap_room
|
||||||
self.bootstrap_addr = bootstrap_addr
|
self.bootstrap_addr = bootstrap_addr
|
||||||
@@ -403,7 +396,7 @@ class KVReceiver:
|
|||||||
raise Exception("Fake KVReceiver Exception")
|
raise Exception("Fake KVReceiver Exception")
|
||||||
|
|
||||||
|
|
||||||
class KVBootstrapServer:
|
class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
||||||
def __init__(self, port: int):
|
def __init__(self, port: int):
|
||||||
self.port = port
|
self.port = port
|
||||||
self.app = web.Application()
|
self.app = web.Application()
|
||||||
@@ -24,10 +24,19 @@ from typing import TYPE_CHECKING, List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVSender
|
from sglang.srt.disaggregation.base import (
|
||||||
|
BaseKVManager,
|
||||||
|
BaseKVReceiver,
|
||||||
|
BaseKVSender,
|
||||||
|
KVArgs,
|
||||||
|
KVPoll,
|
||||||
|
)
|
||||||
from sglang.srt.disaggregation.utils import (
|
from sglang.srt.disaggregation.utils import (
|
||||||
DisaggregationMode,
|
DisaggregationMode,
|
||||||
|
KVClassType,
|
||||||
ReqToMetadataIdxAllocator,
|
ReqToMetadataIdxAllocator,
|
||||||
|
TransferBackend,
|
||||||
|
get_kv_class,
|
||||||
poll_and_all_reduce,
|
poll_and_all_reduce,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
||||||
@@ -38,6 +47,7 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
|
from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
|
||||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -56,6 +66,7 @@ class PrefillBootstrapQueue:
|
|||||||
tp_size: int,
|
tp_size: int,
|
||||||
bootstrap_port: int,
|
bootstrap_port: int,
|
||||||
gloo_group: ProcessGroup,
|
gloo_group: ProcessGroup,
|
||||||
|
transfer_backend: TransferBackend,
|
||||||
):
|
):
|
||||||
self.token_to_kv_pool = token_to_kv_pool
|
self.token_to_kv_pool = token_to_kv_pool
|
||||||
self.aux_dtype = aux_dtype
|
self.aux_dtype = aux_dtype
|
||||||
@@ -64,6 +75,7 @@ class PrefillBootstrapQueue:
|
|||||||
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
|
self.transfer_backend = transfer_backend
|
||||||
self.kv_manager = self._init_kv_manager()
|
self.kv_manager = self._init_kv_manager()
|
||||||
self.queue: List[Req] = []
|
self.queue: List[Req] = []
|
||||||
self.gloo_group = gloo_group
|
self.gloo_group = gloo_group
|
||||||
@@ -74,7 +86,7 @@ class PrefillBootstrapQueue:
|
|||||||
output_id_buffer = self.metadata_buffers[0]
|
output_id_buffer = self.metadata_buffers[0]
|
||||||
output_id_buffer[idx] = token_id
|
output_id_buffer[idx] = token_id
|
||||||
|
|
||||||
def _init_kv_manager(self) -> KVManager:
|
def _init_kv_manager(self) -> BaseKVManager:
|
||||||
kv_args = KVArgs()
|
kv_args = KVArgs()
|
||||||
kv_args.engine_rank = self.tp_rank
|
kv_args.engine_rank = self.tp_rank
|
||||||
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
||||||
@@ -96,11 +108,13 @@ class PrefillBootstrapQueue:
|
|||||||
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
||||||
]
|
]
|
||||||
kv_args.ib_device = "mock-ib-device"
|
kv_args.ib_device = "mock-ib-device"
|
||||||
kv_manager = KVManager(kv_args, DisaggregationMode("prefill"))
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
||||||
|
kv_manager = kv_manager_class(kv_args, DisaggregationMode.PREFILL)
|
||||||
return kv_manager
|
return kv_manager
|
||||||
|
|
||||||
def add(self, req: Req) -> None:
|
def add(self, req: Req) -> None:
|
||||||
req.disagg_kv_sender = KVSender(
|
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
|
||||||
|
req.disagg_kv_sender = kv_sender_class(
|
||||||
mgr=self.kv_manager,
|
mgr=self.kv_manager,
|
||||||
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
||||||
bootstrap_room=req.bootstrap_room,
|
bootstrap_room=req.bootstrap_room,
|
||||||
|
|||||||
@@ -42,3 +42,34 @@ class ReqToMetadataIdxAllocator:
|
|||||||
|
|
||||||
def free(self, free_index: int):
|
def free(self, free_index: int):
|
||||||
self.free_slots.append(free_index)
|
self.free_slots.append(free_index)
|
||||||
|
|
||||||
|
|
||||||
|
class TransferBackend(Enum):
|
||||||
|
MOONCAKE = "mooncake"
|
||||||
|
FAKE = "fake"
|
||||||
|
|
||||||
|
|
||||||
|
class KVClassType(Enum):
|
||||||
|
MANAGER = "manager"
|
||||||
|
SENDER = "sender"
|
||||||
|
RECEIVER = "receiver"
|
||||||
|
BOOTSTRAP_SERVER = "bootstrap_server"
|
||||||
|
|
||||||
|
|
||||||
|
def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
||||||
|
if transfer_backend == TransferBackend.MOONCAKE:
|
||||||
|
from sglang.srt.disaggregation.mooncake import (
|
||||||
|
MooncakeKVBootstrapServer,
|
||||||
|
MooncakeKVManager,
|
||||||
|
MooncakeKVReceiver,
|
||||||
|
MooncakeKVSender,
|
||||||
|
)
|
||||||
|
|
||||||
|
class_mapping = {
|
||||||
|
KVClassType.MANAGER: MooncakeKVManager,
|
||||||
|
KVClassType.SENDER: MooncakeKVSender,
|
||||||
|
KVClassType.RECEIVER: MooncakeKVReceiver,
|
||||||
|
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
||||||
|
}
|
||||||
|
return class_mapping.get(class_type)
|
||||||
|
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ import triton.language as tl
|
|||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||||
from sglang.srt.disaggregation.conn import KVSender
|
from sglang.srt.disaggregation.base import BaseKVSender
|
||||||
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
|
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
@@ -525,7 +525,7 @@ class Req:
|
|||||||
# For disaggregation
|
# For disaggregation
|
||||||
self.bootstrap_host: str = bootstrap_host
|
self.bootstrap_host: str = bootstrap_host
|
||||||
self.bootstrap_room: Optional[int] = bootstrap_room
|
self.bootstrap_room: Optional[int] = bootstrap_room
|
||||||
self.disagg_kv_sender: Optional[KVSender] = None
|
self.disagg_kv_sender: Optional[BaseKVSender] = None
|
||||||
|
|
||||||
# used for warmup because we don't have a pair yet when init
|
# used for warmup because we don't have a pair yet when init
|
||||||
self.skip_kv_transfer: bool = False
|
self.skip_kv_transfer: bool = False
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ from sglang.srt.disaggregation.prefill import (
|
|||||||
from sglang.srt.disaggregation.utils import (
|
from sglang.srt.disaggregation.utils import (
|
||||||
DisaggregationMode,
|
DisaggregationMode,
|
||||||
ReqToMetadataIdxAllocator,
|
ReqToMetadataIdxAllocator,
|
||||||
|
TransferBackend,
|
||||||
)
|
)
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||||
@@ -530,6 +531,10 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_disaggregation(self):
|
def init_disaggregation(self):
|
||||||
|
self.transfer_backend = TransferBackend(
|
||||||
|
self.server_args.disaggregation_transfer_backend
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.disaggregation_mode == DisaggregationMode.DECODE
|
self.disaggregation_mode == DisaggregationMode.DECODE
|
||||||
): # *2 for the headroom.
|
): # *2 for the headroom.
|
||||||
@@ -567,6 +572,7 @@ class Scheduler(
|
|||||||
tp_rank=self.tp_rank,
|
tp_rank=self.tp_rank,
|
||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
||||||
|
transfer_backend=self.transfer_backend,
|
||||||
)
|
)
|
||||||
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
# *2 for the headroom.
|
# *2 for the headroom.
|
||||||
@@ -592,6 +598,7 @@ class Scheduler(
|
|||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
||||||
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
||||||
|
transfer_backend=self.transfer_backend,
|
||||||
)
|
)
|
||||||
# The prefill requests that are in the middle of kv sending
|
# The prefill requests that are in the middle of kv sending
|
||||||
self.disagg_prefill_inflight_queue: List[Req] = []
|
self.disagg_prefill_inflight_queue: List[Req] = []
|
||||||
|
|||||||
@@ -48,8 +48,12 @@ from fastapi import BackgroundTasks
|
|||||||
|
|
||||||
from sglang.srt.aio_rwlock import RWLock
|
from sglang.srt.aio_rwlock import RWLock
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.disaggregation.conn import KVBootstrapServer
|
from sglang.srt.disaggregation.utils import (
|
||||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
DisaggregationMode,
|
||||||
|
KVClassType,
|
||||||
|
TransferBackend,
|
||||||
|
get_kv_class,
|
||||||
|
)
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
@@ -329,10 +333,16 @@ class TokenizerManager:
|
|||||||
self.disaggregation_mode = DisaggregationMode(
|
self.disaggregation_mode = DisaggregationMode(
|
||||||
self.server_args.disaggregation_mode
|
self.server_args.disaggregation_mode
|
||||||
)
|
)
|
||||||
|
self.transfer_backend = TransferBackend(
|
||||||
|
self.server_args.disaggregation_transfer_backend
|
||||||
|
)
|
||||||
# for disaggregtion, start kv boostrap server on prefill
|
# for disaggregtion, start kv boostrap server on prefill
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
# only start bootstrap server on prefill tm
|
# only start bootstrap server on prefill tm
|
||||||
self.bootstrap_server = KVBootstrapServer(
|
kv_bootstrap_server_class = get_kv_class(
|
||||||
|
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
||||||
|
)
|
||||||
|
self.bootstrap_server = kv_bootstrap_server_class(
|
||||||
self.server_args.disaggregation_bootstrap_port
|
self.server_args.disaggregation_bootstrap_port
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -195,6 +195,7 @@ class ServerArgs:
|
|||||||
# For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
|
# For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
|
||||||
disaggregation_mode: str = "null"
|
disaggregation_mode: str = "null"
|
||||||
disaggregation_bootstrap_port: int = 8998
|
disaggregation_bootstrap_port: int = 8998
|
||||||
|
disaggregation_transfer_backend: str = "mooncake"
|
||||||
|
|
||||||
# multimodal
|
# multimodal
|
||||||
disable_fast_image_processor: bool = False
|
disable_fast_image_processor: bool = False
|
||||||
@@ -1173,6 +1174,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.disaggregation_bootstrap_port,
|
default=ServerArgs.disaggregation_bootstrap_port,
|
||||||
help="Bootstrap server port on the prefill server. Default is 8998.",
|
help="Bootstrap server port on the prefill server. Default is 8998.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disaggregation-transfer-backend",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.disaggregation_transfer_backend,
|
||||||
|
help="The backend for disaggregation transfer. Default is mooncake.",
|
||||||
|
)
|
||||||
|
|
||||||
# Multimodal
|
# Multimodal
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user