[Eagle] Refactor eagle speculative decoding (#3986)
Co-authored-by: Ke Bao <ISPObaoke@163.com>
This commit is contained in:
@@ -22,9 +22,13 @@ from typing import Dict, List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Req,
|
||||
ScheduleBatch,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool
|
||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||
|
||||
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
|
||||
@@ -75,7 +79,7 @@ class SchedulePolicy:
|
||||
|
||||
# It is used to find the matching prefix for in-batch prefix caching.
|
||||
self.waiting_queue_radix_tree = RadixCache(
|
||||
req_to_token_pool=None, token_to_kv_pool=None, disable=False
|
||||
req_to_token_pool=None, token_to_kv_pool_allocator=None, disable=False
|
||||
)
|
||||
|
||||
def calc_priority(self, waiting_queue: List[Req]) -> bool:
|
||||
@@ -251,7 +255,7 @@ class PrefillAdder:
|
||||
def __init__(
|
||||
self,
|
||||
tree_cache: BasePrefixCache,
|
||||
token_to_kv_pool: BaseTokenToKVPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
running_batch: ScheduleBatch,
|
||||
new_token_ratio: float,
|
||||
rem_input_tokens: int,
|
||||
@@ -259,7 +263,7 @@ class PrefillAdder:
|
||||
mixed_with_decode_tokens: int = 0,
|
||||
):
|
||||
self.tree_cache = tree_cache
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.running_batch = running_batch
|
||||
self.new_token_ratio = new_token_ratio
|
||||
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
||||
@@ -291,7 +295,7 @@ class PrefillAdder:
|
||||
@property
|
||||
def rem_total_tokens(self):
|
||||
return (
|
||||
self.token_to_kv_pool.available_size()
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
- self.rem_total_token_offset
|
||||
)
|
||||
@@ -299,7 +303,7 @@ class PrefillAdder:
|
||||
@property
|
||||
def cur_rem_tokens(self):
|
||||
return (
|
||||
self.token_to_kv_pool.available_size()
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
- self.cur_rem_token_offset
|
||||
)
|
||||
@@ -332,7 +336,6 @@ class PrefillAdder:
|
||||
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
||||
self.can_run_list.append(req)
|
||||
|
||||
self._prefill_one_req(
|
||||
0,
|
||||
req.extend_input_len,
|
||||
@@ -400,8 +403,8 @@ class PrefillAdder:
|
||||
tokens_freed += tokens_occupied
|
||||
|
||||
if (
|
||||
self.rem_chunk_tokens is None
|
||||
or req.extend_input_len <= self.rem_chunk_tokens
|
||||
self.rem_chunk_tokens is None # chunked prefill is disabled
|
||||
or req.extend_input_len <= self.rem_chunk_tokens # it is the last chunk
|
||||
):
|
||||
# Non-chunked prefill
|
||||
self.can_run_list.append(req)
|
||||
@@ -411,10 +414,11 @@ class PrefillAdder:
|
||||
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
|
||||
)
|
||||
else:
|
||||
if self.rem_chunk_tokens == 0:
|
||||
return AddReqResult.OTHER
|
||||
|
||||
# Chunked prefill
|
||||
trunc_len = self.rem_chunk_tokens
|
||||
if trunc_len == 0:
|
||||
return AddReqResult.OTHER
|
||||
|
||||
req.extend_input_len = trunc_len
|
||||
req.fill_ids = req.fill_ids[:trunc_len]
|
||||
@@ -457,10 +461,11 @@ class PrefillAdder:
|
||||
),
|
||||
)
|
||||
else:
|
||||
if self.rem_chunk_tokens == 0:
|
||||
return AddReqResult.OTHER
|
||||
|
||||
# Chunked prefill
|
||||
trunc_len = self.rem_chunk_tokens
|
||||
if trunc_len == 0:
|
||||
return AddReqResult.OTHER
|
||||
|
||||
req.extend_input_len = trunc_len
|
||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
||||
|
||||
Reference in New Issue
Block a user