Minor PD style fix (#7215)

This commit is contained in:
Byron Hsu
2025-06-15 16:12:12 -07:00
committed by GitHub
parent 88f9c347b2
commit 96be97bfff
8 changed files with 33 additions and 28 deletions

View File

@@ -1,4 +1,4 @@
from .conn import ( from sglang.srt.disaggregation.base.conn import (
BaseKVBootstrapServer, BaseKVBootstrapServer,
BaseKVManager, BaseKVManager,
BaseKVReceiver, BaseKVReceiver,

View File

@@ -1 +1,5 @@
from .conn import CommonKVBootstrapServer, CommonKVManager, CommonKVReceiver from sglang.srt.disaggregation.common.conn import (
CommonKVBootstrapServer,
CommonKVManager,
CommonKVReceiver,
)

View File

@@ -45,11 +45,7 @@ from sglang.srt.disaggregation.utils import (
poll_and_all_reduce, poll_and_all_reduce,
prepare_abort, prepare_abort,
) )
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
FINISH_ABORT,
ScheduleBatch,
global_server_args_dict,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
KVCache, KVCache,
@@ -248,6 +244,7 @@ class DecodePreallocQueue:
mgr=self.kv_manager, mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}", bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
bootstrap_room=req.bootstrap_room, bootstrap_room=req.bootstrap_room,
data_parallel_rank=req.data_parallel_rank,
) )
self.queue.append( self.queue.append(
@@ -636,15 +633,6 @@ class DecodeTransferQueue:
class SchedulerDisaggregationDecodeMixin: class SchedulerDisaggregationDecodeMixin:
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
batch, _ = self.prepare_dp_attn_batch(batch)
result = None
if batch:
result = self.run_batch(batch)
if not delay_process:
self.process_batch_result(batch, result)
return batch, result
@torch.no_grad() @torch.no_grad()
def event_loop_normal_disagg_decode(self: Scheduler): def event_loop_normal_disagg_decode(self: Scheduler):
"""A normal scheduler loop for decode worker in disaggregation mode.""" """A normal scheduler loop for decode worker in disaggregation mode."""
@@ -773,6 +761,15 @@ class SchedulerDisaggregationDecodeMixin:
self.last_batch = batch self.last_batch = batch
self.last_batch_in_queue = last_batch_in_queue self.last_batch_in_queue = last_batch_in_queue
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
batch, _ = self.prepare_dp_attn_batch(batch)
result = None
if batch:
result = self.run_batch(batch)
if not delay_process:
self.process_batch_result(batch, result)
return batch, result
def get_next_disagg_decode_batch_to_run( def get_next_disagg_decode_batch_to_run(
self: Scheduler, self: Scheduler,
) -> Optional[Tuple[ScheduleBatch, bool]]: ) -> Optional[Tuple[ScheduleBatch, bool]]:

View File

@@ -1 +1 @@
from .conn import FakeKVReceiver, FakeKVSender from sglang.srt.disaggregation.fake.conn import FakeKVReceiver, FakeKVSender

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import Dict, List, Optional, Tuple, Union from typing import List, Optional
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@@ -8,7 +8,6 @@ from sglang.srt.disaggregation.base.conn import (
BaseKVManager, BaseKVManager,
BaseKVReceiver, BaseKVReceiver,
BaseKVSender, BaseKVSender,
KVArgs,
KVPoll, KVPoll,
) )
@@ -33,7 +32,7 @@ class FakeKVSender(BaseKVSender):
return KVPoll.WaitingForInput return KVPoll.WaitingForInput
else: else:
# Assume transfer completed instantly # Assume transfer completed instantly
logger.info("FakeKVSender poll success") logger.debug("FakeKVSender poll success")
return KVPoll.Success return KVPoll.Success
def init( def init(
@@ -41,7 +40,7 @@ class FakeKVSender(BaseKVSender):
kv_indices: list[int], kv_indices: list[int],
aux_index: Optional[int] = None, aux_index: Optional[int] = None,
): ):
logger.info( logger.debug(
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}" f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}"
) )
pass pass
@@ -51,7 +50,7 @@ class FakeKVSender(BaseKVSender):
kv_indices: npt.NDArray[np.int32], kv_indices: npt.NDArray[np.int32],
): ):
self.has_sent = True self.has_sent = True
logger.info(f"FakeKVSender send with kv_indices: {kv_indices}") logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}")
def failure_exception(self): def failure_exception(self):
raise Exception("Fake KVSender Exception") raise Exception("Fake KVSender Exception")
@@ -73,12 +72,12 @@ class FakeKVReceiver(BaseKVReceiver):
return KVPoll.WaitingForInput return KVPoll.WaitingForInput
else: else:
# Assume transfer completed instantly # Assume transfer completed instantly
logger.info("FakeKVReceiver poll success") logger.debug("FakeKVReceiver poll success")
return KVPoll.Success return KVPoll.Success
def init(self, kv_indices: list[int], aux_index: Optional[int] = None): def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
self.has_init = True self.has_init = True
logger.info( logger.debug(
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}" f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
) )

View File

@@ -1,4 +1,4 @@
from .conn import ( from sglang.srt.disaggregation.mooncake.conn import (
MooncakeKVBootstrapServer, MooncakeKVBootstrapServer,
MooncakeKVManager, MooncakeKVManager,
MooncakeKVReceiver, MooncakeKVReceiver,

View File

@@ -1 +1,6 @@
from .conn import NixlKVBootstrapServer, NixlKVManager, NixlKVReceiver, NixlKVSender from sglang.srt.disaggregation.nixl.conn import (
NixlKVBootstrapServer,
NixlKVManager,
NixlKVReceiver,
NixlKVSender,
)

View File

@@ -202,7 +202,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer, KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
} }
return class_mapping.get(class_type) return class_mapping.get(class_type)
if transfer_backend == TransferBackend.NIXL: elif transfer_backend == TransferBackend.NIXL:
from sglang.srt.disaggregation.base import KVArgs from sglang.srt.disaggregation.base import KVArgs
from sglang.srt.disaggregation.nixl import ( from sglang.srt.disaggregation.nixl import (
NixlKVBootstrapServer, NixlKVBootstrapServer,
@@ -219,7 +219,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer, KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
} }
return class_mapping.get(class_type) return class_mapping.get(class_type)
if transfer_backend == TransferBackend.FAKE: elif transfer_backend == TransferBackend.FAKE:
from sglang.srt.disaggregation.base import KVArgs from sglang.srt.disaggregation.base import KVArgs
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender