From c6c6264073a88897c5440dc25a3b86c6560814bd Mon Sep 17 00:00:00 2001 From: ybyang <10629930+whybeyoung@users.noreply.github.com> Date: Tue, 29 Apr 2025 00:33:20 +0800 Subject: [PATCH] [PD] support pd fake transfer for warmup (#5726) --- python/sglang/srt/disaggregation/decode.py | 10 ++- .../srt/disaggregation/fake/__init__.py | 1 + python/sglang/srt/disaggregation/fake/conn.py | 88 +++++++++++++++++++ python/sglang/srt/disaggregation/prefill.py | 7 +- python/sglang/srt/disaggregation/utils.py | 18 +++- python/sglang/srt/entrypoints/http_server.py | 29 +++++- 6 files changed, 146 insertions(+), 7 deletions(-) create mode 100644 python/sglang/srt/disaggregation/fake/__init__.py create mode 100644 python/sglang/srt/disaggregation/fake/conn.py diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index dd8e8ca6a..813a5e9cf 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -32,6 +32,7 @@ from torch.distributed import ProcessGroup from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll from sglang.srt.disaggregation.utils import ( DisaggregationMode, + FakeBootstrapHost, KVClassType, ReqToMetadataIdxAllocator, TransferBackend, @@ -133,8 +134,13 @@ class DecodePreallocQueue: def add(self, req: Req) -> None: """Add a request to the pending queue.""" - - kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER) + if req.bootstrap_host == FakeBootstrapHost: + # Fake transfer for warmup reqs + kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER) + else: + kv_receiver_class = get_kv_class( + self.transfer_backend, KVClassType.RECEIVER + ) kv_receiver = kv_receiver_class( mgr=self.kv_manager, bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}", diff --git a/python/sglang/srt/disaggregation/fake/__init__.py b/python/sglang/srt/disaggregation/fake/__init__.py new file mode 100644 index 000000000..4adebb7c3 --- /dev/null +++ b/python/sglang/srt/disaggregation/fake/__init__.py @@ -0,0 +1 @@ +from .conn import FakeKVReceiver, FakeKVSender diff --git a/python/sglang/srt/disaggregation/fake/conn.py b/python/sglang/srt/disaggregation/fake/conn.py new file mode 100644 index 000000000..f65289f44 --- /dev/null +++ b/python/sglang/srt/disaggregation/fake/conn.py @@ -0,0 +1,88 @@ +import logging +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import numpy.typing as npt + +from sglang.srt.disaggregation.base.conn import ( + BaseKVManager, + BaseKVReceiver, + BaseKVSender, + KVArgs, + KVPoll, +) + +logger = logging.getLogger(__name__) + + +# For warmup reqs, we don't kv transfer, we use the fake sender and receiver +class FakeKVSender(BaseKVSender): + def __init__(self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int): + self.has_sent = False + + def poll(self) -> KVPoll: + if self.has_sent is False: + # Assume handshake completed instantly + return KVPoll.WaitingForInput + else: + # Assume transfer completed instantly + logger.info("FakeKVSender poll success") + return KVPoll.Success + + def init( + self, + kv_indices: list[int], + aux_index: Optional[int] = None, + dest_ranks: Optional[list[int]] = None, + ): + logger.info( + f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}, dest_ranks: {dest_ranks}" + ) + pass + + def send( + self, + kv_indices: npt.NDArray[np.int64], + index_slice: slice, + is_last: bool, + ): + logger.info( + f"FakeKVSender send with kv_indices: {kv_indices}, index_slice: {index_slice}, is_last: {is_last}" + ) + if is_last: + self.has_sent = True + logger.info(f"FakeKVSender send success") + else: + self.has_sent = False + logger.info(f"FakeKVSender send fake transfering") + + def failure_exception(self): + raise Exception("Fake KVSender Exception") + + +class FakeKVReceiver(BaseKVReceiver): + def __init__( + self, + mgr: BaseKVManager, + bootstrap_addr: str, + bootstrap_room: Optional[int] = None, + ): + self.has_init = False + + def poll(self) -> KVPoll: + if self.has_init is False: + # Assume handshake completed instantly + return KVPoll.WaitingForInput + else: + # Assume transfer completed instantly + logger.info("FakeKVReceiver poll success") + return KVPoll.Success + + def init(self, kv_indices: list[int], aux_index: Optional[int] = None): + self.has_init = True + logger.info( + f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}" + ) + + def failure_exception(self): + raise Exception("Fake KVReceiver Exception") diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index d6a8fa398..6204faca2 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -29,6 +29,7 @@ import torch from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll from sglang.srt.disaggregation.utils import ( DisaggregationMode, + FakeBootstrapHost, KVClassType, ReqToMetadataIdxAllocator, TransferBackend, @@ -116,7 +117,11 @@ class PrefillBootstrapQueue: return kv_manager def add(self, req: Req) -> None: - kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER) + if req.bootstrap_host == FakeBootstrapHost: + # Fake transfer for warmup reqs + kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER) + else: + kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER) req.disagg_kv_sender = kv_sender_class( mgr=self.kv_manager, bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}", diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index cb7654088..b50a86c63 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -15,6 +15,9 @@ class DisaggregationMode(Enum): DECODE = "decode" +FakeBootstrapHost = "2.2.2.2" + + def poll_and_all_reduce(pollers, gloo_group): polls = [int(poller.poll()) for poller in pollers] tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu") @@ -59,6 +62,8 @@ class KVClassType(Enum): def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): + from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender + if transfer_backend == TransferBackend.MOONCAKE: from sglang.srt.disaggregation.mooncake import ( MooncakeKVBootstrapServer, @@ -70,7 +75,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): class_mapping = { KVClassType.MANAGER: MooncakeKVManager, KVClassType.SENDER: MooncakeKVSender, - KVClassType.RECEIVER: MooncakeKVReceiver, + KVClassType.RECEIVER: (MooncakeKVReceiver), KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer, } return class_mapping.get(class_type) @@ -85,10 +90,19 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): class_mapping = { KVClassType.MANAGER: NixlKVManager, KVClassType.SENDER: NixlKVSender, - KVClassType.RECEIVER: NixlKVReceiver, + KVClassType.RECEIVER: (NixlKVReceiver), KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer, } return class_mapping.get(class_type) + if transfer_backend == TransferBackend.FAKE: + from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender + + class_mapping = { + KVClassType.SENDER: FakeKVSender, + KVClassType.RECEIVER: (FakeKVReceiver), + } + return class_mapping.get(class_type) + raise ValueError(f"Unsupported transfer backend: {transfer_backend}") diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 7f1cc01fd..147ab4131 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -42,6 +42,7 @@ from fastapi import FastAPI, File, Form, Request, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse, Response, StreamingResponse +from sglang.srt.disaggregation.utils import FakeBootstrapHost from sglang.srt.entrypoints.engine import _launch_subprocesses from sglang.srt.function_call_parser import FunctionCallParser from sglang.srt.managers.io_struct import ( @@ -821,8 +822,32 @@ def _wait_and_warmup( ) assert res.status_code == 200, f"{res}" else: - # Warmup request currently hangs in disaggregation mode, so we skip it. - logger.info("Skipping warmup request in disaggregation mode") + logger.info(f"Start of prefill warmup ...") + json_data = { + "sampling_params": { + "temperature": 0.0, + "max_new_tokens": 8, + "ignore_eos": True, + }, + "bootstrap_host": [FakeBootstrapHost] * server_args.dp_size, + # This is a hack to ensure fake transfer is enabled during prefill warmup + # ensure each dp rank has a unique bootstrap_room during prefill warmup + "bootstrap_room": [ + i * (2**63 // server_args.dp_size) + (i % server_args.tp_size) + for i in range(server_args.dp_size) + ], + "input_ids": [[0, 1, 2, 3]] * server_args.dp_size, + } + res = requests.post( + url + request_name, + json=json_data, + headers=headers, + timeout=1800, # because of deep gemm precache is very long if not precache. + ) + logger.info( + f"End of prefill warmup with status {res.status_code}, resp: {res.json()}" + ) + except Exception: last_traceback = get_exception_traceback() if pipe_finish_writer is not None: