diff --git a/python/sglang/srt/disaggregation/base/__init__.py b/python/sglang/srt/disaggregation/base/__init__.py index bfeecfe1c..ef0f797fc 100644 --- a/python/sglang/srt/disaggregation/base/__init__.py +++ b/python/sglang/srt/disaggregation/base/__init__.py @@ -1,4 +1,4 @@ -from .conn import ( +from sglang.srt.disaggregation.base.conn import ( BaseKVBootstrapServer, BaseKVManager, BaseKVReceiver, diff --git a/python/sglang/srt/disaggregation/common/__init__.py b/python/sglang/srt/disaggregation/common/__init__.py index 950db151f..8294c3892 100644 --- a/python/sglang/srt/disaggregation/common/__init__.py +++ b/python/sglang/srt/disaggregation/common/__init__.py @@ -1 +1,5 @@ -from .conn import CommonKVBootstrapServer, CommonKVManager, CommonKVReceiver +from sglang.srt.disaggregation.common.conn import ( + CommonKVBootstrapServer, + CommonKVManager, + CommonKVReceiver, +) diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 9c0860cd0..76b06c8ba 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -45,11 +45,7 @@ from sglang.srt.disaggregation.utils import ( poll_and_all_reduce, prepare_abort, ) -from sglang.srt.managers.schedule_batch import ( - FINISH_ABORT, - ScheduleBatch, - global_server_args_dict, -) +from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.memory_pool import ( KVCache, @@ -248,6 +244,7 @@ class DecodePreallocQueue: mgr=self.kv_manager, bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}", bootstrap_room=req.bootstrap_room, + data_parallel_rank=req.data_parallel_rank, ) self.queue.append( @@ -636,15 +633,6 @@ class DecodeTransferQueue: class SchedulerDisaggregationDecodeMixin: - def _prepare_idle_batch_and_run(self, batch, delay_process=False): - batch, _ = self.prepare_dp_attn_batch(batch) - result = None - if batch: - result = self.run_batch(batch) - if not delay_process: - self.process_batch_result(batch, result) - return batch, result - @torch.no_grad() def event_loop_normal_disagg_decode(self: Scheduler): """A normal scheduler loop for decode worker in disaggregation mode.""" @@ -773,6 +761,15 @@ class SchedulerDisaggregationDecodeMixin: self.last_batch = batch self.last_batch_in_queue = last_batch_in_queue + def _prepare_idle_batch_and_run(self, batch, delay_process=False): + batch, _ = self.prepare_dp_attn_batch(batch) + result = None + if batch: + result = self.run_batch(batch) + if not delay_process: + self.process_batch_result(batch, result) + return batch, result + def get_next_disagg_decode_batch_to_run( self: Scheduler, ) -> Optional[Tuple[ScheduleBatch, bool]]: diff --git a/python/sglang/srt/disaggregation/fake/__init__.py b/python/sglang/srt/disaggregation/fake/__init__.py index 4adebb7c3..d7cdb4b27 100644 --- a/python/sglang/srt/disaggregation/fake/__init__.py +++ b/python/sglang/srt/disaggregation/fake/__init__.py @@ -1 +1 @@ -from .conn import FakeKVReceiver, FakeKVSender +from sglang.srt.disaggregation.fake.conn import FakeKVReceiver, FakeKVSender diff --git a/python/sglang/srt/disaggregation/fake/conn.py b/python/sglang/srt/disaggregation/fake/conn.py index 63a39ac2f..d25f47a38 100644 --- a/python/sglang/srt/disaggregation/fake/conn.py +++ b/python/sglang/srt/disaggregation/fake/conn.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Optional import numpy as np import numpy.typing as npt @@ -8,7 +8,6 @@ from sglang.srt.disaggregation.base.conn import ( BaseKVManager, BaseKVReceiver, BaseKVSender, - KVArgs, KVPoll, ) @@ -33,7 +32,7 @@ class FakeKVSender(BaseKVSender): return KVPoll.WaitingForInput else: # Assume transfer completed instantly - logger.info("FakeKVSender poll success") + logger.debug("FakeKVSender poll success") return KVPoll.Success def init( @@ -41,7 +40,7 @@ class FakeKVSender(BaseKVSender): kv_indices: list[int], aux_index: Optional[int] = None, ): - logger.info( + logger.debug( f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}" ) pass @@ -51,7 +50,7 @@ class FakeKVSender(BaseKVSender): kv_indices: npt.NDArray[np.int32], ): self.has_sent = True - logger.info(f"FakeKVSender send with kv_indices: {kv_indices}") + logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}") def failure_exception(self): raise Exception("Fake KVSender Exception") @@ -73,12 +72,12 @@ class FakeKVReceiver(BaseKVReceiver): return KVPoll.WaitingForInput else: # Assume transfer completed instantly - logger.info("FakeKVReceiver poll success") + logger.debug("FakeKVReceiver poll success") return KVPoll.Success def init(self, kv_indices: list[int], aux_index: Optional[int] = None): self.has_init = True - logger.info( + logger.debug( f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}" ) diff --git a/python/sglang/srt/disaggregation/mooncake/__init__.py b/python/sglang/srt/disaggregation/mooncake/__init__.py index 035097f68..bea967e4e 100644 --- a/python/sglang/srt/disaggregation/mooncake/__init__.py +++ b/python/sglang/srt/disaggregation/mooncake/__init__.py @@ -1,4 +1,4 @@ -from .conn import ( +from sglang.srt.disaggregation.mooncake.conn import ( MooncakeKVBootstrapServer, MooncakeKVManager, MooncakeKVReceiver, diff --git a/python/sglang/srt/disaggregation/nixl/__init__.py b/python/sglang/srt/disaggregation/nixl/__init__.py index d9456e0fd..4df7baba2 100644 --- a/python/sglang/srt/disaggregation/nixl/__init__.py +++ b/python/sglang/srt/disaggregation/nixl/__init__.py @@ -1 +1,6 @@ -from .conn import NixlKVBootstrapServer, NixlKVManager, NixlKVReceiver, NixlKVSender +from sglang.srt.disaggregation.nixl.conn import ( + NixlKVBootstrapServer, + NixlKVManager, + NixlKVReceiver, + NixlKVSender, +) diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 6b52342dd..7dd6a6bec 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -202,7 +202,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer, } return class_mapping.get(class_type) - if transfer_backend == TransferBackend.NIXL: + elif transfer_backend == TransferBackend.NIXL: from sglang.srt.disaggregation.base import KVArgs from sglang.srt.disaggregation.nixl import ( NixlKVBootstrapServer, @@ -219,7 +219,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer, } return class_mapping.get(class_type) - if transfer_backend == TransferBackend.FAKE: + elif transfer_backend == TransferBackend.FAKE: from sglang.srt.disaggregation.base import KVArgs from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender