[PD] Release initial code (#4654)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: Ying1123 <sqy1415@gmail.com> Co-authored-by: merrymercy <lianminzheng@gmail.com> Co-authored-by: makro Co-authored-by: dhou-xai
This commit is contained in:
81
python/sglang/srt/disaggregation/conn.py
Normal file
81
python/sglang/srt/disaggregation/conn.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KVArgs:
|
||||
engine_rank: int
|
||||
kv_data_ptrs: list[int]
|
||||
kv_data_lens: list[int]
|
||||
kv_item_lens: list[int]
|
||||
aux_data_ptrs: list[int]
|
||||
aux_data_lens: list[int]
|
||||
aux_item_lens: list[int]
|
||||
ib_device: str
|
||||
|
||||
|
||||
class KVManager:
|
||||
def __init__(self, args: KVArgs): ...
|
||||
|
||||
|
||||
class KVPoll:
|
||||
Failed = 0
|
||||
Bootstrapping = 1
|
||||
WaitingForInput = 2
|
||||
Transferring = 3
|
||||
Success = 4
|
||||
|
||||
|
||||
class KVSender:
|
||||
def __init__(self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int):
|
||||
self.has_sent = False
|
||||
|
||||
def init(self, num_kv_indices: int, aux_index: Optional[int] = None): ...
|
||||
|
||||
def send(self, kv_indices: npt.NDArray[np.int32]):
|
||||
self.has_sent = True
|
||||
|
||||
def poll(self) -> KVPoll:
|
||||
if self.has_sent is False:
|
||||
# Assume handshake completed instantly
|
||||
return KVPoll.WaitingForInput
|
||||
else:
|
||||
# Assume transfer completed instantly
|
||||
return KVPoll.Success
|
||||
|
||||
def failure_exception(self):
|
||||
raise Exception("Fake KVSender Exception")
|
||||
|
||||
|
||||
class KVReceiver:
|
||||
def __init__(
|
||||
self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None
|
||||
):
|
||||
self.has_init = False
|
||||
|
||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||
self.has_init = True
|
||||
|
||||
def poll(self) -> KVPoll:
|
||||
if self.has_init is False:
|
||||
# Assume handshake completed instantly
|
||||
return KVPoll.WaitingForInput
|
||||
else:
|
||||
# Assume transfer completed instantly
|
||||
return KVPoll.Success
|
||||
|
||||
def failure_exception(self):
|
||||
raise Exception("Fake KVReceiver Exception")
|
||||
|
||||
|
||||
class KVBootstrapServer:
|
||||
def __init__(self, port: int): ...
|
||||
|
||||
def poll(self) -> KVPoll: ...
|
||||
495
python/sglang/srt/disaggregation/decode.py
Normal file
495
python/sglang/srt/disaggregation/decode.py
Normal file
@@ -0,0 +1,495 @@
|
||||
"""
|
||||
Life cycle of a request in the decode server
|
||||
|
||||
1. PreallocQueue:
|
||||
a. Initialize a receiver for each request
|
||||
b. The request handshakes first, and pre-allocate kv once there is available kv.
|
||||
c. Move the request to TransferQueue.
|
||||
|
||||
2. TransferQueue:
|
||||
a. Poll the receiver to check the transfer state
|
||||
b. If the transfer has finished, move the request to waiting queue
|
||||
|
||||
3. WaitingQueue:
|
||||
a. Use the requests in the queue to construct a PrebuiltExtendBatch
|
||||
b. Skip the prefill forward but only populate metadata
|
||||
|
||||
4. RunningBatch:
|
||||
a. Merge the resolved PrebuiltExtendBatch into running batch to run decoding
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVReceiver
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
ReqToMetadataIdxAllocator,
|
||||
poll_and_all_reduce,
|
||||
)
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.managers.scheduler import Scheduler
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodeRequest:
|
||||
req: Req
|
||||
kv_receiver: KVReceiver
|
||||
waiting_for_input: bool = False
|
||||
metadata_buffer_index: int = -1
|
||||
|
||||
|
||||
class DecodePreallocQueue:
|
||||
"""
|
||||
Store the requests that are preallocating.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||
metadata_buffers: List[torch.Tensor],
|
||||
aux_dtype: torch.dtype,
|
||||
scheduler: Scheduler,
|
||||
transfer_queue: DecodeTransferQueue,
|
||||
tree_cache: BasePrefixCache,
|
||||
gloo_group: ProcessGroup,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
bootstrap_port: int,
|
||||
):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
|
||||
self.aux_dtype = aux_dtype
|
||||
self.metadata_buffers = metadata_buffers
|
||||
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||
self.scheduler = scheduler
|
||||
self.transfer_queue = transfer_queue
|
||||
self.tree_cache = tree_cache # this is always a chunk cache
|
||||
self.gloo_group = gloo_group
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.bootstrap_port = bootstrap_port
|
||||
|
||||
self.num_reserved_decode_tokens = 512
|
||||
|
||||
# Queue for requests pending pre-allocation
|
||||
self.queue: List[DecodeRequest] = []
|
||||
self.kv_manager = self._init_kv_manager()
|
||||
|
||||
def _init_kv_manager(self) -> KVManager:
|
||||
kv_args = KVArgs()
|
||||
kv_args.engine_rank = self.tp_rank
|
||||
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
||||
self.token_to_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
|
||||
kv_args.kv_data_ptrs = kv_data_ptrs
|
||||
kv_args.kv_data_lens = kv_data_lens
|
||||
kv_args.kv_item_lens = kv_item_lens
|
||||
|
||||
kv_args.aux_data_ptrs = [
|
||||
output_id_tensor.data_ptr() for output_id_tensor in self.metadata_buffers
|
||||
]
|
||||
kv_args.aux_data_lens = [
|
||||
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
|
||||
]
|
||||
kv_args.aux_item_lens = [
|
||||
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
||||
]
|
||||
kv_args.ib_device = "mock-ib-device"
|
||||
kv_manager = KVManager(kv_args)
|
||||
return kv_manager
|
||||
|
||||
def add(self, req: Req) -> None:
|
||||
"""Add a request to the pending queue."""
|
||||
|
||||
kv_receiver = KVReceiver(
|
||||
mgr=self.kv_manager,
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
||||
bootstrap_room=req.bootstrap_room,
|
||||
)
|
||||
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
|
||||
|
||||
def extend(self, reqs: List[Req]) -> None:
|
||||
"""Add a request to the pending queue."""
|
||||
for req in reqs:
|
||||
self.add(req)
|
||||
|
||||
def _update_handshake_waiters(self) -> None:
|
||||
if not self.queue:
|
||||
return
|
||||
|
||||
if all(decode_req.waiting_for_input for decode_req in self.queue):
|
||||
return
|
||||
|
||||
polls = poll_and_all_reduce(
|
||||
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
||||
)
|
||||
|
||||
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
||||
if poll == KVPoll.Bootstrapping:
|
||||
pass
|
||||
elif poll == KVPoll.WaitingForInput:
|
||||
decode_req.waiting_for_input = True
|
||||
elif poll == KVPoll.Failed:
|
||||
raise Exception("Handshake failed")
|
||||
|
||||
def pop_preallocated(self) -> List[DecodeRequest]:
|
||||
"""Pop the preallocated requests from the pending queue (FIFO)."""
|
||||
self._update_handshake_waiters()
|
||||
|
||||
preallocated_reqs = []
|
||||
indices_to_remove = set()
|
||||
allocatable_tokens = self._allocatable_tokens(count_retracted=True)
|
||||
|
||||
for i, decode_req in enumerate(self.queue):
|
||||
if not decode_req.waiting_for_input:
|
||||
continue
|
||||
|
||||
if self.req_to_token_pool.available_size() <= 0:
|
||||
break
|
||||
|
||||
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
|
||||
break
|
||||
|
||||
required_tokens_for_request = (
|
||||
len(decode_req.req.origin_input_ids) + self.num_reserved_decode_tokens
|
||||
)
|
||||
|
||||
if required_tokens_for_request > allocatable_tokens:
|
||||
break
|
||||
|
||||
allocatable_tokens -= required_tokens_for_request
|
||||
self._pre_alloc(decode_req.req)
|
||||
|
||||
kv_indices = (
|
||||
self.req_to_token_pool.req_to_token[decode_req.req.req_pool_idx][
|
||||
: len(decode_req.req.origin_input_ids)
|
||||
]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
decode_req.metadata_buffer_index = (
|
||||
self.req_to_metadata_buffer_idx_allocator.alloc()
|
||||
)
|
||||
assert decode_req.metadata_buffer_index is not None
|
||||
decode_req.kv_receiver.init(kv_indices, decode_req.metadata_buffer_index)
|
||||
preallocated_reqs.append(decode_req)
|
||||
indices_to_remove.add(i)
|
||||
|
||||
self.queue = [
|
||||
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
||||
]
|
||||
|
||||
return preallocated_reqs
|
||||
|
||||
def _allocatable_tokens(self) -> int:
|
||||
allocatable_tokens = (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
- self.num_reserved_decode_tokens
|
||||
* (
|
||||
len(self.scheduler.running_batch.reqs)
|
||||
+ len(self.transfer_queue.queue)
|
||||
+ len(self.scheduler.waiting_queue)
|
||||
)
|
||||
)
|
||||
|
||||
# Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
|
||||
# the extend batch is not in any queue, so we need to explicitly add the tokens slots here
|
||||
if (
|
||||
self.scheduler.last_batch
|
||||
and self.scheduler.last_batch.forward_mode.is_extend()
|
||||
):
|
||||
allocatable_tokens -= self.num_reserved_decode_tokens * len(
|
||||
self.scheduler.last_batch.reqs
|
||||
)
|
||||
|
||||
return allocatable_tokens
|
||||
|
||||
def _pre_alloc(self, req: Req) -> torch.Tensor:
|
||||
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
|
||||
req_pool_indices = self.req_to_token_pool.alloc(1)
|
||||
|
||||
assert req_pool_indices is not None
|
||||
|
||||
req.req_pool_idx = req_pool_indices[0]
|
||||
kv_loc = self.token_to_kv_pool_allocator.alloc(
|
||||
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
||||
)
|
||||
|
||||
assert kv_loc is not None
|
||||
|
||||
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
|
||||
|
||||
# populate metadata
|
||||
req.fill_ids = req.origin_input_ids + req.output_ids
|
||||
req.extend_input_len = len(req.origin_input_ids)
|
||||
|
||||
return kv_loc
|
||||
|
||||
|
||||
class DecodeTransferQueue:
|
||||
"""
|
||||
Store the requests that is polling kv
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gloo_group: ProcessGroup,
|
||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||
metadata_buffers: torch.Tensor,
|
||||
):
|
||||
self.queue: List[DecodeRequest] = []
|
||||
self.gloo_group = gloo_group
|
||||
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||
self.metadata_buffers = metadata_buffers
|
||||
|
||||
def add(self, req_conn: DecodeRequest) -> None:
|
||||
self.queue.append(req_conn)
|
||||
|
||||
def extend(self, req_conns) -> None:
|
||||
self.queue.extend(req_conns)
|
||||
|
||||
def pop_transferred(self) -> List[Req]:
|
||||
if not self.queue:
|
||||
return []
|
||||
|
||||
polls = poll_and_all_reduce(
|
||||
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
||||
)
|
||||
|
||||
transferred_reqs = []
|
||||
indices_to_remove = set()
|
||||
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
||||
if poll == KVPoll.Failed:
|
||||
raise Exception("Transfer failed")
|
||||
elif poll == KVPoll.Success:
|
||||
# pop and push it to waiting queue
|
||||
idx = decode_req.metadata_buffer_index
|
||||
assert len(decode_req.req.output_ids) == 0
|
||||
output_id_buffer = self.metadata_buffers[0]
|
||||
# the last dimension is padded by the same values.
|
||||
output_id = output_id_buffer[idx][0].item()
|
||||
assert len(decode_req.req.output_ids) == 0
|
||||
assert decode_req.req.transferred_output_id is None
|
||||
decode_req.req.transferred_output_id = output_id
|
||||
transferred_reqs.append(decode_req.req)
|
||||
indices_to_remove.add(i)
|
||||
elif poll in [
|
||||
KVPoll.Bootstrapping,
|
||||
KVPoll.WaitingForInput,
|
||||
KVPoll.Transferring,
|
||||
]:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unexpected poll case: {poll}")
|
||||
|
||||
for i in indices_to_remove:
|
||||
idx = self.queue[i].metadata_buffer_index
|
||||
assert idx != -1
|
||||
self.req_to_metadata_buffer_idx_allocator.free(idx)
|
||||
|
||||
self.queue = [
|
||||
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
||||
]
|
||||
|
||||
return transferred_reqs
|
||||
|
||||
|
||||
class ScheduleBatchDisaggregationDecodeMixin:
|
||||
|
||||
def prepare_for_prebuilt_extend(self: ScheduleBatch):
|
||||
"""
|
||||
Prepare a prebuilt extend by populate metadata
|
||||
Adapted from .prepare_for_extend().
|
||||
"""
|
||||
|
||||
self.forward_mode = ForwardMode.EXTEND
|
||||
reqs = self.reqs
|
||||
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
||||
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
||||
seq_lens = []
|
||||
pre_lens = []
|
||||
req_pool_indices = []
|
||||
|
||||
# Pre-calculate total size
|
||||
total_size = sum(req.extend_input_len for req in reqs)
|
||||
out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
|
||||
|
||||
# Fill the tensor in one pass
|
||||
offset = 0
|
||||
for i, req in enumerate(reqs):
|
||||
req_pool_indices.append(req.req_pool_idx)
|
||||
|
||||
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
: req.extend_input_len
|
||||
]
|
||||
assert (
|
||||
offset + req.extend_input_len <= total_size
|
||||
), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
|
||||
out_cache_loc[offset : offset + req.extend_input_len] = chunk
|
||||
offset += req.extend_input_len
|
||||
|
||||
pre_len = len(req.prefix_indices)
|
||||
seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
|
||||
seq_lens.append(seq_len)
|
||||
if len(req.output_ids) == 0:
|
||||
assert (
|
||||
seq_len - pre_len == req.extend_input_len
|
||||
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
|
||||
|
||||
req.cached_tokens += pre_len - req.already_computed
|
||||
req.already_computed = seq_len
|
||||
req.is_retracted = False
|
||||
pre_lens.append(pre_len)
|
||||
req.extend_logprob_start_len = 0
|
||||
|
||||
extend_input_logprob_token_ids = None
|
||||
|
||||
# Set fields
|
||||
self.input_ids = torch.tensor(
|
||||
sum(input_ids, []), dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.req_pool_indices = torch.tensor(
|
||||
req_pool_indices, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.seq_lens_sum = sum(seq_lens)
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
||||
|
||||
# Build sampling info
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
)
|
||||
|
||||
def process_prebuilt_extend(
|
||||
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
|
||||
):
|
||||
"""Assign the buffered last input id to schedule batch"""
|
||||
self.output_ids = []
|
||||
for req in self.reqs:
|
||||
if req.output_ids and len(req.output_ids) > 0:
|
||||
# resumed retracted req
|
||||
self.output_ids.append(req.output_ids[-1])
|
||||
else:
|
||||
assert req.transferred_output_id is not None
|
||||
req.output_ids.append(req.transferred_output_id)
|
||||
self.output_ids.append(req.transferred_output_id)
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
||||
|
||||
|
||||
class SchedulerDisaggregationDecodeMixin:
|
||||
|
||||
def get_next_disagg_decode_batch_to_run(
|
||||
self: Scheduler,
|
||||
) -> Optional[Tuple[ScheduleBatch, bool]]:
|
||||
"""Create fake completed prefill if possible and merge with running batch"""
|
||||
# Merge the prefill batch into the running batch
|
||||
last_batch = self.last_batch
|
||||
if last_batch and last_batch.forward_mode.is_extend():
|
||||
# chunked prefill doesn't happen in decode instance.
|
||||
assert self.chunked_req is None
|
||||
# Filter finished batches.
|
||||
last_batch.filter_batch()
|
||||
if not last_batch.is_empty():
|
||||
if self.running_batch.is_empty():
|
||||
self.running_batch = last_batch
|
||||
else:
|
||||
# merge running_batch with prefill batch
|
||||
self.running_batch.merge_batch(last_batch)
|
||||
|
||||
new_prebuilt_batch = self.get_new_prebuilt_batch()
|
||||
|
||||
ret: Optional[ScheduleBatch] = None
|
||||
if new_prebuilt_batch:
|
||||
ret = new_prebuilt_batch
|
||||
else:
|
||||
if self.running_batch.is_empty():
|
||||
ret = None
|
||||
else:
|
||||
self.running_batch = self.update_running_batch(self.running_batch)
|
||||
ret = self.running_batch if not self.running_batch.is_empty() else None
|
||||
|
||||
return ret
|
||||
|
||||
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
|
||||
"""Create a schedulebatch for fake completed prefill"""
|
||||
if len(self.waiting_queue) == 0:
|
||||
return None
|
||||
|
||||
curr_batch_size = self.running_batch.batch_size()
|
||||
|
||||
batch_size = min(self.req_to_token_pool.size, self.max_running_requests)
|
||||
|
||||
num_not_used_batch = batch_size - curr_batch_size
|
||||
|
||||
# pop req from waiting queue
|
||||
can_run_list: List[Req] = []
|
||||
waiting_queue: List[Req] = []
|
||||
|
||||
for i in range(len(self.waiting_queue)):
|
||||
req = self.waiting_queue[i]
|
||||
# we can only add at least `num_not_used_batch` new batch to the running queue
|
||||
if i < num_not_used_batch:
|
||||
can_run_list.append(req)
|
||||
req.init_next_round_input(self.tree_cache)
|
||||
else:
|
||||
waiting_queue.append(req)
|
||||
|
||||
self.waiting_queue = waiting_queue
|
||||
if len(can_run_list) == 0:
|
||||
return None
|
||||
# local import to avoid circular import
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
|
||||
# construct a schedule batch with those requests and mark as decode
|
||||
new_batch = ScheduleBatch.init_new(
|
||||
can_run_list,
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool_allocator,
|
||||
self.tree_cache,
|
||||
self.model_config,
|
||||
self.enable_overlap,
|
||||
self.spec_algorithm,
|
||||
self.server_args.enable_custom_logit_processor,
|
||||
)
|
||||
|
||||
# construct fake completed prefill
|
||||
new_batch.prepare_for_prebuilt_extend()
|
||||
new_batch.process_prebuilt_extend(self.server_args, self.model_config)
|
||||
|
||||
return new_batch
|
||||
|
||||
def process_decode_queue(self: Scheduler):
|
||||
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
||||
self.disagg_decode_transfer_queue.extend(req_conns)
|
||||
alloc_reqs = (
|
||||
self.disagg_decode_transfer_queue.pop_transferred()
|
||||
) # the requests which kv has arrived
|
||||
self.waiting_queue.extend(alloc_reqs)
|
||||
285
python/sglang/srt/disaggregation/mini_lb.py
Normal file
285
python/sglang/srt/disaggregation/mini_lb.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Minimal HTTP load balancer for prefill and decode servers for testing purpose.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import urllib
|
||||
from itertools import chain
|
||||
|
||||
import aiohttp
|
||||
import orjson
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
|
||||
class MiniLoadBalancer:
|
||||
def __init__(self, prefill_servers, decode_servers):
|
||||
self.prefill_servers = prefill_servers
|
||||
self.decode_servers = decode_servers
|
||||
|
||||
def select_pair(self):
|
||||
return random.choice(self.prefill_servers), random.choice(self.decode_servers)
|
||||
|
||||
async def generate_request(self, request_data):
|
||||
prefill_server, decode_server = self.select_pair()
|
||||
|
||||
# Parse and transform prefill_server
|
||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||
hostname = parsed_url.hostname
|
||||
bootstrap_host = f"{hostname}"
|
||||
|
||||
modified_request = request_data.copy()
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": bootstrap_host,
|
||||
"bootstrap_room": random.randint(0, 2**63 - 1),
|
||||
}
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create the tasks
|
||||
tasks = [
|
||||
session.post(f"{prefill_server}/generate", json=modified_request),
|
||||
session.post(f"{decode_server}/generate", json=modified_request),
|
||||
]
|
||||
|
||||
prefill_response = None
|
||||
decode_response = None
|
||||
|
||||
# Process responses as they arrive
|
||||
for i, response in enumerate(asyncio.as_completed(tasks)):
|
||||
response = await response
|
||||
# Check if this is the prefill or decode response based on order created
|
||||
if i == 0: # First completed task
|
||||
if str(response.url).startswith(prefill_server):
|
||||
prefill_response = response
|
||||
if response.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"Prefill server error: Status {response.status} Details: {await response.text()}",
|
||||
)
|
||||
else:
|
||||
decode_response = response
|
||||
if response.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"Decode server error: Status {response.status} Details: {await response.text()}",
|
||||
)
|
||||
else: # Second completed task
|
||||
if str(response.url).startswith(prefill_server):
|
||||
prefill_response = response
|
||||
else:
|
||||
decode_response = response
|
||||
|
||||
if response.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"{'Prefill' if str(response.url).startswith(prefill_server) else 'Decode'} server error: Status {response.status} Details: {await response.text()}",
|
||||
)
|
||||
|
||||
return await decode_response.json()
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
load_balancer = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/health_generate")
|
||||
async def health_check():
|
||||
prefill_servers, decode_servers = (
|
||||
load_balancer.prefill_servers,
|
||||
load_balancer.decode_servers,
|
||||
)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create the tasks
|
||||
tasks = []
|
||||
for server in chain(prefill_servers, decode_servers):
|
||||
tasks.append(session.post(f"{server}/health_generate"))
|
||||
for i, response in enumerate(asyncio.as_completed(tasks)):
|
||||
await response
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/flush_cache")
|
||||
async def flush_cache():
|
||||
prefill_servers, decode_servers = (
|
||||
load_balancer.prefill_servers,
|
||||
load_balancer.decode_servers,
|
||||
)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create the tasks
|
||||
tasks = []
|
||||
for server in chain(prefill_servers, decode_servers):
|
||||
tasks.append(session.post(f"{server}/flush_cache"))
|
||||
for i, response in enumerate(asyncio.as_completed(tasks)):
|
||||
await response
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/get_server_info")
|
||||
async def get_server_info():
|
||||
prefill_servers, decode_servers = (
|
||||
load_balancer.prefill_servers,
|
||||
load_balancer.decode_servers,
|
||||
)
|
||||
prefill_infos = []
|
||||
decode_infos = []
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for server in chain(prefill_servers):
|
||||
server_info = await session.get(f"{server}/get_server_info")
|
||||
prefill_infos.append(await server_info.json())
|
||||
for server in chain(decode_servers):
|
||||
server_info = await session.get(f"{server}/get_server_info")
|
||||
decode_infos.append(await server_info.json())
|
||||
|
||||
return {"prefill": prefill_infos, "decode": decode_infos}
|
||||
|
||||
|
||||
@app.get("/get_model_info")
|
||||
async def get_model_info():
|
||||
# Dummy model information
|
||||
model_info = {
|
||||
"model_path": "/path/to/dummy/model",
|
||||
"tokenizer_path": "/path/to/dummy/tokenizer",
|
||||
"is_generation": True,
|
||||
"preferred_sampling_params": {"temperature": 0.7, "max_new_tokens": 128},
|
||||
}
|
||||
return ORJSONResponse(content=model_info)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def handle_generate_request(request_data: dict):
|
||||
prefill_server, decode_server = load_balancer.select_pair()
|
||||
|
||||
# Parse and transform prefill_server for bootstrap data
|
||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||
hostname = parsed_url.hostname
|
||||
modified_request = request_data.copy()
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": hostname,
|
||||
"bootstrap_room": random.randint(0, 2**63 - 1),
|
||||
}
|
||||
)
|
||||
|
||||
# Check if streaming is requested
|
||||
if request_data.get("stream", False):
|
||||
|
||||
async def stream_results():
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=3600)
|
||||
) as session:
|
||||
try:
|
||||
# Create the tasks
|
||||
tasks = [
|
||||
session.post(
|
||||
f"{prefill_server}/generate", json=modified_request
|
||||
),
|
||||
session.post(
|
||||
f"{decode_server}/generate", json=modified_request
|
||||
),
|
||||
]
|
||||
|
||||
prefill_response = None
|
||||
decode_response = None
|
||||
|
||||
# Process responses as they arrive
|
||||
for i, response_task in enumerate(asyncio.as_completed(tasks)):
|
||||
response = await response_task
|
||||
|
||||
# Check the response immediately
|
||||
if str(response.url).startswith(prefill_server):
|
||||
prefill_response = response
|
||||
if response.status != 200:
|
||||
error_msg = {
|
||||
"error": {
|
||||
"message": f"Prefill server error: Status {response.status}, Details: {await response.text()}"
|
||||
}
|
||||
}
|
||||
yield b"data: " + orjson.dumps(
|
||||
error_msg, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
return
|
||||
else:
|
||||
decode_response = response
|
||||
if response.status != 200:
|
||||
error_msg = {
|
||||
"error": {
|
||||
"message": f"Decode server error: Status {response.status}"
|
||||
}
|
||||
}
|
||||
yield b"data: " + orjson.dumps(
|
||||
error_msg, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
return
|
||||
|
||||
# Stream successful decode server response
|
||||
async for line in decode_response.content:
|
||||
yield line
|
||||
yield b"data: [DONE]\n\n"
|
||||
|
||||
except Exception as e:
|
||||
error_msg = {
|
||||
"error": {"message": f"Stream processing error: {str(e)}"}
|
||||
}
|
||||
yield b"data: " + orjson.dumps(
|
||||
error_msg, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
stream_results(),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Non-streaming case
|
||||
result = await load_balancer.generate_request(request_data)
|
||||
return ORJSONResponse(content=result)
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def get_models():
|
||||
prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
response = await session.get(f"{prefill_server}/v1/models")
|
||||
if response.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"Prefill server error: Status {response.status}",
|
||||
)
|
||||
return ORJSONResponse(content=await response.json())
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def run(prefill_addrs, decode_addrs, host, port):
|
||||
global load_balancer
|
||||
load_balancer = MiniLoadBalancer(prefill_addrs, decode_addrs)
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
|
||||
parser.add_argument(
|
||||
"--prefill", required=True, help="Comma-separated URLs for prefill servers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode", required=True, help="Comma-separated URLs for decode servers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
run(args.prefill.split(","), args.decode.split(","), args.host, args.port)
|
||||
249
python/sglang/srt/disaggregation/prefill.py
Normal file
249
python/sglang/srt/disaggregation/prefill.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""
|
||||
Life cycle of a request in the prefill server
|
||||
|
||||
1. Bootstrap Queue
|
||||
a. Initialize a sender for each request
|
||||
b. Use the queue to store requests whose bootstrap (handshake and preallocation) has not finished
|
||||
c. Poll senders to check bootstrap state
|
||||
d. Once bootstrap is complete, move request to Waiting Queue
|
||||
|
||||
2. Waiting Queue
|
||||
a. Use PrefillAdder to pop requests
|
||||
b. Run forward
|
||||
c. Add the request to Infight Queue
|
||||
|
||||
3. Infight Queue
|
||||
a. Poll (non-blocking) the sender of the request
|
||||
b. Once the transfer has finished, return the request
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVSender
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
ReqToMetadataIdxAllocator,
|
||||
poll_and_all_reduce,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PrefillBootstrapQueue:
|
||||
"""
|
||||
Store the requests in bootstrapping
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_to_kv_pool: KVCache,
|
||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||
metadata_buffers: List[torch.Tensor],
|
||||
aux_dtype: torch.dtype,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
bootstrap_port: int,
|
||||
gloo_group: ProcessGroup,
|
||||
):
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.aux_dtype = aux_dtype
|
||||
|
||||
self.metadata_buffers = metadata_buffers
|
||||
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.kv_manager = self._init_kv_manager()
|
||||
self.queue: List[Req] = []
|
||||
self.gloo_group = gloo_group
|
||||
self.bootstrap_port = bootstrap_port
|
||||
|
||||
def allocate_token_id(self, idx: int, token_id: int):
|
||||
assert token_id >= 0, f"token_id: {token_id} is negative"
|
||||
output_id_buffer = self.metadata_buffers[0]
|
||||
output_id_buffer[idx] = token_id
|
||||
|
||||
def _init_kv_manager(self) -> KVManager:
|
||||
kv_args = KVArgs()
|
||||
kv_args.engine_rank = self.tp_rank
|
||||
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
||||
self.token_to_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
|
||||
kv_args.kv_data_ptrs = kv_data_ptrs
|
||||
kv_args.kv_data_lens = kv_data_lens
|
||||
kv_args.kv_item_lens = kv_item_lens
|
||||
|
||||
# Define req -> input ids buffer
|
||||
kv_args.aux_data_ptrs = [
|
||||
metadata_buffer.data_ptr() for metadata_buffer in self.metadata_buffers
|
||||
]
|
||||
kv_args.aux_data_lens = [
|
||||
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
|
||||
]
|
||||
kv_args.aux_item_lens = [
|
||||
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
||||
]
|
||||
kv_args.ib_device = "mock-ib-device"
|
||||
kv_manager = KVManager(kv_args)
|
||||
return kv_manager
|
||||
|
||||
def add(self, req: Req) -> None:
|
||||
req.disagg_kv_sender = KVSender(
|
||||
mgr=self.kv_manager,
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
||||
bootstrap_room=req.bootstrap_room,
|
||||
)
|
||||
self._process_req(req)
|
||||
self.queue.append(req)
|
||||
|
||||
def _process_req(self, req: Req) -> None:
|
||||
"""
|
||||
Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate
|
||||
"""
|
||||
req.sampling_params.max_new_tokens = 1
|
||||
|
||||
def pop_bootstrapped(self) -> List[Req]:
|
||||
"""pop the reqs which has finished bootstrapping"""
|
||||
bootstrapped_reqs = []
|
||||
indices_to_remove = set()
|
||||
|
||||
if len(self.queue) == 0:
|
||||
return []
|
||||
|
||||
polls = poll_and_all_reduce(
|
||||
[req.disagg_kv_sender for req in self.queue], self.gloo_group
|
||||
)
|
||||
|
||||
for i, (req, poll) in enumerate(zip(self.queue, polls)):
|
||||
if poll == KVPoll.Bootstrapping:
|
||||
continue
|
||||
elif poll == KVPoll.Failed:
|
||||
raise Exception("Bootstrap failed")
|
||||
|
||||
# KV.WaitingForInput - init here
|
||||
num_kv_indices = len(req.origin_input_ids)
|
||||
if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
|
||||
break
|
||||
|
||||
req.metadata_buffer_index = (
|
||||
self.req_to_metadata_buffer_idx_allocator.alloc()
|
||||
)
|
||||
assert req.metadata_buffer_index is not None
|
||||
req.disagg_kv_sender.init(num_kv_indices, req.metadata_buffer_index)
|
||||
|
||||
bootstrapped_reqs.append(req)
|
||||
indices_to_remove.add(i)
|
||||
|
||||
self.queue = [
|
||||
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
||||
]
|
||||
|
||||
return bootstrapped_reqs
|
||||
|
||||
|
||||
class SchedulerDisaggregationPrefillMixin:
|
||||
"""
|
||||
Mixin for Scheduler to handle disaggregation prefill
|
||||
"""
|
||||
|
||||
def process_batch_result_disagg_prefill(
|
||||
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
|
||||
) -> None:
|
||||
"""
|
||||
Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
||||
Adapted from process_batch_result_prefill
|
||||
"""
|
||||
|
||||
next_token_ids = result.next_token_ids.tolist()
|
||||
|
||||
for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
|
||||
req: Req
|
||||
if req.is_chunked <= 0:
|
||||
# There is no output_ids for prefill
|
||||
req.output_ids.append(next_token_id)
|
||||
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
||||
self.send_kv_chunk(req, token_id=next_token_id)
|
||||
self.disagg_prefill_infight_queue.append(req)
|
||||
else:
|
||||
# being chunked reqs' prefill is not finished
|
||||
req.is_chunked -= 1
|
||||
|
||||
# TODO: Not sure if this is necessary
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
# We need to remove this for overlap schedule.
|
||||
self.current_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
def process_disagg_prefill_infight_queue(self: Scheduler) -> None:
|
||||
"""
|
||||
Poll the requests in the middle of transfer. If done, return the request.
|
||||
"""
|
||||
assert len(self.disagg_prefill_infight_queue) > 0
|
||||
|
||||
done_reqs = []
|
||||
|
||||
polls = poll_and_all_reduce(
|
||||
[req.disagg_kv_sender for req in self.disagg_prefill_infight_queue],
|
||||
self.tp_worker.get_tp_cpu_group(),
|
||||
)
|
||||
|
||||
undone_reqs: List[Req] = []
|
||||
# Check .poll() for the reqs in disagg_prefill_infight_queue. If Success, respond to the client and remove it from the queue
|
||||
for req, poll in zip(self.disagg_prefill_infight_queue, polls):
|
||||
if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
|
||||
undone_reqs.append(req)
|
||||
elif poll == KVPoll.Success: # transfer done
|
||||
self.tree_cache.cache_finished_req(req) # unlock the tree
|
||||
req.finished_reason = FINISH_LENGTH(length=0)
|
||||
done_reqs.append(req)
|
||||
elif poll == KVPoll.Failed:
|
||||
raise Exception("Transferring failed")
|
||||
|
||||
# Stream requests which have finished transfer
|
||||
self.stream_output(done_reqs, False, None)
|
||||
|
||||
self.disagg_prefill_infight_queue = undone_reqs
|
||||
|
||||
def process_prefill_chunk(self: Scheduler) -> None:
|
||||
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
||||
if self.chunked_req:
|
||||
# Move the chunked request out of the batch so that we can merge
|
||||
# only finished requests to running_batch.
|
||||
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
||||
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
||||
self.send_kv_chunk(self.chunked_req)
|
||||
# chunked request keeps its rid but will get a new req_pool_idx
|
||||
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
def send_kv_chunk(
|
||||
self: Scheduler, req: Req, token_id: Optional[int] = None
|
||||
) -> None:
|
||||
"""
|
||||
Send a prefilled chunk to the decode server
|
||||
"""
|
||||
start_idx = req.start_send_idx
|
||||
end_idx = min(len(req.fill_ids), len(req.origin_input_ids))
|
||||
kv_indices = (
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][start_idx:end_idx]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
req.start_send_idx = end_idx
|
||||
if token_id is not None:
|
||||
self.disagg_prefill_pending_queue.allocate_token_id(
|
||||
req.metadata_buffer_index, token_id
|
||||
)
|
||||
req.disagg_kv_sender.send(kv_indices)
|
||||
44
python/sglang/srt/disaggregation/utils.py
Normal file
44
python/sglang/srt/disaggregation/utils.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class DisaggregationMode(Enum):
|
||||
NULL = "null"
|
||||
PREFILL = "prefill"
|
||||
DECODE = "decode"
|
||||
|
||||
|
||||
def poll_and_all_reduce(pollers, gloo_group):
|
||||
polls = [int(poller.poll()) for poller in pollers]
|
||||
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
|
||||
dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group)
|
||||
return tensor_to_reduce.tolist()
|
||||
|
||||
|
||||
class ReqToMetadataIdxAllocator:
|
||||
"""A memory pool that maps a request to its first output token location."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
):
|
||||
self.size = size
|
||||
self.free_slots = deque(list(range(size)))
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
def alloc(self) -> List[int]:
|
||||
if len(self.free_slots) == 0:
|
||||
return None
|
||||
|
||||
return self.free_slots.popleft()
|
||||
|
||||
def free(self, free_index: int):
|
||||
self.free_slots.append(free_index)
|
||||
@@ -42,6 +42,8 @@ import triton.language as tl
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||
from sglang.srt.disaggregation.conn import KVSender
|
||||
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
@@ -396,6 +398,24 @@ class Req:
|
||||
self.spec_verify_ct = 0
|
||||
self.lora_path = lora_path
|
||||
|
||||
# For disaggregation
|
||||
self.bootstrap_host: str = "0.0.0.0"
|
||||
self.bootstrap_room: Optional[int] = None
|
||||
self.disagg_kv_sender: Optional[KVSender] = None
|
||||
|
||||
# used for warmup because we don't have a pair yet when init
|
||||
self.skip_kv_transfer: bool = False
|
||||
# the start index of the sent kv cache
|
||||
# We want to send it chunk by chunk for chunked prefill.
|
||||
# After every chunk forward, we do the following:
|
||||
# kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)])
|
||||
# start_send_idx = len(req.fill_ids)
|
||||
self.start_send_idx: int = 0
|
||||
|
||||
self.metadata_buffer_index: int = -1
|
||||
# The first output_id transferred from prefill instance.
|
||||
self.transferred_output_id: Optional[int] = None
|
||||
|
||||
@property
|
||||
def seqlen(self):
|
||||
return len(self.origin_input_ids) + len(self.output_ids)
|
||||
@@ -531,7 +551,7 @@ bid = 0
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ScheduleBatch:
|
||||
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
"""Store all information of a batch on the scheduler."""
|
||||
|
||||
# Request, memory pool, and cache
|
||||
|
||||
@@ -37,6 +37,19 @@ from torch.distributed import barrier
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
|
||||
from sglang.srt.disaggregation.decode import (
|
||||
DecodePreallocQueue,
|
||||
DecodeTransferQueue,
|
||||
SchedulerDisaggregationDecodeMixin,
|
||||
)
|
||||
from sglang.srt.disaggregation.prefill import (
|
||||
PrefillBootstrapQueue,
|
||||
SchedulerDisaggregationPrefillMixin,
|
||||
)
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
DisaggregationMode,
|
||||
ReqToMetadataIdxAllocator,
|
||||
)
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
@@ -137,7 +150,11 @@ class EmbeddingBatchResult:
|
||||
bid: int
|
||||
|
||||
|
||||
class Scheduler(SchedulerOutputProcessorMixin):
|
||||
class Scheduler(
|
||||
SchedulerOutputProcessorMixin,
|
||||
SchedulerDisaggregationDecodeMixin,
|
||||
SchedulerDisaggregationPrefillMixin,
|
||||
):
|
||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||
|
||||
def __init__(
|
||||
@@ -389,6 +406,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
]
|
||||
)
|
||||
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
)
|
||||
self.init_disaggregation()
|
||||
|
||||
def init_tokenizer(self):
|
||||
server_args = self.server_args
|
||||
|
||||
@@ -489,6 +511,73 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
},
|
||||
)
|
||||
|
||||
def init_disaggregation(self):
|
||||
if (
|
||||
self.disaggregation_mode == DisaggregationMode.DECODE
|
||||
): # *2 for the headroom.
|
||||
buffer_size = (self.req_to_token_pool.size) * 2
|
||||
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||
buffer_size
|
||||
)
|
||||
aux_dtype = torch.int32
|
||||
# A list of metadata buffers. The shape is (b, metadata_size) where
|
||||
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
||||
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
||||
output_id_buffer = torch.zeros(
|
||||
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
||||
)
|
||||
metadata_buffers = [output_id_buffer]
|
||||
|
||||
# The decode requests polling kv cache
|
||||
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
||||
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=metadata_buffers,
|
||||
)
|
||||
|
||||
# The decode requests pending for pre-allocation
|
||||
self.disagg_decode_prealloc_queue = DecodePreallocQueue(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=metadata_buffers,
|
||||
aux_dtype=aux_dtype,
|
||||
scheduler=self,
|
||||
transfer_queue=self.disagg_decode_transfer_queue,
|
||||
tree_cache=self.tree_cache,
|
||||
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
||||
)
|
||||
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# *2 for the headroom.
|
||||
buffer_size = self.max_running_requests * 2
|
||||
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||
buffer_size
|
||||
)
|
||||
aux_dtype = torch.int32
|
||||
# A list of metadata buffers. The shape is (b, metadata_size) where
|
||||
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
||||
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
||||
output_id_buffer = torch.zeros(
|
||||
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
||||
)
|
||||
metadata_buffers = [output_id_buffer]
|
||||
|
||||
self.disagg_prefill_pending_queue = PrefillBootstrapQueue(
|
||||
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=metadata_buffers,
|
||||
aux_dtype=aux_dtype,
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
||||
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
||||
)
|
||||
# The prefill requests that are in the middle of kv sending
|
||||
self.disagg_prefill_infight_queue: List[Req] = []
|
||||
|
||||
@DynamicGradMode()
|
||||
def event_loop_normal(self):
|
||||
"""A normal scheduler loop."""
|
||||
@@ -549,6 +638,70 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
|
||||
self.last_batch = batch
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_normal_disagg_prefill(self):
|
||||
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
self.waiting_queue.extend(
|
||||
self.disagg_prefill_pending_queue.pop_bootstrapped()
|
||||
)
|
||||
self.process_prefill_chunk()
|
||||
batch = self.get_new_batch_prefill()
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result_disagg_prefill(batch, result)
|
||||
|
||||
if len(self.disagg_prefill_infight_queue) > 0:
|
||||
self.process_disagg_prefill_infight_queue()
|
||||
|
||||
if batch is None and len(self.disagg_prefill_infight_queue) == 0:
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
self.last_batch = batch
|
||||
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
||||
# Otherwise, it hangs under high concurrency
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_normal_disagg_decode(self):
|
||||
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
# polling and allocating kv cache
|
||||
self.process_decode_queue()
|
||||
batch = self.get_next_disagg_decode_batch_to_run()
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
# Generate fake extend output.
|
||||
if batch.forward_mode.is_extend():
|
||||
# Note: Logprobs should be handled on the prefill engine.
|
||||
self.stream_output(
|
||||
batch.reqs, [False for _ in range(len(batch.reqs))]
|
||||
)
|
||||
else:
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
|
||||
if batch is None and (
|
||||
len(self.disagg_decode_transfer_queue.queue)
|
||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||
== 0
|
||||
):
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
self.last_batch = batch
|
||||
|
||||
def recv_requests(self) -> List[Req]:
|
||||
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
||||
if self.attn_tp_rank == 0:
|
||||
@@ -778,10 +931,20 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
self._add_request_to_queue(req)
|
||||
|
||||
def _add_request_to_queue(self, req: Req):
|
||||
self.waiting_queue.append(req)
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.disagg_prefill_pending_queue.add(req)
|
||||
|
||||
def _extend_requests_to_queue(self, reqs: List[Req]):
|
||||
self.waiting_queue.extend(reqs)
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.disagg_decode_prealloc_queue.add(req)
|
||||
|
||||
else:
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.disagg_decode_prealloc_queue.extend(reqs)
|
||||
else:
|
||||
self.waiting_queue.extend(reqs)
|
||||
|
||||
def handle_embedding_request(
|
||||
self,
|
||||
@@ -1814,10 +1977,18 @@ def run_scheduler_process(
|
||||
"max_req_input_len": scheduler.max_req_input_len,
|
||||
}
|
||||
)
|
||||
if scheduler.enable_overlap:
|
||||
scheduler.event_loop_overlap()
|
||||
else:
|
||||
scheduler.event_loop_normal()
|
||||
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
||||
|
||||
if disaggregation_mode == DisaggregationMode.NULL:
|
||||
if scheduler.enable_overlap:
|
||||
scheduler.event_loop_overlap()
|
||||
else:
|
||||
scheduler.event_loop_normal()
|
||||
elif disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
scheduler.event_loop_normal_disagg_prefill()
|
||||
elif disaggregation_mode == DisaggregationMode.DECODE:
|
||||
scheduler.event_loop_normal_disagg_decode()
|
||||
|
||||
except Exception:
|
||||
traceback = get_exception_traceback()
|
||||
logger.error(f"Scheduler hit an exception: {traceback}")
|
||||
|
||||
@@ -49,6 +49,8 @@ from fastapi import BackgroundTasks
|
||||
|
||||
from sglang.srt.aio_rwlock import RWLock
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.disaggregation.conn import KVBootstrapServer
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.image_processor import (
|
||||
get_dummy_image_processor,
|
||||
@@ -313,6 +315,16 @@ class TokenizerManager:
|
||||
]
|
||||
)
|
||||
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
)
|
||||
# for disaggregtion, start kv boostrap server on prefill
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# only start bootstrap server on prefill tm
|
||||
self.bootstrap_server = KVBootstrapServer(
|
||||
self.server_args.disaggregation_bootstrap_port
|
||||
)
|
||||
|
||||
async def generate_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
|
||||
@@ -271,6 +271,19 @@ class MHATokenToKVPool(KVCache):
|
||||
v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
|
||||
return k_size_bytes, v_size_bytes
|
||||
|
||||
# for disagg
|
||||
def get_contiguous_buf_infos(self):
|
||||
kv_data_ptrs = [
|
||||
self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
|
||||
] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
|
||||
kv_data_lens = [
|
||||
self.get_key_buffer(i).nbytes for i in range(self.layer_num)
|
||||
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
|
||||
kv_item_lens = [
|
||||
self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num)
|
||||
] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
# Todo: different memory layout
|
||||
def get_flat_data(self, indices):
|
||||
# prepare a large chunk of contiguous data for efficient transfer
|
||||
|
||||
@@ -185,6 +185,10 @@ class ServerArgs:
|
||||
debug_tensor_dump_input_file: Optional[str] = None
|
||||
debug_tensor_dump_inject: bool = False
|
||||
|
||||
# For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
|
||||
disaggregation_mode: str = "null"
|
||||
disaggregation_bootstrap_port: int = 8998
|
||||
|
||||
def __post_init__(self):
|
||||
# Set missing default values
|
||||
if self.tokenizer_path is None:
|
||||
@@ -325,6 +329,18 @@ class ServerArgs:
|
||||
if is_hip():
|
||||
self.triton_attention_num_kv_splits = 16
|
||||
|
||||
# PD disaggregation
|
||||
if self.disaggregation_mode == "prefill":
|
||||
self.disable_cuda_graph = True
|
||||
logger.warning("KV cache is forced as chunk cache for decode server")
|
||||
self.disable_overlap_schedule = True
|
||||
logger.warning("Overlap scheduler is disabled for prefill server")
|
||||
elif self.disaggregation_mode == "decode":
|
||||
self.disable_radix_cache = True
|
||||
logger.warning("Cuda graph is disabled for prefill server")
|
||||
self.disable_overlap_schedule = True
|
||||
logger.warning("Overlap scheduler is disabled for decode server")
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
# Model and port args
|
||||
@@ -1063,6 +1079,21 @@ class ServerArgs:
|
||||
help="Inject the outputs from jax as the input of every layer.",
|
||||
)
|
||||
|
||||
# Disaggregation
|
||||
parser.add_argument(
|
||||
"--disaggregation-mode",
|
||||
type=str,
|
||||
default="null",
|
||||
choices=["null", "prefill", "decode"],
|
||||
help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disaggregation-bootstrap-port",
|
||||
type=int,
|
||||
default=ServerArgs.disaggregation_bootstrap_port,
|
||||
help="Bootstrap server port on the prefill server. Default is 8998.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
args.tp_size = args.tensor_parallel_size
|
||||
|
||||
Reference in New Issue
Block a user