diff --git a/python/sglang/srt/disaggregation/conn.py b/python/sglang/srt/disaggregation/conn.py new file mode 100644 index 000000000..3989504ad --- /dev/null +++ b/python/sglang/srt/disaggregation/conn.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import logging +from enum import Enum +from typing import Optional + +import numpy as np +import numpy.typing as npt + +logger = logging.getLogger(__name__) + + +class KVArgs: + engine_rank: int + kv_data_ptrs: list[int] + kv_data_lens: list[int] + kv_item_lens: list[int] + aux_data_ptrs: list[int] + aux_data_lens: list[int] + aux_item_lens: list[int] + ib_device: str + + +class KVManager: + def __init__(self, args: KVArgs): ... + + +class KVPoll: + Failed = 0 + Bootstrapping = 1 + WaitingForInput = 2 + Transferring = 3 + Success = 4 + + +class KVSender: + def __init__(self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int): + self.has_sent = False + + def init(self, num_kv_indices: int, aux_index: Optional[int] = None): ... + + def send(self, kv_indices: npt.NDArray[np.int32]): + self.has_sent = True + + def poll(self) -> KVPoll: + if self.has_sent is False: + # Assume handshake completed instantly + return KVPoll.WaitingForInput + else: + # Assume transfer completed instantly + return KVPoll.Success + + def failure_exception(self): + raise Exception("Fake KVSender Exception") + + +class KVReceiver: + def __init__( + self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None + ): + self.has_init = False + + def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): + self.has_init = True + + def poll(self) -> KVPoll: + if self.has_init is False: + # Assume handshake completed instantly + return KVPoll.WaitingForInput + else: + # Assume transfer completed instantly + return KVPoll.Success + + def failure_exception(self): + raise Exception("Fake KVReceiver Exception") + + +class KVBootstrapServer: + def __init__(self, port: int): ... + + def poll(self) -> KVPoll: ... diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py new file mode 100644 index 000000000..e4e3dde71 --- /dev/null +++ b/python/sglang/srt/disaggregation/decode.py @@ -0,0 +1,495 @@ +""" +Life cycle of a request in the decode server + +1. PreallocQueue: + a. Initialize a receiver for each request + b. The request handshakes first, and pre-allocate kv once there is available kv. + c. Move the request to TransferQueue. + +2. TransferQueue: + a. Poll the receiver to check the transfer state + b. If the transfer has finished, move the request to waiting queue + +3. WaitingQueue: + a. Use the requests in the queue to construct a PrebuiltExtendBatch + b. Skip the prefill forward but only populate metadata + +4. RunningBatch: + a. Merge the resolved PrebuiltExtendBatch into running batch to run decoding +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Tuple + +import torch +from torch.distributed import ProcessGroup + +from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVReceiver +from sglang.srt.disaggregation.utils import ( + ReqToMetadataIdxAllocator, + poll_and_all_reduce, +) +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator +from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from sglang.srt.configs.model_config import ModelConfig + from sglang.srt.managers.schedule_batch import Req, ScheduleBatch + from sglang.srt.managers.scheduler import Scheduler + from sglang.srt.server_args import ServerArgs + + +@dataclass +class DecodeRequest: + req: Req + kv_receiver: KVReceiver + waiting_for_input: bool = False + metadata_buffer_index: int = -1 + + +class DecodePreallocQueue: + """ + Store the requests that are preallocating. + """ + + def __init__( + self, + req_to_token_pool: ReqToTokenPool, + token_to_kv_pool_allocator: TokenToKVPoolAllocator, + req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator, + metadata_buffers: List[torch.Tensor], + aux_dtype: torch.dtype, + scheduler: Scheduler, + transfer_queue: DecodeTransferQueue, + tree_cache: BasePrefixCache, + gloo_group: ProcessGroup, + tp_rank: int, + tp_size: int, + bootstrap_port: int, + ): + self.req_to_token_pool = req_to_token_pool + self.token_to_kv_pool_allocator = token_to_kv_pool_allocator + self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache() + self.aux_dtype = aux_dtype + self.metadata_buffers = metadata_buffers + self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator + self.scheduler = scheduler + self.transfer_queue = transfer_queue + self.tree_cache = tree_cache # this is always a chunk cache + self.gloo_group = gloo_group + self.tp_rank = tp_rank + self.tp_size = tp_size + self.bootstrap_port = bootstrap_port + + self.num_reserved_decode_tokens = 512 + + # Queue for requests pending pre-allocation + self.queue: List[DecodeRequest] = [] + self.kv_manager = self._init_kv_manager() + + def _init_kv_manager(self) -> KVManager: + kv_args = KVArgs() + kv_args.engine_rank = self.tp_rank + kv_data_ptrs, kv_data_lens, kv_item_lens = ( + self.token_to_kv_pool.get_contiguous_buf_infos() + ) + + kv_args.kv_data_ptrs = kv_data_ptrs + kv_args.kv_data_lens = kv_data_lens + kv_args.kv_item_lens = kv_item_lens + + kv_args.aux_data_ptrs = [ + output_id_tensor.data_ptr() for output_id_tensor in self.metadata_buffers + ] + kv_args.aux_data_lens = [ + metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers + ] + kv_args.aux_item_lens = [ + metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers + ] + kv_args.ib_device = "mock-ib-device" + kv_manager = KVManager(kv_args) + return kv_manager + + def add(self, req: Req) -> None: + """Add a request to the pending queue.""" + + kv_receiver = KVReceiver( + mgr=self.kv_manager, + bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}", + bootstrap_room=req.bootstrap_room, + ) + self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver)) + + def extend(self, reqs: List[Req]) -> None: + """Add a request to the pending queue.""" + for req in reqs: + self.add(req) + + def _update_handshake_waiters(self) -> None: + if not self.queue: + return + + if all(decode_req.waiting_for_input for decode_req in self.queue): + return + + polls = poll_and_all_reduce( + [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group + ) + + for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): + if poll == KVPoll.Bootstrapping: + pass + elif poll == KVPoll.WaitingForInput: + decode_req.waiting_for_input = True + elif poll == KVPoll.Failed: + raise Exception("Handshake failed") + + def pop_preallocated(self) -> List[DecodeRequest]: + """Pop the preallocated requests from the pending queue (FIFO).""" + self._update_handshake_waiters() + + preallocated_reqs = [] + indices_to_remove = set() + allocatable_tokens = self._allocatable_tokens(count_retracted=True) + + for i, decode_req in enumerate(self.queue): + if not decode_req.waiting_for_input: + continue + + if self.req_to_token_pool.available_size() <= 0: + break + + if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0: + break + + required_tokens_for_request = ( + len(decode_req.req.origin_input_ids) + self.num_reserved_decode_tokens + ) + + if required_tokens_for_request > allocatable_tokens: + break + + allocatable_tokens -= required_tokens_for_request + self._pre_alloc(decode_req.req) + + kv_indices = ( + self.req_to_token_pool.req_to_token[decode_req.req.req_pool_idx][ + : len(decode_req.req.origin_input_ids) + ] + .cpu() + .numpy() + ) + + decode_req.metadata_buffer_index = ( + self.req_to_metadata_buffer_idx_allocator.alloc() + ) + assert decode_req.metadata_buffer_index is not None + decode_req.kv_receiver.init(kv_indices, decode_req.metadata_buffer_index) + preallocated_reqs.append(decode_req) + indices_to_remove.add(i) + + self.queue = [ + entry for i, entry in enumerate(self.queue) if i not in indices_to_remove + ] + + return preallocated_reqs + + def _allocatable_tokens(self) -> int: + allocatable_tokens = ( + self.token_to_kv_pool_allocator.available_size() + - self.num_reserved_decode_tokens + * ( + len(self.scheduler.running_batch.reqs) + + len(self.transfer_queue.queue) + + len(self.scheduler.waiting_queue) + ) + ) + + # Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration + # the extend batch is not in any queue, so we need to explicitly add the tokens slots here + if ( + self.scheduler.last_batch + and self.scheduler.last_batch.forward_mode.is_extend() + ): + allocatable_tokens -= self.num_reserved_decode_tokens * len( + self.scheduler.last_batch.reqs + ) + + return allocatable_tokens + + def _pre_alloc(self, req: Req) -> torch.Tensor: + """Pre-allocate the memory for req_to_token and token_kv_pool""" + req_pool_indices = self.req_to_token_pool.alloc(1) + + assert req_pool_indices is not None + + req.req_pool_idx = req_pool_indices[0] + kv_loc = self.token_to_kv_pool_allocator.alloc( + len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0) + ) + + assert kv_loc is not None + + self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc) + + # populate metadata + req.fill_ids = req.origin_input_ids + req.output_ids + req.extend_input_len = len(req.origin_input_ids) + + return kv_loc + + +class DecodeTransferQueue: + """ + Store the requests that is polling kv + """ + + def __init__( + self, + gloo_group: ProcessGroup, + req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator, + metadata_buffers: torch.Tensor, + ): + self.queue: List[DecodeRequest] = [] + self.gloo_group = gloo_group + self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator + self.metadata_buffers = metadata_buffers + + def add(self, req_conn: DecodeRequest) -> None: + self.queue.append(req_conn) + + def extend(self, req_conns) -> None: + self.queue.extend(req_conns) + + def pop_transferred(self) -> List[Req]: + if not self.queue: + return [] + + polls = poll_and_all_reduce( + [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group + ) + + transferred_reqs = [] + indices_to_remove = set() + for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): + if poll == KVPoll.Failed: + raise Exception("Transfer failed") + elif poll == KVPoll.Success: + # pop and push it to waiting queue + idx = decode_req.metadata_buffer_index + assert len(decode_req.req.output_ids) == 0 + output_id_buffer = self.metadata_buffers[0] + # the last dimension is padded by the same values. + output_id = output_id_buffer[idx][0].item() + assert len(decode_req.req.output_ids) == 0 + assert decode_req.req.transferred_output_id is None + decode_req.req.transferred_output_id = output_id + transferred_reqs.append(decode_req.req) + indices_to_remove.add(i) + elif poll in [ + KVPoll.Bootstrapping, + KVPoll.WaitingForInput, + KVPoll.Transferring, + ]: + pass + else: + raise ValueError(f"Unexpected poll case: {poll}") + + for i in indices_to_remove: + idx = self.queue[i].metadata_buffer_index + assert idx != -1 + self.req_to_metadata_buffer_idx_allocator.free(idx) + + self.queue = [ + entry for i, entry in enumerate(self.queue) if i not in indices_to_remove + ] + + return transferred_reqs + + +class ScheduleBatchDisaggregationDecodeMixin: + + def prepare_for_prebuilt_extend(self: ScheduleBatch): + """ + Prepare a prebuilt extend by populate metadata + Adapted from .prepare_for_extend(). + """ + + self.forward_mode = ForwardMode.EXTEND + reqs = self.reqs + input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] + extend_num_tokens = sum(len(ids) for ids in input_ids) + seq_lens = [] + pre_lens = [] + req_pool_indices = [] + + # Pre-calculate total size + total_size = sum(req.extend_input_len for req in reqs) + out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device) + + # Fill the tensor in one pass + offset = 0 + for i, req in enumerate(reqs): + req_pool_indices.append(req.req_pool_idx) + + chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][ + : req.extend_input_len + ] + assert ( + offset + req.extend_input_len <= total_size + ), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}" + out_cache_loc[offset : offset + req.extend_input_len] = chunk + offset += req.extend_input_len + + pre_len = len(req.prefix_indices) + seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1) + seq_lens.append(seq_len) + if len(req.output_ids) == 0: + assert ( + seq_len - pre_len == req.extend_input_len + ), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}" + + req.cached_tokens += pre_len - req.already_computed + req.already_computed = seq_len + req.is_retracted = False + pre_lens.append(pre_len) + req.extend_logprob_start_len = 0 + + extend_input_logprob_token_ids = None + + # Set fields + self.input_ids = torch.tensor( + sum(input_ids, []), dtype=torch.int32, device=self.device + ) + self.req_pool_indices = torch.tensor( + req_pool_indices, dtype=torch.int64, device=self.device + ) + self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device) + self.out_cache_loc = out_cache_loc + self.seq_lens_sum = sum(seq_lens) + self.extend_num_tokens = extend_num_tokens + self.prefix_lens = [len(r.prefix_indices) for r in reqs] + self.extend_lens = [r.extend_input_len for r in reqs] + self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] + self.extend_input_logprob_token_ids = extend_input_logprob_token_ids + + # Build sampling info + self.sampling_info = SamplingBatchInfo.from_schedule_batch( + self, + self.model_config.vocab_size, + ) + + def process_prebuilt_extend( + self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig + ): + """Assign the buffered last input id to schedule batch""" + self.output_ids = [] + for req in self.reqs: + if req.output_ids and len(req.output_ids) > 0: + # resumed retracted req + self.output_ids.append(req.output_ids[-1]) + else: + assert req.transferred_output_id is not None + req.output_ids.append(req.transferred_output_id) + self.output_ids.append(req.transferred_output_id) + self.tree_cache.cache_unfinished_req(req) + self.output_ids = torch.tensor(self.output_ids, device=self.device) + + +class SchedulerDisaggregationDecodeMixin: + + def get_next_disagg_decode_batch_to_run( + self: Scheduler, + ) -> Optional[Tuple[ScheduleBatch, bool]]: + """Create fake completed prefill if possible and merge with running batch""" + # Merge the prefill batch into the running batch + last_batch = self.last_batch + if last_batch and last_batch.forward_mode.is_extend(): + # chunked prefill doesn't happen in decode instance. + assert self.chunked_req is None + # Filter finished batches. + last_batch.filter_batch() + if not last_batch.is_empty(): + if self.running_batch.is_empty(): + self.running_batch = last_batch + else: + # merge running_batch with prefill batch + self.running_batch.merge_batch(last_batch) + + new_prebuilt_batch = self.get_new_prebuilt_batch() + + ret: Optional[ScheduleBatch] = None + if new_prebuilt_batch: + ret = new_prebuilt_batch + else: + if self.running_batch.is_empty(): + ret = None + else: + self.running_batch = self.update_running_batch(self.running_batch) + ret = self.running_batch if not self.running_batch.is_empty() else None + + return ret + + def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]: + """Create a schedulebatch for fake completed prefill""" + if len(self.waiting_queue) == 0: + return None + + curr_batch_size = self.running_batch.batch_size() + + batch_size = min(self.req_to_token_pool.size, self.max_running_requests) + + num_not_used_batch = batch_size - curr_batch_size + + # pop req from waiting queue + can_run_list: List[Req] = [] + waiting_queue: List[Req] = [] + + for i in range(len(self.waiting_queue)): + req = self.waiting_queue[i] + # we can only add at least `num_not_used_batch` new batch to the running queue + if i < num_not_used_batch: + can_run_list.append(req) + req.init_next_round_input(self.tree_cache) + else: + waiting_queue.append(req) + + self.waiting_queue = waiting_queue + if len(can_run_list) == 0: + return None + # local import to avoid circular import + from sglang.srt.managers.schedule_batch import ScheduleBatch + + # construct a schedule batch with those requests and mark as decode + new_batch = ScheduleBatch.init_new( + can_run_list, + self.req_to_token_pool, + self.token_to_kv_pool_allocator, + self.tree_cache, + self.model_config, + self.enable_overlap, + self.spec_algorithm, + self.server_args.enable_custom_logit_processor, + ) + + # construct fake completed prefill + new_batch.prepare_for_prebuilt_extend() + new_batch.process_prebuilt_extend(self.server_args, self.model_config) + + return new_batch + + def process_decode_queue(self: Scheduler): + req_conns = self.disagg_decode_prealloc_queue.pop_preallocated() + self.disagg_decode_transfer_queue.extend(req_conns) + alloc_reqs = ( + self.disagg_decode_transfer_queue.pop_transferred() + ) # the requests which kv has arrived + self.waiting_queue.extend(alloc_reqs) diff --git a/python/sglang/srt/disaggregation/mini_lb.py b/python/sglang/srt/disaggregation/mini_lb.py new file mode 100644 index 000000000..fdc9f2f45 --- /dev/null +++ b/python/sglang/srt/disaggregation/mini_lb.py @@ -0,0 +1,285 @@ +""" +Minimal HTTP load balancer for prefill and decode servers for testing purpose. +""" + +import asyncio +import random +import urllib +from itertools import chain + +import aiohttp +import orjson +import uvicorn +from fastapi import FastAPI, HTTPException +from fastapi.responses import ORJSONResponse, Response, StreamingResponse + + +class MiniLoadBalancer: + def __init__(self, prefill_servers, decode_servers): + self.prefill_servers = prefill_servers + self.decode_servers = decode_servers + + def select_pair(self): + return random.choice(self.prefill_servers), random.choice(self.decode_servers) + + async def generate_request(self, request_data): + prefill_server, decode_server = self.select_pair() + + # Parse and transform prefill_server + parsed_url = urllib.parse.urlparse(prefill_server) + hostname = parsed_url.hostname + bootstrap_host = f"{hostname}" + + modified_request = request_data.copy() + modified_request.update( + { + "bootstrap_host": bootstrap_host, + "bootstrap_room": random.randint(0, 2**63 - 1), + } + ) + + async with aiohttp.ClientSession() as session: + # Create the tasks + tasks = [ + session.post(f"{prefill_server}/generate", json=modified_request), + session.post(f"{decode_server}/generate", json=modified_request), + ] + + prefill_response = None + decode_response = None + + # Process responses as they arrive + for i, response in enumerate(asyncio.as_completed(tasks)): + response = await response + # Check if this is the prefill or decode response based on order created + if i == 0: # First completed task + if str(response.url).startswith(prefill_server): + prefill_response = response + if response.status != 200: + raise HTTPException( + status_code=response.status, + detail=f"Prefill server error: Status {response.status} Details: {await response.text()}", + ) + else: + decode_response = response + if response.status != 200: + raise HTTPException( + status_code=response.status, + detail=f"Decode server error: Status {response.status} Details: {await response.text()}", + ) + else: # Second completed task + if str(response.url).startswith(prefill_server): + prefill_response = response + else: + decode_response = response + + if response.status != 200: + raise HTTPException( + status_code=response.status, + detail=f"{'Prefill' if str(response.url).startswith(prefill_server) else 'Decode'} server error: Status {response.status} Details: {await response.text()}", + ) + + return await decode_response.json() + + +app = FastAPI() +load_balancer = None + + +@app.get("/health") +async def health_check(): + return Response(status_code=200) + + +@app.get("/health_generate") +async def health_check(): + prefill_servers, decode_servers = ( + load_balancer.prefill_servers, + load_balancer.decode_servers, + ) + async with aiohttp.ClientSession() as session: + # Create the tasks + tasks = [] + for server in chain(prefill_servers, decode_servers): + tasks.append(session.post(f"{server}/health_generate")) + for i, response in enumerate(asyncio.as_completed(tasks)): + await response + return Response(status_code=200) + + +@app.post("/flush_cache") +async def flush_cache(): + prefill_servers, decode_servers = ( + load_balancer.prefill_servers, + load_balancer.decode_servers, + ) + async with aiohttp.ClientSession() as session: + # Create the tasks + tasks = [] + for server in chain(prefill_servers, decode_servers): + tasks.append(session.post(f"{server}/flush_cache")) + for i, response in enumerate(asyncio.as_completed(tasks)): + await response + return Response(status_code=200) + + +@app.get("/get_server_info") +async def get_server_info(): + prefill_servers, decode_servers = ( + load_balancer.prefill_servers, + load_balancer.decode_servers, + ) + prefill_infos = [] + decode_infos = [] + async with aiohttp.ClientSession() as session: + for server in chain(prefill_servers): + server_info = await session.get(f"{server}/get_server_info") + prefill_infos.append(await server_info.json()) + for server in chain(decode_servers): + server_info = await session.get(f"{server}/get_server_info") + decode_infos.append(await server_info.json()) + + return {"prefill": prefill_infos, "decode": decode_infos} + + +@app.get("/get_model_info") +async def get_model_info(): + # Dummy model information + model_info = { + "model_path": "/path/to/dummy/model", + "tokenizer_path": "/path/to/dummy/tokenizer", + "is_generation": True, + "preferred_sampling_params": {"temperature": 0.7, "max_new_tokens": 128}, + } + return ORJSONResponse(content=model_info) + + +@app.post("/generate") +async def handle_generate_request(request_data: dict): + prefill_server, decode_server = load_balancer.select_pair() + + # Parse and transform prefill_server for bootstrap data + parsed_url = urllib.parse.urlparse(prefill_server) + hostname = parsed_url.hostname + modified_request = request_data.copy() + modified_request.update( + { + "bootstrap_host": hostname, + "bootstrap_room": random.randint(0, 2**63 - 1), + } + ) + + # Check if streaming is requested + if request_data.get("stream", False): + + async def stream_results(): + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=3600) + ) as session: + try: + # Create the tasks + tasks = [ + session.post( + f"{prefill_server}/generate", json=modified_request + ), + session.post( + f"{decode_server}/generate", json=modified_request + ), + ] + + prefill_response = None + decode_response = None + + # Process responses as they arrive + for i, response_task in enumerate(asyncio.as_completed(tasks)): + response = await response_task + + # Check the response immediately + if str(response.url).startswith(prefill_server): + prefill_response = response + if response.status != 200: + error_msg = { + "error": { + "message": f"Prefill server error: Status {response.status}, Details: {await response.text()}" + } + } + yield b"data: " + orjson.dumps( + error_msg, option=orjson.OPT_NON_STR_KEYS + ) + b"\n\n" + return + else: + decode_response = response + if response.status != 200: + error_msg = { + "error": { + "message": f"Decode server error: Status {response.status}" + } + } + yield b"data: " + orjson.dumps( + error_msg, option=orjson.OPT_NON_STR_KEYS + ) + b"\n\n" + return + + # Stream successful decode server response + async for line in decode_response.content: + yield line + yield b"data: [DONE]\n\n" + + except Exception as e: + error_msg = { + "error": {"message": f"Stream processing error: {str(e)}"} + } + yield b"data: " + orjson.dumps( + error_msg, option=orjson.OPT_NON_STR_KEYS + ) + b"\n\n" + + return StreamingResponse( + stream_results(), + media_type="text/event-stream", + ) + + # Non-streaming case + result = await load_balancer.generate_request(request_data) + return ORJSONResponse(content=result) + + +@app.get("/v1/models") +async def get_models(): + prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server + async with aiohttp.ClientSession() as session: + try: + response = await session.get(f"{prefill_server}/v1/models") + if response.status != 200: + raise HTTPException( + status_code=response.status, + detail=f"Prefill server error: Status {response.status}", + ) + return ORJSONResponse(content=await response.json()) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +def run(prefill_addrs, decode_addrs, host, port): + global load_balancer + load_balancer = MiniLoadBalancer(prefill_addrs, decode_addrs) + uvicorn.run(app, host=host, port=port) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Mini Load Balancer Server") + parser.add_argument( + "--prefill", required=True, help="Comma-separated URLs for prefill servers" + ) + parser.add_argument( + "--decode", required=True, help="Comma-separated URLs for decode servers" + ) + parser.add_argument( + "--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)" + ) + parser.add_argument( + "--port", type=int, default=8000, help="Port to bind the server (default: 8000)" + ) + args = parser.parse_args() + run(args.prefill.split(","), args.decode.split(","), args.host, args.port) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py new file mode 100644 index 000000000..fad26c571 --- /dev/null +++ b/python/sglang/srt/disaggregation/prefill.py @@ -0,0 +1,249 @@ +""" +Life cycle of a request in the prefill server + +1. Bootstrap Queue + a. Initialize a sender for each request + b. Use the queue to store requests whose bootstrap (handshake and preallocation) has not finished + c. Poll senders to check bootstrap state + d. Once bootstrap is complete, move request to Waiting Queue + +2. Waiting Queue + a. Use PrefillAdder to pop requests + b. Run forward + c. Add the request to Infight Queue + +3. Infight Queue + a. Poll (non-blocking) the sender of the request + b. Once the transfer has finished, return the request +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, List, Optional + +import torch + +from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVSender +from sglang.srt.disaggregation.utils import ( + ReqToMetadataIdxAllocator, + poll_and_all_reduce, +) +from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch + +if TYPE_CHECKING: + from torch.distributed import ProcessGroup + + from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler + from sglang.srt.mem_cache.memory_pool import KVCache + +logger = logging.getLogger(__name__) + + +class PrefillBootstrapQueue: + """ + Store the requests in bootstrapping + """ + + def __init__( + self, + token_to_kv_pool: KVCache, + req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator, + metadata_buffers: List[torch.Tensor], + aux_dtype: torch.dtype, + tp_rank: int, + tp_size: int, + bootstrap_port: int, + gloo_group: ProcessGroup, + ): + self.token_to_kv_pool = token_to_kv_pool + self.aux_dtype = aux_dtype + + self.metadata_buffers = metadata_buffers + self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator + self.tp_rank = tp_rank + self.tp_size = tp_size + self.kv_manager = self._init_kv_manager() + self.queue: List[Req] = [] + self.gloo_group = gloo_group + self.bootstrap_port = bootstrap_port + + def allocate_token_id(self, idx: int, token_id: int): + assert token_id >= 0, f"token_id: {token_id} is negative" + output_id_buffer = self.metadata_buffers[0] + output_id_buffer[idx] = token_id + + def _init_kv_manager(self) -> KVManager: + kv_args = KVArgs() + kv_args.engine_rank = self.tp_rank + kv_data_ptrs, kv_data_lens, kv_item_lens = ( + self.token_to_kv_pool.get_contiguous_buf_infos() + ) + + kv_args.kv_data_ptrs = kv_data_ptrs + kv_args.kv_data_lens = kv_data_lens + kv_args.kv_item_lens = kv_item_lens + + # Define req -> input ids buffer + kv_args.aux_data_ptrs = [ + metadata_buffer.data_ptr() for metadata_buffer in self.metadata_buffers + ] + kv_args.aux_data_lens = [ + metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers + ] + kv_args.aux_item_lens = [ + metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers + ] + kv_args.ib_device = "mock-ib-device" + kv_manager = KVManager(kv_args) + return kv_manager + + def add(self, req: Req) -> None: + req.disagg_kv_sender = KVSender( + mgr=self.kv_manager, + bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}", + bootstrap_room=req.bootstrap_room, + ) + self._process_req(req) + self.queue.append(req) + + def _process_req(self, req: Req) -> None: + """ + Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate + """ + req.sampling_params.max_new_tokens = 1 + + def pop_bootstrapped(self) -> List[Req]: + """pop the reqs which has finished bootstrapping""" + bootstrapped_reqs = [] + indices_to_remove = set() + + if len(self.queue) == 0: + return [] + + polls = poll_and_all_reduce( + [req.disagg_kv_sender for req in self.queue], self.gloo_group + ) + + for i, (req, poll) in enumerate(zip(self.queue, polls)): + if poll == KVPoll.Bootstrapping: + continue + elif poll == KVPoll.Failed: + raise Exception("Bootstrap failed") + + # KV.WaitingForInput - init here + num_kv_indices = len(req.origin_input_ids) + if self.req_to_metadata_buffer_idx_allocator.available_size() == 0: + break + + req.metadata_buffer_index = ( + self.req_to_metadata_buffer_idx_allocator.alloc() + ) + assert req.metadata_buffer_index is not None + req.disagg_kv_sender.init(num_kv_indices, req.metadata_buffer_index) + + bootstrapped_reqs.append(req) + indices_to_remove.add(i) + + self.queue = [ + entry for i, entry in enumerate(self.queue) if i not in indices_to_remove + ] + + return bootstrapped_reqs + + +class SchedulerDisaggregationPrefillMixin: + """ + Mixin for Scheduler to handle disaggregation prefill + """ + + def process_batch_result_disagg_prefill( + self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult + ) -> None: + """ + Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue + Adapted from process_batch_result_prefill + """ + + next_token_ids = result.next_token_ids.tolist() + + for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True): + req: Req + if req.is_chunked <= 0: + # There is no output_ids for prefill + req.output_ids.append(next_token_id) + self.tree_cache.cache_unfinished_req(req) # update the tree and lock + self.send_kv_chunk(req, token_id=next_token_id) + self.disagg_prefill_infight_queue.append(req) + else: + # being chunked reqs' prefill is not finished + req.is_chunked -= 1 + + # TODO: Not sure if this is necessary + if batch.next_batch_sampling_info: + batch.next_batch_sampling_info.update_regex_vocab_mask() + # We need to remove this for overlap schedule. + self.current_stream.synchronize() + batch.next_batch_sampling_info.sampling_info_done.set() + + def process_disagg_prefill_infight_queue(self: Scheduler) -> None: + """ + Poll the requests in the middle of transfer. If done, return the request. + """ + assert len(self.disagg_prefill_infight_queue) > 0 + + done_reqs = [] + + polls = poll_and_all_reduce( + [req.disagg_kv_sender for req in self.disagg_prefill_infight_queue], + self.tp_worker.get_tp_cpu_group(), + ) + + undone_reqs: List[Req] = [] + # Check .poll() for the reqs in disagg_prefill_infight_queue. If Success, respond to the client and remove it from the queue + for req, poll in zip(self.disagg_prefill_infight_queue, polls): + if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]: + undone_reqs.append(req) + elif poll == KVPoll.Success: # transfer done + self.tree_cache.cache_finished_req(req) # unlock the tree + req.finished_reason = FINISH_LENGTH(length=0) + done_reqs.append(req) + elif poll == KVPoll.Failed: + raise Exception("Transferring failed") + + # Stream requests which have finished transfer + self.stream_output(done_reqs, False, None) + + self.disagg_prefill_infight_queue = undone_reqs + + def process_prefill_chunk(self: Scheduler) -> None: + if self.last_batch and self.last_batch.forward_mode.is_extend(): + if self.chunked_req: + # Move the chunked request out of the batch so that we can merge + # only finished requests to running_batch. + self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req) + self.tree_cache.cache_unfinished_req(self.chunked_req) + self.send_kv_chunk(self.chunked_req) + # chunked request keeps its rid but will get a new req_pool_idx + self.req_to_token_pool.free(self.chunked_req.req_pool_idx) + self.running_batch.batch_is_full = False + + def send_kv_chunk( + self: Scheduler, req: Req, token_id: Optional[int] = None + ) -> None: + """ + Send a prefilled chunk to the decode server + """ + start_idx = req.start_send_idx + end_idx = min(len(req.fill_ids), len(req.origin_input_ids)) + kv_indices = ( + self.req_to_token_pool.req_to_token[req.req_pool_idx][start_idx:end_idx] + .cpu() + .numpy() + ) + req.start_send_idx = end_idx + if token_id is not None: + self.disagg_prefill_pending_queue.allocate_token_id( + req.metadata_buffer_index, token_id + ) + req.disagg_kv_sender.send(kv_indices) diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py new file mode 100644 index 000000000..76da71a00 --- /dev/null +++ b/python/sglang/srt/disaggregation/utils.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from collections import deque +from enum import Enum +from typing import List + +import torch +import torch.distributed as dist + + +class DisaggregationMode(Enum): + NULL = "null" + PREFILL = "prefill" + DECODE = "decode" + + +def poll_and_all_reduce(pollers, gloo_group): + polls = [int(poller.poll()) for poller in pollers] + tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu") + dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group) + return tensor_to_reduce.tolist() + + +class ReqToMetadataIdxAllocator: + """A memory pool that maps a request to its first output token location.""" + + def __init__( + self, + size: int, + ): + self.size = size + self.free_slots = deque(list(range(size))) + + def available_size(self): + return len(self.free_slots) + + def alloc(self) -> List[int]: + if len(self.free_slots) == 0: + return None + + return self.free_slots.popleft() + + def free(self, free_index: int): + self.free_slots.append(free_index) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3b8259bfc..b40f93002 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -42,6 +42,8 @@ import triton.language as tl from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject +from sglang.srt.disaggregation.conn import KVSender +from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator @@ -396,6 +398,24 @@ class Req: self.spec_verify_ct = 0 self.lora_path = lora_path + # For disaggregation + self.bootstrap_host: str = "0.0.0.0" + self.bootstrap_room: Optional[int] = None + self.disagg_kv_sender: Optional[KVSender] = None + + # used for warmup because we don't have a pair yet when init + self.skip_kv_transfer: bool = False + # the start index of the sent kv cache + # We want to send it chunk by chunk for chunked prefill. + # After every chunk forward, we do the following: + # kv_send(req.input_ids[req.start_send_idx:len(req.fill_ids)]) + # start_send_idx = len(req.fill_ids) + self.start_send_idx: int = 0 + + self.metadata_buffer_index: int = -1 + # The first output_id transferred from prefill instance. + self.transferred_output_id: Optional[int] = None + @property def seqlen(self): return len(self.origin_input_ids) + len(self.output_ids) @@ -531,7 +551,7 @@ bid = 0 @dataclasses.dataclass -class ScheduleBatch: +class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): """Store all information of a batch on the scheduler.""" # Request, memory pool, and cache diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index ef4ecaf90..71a7e2c3a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -37,6 +37,19 @@ from torch.distributed import barrier from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import create_grammar_backend +from sglang.srt.disaggregation.decode import ( + DecodePreallocQueue, + DecodeTransferQueue, + SchedulerDisaggregationDecodeMixin, +) +from sglang.srt.disaggregation.prefill import ( + PrefillBootstrapQueue, + SchedulerDisaggregationPrefillMixin, +) +from sglang.srt.disaggregation.utils import ( + DisaggregationMode, + ReqToMetadataIdxAllocator, +) from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput @@ -137,7 +150,11 @@ class EmbeddingBatchResult: bid: int -class Scheduler(SchedulerOutputProcessorMixin): +class Scheduler( + SchedulerOutputProcessorMixin, + SchedulerDisaggregationDecodeMixin, + SchedulerDisaggregationPrefillMixin, +): """A scheduler that manages a tensor parallel GPU worker.""" def __init__( @@ -389,6 +406,11 @@ class Scheduler(SchedulerOutputProcessorMixin): ] ) + self.disaggregation_mode = DisaggregationMode( + self.server_args.disaggregation_mode + ) + self.init_disaggregation() + def init_tokenizer(self): server_args = self.server_args @@ -489,6 +511,73 @@ class Scheduler(SchedulerOutputProcessorMixin): }, ) + def init_disaggregation(self): + if ( + self.disaggregation_mode == DisaggregationMode.DECODE + ): # *2 for the headroom. + buffer_size = (self.req_to_token_pool.size) * 2 + req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator( + buffer_size + ) + aux_dtype = torch.int32 + # A list of metadata buffers. The shape is (b, metadata_size) where + # b corresponds to a max running requests. The last shape * dtype.itemsize + # should be larger than 64 bytes to work with RDMA, so we pad it. + output_id_buffer = torch.zeros( + (buffer_size, 16), dtype=aux_dtype, device="cpu" + ) + metadata_buffers = [output_id_buffer] + + # The decode requests polling kv cache + self.disagg_decode_transfer_queue = DecodeTransferQueue( + gloo_group=self.tp_worker.get_attention_tp_cpu_group(), + req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator, + metadata_buffers=metadata_buffers, + ) + + # The decode requests pending for pre-allocation + self.disagg_decode_prealloc_queue = DecodePreallocQueue( + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator, + metadata_buffers=metadata_buffers, + aux_dtype=aux_dtype, + scheduler=self, + transfer_queue=self.disagg_decode_transfer_queue, + tree_cache=self.tree_cache, + gloo_group=self.tp_worker.get_attention_tp_cpu_group(), + tp_rank=self.tp_rank, + tp_size=self.tp_size, + bootstrap_port=self.server_args.disaggregation_bootstrap_port, + ) + elif self.disaggregation_mode == DisaggregationMode.PREFILL: + # *2 for the headroom. + buffer_size = self.max_running_requests * 2 + req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator( + buffer_size + ) + aux_dtype = torch.int32 + # A list of metadata buffers. The shape is (b, metadata_size) where + # b corresponds to a max running requests. The last shape * dtype.itemsize + # should be larger than 64 bytes to work with RDMA, so we pad it. + output_id_buffer = torch.zeros( + (buffer_size, 16), dtype=aux_dtype, device="cpu" + ) + metadata_buffers = [output_id_buffer] + + self.disagg_prefill_pending_queue = PrefillBootstrapQueue( + token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(), + req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator, + metadata_buffers=metadata_buffers, + aux_dtype=aux_dtype, + tp_rank=self.tp_rank, + tp_size=self.tp_size, + bootstrap_port=self.server_args.disaggregation_bootstrap_port, + gloo_group=self.tp_worker.get_attention_tp_cpu_group(), + ) + # The prefill requests that are in the middle of kv sending + self.disagg_prefill_infight_queue: List[Req] = [] + @DynamicGradMode() def event_loop_normal(self): """A normal scheduler loop.""" @@ -549,6 +638,70 @@ class Scheduler(SchedulerOutputProcessorMixin): self.last_batch = batch + @torch.no_grad() + def event_loop_normal_disagg_prefill(self): + """A normal scheduler loop for prefill worker in disaggregation mode.""" + + while True: + recv_reqs = self.recv_requests() + self.process_input_requests(recv_reqs) + self.waiting_queue.extend( + self.disagg_prefill_pending_queue.pop_bootstrapped() + ) + self.process_prefill_chunk() + batch = self.get_new_batch_prefill() + self.cur_batch = batch + + if batch: + result = self.run_batch(batch) + self.process_batch_result_disagg_prefill(batch, result) + + if len(self.disagg_prefill_infight_queue) > 0: + self.process_disagg_prefill_infight_queue() + + if batch is None and len(self.disagg_prefill_infight_queue) == 0: + self.check_memory() + self.new_token_ratio = self.init_new_token_ratio + + self.last_batch = batch + # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it + # Otherwise, it hangs under high concurrency + self.running_batch.batch_is_full = False + + @torch.no_grad() + def event_loop_normal_disagg_decode(self): + """A normal scheduler loop for decode worker in disaggregation mode.""" + + while True: + recv_reqs = self.recv_requests() + self.process_input_requests(recv_reqs) + # polling and allocating kv cache + self.process_decode_queue() + batch = self.get_next_disagg_decode_batch_to_run() + self.cur_batch = batch + + if batch: + # Generate fake extend output. + if batch.forward_mode.is_extend(): + # Note: Logprobs should be handled on the prefill engine. + self.stream_output( + batch.reqs, [False for _ in range(len(batch.reqs))] + ) + else: + result = self.run_batch(batch) + self.process_batch_result(batch, result) + + if batch is None and ( + len(self.disagg_decode_transfer_queue.queue) + + len(self.disagg_decode_prealloc_queue.queue) + == 0 + ): + # When the server is idle, do self-check and re-init some states + self.check_memory() + self.new_token_ratio = self.init_new_token_ratio + + self.last_batch = batch + def recv_requests(self) -> List[Req]: """Receive results at tp_rank = 0 and broadcast it to all other TP ranks.""" if self.attn_tp_rank == 0: @@ -778,10 +931,20 @@ class Scheduler(SchedulerOutputProcessorMixin): self._add_request_to_queue(req) def _add_request_to_queue(self, req: Req): - self.waiting_queue.append(req) + if self.disaggregation_mode == DisaggregationMode.PREFILL: + self.disagg_prefill_pending_queue.add(req) - def _extend_requests_to_queue(self, reqs: List[Req]): - self.waiting_queue.extend(reqs) + elif self.disaggregation_mode == DisaggregationMode.DECODE: + self.disagg_decode_prealloc_queue.add(req) + + else: + self.waiting_queue.append(req) + + def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False): + if self.disaggregation_mode == DisaggregationMode.DECODE: + self.disagg_decode_prealloc_queue.extend(reqs) + else: + self.waiting_queue.extend(reqs) def handle_embedding_request( self, @@ -1814,10 +1977,18 @@ def run_scheduler_process( "max_req_input_len": scheduler.max_req_input_len, } ) - if scheduler.enable_overlap: - scheduler.event_loop_overlap() - else: - scheduler.event_loop_normal() + disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode + + if disaggregation_mode == DisaggregationMode.NULL: + if scheduler.enable_overlap: + scheduler.event_loop_overlap() + else: + scheduler.event_loop_normal() + elif disaggregation_mode == DisaggregationMode.PREFILL: + scheduler.event_loop_normal_disagg_prefill() + elif disaggregation_mode == DisaggregationMode.DECODE: + scheduler.event_loop_normal_disagg_decode() + except Exception: traceback = get_exception_traceback() logger.error(f"Scheduler hit an exception: {traceback}") diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index c211d76ff..bbc1cbbbc 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -49,6 +49,8 @@ from fastapi import BackgroundTasks from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.disaggregation.conn import KVBootstrapServer +from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.image_processor import ( get_dummy_image_processor, @@ -313,6 +315,16 @@ class TokenizerManager: ] ) + self.disaggregation_mode = DisaggregationMode( + self.server_args.disaggregation_mode + ) + # for disaggregtion, start kv boostrap server on prefill + if self.disaggregation_mode == DisaggregationMode.PREFILL: + # only start bootstrap server on prefill tm + self.bootstrap_server = KVBootstrapServer( + self.server_args.disaggregation_bootstrap_port + ) + async def generate_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index b1cbb739c..2b0f72be8 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -271,6 +271,19 @@ class MHATokenToKVPool(KVCache): v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize return k_size_bytes, v_size_bytes + # for disagg + def get_contiguous_buf_infos(self): + kv_data_ptrs = [ + self.get_key_buffer(i).data_ptr() for i in range(self.layer_num) + ] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)] + kv_data_lens = [ + self.get_key_buffer(i).nbytes for i in range(self.layer_num) + ] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)] + kv_item_lens = [ + self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num) + ] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)] + return kv_data_ptrs, kv_data_lens, kv_item_lens + # Todo: different memory layout def get_flat_data(self, indices): # prepare a large chunk of contiguous data for efficient transfer diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 5c1584c8c..8ad0e4b1a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -185,6 +185,10 @@ class ServerArgs: debug_tensor_dump_input_file: Optional[str] = None debug_tensor_dump_inject: bool = False + # For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only) + disaggregation_mode: str = "null" + disaggregation_bootstrap_port: int = 8998 + def __post_init__(self): # Set missing default values if self.tokenizer_path is None: @@ -325,6 +329,18 @@ class ServerArgs: if is_hip(): self.triton_attention_num_kv_splits = 16 + # PD disaggregation + if self.disaggregation_mode == "prefill": + self.disable_cuda_graph = True + logger.warning("KV cache is forced as chunk cache for decode server") + self.disable_overlap_schedule = True + logger.warning("Overlap scheduler is disabled for prefill server") + elif self.disaggregation_mode == "decode": + self.disable_radix_cache = True + logger.warning("Cuda graph is disabled for prefill server") + self.disable_overlap_schedule = True + logger.warning("Overlap scheduler is disabled for decode server") + @staticmethod def add_cli_args(parser: argparse.ArgumentParser): # Model and port args @@ -1063,6 +1079,21 @@ class ServerArgs: help="Inject the outputs from jax as the input of every layer.", ) + # Disaggregation + parser.add_argument( + "--disaggregation-mode", + type=str, + default="null", + choices=["null", "prefill", "decode"], + help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated', + ) + parser.add_argument( + "--disaggregation-bootstrap-port", + type=int, + default=ServerArgs.disaggregation_bootstrap_port, + help="Bootstrap server port on the prefill server. Default is 8998.", + ) + @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size