[PD] Add transfer backend abstraction (#5328)

This commit is contained in:
Byron Hsu
2025-04-13 10:39:39 -07:00
committed by GitHub
parent f765579046
commit a9499885e9
12 changed files with 236 additions and 41 deletions

View File

@@ -45,7 +45,7 @@ import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.disaggregation.conn import KVSender
from sglang.srt.disaggregation.base import BaseKVSender
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
@@ -525,7 +525,7 @@ class Req:
# For disaggregation
self.bootstrap_host: str = bootstrap_host
self.bootstrap_room: Optional[int] = bootstrap_room
self.disagg_kv_sender: Optional[KVSender] = None
self.disagg_kv_sender: Optional[BaseKVSender] = None
# used for warmup because we don't have a pair yet when init
self.skip_kv_transfer: bool = False

View File

@@ -49,6 +49,7 @@ from sglang.srt.disaggregation.prefill import (
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
ReqToMetadataIdxAllocator,
TransferBackend,
)
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
@@ -530,6 +531,10 @@ class Scheduler(
)
def init_disaggregation(self):
self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
if (
self.disaggregation_mode == DisaggregationMode.DECODE
): # *2 for the headroom.
@@ -567,6 +572,7 @@ class Scheduler(
tp_rank=self.tp_rank,
tp_size=self.tp_size,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
transfer_backend=self.transfer_backend,
)
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
# *2 for the headroom.
@@ -592,6 +598,7 @@ class Scheduler(
tp_size=self.tp_size,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
transfer_backend=self.transfer_backend,
)
# The prefill requests that are in the middle of kv sending
self.disagg_prefill_inflight_queue: List[Req] = []

View File

@@ -48,8 +48,12 @@ from fastapi import BackgroundTasks
from sglang.srt.aio_rwlock import RWLock
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.disaggregation.conn import KVBootstrapServer
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
KVClassType,
TransferBackend,
get_kv_class,
)
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import (
AbortReq,
@@ -329,10 +333,16 @@ class TokenizerManager:
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# for disaggregtion, start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
self.bootstrap_server = KVBootstrapServer(
kv_bootstrap_server_class = get_kv_class(
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
)