[PD] Update prefill.py (#7190)
This commit is contained in:
@@ -1,23 +1,32 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
|
||||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
|
|
||||||
|
|
||||||
class KVArgs:
|
class KVArgs:
|
||||||
engine_rank: int
|
engine_rank: int
|
||||||
kv_data_ptrs: list[int]
|
kv_data_ptrs: List[int]
|
||||||
kv_data_lens: list[int]
|
kv_data_lens: List[int]
|
||||||
kv_item_lens: list[int]
|
kv_item_lens: List[int]
|
||||||
aux_data_ptrs: list[int]
|
aux_data_ptrs: List[int]
|
||||||
aux_data_lens: list[int]
|
aux_data_lens: List[int]
|
||||||
aux_item_lens: list[int]
|
aux_item_lens: List[int]
|
||||||
ib_device: str
|
ib_device: str
|
||||||
|
ib_traffic_class: str
|
||||||
gpu_id: int
|
gpu_id: int
|
||||||
|
# for different tp
|
||||||
|
decode_tp_size: int
|
||||||
|
# for pp prefill
|
||||||
|
prefill_pp_size: int
|
||||||
|
|
||||||
|
|
||||||
class KVPoll:
|
class KVPoll:
|
||||||
@@ -45,7 +54,12 @@ class BaseKVSender(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(
|
def __init__(
|
||||||
self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int
|
self,
|
||||||
|
mgr: BaseKVManager,
|
||||||
|
bootstrap_addr: str,
|
||||||
|
bootstrap_room: int,
|
||||||
|
dest_tp_ranks: List[int],
|
||||||
|
pp_rank: int,
|
||||||
): ...
|
): ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
42
python/sglang/srt/disaggregation/common/utils.py
Normal file
42
python/sglang/srt/disaggregation/common/utils.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import threading
|
||||||
|
from collections import deque
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
|
|
||||||
|
class FastQueue:
|
||||||
|
def __init__(self):
|
||||||
|
self._buf = deque()
|
||||||
|
self._cond = threading.Condition()
|
||||||
|
|
||||||
|
def put(self, item):
|
||||||
|
with self._cond:
|
||||||
|
self._buf.append(item)
|
||||||
|
# wake up a thread of wait()
|
||||||
|
self._cond.notify()
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
with self._cond:
|
||||||
|
# if queue is empty ,block until is notified()
|
||||||
|
while not self._buf:
|
||||||
|
self._cond.wait()
|
||||||
|
return self._buf.popleft()
|
||||||
|
|
||||||
|
|
||||||
|
def group_concurrent_contiguous(
|
||||||
|
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
||||||
|
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
||||||
|
"""Vectorised NumPy implementation."""
|
||||||
|
if src_indices.size == 0:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
|
||||||
|
src_groups = np.split(src_indices, brk)
|
||||||
|
dst_groups = np.split(dst_indices, brk)
|
||||||
|
|
||||||
|
src_groups = [g.tolist() for g in src_groups]
|
||||||
|
dst_groups = [g.tolist() for g in dst_groups]
|
||||||
|
|
||||||
|
return src_groups, dst_groups
|
||||||
@@ -33,8 +33,8 @@ from torch.distributed import ProcessGroup
|
|||||||
|
|
||||||
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
|
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
|
||||||
from sglang.srt.disaggregation.utils import (
|
from sglang.srt.disaggregation.utils import (
|
||||||
|
FAKE_BOOTSTRAP_HOST,
|
||||||
DisaggregationMode,
|
DisaggregationMode,
|
||||||
FakeBootstrapHost,
|
|
||||||
KVClassType,
|
KVClassType,
|
||||||
MetadataBuffers,
|
MetadataBuffers,
|
||||||
ReqToMetadataIdxAllocator,
|
ReqToMetadataIdxAllocator,
|
||||||
@@ -207,7 +207,7 @@ class DecodePreallocQueue:
|
|||||||
|
|
||||||
def add(self, req: Req) -> None:
|
def add(self, req: Req) -> None:
|
||||||
"""Add a request to the pending queue."""
|
"""Add a request to the pending queue."""
|
||||||
if req.bootstrap_host == FakeBootstrapHost:
|
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
|
||||||
# Fake transfer for warmup reqs
|
# Fake transfer for warmup reqs
|
||||||
kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
|
kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -17,7 +17,14 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# For warmup reqs, we don't kv transfer, we use the fake sender and receiver
|
# For warmup reqs, we don't kv transfer, we use the fake sender and receiver
|
||||||
class FakeKVSender(BaseKVSender):
|
class FakeKVSender(BaseKVSender):
|
||||||
def __init__(self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int):
|
def __init__(
|
||||||
|
self,
|
||||||
|
mgr: BaseKVManager,
|
||||||
|
bootstrap_addr: str,
|
||||||
|
bootstrap_room: int,
|
||||||
|
dest_tp_ranks: List[int],
|
||||||
|
pp_rank: int,
|
||||||
|
):
|
||||||
self.has_sent = False
|
self.has_sent = False
|
||||||
|
|
||||||
def poll(self) -> KVPoll:
|
def poll(self) -> KVPoll:
|
||||||
|
|||||||
@@ -28,12 +28,12 @@ from sglang.srt.disaggregation.base.conn import (
|
|||||||
KVArgs,
|
KVArgs,
|
||||||
KVPoll,
|
KVPoll,
|
||||||
)
|
)
|
||||||
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
from sglang.srt.disaggregation.common.utils import (
|
||||||
from sglang.srt.disaggregation.utils import (
|
|
||||||
DisaggregationMode,
|
|
||||||
FastQueue,
|
FastQueue,
|
||||||
group_concurrent_contiguous,
|
group_concurrent_contiguous,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
||||||
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_free_port,
|
get_free_port,
|
||||||
@@ -677,7 +677,12 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
class MooncakeKVSender(BaseKVSender):
|
class MooncakeKVSender(BaseKVSender):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: int
|
self,
|
||||||
|
mgr: MooncakeKVManager,
|
||||||
|
bootstrap_addr: str,
|
||||||
|
bootstrap_room: int,
|
||||||
|
dest_tp_ranks: List[int],
|
||||||
|
pp_rank: int,
|
||||||
):
|
):
|
||||||
self.kv_mgr = mgr
|
self.kv_mgr = mgr
|
||||||
self.bootstrap_room = bootstrap_room
|
self.bootstrap_room = bootstrap_room
|
||||||
|
|||||||
@@ -24,10 +24,8 @@ from sglang.srt.disaggregation.common.conn import (
|
|||||||
CommonKVManager,
|
CommonKVManager,
|
||||||
CommonKVReceiver,
|
CommonKVReceiver,
|
||||||
)
|
)
|
||||||
from sglang.srt.disaggregation.utils import (
|
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
||||||
DisaggregationMode,
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
group_concurrent_contiguous,
|
|
||||||
)
|
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_local_ip_by_remote
|
from sglang.srt.utils import get_local_ip_by_remote
|
||||||
|
|
||||||
@@ -350,7 +348,14 @@ class NixlKVManager(CommonKVManager):
|
|||||||
|
|
||||||
class NixlKVSender(BaseKVSender):
|
class NixlKVSender(BaseKVSender):
|
||||||
|
|
||||||
def __init__(self, mgr: NixlKVManager, bootstrap_addr: str, bootstrap_room: int):
|
def __init__(
|
||||||
|
self,
|
||||||
|
mgr: NixlKVManager,
|
||||||
|
bootstrap_addr: str,
|
||||||
|
bootstrap_room: int,
|
||||||
|
dest_tp_ranks: List[int],
|
||||||
|
pp_rank: int,
|
||||||
|
):
|
||||||
self.kv_mgr = mgr
|
self.kv_mgr = mgr
|
||||||
self.bootstrap_room = bootstrap_room
|
self.bootstrap_room = bootstrap_room
|
||||||
self.aux_index = None
|
self.aux_index = None
|
||||||
|
|||||||
@@ -27,10 +27,10 @@ from typing import TYPE_CHECKING, List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
|
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
|
||||||
from sglang.srt.disaggregation.utils import (
|
from sglang.srt.disaggregation.utils import (
|
||||||
|
FAKE_BOOTSTRAP_HOST,
|
||||||
DisaggregationMode,
|
DisaggregationMode,
|
||||||
FakeBootstrapHost,
|
|
||||||
KVClassType,
|
KVClassType,
|
||||||
MetadataBuffers,
|
MetadataBuffers,
|
||||||
ReqToMetadataIdxAllocator,
|
ReqToMetadataIdxAllocator,
|
||||||
@@ -51,7 +51,6 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
|
from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
|
||||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -68,35 +67,45 @@ class PrefillBootstrapQueue:
|
|||||||
metadata_buffers: MetadataBuffers,
|
metadata_buffers: MetadataBuffers,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
tp_size: int,
|
tp_size: int,
|
||||||
|
gpu_id: int,
|
||||||
bootstrap_port: int,
|
bootstrap_port: int,
|
||||||
gloo_group: ProcessGroup,
|
gloo_group: ProcessGroup,
|
||||||
transfer_backend: TransferBackend,
|
max_total_num_tokens: int,
|
||||||
|
decode_tp_size: int,
|
||||||
|
decode_dp_size: int,
|
||||||
scheduler: Scheduler,
|
scheduler: Scheduler,
|
||||||
|
pp_rank: int,
|
||||||
|
pp_size: int,
|
||||||
|
transfer_backend: TransferBackend,
|
||||||
):
|
):
|
||||||
self.token_to_kv_pool = token_to_kv_pool
|
self.token_to_kv_pool = token_to_kv_pool
|
||||||
self.draft_token_to_kv_pool = draft_token_to_kv_pool
|
self.draft_token_to_kv_pool = draft_token_to_kv_pool
|
||||||
|
|
||||||
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
|
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
|
||||||
|
|
||||||
self.metadata_buffers = metadata_buffers
|
self.metadata_buffers = metadata_buffers
|
||||||
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
self.transfer_backend = transfer_backend
|
self.decode_tp_size = decode_tp_size
|
||||||
self.scheduler = scheduler
|
self.decode_dp_size = decode_dp_size
|
||||||
self.kv_manager = self._init_kv_manager()
|
self.pp_rank = pp_rank
|
||||||
self.queue: List[Req] = []
|
self.pp_size = pp_size
|
||||||
self.gloo_group = gloo_group
|
self.gpu_id = gpu_id
|
||||||
self.bootstrap_port = bootstrap_port
|
self.bootstrap_port = bootstrap_port
|
||||||
|
self.queue: List[Req] = []
|
||||||
def store_prefill_results(self, idx: int, token_id: int):
|
self.pp_rank = pp_rank
|
||||||
assert token_id >= 0, f"token_id: {token_id} is negative"
|
self.pp_size = pp_size
|
||||||
output_id_buffer = self.metadata_buffers[0]
|
self.gloo_group = gloo_group
|
||||||
output_id_buffer[idx] = token_id
|
self.max_total_num_tokens = max_total_num_tokens
|
||||||
|
self.scheduler = scheduler
|
||||||
|
self.transfer_backend = transfer_backend
|
||||||
|
self.kv_manager = self._init_kv_manager()
|
||||||
|
|
||||||
def _init_kv_manager(self) -> BaseKVManager:
|
def _init_kv_manager(self) -> BaseKVManager:
|
||||||
kv_args = KVArgs()
|
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
|
||||||
|
kv_args = kv_args_class()
|
||||||
kv_args.engine_rank = self.tp_rank
|
kv_args.engine_rank = self.tp_rank
|
||||||
|
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
|
||||||
|
kv_args.prefill_pp_size = self.pp_size
|
||||||
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
||||||
self.token_to_kv_pool.get_contiguous_buf_infos()
|
self.token_to_kv_pool.get_contiguous_buf_infos()
|
||||||
)
|
)
|
||||||
@@ -115,12 +124,12 @@ class PrefillBootstrapQueue:
|
|||||||
kv_args.kv_data_lens = kv_data_lens
|
kv_args.kv_data_lens = kv_data_lens
|
||||||
kv_args.kv_item_lens = kv_item_lens
|
kv_args.kv_item_lens = kv_item_lens
|
||||||
|
|
||||||
# Define req -> input ids buffer
|
|
||||||
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
||||||
self.metadata_buffers.get_buf_infos()
|
self.metadata_buffers.get_buf_infos()
|
||||||
)
|
)
|
||||||
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||||
kv_args.gpu_id = self.scheduler.gpu_id
|
kv_args.gpu_id = self.scheduler.gpu_id
|
||||||
|
|
||||||
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
||||||
kv_manager = kv_manager_class(
|
kv_manager = kv_manager_class(
|
||||||
kv_args,
|
kv_args,
|
||||||
@@ -130,23 +139,39 @@ class PrefillBootstrapQueue:
|
|||||||
)
|
)
|
||||||
return kv_manager
|
return kv_manager
|
||||||
|
|
||||||
def add(self, req: Req) -> None:
|
def add(self, req: Req, num_kv_heads: int) -> None:
|
||||||
if req.bootstrap_host == FakeBootstrapHost:
|
if self._check_if_req_exceed_kv_capacity(req):
|
||||||
# Fake transfer for warmup reqs
|
return
|
||||||
|
|
||||||
|
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
|
||||||
kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
|
kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
|
||||||
else:
|
else:
|
||||||
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
|
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
|
||||||
|
|
||||||
|
dest_tp_ranks = [self.tp_rank]
|
||||||
|
|
||||||
req.disagg_kv_sender = kv_sender_class(
|
req.disagg_kv_sender = kv_sender_class(
|
||||||
mgr=self.kv_manager,
|
mgr=self.kv_manager,
|
||||||
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
||||||
bootstrap_room=req.bootstrap_room,
|
bootstrap_room=req.bootstrap_room,
|
||||||
|
dest_tp_ranks=dest_tp_ranks,
|
||||||
|
pp_rank=self.pp_rank,
|
||||||
)
|
)
|
||||||
self._process_req(req)
|
self._process_req(req)
|
||||||
self.queue.append(req)
|
self.queue.append(req)
|
||||||
|
|
||||||
def extend(self, reqs: List[Req]) -> None:
|
def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
|
||||||
for req in reqs:
|
for req in reqs:
|
||||||
self.add(req)
|
self.add(req, num_kv_heads)
|
||||||
|
|
||||||
|
def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
|
||||||
|
if len(req.origin_input_ids) > self.max_total_num_tokens:
|
||||||
|
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
|
||||||
|
logger.error(message)
|
||||||
|
prepare_abort(req, message)
|
||||||
|
self.scheduler.stream_output([req], req.return_logprob)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def _process_req(self, req: Req) -> None:
|
def _process_req(self, req: Req) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -154,19 +179,40 @@ class PrefillBootstrapQueue:
|
|||||||
"""
|
"""
|
||||||
req.sampling_params.max_new_tokens = 1
|
req.sampling_params.max_new_tokens = 1
|
||||||
|
|
||||||
def pop_bootstrapped(self) -> List[Req]:
|
def pop_bootstrapped(
|
||||||
"""pop the reqs which has finished bootstrapping"""
|
self,
|
||||||
|
return_failed_reqs: bool = False,
|
||||||
|
rids_to_check: Optional[List[str]] = None,
|
||||||
|
) -> List[Req]:
|
||||||
|
"""
|
||||||
|
pop the reqs which has finished bootstrapping
|
||||||
|
|
||||||
|
return_failed_reqs: For PP, on rank 0, also return the failed reqs to notify the next rank
|
||||||
|
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
|
||||||
|
"""
|
||||||
|
|
||||||
bootstrapped_reqs = []
|
bootstrapped_reqs = []
|
||||||
|
failed_reqs = []
|
||||||
indices_to_remove = set()
|
indices_to_remove = set()
|
||||||
|
|
||||||
if len(self.queue) == 0:
|
if len(self.queue) == 0:
|
||||||
|
if return_failed_reqs is False:
|
||||||
return []
|
return []
|
||||||
|
else:
|
||||||
|
return [], []
|
||||||
|
|
||||||
polls = poll_and_all_reduce(
|
polls = poll_and_all_reduce(
|
||||||
[req.disagg_kv_sender for req in self.queue], self.gloo_group
|
[req.disagg_kv_sender for req in self.queue], self.gloo_group
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, (req, poll) in enumerate(zip(self.queue, polls)):
|
for i, (req, poll) in enumerate(zip(self.queue, polls)):
|
||||||
|
|
||||||
|
if rids_to_check is not None:
|
||||||
|
# if req not in reqs_info_to_check, skip
|
||||||
|
if req.rid not in rids_to_check:
|
||||||
|
continue
|
||||||
|
# Either waiting for input or failed
|
||||||
|
assert poll == KVPoll.WaitingForInput or poll == KVPoll.Failed
|
||||||
|
|
||||||
if poll == KVPoll.Bootstrapping:
|
if poll == KVPoll.Bootstrapping:
|
||||||
continue
|
continue
|
||||||
elif poll == KVPoll.Failed:
|
elif poll == KVPoll.Failed:
|
||||||
@@ -181,9 +227,10 @@ class PrefillBootstrapQueue:
|
|||||||
)
|
)
|
||||||
self.scheduler.stream_output([req], req.return_logprob)
|
self.scheduler.stream_output([req], req.return_logprob)
|
||||||
indices_to_remove.add(i)
|
indices_to_remove.add(i)
|
||||||
|
failed_reqs.append(req)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# KV.WaitingForInput
|
# KV.WaitingForInput - init here
|
||||||
num_kv_indices = len(req.origin_input_ids)
|
num_kv_indices = len(req.origin_input_ids)
|
||||||
if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
|
if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
|
||||||
break
|
break
|
||||||
@@ -192,9 +239,9 @@ class PrefillBootstrapQueue:
|
|||||||
self.req_to_metadata_buffer_idx_allocator.alloc()
|
self.req_to_metadata_buffer_idx_allocator.alloc()
|
||||||
)
|
)
|
||||||
assert req.metadata_buffer_index is not None
|
assert req.metadata_buffer_index is not None
|
||||||
|
|
||||||
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
||||||
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
||||||
|
|
||||||
bootstrapped_reqs.append(req)
|
bootstrapped_reqs.append(req)
|
||||||
indices_to_remove.add(i)
|
indices_to_remove.add(i)
|
||||||
|
|
||||||
@@ -202,7 +249,10 @@ class PrefillBootstrapQueue:
|
|||||||
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if return_failed_reqs is False:
|
||||||
return bootstrapped_reqs
|
return bootstrapped_reqs
|
||||||
|
else:
|
||||||
|
return bootstrapped_reqs, failed_reqs
|
||||||
|
|
||||||
|
|
||||||
class SchedulerDisaggregationPrefillMixin:
|
class SchedulerDisaggregationPrefillMixin:
|
||||||
@@ -211,7 +261,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def event_loop_normal_disagg_prefill(self: Scheduler):
|
def event_loop_normal_disagg_prefill(self: Scheduler) -> None:
|
||||||
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -229,7 +279,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
or self.server_args.enable_sp_layernorm
|
or self.server_args.enable_sp_layernorm
|
||||||
):
|
):
|
||||||
batch, _ = self.prepare_dp_attn_batch(batch)
|
batch, _ = self.prepare_dp_attn_batch(batch)
|
||||||
|
|
||||||
self.cur_batch = batch
|
self.cur_batch = batch
|
||||||
|
|
||||||
if batch:
|
if batch:
|
||||||
@@ -250,7 +299,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
self.running_batch.batch_is_full = False
|
self.running_batch.batch_is_full = False
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def event_loop_overlap_disagg_prefill(self: Scheduler):
|
def event_loop_overlap_disagg_prefill(self: Scheduler) -> None:
|
||||||
self.result_queue = deque()
|
self.result_queue = deque()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -268,9 +317,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
or self.server_args.enable_sp_layernorm
|
or self.server_args.enable_sp_layernorm
|
||||||
):
|
):
|
||||||
batch, _ = self.prepare_dp_attn_batch(batch)
|
batch, _ = self.prepare_dp_attn_batch(batch)
|
||||||
|
|
||||||
self.cur_batch = batch
|
self.cur_batch = batch
|
||||||
|
|
||||||
if batch:
|
if batch:
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
self.result_queue.append((batch.copy(), result))
|
self.result_queue.append((batch.copy(), result))
|
||||||
@@ -287,6 +334,9 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
|
|
||||||
if self.last_batch:
|
if self.last_batch:
|
||||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||||
|
tmp_batch.next_batch_sampling_info = (
|
||||||
|
self.tp_worker.cur_sampling_info if batch else None
|
||||||
|
)
|
||||||
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
||||||
|
|
||||||
if len(self.disagg_prefill_inflight_queue) > 0:
|
if len(self.disagg_prefill_inflight_queue) > 0:
|
||||||
@@ -309,7 +359,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
launch_done: Optional[threading.Event] = None,
|
launch_done: Optional[threading.Event] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
||||||
Adapted from process_batch_result_prefill
|
Adapted from process_batch_result_prefill
|
||||||
"""
|
"""
|
||||||
(
|
(
|
||||||
@@ -325,7 +375,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
logprob_pt = 0
|
logprob_pt = 0
|
||||||
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
# Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
# wait
|
# wait
|
||||||
logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
|
logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
|
||||||
@@ -397,11 +447,15 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
# We need to remove the sync in the following function for overlap schedule.
|
# We need to remove the sync in the following function for overlap schedule.
|
||||||
self.set_next_batch_sampling_info_done(batch)
|
self.set_next_batch_sampling_info_done(batch)
|
||||||
|
|
||||||
def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
|
def process_disagg_prefill_inflight_queue(
|
||||||
|
self: Scheduler, rids_to_check: Optional[List[str]] = None
|
||||||
|
) -> List[Req]:
|
||||||
"""
|
"""
|
||||||
Poll the requests in the middle of transfer. If done, return the request.
|
Poll the requests in the middle of transfer. If done, return the request.
|
||||||
|
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
|
||||||
"""
|
"""
|
||||||
assert len(self.disagg_prefill_inflight_queue) > 0
|
if len(self.disagg_prefill_inflight_queue) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
done_reqs = []
|
done_reqs = []
|
||||||
|
|
||||||
@@ -413,6 +467,14 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
undone_reqs: List[Req] = []
|
undone_reqs: List[Req] = []
|
||||||
# Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue
|
# Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue
|
||||||
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
|
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
|
||||||
|
|
||||||
|
if rids_to_check is not None:
|
||||||
|
if req.rid not in rids_to_check:
|
||||||
|
undone_reqs.append(req)
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert poll == KVPoll.Success or poll == KVPoll.Failed
|
||||||
|
|
||||||
if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
|
if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
|
||||||
undone_reqs.append(req)
|
undone_reqs.append(req)
|
||||||
elif poll == KVPoll.Success: # transfer done
|
elif poll == KVPoll.Success: # transfer done
|
||||||
@@ -434,11 +496,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
|
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
)
|
)
|
||||||
done_reqs.append(req)
|
done_reqs.append(req)
|
||||||
|
else:
|
||||||
for req in done_reqs:
|
assert False, f"Unexpected polling state {poll=}"
|
||||||
self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
|
|
||||||
req.metadata_buffer_index
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stream requests which have finished transfer
|
# Stream requests which have finished transfer
|
||||||
self.stream_output(
|
self.stream_output(
|
||||||
@@ -446,9 +505,32 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
any(req.return_logprob for req in done_reqs),
|
any(req.return_logprob for req in done_reqs),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
for req in done_reqs:
|
||||||
|
req: Req
|
||||||
|
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
|
||||||
|
req.metadata_buffer_index = -1
|
||||||
|
|
||||||
self.disagg_prefill_inflight_queue = undone_reqs
|
self.disagg_prefill_inflight_queue = undone_reqs
|
||||||
|
|
||||||
|
return done_reqs
|
||||||
|
|
||||||
|
def get_transferred_rids(self: Scheduler) -> List[str]:
|
||||||
|
"""
|
||||||
|
Used by PP, get the transferred rids but **do not pop**
|
||||||
|
"""
|
||||||
|
polls = poll_and_all_reduce(
|
||||||
|
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
|
||||||
|
self.tp_worker.get_tp_group().cpu_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
transferred_rids: List[str] = []
|
||||||
|
|
||||||
|
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
|
||||||
|
if poll == KVPoll.Success or poll == KVPoll.Failed:
|
||||||
|
transferred_rids.append(req.rid)
|
||||||
|
|
||||||
|
return transferred_rids
|
||||||
|
|
||||||
def process_prefill_chunk(self: Scheduler) -> None:
|
def process_prefill_chunk(self: Scheduler) -> None:
|
||||||
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
||||||
if self.chunked_req:
|
if self.chunked_req:
|
||||||
|
|||||||
@@ -14,15 +14,15 @@ import requests
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from sglang.srt.utils import get_ip, get_local_ip_by_remote
|
from sglang.srt.utils import get_ip
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.schedule_batch import Req
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
|
|
||||||
FakeBootstrapHost = "2.2.2.2"
|
#########################
|
||||||
|
# Constants & Enums
|
||||||
# env var for testing failure, convert to float explicitly
|
#########################
|
||||||
FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
|
FAKE_BOOTSTRAP_HOST = "2.2.2.2"
|
||||||
|
|
||||||
|
|
||||||
class DisaggregationMode(Enum):
|
class DisaggregationMode(Enum):
|
||||||
@@ -31,6 +31,14 @@ class DisaggregationMode(Enum):
|
|||||||
DECODE = "decode"
|
DECODE = "decode"
|
||||||
|
|
||||||
|
|
||||||
|
#########################
|
||||||
|
# Synchronization
|
||||||
|
#########################
|
||||||
|
|
||||||
|
# env var for testing failure, convert to float explicitly
|
||||||
|
FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
|
||||||
|
|
||||||
|
|
||||||
def poll_and_all_reduce(pollers, gloo_group):
|
def poll_and_all_reduce(pollers, gloo_group):
|
||||||
# at a certain prob, the poll is failed to simulate failure
|
# at a certain prob, the poll is failed to simulate failure
|
||||||
if FAILURE_PROB > 0:
|
if FAILURE_PROB > 0:
|
||||||
@@ -47,6 +55,11 @@ def poll_and_all_reduce(pollers, gloo_group):
|
|||||||
return tensor_to_reduce.tolist()
|
return tensor_to_reduce.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
#########################
|
||||||
|
# Metadata Buffers
|
||||||
|
#########################
|
||||||
|
|
||||||
|
|
||||||
class ReqToMetadataIdxAllocator:
|
class ReqToMetadataIdxAllocator:
|
||||||
"""A memory pool that maps a request to its first output token location."""
|
"""A memory pool that maps a request to its first output token location."""
|
||||||
|
|
||||||
@@ -70,138 +83,6 @@ class ReqToMetadataIdxAllocator:
|
|||||||
self.free_slots.append(free_index)
|
self.free_slots.append(free_index)
|
||||||
|
|
||||||
|
|
||||||
class TransferBackend(Enum):
|
|
||||||
MOONCAKE = "mooncake"
|
|
||||||
NIXL = "nixl"
|
|
||||||
FAKE = "fake"
|
|
||||||
|
|
||||||
|
|
||||||
class KVClassType(Enum):
|
|
||||||
MANAGER = "manager"
|
|
||||||
SENDER = "sender"
|
|
||||||
RECEIVER = "receiver"
|
|
||||||
BOOTSTRAP_SERVER = "bootstrap_server"
|
|
||||||
|
|
||||||
|
|
||||||
def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
||||||
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
|
||||||
|
|
||||||
if transfer_backend == TransferBackend.MOONCAKE:
|
|
||||||
from sglang.srt.disaggregation.mooncake import (
|
|
||||||
MooncakeKVBootstrapServer,
|
|
||||||
MooncakeKVManager,
|
|
||||||
MooncakeKVReceiver,
|
|
||||||
MooncakeKVSender,
|
|
||||||
)
|
|
||||||
|
|
||||||
class_mapping = {
|
|
||||||
KVClassType.MANAGER: MooncakeKVManager,
|
|
||||||
KVClassType.SENDER: MooncakeKVSender,
|
|
||||||
KVClassType.RECEIVER: (MooncakeKVReceiver),
|
|
||||||
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
|
||||||
}
|
|
||||||
return class_mapping.get(class_type)
|
|
||||||
if transfer_backend == TransferBackend.NIXL:
|
|
||||||
from sglang.srt.disaggregation.nixl import (
|
|
||||||
NixlKVBootstrapServer,
|
|
||||||
NixlKVManager,
|
|
||||||
NixlKVReceiver,
|
|
||||||
NixlKVSender,
|
|
||||||
)
|
|
||||||
|
|
||||||
class_mapping = {
|
|
||||||
KVClassType.MANAGER: NixlKVManager,
|
|
||||||
KVClassType.SENDER: NixlKVSender,
|
|
||||||
KVClassType.RECEIVER: (NixlKVReceiver),
|
|
||||||
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
|
|
||||||
}
|
|
||||||
return class_mapping.get(class_type)
|
|
||||||
if transfer_backend == TransferBackend.FAKE:
|
|
||||||
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
|
||||||
|
|
||||||
class_mapping = {
|
|
||||||
KVClassType.SENDER: FakeKVSender,
|
|
||||||
KVClassType.RECEIVER: (FakeKVReceiver),
|
|
||||||
}
|
|
||||||
return class_mapping.get(class_type)
|
|
||||||
|
|
||||||
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
|
||||||
|
|
||||||
|
|
||||||
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
|
||||||
# 1. The page is guaranteed to be full except the last page.
|
|
||||||
# 2. page index = kv_index // page_size
|
|
||||||
# The return vector is kv_indices[::page_size] // page_size
|
|
||||||
if page_size == 1: # shortcut
|
|
||||||
return kv_indices
|
|
||||||
|
|
||||||
return kv_indices[::page_size] // page_size
|
|
||||||
|
|
||||||
|
|
||||||
def kv_to_page_num(num_kv_indices: int, page_size: int):
|
|
||||||
# ceil(num_kv_indices / page_size)
|
|
||||||
return (num_kv_indices + page_size - 1) // page_size
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class PDRegistryRequest:
|
|
||||||
"""A request to register a machine itself to the LB."""
|
|
||||||
|
|
||||||
mode: str
|
|
||||||
registry_url: str
|
|
||||||
bootstrap_port: Optional[int] = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.mode == "prefill" and self.bootstrap_port is None:
|
|
||||||
raise ValueError("Bootstrap port must be set in PREFILL mode.")
|
|
||||||
elif self.mode == "decode" and self.bootstrap_port is not None:
|
|
||||||
raise ValueError("Bootstrap port must not be set in DECODE mode.")
|
|
||||||
elif self.mode not in ["prefill", "decode"]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def register_disaggregation_server(
|
|
||||||
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
|
|
||||||
):
|
|
||||||
boostrap_port = bootstrap_port if mode == "prefill" else None
|
|
||||||
registry_request = PDRegistryRequest(
|
|
||||||
mode=mode,
|
|
||||||
registry_url=f"http://{get_ip()}:{server_port}",
|
|
||||||
bootstrap_port=boostrap_port,
|
|
||||||
)
|
|
||||||
res = requests.post(
|
|
||||||
f"{pdlb_url}/register",
|
|
||||||
json=dataclasses.asdict(registry_request),
|
|
||||||
)
|
|
||||||
if res.status_code != 200:
|
|
||||||
warnings.warn(
|
|
||||||
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_mla_backend(target_kv_pool) -> bool:
|
|
||||||
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
|
||||||
|
|
||||||
return isinstance(target_kv_pool, MLATokenToKVPool)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_abort(req: Req, error_message: str, status_code=None):
|
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT
|
|
||||||
|
|
||||||
# populate finish metadata and stream output
|
|
||||||
req.finished_reason = FINISH_ABORT(error_message, status_code)
|
|
||||||
|
|
||||||
if req.return_logprob:
|
|
||||||
req.input_token_logprobs_val = []
|
|
||||||
req.input_token_logprobs_idx = []
|
|
||||||
req.input_top_logprobs_val = []
|
|
||||||
req.input_top_logprobs_idx = []
|
|
||||||
req.input_token_ids_logprobs_val = []
|
|
||||||
req.input_token_ids_logprobs_idx = []
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataBuffers:
|
class MetadataBuffers:
|
||||||
def __init__(self, size: int, max_top_logprobs_num: int = 128):
|
def __init__(self, size: int, max_top_logprobs_num: int = 128):
|
||||||
# TODO: abort top_logprobs_num > 128 in PD
|
# TODO: abort top_logprobs_num > 128 in PD
|
||||||
@@ -282,37 +163,160 @@ class MetadataBuffers:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class FastQueue:
|
#########################
|
||||||
def __init__(self):
|
# Transfer Backend
|
||||||
self._buf = deque()
|
#########################
|
||||||
self._cond = threading.Condition()
|
|
||||||
|
|
||||||
def put(self, item):
|
|
||||||
with self._cond:
|
|
||||||
self._buf.append(item)
|
|
||||||
# wake up a thread of wait()
|
|
||||||
self._cond.notify()
|
|
||||||
|
|
||||||
def get(self):
|
|
||||||
with self._cond:
|
|
||||||
# if queue is empty ,block until is notified()
|
|
||||||
while not self._buf:
|
|
||||||
self._cond.wait()
|
|
||||||
return self._buf.popleft()
|
|
||||||
|
|
||||||
|
|
||||||
def group_concurrent_contiguous(
|
class TransferBackend(Enum):
|
||||||
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
MOONCAKE = "mooncake"
|
||||||
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
NIXL = "nixl"
|
||||||
"""Vectorised NumPy implementation."""
|
FAKE = "fake"
|
||||||
if src_indices.size == 0:
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
|
|
||||||
src_groups = np.split(src_indices, brk)
|
|
||||||
dst_groups = np.split(dst_indices, brk)
|
|
||||||
|
|
||||||
src_groups = [g.tolist() for g in src_groups]
|
class KVClassType(Enum):
|
||||||
dst_groups = [g.tolist() for g in dst_groups]
|
KVARGS = "kvargs"
|
||||||
|
MANAGER = "manager"
|
||||||
|
SENDER = "sender"
|
||||||
|
RECEIVER = "receiver"
|
||||||
|
BOOTSTRAP_SERVER = "bootstrap_server"
|
||||||
|
|
||||||
return src_groups, dst_groups
|
|
||||||
|
def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
||||||
|
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
||||||
|
|
||||||
|
if transfer_backend == TransferBackend.MOONCAKE:
|
||||||
|
from sglang.srt.disaggregation.base import KVArgs
|
||||||
|
from sglang.srt.disaggregation.mooncake import (
|
||||||
|
MooncakeKVBootstrapServer,
|
||||||
|
MooncakeKVManager,
|
||||||
|
MooncakeKVReceiver,
|
||||||
|
MooncakeKVSender,
|
||||||
|
)
|
||||||
|
|
||||||
|
class_mapping = {
|
||||||
|
KVClassType.KVARGS: KVArgs,
|
||||||
|
KVClassType.MANAGER: MooncakeKVManager,
|
||||||
|
KVClassType.SENDER: MooncakeKVSender,
|
||||||
|
KVClassType.RECEIVER: (MooncakeKVReceiver),
|
||||||
|
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
||||||
|
}
|
||||||
|
return class_mapping.get(class_type)
|
||||||
|
if transfer_backend == TransferBackend.NIXL:
|
||||||
|
from sglang.srt.disaggregation.base import KVArgs
|
||||||
|
from sglang.srt.disaggregation.nixl import (
|
||||||
|
NixlKVBootstrapServer,
|
||||||
|
NixlKVManager,
|
||||||
|
NixlKVReceiver,
|
||||||
|
NixlKVSender,
|
||||||
|
)
|
||||||
|
|
||||||
|
class_mapping = {
|
||||||
|
KVClassType.KVARGS: KVArgs,
|
||||||
|
KVClassType.MANAGER: NixlKVManager,
|
||||||
|
KVClassType.SENDER: NixlKVSender,
|
||||||
|
KVClassType.RECEIVER: (NixlKVReceiver),
|
||||||
|
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
|
||||||
|
}
|
||||||
|
return class_mapping.get(class_type)
|
||||||
|
if transfer_backend == TransferBackend.FAKE:
|
||||||
|
from sglang.srt.disaggregation.base import KVArgs
|
||||||
|
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
||||||
|
|
||||||
|
class_mapping = {
|
||||||
|
KVClassType.KVARGS: KVArgs,
|
||||||
|
KVClassType.SENDER: FakeKVSender,
|
||||||
|
KVClassType.RECEIVER: (FakeKVReceiver),
|
||||||
|
}
|
||||||
|
return class_mapping.get(class_type)
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
||||||
|
|
||||||
|
|
||||||
|
#########################
|
||||||
|
# KV Pages
|
||||||
|
#########################
|
||||||
|
|
||||||
|
|
||||||
|
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
||||||
|
# 1. The page is guaranteed to be full except the last page.
|
||||||
|
# 2. page index = kv_index // page_size
|
||||||
|
# The return vector is kv_indices[::page_size] // page_size
|
||||||
|
if page_size == 1: # shortcut
|
||||||
|
return kv_indices
|
||||||
|
|
||||||
|
return kv_indices[::page_size] // page_size
|
||||||
|
|
||||||
|
|
||||||
|
def kv_to_page_num(num_kv_indices: int, page_size: int):
|
||||||
|
# ceil(num_kv_indices / page_size)
|
||||||
|
return (num_kv_indices + page_size - 1) // page_size
|
||||||
|
|
||||||
|
|
||||||
|
#########################
|
||||||
|
# PDLB Registry
|
||||||
|
#########################
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class PDRegistryRequest:
|
||||||
|
"""A request to register a machine itself to the LB."""
|
||||||
|
|
||||||
|
mode: str
|
||||||
|
registry_url: str
|
||||||
|
bootstrap_port: Optional[int] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.mode == "prefill" and self.bootstrap_port is None:
|
||||||
|
raise ValueError("Bootstrap port must be set in PREFILL mode.")
|
||||||
|
elif self.mode == "decode" and self.bootstrap_port is not None:
|
||||||
|
raise ValueError("Bootstrap port must not be set in DECODE mode.")
|
||||||
|
elif self.mode not in ["prefill", "decode"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_disaggregation_server(
|
||||||
|
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
|
||||||
|
):
|
||||||
|
boostrap_port = bootstrap_port if mode == "prefill" else None
|
||||||
|
registry_request = PDRegistryRequest(
|
||||||
|
mode=mode,
|
||||||
|
registry_url=f"http://{get_ip()}:{server_port}",
|
||||||
|
bootstrap_port=boostrap_port,
|
||||||
|
)
|
||||||
|
res = requests.post(
|
||||||
|
f"{pdlb_url}/register",
|
||||||
|
json=dataclasses.asdict(registry_request),
|
||||||
|
)
|
||||||
|
if res.status_code != 200:
|
||||||
|
warnings.warn(
|
||||||
|
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
#########################
|
||||||
|
# Misc
|
||||||
|
#########################
|
||||||
|
|
||||||
|
|
||||||
|
def is_mla_backend(target_kv_pool) -> bool:
|
||||||
|
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
||||||
|
|
||||||
|
return isinstance(target_kv_pool, MLATokenToKVPool)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_abort(req: Req, error_message: str, status_code=None):
|
||||||
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT
|
||||||
|
|
||||||
|
# populate finish metadata and stream output
|
||||||
|
req.finished_reason = FINISH_ABORT(error_message, status_code)
|
||||||
|
|
||||||
|
if req.return_logprob:
|
||||||
|
req.input_token_logprobs_val = []
|
||||||
|
req.input_token_logprobs_idx = []
|
||||||
|
req.input_top_logprobs_val = []
|
||||||
|
req.input_top_logprobs_idx = []
|
||||||
|
req.input_token_ids_logprobs_val = []
|
||||||
|
req.input_token_ids_logprobs_idx = []
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||||
|
|
||||||
from sglang.srt.disaggregation.utils import (
|
from sglang.srt.disaggregation.utils import (
|
||||||
FakeBootstrapHost,
|
FAKE_BOOTSTRAP_HOST,
|
||||||
register_disaggregation_server,
|
register_disaggregation_server,
|
||||||
)
|
)
|
||||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||||
@@ -878,7 +878,7 @@ def _wait_and_warmup(
|
|||||||
"max_new_tokens": 8,
|
"max_new_tokens": 8,
|
||||||
"ignore_eos": True,
|
"ignore_eos": True,
|
||||||
},
|
},
|
||||||
"bootstrap_host": [FakeBootstrapHost] * server_args.dp_size,
|
"bootstrap_host": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size,
|
||||||
# This is a hack to ensure fake transfer is enabled during prefill warmup
|
# This is a hack to ensure fake transfer is enabled during prefill warmup
|
||||||
# ensure each dp rank has a unique bootstrap_room during prefill warmup
|
# ensure each dp rank has a unique bootstrap_room during prefill warmup
|
||||||
"bootstrap_room": [
|
"bootstrap_room": [
|
||||||
|
|||||||
@@ -619,7 +619,7 @@ class Scheduler(
|
|||||||
self.disaggregation_mode == DisaggregationMode.DECODE
|
self.disaggregation_mode == DisaggregationMode.DECODE
|
||||||
): # *2 for the headroom.
|
): # *2 for the headroom.
|
||||||
buffer_size = (self.req_to_token_pool.size) * 2
|
buffer_size = (self.req_to_token_pool.size) * 2
|
||||||
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||||
buffer_size
|
buffer_size
|
||||||
)
|
)
|
||||||
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
||||||
@@ -627,7 +627,7 @@ class Scheduler(
|
|||||||
# The decode requests polling kv cache
|
# The decode requests polling kv cache
|
||||||
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
||||||
gloo_group=self.attn_tp_cpu_group,
|
gloo_group=self.attn_tp_cpu_group,
|
||||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
||||||
metadata_buffers=self.disagg_metadata_buffers,
|
metadata_buffers=self.disagg_metadata_buffers,
|
||||||
scheduler=self,
|
scheduler=self,
|
||||||
tree_cache=self.tree_cache,
|
tree_cache=self.tree_cache,
|
||||||
@@ -642,7 +642,7 @@ class Scheduler(
|
|||||||
if self.draft_worker is None
|
if self.draft_worker is None
|
||||||
else self.draft_worker.model_runner.token_to_kv_pool
|
else self.draft_worker.model_runner.token_to_kv_pool
|
||||||
),
|
),
|
||||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
||||||
metadata_buffers=self.disagg_metadata_buffers,
|
metadata_buffers=self.disagg_metadata_buffers,
|
||||||
scheduler=self,
|
scheduler=self,
|
||||||
transfer_queue=self.disagg_decode_transfer_queue,
|
transfer_queue=self.disagg_decode_transfer_queue,
|
||||||
@@ -660,7 +660,7 @@ class Scheduler(
|
|||||||
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
# *2 for the headroom.
|
# *2 for the headroom.
|
||||||
buffer_size = self.max_running_requests * 2
|
buffer_size = self.max_running_requests * 2
|
||||||
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||||
buffer_size
|
buffer_size
|
||||||
)
|
)
|
||||||
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
||||||
@@ -672,14 +672,20 @@ class Scheduler(
|
|||||||
if self.draft_worker is None
|
if self.draft_worker is None
|
||||||
else self.draft_worker.model_runner.token_to_kv_pool
|
else self.draft_worker.model_runner.token_to_kv_pool
|
||||||
),
|
),
|
||||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
||||||
metadata_buffers=self.disagg_metadata_buffers,
|
metadata_buffers=self.disagg_metadata_buffers,
|
||||||
tp_rank=self.tp_rank,
|
tp_rank=self.tp_rank,
|
||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
|
gpu_id=self.gpu_id,
|
||||||
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
||||||
gloo_group=self.attn_tp_cpu_group,
|
gloo_group=self.attn_tp_cpu_group,
|
||||||
transfer_backend=self.transfer_backend,
|
max_total_num_tokens=self.max_total_num_tokens,
|
||||||
|
decode_tp_size=self.server_args.disaggregation_decode_tp,
|
||||||
|
decode_dp_size=self.server_args.disaggregation_decode_dp,
|
||||||
scheduler=self,
|
scheduler=self,
|
||||||
|
pp_rank=self.pp_rank,
|
||||||
|
pp_size=self.pp_size,
|
||||||
|
transfer_backend=self.transfer_backend,
|
||||||
)
|
)
|
||||||
# The prefill requests that are in the middle of kv sending
|
# The prefill requests that are in the middle of kv sending
|
||||||
self.disagg_prefill_inflight_queue: List[Req] = []
|
self.disagg_prefill_inflight_queue: List[Req] = []
|
||||||
@@ -1110,7 +1116,9 @@ class Scheduler(
|
|||||||
def _add_request_to_queue(self, req: Req):
|
def _add_request_to_queue(self, req: Req):
|
||||||
req.queue_time_start = time.perf_counter()
|
req.queue_time_start = time.perf_counter()
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
self.disagg_prefill_bootstrap_queue.add(req)
|
self.disagg_prefill_bootstrap_queue.add(
|
||||||
|
req, self.model_config.num_key_value_heads
|
||||||
|
)
|
||||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
self.disagg_decode_prealloc_queue.add(req)
|
self.disagg_decode_prealloc_queue.add(req)
|
||||||
else:
|
else:
|
||||||
@@ -1118,7 +1126,9 @@ class Scheduler(
|
|||||||
|
|
||||||
def _extend_requests_to_queue(self, reqs: List[Req]):
|
def _extend_requests_to_queue(self, reqs: List[Req]):
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
self.disagg_prefill_bootstrap_queue.extend(reqs)
|
self.disagg_prefill_bootstrap_queue.extend(
|
||||||
|
reqs, self.model_config.num_key_value_heads
|
||||||
|
)
|
||||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
# If this is a decode server, we put the request to the decode pending prealloc queue
|
# If this is a decode server, we put the request to the decode pending prealloc queue
|
||||||
self.disagg_decode_prealloc_queue.extend(reqs)
|
self.disagg_decode_prealloc_queue.extend(reqs)
|
||||||
|
|||||||
@@ -227,6 +227,9 @@ class ServerArgs:
|
|||||||
disaggregation_mode: str = "null"
|
disaggregation_mode: str = "null"
|
||||||
disaggregation_transfer_backend: str = "mooncake"
|
disaggregation_transfer_backend: str = "mooncake"
|
||||||
disaggregation_bootstrap_port: int = 8998
|
disaggregation_bootstrap_port: int = 8998
|
||||||
|
disaggregation_decode_tp: Optional[int] = None
|
||||||
|
disaggregation_decode_dp: Optional[int] = None
|
||||||
|
disaggregation_prefill_pp: Optional[int] = 1
|
||||||
disaggregation_ib_device: Optional[str] = None
|
disaggregation_ib_device: Optional[str] = None
|
||||||
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
|
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
|
||||||
pdlb_url: Optional[str] = None
|
pdlb_url: Optional[str] = None
|
||||||
@@ -505,12 +508,27 @@ class ServerArgs:
|
|||||||
self.triton_attention_num_kv_splits = 16
|
self.triton_attention_num_kv_splits = 16
|
||||||
|
|
||||||
# PD disaggregation
|
# PD disaggregation
|
||||||
if self.disaggregation_mode == "prefill":
|
if self.disaggregation_mode == "decode":
|
||||||
self.disable_cuda_graph = True
|
assert (
|
||||||
logger.warning("Cuda graph is disabled for prefill server")
|
self.disaggregation_decode_tp is None
|
||||||
elif self.disaggregation_mode == "decode":
|
), "Cannot set --disaggregation-decode-tp for the decode engine."
|
||||||
|
assert (
|
||||||
|
self.disaggregation_decode_dp is None
|
||||||
|
), "Cannot set --disaggregation-decode-dp for the decode engine."
|
||||||
|
|
||||||
self.disable_radix_cache = True
|
self.disable_radix_cache = True
|
||||||
logger.warning("KV cache is forced as chunk cache for decode server")
|
logger.warning("KV cache is forced as chunk cache for decode server")
|
||||||
|
elif self.disaggregation_mode == "prefill":
|
||||||
|
if self.disaggregation_decode_tp is None:
|
||||||
|
self.disaggregation_decode_tp = self.tp_size
|
||||||
|
if self.disaggregation_decode_dp is None:
|
||||||
|
self.disaggregation_decode_dp = self.dp_size
|
||||||
|
|
||||||
|
self.disaggregation_prefill_pp = self.pp_size
|
||||||
|
self.validate_disagg_tp_size(self.tp_size, self.disaggregation_decode_tp)
|
||||||
|
|
||||||
|
self.disable_cuda_graph = True
|
||||||
|
logger.warning("Cuda graph is disabled for prefill server")
|
||||||
|
|
||||||
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
|
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
|
||||||
"1" if self.enable_torch_compile else "0"
|
"1" if self.enable_torch_compile else "0"
|
||||||
@@ -520,6 +538,14 @@ class ServerArgs:
|
|||||||
"1" if self.disable_outlines_disk_cache else "0"
|
"1" if self.disable_outlines_disk_cache else "0"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
|
||||||
|
larger_tp = max(decode_tp, prefill_tp)
|
||||||
|
smaller_tp = min(decode_tp, prefill_tp)
|
||||||
|
assert larger_tp % smaller_tp == 0, (
|
||||||
|
"Different tp size is supported only when one tp is multiple of the other. "
|
||||||
|
f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: argparse.ArgumentParser):
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
# Model and port args
|
# Model and port args
|
||||||
@@ -1512,6 +1538,24 @@ class ServerArgs:
|
|||||||
default=ServerArgs.disaggregation_bootstrap_port,
|
default=ServerArgs.disaggregation_bootstrap_port,
|
||||||
help="Bootstrap server port on the prefill server. Default is 8998.",
|
help="Bootstrap server port on the prefill server. Default is 8998.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disaggregation-decode-tp",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.disaggregation_decode_tp,
|
||||||
|
help="Decode tp size. If not set, it matches the tp size of the current engine. This is only set on the prefill server.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disaggregation-decode-dp",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.disaggregation_decode_dp,
|
||||||
|
help="Decode dp size. If not set, it matches the dp size of the current engine. This is only set on the prefill server.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disaggregation-prefill-pp",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.disaggregation_prefill_pp,
|
||||||
|
help="Prefill pp size. If not set, it is default to 1. This is only set on the decode server.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disaggregation-ib-device",
|
"--disaggregation-ib-device",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
Reference in New Issue
Block a user