[PD] support pd fake transfer for warmup (#5726)
This commit is contained in:
@@ -32,6 +32,7 @@ from torch.distributed import ProcessGroup
|
||||
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
DisaggregationMode,
|
||||
FakeBootstrapHost,
|
||||
KVClassType,
|
||||
ReqToMetadataIdxAllocator,
|
||||
TransferBackend,
|
||||
@@ -133,8 +134,13 @@ class DecodePreallocQueue:
|
||||
|
||||
def add(self, req: Req) -> None:
|
||||
"""Add a request to the pending queue."""
|
||||
|
||||
kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
|
||||
if req.bootstrap_host == FakeBootstrapHost:
|
||||
# Fake transfer for warmup reqs
|
||||
kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
|
||||
else:
|
||||
kv_receiver_class = get_kv_class(
|
||||
self.transfer_backend, KVClassType.RECEIVER
|
||||
)
|
||||
kv_receiver = kv_receiver_class(
|
||||
mgr=self.kv_manager,
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
||||
|
||||
1
python/sglang/srt/disaggregation/fake/__init__.py
Normal file
1
python/sglang/srt/disaggregation/fake/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .conn import FakeKVReceiver, FakeKVSender
|
||||
88
python/sglang/srt/disaggregation/fake/conn.py
Normal file
88
python/sglang/srt/disaggregation/fake/conn.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from sglang.srt.disaggregation.base.conn import (
|
||||
BaseKVManager,
|
||||
BaseKVReceiver,
|
||||
BaseKVSender,
|
||||
KVArgs,
|
||||
KVPoll,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# For warmup reqs, we don't kv transfer, we use the fake sender and receiver
|
||||
class FakeKVSender(BaseKVSender):
|
||||
def __init__(self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int):
|
||||
self.has_sent = False
|
||||
|
||||
def poll(self) -> KVPoll:
|
||||
if self.has_sent is False:
|
||||
# Assume handshake completed instantly
|
||||
return KVPoll.WaitingForInput
|
||||
else:
|
||||
# Assume transfer completed instantly
|
||||
logger.info("FakeKVSender poll success")
|
||||
return KVPoll.Success
|
||||
|
||||
def init(
|
||||
self,
|
||||
kv_indices: list[int],
|
||||
aux_index: Optional[int] = None,
|
||||
dest_ranks: Optional[list[int]] = None,
|
||||
):
|
||||
logger.info(
|
||||
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}, dest_ranks: {dest_ranks}"
|
||||
)
|
||||
pass
|
||||
|
||||
def send(
|
||||
self,
|
||||
kv_indices: npt.NDArray[np.int64],
|
||||
index_slice: slice,
|
||||
is_last: bool,
|
||||
):
|
||||
logger.info(
|
||||
f"FakeKVSender send with kv_indices: {kv_indices}, index_slice: {index_slice}, is_last: {is_last}"
|
||||
)
|
||||
if is_last:
|
||||
self.has_sent = True
|
||||
logger.info(f"FakeKVSender send success")
|
||||
else:
|
||||
self.has_sent = False
|
||||
logger.info(f"FakeKVSender send fake transfering")
|
||||
|
||||
def failure_exception(self):
|
||||
raise Exception("Fake KVSender Exception")
|
||||
|
||||
|
||||
class FakeKVReceiver(BaseKVReceiver):
|
||||
def __init__(
|
||||
self,
|
||||
mgr: BaseKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
):
|
||||
self.has_init = False
|
||||
|
||||
def poll(self) -> KVPoll:
|
||||
if self.has_init is False:
|
||||
# Assume handshake completed instantly
|
||||
return KVPoll.WaitingForInput
|
||||
else:
|
||||
# Assume transfer completed instantly
|
||||
logger.info("FakeKVReceiver poll success")
|
||||
return KVPoll.Success
|
||||
|
||||
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
|
||||
self.has_init = True
|
||||
logger.info(
|
||||
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
||||
)
|
||||
|
||||
def failure_exception(self):
|
||||
raise Exception("Fake KVReceiver Exception")
|
||||
@@ -29,6 +29,7 @@ import torch
|
||||
from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
DisaggregationMode,
|
||||
FakeBootstrapHost,
|
||||
KVClassType,
|
||||
ReqToMetadataIdxAllocator,
|
||||
TransferBackend,
|
||||
@@ -116,7 +117,11 @@ class PrefillBootstrapQueue:
|
||||
return kv_manager
|
||||
|
||||
def add(self, req: Req) -> None:
|
||||
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
|
||||
if req.bootstrap_host == FakeBootstrapHost:
|
||||
# Fake transfer for warmup reqs
|
||||
kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
|
||||
else:
|
||||
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
|
||||
req.disagg_kv_sender = kv_sender_class(
|
||||
mgr=self.kv_manager,
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
||||
|
||||
@@ -15,6 +15,9 @@ class DisaggregationMode(Enum):
|
||||
DECODE = "decode"
|
||||
|
||||
|
||||
FakeBootstrapHost = "2.2.2.2"
|
||||
|
||||
|
||||
def poll_and_all_reduce(pollers, gloo_group):
|
||||
polls = [int(poller.poll()) for poller in pollers]
|
||||
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
|
||||
@@ -59,6 +62,8 @@ class KVClassType(Enum):
|
||||
|
||||
|
||||
def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
||||
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
||||
|
||||
if transfer_backend == TransferBackend.MOONCAKE:
|
||||
from sglang.srt.disaggregation.mooncake import (
|
||||
MooncakeKVBootstrapServer,
|
||||
@@ -70,7 +75,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
||||
class_mapping = {
|
||||
KVClassType.MANAGER: MooncakeKVManager,
|
||||
KVClassType.SENDER: MooncakeKVSender,
|
||||
KVClassType.RECEIVER: MooncakeKVReceiver,
|
||||
KVClassType.RECEIVER: (MooncakeKVReceiver),
|
||||
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
@@ -85,10 +90,19 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
||||
class_mapping = {
|
||||
KVClassType.MANAGER: NixlKVManager,
|
||||
KVClassType.SENDER: NixlKVSender,
|
||||
KVClassType.RECEIVER: NixlKVReceiver,
|
||||
KVClassType.RECEIVER: (NixlKVReceiver),
|
||||
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
if transfer_backend == TransferBackend.FAKE:
|
||||
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
||||
|
||||
class_mapping = {
|
||||
KVClassType.SENDER: FakeKVSender,
|
||||
KVClassType.RECEIVER: (FakeKVReceiver),
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
|
||||
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
||||
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
from sglang.srt.disaggregation.utils import FakeBootstrapHost
|
||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||
from sglang.srt.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.managers.io_struct import (
|
||||
@@ -821,8 +822,32 @@ def _wait_and_warmup(
|
||||
)
|
||||
assert res.status_code == 200, f"{res}"
|
||||
else:
|
||||
# Warmup request currently hangs in disaggregation mode, so we skip it.
|
||||
logger.info("Skipping warmup request in disaggregation mode")
|
||||
logger.info(f"Start of prefill warmup ...")
|
||||
json_data = {
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 8,
|
||||
"ignore_eos": True,
|
||||
},
|
||||
"bootstrap_host": [FakeBootstrapHost] * server_args.dp_size,
|
||||
# This is a hack to ensure fake transfer is enabled during prefill warmup
|
||||
# ensure each dp rank has a unique bootstrap_room during prefill warmup
|
||||
"bootstrap_room": [
|
||||
i * (2**63 // server_args.dp_size) + (i % server_args.tp_size)
|
||||
for i in range(server_args.dp_size)
|
||||
],
|
||||
"input_ids": [[0, 1, 2, 3]] * server_args.dp_size,
|
||||
}
|
||||
res = requests.post(
|
||||
url + request_name,
|
||||
json=json_data,
|
||||
headers=headers,
|
||||
timeout=1800, # because of deep gemm precache is very long if not precache.
|
||||
)
|
||||
logger.info(
|
||||
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
last_traceback = get_exception_traceback()
|
||||
if pipe_finish_writer is not None:
|
||||
|
||||
Reference in New Issue
Block a user