[Eagle] Refactor eagle speculative decoding (#3986)

Co-authored-by: Ke Bao <ISPObaoke@163.com>
This commit is contained in:
Ying Sheng
2025-03-05 08:06:07 -08:00
committed by GitHub
parent 5be8f1ed98
commit d3d4d76758
22 changed files with 670 additions and 352 deletions

View File

@@ -164,7 +164,7 @@ class Scheduler:
self.server_args.speculative_num_draft_tokens
+ (
self.server_args.speculative_eagle_topk
* self.server_args.speculative_num_steps
* self.server_args.speculative_num_draft_tokens
)
)
if not self.spec_algorithm.is_none()
@@ -309,7 +309,9 @@ class Scheduler:
)
# Init memory pool and cache
self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
self.tp_worker.get_memory_pool()
)
if (
server_args.chunked_prefill_size is not None
@@ -317,18 +319,18 @@ class Scheduler:
):
self.tree_cache = ChunkCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
else:
if self.enable_hierarchical_cache:
self.tree_cache = HiRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
else:
self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
disable=server_args.disable_radix_cache,
)
@@ -458,7 +460,6 @@ class Scheduler:
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
(ProfileReq, self.profile),
(GetInternalStateReq, self.get_internal_state),
(SetInternalStateReq, self.set_internal_state),
]
)
@@ -809,7 +810,8 @@ class Scheduler:
running_bs: int,
):
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
self._largest_prefill_len = max(
self._largest_prefill_len, adder.log_input_tokens
@@ -844,7 +846,8 @@ class Scheduler:
self.num_generated_tokens = 0
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
if RECORD_STEP_TIME:
@@ -894,7 +897,8 @@ class Scheduler:
def check_memory(self):
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
protected_size = self.tree_cache.protected_size()
memory_leak = available_size != (
@@ -999,7 +1003,7 @@ class Scheduler:
# Prefill policy
adder = PrefillAdder(
self.tree_cache,
self.token_to_kv_pool,
self.token_to_kv_pool_allocator,
self.running_batch,
self.new_token_ratio,
self.max_prefill_tokens,
@@ -1099,7 +1103,7 @@ class Scheduler:
new_batch = ScheduleBatch.init_new(
can_run_list,
self.req_to_token_pool,
self.token_to_kv_pool,
self.token_to_kv_pool_allocator,
self.tree_cache,
self.model_config,
self.enable_overlap,
@@ -1143,8 +1147,6 @@ class Scheduler:
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
self.new_token_ratio = new_token_ratio
if self.draft_worker:
self.draft_worker.finish_request(retracted_reqs)
logger.info(
"Decode out of memory happened. "
@@ -1184,11 +1186,12 @@ class Scheduler:
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
bid = model_worker_batch.bid
else:
(
logits_output,
next_token_ids,
model_worker_batch,
bid,
num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += (
@@ -1214,7 +1217,7 @@ class Scheduler:
next_token_ids=next_token_ids,
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
bid=model_worker_batch.bid,
bid=bid,
)
else: # embedding or reward model
model_worker_batch = batch.get_model_worker_batch()
@@ -1230,6 +1233,7 @@ class Scheduler:
result: Union[GenerationBatchResult, EmbeddingBatchResult],
):
if batch.forward_mode.is_decode():
assert isinstance(result, GenerationBatchResult)
self.process_batch_result_decode(batch, result)
if batch.is_empty():
self.running_batch = None
@@ -1302,7 +1306,7 @@ class Scheduler:
if self.is_mixed_chunk and self.enable_overlap and req.finished():
# Free the one delayed token for the mixed decode batch
j = len(batch.out_cache_loc) - len(batch.reqs) + i
self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
continue
if req.is_chunked <= 0:
@@ -1420,23 +1424,27 @@ class Scheduler:
self.num_generated_tokens += len(batch.reqs)
if self.enable_overlap:
assert batch.spec_algorithm.is_none()
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
next_token_logprobs = logits_output.next_token_logprobs
else:
elif batch.spec_algorithm.is_none():
# spec decoding handles output logprobs inside verify process.
next_token_ids = next_token_ids.tolist()
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs.tolist()
self.token_to_kv_pool.free_group_begin()
self.token_to_kv_pool_allocator.free_group_begin()
# Check finish condition
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
# We should ignore using next_token_ids for spec decoding cases.
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if req.is_retracted:
continue
if self.enable_overlap and req.finished():
# Free the one delayed token
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
continue
if batch.spec_algorithm.is_none():
@@ -1479,7 +1487,7 @@ class Scheduler:
batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs, batch.return_logprob)
self.token_to_kv_pool.free_group_end()
self.token_to_kv_pool_allocator.free_group_end()
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
if (
@@ -1718,9 +1726,6 @@ class Scheduler:
and not self.model_config.is_multimodal_gen
)
):
if self.draft_worker and req.finished():
self.draft_worker.finish_request(req)
rids.append(req.rid)
finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None
@@ -1860,7 +1865,7 @@ class Scheduler:
idle_batch = ScheduleBatch.init_new(
[],
self.req_to_token_pool,
self.token_to_kv_pool,
self.token_to_kv_pool_allocator,
self.tree_cache,
self.model_config,
self.enable_overlap,
@@ -1916,11 +1921,11 @@ class Scheduler:
if self.grammar_backend:
self.grammar_backend.reset()
self.req_to_token_pool.clear()
self.token_to_kv_pool.clear()
self.token_to_kv_pool_allocator.clear()
if not self.spec_algorithm.is_none():
self.draft_worker.model_runner.req_to_token_pool.clear()
self.draft_worker.model_runner.token_to_kv_pool.clear()
self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
self.num_generated_tokens = 0
self.forward_ct_decode = 0