[Feature] Speculative decoding support lookahead (#9873)

Co-authored-by: a4zhangfei <a4zhangfei@qq.com>
Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
This commit is contained in:
Zhihao Zhang
2025-09-19 07:42:41 +08:00
committed by GitHub
parent 2a2ff9a840
commit e7bc600304
30 changed files with 2058 additions and 32 deletions

View File

@@ -385,6 +385,18 @@ class Scheduler(
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
elif self.spec_algorithm.is_lookahead():
from sglang.srt.speculative.lookahead_worker import LOOKAHEADWorker
self.draft_worker = LOOKAHEADWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args,
nccl_port=port_args.nccl_port,
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
else:
self.draft_worker = None
@@ -740,8 +752,8 @@ class Scheduler(
else (
server_args.speculative_num_draft_tokens
+ (
server_args.speculative_eagle_topk
* server_args.speculative_num_steps
(server_args.speculative_eagle_topk or 1)
* (server_args.speculative_num_steps or 1)
)
)
)
@@ -784,7 +796,7 @@ class Scheduler(
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
draft_token_to_kv_pool=(
None
if self.draft_worker is None
if self.draft_worker is None or self.spec_algorithm.is_lookahead()
else self.draft_worker.model_runner.token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
@@ -821,7 +833,7 @@ class Scheduler(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
draft_token_to_kv_pool=(
None
if self.draft_worker is None
if self.draft_worker is None or self.spec_algorithm.is_lookahead()
else self.draft_worker.model_runner.token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
@@ -2358,9 +2370,8 @@ class Scheduler(
self.req_to_token_pool.clear()
self.token_to_kv_pool_allocator.clear()
if not self.spec_algorithm.is_none():
self.draft_worker.model_runner.req_to_token_pool.clear()
self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
if self.draft_worker:
self.draft_worker.clear_cache_pool()
self.num_generated_tokens = 0
self.forward_ct_decode = 0