Minor PD style fix (#7215)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from .conn import (
|
||||
from sglang.srt.disaggregation.base.conn import (
|
||||
BaseKVBootstrapServer,
|
||||
BaseKVManager,
|
||||
BaseKVReceiver,
|
||||
|
||||
@@ -1 +1,5 @@
|
||||
from .conn import CommonKVBootstrapServer, CommonKVManager, CommonKVReceiver
|
||||
from sglang.srt.disaggregation.common.conn import (
|
||||
CommonKVBootstrapServer,
|
||||
CommonKVManager,
|
||||
CommonKVReceiver,
|
||||
)
|
||||
|
||||
@@ -45,11 +45,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
poll_and_all_reduce,
|
||||
prepare_abort,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
FINISH_ABORT,
|
||||
ScheduleBatch,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
KVCache,
|
||||
@@ -248,6 +244,7 @@ class DecodePreallocQueue:
|
||||
mgr=self.kv_manager,
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
||||
bootstrap_room=req.bootstrap_room,
|
||||
data_parallel_rank=req.data_parallel_rank,
|
||||
)
|
||||
|
||||
self.queue.append(
|
||||
@@ -636,15 +633,6 @@ class DecodeTransferQueue:
|
||||
|
||||
class SchedulerDisaggregationDecodeMixin:
|
||||
|
||||
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
||||
batch, _ = self.prepare_dp_attn_batch(batch)
|
||||
result = None
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
if not delay_process:
|
||||
self.process_batch_result(batch, result)
|
||||
return batch, result
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_normal_disagg_decode(self: Scheduler):
|
||||
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
||||
@@ -773,6 +761,15 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.last_batch = batch
|
||||
self.last_batch_in_queue = last_batch_in_queue
|
||||
|
||||
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
||||
batch, _ = self.prepare_dp_attn_batch(batch)
|
||||
result = None
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
if not delay_process:
|
||||
self.process_batch_result(batch, result)
|
||||
return batch, result
|
||||
|
||||
def get_next_disagg_decode_batch_to_run(
|
||||
self: Scheduler,
|
||||
) -> Optional[Tuple[ScheduleBatch, bool]]:
|
||||
|
||||
@@ -1 +1 @@
|
||||
from .conn import FakeKVReceiver, FakeKVSender
|
||||
from sglang.srt.disaggregation.fake.conn import FakeKVReceiver, FakeKVSender
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@@ -8,7 +8,6 @@ from sglang.srt.disaggregation.base.conn import (
|
||||
BaseKVManager,
|
||||
BaseKVReceiver,
|
||||
BaseKVSender,
|
||||
KVArgs,
|
||||
KVPoll,
|
||||
)
|
||||
|
||||
@@ -33,7 +32,7 @@ class FakeKVSender(BaseKVSender):
|
||||
return KVPoll.WaitingForInput
|
||||
else:
|
||||
# Assume transfer completed instantly
|
||||
logger.info("FakeKVSender poll success")
|
||||
logger.debug("FakeKVSender poll success")
|
||||
return KVPoll.Success
|
||||
|
||||
def init(
|
||||
@@ -41,7 +40,7 @@ class FakeKVSender(BaseKVSender):
|
||||
kv_indices: list[int],
|
||||
aux_index: Optional[int] = None,
|
||||
):
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
||||
)
|
||||
pass
|
||||
@@ -51,7 +50,7 @@ class FakeKVSender(BaseKVSender):
|
||||
kv_indices: npt.NDArray[np.int32],
|
||||
):
|
||||
self.has_sent = True
|
||||
logger.info(f"FakeKVSender send with kv_indices: {kv_indices}")
|
||||
logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}")
|
||||
|
||||
def failure_exception(self):
|
||||
raise Exception("Fake KVSender Exception")
|
||||
@@ -73,12 +72,12 @@ class FakeKVReceiver(BaseKVReceiver):
|
||||
return KVPoll.WaitingForInput
|
||||
else:
|
||||
# Assume transfer completed instantly
|
||||
logger.info("FakeKVReceiver poll success")
|
||||
logger.debug("FakeKVReceiver poll success")
|
||||
return KVPoll.Success
|
||||
|
||||
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
|
||||
self.has_init = True
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .conn import (
|
||||
from sglang.srt.disaggregation.mooncake.conn import (
|
||||
MooncakeKVBootstrapServer,
|
||||
MooncakeKVManager,
|
||||
MooncakeKVReceiver,
|
||||
|
||||
@@ -1 +1,6 @@
|
||||
from .conn import NixlKVBootstrapServer, NixlKVManager, NixlKVReceiver, NixlKVSender
|
||||
from sglang.srt.disaggregation.nixl.conn import (
|
||||
NixlKVBootstrapServer,
|
||||
NixlKVManager,
|
||||
NixlKVReceiver,
|
||||
NixlKVSender,
|
||||
)
|
||||
|
||||
@@ -202,7 +202,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
||||
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
if transfer_backend == TransferBackend.NIXL:
|
||||
elif transfer_backend == TransferBackend.NIXL:
|
||||
from sglang.srt.disaggregation.base import KVArgs
|
||||
from sglang.srt.disaggregation.nixl import (
|
||||
NixlKVBootstrapServer,
|
||||
@@ -219,7 +219,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
||||
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
if transfer_backend == TransferBackend.FAKE:
|
||||
elif transfer_backend == TransferBackend.FAKE:
|
||||
from sglang.srt.disaggregation.base import KVArgs
|
||||
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
||||
|
||||
|
||||
Reference in New Issue
Block a user