[PD] support pd fake transfer for warmup (#5726)
This commit is contained in:
@@ -32,6 +32,7 @@ from torch.distributed import ProcessGroup
|
||||
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
DisaggregationMode,
|
||||
FakeBootstrapHost,
|
||||
KVClassType,
|
||||
ReqToMetadataIdxAllocator,
|
||||
TransferBackend,
|
||||
@@ -133,8 +134,13 @@ class DecodePreallocQueue:
|
||||
|
||||
def add(self, req: Req) -> None:
|
||||
"""Add a request to the pending queue."""
|
||||
|
||||
kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
|
||||
if req.bootstrap_host == FakeBootstrapHost:
|
||||
# Fake transfer for warmup reqs
|
||||
kv_receiver_class = get_kv_class(TransferBackend.FAKE, KVClassType.RECEIVER)
|
||||
else:
|
||||
kv_receiver_class = get_kv_class(
|
||||
self.transfer_backend, KVClassType.RECEIVER
|
||||
)
|
||||
kv_receiver = kv_receiver_class(
|
||||
mgr=self.kv_manager,
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
||||
|
||||
1
python/sglang/srt/disaggregation/fake/__init__.py
Normal file
1
python/sglang/srt/disaggregation/fake/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .conn import FakeKVReceiver, FakeKVSender
|
||||
88
python/sglang/srt/disaggregation/fake/conn.py
Normal file
88
python/sglang/srt/disaggregation/fake/conn.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from sglang.srt.disaggregation.base.conn import (
|
||||
BaseKVManager,
|
||||
BaseKVReceiver,
|
||||
BaseKVSender,
|
||||
KVArgs,
|
||||
KVPoll,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# For warmup reqs, we don't kv transfer, we use the fake sender and receiver
|
||||
class FakeKVSender(BaseKVSender):
|
||||
def __init__(self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int):
|
||||
self.has_sent = False
|
||||
|
||||
def poll(self) -> KVPoll:
|
||||
if self.has_sent is False:
|
||||
# Assume handshake completed instantly
|
||||
return KVPoll.WaitingForInput
|
||||
else:
|
||||
# Assume transfer completed instantly
|
||||
logger.info("FakeKVSender poll success")
|
||||
return KVPoll.Success
|
||||
|
||||
def init(
|
||||
self,
|
||||
kv_indices: list[int],
|
||||
aux_index: Optional[int] = None,
|
||||
dest_ranks: Optional[list[int]] = None,
|
||||
):
|
||||
logger.info(
|
||||
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}, dest_ranks: {dest_ranks}"
|
||||
)
|
||||
pass
|
||||
|
||||
def send(
|
||||
self,
|
||||
kv_indices: npt.NDArray[np.int64],
|
||||
index_slice: slice,
|
||||
is_last: bool,
|
||||
):
|
||||
logger.info(
|
||||
f"FakeKVSender send with kv_indices: {kv_indices}, index_slice: {index_slice}, is_last: {is_last}"
|
||||
)
|
||||
if is_last:
|
||||
self.has_sent = True
|
||||
logger.info(f"FakeKVSender send success")
|
||||
else:
|
||||
self.has_sent = False
|
||||
logger.info(f"FakeKVSender send fake transfering")
|
||||
|
||||
def failure_exception(self):
|
||||
raise Exception("Fake KVSender Exception")
|
||||
|
||||
|
||||
class FakeKVReceiver(BaseKVReceiver):
|
||||
def __init__(
|
||||
self,
|
||||
mgr: BaseKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
):
|
||||
self.has_init = False
|
||||
|
||||
def poll(self) -> KVPoll:
|
||||
if self.has_init is False:
|
||||
# Assume handshake completed instantly
|
||||
return KVPoll.WaitingForInput
|
||||
else:
|
||||
# Assume transfer completed instantly
|
||||
logger.info("FakeKVReceiver poll success")
|
||||
return KVPoll.Success
|
||||
|
||||
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
|
||||
self.has_init = True
|
||||
logger.info(
|
||||
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
||||
)
|
||||
|
||||
def failure_exception(self):
|
||||
raise Exception("Fake KVReceiver Exception")
|
||||
@@ -29,6 +29,7 @@ import torch
|
||||
from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
DisaggregationMode,
|
||||
FakeBootstrapHost,
|
||||
KVClassType,
|
||||
ReqToMetadataIdxAllocator,
|
||||
TransferBackend,
|
||||
@@ -116,7 +117,11 @@ class PrefillBootstrapQueue:
|
||||
return kv_manager
|
||||
|
||||
def add(self, req: Req) -> None:
|
||||
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
|
||||
if req.bootstrap_host == FakeBootstrapHost:
|
||||
# Fake transfer for warmup reqs
|
||||
kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
|
||||
else:
|
||||
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
|
||||
req.disagg_kv_sender = kv_sender_class(
|
||||
mgr=self.kv_manager,
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
||||
|
||||
@@ -15,6 +15,9 @@ class DisaggregationMode(Enum):
|
||||
DECODE = "decode"
|
||||
|
||||
|
||||
FakeBootstrapHost = "2.2.2.2"
|
||||
|
||||
|
||||
def poll_and_all_reduce(pollers, gloo_group):
|
||||
polls = [int(poller.poll()) for poller in pollers]
|
||||
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
|
||||
@@ -59,6 +62,8 @@ class KVClassType(Enum):
|
||||
|
||||
|
||||
def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
||||
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
||||
|
||||
if transfer_backend == TransferBackend.MOONCAKE:
|
||||
from sglang.srt.disaggregation.mooncake import (
|
||||
MooncakeKVBootstrapServer,
|
||||
@@ -70,7 +75,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
||||
class_mapping = {
|
||||
KVClassType.MANAGER: MooncakeKVManager,
|
||||
KVClassType.SENDER: MooncakeKVSender,
|
||||
KVClassType.RECEIVER: MooncakeKVReceiver,
|
||||
KVClassType.RECEIVER: (MooncakeKVReceiver),
|
||||
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
@@ -85,10 +90,19 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
||||
class_mapping = {
|
||||
KVClassType.MANAGER: NixlKVManager,
|
||||
KVClassType.SENDER: NixlKVSender,
|
||||
KVClassType.RECEIVER: NixlKVReceiver,
|
||||
KVClassType.RECEIVER: (NixlKVReceiver),
|
||||
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
if transfer_backend == TransferBackend.FAKE:
|
||||
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
||||
|
||||
class_mapping = {
|
||||
KVClassType.SENDER: FakeKVSender,
|
||||
KVClassType.RECEIVER: (FakeKVReceiver),
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
|
||||
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user