Minor PD style fix (#7215)
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
from .conn import (
|
from sglang.srt.disaggregation.base.conn import (
|
||||||
BaseKVBootstrapServer,
|
BaseKVBootstrapServer,
|
||||||
BaseKVManager,
|
BaseKVManager,
|
||||||
BaseKVReceiver,
|
BaseKVReceiver,
|
||||||
|
|||||||
@@ -1 +1,5 @@
|
|||||||
from .conn import CommonKVBootstrapServer, CommonKVManager, CommonKVReceiver
|
from sglang.srt.disaggregation.common.conn import (
|
||||||
|
CommonKVBootstrapServer,
|
||||||
|
CommonKVManager,
|
||||||
|
CommonKVReceiver,
|
||||||
|
)
|
||||||
|
|||||||
@@ -45,11 +45,7 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
poll_and_all_reduce,
|
poll_and_all_reduce,
|
||||||
prepare_abort,
|
prepare_abort,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
||||||
FINISH_ABORT,
|
|
||||||
ScheduleBatch,
|
|
||||||
global_server_args_dict,
|
|
||||||
)
|
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import (
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
KVCache,
|
KVCache,
|
||||||
@@ -248,6 +244,7 @@ class DecodePreallocQueue:
|
|||||||
mgr=self.kv_manager,
|
mgr=self.kv_manager,
|
||||||
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
||||||
bootstrap_room=req.bootstrap_room,
|
bootstrap_room=req.bootstrap_room,
|
||||||
|
data_parallel_rank=req.data_parallel_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.queue.append(
|
self.queue.append(
|
||||||
@@ -636,15 +633,6 @@ class DecodeTransferQueue:
|
|||||||
|
|
||||||
class SchedulerDisaggregationDecodeMixin:
|
class SchedulerDisaggregationDecodeMixin:
|
||||||
|
|
||||||
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
|
||||||
batch, _ = self.prepare_dp_attn_batch(batch)
|
|
||||||
result = None
|
|
||||||
if batch:
|
|
||||||
result = self.run_batch(batch)
|
|
||||||
if not delay_process:
|
|
||||||
self.process_batch_result(batch, result)
|
|
||||||
return batch, result
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def event_loop_normal_disagg_decode(self: Scheduler):
|
def event_loop_normal_disagg_decode(self: Scheduler):
|
||||||
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
||||||
@@ -773,6 +761,15 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
self.last_batch_in_queue = last_batch_in_queue
|
self.last_batch_in_queue = last_batch_in_queue
|
||||||
|
|
||||||
|
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
||||||
|
batch, _ = self.prepare_dp_attn_batch(batch)
|
||||||
|
result = None
|
||||||
|
if batch:
|
||||||
|
result = self.run_batch(batch)
|
||||||
|
if not delay_process:
|
||||||
|
self.process_batch_result(batch, result)
|
||||||
|
return batch, result
|
||||||
|
|
||||||
def get_next_disagg_decode_batch_to_run(
|
def get_next_disagg_decode_batch_to_run(
|
||||||
self: Scheduler,
|
self: Scheduler,
|
||||||
) -> Optional[Tuple[ScheduleBatch, bool]]:
|
) -> Optional[Tuple[ScheduleBatch, bool]]:
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
from .conn import FakeKVReceiver, FakeKVSender
|
from sglang.srt.disaggregation.fake.conn import FakeKVReceiver, FakeKVSender
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@@ -8,7 +8,6 @@ from sglang.srt.disaggregation.base.conn import (
|
|||||||
BaseKVManager,
|
BaseKVManager,
|
||||||
BaseKVReceiver,
|
BaseKVReceiver,
|
||||||
BaseKVSender,
|
BaseKVSender,
|
||||||
KVArgs,
|
|
||||||
KVPoll,
|
KVPoll,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -33,7 +32,7 @@ class FakeKVSender(BaseKVSender):
|
|||||||
return KVPoll.WaitingForInput
|
return KVPoll.WaitingForInput
|
||||||
else:
|
else:
|
||||||
# Assume transfer completed instantly
|
# Assume transfer completed instantly
|
||||||
logger.info("FakeKVSender poll success")
|
logger.debug("FakeKVSender poll success")
|
||||||
return KVPoll.Success
|
return KVPoll.Success
|
||||||
|
|
||||||
def init(
|
def init(
|
||||||
@@ -41,7 +40,7 @@ class FakeKVSender(BaseKVSender):
|
|||||||
kv_indices: list[int],
|
kv_indices: list[int],
|
||||||
aux_index: Optional[int] = None,
|
aux_index: Optional[int] = None,
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
||||||
)
|
)
|
||||||
pass
|
pass
|
||||||
@@ -51,7 +50,7 @@ class FakeKVSender(BaseKVSender):
|
|||||||
kv_indices: npt.NDArray[np.int32],
|
kv_indices: npt.NDArray[np.int32],
|
||||||
):
|
):
|
||||||
self.has_sent = True
|
self.has_sent = True
|
||||||
logger.info(f"FakeKVSender send with kv_indices: {kv_indices}")
|
logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}")
|
||||||
|
|
||||||
def failure_exception(self):
|
def failure_exception(self):
|
||||||
raise Exception("Fake KVSender Exception")
|
raise Exception("Fake KVSender Exception")
|
||||||
@@ -73,12 +72,12 @@ class FakeKVReceiver(BaseKVReceiver):
|
|||||||
return KVPoll.WaitingForInput
|
return KVPoll.WaitingForInput
|
||||||
else:
|
else:
|
||||||
# Assume transfer completed instantly
|
# Assume transfer completed instantly
|
||||||
logger.info("FakeKVReceiver poll success")
|
logger.debug("FakeKVReceiver poll success")
|
||||||
return KVPoll.Success
|
return KVPoll.Success
|
||||||
|
|
||||||
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
|
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
|
||||||
self.has_init = True
|
self.has_init = True
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from .conn import (
|
from sglang.srt.disaggregation.mooncake.conn import (
|
||||||
MooncakeKVBootstrapServer,
|
MooncakeKVBootstrapServer,
|
||||||
MooncakeKVManager,
|
MooncakeKVManager,
|
||||||
MooncakeKVReceiver,
|
MooncakeKVReceiver,
|
||||||
|
|||||||
@@ -1 +1,6 @@
|
|||||||
from .conn import NixlKVBootstrapServer, NixlKVManager, NixlKVReceiver, NixlKVSender
|
from sglang.srt.disaggregation.nixl.conn import (
|
||||||
|
NixlKVBootstrapServer,
|
||||||
|
NixlKVManager,
|
||||||
|
NixlKVReceiver,
|
||||||
|
NixlKVSender,
|
||||||
|
)
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|||||||
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
||||||
}
|
}
|
||||||
return class_mapping.get(class_type)
|
return class_mapping.get(class_type)
|
||||||
if transfer_backend == TransferBackend.NIXL:
|
elif transfer_backend == TransferBackend.NIXL:
|
||||||
from sglang.srt.disaggregation.base import KVArgs
|
from sglang.srt.disaggregation.base import KVArgs
|
||||||
from sglang.srt.disaggregation.nixl import (
|
from sglang.srt.disaggregation.nixl import (
|
||||||
NixlKVBootstrapServer,
|
NixlKVBootstrapServer,
|
||||||
@@ -219,7 +219,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|||||||
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
|
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
|
||||||
}
|
}
|
||||||
return class_mapping.get(class_type)
|
return class_mapping.get(class_type)
|
||||||
if transfer_backend == TransferBackend.FAKE:
|
elif transfer_backend == TransferBackend.FAKE:
|
||||||
from sglang.srt.disaggregation.base import KVArgs
|
from sglang.srt.disaggregation.base import KVArgs
|
||||||
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user