[PD] Add transfer backend abstraction (#5328)
This commit is contained in:
@@ -24,10 +24,19 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVSender
|
||||
from sglang.srt.disaggregation.base import (
|
||||
BaseKVManager,
|
||||
BaseKVReceiver,
|
||||
BaseKVSender,
|
||||
KVArgs,
|
||||
KVPoll,
|
||||
)
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
DisaggregationMode,
|
||||
KVClassType,
|
||||
ReqToMetadataIdxAllocator,
|
||||
TransferBackend,
|
||||
get_kv_class,
|
||||
poll_and_all_reduce,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
||||
@@ -38,6 +47,7 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -56,6 +66,7 @@ class PrefillBootstrapQueue:
|
||||
tp_size: int,
|
||||
bootstrap_port: int,
|
||||
gloo_group: ProcessGroup,
|
||||
transfer_backend: TransferBackend,
|
||||
):
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.aux_dtype = aux_dtype
|
||||
@@ -64,6 +75,7 @@ class PrefillBootstrapQueue:
|
||||
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.transfer_backend = transfer_backend
|
||||
self.kv_manager = self._init_kv_manager()
|
||||
self.queue: List[Req] = []
|
||||
self.gloo_group = gloo_group
|
||||
@@ -74,7 +86,7 @@ class PrefillBootstrapQueue:
|
||||
output_id_buffer = self.metadata_buffers[0]
|
||||
output_id_buffer[idx] = token_id
|
||||
|
||||
def _init_kv_manager(self) -> KVManager:
|
||||
def _init_kv_manager(self) -> BaseKVManager:
|
||||
kv_args = KVArgs()
|
||||
kv_args.engine_rank = self.tp_rank
|
||||
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
||||
@@ -96,11 +108,13 @@ class PrefillBootstrapQueue:
|
||||
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
||||
]
|
||||
kv_args.ib_device = "mock-ib-device"
|
||||
kv_manager = KVManager(kv_args, DisaggregationMode("prefill"))
|
||||
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
||||
kv_manager = kv_manager_class(kv_args, DisaggregationMode.PREFILL)
|
||||
return kv_manager
|
||||
|
||||
def add(self, req: Req) -> None:
|
||||
req.disagg_kv_sender = KVSender(
|
||||
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
|
||||
req.disagg_kv_sender = kv_sender_class(
|
||||
mgr=self.kv_manager,
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
||||
bootstrap_room=req.bootstrap_room,
|
||||
|
||||
Reference in New Issue
Block a user