[Eagle] Refactor eagle speculative decoding (#3986)

Co-authored-by: Ke Bao <ISPObaoke@163.com>
This commit is contained in:
Ying Sheng
2025-03-05 08:06:07 -08:00
committed by GitHub
parent 5be8f1ed98
commit d3d4d76758
22 changed files with 670 additions and 352 deletions

View File

@@ -22,7 +22,7 @@ from typing import List, Optional
import torch
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPoolHost
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MHATokenToKVPoolHost
logger = logging.getLogger(__name__)
@@ -128,7 +128,7 @@ class HiCacheController:
def __init__(
self,
mem_pool_device: MHATokenToKVPool,
mem_pool_host: MLATokenToKVPoolHost,
mem_pool_host: MHATokenToKVPoolHost,
write_policy: str = "write_through_selective",
):

View File

@@ -44,18 +44,16 @@ from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access
@@ -523,7 +521,7 @@ class ScheduleBatch:
# Request, memory pool, and cache
reqs: List[Req]
req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool: BaseTokenToKVPool = None
token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
tree_cache: BasePrefixCache = None
# Batch configs
@@ -596,7 +594,7 @@ class ScheduleBatch:
cls,
reqs: List[Req],
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
tree_cache: BasePrefixCache,
model_config: ModelConfig,
enable_overlap: bool,
@@ -606,7 +604,7 @@ class ScheduleBatch:
return cls(
reqs=reqs,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool=token_to_kv_pool,
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
tree_cache=tree_cache,
model_config=model_config,
enable_overlap=enable_overlap,
@@ -637,19 +635,19 @@ class ScheduleBatch:
return req_pool_indices
def alloc_token_slots(self, num_tokens: int):
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
if out_cache_loc is None:
if self.tree_cache is not None:
self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
self.tree_cache.evict(num_tokens, self.token_to_kv_pool_allocator.free)
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
if out_cache_loc is None:
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
logger.error(
f"{phase_str} out of memory. Try to lower your batch size.\n"
f"Try to allocate {num_tokens} tokens.\n"
f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
)
if self.tree_cache is not None:
self.tree_cache.pretty_print()
@@ -917,12 +915,12 @@ class ScheduleBatch:
def check_decode_mem(self, buf_multiplier=1):
bs = len(self.reqs) * buf_multiplier
if self.token_to_kv_pool.available_size() >= bs:
if self.token_to_kv_pool_allocator.available_size() >= bs:
return True
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
self.tree_cache.evict(bs, self.token_to_kv_pool_allocator.free)
if self.token_to_kv_pool.available_size() >= bs:
if self.token_to_kv_pool_allocator.available_size() >= bs:
return True
return False
@@ -945,6 +943,10 @@ class ScheduleBatch:
reverse=True,
)
retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True
def get_required_tokens(num_reqs: int):
headroom_for_spec_decode = 0
if server_args.speculative_algorithm:
@@ -958,18 +960,15 @@ class ScheduleBatch:
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
)
retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True
while (
self.token_to_kv_pool.available_size()
self.token_to_kv_pool_allocator.available_size()
< get_required_tokens(len(sorted_indices))
or first_iter
):
if len(sorted_indices) == 1:
# Corner case: only one request left
assert (
self.token_to_kv_pool.available_size() > 0
self.token_to_kv_pool_allocator.available_size() > 0
), "No space left for only one request"
break
@@ -983,7 +982,7 @@ class ScheduleBatch:
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : seq_lens_cpu[idx]
]
self.token_to_kv_pool.free(token_indices)
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
del self.tree_cache.entries[req.rid]
else:
@@ -992,7 +991,7 @@ class ScheduleBatch:
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
]
self.token_to_kv_pool.free(token_indices)
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
# release the last node
@@ -1001,10 +1000,13 @@ class ScheduleBatch:
# NOTE(lsyin): we should use the newly evictable memory instantly.
residual_size = (
len(sorted_indices) * global_config.retract_decode_steps
- self.token_to_kv_pool.available_size()
- self.token_to_kv_pool_allocator.available_size()
)
residual_size = max(0, residual_size)
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
self.tree_cache.evict(
residual_size, self.token_to_kv_pool_allocator.free
)
req.reset_for_retract()
self.filter_batch(keep_indices=sorted_indices)
@@ -1183,7 +1185,7 @@ class ScheduleBatch:
if self.spec_info:
self.spec_info.merge_batch(other.spec_info)
def get_model_worker_batch(self):
def get_model_worker_batch(self) -> ModelWorkerBatch:
if self.forward_mode.is_decode_or_idle():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
@@ -1273,7 +1275,7 @@ class ModelWorkerBatch:
req_pool_indices: torch.Tensor
# The sequence length
seq_lens: torch.Tensor
# The indices of output tokens in the token_to_kv_pool
# The indices of output tokens in the token_to_kv_pool_allocator
out_cache_loc: torch.Tensor
# The sum of all sequence lengths

View File

@@ -22,9 +22,13 @@ from typing import Dict, List, Optional, Set, Union
import torch
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.schedule_batch import (
Req,
ScheduleBatch,
global_server_args_dict,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
@@ -75,7 +79,7 @@ class SchedulePolicy:
# It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache(
req_to_token_pool=None, token_to_kv_pool=None, disable=False
req_to_token_pool=None, token_to_kv_pool_allocator=None, disable=False
)
def calc_priority(self, waiting_queue: List[Req]) -> bool:
@@ -251,7 +255,7 @@ class PrefillAdder:
def __init__(
self,
tree_cache: BasePrefixCache,
token_to_kv_pool: BaseTokenToKVPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
running_batch: ScheduleBatch,
new_token_ratio: float,
rem_input_tokens: int,
@@ -259,7 +263,7 @@ class PrefillAdder:
mixed_with_decode_tokens: int = 0,
):
self.tree_cache = tree_cache
self.token_to_kv_pool = token_to_kv_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.running_batch = running_batch
self.new_token_ratio = new_token_ratio
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
@@ -291,7 +295,7 @@ class PrefillAdder:
@property
def rem_total_tokens(self):
return (
self.token_to_kv_pool.available_size()
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
- self.rem_total_token_offset
)
@@ -299,7 +303,7 @@ class PrefillAdder:
@property
def cur_rem_tokens(self):
return (
self.token_to_kv_pool.available_size()
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
- self.cur_rem_token_offset
)
@@ -332,7 +336,6 @@ class PrefillAdder:
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
self.can_run_list.append(req)
self._prefill_one_req(
0,
req.extend_input_len,
@@ -400,8 +403,8 @@ class PrefillAdder:
tokens_freed += tokens_occupied
if (
self.rem_chunk_tokens is None
or req.extend_input_len <= self.rem_chunk_tokens
self.rem_chunk_tokens is None # chunked prefill is disabled
or req.extend_input_len <= self.rem_chunk_tokens # it is the last chunk
):
# Non-chunked prefill
self.can_run_list.append(req)
@@ -411,10 +414,11 @@ class PrefillAdder:
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
)
else:
if self.rem_chunk_tokens == 0:
return AddReqResult.OTHER
# Chunked prefill
trunc_len = self.rem_chunk_tokens
if trunc_len == 0:
return AddReqResult.OTHER
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[:trunc_len]
@@ -457,10 +461,11 @@ class PrefillAdder:
),
)
else:
if self.rem_chunk_tokens == 0:
return AddReqResult.OTHER
# Chunked prefill
trunc_len = self.rem_chunk_tokens
if trunc_len == 0:
return AddReqResult.OTHER
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]

View File

@@ -164,7 +164,7 @@ class Scheduler:
self.server_args.speculative_num_draft_tokens
+ (
self.server_args.speculative_eagle_topk
* self.server_args.speculative_num_steps
* self.server_args.speculative_num_draft_tokens
)
)
if not self.spec_algorithm.is_none()
@@ -309,7 +309,9 @@ class Scheduler:
)
# Init memory pool and cache
self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
self.tp_worker.get_memory_pool()
)
if (
server_args.chunked_prefill_size is not None
@@ -317,18 +319,18 @@ class Scheduler:
):
self.tree_cache = ChunkCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
else:
if self.enable_hierarchical_cache:
self.tree_cache = HiRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
else:
self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
disable=server_args.disable_radix_cache,
)
@@ -458,7 +460,6 @@ class Scheduler:
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
(ProfileReq, self.profile),
(GetInternalStateReq, self.get_internal_state),
(SetInternalStateReq, self.set_internal_state),
]
)
@@ -809,7 +810,8 @@ class Scheduler:
running_bs: int,
):
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
self._largest_prefill_len = max(
self._largest_prefill_len, adder.log_input_tokens
@@ -844,7 +846,8 @@ class Scheduler:
self.num_generated_tokens = 0
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
if RECORD_STEP_TIME:
@@ -894,7 +897,8 @@ class Scheduler:
def check_memory(self):
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
protected_size = self.tree_cache.protected_size()
memory_leak = available_size != (
@@ -999,7 +1003,7 @@ class Scheduler:
# Prefill policy
adder = PrefillAdder(
self.tree_cache,
self.token_to_kv_pool,
self.token_to_kv_pool_allocator,
self.running_batch,
self.new_token_ratio,
self.max_prefill_tokens,
@@ -1099,7 +1103,7 @@ class Scheduler:
new_batch = ScheduleBatch.init_new(
can_run_list,
self.req_to_token_pool,
self.token_to_kv_pool,
self.token_to_kv_pool_allocator,
self.tree_cache,
self.model_config,
self.enable_overlap,
@@ -1143,8 +1147,6 @@ class Scheduler:
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
self.new_token_ratio = new_token_ratio
if self.draft_worker:
self.draft_worker.finish_request(retracted_reqs)
logger.info(
"Decode out of memory happened. "
@@ -1184,11 +1186,12 @@ class Scheduler:
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
bid = model_worker_batch.bid
else:
(
logits_output,
next_token_ids,
model_worker_batch,
bid,
num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += (
@@ -1214,7 +1217,7 @@ class Scheduler:
next_token_ids=next_token_ids,
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
bid=model_worker_batch.bid,
bid=bid,
)
else: # embedding or reward model
model_worker_batch = batch.get_model_worker_batch()
@@ -1230,6 +1233,7 @@ class Scheduler:
result: Union[GenerationBatchResult, EmbeddingBatchResult],
):
if batch.forward_mode.is_decode():
assert isinstance(result, GenerationBatchResult)
self.process_batch_result_decode(batch, result)
if batch.is_empty():
self.running_batch = None
@@ -1302,7 +1306,7 @@ class Scheduler:
if self.is_mixed_chunk and self.enable_overlap and req.finished():
# Free the one delayed token for the mixed decode batch
j = len(batch.out_cache_loc) - len(batch.reqs) + i
self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
continue
if req.is_chunked <= 0:
@@ -1420,23 +1424,27 @@ class Scheduler:
self.num_generated_tokens += len(batch.reqs)
if self.enable_overlap:
assert batch.spec_algorithm.is_none()
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
next_token_logprobs = logits_output.next_token_logprobs
else:
elif batch.spec_algorithm.is_none():
# spec decoding handles output logprobs inside verify process.
next_token_ids = next_token_ids.tolist()
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs.tolist()
self.token_to_kv_pool.free_group_begin()
self.token_to_kv_pool_allocator.free_group_begin()
# Check finish condition
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
# We should ignore using next_token_ids for spec decoding cases.
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if req.is_retracted:
continue
if self.enable_overlap and req.finished():
# Free the one delayed token
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
continue
if batch.spec_algorithm.is_none():
@@ -1479,7 +1487,7 @@ class Scheduler:
batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs, batch.return_logprob)
self.token_to_kv_pool.free_group_end()
self.token_to_kv_pool_allocator.free_group_end()
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
if (
@@ -1718,9 +1726,6 @@ class Scheduler:
and not self.model_config.is_multimodal_gen
)
):
if self.draft_worker and req.finished():
self.draft_worker.finish_request(req)
rids.append(req.rid)
finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None
@@ -1860,7 +1865,7 @@ class Scheduler:
idle_batch = ScheduleBatch.init_new(
[],
self.req_to_token_pool,
self.token_to_kv_pool,
self.token_to_kv_pool_allocator,
self.tree_cache,
self.model_config,
self.enable_overlap,
@@ -1916,11 +1921,11 @@ class Scheduler:
if self.grammar_backend:
self.grammar_backend.reset()
self.req_to_token_pool.clear()
self.token_to_kv_pool.clear()
self.token_to_kv_pool_allocator.clear()
if not self.spec_algorithm.is_none():
self.draft_worker.model_runner.req_to_token_pool.clear()
self.draft_worker.model_runner.token_to_kv_pool.clear()
self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
self.num_generated_tokens = 0
self.forward_ct_decode = 0

View File

@@ -82,8 +82,6 @@ from sglang.srt.managers.io_struct import (
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
SessionParams,
SetInternalStateReq,
SetInternalStateReqOutput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput,
@@ -257,9 +255,6 @@ class TokenizerManager:
self.get_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.set_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self._result_dispatcher = TypeBasedDispatcher(
[
@@ -309,10 +304,6 @@ class TokenizerManager:
GetInternalStateReqOutput,
self.get_internal_state_communicator.handle_recv,
),
(
SetInternalStateReqOutput,
self.set_internal_state_communicator.handle_recv,
),
(HealthCheckOutput, lambda x: None),
]
)
@@ -774,14 +765,6 @@ class TokenizerManager:
)
return res[0].internal_state
async def set_internal_state(
self, obj: SetInternalStateReq
) -> SetInternalStateReqOutput:
res: List[SetInternalStateReqOutput] = (
await self.set_internal_state_communicator(obj)
)
return res[0]
def get_log_request_metadata(self):
max_length = None
skip_names = None

View File

@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
@@ -49,6 +50,8 @@ class TpModelWorker:
dp_rank: Optional[int],
nccl_port: int,
is_draft_worker: bool = False,
req_to_token_pool: Optional[ReqToTokenPool] = None,
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
):
# Parse args
self.tp_rank = tp_rank
@@ -77,6 +80,8 @@ class TpModelWorker:
nccl_port=nccl_port,
server_args=server_args,
is_draft_worker=is_draft_worker,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
)
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
@@ -154,7 +159,7 @@ class TpModelWorker:
def get_memory_pool(self):
return (
self.model_runner.req_to_token_pool,
self.model_runner.token_to_kv_pool,
self.model_runner.token_to_kv_pool_allocator,
)
def forward_batch_generation(

View File

@@ -100,7 +100,7 @@ class TpModelWorkerClient:
def get_memory_pool(self):
return (
self.worker.model_runner.req_to_token_pool,
self.worker.model_runner.token_to_kv_pool,
self.worker.model_runner.token_to_kv_pool_allocator,
)
def forward_thread_func(self):