[Eagle] Refactor eagle speculative decoding (#3986)
Co-authored-by: Ke Bao <ISPObaoke@163.com>
This commit is contained in:
@@ -164,7 +164,7 @@ class Scheduler:
|
||||
self.server_args.speculative_num_draft_tokens
|
||||
+ (
|
||||
self.server_args.speculative_eagle_topk
|
||||
* self.server_args.speculative_num_steps
|
||||
* self.server_args.speculative_num_draft_tokens
|
||||
)
|
||||
)
|
||||
if not self.spec_algorithm.is_none()
|
||||
@@ -309,7 +309,9 @@ class Scheduler:
|
||||
)
|
||||
|
||||
# Init memory pool and cache
|
||||
self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
|
||||
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
||||
self.tp_worker.get_memory_pool()
|
||||
)
|
||||
|
||||
if (
|
||||
server_args.chunked_prefill_size is not None
|
||||
@@ -317,18 +319,18 @@ class Scheduler:
|
||||
):
|
||||
self.tree_cache = ChunkCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool=self.token_to_kv_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
)
|
||||
else:
|
||||
if self.enable_hierarchical_cache:
|
||||
self.tree_cache = HiRadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool=self.token_to_kv_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
)
|
||||
else:
|
||||
self.tree_cache = RadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool=self.token_to_kv_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
|
||||
@@ -458,7 +460,6 @@ class Scheduler:
|
||||
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
||||
(ProfileReq, self.profile),
|
||||
(GetInternalStateReq, self.get_internal_state),
|
||||
(SetInternalStateReq, self.set_internal_state),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -809,7 +810,8 @@ class Scheduler:
|
||||
running_bs: int,
|
||||
):
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
self._largest_prefill_len = max(
|
||||
self._largest_prefill_len, adder.log_input_tokens
|
||||
@@ -844,7 +846,8 @@ class Scheduler:
|
||||
self.num_generated_tokens = 0
|
||||
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
|
||||
if RECORD_STEP_TIME:
|
||||
@@ -894,7 +897,8 @@ class Scheduler:
|
||||
|
||||
def check_memory(self):
|
||||
available_size = (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
protected_size = self.tree_cache.protected_size()
|
||||
memory_leak = available_size != (
|
||||
@@ -999,7 +1003,7 @@ class Scheduler:
|
||||
# Prefill policy
|
||||
adder = PrefillAdder(
|
||||
self.tree_cache,
|
||||
self.token_to_kv_pool,
|
||||
self.token_to_kv_pool_allocator,
|
||||
self.running_batch,
|
||||
self.new_token_ratio,
|
||||
self.max_prefill_tokens,
|
||||
@@ -1099,7 +1103,7 @@ class Scheduler:
|
||||
new_batch = ScheduleBatch.init_new(
|
||||
can_run_list,
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool,
|
||||
self.token_to_kv_pool_allocator,
|
||||
self.tree_cache,
|
||||
self.model_config,
|
||||
self.enable_overlap,
|
||||
@@ -1143,8 +1147,6 @@ class Scheduler:
|
||||
|
||||
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
|
||||
self.new_token_ratio = new_token_ratio
|
||||
if self.draft_worker:
|
||||
self.draft_worker.finish_request(retracted_reqs)
|
||||
|
||||
logger.info(
|
||||
"Decode out of memory happened. "
|
||||
@@ -1184,11 +1186,12 @@ class Scheduler:
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
bid = model_worker_batch.bid
|
||||
else:
|
||||
(
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
model_worker_batch,
|
||||
bid,
|
||||
num_accepted_tokens,
|
||||
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
||||
self.spec_num_total_accepted_tokens += (
|
||||
@@ -1214,7 +1217,7 @@ class Scheduler:
|
||||
next_token_ids=next_token_ids,
|
||||
extend_input_len_per_req=extend_input_len_per_req,
|
||||
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
||||
bid=model_worker_batch.bid,
|
||||
bid=bid,
|
||||
)
|
||||
else: # embedding or reward model
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
@@ -1230,6 +1233,7 @@ class Scheduler:
|
||||
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
||||
):
|
||||
if batch.forward_mode.is_decode():
|
||||
assert isinstance(result, GenerationBatchResult)
|
||||
self.process_batch_result_decode(batch, result)
|
||||
if batch.is_empty():
|
||||
self.running_batch = None
|
||||
@@ -1302,7 +1306,7 @@ class Scheduler:
|
||||
if self.is_mixed_chunk and self.enable_overlap and req.finished():
|
||||
# Free the one delayed token for the mixed decode batch
|
||||
j = len(batch.out_cache_loc) - len(batch.reqs) + i
|
||||
self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
|
||||
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
|
||||
continue
|
||||
|
||||
if req.is_chunked <= 0:
|
||||
@@ -1420,23 +1424,27 @@ class Scheduler:
|
||||
self.num_generated_tokens += len(batch.reqs)
|
||||
|
||||
if self.enable_overlap:
|
||||
assert batch.spec_algorithm.is_none()
|
||||
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
||||
next_token_logprobs = logits_output.next_token_logprobs
|
||||
else:
|
||||
elif batch.spec_algorithm.is_none():
|
||||
# spec decoding handles output logprobs inside verify process.
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
if batch.return_logprob:
|
||||
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
||||
|
||||
self.token_to_kv_pool.free_group_begin()
|
||||
self.token_to_kv_pool_allocator.free_group_begin()
|
||||
|
||||
# Check finish condition
|
||||
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
|
||||
# We should ignore using next_token_ids for spec decoding cases.
|
||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||
if req.is_retracted:
|
||||
continue
|
||||
|
||||
if self.enable_overlap and req.finished():
|
||||
# Free the one delayed token
|
||||
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
||||
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
|
||||
continue
|
||||
|
||||
if batch.spec_algorithm.is_none():
|
||||
@@ -1479,7 +1487,7 @@ class Scheduler:
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
self.stream_output(batch.reqs, batch.return_logprob)
|
||||
|
||||
self.token_to_kv_pool.free_group_end()
|
||||
self.token_to_kv_pool_allocator.free_group_end()
|
||||
|
||||
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
||||
if (
|
||||
@@ -1718,9 +1726,6 @@ class Scheduler:
|
||||
and not self.model_config.is_multimodal_gen
|
||||
)
|
||||
):
|
||||
if self.draft_worker and req.finished():
|
||||
self.draft_worker.finish_request(req)
|
||||
|
||||
rids.append(req.rid)
|
||||
finished_reasons.append(
|
||||
req.finished_reason.to_json() if req.finished_reason else None
|
||||
@@ -1860,7 +1865,7 @@ class Scheduler:
|
||||
idle_batch = ScheduleBatch.init_new(
|
||||
[],
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool,
|
||||
self.token_to_kv_pool_allocator,
|
||||
self.tree_cache,
|
||||
self.model_config,
|
||||
self.enable_overlap,
|
||||
@@ -1916,11 +1921,11 @@ class Scheduler:
|
||||
if self.grammar_backend:
|
||||
self.grammar_backend.reset()
|
||||
self.req_to_token_pool.clear()
|
||||
self.token_to_kv_pool.clear()
|
||||
self.token_to_kv_pool_allocator.clear()
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
self.draft_worker.model_runner.req_to_token_pool.clear()
|
||||
self.draft_worker.model_runner.token_to_kv_pool.clear()
|
||||
self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
|
||||
|
||||
self.num_generated_tokens = 0
|
||||
self.forward_ct_decode = 0
|
||||
|
||||
Reference in New Issue
Block a user