Minor PD style fix (#7215)

This commit is contained in:
Byron Hsu
2025-06-15 16:12:12 -07:00
committed by GitHub
parent 88f9c347b2
commit 96be97bfff
8 changed files with 33 additions and 28 deletions

View File

@@ -1,4 +1,4 @@
from .conn import (
from sglang.srt.disaggregation.base.conn import (
BaseKVBootstrapServer,
BaseKVManager,
BaseKVReceiver,

View File

@@ -1 +1,5 @@
from .conn import CommonKVBootstrapServer, CommonKVManager, CommonKVReceiver
from sglang.srt.disaggregation.common.conn import (
CommonKVBootstrapServer,
CommonKVManager,
CommonKVReceiver,
)

View File

@@ -45,11 +45,7 @@ from sglang.srt.disaggregation.utils import (
poll_and_all_reduce,
prepare_abort,
)
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
ScheduleBatch,
global_server_args_dict,
)
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import (
KVCache,
@@ -248,6 +244,7 @@ class DecodePreallocQueue:
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
data_parallel_rank=req.data_parallel_rank,
)
self.queue.append(
@@ -636,15 +633,6 @@ class DecodeTransferQueue:
class SchedulerDisaggregationDecodeMixin:
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
batch, _ = self.prepare_dp_attn_batch(batch)
result = None
if batch:
result = self.run_batch(batch)
if not delay_process:
self.process_batch_result(batch, result)
return batch, result
@torch.no_grad()
def event_loop_normal_disagg_decode(self: Scheduler):
"""A normal scheduler loop for decode worker in disaggregation mode."""
@@ -773,6 +761,15 @@ class SchedulerDisaggregationDecodeMixin:
self.last_batch = batch
self.last_batch_in_queue = last_batch_in_queue
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
batch, _ = self.prepare_dp_attn_batch(batch)
result = None
if batch:
result = self.run_batch(batch)
if not delay_process:
self.process_batch_result(batch, result)
return batch, result
def get_next_disagg_decode_batch_to_run(
self: Scheduler,
) -> Optional[Tuple[ScheduleBatch, bool]]:

View File

@@ -1 +1 @@
from .conn import FakeKVReceiver, FakeKVSender
from sglang.srt.disaggregation.fake.conn import FakeKVReceiver, FakeKVSender

View File

@@ -1,5 +1,5 @@
import logging
from typing import Dict, List, Optional, Tuple, Union
from typing import List, Optional
import numpy as np
import numpy.typing as npt
@@ -8,7 +8,6 @@ from sglang.srt.disaggregation.base.conn import (
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
@@ -33,7 +32,7 @@ class FakeKVSender(BaseKVSender):
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
logger.info("FakeKVSender poll success")
logger.debug("FakeKVSender poll success")
return KVPoll.Success
def init(
@@ -41,7 +40,7 @@ class FakeKVSender(BaseKVSender):
kv_indices: list[int],
aux_index: Optional[int] = None,
):
logger.info(
logger.debug(
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}"
)
pass
@@ -51,7 +50,7 @@ class FakeKVSender(BaseKVSender):
kv_indices: npt.NDArray[np.int32],
):
self.has_sent = True
logger.info(f"FakeKVSender send with kv_indices: {kv_indices}")
logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}")
def failure_exception(self):
raise Exception("Fake KVSender Exception")
@@ -73,12 +72,12 @@ class FakeKVReceiver(BaseKVReceiver):
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
logger.info("FakeKVReceiver poll success")
logger.debug("FakeKVReceiver poll success")
return KVPoll.Success
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
self.has_init = True
logger.info(
logger.debug(
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
)

View File

@@ -1,4 +1,4 @@
from .conn import (
from sglang.srt.disaggregation.mooncake.conn import (
MooncakeKVBootstrapServer,
MooncakeKVManager,
MooncakeKVReceiver,

View File

@@ -1 +1,6 @@
from .conn import NixlKVBootstrapServer, NixlKVManager, NixlKVReceiver, NixlKVSender
from sglang.srt.disaggregation.nixl.conn import (
NixlKVBootstrapServer,
NixlKVManager,
NixlKVReceiver,
NixlKVSender,
)

View File

@@ -202,7 +202,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
}
return class_mapping.get(class_type)
if transfer_backend == TransferBackend.NIXL:
elif transfer_backend == TransferBackend.NIXL:
from sglang.srt.disaggregation.base import KVArgs
from sglang.srt.disaggregation.nixl import (
NixlKVBootstrapServer,
@@ -219,7 +219,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
}
return class_mapping.get(class_type)
if transfer_backend == TransferBackend.FAKE:
elif transfer_backend == TransferBackend.FAKE:
from sglang.srt.disaggregation.base import KVArgs
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender