[PD] support spec decode (#6507)
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
@@ -47,7 +47,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -76,6 +76,7 @@ class DecodePreallocQueue:
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
draft_token_to_kv_pool: Optional[KVCache],
|
||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||
metadata_buffers: List[torch.Tensor],
|
||||
aux_dtype: torch.dtype,
|
||||
@@ -91,6 +92,7 @@ class DecodePreallocQueue:
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
|
||||
self.draft_token_to_kv_pool = draft_token_to_kv_pool
|
||||
self.is_mla_backend = is_mla_backend(self.token_to_kv_pool)
|
||||
self.aux_dtype = aux_dtype
|
||||
self.metadata_buffers = metadata_buffers
|
||||
@@ -119,6 +121,14 @@ class DecodePreallocQueue:
|
||||
self.token_to_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
|
||||
if self.draft_token_to_kv_pool is not None:
|
||||
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
|
||||
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
kv_data_ptrs += draft_kv_data_ptrs
|
||||
kv_data_lens += draft_kv_data_lens
|
||||
kv_item_lens += draft_kv_item_lens
|
||||
|
||||
kv_args.kv_data_ptrs = kv_data_ptrs
|
||||
kv_args.kv_data_lens = kv_data_lens
|
||||
kv_args.kv_item_lens = kv_item_lens
|
||||
|
||||
@@ -51,6 +51,7 @@ def group_concurrent_contiguous(
|
||||
return src_groups, dst_groups
|
||||
|
||||
|
||||
# prefill
|
||||
@dataclasses.dataclass
|
||||
class TransferKVChunk:
|
||||
room: int
|
||||
@@ -60,6 +61,7 @@ class TransferKVChunk:
|
||||
prefill_aux_index: Optional[int]
|
||||
|
||||
|
||||
# decode
|
||||
@dataclasses.dataclass
|
||||
class TransferInfo:
|
||||
room: int
|
||||
@@ -93,6 +95,7 @@ class TransferInfo:
|
||||
)
|
||||
|
||||
|
||||
# decode
|
||||
@dataclasses.dataclass
|
||||
class KVArgsRegisterInfo:
|
||||
room: str
|
||||
|
||||
@@ -61,7 +61,8 @@ class MooncakeTransferEngine:
|
||||
self, session_id: str, buffer: int, peer_buffer_address: int, length: int
|
||||
) -> int:
|
||||
"""Synchronously transfer data to the specified address."""
|
||||
|
||||
# the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair
|
||||
# later: based on the cached queue pair to send data
|
||||
ret = self.engine.transfer_sync_write(
|
||||
session_id, buffer, peer_buffer_address, length
|
||||
)
|
||||
|
||||
@@ -61,6 +61,7 @@ class PrefillBootstrapQueue:
|
||||
def __init__(
|
||||
self,
|
||||
token_to_kv_pool: KVCache,
|
||||
draft_token_to_kv_pool: Optional[KVCache],
|
||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||
metadata_buffers: List[torch.Tensor],
|
||||
aux_dtype: torch.dtype,
|
||||
@@ -72,6 +73,8 @@ class PrefillBootstrapQueue:
|
||||
scheduler: Scheduler,
|
||||
):
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.draft_token_to_kv_pool = draft_token_to_kv_pool
|
||||
|
||||
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
|
||||
self.aux_dtype = aux_dtype
|
||||
|
||||
@@ -98,6 +101,16 @@ class PrefillBootstrapQueue:
|
||||
self.token_to_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
|
||||
if self.draft_token_to_kv_pool is not None:
|
||||
# We should also transfer draft model kv cache. The indices are
|
||||
# always shared with a target model.
|
||||
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
|
||||
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
kv_data_ptrs += draft_kv_data_ptrs
|
||||
kv_data_lens += draft_kv_data_lens
|
||||
kv_item_lens += draft_kv_item_lens
|
||||
|
||||
kv_args.kv_data_ptrs = kv_data_ptrs
|
||||
kv_args.kv_data_lens = kv_data_lens
|
||||
kv_args.kv_item_lens = kv_item_lens
|
||||
|
||||
@@ -591,6 +591,11 @@ class Scheduler(
|
||||
self.disagg_decode_prealloc_queue = DecodePreallocQueue(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
draft_token_to_kv_pool=(
|
||||
None
|
||||
if self.draft_worker is None
|
||||
else self.draft_worker.model_runner.token_to_kv_pool
|
||||
),
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=metadata_buffers,
|
||||
aux_dtype=aux_dtype,
|
||||
@@ -624,6 +629,11 @@ class Scheduler(
|
||||
|
||||
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
|
||||
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
||||
draft_token_to_kv_pool=(
|
||||
None
|
||||
if self.draft_worker is None
|
||||
else self.draft_worker.model_runner.token_to_kv_pool
|
||||
),
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=metadata_buffers,
|
||||
aux_dtype=aux_dtype,
|
||||
@@ -1409,6 +1419,13 @@ class Scheduler(
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# In prefill mode, prealloc queue and transfer queue can also take memory,
|
||||
# so we need to check if the available size for the actual available size.
|
||||
if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
|
||||
req.init_next_round_input(
|
||||
None if prefix_computed else self.tree_cache,
|
||||
self.enable_hierarchical_cache,
|
||||
|
||||
Reference in New Issue
Block a user