[Eagle] Refactor eagle speculative decoding (#3986)

Co-authored-by: Ke Bao <ISPObaoke@163.com>
This commit is contained in:
Ying Sheng
2025-03-05 08:06:07 -08:00
committed by GitHub
parent 5be8f1ed98
commit d3d4d76758
22 changed files with 670 additions and 352 deletions

View File

@@ -22,9 +22,13 @@ from typing import Dict, List, Optional, Set, Union
import torch
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.schedule_batch import (
Req,
ScheduleBatch,
global_server_args_dict,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
@@ -75,7 +79,7 @@ class SchedulePolicy:
# It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache(
req_to_token_pool=None, token_to_kv_pool=None, disable=False
req_to_token_pool=None, token_to_kv_pool_allocator=None, disable=False
)
def calc_priority(self, waiting_queue: List[Req]) -> bool:
@@ -251,7 +255,7 @@ class PrefillAdder:
def __init__(
self,
tree_cache: BasePrefixCache,
token_to_kv_pool: BaseTokenToKVPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
running_batch: ScheduleBatch,
new_token_ratio: float,
rem_input_tokens: int,
@@ -259,7 +263,7 @@ class PrefillAdder:
mixed_with_decode_tokens: int = 0,
):
self.tree_cache = tree_cache
self.token_to_kv_pool = token_to_kv_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.running_batch = running_batch
self.new_token_ratio = new_token_ratio
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
@@ -291,7 +295,7 @@ class PrefillAdder:
@property
def rem_total_tokens(self):
return (
self.token_to_kv_pool.available_size()
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
- self.rem_total_token_offset
)
@@ -299,7 +303,7 @@ class PrefillAdder:
@property
def cur_rem_tokens(self):
return (
self.token_to_kv_pool.available_size()
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
- self.cur_rem_token_offset
)
@@ -332,7 +336,6 @@ class PrefillAdder:
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
self.can_run_list.append(req)
self._prefill_one_req(
0,
req.extend_input_len,
@@ -400,8 +403,8 @@ class PrefillAdder:
tokens_freed += tokens_occupied
if (
self.rem_chunk_tokens is None
or req.extend_input_len <= self.rem_chunk_tokens
self.rem_chunk_tokens is None # chunked prefill is disabled
or req.extend_input_len <= self.rem_chunk_tokens # it is the last chunk
):
# Non-chunked prefill
self.can_run_list.append(req)
@@ -411,10 +414,11 @@ class PrefillAdder:
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
)
else:
if self.rem_chunk_tokens == 0:
return AddReqResult.OTHER
# Chunked prefill
trunc_len = self.rem_chunk_tokens
if trunc_len == 0:
return AddReqResult.OTHER
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[:trunc_len]
@@ -457,10 +461,11 @@ class PrefillAdder:
),
)
else:
if self.rem_chunk_tokens == 0:
return AddReqResult.OTHER
# Chunked prefill
trunc_len = self.rem_chunk_tokens
if trunc_len == 0:
return AddReqResult.OTHER
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]