[Eagle] Refactor eagle speculative decoding (#3986)

Co-authored-by: Ke Bao <ISPObaoke@163.com>
This commit is contained in:
Ying Sheng
2025-03-05 08:06:07 -08:00
committed by GitHub
parent 5be8f1ed98
commit d3d4d76758
22 changed files with 670 additions and 352 deletions

View File

@@ -44,18 +44,16 @@ from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access
@@ -523,7 +521,7 @@ class ScheduleBatch:
# Request, memory pool, and cache
reqs: List[Req]
req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool: BaseTokenToKVPool = None
token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
tree_cache: BasePrefixCache = None
# Batch configs
@@ -596,7 +594,7 @@ class ScheduleBatch:
cls,
reqs: List[Req],
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
tree_cache: BasePrefixCache,
model_config: ModelConfig,
enable_overlap: bool,
@@ -606,7 +604,7 @@ class ScheduleBatch:
return cls(
reqs=reqs,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool=token_to_kv_pool,
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
tree_cache=tree_cache,
model_config=model_config,
enable_overlap=enable_overlap,
@@ -637,19 +635,19 @@ class ScheduleBatch:
return req_pool_indices
def alloc_token_slots(self, num_tokens: int):
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
if out_cache_loc is None:
if self.tree_cache is not None:
self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
self.tree_cache.evict(num_tokens, self.token_to_kv_pool_allocator.free)
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
if out_cache_loc is None:
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
logger.error(
f"{phase_str} out of memory. Try to lower your batch size.\n"
f"Try to allocate {num_tokens} tokens.\n"
f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
)
if self.tree_cache is not None:
self.tree_cache.pretty_print()
@@ -917,12 +915,12 @@ class ScheduleBatch:
def check_decode_mem(self, buf_multiplier=1):
bs = len(self.reqs) * buf_multiplier
if self.token_to_kv_pool.available_size() >= bs:
if self.token_to_kv_pool_allocator.available_size() >= bs:
return True
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
self.tree_cache.evict(bs, self.token_to_kv_pool_allocator.free)
if self.token_to_kv_pool.available_size() >= bs:
if self.token_to_kv_pool_allocator.available_size() >= bs:
return True
return False
@@ -945,6 +943,10 @@ class ScheduleBatch:
reverse=True,
)
retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True
def get_required_tokens(num_reqs: int):
headroom_for_spec_decode = 0
if server_args.speculative_algorithm:
@@ -958,18 +960,15 @@ class ScheduleBatch:
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
)
retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True
while (
self.token_to_kv_pool.available_size()
self.token_to_kv_pool_allocator.available_size()
< get_required_tokens(len(sorted_indices))
or first_iter
):
if len(sorted_indices) == 1:
# Corner case: only one request left
assert (
self.token_to_kv_pool.available_size() > 0
self.token_to_kv_pool_allocator.available_size() > 0
), "No space left for only one request"
break
@@ -983,7 +982,7 @@ class ScheduleBatch:
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : seq_lens_cpu[idx]
]
self.token_to_kv_pool.free(token_indices)
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
del self.tree_cache.entries[req.rid]
else:
@@ -992,7 +991,7 @@ class ScheduleBatch:
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
]
self.token_to_kv_pool.free(token_indices)
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
# release the last node
@@ -1001,10 +1000,13 @@ class ScheduleBatch:
# NOTE(lsyin): we should use the newly evictable memory instantly.
residual_size = (
len(sorted_indices) * global_config.retract_decode_steps
- self.token_to_kv_pool.available_size()
- self.token_to_kv_pool_allocator.available_size()
)
residual_size = max(0, residual_size)
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
self.tree_cache.evict(
residual_size, self.token_to_kv_pool_allocator.free
)
req.reset_for_retract()
self.filter_batch(keep_indices=sorted_indices)
@@ -1183,7 +1185,7 @@ class ScheduleBatch:
if self.spec_info:
self.spec_info.merge_batch(other.spec_info)
def get_model_worker_batch(self):
def get_model_worker_batch(self) -> ModelWorkerBatch:
if self.forward_mode.is_decode_or_idle():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
@@ -1273,7 +1275,7 @@ class ModelWorkerBatch:
req_pool_indices: torch.Tensor
# The sequence length
seq_lens: torch.Tensor
# The indices of output tokens in the token_to_kv_pool
# The indices of output tokens in the token_to_kv_pool_allocator
out_cache_loc: torch.Tensor
# The sum of all sequence lengths