[PD] Add transfer backend abstraction (#5328)
This commit is contained in:
@@ -45,7 +45,7 @@ import triton.language as tl
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||
from sglang.srt.disaggregation.conn import KVSender
|
||||
from sglang.srt.disaggregation.base import BaseKVSender
|
||||
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
@@ -525,7 +525,7 @@ class Req:
|
||||
# For disaggregation
|
||||
self.bootstrap_host: str = bootstrap_host
|
||||
self.bootstrap_room: Optional[int] = bootstrap_room
|
||||
self.disagg_kv_sender: Optional[KVSender] = None
|
||||
self.disagg_kv_sender: Optional[BaseKVSender] = None
|
||||
|
||||
# used for warmup because we don't have a pair yet when init
|
||||
self.skip_kv_transfer: bool = False
|
||||
|
||||
@@ -49,6 +49,7 @@ from sglang.srt.disaggregation.prefill import (
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
DisaggregationMode,
|
||||
ReqToMetadataIdxAllocator,
|
||||
TransferBackend,
|
||||
)
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
@@ -530,6 +531,10 @@ class Scheduler(
|
||||
)
|
||||
|
||||
def init_disaggregation(self):
|
||||
self.transfer_backend = TransferBackend(
|
||||
self.server_args.disaggregation_transfer_backend
|
||||
)
|
||||
|
||||
if (
|
||||
self.disaggregation_mode == DisaggregationMode.DECODE
|
||||
): # *2 for the headroom.
|
||||
@@ -567,6 +572,7 @@ class Scheduler(
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
||||
transfer_backend=self.transfer_backend,
|
||||
)
|
||||
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# *2 for the headroom.
|
||||
@@ -592,6 +598,7 @@ class Scheduler(
|
||||
tp_size=self.tp_size,
|
||||
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
||||
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
||||
transfer_backend=self.transfer_backend,
|
||||
)
|
||||
# The prefill requests that are in the middle of kv sending
|
||||
self.disagg_prefill_inflight_queue: List[Req] = []
|
||||
|
||||
@@ -48,8 +48,12 @@ from fastapi import BackgroundTasks
|
||||
|
||||
from sglang.srt.aio_rwlock import RWLock
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.disaggregation.conn import KVBootstrapServer
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
DisaggregationMode,
|
||||
KVClassType,
|
||||
TransferBackend,
|
||||
get_kv_class,
|
||||
)
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
@@ -329,10 +333,16 @@ class TokenizerManager:
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
)
|
||||
self.transfer_backend = TransferBackend(
|
||||
self.server_args.disaggregation_transfer_backend
|
||||
)
|
||||
# for disaggregtion, start kv boostrap server on prefill
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# only start bootstrap server on prefill tm
|
||||
self.bootstrap_server = KVBootstrapServer(
|
||||
kv_bootstrap_server_class = get_kv_class(
|
||||
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
||||
)
|
||||
self.bootstrap_server = kv_bootstrap_server_class(
|
||||
self.server_args.disaggregation_bootstrap_port
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user