[PD] Release initial code (#4654)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: Ying1123 <sqy1415@gmail.com> Co-authored-by: merrymercy <lianminzheng@gmail.com> Co-authored-by: makro Co-authored-by: dhou-xai
This commit is contained in:
81
python/sglang/srt/disaggregation/conn.py
Normal file
81
python/sglang/srt/disaggregation/conn.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class KVArgs:
|
||||||
|
engine_rank: int
|
||||||
|
kv_data_ptrs: list[int]
|
||||||
|
kv_data_lens: list[int]
|
||||||
|
kv_item_lens: list[int]
|
||||||
|
aux_data_ptrs: list[int]
|
||||||
|
aux_data_lens: list[int]
|
||||||
|
aux_item_lens: list[int]
|
||||||
|
ib_device: str
|
||||||
|
|
||||||
|
|
||||||
|
class KVManager:
|
||||||
|
def __init__(self, args: KVArgs): ...
|
||||||
|
|
||||||
|
|
||||||
|
class KVPoll:
|
||||||
|
Failed = 0
|
||||||
|
Bootstrapping = 1
|
||||||
|
WaitingForInput = 2
|
||||||
|
Transferring = 3
|
||||||
|
Success = 4
|
||||||
|
|
||||||
|
|
||||||
|
class KVSender:
|
||||||
|
def __init__(self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int):
|
||||||
|
self.has_sent = False
|
||||||
|
|
||||||
|
def init(self, num_kv_indices: int, aux_index: Optional[int] = None): ...
|
||||||
|
|
||||||
|
def send(self, kv_indices: npt.NDArray[np.int32]):
|
||||||
|
self.has_sent = True
|
||||||
|
|
||||||
|
def poll(self) -> KVPoll:
|
||||||
|
if self.has_sent is False:
|
||||||
|
# Assume handshake completed instantly
|
||||||
|
return KVPoll.WaitingForInput
|
||||||
|
else:
|
||||||
|
# Assume transfer completed instantly
|
||||||
|
return KVPoll.Success
|
||||||
|
|
||||||
|
def failure_exception(self):
|
||||||
|
raise Exception("Fake KVSender Exception")
|
||||||
|
|
||||||
|
|
||||||
|
class KVReceiver:
|
||||||
|
def __init__(
|
||||||
|
self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None
|
||||||
|
):
|
||||||
|
self.has_init = False
|
||||||
|
|
||||||
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||||
|
self.has_init = True
|
||||||
|
|
||||||
|
def poll(self) -> KVPoll:
|
||||||
|
if self.has_init is False:
|
||||||
|
# Assume handshake completed instantly
|
||||||
|
return KVPoll.WaitingForInput
|
||||||
|
else:
|
||||||
|
# Assume transfer completed instantly
|
||||||
|
return KVPoll.Success
|
||||||
|
|
||||||
|
def failure_exception(self):
|
||||||
|
raise Exception("Fake KVReceiver Exception")
|
||||||
|
|
||||||
|
|
||||||
|
class KVBootstrapServer:
|
||||||
|
def __init__(self, port: int): ...
|
||||||
|
|
||||||
|
def poll(self) -> KVPoll: ...
|
||||||
495
python/sglang/srt/disaggregation/decode.py
Normal file
495
python/sglang/srt/disaggregation/decode.py
Normal file
@@ -0,0 +1,495 @@
|
|||||||
|
"""
|
||||||
|
Life cycle of a request in the decode server
|
||||||
|
|
||||||
|
1. PreallocQueue:
|
||||||
|
a. Initialize a receiver for each request
|
||||||
|
b. The request handshakes first, and pre-allocate kv once there is available kv.
|
||||||
|
c. Move the request to TransferQueue.
|
||||||
|
|
||||||
|
2. TransferQueue:
|
||||||
|
a. Poll the receiver to check the transfer state
|
||||||
|
b. If the transfer has finished, move the request to waiting queue
|
||||||
|
|
||||||
|
3. WaitingQueue:
|
||||||
|
a. Use the requests in the queue to construct a PrebuiltExtendBatch
|
||||||
|
b. Skip the prefill forward but only populate metadata
|
||||||
|
|
||||||
|
4. RunningBatch:
|
||||||
|
a. Merge the resolved PrebuiltExtendBatch into running batch to run decoding
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVReceiver
|
||||||
|
from sglang.srt.disaggregation.utils import (
|
||||||
|
ReqToMetadataIdxAllocator,
|
||||||
|
poll_and_all_reduce,
|
||||||
|
)
|
||||||
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||||
|
from sglang.srt.managers.scheduler import Scheduler
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DecodeRequest:
|
||||||
|
req: Req
|
||||||
|
kv_receiver: KVReceiver
|
||||||
|
waiting_for_input: bool = False
|
||||||
|
metadata_buffer_index: int = -1
|
||||||
|
|
||||||
|
|
||||||
|
class DecodePreallocQueue:
|
||||||
|
"""
|
||||||
|
Store the requests that are preallocating.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
req_to_token_pool: ReqToTokenPool,
|
||||||
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||||
|
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||||
|
metadata_buffers: List[torch.Tensor],
|
||||||
|
aux_dtype: torch.dtype,
|
||||||
|
scheduler: Scheduler,
|
||||||
|
transfer_queue: DecodeTransferQueue,
|
||||||
|
tree_cache: BasePrefixCache,
|
||||||
|
gloo_group: ProcessGroup,
|
||||||
|
tp_rank: int,
|
||||||
|
tp_size: int,
|
||||||
|
bootstrap_port: int,
|
||||||
|
):
|
||||||
|
self.req_to_token_pool = req_to_token_pool
|
||||||
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||||
|
self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
|
||||||
|
self.aux_dtype = aux_dtype
|
||||||
|
self.metadata_buffers = metadata_buffers
|
||||||
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||||
|
self.scheduler = scheduler
|
||||||
|
self.transfer_queue = transfer_queue
|
||||||
|
self.tree_cache = tree_cache # this is always a chunk cache
|
||||||
|
self.gloo_group = gloo_group
|
||||||
|
self.tp_rank = tp_rank
|
||||||
|
self.tp_size = tp_size
|
||||||
|
self.bootstrap_port = bootstrap_port
|
||||||
|
|
||||||
|
self.num_reserved_decode_tokens = 512
|
||||||
|
|
||||||
|
# Queue for requests pending pre-allocation
|
||||||
|
self.queue: List[DecodeRequest] = []
|
||||||
|
self.kv_manager = self._init_kv_manager()
|
||||||
|
|
||||||
|
def _init_kv_manager(self) -> KVManager:
|
||||||
|
kv_args = KVArgs()
|
||||||
|
kv_args.engine_rank = self.tp_rank
|
||||||
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
||||||
|
self.token_to_kv_pool.get_contiguous_buf_infos()
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_args.kv_data_ptrs = kv_data_ptrs
|
||||||
|
kv_args.kv_data_lens = kv_data_lens
|
||||||
|
kv_args.kv_item_lens = kv_item_lens
|
||||||
|
|
||||||
|
kv_args.aux_data_ptrs = [
|
||||||
|
output_id_tensor.data_ptr() for output_id_tensor in self.metadata_buffers
|
||||||
|
]
|
||||||
|
kv_args.aux_data_lens = [
|
||||||
|
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
|
||||||
|
]
|
||||||
|
kv_args.aux_item_lens = [
|
||||||
|
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
||||||
|
]
|
||||||
|
kv_args.ib_device = "mock-ib-device"
|
||||||
|
kv_manager = KVManager(kv_args)
|
||||||
|
return kv_manager
|
||||||
|
|
||||||
|
def add(self, req: Req) -> None:
|
||||||
|
"""Add a request to the pending queue."""
|
||||||
|
|
||||||
|
kv_receiver = KVReceiver(
|
||||||
|
mgr=self.kv_manager,
|
||||||
|
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
||||||
|
bootstrap_room=req.bootstrap_room,
|
||||||
|
)
|
||||||
|
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
|
||||||
|
|
||||||
|
def extend(self, reqs: List[Req]) -> None:
|
||||||
|
"""Add a request to the pending queue."""
|
||||||
|
for req in reqs:
|
||||||
|
self.add(req)
|
||||||
|
|
||||||
|
def _update_handshake_waiters(self) -> None:
|
||||||
|
if not self.queue:
|
||||||
|
return
|
||||||
|
|
||||||
|
if all(decode_req.waiting_for_input for decode_req in self.queue):
|
||||||
|
return
|
||||||
|
|
||||||
|
polls = poll_and_all_reduce(
|
||||||
|
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
||||||
|
if poll == KVPoll.Bootstrapping:
|
||||||
|
pass
|
||||||
|
elif poll == KVPoll.WaitingForInput:
|
||||||
|
decode_req.waiting_for_input = True
|
||||||
|
elif poll == KVPoll.Failed:
|
||||||
|
raise Exception("Handshake failed")
|
||||||
|
|
||||||
|
def pop_preallocated(self) -> List[DecodeRequest]:
|
||||||
|
"""Pop the preallocated requests from the pending queue (FIFO)."""
|
||||||
|
self._update_handshake_waiters()
|
||||||
|
|
||||||
|
preallocated_reqs = []
|
||||||
|
indices_to_remove = set()
|
||||||
|
allocatable_tokens = self._allocatable_tokens(count_retracted=True)
|
||||||
|
|
||||||
|
for i, decode_req in enumerate(self.queue):
|
||||||
|
if not decode_req.waiting_for_input:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self.req_to_token_pool.available_size() <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
required_tokens_for_request = (
|
||||||
|
len(decode_req.req.origin_input_ids) + self.num_reserved_decode_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
if required_tokens_for_request > allocatable_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
|
allocatable_tokens -= required_tokens_for_request
|
||||||
|
self._pre_alloc(decode_req.req)
|
||||||
|
|
||||||
|
kv_indices = (
|
||||||
|
self.req_to_token_pool.req_to_token[decode_req.req.req_pool_idx][
|
||||||
|
: len(decode_req.req.origin_input_ids)
|
||||||
|
]
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
|
||||||
|
decode_req.metadata_buffer_index = (
|
||||||
|
self.req_to_metadata_buffer_idx_allocator.alloc()
|
||||||
|
)
|
||||||
|
assert decode_req.metadata_buffer_index is not None
|
||||||
|
decode_req.kv_receiver.init(kv_indices, decode_req.metadata_buffer_index)
|
||||||
|
preallocated_reqs.append(decode_req)
|
||||||
|
indices_to_remove.add(i)
|
||||||
|
|
||||||
|
self.queue = [
|
||||||
|
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
||||||
|
]
|
||||||
|
|
||||||
|
return preallocated_reqs
|
||||||
|
|
||||||
|
def _allocatable_tokens(self) -> int:
|
||||||
|
allocatable_tokens = (
|
||||||
|
self.token_to_kv_pool_allocator.available_size()
|
||||||
|
- self.num_reserved_decode_tokens
|
||||||
|
* (
|
||||||
|
len(self.scheduler.running_batch.reqs)
|
||||||
|
+ len(self.transfer_queue.queue)
|
||||||
|
+ len(self.scheduler.waiting_queue)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
|
||||||
|
# the extend batch is not in any queue, so we need to explicitly add the tokens slots here
|
||||||
|
if (
|
||||||
|
self.scheduler.last_batch
|
||||||
|
and self.scheduler.last_batch.forward_mode.is_extend()
|
||||||
|
):
|
||||||
|
allocatable_tokens -= self.num_reserved_decode_tokens * len(
|
||||||
|
self.scheduler.last_batch.reqs
|
||||||
|
)
|
||||||
|
|
||||||
|
return allocatable_tokens
|
||||||
|
|
||||||
|
def _pre_alloc(self, req: Req) -> torch.Tensor:
|
||||||
|
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
|
||||||
|
req_pool_indices = self.req_to_token_pool.alloc(1)
|
||||||
|
|
||||||
|
assert req_pool_indices is not None
|
||||||
|
|
||||||
|
req.req_pool_idx = req_pool_indices[0]
|
||||||
|
kv_loc = self.token_to_kv_pool_allocator.alloc(
|
||||||
|
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert kv_loc is not None
|
||||||
|
|
||||||
|
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
|
||||||
|
|
||||||
|
# populate metadata
|
||||||
|
req.fill_ids = req.origin_input_ids + req.output_ids
|
||||||
|
req.extend_input_len = len(req.origin_input_ids)
|
||||||
|
|
||||||
|
return kv_loc
|
||||||
|
|
||||||
|
|
||||||
|
class DecodeTransferQueue:
|
||||||
|
"""
|
||||||
|
Store the requests that is polling kv
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
gloo_group: ProcessGroup,
|
||||||
|
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||||
|
metadata_buffers: torch.Tensor,
|
||||||
|
):
|
||||||
|
self.queue: List[DecodeRequest] = []
|
||||||
|
self.gloo_group = gloo_group
|
||||||
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||||
|
self.metadata_buffers = metadata_buffers
|
||||||
|
|
||||||
|
def add(self, req_conn: DecodeRequest) -> None:
|
||||||
|
self.queue.append(req_conn)
|
||||||
|
|
||||||
|
def extend(self, req_conns) -> None:
|
||||||
|
self.queue.extend(req_conns)
|
||||||
|
|
||||||
|
def pop_transferred(self) -> List[Req]:
|
||||||
|
if not self.queue:
|
||||||
|
return []
|
||||||
|
|
||||||
|
polls = poll_and_all_reduce(
|
||||||
|
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
||||||
|
)
|
||||||
|
|
||||||
|
transferred_reqs = []
|
||||||
|
indices_to_remove = set()
|
||||||
|
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
||||||
|
if poll == KVPoll.Failed:
|
||||||
|
raise Exception("Transfer failed")
|
||||||
|
elif poll == KVPoll.Success:
|
||||||
|
# pop and push it to waiting queue
|
||||||
|
idx = decode_req.metadata_buffer_index
|
||||||
|
assert len(decode_req.req.output_ids) == 0
|
||||||
|
output_id_buffer = self.metadata_buffers[0]
|
||||||
|
# the last dimension is padded by the same values.
|
||||||
|
output_id = output_id_buffer[idx][0].item()
|
||||||
|
assert len(decode_req.req.output_ids) == 0
|
||||||
|
assert decode_req.req.transferred_output_id is None
|
||||||
|
decode_req.req.transferred_output_id = output_id
|
||||||
|
transferred_reqs.append(decode_req.req)
|
||||||
|
indices_to_remove.add(i)
|
||||||
|
elif poll in [
|
||||||
|
KVPoll.Bootstrapping,
|
||||||
|
KVPoll.WaitingForInput,
|
||||||
|
KVPoll.Transferring,
|
||||||
|
]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected poll case: {poll}")
|
||||||
|
|
||||||
|
for i in indices_to_remove:
|
||||||
|
idx = self.queue[i].metadata_buffer_index
|
||||||
|
assert idx != -1
|
||||||
|
self.req_to_metadata_buffer_idx_allocator.free(idx)
|
||||||
|
|
||||||
|
self.queue = [
|
||||||
|
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
||||||
|
]
|
||||||
|
|
||||||
|
return transferred_reqs
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduleBatchDisaggregationDecodeMixin:
|
||||||
|
|
||||||
|
def prepare_for_prebuilt_extend(self: ScheduleBatch):
|
||||||
|
"""
|
||||||
|
Prepare a prebuilt extend by populate metadata
|
||||||
|
Adapted from .prepare_for_extend().
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.forward_mode = ForwardMode.EXTEND
|
||||||
|
reqs = self.reqs
|
||||||
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
||||||
|
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
||||||
|
seq_lens = []
|
||||||
|
pre_lens = []
|
||||||
|
req_pool_indices = []
|
||||||
|
|
||||||
|
# Pre-calculate total size
|
||||||
|
total_size = sum(req.extend_input_len for req in reqs)
|
||||||
|
out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
|
||||||
|
|
||||||
|
# Fill the tensor in one pass
|
||||||
|
offset = 0
|
||||||
|
for i, req in enumerate(reqs):
|
||||||
|
req_pool_indices.append(req.req_pool_idx)
|
||||||
|
|
||||||
|
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||||
|
: req.extend_input_len
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
offset + req.extend_input_len <= total_size
|
||||||
|
), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
|
||||||
|
out_cache_loc[offset : offset + req.extend_input_len] = chunk
|
||||||
|
offset += req.extend_input_len
|
||||||
|
|
||||||
|
pre_len = len(req.prefix_indices)
|
||||||
|
seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
|
||||||
|
seq_lens.append(seq_len)
|
||||||
|
if len(req.output_ids) == 0:
|
||||||
|
assert (
|
||||||
|
seq_len - pre_len == req.extend_input_len
|
||||||
|
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
|
||||||
|
|
||||||
|
req.cached_tokens += pre_len - req.already_computed
|
||||||
|
req.already_computed = seq_len
|
||||||
|
req.is_retracted = False
|
||||||
|
pre_lens.append(pre_len)
|
||||||
|
req.extend_logprob_start_len = 0
|
||||||
|
|
||||||
|
extend_input_logprob_token_ids = None
|
||||||
|
|
||||||
|
# Set fields
|
||||||
|
self.input_ids = torch.tensor(
|
||||||
|
sum(input_ids, []), dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
self.req_pool_indices = torch.tensor(
|
||||||
|
req_pool_indices, dtype=torch.int64, device=self.device
|
||||||
|
)
|
||||||
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
||||||
|
self.out_cache_loc = out_cache_loc
|
||||||
|
self.seq_lens_sum = sum(seq_lens)
|
||||||
|
self.extend_num_tokens = extend_num_tokens
|
||||||
|
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||||
|
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||||
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||||
|
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
||||||
|
|
||||||
|
# Build sampling info
|
||||||
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||||
|
self,
|
||||||
|
self.model_config.vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_prebuilt_extend(
|
||||||
|
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
|
||||||
|
):
|
||||||
|
"""Assign the buffered last input id to schedule batch"""
|
||||||
|
self.output_ids = []
|
||||||
|
for req in self.reqs:
|
||||||
|
if req.output_ids and len(req.output_ids) > 0:
|
||||||
|
# resumed retracted req
|
||||||
|
self.output_ids.append(req.output_ids[-1])
|
||||||
|
else:
|
||||||
|
assert req.transferred_output_id is not None
|
||||||
|
req.output_ids.append(req.transferred_output_id)
|
||||||
|
self.output_ids.append(req.transferred_output_id)
|
||||||
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
|
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerDisaggregationDecodeMixin:
|
||||||
|
|
||||||
|
def get_next_disagg_decode_batch_to_run(
|
||||||
|
self: Scheduler,
|
||||||
|
) -> Optional[Tuple[ScheduleBatch, bool]]:
|
||||||
|
"""Create fake completed prefill if possible and merge with running batch"""
|
||||||
|
# Merge the prefill batch into the running batch
|
||||||
|
last_batch = self.last_batch
|
||||||
|
if last_batch and last_batch.forward_mode.is_extend():
|
||||||
|
# chunked prefill doesn't happen in decode instance.
|
||||||
|
assert self.chunked_req is None
|
||||||
|
# Filter finished batches.
|
||||||
|
last_batch.filter_batch()
|
||||||
|
if not last_batch.is_empty():
|
||||||
|
if self.running_batch.is_empty():
|
||||||
|
self.running_batch = last_batch
|
||||||
|
else:
|
||||||
|
# merge running_batch with prefill batch
|
||||||
|
self.running_batch.merge_batch(last_batch)
|
||||||
|
|
||||||
|
new_prebuilt_batch = self.get_new_prebuilt_batch()
|
||||||
|
|
||||||
|
ret: Optional[ScheduleBatch] = None
|
||||||
|
if new_prebuilt_batch:
|
||||||
|
ret = new_prebuilt_batch
|
||||||
|
else:
|
||||||
|
if self.running_batch.is_empty():
|
||||||
|
ret = None
|
||||||
|
else:
|
||||||
|
self.running_batch = self.update_running_batch(self.running_batch)
|
||||||
|
ret = self.running_batch if not self.running_batch.is_empty() else None
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
|
||||||
|
"""Create a schedulebatch for fake completed prefill"""
|
||||||
|
if len(self.waiting_queue) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
curr_batch_size = self.running_batch.batch_size()
|
||||||
|
|
||||||
|
batch_size = min(self.req_to_token_pool.size, self.max_running_requests)
|
||||||
|
|
||||||
|
num_not_used_batch = batch_size - curr_batch_size
|
||||||
|
|
||||||
|
# pop req from waiting queue
|
||||||
|
can_run_list: List[Req] = []
|
||||||
|
waiting_queue: List[Req] = []
|
||||||
|
|
||||||
|
for i in range(len(self.waiting_queue)):
|
||||||
|
req = self.waiting_queue[i]
|
||||||
|
# we can only add at least `num_not_used_batch` new batch to the running queue
|
||||||
|
if i < num_not_used_batch:
|
||||||
|
can_run_list.append(req)
|
||||||
|
req.init_next_round_input(self.tree_cache)
|
||||||
|
else:
|
||||||
|
waiting_queue.append(req)
|
||||||
|
|
||||||
|
self.waiting_queue = waiting_queue
|
||||||
|
if len(can_run_list) == 0:
|
||||||
|
return None
|
||||||
|
# local import to avoid circular import
|
||||||
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
|
|
||||||
|
# construct a schedule batch with those requests and mark as decode
|
||||||
|
new_batch = ScheduleBatch.init_new(
|
||||||
|
can_run_list,
|
||||||
|
self.req_to_token_pool,
|
||||||
|
self.token_to_kv_pool_allocator,
|
||||||
|
self.tree_cache,
|
||||||
|
self.model_config,
|
||||||
|
self.enable_overlap,
|
||||||
|
self.spec_algorithm,
|
||||||
|
self.server_args.enable_custom_logit_processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# construct fake completed prefill
|
||||||
|
new_batch.prepare_for_prebuilt_extend()
|
||||||
|
new_batch.process_prebuilt_extend(self.server_args, self.model_config)
|
||||||
|
|
||||||
|
return new_batch
|
||||||
|
|
||||||
|
def process_decode_queue(self: Scheduler):
|
||||||
|
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
||||||
|
self.disagg_decode_transfer_queue.extend(req_conns)
|
||||||
|
alloc_reqs = (
|
||||||
|
self.disagg_decode_transfer_queue.pop_transferred()
|
||||||
|
) # the requests which kv has arrived
|
||||||
|
self.waiting_queue.extend(alloc_reqs)
|
||||||
285
python/sglang/srt/disaggregation/mini_lb.py
Normal file
285
python/sglang/srt/disaggregation/mini_lb.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
"""
|
||||||
|
Minimal HTTP load balancer for prefill and decode servers for testing purpose.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
import urllib
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import orjson
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||||
|
|
||||||
|
|
||||||
|
class MiniLoadBalancer:
|
||||||
|
def __init__(self, prefill_servers, decode_servers):
|
||||||
|
self.prefill_servers = prefill_servers
|
||||||
|
self.decode_servers = decode_servers
|
||||||
|
|
||||||
|
def select_pair(self):
|
||||||
|
return random.choice(self.prefill_servers), random.choice(self.decode_servers)
|
||||||
|
|
||||||
|
async def generate_request(self, request_data):
|
||||||
|
prefill_server, decode_server = self.select_pair()
|
||||||
|
|
||||||
|
# Parse and transform prefill_server
|
||||||
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||||
|
hostname = parsed_url.hostname
|
||||||
|
bootstrap_host = f"{hostname}"
|
||||||
|
|
||||||
|
modified_request = request_data.copy()
|
||||||
|
modified_request.update(
|
||||||
|
{
|
||||||
|
"bootstrap_host": bootstrap_host,
|
||||||
|
"bootstrap_room": random.randint(0, 2**63 - 1),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# Create the tasks
|
||||||
|
tasks = [
|
||||||
|
session.post(f"{prefill_server}/generate", json=modified_request),
|
||||||
|
session.post(f"{decode_server}/generate", json=modified_request),
|
||||||
|
]
|
||||||
|
|
||||||
|
prefill_response = None
|
||||||
|
decode_response = None
|
||||||
|
|
||||||
|
# Process responses as they arrive
|
||||||
|
for i, response in enumerate(asyncio.as_completed(tasks)):
|
||||||
|
response = await response
|
||||||
|
# Check if this is the prefill or decode response based on order created
|
||||||
|
if i == 0: # First completed task
|
||||||
|
if str(response.url).startswith(prefill_server):
|
||||||
|
prefill_response = response
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=response.status,
|
||||||
|
detail=f"Prefill server error: Status {response.status} Details: {await response.text()}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
decode_response = response
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=response.status,
|
||||||
|
detail=f"Decode server error: Status {response.status} Details: {await response.text()}",
|
||||||
|
)
|
||||||
|
else: # Second completed task
|
||||||
|
if str(response.url).startswith(prefill_server):
|
||||||
|
prefill_response = response
|
||||||
|
else:
|
||||||
|
decode_response = response
|
||||||
|
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=response.status,
|
||||||
|
detail=f"{'Prefill' if str(response.url).startswith(prefill_server) else 'Decode'} server error: Status {response.status} Details: {await response.text()}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return await decode_response.json()
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
load_balancer = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health_generate")
|
||||||
|
async def health_check():
|
||||||
|
prefill_servers, decode_servers = (
|
||||||
|
load_balancer.prefill_servers,
|
||||||
|
load_balancer.decode_servers,
|
||||||
|
)
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# Create the tasks
|
||||||
|
tasks = []
|
||||||
|
for server in chain(prefill_servers, decode_servers):
|
||||||
|
tasks.append(session.post(f"{server}/health_generate"))
|
||||||
|
for i, response in enumerate(asyncio.as_completed(tasks)):
|
||||||
|
await response
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/flush_cache")
|
||||||
|
async def flush_cache():
|
||||||
|
prefill_servers, decode_servers = (
|
||||||
|
load_balancer.prefill_servers,
|
||||||
|
load_balancer.decode_servers,
|
||||||
|
)
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# Create the tasks
|
||||||
|
tasks = []
|
||||||
|
for server in chain(prefill_servers, decode_servers):
|
||||||
|
tasks.append(session.post(f"{server}/flush_cache"))
|
||||||
|
for i, response in enumerate(asyncio.as_completed(tasks)):
|
||||||
|
await response
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/get_server_info")
|
||||||
|
async def get_server_info():
|
||||||
|
prefill_servers, decode_servers = (
|
||||||
|
load_balancer.prefill_servers,
|
||||||
|
load_balancer.decode_servers,
|
||||||
|
)
|
||||||
|
prefill_infos = []
|
||||||
|
decode_infos = []
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
for server in chain(prefill_servers):
|
||||||
|
server_info = await session.get(f"{server}/get_server_info")
|
||||||
|
prefill_infos.append(await server_info.json())
|
||||||
|
for server in chain(decode_servers):
|
||||||
|
server_info = await session.get(f"{server}/get_server_info")
|
||||||
|
decode_infos.append(await server_info.json())
|
||||||
|
|
||||||
|
return {"prefill": prefill_infos, "decode": decode_infos}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/get_model_info")
|
||||||
|
async def get_model_info():
|
||||||
|
# Dummy model information
|
||||||
|
model_info = {
|
||||||
|
"model_path": "/path/to/dummy/model",
|
||||||
|
"tokenizer_path": "/path/to/dummy/tokenizer",
|
||||||
|
"is_generation": True,
|
||||||
|
"preferred_sampling_params": {"temperature": 0.7, "max_new_tokens": 128},
|
||||||
|
}
|
||||||
|
return ORJSONResponse(content=model_info)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/generate")
|
||||||
|
async def handle_generate_request(request_data: dict):
|
||||||
|
prefill_server, decode_server = load_balancer.select_pair()
|
||||||
|
|
||||||
|
# Parse and transform prefill_server for bootstrap data
|
||||||
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||||
|
hostname = parsed_url.hostname
|
||||||
|
modified_request = request_data.copy()
|
||||||
|
modified_request.update(
|
||||||
|
{
|
||||||
|
"bootstrap_host": hostname,
|
||||||
|
"bootstrap_room": random.randint(0, 2**63 - 1),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if streaming is requested
|
||||||
|
if request_data.get("stream", False):
|
||||||
|
|
||||||
|
async def stream_results():
|
||||||
|
async with aiohttp.ClientSession(
|
||||||
|
timeout=aiohttp.ClientTimeout(total=3600)
|
||||||
|
) as session:
|
||||||
|
try:
|
||||||
|
# Create the tasks
|
||||||
|
tasks = [
|
||||||
|
session.post(
|
||||||
|
f"{prefill_server}/generate", json=modified_request
|
||||||
|
),
|
||||||
|
session.post(
|
||||||
|
f"{decode_server}/generate", json=modified_request
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
prefill_response = None
|
||||||
|
decode_response = None
|
||||||
|
|
||||||
|
# Process responses as they arrive
|
||||||
|
for i, response_task in enumerate(asyncio.as_completed(tasks)):
|
||||||
|
response = await response_task
|
||||||
|
|
||||||
|
# Check the response immediately
|
||||||
|
if str(response.url).startswith(prefill_server):
|
||||||
|
prefill_response = response
|
||||||
|
if response.status != 200:
|
||||||
|
error_msg = {
|
||||||
|
"error": {
|
||||||
|
"message": f"Prefill server error: Status {response.status}, Details: {await response.text()}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
yield b"data: " + orjson.dumps(
|
||||||
|
error_msg, option=orjson.OPT_NON_STR_KEYS
|
||||||
|
) + b"\n\n"
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
decode_response = response
|
||||||
|
if response.status != 200:
|
||||||
|
error_msg = {
|
||||||
|
"error": {
|
||||||
|
"message": f"Decode server error: Status {response.status}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
yield b"data: " + orjson.dumps(
|
||||||
|
error_msg, option=orjson.OPT_NON_STR_KEYS
|
||||||
|
) + b"\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
# Stream successful decode server response
|
||||||
|
async for line in decode_response.content:
|
||||||
|
yield line
|
||||||
|
yield b"data: [DONE]\n\n"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = {
|
||||||
|
"error": {"message": f"Stream processing error: {str(e)}"}
|
||||||
|
}
|
||||||
|
yield b"data: " + orjson.dumps(
|
||||||
|
error_msg, option=orjson.OPT_NON_STR_KEYS
|
||||||
|
) + b"\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_results(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Non-streaming case
|
||||||
|
result = await load_balancer.generate_request(request_data)
|
||||||
|
return ORJSONResponse(content=result)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/models")
|
||||||
|
async def get_models():
|
||||||
|
prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
try:
|
||||||
|
response = await session.get(f"{prefill_server}/v1/models")
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=response.status,
|
||||||
|
detail=f"Prefill server error: Status {response.status}",
|
||||||
|
)
|
||||||
|
return ORJSONResponse(content=await response.json())
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
def run(prefill_addrs, decode_addrs, host, port):
|
||||||
|
global load_balancer
|
||||||
|
load_balancer = MiniLoadBalancer(prefill_addrs, decode_addrs)
|
||||||
|
uvicorn.run(app, host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefill", required=True, help="Comma-separated URLs for prefill servers"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode", required=True, help="Comma-separated URLs for decode servers"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
run(args.prefill.split(","), args.decode.split(","), args.host, args.port)
|
||||||
249
python/sglang/srt/disaggregation/prefill.py
Normal file
249
python/sglang/srt/disaggregation/prefill.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
"""
|
||||||
|
Life cycle of a request in the prefill server
|
||||||
|
|
||||||
|
1. Bootstrap Queue
|
||||||
|
a. Initialize a sender for each request
|
||||||
|
b. Use the queue to store requests whose bootstrap (handshake and preallocation) has not finished
|
||||||
|
c. Poll senders to check bootstrap state
|
||||||
|
d. Once bootstrap is complete, move request to Waiting Queue
|
||||||
|
|
||||||
|
2. Waiting Queue
|
||||||
|
a. Use PrefillAdder to pop requests
|
||||||
|
b. Run forward
|
||||||
|
c. Add the request to Infight Queue
|
||||||
|
|
||||||
|
3. Infight Queue
|
||||||
|
a. Poll (non-blocking) the sender of the request
|
||||||
|
b. Once the transfer has finished, return the request
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVSender
|
||||||
|
from sglang.srt.disaggregation.utils import (
|
||||||
|
ReqToMetadataIdxAllocator,
|
||||||
|
poll_and_all_reduce,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
|
||||||
|
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PrefillBootstrapQueue:
|
||||||
|
"""
|
||||||
|
Store the requests in bootstrapping
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
token_to_kv_pool: KVCache,
|
||||||
|
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||||
|
metadata_buffers: List[torch.Tensor],
|
||||||
|
aux_dtype: torch.dtype,
|
||||||
|
tp_rank: int,
|
||||||
|
tp_size: int,
|
||||||
|
bootstrap_port: int,
|
||||||
|
gloo_group: ProcessGroup,
|
||||||
|
):
|
||||||
|
self.token_to_kv_pool = token_to_kv_pool
|
||||||
|
self.aux_dtype = aux_dtype
|
||||||
|
|
||||||
|
self.metadata_buffers = metadata_buffers
|
||||||
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||||
|
self.tp_rank = tp_rank
|
||||||
|
self.tp_size = tp_size
|
||||||
|
self.kv_manager = self._init_kv_manager()
|
||||||
|
self.queue: List[Req] = []
|
||||||
|
self.gloo_group = gloo_group
|
||||||
|
self.bootstrap_port = bootstrap_port
|
||||||
|
|
||||||
|
def allocate_token_id(self, idx: int, token_id: int):
|
||||||
|
assert token_id >= 0, f"token_id: {token_id} is negative"
|
||||||
|
output_id_buffer = self.metadata_buffers[0]
|
||||||
|
output_id_buffer[idx] = token_id
|
||||||
|
|
||||||
|
def _init_kv_manager(self) -> KVManager:
|
||||||
|
kv_args = KVArgs()
|
||||||
|
kv_args.engine_rank = self.tp_rank
|
||||||
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
||||||
|
self.token_to_kv_pool.get_contiguous_buf_infos()
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_args.kv_data_ptrs = kv_data_ptrs
|
||||||
|
kv_args.kv_data_lens = kv_data_lens
|
||||||
|
kv_args.kv_item_lens = kv_item_lens
|
||||||
|
|
||||||
|
# Define req -> input ids buffer
|
||||||
|
kv_args.aux_data_ptrs = [
|
||||||
|
metadata_buffer.data_ptr() for metadata_buffer in self.metadata_buffers
|
||||||
|
]
|
||||||
|
kv_args.aux_data_lens = [
|
||||||
|
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
|
||||||
|
]
|
||||||
|
kv_args.aux_item_lens = [
|
||||||
|
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
||||||
|
]
|
||||||
|
kv_args.ib_device = "mock-ib-device"
|
||||||
|
kv_manager = KVManager(kv_args)
|
||||||
|
return kv_manager
|
||||||
|
|
||||||
|
def add(self, req: Req) -> None:
|
||||||
|
req.disagg_kv_sender = KVSender(
|
||||||
|
mgr=self.kv_manager,
|
||||||
|
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
||||||
|
bootstrap_room=req.bootstrap_room,
|
||||||
|
)
|
||||||
|
self._process_req(req)
|
||||||
|
self.queue.append(req)
|
||||||
|
|
||||||
|
def _process_req(self, req: Req) -> None:
|
||||||
|
"""
|
||||||
|
Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate
|
||||||
|
"""
|
||||||
|
req.sampling_params.max_new_tokens = 1
|
||||||
|
|
||||||
|
def pop_bootstrapped(self) -> List[Req]:
|
||||||
|
"""pop the reqs which has finished bootstrapping"""
|
||||||
|
bootstrapped_reqs = []
|
||||||
|
indices_to_remove = set()
|
||||||
|
|
||||||
|
if len(self.queue) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
polls = poll_and_all_reduce(
|
||||||
|
[req.disagg_kv_sender for req in self.queue], self.gloo_group
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, (req, poll) in enumerate(zip(self.queue, polls)):
|
||||||
|
if poll == KVPoll.Bootstrapping:
|
||||||
|
continue
|
||||||
|
elif poll == KVPoll.Failed:
|
||||||
|
raise Exception("Bootstrap failed")
|
||||||
|
|
||||||
|
# KV.WaitingForInput - init here
|
||||||
|
num_kv_indices = len(req.origin_input_ids)
|
||||||
|
if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
req.metadata_buffer_index = (
|
||||||
|
self.req_to_metadata_buffer_idx_allocator.alloc()
|
||||||
|
)
|
||||||
|
assert req.metadata_buffer_index is not None
|
||||||
|
req.disagg_kv_sender.init(num_kv_indices, req.metadata_buffer_index)
|
||||||
|
|
||||||
|
bootstrapped_reqs.append(req)
|
||||||
|
indices_to_remove.add(i)
|
||||||
|
|
||||||
|
self.queue = [
|
||||||
|
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
||||||
|
]
|
||||||
|
|
||||||
|
return bootstrapped_reqs
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerDisaggregationPrefillMixin:
|
||||||
|
"""
|
||||||
|
Mixin for Scheduler to handle disaggregation prefill
|
||||||
|
"""
|
||||||
|
|
||||||
|
def process_batch_result_disagg_prefill(
|
||||||
|
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
||||||
|
Adapted from process_batch_result_prefill
|
||||||
|
"""
|
||||||
|
|
||||||
|
next_token_ids = result.next_token_ids.tolist()
|
||||||
|
|
||||||
|
for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
|
||||||
|
req: Req
|
||||||
|
if req.is_chunked <= 0:
|
||||||
|
# There is no output_ids for prefill
|
||||||
|
req.output_ids.append(next_token_id)
|
||||||
|
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
||||||
|
self.send_kv_chunk(req, token_id=next_token_id)
|
||||||
|
self.disagg_prefill_infight_queue.append(req)
|
||||||
|
else:
|
||||||
|
# being chunked reqs' prefill is not finished
|
||||||
|
req.is_chunked -= 1
|
||||||
|
|
||||||
|
# TODO: Not sure if this is necessary
|
||||||
|
if batch.next_batch_sampling_info:
|
||||||
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||||
|
# We need to remove this for overlap schedule.
|
||||||
|
self.current_stream.synchronize()
|
||||||
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||||
|
|
||||||
|
def process_disagg_prefill_infight_queue(self: Scheduler) -> None:
|
||||||
|
"""
|
||||||
|
Poll the requests in the middle of transfer. If done, return the request.
|
||||||
|
"""
|
||||||
|
assert len(self.disagg_prefill_infight_queue) > 0
|
||||||
|
|
||||||
|
done_reqs = []
|
||||||
|
|
||||||
|
polls = poll_and_all_reduce(
|
||||||
|
[req.disagg_kv_sender for req in self.disagg_prefill_infight_queue],
|
||||||
|
self.tp_worker.get_tp_cpu_group(),
|
||||||
|
)
|
||||||
|
|
||||||
|
undone_reqs: List[Req] = []
|
||||||
|
# Check .poll() for the reqs in disagg_prefill_infight_queue. If Success, respond to the client and remove it from the queue
|
||||||
|
for req, poll in zip(self.disagg_prefill_infight_queue, polls):
|
||||||
|
if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
|
||||||
|
undone_reqs.append(req)
|
||||||
|
elif poll == KVPoll.Success: # transfer done
|
||||||
|
self.tree_cache.cache_finished_req(req) # unlock the tree
|
||||||
|
req.finished_reason = FINISH_LENGTH(length=0)
|
||||||
|
done_reqs.append(req)
|
||||||
|
elif poll == KVPoll.Failed:
|
||||||
|
raise Exception("Transferring failed")
|
||||||
|
|
||||||
|
# Stream requests which have finished transfer
|
||||||
|
self.stream_output(done_reqs, False, None)
|
||||||
|
|
||||||
|
self.disagg_prefill_infight_queue = undone_reqs
|
||||||
|
|
||||||
|
def process_prefill_chunk(self: Scheduler) -> None:
|
||||||
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
||||||
|
if self.chunked_req:
|
||||||
|
# Move the chunked request out of the batch so that we can merge
|
||||||
|
# only finished requests to running_batch.
|
||||||
|
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
||||||
|
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
||||||
|
self.send_kv_chunk(self.chunked_req)
|
||||||
|
# chunked request keeps its rid but will get a new req_pool_idx
|
||||||
|
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
||||||
|
self.running_batch.batch_is_full = False
|
||||||
|
|
||||||
|
def send_kv_chunk(
|
||||||
|
self: Scheduler, req: Req, token_id: Optional[int] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Send a prefilled chunk to the decode server
|
||||||
|
"""
|
||||||
|
start_idx = req.start_send_idx
|
||||||
|
end_idx = min(len(req.fill_ids), len(req.origin_input_ids))
|
||||||
|
kv_indices = (
|
||||||
|
self.req_to_token_pool.req_to_token[req.req_pool_idx][start_idx:end_idx]
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
req.start_send_idx = end_idx
|
||||||
|
if token_id is not None:
|
||||||
|
self.disagg_prefill_pending_queue.allocate_token_id(
|
||||||
|
req.metadata_buffer_index, token_id
|
||||||
|
)
|
||||||
|
req.disagg_kv_sender.send(kv_indices)
|
||||||
44
python/sglang/srt/disaggregation/utils.py
Normal file
44
python/sglang/srt/disaggregation/utils.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
class DisaggregationMode(Enum):
|
||||||
|
NULL = "null"
|
||||||
|
PREFILL = "prefill"
|
||||||
|
DECODE = "decode"
|
||||||
|
|
||||||
|
|
||||||
|
def poll_and_all_reduce(pollers, gloo_group):
|
||||||
|
polls = [int(poller.poll()) for poller in pollers]
|
||||||
|
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
|
||||||
|
dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group)
|
||||||
|
return tensor_to_reduce.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
class ReqToMetadataIdxAllocator:
|
||||||
|
"""A memory pool that maps a request to its first output token location."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
size: int,
|
||||||
|
):
|
||||||
|
self.size = size
|
||||||
|
self.free_slots = deque(list(range(size)))
|
||||||
|
|
||||||
|
def available_size(self):
|
||||||
|
return len(self.free_slots)
|
||||||
|
|
||||||
|
def alloc(self) -> List[int]:
|
||||||
|
if len(self.free_slots) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.free_slots.popleft()
|
||||||
|
|
||||||
|
def free(self, free_index: int):
|
||||||
|
self.free_slots.append(free_index)
|
||||||
@@ -42,6 +42,8 @@ import triton.language as tl
|
|||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||||
|
from sglang.srt.disaggregation.conn import KVSender
|
||||||
|
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||||
@@ -396,6 +398,24 @@ class Req:
|
|||||||
self.spec_verify_ct = 0
|
self.spec_verify_ct = 0
|
||||||
self.lora_path = lora_path
|
self.lora_path = lora_path
|
||||||
|
|
||||||
|
# For disaggregation
|
||||||
|
self.bootstrap_host: str = "0.0.0.0"
|
||||||
|
self.bootstrap_room: Optional[int] = None
|
||||||
|
self.disagg_kv_sender: Optional[KVSender] = None
|
||||||
|
|
||||||
|
# used for warmup because we don't have a pair yet when init
|
||||||
|
self.skip_kv_transfer: bool = False
|
||||||
|
# the start index of the sent kv cache
|
||||||
|
# We want to send it chunk by chunk for chunked prefill.
|
||||||
|
# After every chunk forward, we do the following:
|
||||||
|
# kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
|
||||||
|
# start_send_idx = len(req.fill_ids)
|
||||||
|
self.start_send_idx: int = 0
|
||||||
|
|
||||||
|
self.metadata_buffer_index: int = -1
|
||||||
|
# The first output_id transferred from prefill instance.
|
||||||
|
self.transferred_output_id: Optional[int] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def seqlen(self):
|
def seqlen(self):
|
||||||
return len(self.origin_input_ids) + len(self.output_ids)
|
return len(self.origin_input_ids) + len(self.output_ids)
|
||||||
@@ -531,7 +551,7 @@ bid = 0
|
|||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ScheduleBatch:
|
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||||
"""Store all information of a batch on the scheduler."""
|
"""Store all information of a batch on the scheduler."""
|
||||||
|
|
||||||
# Request, memory pool, and cache
|
# Request, memory pool, and cache
|
||||||
|
|||||||
@@ -37,6 +37,19 @@ from torch.distributed import barrier
|
|||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
|
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
|
||||||
|
from sglang.srt.disaggregation.decode import (
|
||||||
|
DecodePreallocQueue,
|
||||||
|
DecodeTransferQueue,
|
||||||
|
SchedulerDisaggregationDecodeMixin,
|
||||||
|
)
|
||||||
|
from sglang.srt.disaggregation.prefill import (
|
||||||
|
PrefillBootstrapQueue,
|
||||||
|
SchedulerDisaggregationPrefillMixin,
|
||||||
|
)
|
||||||
|
from sglang.srt.disaggregation.utils import (
|
||||||
|
DisaggregationMode,
|
||||||
|
ReqToMetadataIdxAllocator,
|
||||||
|
)
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
@@ -137,7 +150,11 @@ class EmbeddingBatchResult:
|
|||||||
bid: int
|
bid: int
|
||||||
|
|
||||||
|
|
||||||
class Scheduler(SchedulerOutputProcessorMixin):
|
class Scheduler(
|
||||||
|
SchedulerOutputProcessorMixin,
|
||||||
|
SchedulerDisaggregationDecodeMixin,
|
||||||
|
SchedulerDisaggregationPrefillMixin,
|
||||||
|
):
|
||||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -389,6 +406,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.disaggregation_mode = DisaggregationMode(
|
||||||
|
self.server_args.disaggregation_mode
|
||||||
|
)
|
||||||
|
self.init_disaggregation()
|
||||||
|
|
||||||
def init_tokenizer(self):
|
def init_tokenizer(self):
|
||||||
server_args = self.server_args
|
server_args = self.server_args
|
||||||
|
|
||||||
@@ -489,6 +511,73 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def init_disaggregation(self):
|
||||||
|
if (
|
||||||
|
self.disaggregation_mode == DisaggregationMode.DECODE
|
||||||
|
): # *2 for the headroom.
|
||||||
|
buffer_size = (self.req_to_token_pool.size) * 2
|
||||||
|
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||||
|
buffer_size
|
||||||
|
)
|
||||||
|
aux_dtype = torch.int32
|
||||||
|
# A list of metadata buffers. The shape is (b, metadata_size) where
|
||||||
|
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
||||||
|
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
||||||
|
output_id_buffer = torch.zeros(
|
||||||
|
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
||||||
|
)
|
||||||
|
metadata_buffers = [output_id_buffer]
|
||||||
|
|
||||||
|
# The decode requests polling kv cache
|
||||||
|
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
||||||
|
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
||||||
|
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||||
|
metadata_buffers=metadata_buffers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The decode requests pending for pre-allocation
|
||||||
|
self.disagg_decode_prealloc_queue = DecodePreallocQueue(
|
||||||
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
|
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||||
|
metadata_buffers=metadata_buffers,
|
||||||
|
aux_dtype=aux_dtype,
|
||||||
|
scheduler=self,
|
||||||
|
transfer_queue=self.disagg_decode_transfer_queue,
|
||||||
|
tree_cache=self.tree_cache,
|
||||||
|
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
||||||
|
tp_rank=self.tp_rank,
|
||||||
|
tp_size=self.tp_size,
|
||||||
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
||||||
|
)
|
||||||
|
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
|
# *2 for the headroom.
|
||||||
|
buffer_size = self.max_running_requests * 2
|
||||||
|
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||||
|
buffer_size
|
||||||
|
)
|
||||||
|
aux_dtype = torch.int32
|
||||||
|
# A list of metadata buffers. The shape is (b, metadata_size) where
|
||||||
|
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
||||||
|
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
||||||
|
output_id_buffer = torch.zeros(
|
||||||
|
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
||||||
|
)
|
||||||
|
metadata_buffers = [output_id_buffer]
|
||||||
|
|
||||||
|
self.disagg_prefill_pending_queue = PrefillBootstrapQueue(
|
||||||
|
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
||||||
|
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||||
|
metadata_buffers=metadata_buffers,
|
||||||
|
aux_dtype=aux_dtype,
|
||||||
|
tp_rank=self.tp_rank,
|
||||||
|
tp_size=self.tp_size,
|
||||||
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
||||||
|
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
||||||
|
)
|
||||||
|
# The prefill requests that are in the middle of kv sending
|
||||||
|
self.disagg_prefill_infight_queue: List[Req] = []
|
||||||
|
|
||||||
@DynamicGradMode()
|
@DynamicGradMode()
|
||||||
def event_loop_normal(self):
|
def event_loop_normal(self):
|
||||||
"""A normal scheduler loop."""
|
"""A normal scheduler loop."""
|
||||||
@@ -549,6 +638,70 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def event_loop_normal_disagg_prefill(self):
|
||||||
|
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
recv_reqs = self.recv_requests()
|
||||||
|
self.process_input_requests(recv_reqs)
|
||||||
|
self.waiting_queue.extend(
|
||||||
|
self.disagg_prefill_pending_queue.pop_bootstrapped()
|
||||||
|
)
|
||||||
|
self.process_prefill_chunk()
|
||||||
|
batch = self.get_new_batch_prefill()
|
||||||
|
self.cur_batch = batch
|
||||||
|
|
||||||
|
if batch:
|
||||||
|
result = self.run_batch(batch)
|
||||||
|
self.process_batch_result_disagg_prefill(batch, result)
|
||||||
|
|
||||||
|
if len(self.disagg_prefill_infight_queue) > 0:
|
||||||
|
self.process_disagg_prefill_infight_queue()
|
||||||
|
|
||||||
|
if batch is None and len(self.disagg_prefill_infight_queue) == 0:
|
||||||
|
self.check_memory()
|
||||||
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
|
|
||||||
|
self.last_batch = batch
|
||||||
|
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
||||||
|
# Otherwise, it hangs under high concurrency
|
||||||
|
self.running_batch.batch_is_full = False
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def event_loop_normal_disagg_decode(self):
|
||||||
|
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
recv_reqs = self.recv_requests()
|
||||||
|
self.process_input_requests(recv_reqs)
|
||||||
|
# polling and allocating kv cache
|
||||||
|
self.process_decode_queue()
|
||||||
|
batch = self.get_next_disagg_decode_batch_to_run()
|
||||||
|
self.cur_batch = batch
|
||||||
|
|
||||||
|
if batch:
|
||||||
|
# Generate fake extend output.
|
||||||
|
if batch.forward_mode.is_extend():
|
||||||
|
# Note: Logprobs should be handled on the prefill engine.
|
||||||
|
self.stream_output(
|
||||||
|
batch.reqs, [False for _ in range(len(batch.reqs))]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = self.run_batch(batch)
|
||||||
|
self.process_batch_result(batch, result)
|
||||||
|
|
||||||
|
if batch is None and (
|
||||||
|
len(self.disagg_decode_transfer_queue.queue)
|
||||||
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||||
|
== 0
|
||||||
|
):
|
||||||
|
# When the server is idle, do self-check and re-init some states
|
||||||
|
self.check_memory()
|
||||||
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
|
|
||||||
|
self.last_batch = batch
|
||||||
|
|
||||||
def recv_requests(self) -> List[Req]:
|
def recv_requests(self) -> List[Req]:
|
||||||
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
||||||
if self.attn_tp_rank == 0:
|
if self.attn_tp_rank == 0:
|
||||||
@@ -778,10 +931,20 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
|||||||
self._add_request_to_queue(req)
|
self._add_request_to_queue(req)
|
||||||
|
|
||||||
def _add_request_to_queue(self, req: Req):
|
def _add_request_to_queue(self, req: Req):
|
||||||
self.waiting_queue.append(req)
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
|
self.disagg_prefill_pending_queue.add(req)
|
||||||
|
|
||||||
def _extend_requests_to_queue(self, reqs: List[Req]):
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
self.waiting_queue.extend(reqs)
|
self.disagg_decode_prealloc_queue.add(req)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.waiting_queue.append(req)
|
||||||
|
|
||||||
|
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
||||||
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
|
self.disagg_decode_prealloc_queue.extend(reqs)
|
||||||
|
else:
|
||||||
|
self.waiting_queue.extend(reqs)
|
||||||
|
|
||||||
def handle_embedding_request(
|
def handle_embedding_request(
|
||||||
self,
|
self,
|
||||||
@@ -1814,10 +1977,18 @@ def run_scheduler_process(
|
|||||||
"max_req_input_len": scheduler.max_req_input_len,
|
"max_req_input_len": scheduler.max_req_input_len,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if scheduler.enable_overlap:
|
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
||||||
scheduler.event_loop_overlap()
|
|
||||||
else:
|
if disaggregation_mode == DisaggregationMode.NULL:
|
||||||
scheduler.event_loop_normal()
|
if scheduler.enable_overlap:
|
||||||
|
scheduler.event_loop_overlap()
|
||||||
|
else:
|
||||||
|
scheduler.event_loop_normal()
|
||||||
|
elif disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
|
scheduler.event_loop_normal_disagg_prefill()
|
||||||
|
elif disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
|
scheduler.event_loop_normal_disagg_decode()
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback = get_exception_traceback()
|
traceback = get_exception_traceback()
|
||||||
logger.error(f"Scheduler hit an exception: {traceback}")
|
logger.error(f"Scheduler hit an exception: {traceback}")
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ from fastapi import BackgroundTasks
|
|||||||
|
|
||||||
from sglang.srt.aio_rwlock import RWLock
|
from sglang.srt.aio_rwlock import RWLock
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
from sglang.srt.disaggregation.conn import KVBootstrapServer
|
||||||
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.managers.image_processor import (
|
from sglang.srt.managers.image_processor import (
|
||||||
get_dummy_image_processor,
|
get_dummy_image_processor,
|
||||||
@@ -313,6 +315,16 @@ class TokenizerManager:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.disaggregation_mode = DisaggregationMode(
|
||||||
|
self.server_args.disaggregation_mode
|
||||||
|
)
|
||||||
|
# for disaggregtion, start kv boostrap server on prefill
|
||||||
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
|
# only start bootstrap server on prefill tm
|
||||||
|
self.bootstrap_server = KVBootstrapServer(
|
||||||
|
self.server_args.disaggregation_bootstrap_port
|
||||||
|
)
|
||||||
|
|
||||||
async def generate_request(
|
async def generate_request(
|
||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
|
|||||||
@@ -271,6 +271,19 @@ class MHATokenToKVPool(KVCache):
|
|||||||
v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
|
v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
|
||||||
return k_size_bytes, v_size_bytes
|
return k_size_bytes, v_size_bytes
|
||||||
|
|
||||||
|
# for disagg
|
||||||
|
def get_contiguous_buf_infos(self):
|
||||||
|
kv_data_ptrs = [
|
||||||
|
self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
|
||||||
|
] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
|
||||||
|
kv_data_lens = [
|
||||||
|
self.get_key_buffer(i).nbytes for i in range(self.layer_num)
|
||||||
|
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
|
||||||
|
kv_item_lens = [
|
||||||
|
self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num)
|
||||||
|
] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)]
|
||||||
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||||
|
|
||||||
# Todo: different memory layout
|
# Todo: different memory layout
|
||||||
def get_flat_data(self, indices):
|
def get_flat_data(self, indices):
|
||||||
# prepare a large chunk of contiguous data for efficient transfer
|
# prepare a large chunk of contiguous data for efficient transfer
|
||||||
|
|||||||
@@ -185,6 +185,10 @@ class ServerArgs:
|
|||||||
debug_tensor_dump_input_file: Optional[str] = None
|
debug_tensor_dump_input_file: Optional[str] = None
|
||||||
debug_tensor_dump_inject: bool = False
|
debug_tensor_dump_inject: bool = False
|
||||||
|
|
||||||
|
# For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
|
||||||
|
disaggregation_mode: str = "null"
|
||||||
|
disaggregation_bootstrap_port: int = 8998
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Set missing default values
|
# Set missing default values
|
||||||
if self.tokenizer_path is None:
|
if self.tokenizer_path is None:
|
||||||
@@ -325,6 +329,18 @@ class ServerArgs:
|
|||||||
if is_hip():
|
if is_hip():
|
||||||
self.triton_attention_num_kv_splits = 16
|
self.triton_attention_num_kv_splits = 16
|
||||||
|
|
||||||
|
# PD disaggregation
|
||||||
|
if self.disaggregation_mode == "prefill":
|
||||||
|
self.disable_cuda_graph = True
|
||||||
|
logger.warning("KV cache is forced as chunk cache for decode server")
|
||||||
|
self.disable_overlap_schedule = True
|
||||||
|
logger.warning("Overlap scheduler is disabled for prefill server")
|
||||||
|
elif self.disaggregation_mode == "decode":
|
||||||
|
self.disable_radix_cache = True
|
||||||
|
logger.warning("Cuda graph is disabled for prefill server")
|
||||||
|
self.disable_overlap_schedule = True
|
||||||
|
logger.warning("Overlap scheduler is disabled for decode server")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: argparse.ArgumentParser):
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
# Model and port args
|
# Model and port args
|
||||||
@@ -1063,6 +1079,21 @@ class ServerArgs:
|
|||||||
help="Inject the outputs from jax as the input of every layer.",
|
help="Inject the outputs from jax as the input of every layer.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Disaggregation
|
||||||
|
parser.add_argument(
|
||||||
|
"--disaggregation-mode",
|
||||||
|
type=str,
|
||||||
|
default="null",
|
||||||
|
choices=["null", "prefill", "decode"],
|
||||||
|
help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disaggregation-bootstrap-port",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.disaggregation_bootstrap_port,
|
||||||
|
help="Bootstrap server port on the prefill server. Default is 8998.",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
args.tp_size = args.tensor_parallel_size
|
args.tp_size = args.tensor_parallel_size
|
||||||
|
|||||||
Reference in New Issue
Block a user