[PD] Update prefill.py (#7190)
This commit is contained in:
@@ -1,23 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
|
||||
|
||||
class KVArgs:
|
||||
engine_rank: int
|
||||
kv_data_ptrs: list[int]
|
||||
kv_data_lens: list[int]
|
||||
kv_item_lens: list[int]
|
||||
aux_data_ptrs: list[int]
|
||||
aux_data_lens: list[int]
|
||||
aux_item_lens: list[int]
|
||||
kv_data_ptrs: List[int]
|
||||
kv_data_lens: List[int]
|
||||
kv_item_lens: List[int]
|
||||
aux_data_ptrs: List[int]
|
||||
aux_data_lens: List[int]
|
||||
aux_item_lens: List[int]
|
||||
ib_device: str
|
||||
ib_traffic_class: str
|
||||
gpu_id: int
|
||||
# for different tp
|
||||
decode_tp_size: int
|
||||
# for pp prefill
|
||||
prefill_pp_size: int
|
||||
|
||||
|
||||
class KVPoll:
|
||||
@@ -45,7 +54,12 @@ class BaseKVSender(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int
|
||||
self,
|
||||
mgr: BaseKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: int,
|
||||
dest_tp_ranks: List[int],
|
||||
pp_rank: int,
|
||||
): ...
|
||||
|
||||
@abstractmethod
|
||||
|
||||
42
python/sglang/srt/disaggregation/common/utils.py
Normal file
42
python/sglang/srt/disaggregation/common/utils.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import threading
|
||||
from collections import deque
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
|
||||
class FastQueue:
|
||||
def __init__(self):
|
||||
self._buf = deque()
|
||||
self._cond = threading.Condition()
|
||||
|
||||
def put(self, item):
|
||||
with self._cond:
|
||||
self._buf.append(item)
|
||||
# wake up a thread of wait()
|
||||
self._cond.notify()
|
||||
|
||||
def get(self):
|
||||
with self._cond:
|
||||
# if queue is empty ,block until is notified()
|
||||
while not self._buf:
|
||||
self._cond.wait()
|
||||
return self._buf.popleft()
|
||||
|
||||
|
||||
def group_concurrent_contiguous(
|
||||
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
||||
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
||||
"""Vectorised NumPy implementation."""
|
||||
if src_indices.size == 0:
|
||||
return [], []
|
||||
|
||||
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
|
||||
src_groups = np.split(src_indices, brk)
|
||||
dst_groups = np.split(dst_indices, brk)
|
||||
|
||||
src_groups = [g.tolist() for g in src_groups]
|
||||
dst_groups = [g.tolist() for g in dst_groups]
|
||||
|
||||
return src_groups, dst_groups
|
||||
@@ -33,8 +33,8 @@ from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
FAKE_BOOTSTRAP_HOST,
|
||||
DisaggregationMode,
|
||||
FakeBootstrapHost,
|
||||
KVClassType,
|
||||
MetadataBuffers,
|
||||
ReqToMetadataIdxAllocator,
|
||||
@@ -207,7 +207,7 @@ class DecodePreallocQueue:
|
||||
|
||||
def add(self, req: Req) -> None:
|
||||
"""Add a request to the pending queue."""
|
||||
if req.bootstrap_host == FakeBootstrapHost:
|
||||
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
|
||||
# Fake transfer for warmup reqs
|
||||
kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
|
||||
else:
|
||||
|
||||
@@ -17,7 +17,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# For warmup reqs, we don't kv transfer, we use the fake sender and receiver
|
||||
class FakeKVSender(BaseKVSender):
|
||||
def __init__(self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int):
|
||||
def __init__(
|
||||
self,
|
||||
mgr: BaseKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: int,
|
||||
dest_tp_ranks: List[int],
|
||||
pp_rank: int,
|
||||
):
|
||||
self.has_sent = False
|
||||
|
||||
def poll(self) -> KVPoll:
|
||||
|
||||
@@ -28,12 +28,12 @@ from sglang.srt.disaggregation.base.conn import (
|
||||
KVArgs,
|
||||
KVPoll,
|
||||
)
|
||||
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
DisaggregationMode,
|
||||
from sglang.srt.disaggregation.common.utils import (
|
||||
FastQueue,
|
||||
group_concurrent_contiguous,
|
||||
)
|
||||
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_free_port,
|
||||
@@ -677,7 +677,12 @@ class MooncakeKVManager(BaseKVManager):
|
||||
class MooncakeKVSender(BaseKVSender):
|
||||
|
||||
def __init__(
|
||||
self, mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: int
|
||||
self,
|
||||
mgr: MooncakeKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: int,
|
||||
dest_tp_ranks: List[int],
|
||||
pp_rank: int,
|
||||
):
|
||||
self.kv_mgr = mgr
|
||||
self.bootstrap_room = bootstrap_room
|
||||
|
||||
@@ -24,10 +24,8 @@ from sglang.srt.disaggregation.common.conn import (
|
||||
CommonKVManager,
|
||||
CommonKVReceiver,
|
||||
)
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
DisaggregationMode,
|
||||
group_concurrent_contiguous,
|
||||
)
|
||||
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_local_ip_by_remote
|
||||
|
||||
@@ -350,7 +348,14 @@ class NixlKVManager(CommonKVManager):
|
||||
|
||||
class NixlKVSender(BaseKVSender):
|
||||
|
||||
def __init__(self, mgr: NixlKVManager, bootstrap_addr: str, bootstrap_room: int):
|
||||
def __init__(
|
||||
self,
|
||||
mgr: NixlKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: int,
|
||||
dest_tp_ranks: List[int],
|
||||
pp_rank: int,
|
||||
):
|
||||
self.kv_mgr = mgr
|
||||
self.bootstrap_room = bootstrap_room
|
||||
self.aux_index = None
|
||||
|
||||
@@ -27,10 +27,10 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
|
||||
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
FAKE_BOOTSTRAP_HOST,
|
||||
DisaggregationMode,
|
||||
FakeBootstrapHost,
|
||||
KVClassType,
|
||||
MetadataBuffers,
|
||||
ReqToMetadataIdxAllocator,
|
||||
@@ -51,7 +51,6 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -68,35 +67,45 @@ class PrefillBootstrapQueue:
|
||||
metadata_buffers: MetadataBuffers,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
gpu_id: int,
|
||||
bootstrap_port: int,
|
||||
gloo_group: ProcessGroup,
|
||||
transfer_backend: TransferBackend,
|
||||
max_total_num_tokens: int,
|
||||
decode_tp_size: int,
|
||||
decode_dp_size: int,
|
||||
scheduler: Scheduler,
|
||||
pp_rank: int,
|
||||
pp_size: int,
|
||||
transfer_backend: TransferBackend,
|
||||
):
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.draft_token_to_kv_pool = draft_token_to_kv_pool
|
||||
|
||||
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
|
||||
|
||||
self.metadata_buffers = metadata_buffers
|
||||
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.transfer_backend = transfer_backend
|
||||
self.scheduler = scheduler
|
||||
self.kv_manager = self._init_kv_manager()
|
||||
self.queue: List[Req] = []
|
||||
self.gloo_group = gloo_group
|
||||
self.decode_tp_size = decode_tp_size
|
||||
self.decode_dp_size = decode_dp_size
|
||||
self.pp_rank = pp_rank
|
||||
self.pp_size = pp_size
|
||||
self.gpu_id = gpu_id
|
||||
self.bootstrap_port = bootstrap_port
|
||||
|
||||
def store_prefill_results(self, idx: int, token_id: int):
|
||||
assert token_id >= 0, f"token_id: {token_id} is negative"
|
||||
output_id_buffer = self.metadata_buffers[0]
|
||||
output_id_buffer[idx] = token_id
|
||||
self.queue: List[Req] = []
|
||||
self.pp_rank = pp_rank
|
||||
self.pp_size = pp_size
|
||||
self.gloo_group = gloo_group
|
||||
self.max_total_num_tokens = max_total_num_tokens
|
||||
self.scheduler = scheduler
|
||||
self.transfer_backend = transfer_backend
|
||||
self.kv_manager = self._init_kv_manager()
|
||||
|
||||
def _init_kv_manager(self) -> BaseKVManager:
|
||||
kv_args = KVArgs()
|
||||
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
|
||||
kv_args = kv_args_class()
|
||||
kv_args.engine_rank = self.tp_rank
|
||||
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
|
||||
kv_args.prefill_pp_size = self.pp_size
|
||||
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
||||
self.token_to_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
@@ -115,12 +124,12 @@ class PrefillBootstrapQueue:
|
||||
kv_args.kv_data_lens = kv_data_lens
|
||||
kv_args.kv_item_lens = kv_item_lens
|
||||
|
||||
# Define req -> input ids buffer
|
||||
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
||||
self.metadata_buffers.get_buf_infos()
|
||||
)
|
||||
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||
kv_args.gpu_id = self.scheduler.gpu_id
|
||||
|
||||
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
||||
kv_manager = kv_manager_class(
|
||||
kv_args,
|
||||
@@ -130,23 +139,39 @@ class PrefillBootstrapQueue:
|
||||
)
|
||||
return kv_manager
|
||||
|
||||
def add(self, req: Req) -> None:
|
||||
if req.bootstrap_host == FakeBootstrapHost:
|
||||
# Fake transfer for warmup reqs
|
||||
def add(self, req: Req, num_kv_heads: int) -> None:
|
||||
if self._check_if_req_exceed_kv_capacity(req):
|
||||
return
|
||||
|
||||
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
|
||||
kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
|
||||
else:
|
||||
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
|
||||
|
||||
dest_tp_ranks = [self.tp_rank]
|
||||
|
||||
req.disagg_kv_sender = kv_sender_class(
|
||||
mgr=self.kv_manager,
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
||||
bootstrap_room=req.bootstrap_room,
|
||||
dest_tp_ranks=dest_tp_ranks,
|
||||
pp_rank=self.pp_rank,
|
||||
)
|
||||
self._process_req(req)
|
||||
self.queue.append(req)
|
||||
|
||||
def extend(self, reqs: List[Req]) -> None:
|
||||
def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
|
||||
for req in reqs:
|
||||
self.add(req)
|
||||
self.add(req, num_kv_heads)
|
||||
|
||||
def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
|
||||
if len(req.origin_input_ids) > self.max_total_num_tokens:
|
||||
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
|
||||
logger.error(message)
|
||||
prepare_abort(req, message)
|
||||
self.scheduler.stream_output([req], req.return_logprob)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _process_req(self, req: Req) -> None:
|
||||
"""
|
||||
@@ -154,19 +179,40 @@ class PrefillBootstrapQueue:
|
||||
"""
|
||||
req.sampling_params.max_new_tokens = 1
|
||||
|
||||
def pop_bootstrapped(self) -> List[Req]:
|
||||
"""pop the reqs which has finished bootstrapping"""
|
||||
def pop_bootstrapped(
|
||||
self,
|
||||
return_failed_reqs: bool = False,
|
||||
rids_to_check: Optional[List[str]] = None,
|
||||
) -> List[Req]:
|
||||
"""
|
||||
pop the reqs which has finished bootstrapping
|
||||
|
||||
return_failed_reqs: For PP, on rank 0, also return the failed reqs to notify the next rank
|
||||
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
|
||||
"""
|
||||
|
||||
bootstrapped_reqs = []
|
||||
failed_reqs = []
|
||||
indices_to_remove = set()
|
||||
|
||||
if len(self.queue) == 0:
|
||||
return []
|
||||
if return_failed_reqs is False:
|
||||
return []
|
||||
else:
|
||||
return [], []
|
||||
|
||||
polls = poll_and_all_reduce(
|
||||
[req.disagg_kv_sender for req in self.queue], self.gloo_group
|
||||
)
|
||||
|
||||
for i, (req, poll) in enumerate(zip(self.queue, polls)):
|
||||
|
||||
if rids_to_check is not None:
|
||||
# if req not in reqs_info_to_check, skip
|
||||
if req.rid not in rids_to_check:
|
||||
continue
|
||||
# Either waiting for input or failed
|
||||
assert poll == KVPoll.WaitingForInput or poll == KVPoll.Failed
|
||||
|
||||
if poll == KVPoll.Bootstrapping:
|
||||
continue
|
||||
elif poll == KVPoll.Failed:
|
||||
@@ -181,9 +227,10 @@ class PrefillBootstrapQueue:
|
||||
)
|
||||
self.scheduler.stream_output([req], req.return_logprob)
|
||||
indices_to_remove.add(i)
|
||||
failed_reqs.append(req)
|
||||
continue
|
||||
|
||||
# KV.WaitingForInput
|
||||
# KV.WaitingForInput - init here
|
||||
num_kv_indices = len(req.origin_input_ids)
|
||||
if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
|
||||
break
|
||||
@@ -192,9 +239,9 @@ class PrefillBootstrapQueue:
|
||||
self.req_to_metadata_buffer_idx_allocator.alloc()
|
||||
)
|
||||
assert req.metadata_buffer_index is not None
|
||||
|
||||
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
||||
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
||||
|
||||
bootstrapped_reqs.append(req)
|
||||
indices_to_remove.add(i)
|
||||
|
||||
@@ -202,7 +249,10 @@ class PrefillBootstrapQueue:
|
||||
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
||||
]
|
||||
|
||||
return bootstrapped_reqs
|
||||
if return_failed_reqs is False:
|
||||
return bootstrapped_reqs
|
||||
else:
|
||||
return bootstrapped_reqs, failed_reqs
|
||||
|
||||
|
||||
class SchedulerDisaggregationPrefillMixin:
|
||||
@@ -211,7 +261,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_normal_disagg_prefill(self: Scheduler):
|
||||
def event_loop_normal_disagg_prefill(self: Scheduler) -> None:
|
||||
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
||||
|
||||
while True:
|
||||
@@ -229,7 +279,6 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
or self.server_args.enable_sp_layernorm
|
||||
):
|
||||
batch, _ = self.prepare_dp_attn_batch(batch)
|
||||
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
@@ -250,7 +299,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_overlap_disagg_prefill(self: Scheduler):
|
||||
def event_loop_overlap_disagg_prefill(self: Scheduler) -> None:
|
||||
self.result_queue = deque()
|
||||
|
||||
while True:
|
||||
@@ -268,9 +317,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
or self.server_args.enable_sp_layernorm
|
||||
):
|
||||
batch, _ = self.prepare_dp_attn_batch(batch)
|
||||
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
@@ -287,6 +334,9 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
|
||||
if self.last_batch:
|
||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
||||
|
||||
if len(self.disagg_prefill_inflight_queue) > 0:
|
||||
@@ -309,7 +359,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
||||
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
||||
Adapted from process_batch_result_prefill
|
||||
"""
|
||||
(
|
||||
@@ -325,7 +375,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
)
|
||||
|
||||
logprob_pt = 0
|
||||
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
||||
# Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
||||
if self.enable_overlap:
|
||||
# wait
|
||||
logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
|
||||
@@ -397,11 +447,15 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
# We need to remove the sync in the following function for overlap schedule.
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
|
||||
def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
|
||||
def process_disagg_prefill_inflight_queue(
|
||||
self: Scheduler, rids_to_check: Optional[List[str]] = None
|
||||
) -> List[Req]:
|
||||
"""
|
||||
Poll the requests in the middle of transfer. If done, return the request.
|
||||
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
|
||||
"""
|
||||
assert len(self.disagg_prefill_inflight_queue) > 0
|
||||
if len(self.disagg_prefill_inflight_queue) == 0:
|
||||
return []
|
||||
|
||||
done_reqs = []
|
||||
|
||||
@@ -413,6 +467,14 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
undone_reqs: List[Req] = []
|
||||
# Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue
|
||||
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
|
||||
|
||||
if rids_to_check is not None:
|
||||
if req.rid not in rids_to_check:
|
||||
undone_reqs.append(req)
|
||||
continue
|
||||
|
||||
assert poll == KVPoll.Success or poll == KVPoll.Failed
|
||||
|
||||
if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
|
||||
undone_reqs.append(req)
|
||||
elif poll == KVPoll.Success: # transfer done
|
||||
@@ -434,11 +496,8 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
)
|
||||
done_reqs.append(req)
|
||||
|
||||
for req in done_reqs:
|
||||
self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
|
||||
req.metadata_buffer_index
|
||||
)
|
||||
else:
|
||||
assert False, f"Unexpected polling state {poll=}"
|
||||
|
||||
# Stream requests which have finished transfer
|
||||
self.stream_output(
|
||||
@@ -446,9 +505,32 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
any(req.return_logprob for req in done_reqs),
|
||||
None,
|
||||
)
|
||||
for req in done_reqs:
|
||||
req: Req
|
||||
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
|
||||
req.metadata_buffer_index = -1
|
||||
|
||||
self.disagg_prefill_inflight_queue = undone_reqs
|
||||
|
||||
return done_reqs
|
||||
|
||||
def get_transferred_rids(self: Scheduler) -> List[str]:
|
||||
"""
|
||||
Used by PP, get the transferred rids but **do not pop**
|
||||
"""
|
||||
polls = poll_and_all_reduce(
|
||||
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
|
||||
self.tp_worker.get_tp_group().cpu_group,
|
||||
)
|
||||
|
||||
transferred_rids: List[str] = []
|
||||
|
||||
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
|
||||
if poll == KVPoll.Success or poll == KVPoll.Failed:
|
||||
transferred_rids.append(req.rid)
|
||||
|
||||
return transferred_rids
|
||||
|
||||
def process_prefill_chunk(self: Scheduler) -> None:
|
||||
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
||||
if self.chunked_req:
|
||||
|
||||
@@ -14,15 +14,15 @@ import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.utils import get_ip, get_local_ip_by_remote
|
||||
from sglang.srt.utils import get_ip
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
|
||||
FakeBootstrapHost = "2.2.2.2"
|
||||
|
||||
# env var for testing failure, convert to float explicitly
|
||||
FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
|
||||
#########################
|
||||
# Constants & Enums
|
||||
#########################
|
||||
FAKE_BOOTSTRAP_HOST = "2.2.2.2"
|
||||
|
||||
|
||||
class DisaggregationMode(Enum):
|
||||
@@ -31,6 +31,14 @@ class DisaggregationMode(Enum):
|
||||
DECODE = "decode"
|
||||
|
||||
|
||||
#########################
|
||||
# Synchronization
|
||||
#########################
|
||||
|
||||
# env var for testing failure, convert to float explicitly
|
||||
FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
|
||||
|
||||
|
||||
def poll_and_all_reduce(pollers, gloo_group):
|
||||
# at a certain prob, the poll is failed to simulate failure
|
||||
if FAILURE_PROB > 0:
|
||||
@@ -47,6 +55,11 @@ def poll_and_all_reduce(pollers, gloo_group):
|
||||
return tensor_to_reduce.tolist()
|
||||
|
||||
|
||||
#########################
|
||||
# Metadata Buffers
|
||||
#########################
|
||||
|
||||
|
||||
class ReqToMetadataIdxAllocator:
|
||||
"""A memory pool that maps a request to its first output token location."""
|
||||
|
||||
@@ -70,138 +83,6 @@ class ReqToMetadataIdxAllocator:
|
||||
self.free_slots.append(free_index)
|
||||
|
||||
|
||||
class TransferBackend(Enum):
|
||||
MOONCAKE = "mooncake"
|
||||
NIXL = "nixl"
|
||||
FAKE = "fake"
|
||||
|
||||
|
||||
class KVClassType(Enum):
|
||||
MANAGER = "manager"
|
||||
SENDER = "sender"
|
||||
RECEIVER = "receiver"
|
||||
BOOTSTRAP_SERVER = "bootstrap_server"
|
||||
|
||||
|
||||
def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
||||
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
||||
|
||||
if transfer_backend == TransferBackend.MOONCAKE:
|
||||
from sglang.srt.disaggregation.mooncake import (
|
||||
MooncakeKVBootstrapServer,
|
||||
MooncakeKVManager,
|
||||
MooncakeKVReceiver,
|
||||
MooncakeKVSender,
|
||||
)
|
||||
|
||||
class_mapping = {
|
||||
KVClassType.MANAGER: MooncakeKVManager,
|
||||
KVClassType.SENDER: MooncakeKVSender,
|
||||
KVClassType.RECEIVER: (MooncakeKVReceiver),
|
||||
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
if transfer_backend == TransferBackend.NIXL:
|
||||
from sglang.srt.disaggregation.nixl import (
|
||||
NixlKVBootstrapServer,
|
||||
NixlKVManager,
|
||||
NixlKVReceiver,
|
||||
NixlKVSender,
|
||||
)
|
||||
|
||||
class_mapping = {
|
||||
KVClassType.MANAGER: NixlKVManager,
|
||||
KVClassType.SENDER: NixlKVSender,
|
||||
KVClassType.RECEIVER: (NixlKVReceiver),
|
||||
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
if transfer_backend == TransferBackend.FAKE:
|
||||
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
||||
|
||||
class_mapping = {
|
||||
KVClassType.SENDER: FakeKVSender,
|
||||
KVClassType.RECEIVER: (FakeKVReceiver),
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
|
||||
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
||||
|
||||
|
||||
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
||||
# 1. The page is guaranteed to be full except the last page.
|
||||
# 2. page index = kv_index // page_size
|
||||
# The return vector is kv_indices[::page_size] // page_size
|
||||
if page_size == 1: # shortcut
|
||||
return kv_indices
|
||||
|
||||
return kv_indices[::page_size] // page_size
|
||||
|
||||
|
||||
def kv_to_page_num(num_kv_indices: int, page_size: int):
|
||||
# ceil(num_kv_indices / page_size)
|
||||
return (num_kv_indices + page_size - 1) // page_size
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PDRegistryRequest:
|
||||
"""A request to register a machine itself to the LB."""
|
||||
|
||||
mode: str
|
||||
registry_url: str
|
||||
bootstrap_port: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.mode == "prefill" and self.bootstrap_port is None:
|
||||
raise ValueError("Bootstrap port must be set in PREFILL mode.")
|
||||
elif self.mode == "decode" and self.bootstrap_port is not None:
|
||||
raise ValueError("Bootstrap port must not be set in DECODE mode.")
|
||||
elif self.mode not in ["prefill", "decode"]:
|
||||
raise ValueError(
|
||||
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
|
||||
)
|
||||
|
||||
|
||||
def register_disaggregation_server(
|
||||
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
|
||||
):
|
||||
boostrap_port = bootstrap_port if mode == "prefill" else None
|
||||
registry_request = PDRegistryRequest(
|
||||
mode=mode,
|
||||
registry_url=f"http://{get_ip()}:{server_port}",
|
||||
bootstrap_port=boostrap_port,
|
||||
)
|
||||
res = requests.post(
|
||||
f"{pdlb_url}/register",
|
||||
json=dataclasses.asdict(registry_request),
|
||||
)
|
||||
if res.status_code != 200:
|
||||
warnings.warn(
|
||||
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
||||
)
|
||||
|
||||
|
||||
def is_mla_backend(target_kv_pool) -> bool:
|
||||
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
||||
|
||||
return isinstance(target_kv_pool, MLATokenToKVPool)
|
||||
|
||||
|
||||
def prepare_abort(req: Req, error_message: str, status_code=None):
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT
|
||||
|
||||
# populate finish metadata and stream output
|
||||
req.finished_reason = FINISH_ABORT(error_message, status_code)
|
||||
|
||||
if req.return_logprob:
|
||||
req.input_token_logprobs_val = []
|
||||
req.input_token_logprobs_idx = []
|
||||
req.input_top_logprobs_val = []
|
||||
req.input_top_logprobs_idx = []
|
||||
req.input_token_ids_logprobs_val = []
|
||||
req.input_token_ids_logprobs_idx = []
|
||||
|
||||
|
||||
class MetadataBuffers:
|
||||
def __init__(self, size: int, max_top_logprobs_num: int = 128):
|
||||
# TODO: abort top_logprobs_num > 128 in PD
|
||||
@@ -282,37 +163,160 @@ class MetadataBuffers:
|
||||
)
|
||||
|
||||
|
||||
class FastQueue:
|
||||
def __init__(self):
|
||||
self._buf = deque()
|
||||
self._cond = threading.Condition()
|
||||
|
||||
def put(self, item):
|
||||
with self._cond:
|
||||
self._buf.append(item)
|
||||
# wake up a thread of wait()
|
||||
self._cond.notify()
|
||||
|
||||
def get(self):
|
||||
with self._cond:
|
||||
# if queue is empty ,block until is notified()
|
||||
while not self._buf:
|
||||
self._cond.wait()
|
||||
return self._buf.popleft()
|
||||
#########################
|
||||
# Transfer Backend
|
||||
#########################
|
||||
|
||||
|
||||
def group_concurrent_contiguous(
|
||||
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
||||
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
||||
"""Vectorised NumPy implementation."""
|
||||
if src_indices.size == 0:
|
||||
return [], []
|
||||
class TransferBackend(Enum):
|
||||
MOONCAKE = "mooncake"
|
||||
NIXL = "nixl"
|
||||
FAKE = "fake"
|
||||
|
||||
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
|
||||
src_groups = np.split(src_indices, brk)
|
||||
dst_groups = np.split(dst_indices, brk)
|
||||
|
||||
src_groups = [g.tolist() for g in src_groups]
|
||||
dst_groups = [g.tolist() for g in dst_groups]
|
||||
class KVClassType(Enum):
|
||||
KVARGS = "kvargs"
|
||||
MANAGER = "manager"
|
||||
SENDER = "sender"
|
||||
RECEIVER = "receiver"
|
||||
BOOTSTRAP_SERVER = "bootstrap_server"
|
||||
|
||||
return src_groups, dst_groups
|
||||
|
||||
def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
||||
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
||||
|
||||
if transfer_backend == TransferBackend.MOONCAKE:
|
||||
from sglang.srt.disaggregation.base import KVArgs
|
||||
from sglang.srt.disaggregation.mooncake import (
|
||||
MooncakeKVBootstrapServer,
|
||||
MooncakeKVManager,
|
||||
MooncakeKVReceiver,
|
||||
MooncakeKVSender,
|
||||
)
|
||||
|
||||
class_mapping = {
|
||||
KVClassType.KVARGS: KVArgs,
|
||||
KVClassType.MANAGER: MooncakeKVManager,
|
||||
KVClassType.SENDER: MooncakeKVSender,
|
||||
KVClassType.RECEIVER: (MooncakeKVReceiver),
|
||||
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
if transfer_backend == TransferBackend.NIXL:
|
||||
from sglang.srt.disaggregation.base import KVArgs
|
||||
from sglang.srt.disaggregation.nixl import (
|
||||
NixlKVBootstrapServer,
|
||||
NixlKVManager,
|
||||
NixlKVReceiver,
|
||||
NixlKVSender,
|
||||
)
|
||||
|
||||
class_mapping = {
|
||||
KVClassType.KVARGS: KVArgs,
|
||||
KVClassType.MANAGER: NixlKVManager,
|
||||
KVClassType.SENDER: NixlKVSender,
|
||||
KVClassType.RECEIVER: (NixlKVReceiver),
|
||||
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
if transfer_backend == TransferBackend.FAKE:
|
||||
from sglang.srt.disaggregation.base import KVArgs
|
||||
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
||||
|
||||
class_mapping = {
|
||||
KVClassType.KVARGS: KVArgs,
|
||||
KVClassType.SENDER: FakeKVSender,
|
||||
KVClassType.RECEIVER: (FakeKVReceiver),
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
|
||||
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
||||
|
||||
|
||||
#########################
|
||||
# KV Pages
|
||||
#########################
|
||||
|
||||
|
||||
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
||||
# 1. The page is guaranteed to be full except the last page.
|
||||
# 2. page index = kv_index // page_size
|
||||
# The return vector is kv_indices[::page_size] // page_size
|
||||
if page_size == 1: # shortcut
|
||||
return kv_indices
|
||||
|
||||
return kv_indices[::page_size] // page_size
|
||||
|
||||
|
||||
def kv_to_page_num(num_kv_indices: int, page_size: int):
|
||||
# ceil(num_kv_indices / page_size)
|
||||
return (num_kv_indices + page_size - 1) // page_size
|
||||
|
||||
|
||||
#########################
|
||||
# PDLB Registry
|
||||
#########################
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PDRegistryRequest:
|
||||
"""A request to register a machine itself to the LB."""
|
||||
|
||||
mode: str
|
||||
registry_url: str
|
||||
bootstrap_port: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.mode == "prefill" and self.bootstrap_port is None:
|
||||
raise ValueError("Bootstrap port must be set in PREFILL mode.")
|
||||
elif self.mode == "decode" and self.bootstrap_port is not None:
|
||||
raise ValueError("Bootstrap port must not be set in DECODE mode.")
|
||||
elif self.mode not in ["prefill", "decode"]:
|
||||
raise ValueError(
|
||||
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
|
||||
)
|
||||
|
||||
|
||||
def register_disaggregation_server(
|
||||
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
|
||||
):
|
||||
boostrap_port = bootstrap_port if mode == "prefill" else None
|
||||
registry_request = PDRegistryRequest(
|
||||
mode=mode,
|
||||
registry_url=f"http://{get_ip()}:{server_port}",
|
||||
bootstrap_port=boostrap_port,
|
||||
)
|
||||
res = requests.post(
|
||||
f"{pdlb_url}/register",
|
||||
json=dataclasses.asdict(registry_request),
|
||||
)
|
||||
if res.status_code != 200:
|
||||
warnings.warn(
|
||||
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
||||
)
|
||||
|
||||
|
||||
#########################
|
||||
# Misc
|
||||
#########################
|
||||
|
||||
|
||||
def is_mla_backend(target_kv_pool) -> bool:
|
||||
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
||||
|
||||
return isinstance(target_kv_pool, MLATokenToKVPool)
|
||||
|
||||
|
||||
def prepare_abort(req: Req, error_message: str, status_code=None):
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT
|
||||
|
||||
# populate finish metadata and stream output
|
||||
req.finished_reason = FINISH_ABORT(error_message, status_code)
|
||||
|
||||
if req.return_logprob:
|
||||
req.input_token_logprobs_val = []
|
||||
req.input_token_logprobs_idx = []
|
||||
req.input_top_logprobs_val = []
|
||||
req.input_top_logprobs_idx = []
|
||||
req.input_token_ids_logprobs_val = []
|
||||
req.input_token_ids_logprobs_idx = []
|
||||
|
||||
@@ -43,7 +43,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
FakeBootstrapHost,
|
||||
FAKE_BOOTSTRAP_HOST,
|
||||
register_disaggregation_server,
|
||||
)
|
||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||
@@ -878,7 +878,7 @@ def _wait_and_warmup(
|
||||
"max_new_tokens": 8,
|
||||
"ignore_eos": True,
|
||||
},
|
||||
"bootstrap_host": [FakeBootstrapHost] * server_args.dp_size,
|
||||
"bootstrap_host": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size,
|
||||
# This is a hack to ensure fake transfer is enabled during prefill warmup
|
||||
# ensure each dp rank has a unique bootstrap_room during prefill warmup
|
||||
"bootstrap_room": [
|
||||
|
||||
@@ -619,7 +619,7 @@ class Scheduler(
|
||||
self.disaggregation_mode == DisaggregationMode.DECODE
|
||||
): # *2 for the headroom.
|
||||
buffer_size = (self.req_to_token_pool.size) * 2
|
||||
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||
buffer_size
|
||||
)
|
||||
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
||||
@@ -627,7 +627,7 @@ class Scheduler(
|
||||
# The decode requests polling kv cache
|
||||
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
||||
gloo_group=self.attn_tp_cpu_group,
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=self.disagg_metadata_buffers,
|
||||
scheduler=self,
|
||||
tree_cache=self.tree_cache,
|
||||
@@ -642,7 +642,7 @@ class Scheduler(
|
||||
if self.draft_worker is None
|
||||
else self.draft_worker.model_runner.token_to_kv_pool
|
||||
),
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=self.disagg_metadata_buffers,
|
||||
scheduler=self,
|
||||
transfer_queue=self.disagg_decode_transfer_queue,
|
||||
@@ -660,7 +660,7 @@ class Scheduler(
|
||||
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# *2 for the headroom.
|
||||
buffer_size = self.max_running_requests * 2
|
||||
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||
buffer_size
|
||||
)
|
||||
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
||||
@@ -672,14 +672,20 @@ class Scheduler(
|
||||
if self.draft_worker is None
|
||||
else self.draft_worker.model_runner.token_to_kv_pool
|
||||
),
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=self.disagg_metadata_buffers,
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
gpu_id=self.gpu_id,
|
||||
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
||||
gloo_group=self.attn_tp_cpu_group,
|
||||
transfer_backend=self.transfer_backend,
|
||||
max_total_num_tokens=self.max_total_num_tokens,
|
||||
decode_tp_size=self.server_args.disaggregation_decode_tp,
|
||||
decode_dp_size=self.server_args.disaggregation_decode_dp,
|
||||
scheduler=self,
|
||||
pp_rank=self.pp_rank,
|
||||
pp_size=self.pp_size,
|
||||
transfer_backend=self.transfer_backend,
|
||||
)
|
||||
# The prefill requests that are in the middle of kv sending
|
||||
self.disagg_prefill_inflight_queue: List[Req] = []
|
||||
@@ -1110,7 +1116,9 @@ class Scheduler(
|
||||
def _add_request_to_queue(self, req: Req):
|
||||
req.queue_time_start = time.perf_counter()
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.disagg_prefill_bootstrap_queue.add(req)
|
||||
self.disagg_prefill_bootstrap_queue.add(
|
||||
req, self.model_config.num_key_value_heads
|
||||
)
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.disagg_decode_prealloc_queue.add(req)
|
||||
else:
|
||||
@@ -1118,7 +1126,9 @@ class Scheduler(
|
||||
|
||||
def _extend_requests_to_queue(self, reqs: List[Req]):
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.disagg_prefill_bootstrap_queue.extend(reqs)
|
||||
self.disagg_prefill_bootstrap_queue.extend(
|
||||
reqs, self.model_config.num_key_value_heads
|
||||
)
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
# If this is a decode server, we put the request to the decode pending prealloc queue
|
||||
self.disagg_decode_prealloc_queue.extend(reqs)
|
||||
|
||||
@@ -227,6 +227,9 @@ class ServerArgs:
|
||||
disaggregation_mode: str = "null"
|
||||
disaggregation_transfer_backend: str = "mooncake"
|
||||
disaggregation_bootstrap_port: int = 8998
|
||||
disaggregation_decode_tp: Optional[int] = None
|
||||
disaggregation_decode_dp: Optional[int] = None
|
||||
disaggregation_prefill_pp: Optional[int] = 1
|
||||
disaggregation_ib_device: Optional[str] = None
|
||||
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
|
||||
pdlb_url: Optional[str] = None
|
||||
@@ -505,12 +508,27 @@ class ServerArgs:
|
||||
self.triton_attention_num_kv_splits = 16
|
||||
|
||||
# PD disaggregation
|
||||
if self.disaggregation_mode == "prefill":
|
||||
self.disable_cuda_graph = True
|
||||
logger.warning("Cuda graph is disabled for prefill server")
|
||||
elif self.disaggregation_mode == "decode":
|
||||
if self.disaggregation_mode == "decode":
|
||||
assert (
|
||||
self.disaggregation_decode_tp is None
|
||||
), "Cannot set --disaggregation-decode-tp for the decode engine."
|
||||
assert (
|
||||
self.disaggregation_decode_dp is None
|
||||
), "Cannot set --disaggregation-decode-dp for the decode engine."
|
||||
|
||||
self.disable_radix_cache = True
|
||||
logger.warning("KV cache is forced as chunk cache for decode server")
|
||||
elif self.disaggregation_mode == "prefill":
|
||||
if self.disaggregation_decode_tp is None:
|
||||
self.disaggregation_decode_tp = self.tp_size
|
||||
if self.disaggregation_decode_dp is None:
|
||||
self.disaggregation_decode_dp = self.dp_size
|
||||
|
||||
self.disaggregation_prefill_pp = self.pp_size
|
||||
self.validate_disagg_tp_size(self.tp_size, self.disaggregation_decode_tp)
|
||||
|
||||
self.disable_cuda_graph = True
|
||||
logger.warning("Cuda graph is disabled for prefill server")
|
||||
|
||||
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
|
||||
"1" if self.enable_torch_compile else "0"
|
||||
@@ -520,6 +538,14 @@ class ServerArgs:
|
||||
"1" if self.disable_outlines_disk_cache else "0"
|
||||
)
|
||||
|
||||
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
|
||||
larger_tp = max(decode_tp, prefill_tp)
|
||||
smaller_tp = min(decode_tp, prefill_tp)
|
||||
assert larger_tp % smaller_tp == 0, (
|
||||
"Different tp size is supported only when one tp is multiple of the other. "
|
||||
f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
# Model and port args
|
||||
@@ -1512,6 +1538,24 @@ class ServerArgs:
|
||||
default=ServerArgs.disaggregation_bootstrap_port,
|
||||
help="Bootstrap server port on the prefill server. Default is 8998.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disaggregation-decode-tp",
|
||||
type=int,
|
||||
default=ServerArgs.disaggregation_decode_tp,
|
||||
help="Decode tp size. If not set, it matches the tp size of the current engine. This is only set on the prefill server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disaggregation-decode-dp",
|
||||
type=int,
|
||||
default=ServerArgs.disaggregation_decode_dp,
|
||||
help="Decode dp size. If not set, it matches the dp size of the current engine. This is only set on the prefill server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disaggregation-prefill-pp",
|
||||
type=int,
|
||||
default=ServerArgs.disaggregation_prefill_pp,
|
||||
help="Prefill pp size. If not set, it is default to 1. This is only set on the decode server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disaggregation-ib-device",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user