diff --git a/docs/backend/pd_disaggregation.md b/docs/backend/pd_disaggregation.md new file mode 100644 index 000000000..de95763a6 --- /dev/null +++ b/docs/backend/pd_disaggregation.md @@ -0,0 +1,49 @@ +# PD Disaggregation + +## Why and What is PD Disaggregation? + +Large Language Model (LLM) inference comprises two distinct phases: **Prefill** and **Decode**. The Prefill phase is computation-intensive, processing the entire input sequence, while the Decode phase is memory-intensive, managing the Key-Value (KV) cache for token generation. Traditionally, these phases are handled within a unified engine, where combined scheduling of prefill and decode batches introduces inefficiencies. To address these challenges, we introduce **Prefill and Decoding (PD) Disaggregation** in SGLang. + +### Issues with Unified Scheduling + +The conventional unified engine, which processes prefill and decode batches together, results in two significant problems: + +1. **Prefill Interruption**: Incoming prefill batches frequently interrupt ongoing decode batches, causing substantial delays in token generation. +2. **DP Attention Imbalance**: In data-parallel (DP) attention, one DP worker may process a prefill batch while another handles a decode batch simultaneously, leading to increased decode latency. + +PD Disaggregation resolves these by separating the two stages, enabling tailored optimizations for each. + +For the design details, please refer to [link](https://docs.google.com/document/d/1rQXJwKd5b9b1aOzLh98mnyMhBMhlxXA5ATZTHoQrwvc/edit?tab=t.0). + +Currently, we support Mooncake and NIXL as the transfer engine. + + +## Mooncake +### Requirements + +```bash +uv pip install mooncake-transfer-engine +``` + +### Usage + +### Llama Single Node + +```bash +$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-ib-device mlx5_roce0 +$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-ib-device mlx5_roce0 +$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000 +``` + +### DeepSeek Multi-Node + +```bash +# prefill 0 +$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8 +# prefill 1 +$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8 +# decode 0 +$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128 +# decode 1 +$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128 +``` diff --git a/docs/index.rst b/docs/index.rst index eac4cbd8f..edd786372 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -55,6 +55,7 @@ The core features include: backend/custom_chat_template.md backend/quantization.md backend/lora.ipynb + backend/pd_disaggregation.md .. toctree:: :maxdepth: 1 diff --git a/python/sglang/srt/disaggregation/fake/conn.py b/python/sglang/srt/disaggregation/fake/conn.py index 1e2bd4461..1e650753e 100644 --- a/python/sglang/srt/disaggregation/fake/conn.py +++ b/python/sglang/srt/disaggregation/fake/conn.py @@ -33,28 +33,18 @@ class FakeKVSender(BaseKVSender): self, kv_indices: list[int], aux_index: Optional[int] = None, - dest_ranks: Optional[list[int]] = None, ): logger.info( - f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}, dest_ranks: {dest_ranks}" + f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}" ) pass def send( self, kv_indices: npt.NDArray[np.int64], - index_slice: slice, - is_last: bool, ): - logger.info( - f"FakeKVSender send with kv_indices: {kv_indices}, index_slice: {index_slice}, is_last: {is_last}" - ) - if is_last: - self.has_sent = True - logger.info(f"FakeKVSender send success") - else: - self.has_sent = False - logger.info(f"FakeKVSender send fake transferring") + self.has_sent = True + logger.info(f"FakeKVSender send with kv_indices: {kv_indices}") def failure_exception(self): raise Exception("Fake KVSender Exception") diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 7226805bc..4b843e02e 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -464,6 +464,8 @@ class MooncakeKVSender(BaseKVSender): self.aux_index = None self.bootstrap_server_url = bootstrap_addr self.session_id = self.kv_mgr.get_session_id() + # inner state + self.curr_idx = 0 def init(self, num_kv_indices: int, aux_index: Optional[int] = None): self.num_kv_indices = num_kv_indices @@ -472,9 +474,11 @@ class MooncakeKVSender(BaseKVSender): def send( self, kv_indices: npt.NDArray[np.int64], - index_slice: slice, - is_last: bool, ): + index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices)) + self.curr_idx += len(kv_indices) + is_last = self.curr_idx == self.num_kv_indices + if not is_last: self.kv_mgr.add_transfer_request( self.bootstrap_room, kv_indices, index_slice, False diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 6af1928ff..83fe5a838 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -384,11 +384,10 @@ class SchedulerDisaggregationPrefillMixin: if end_idx is not None else min(len(req.fill_ids), len(req.origin_input_ids)) ) + last_chunk = token_id is not None - if (not last_chunk) and ( - end_idx % page_size != 0 - ): # todo: remove the second condition + if not last_chunk: # if not the last chunk and the last page is partial, delay the last partial page to the next send end_idx = end_idx - end_idx % page_size @@ -405,16 +404,10 @@ class SchedulerDisaggregationPrefillMixin: req.metadata_buffer_index, token_id ) page_indices = kv_to_page_indices(kv_indices, page_size) - - page_start_idx = start_idx // page_size - page_end_idx = page_start_idx + len(page_indices) - if len(page_indices) == 0: logger.info( f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty" ) return - req.disagg_kv_sender.send( - page_indices, slice(page_start_idx, page_end_idx), last_chunk - ) + req.disagg_kv_sender.send(page_indices) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 5734cd95c..1dd9c519e 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -407,6 +407,7 @@ class GenerateReqInput: else None ), return_hidden_states=self.return_hidden_states, + # if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list bootstrap_host=( self.bootstrap_host[i] if self.bootstrap_host is not None else None ),