[PD] support spec decode (#6507)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
Byron Hsu
2025-05-23 12:03:05 -07:00
committed by GitHub
parent 2f42749184
commit d2e0881a34
8 changed files with 190 additions and 5 deletions

View File

@@ -47,7 +47,7 @@ from sglang.srt.disaggregation.utils import (
from sglang.srt.managers.schedule_batch import FINISH_ABORT
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
logger = logging.getLogger(__name__)
@@ -76,6 +76,7 @@ class DecodePreallocQueue:
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
draft_token_to_kv_pool: Optional[KVCache],
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: List[torch.Tensor],
aux_dtype: torch.dtype,
@@ -91,6 +92,7 @@ class DecodePreallocQueue:
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
self.draft_token_to_kv_pool = draft_token_to_kv_pool
self.is_mla_backend = is_mla_backend(self.token_to_kv_pool)
self.aux_dtype = aux_dtype
self.metadata_buffers = metadata_buffers
@@ -119,6 +121,14 @@ class DecodePreallocQueue:
self.token_to_kv_pool.get_contiguous_buf_infos()
)
if self.draft_token_to_kv_pool is not None:
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
)
kv_data_ptrs += draft_kv_data_ptrs
kv_data_lens += draft_kv_data_lens
kv_item_lens += draft_kv_item_lens
kv_args.kv_data_ptrs = kv_data_ptrs
kv_args.kv_data_lens = kv_data_lens
kv_args.kv_item_lens = kv_item_lens

View File

@@ -51,6 +51,7 @@ def group_concurrent_contiguous(
return src_groups, dst_groups
# prefill
@dataclasses.dataclass
class TransferKVChunk:
room: int
@@ -60,6 +61,7 @@ class TransferKVChunk:
prefill_aux_index: Optional[int]
# decode
@dataclasses.dataclass
class TransferInfo:
room: int
@@ -93,6 +95,7 @@ class TransferInfo:
)
# decode
@dataclasses.dataclass
class KVArgsRegisterInfo:
room: str

View File

@@ -61,7 +61,8 @@ class MooncakeTransferEngine:
self, session_id: str, buffer: int, peer_buffer_address: int, length: int
) -> int:
"""Synchronously transfer data to the specified address."""
# the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair
# later: based on the cached queue pair to send data
ret = self.engine.transfer_sync_write(
session_id, buffer, peer_buffer_address, length
)

View File

@@ -61,6 +61,7 @@ class PrefillBootstrapQueue:
def __init__(
self,
token_to_kv_pool: KVCache,
draft_token_to_kv_pool: Optional[KVCache],
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: List[torch.Tensor],
aux_dtype: torch.dtype,
@@ -72,6 +73,8 @@ class PrefillBootstrapQueue:
scheduler: Scheduler,
):
self.token_to_kv_pool = token_to_kv_pool
self.draft_token_to_kv_pool = draft_token_to_kv_pool
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
self.aux_dtype = aux_dtype
@@ -98,6 +101,16 @@ class PrefillBootstrapQueue:
self.token_to_kv_pool.get_contiguous_buf_infos()
)
if self.draft_token_to_kv_pool is not None:
# We should also transfer draft model kv cache. The indices are
# always shared with a target model.
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
)
kv_data_ptrs += draft_kv_data_ptrs
kv_data_lens += draft_kv_data_lens
kv_item_lens += draft_kv_item_lens
kv_args.kv_data_ptrs = kv_data_ptrs
kv_args.kv_data_lens = kv_data_lens
kv_args.kv_item_lens = kv_item_lens

View File

@@ -591,6 +591,11 @@ class Scheduler(
self.disagg_decode_prealloc_queue = DecodePreallocQueue(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
draft_token_to_kv_pool=(
None
if self.draft_worker is None
else self.draft_worker.model_runner.token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers,
aux_dtype=aux_dtype,
@@ -624,6 +629,11 @@ class Scheduler(
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
draft_token_to_kv_pool=(
None
if self.draft_worker is None
else self.draft_worker.model_runner.token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers,
aux_dtype=aux_dtype,
@@ -1409,6 +1419,13 @@ class Scheduler(
self.running_batch.batch_is_full = True
break
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# In prefill mode, prealloc queue and transfer queue can also take memory,
# so we need to check if the available size for the actual available size.
if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
self.running_batch.batch_is_full = True
break
req.init_next_round_input(
None if prefix_computed else self.tree_cache,
self.enable_hierarchical_cache,