[PD] support spec decode (#6507)
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
@@ -61,6 +61,7 @@ class PrefillBootstrapQueue:
|
||||
def __init__(
|
||||
self,
|
||||
token_to_kv_pool: KVCache,
|
||||
draft_token_to_kv_pool: Optional[KVCache],
|
||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||
metadata_buffers: List[torch.Tensor],
|
||||
aux_dtype: torch.dtype,
|
||||
@@ -72,6 +73,8 @@ class PrefillBootstrapQueue:
|
||||
scheduler: Scheduler,
|
||||
):
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.draft_token_to_kv_pool = draft_token_to_kv_pool
|
||||
|
||||
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
|
||||
self.aux_dtype = aux_dtype
|
||||
|
||||
@@ -98,6 +101,16 @@ class PrefillBootstrapQueue:
|
||||
self.token_to_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
|
||||
if self.draft_token_to_kv_pool is not None:
|
||||
# We should also transfer draft model kv cache. The indices are
|
||||
# always shared with a target model.
|
||||
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
|
||||
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
|
||||
)
|
||||
kv_data_ptrs += draft_kv_data_ptrs
|
||||
kv_data_lens += draft_kv_data_lens
|
||||
kv_item_lens += draft_kv_item_lens
|
||||
|
||||
kv_args.kv_data_ptrs = kv_data_ptrs
|
||||
kv_args.kv_data_lens = kv_data_lens
|
||||
kv_args.kv_item_lens = kv_item_lens
|
||||
|
||||
Reference in New Issue
Block a user