[PD] support spec decode (#6507)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
Byron Hsu
2025-05-23 12:03:05 -07:00
committed by GitHub
parent 2f42749184
commit d2e0881a34
8 changed files with 190 additions and 5 deletions

View File

@@ -61,6 +61,7 @@ class PrefillBootstrapQueue:
def __init__(
self,
token_to_kv_pool: KVCache,
draft_token_to_kv_pool: Optional[KVCache],
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: List[torch.Tensor],
aux_dtype: torch.dtype,
@@ -72,6 +73,8 @@ class PrefillBootstrapQueue:
scheduler: Scheduler,
):
self.token_to_kv_pool = token_to_kv_pool
self.draft_token_to_kv_pool = draft_token_to_kv_pool
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
self.aux_dtype = aux_dtype
@@ -98,6 +101,16 @@ class PrefillBootstrapQueue:
self.token_to_kv_pool.get_contiguous_buf_infos()
)
if self.draft_token_to_kv_pool is not None:
# We should also transfer draft model kv cache. The indices are
# always shared with a target model.
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
)
kv_data_ptrs += draft_kv_data_ptrs
kv_data_lens += draft_kv_data_lens
kv_item_lens += draft_kv_item_lens
kv_args.kv_data_ptrs = kv_data_ptrs
kv_args.kv_data_lens = kv_data_lens
kv_args.kv_item_lens = kv_item_lens