Support page size > 1 + eagle (#4908)

This commit is contained in:
Lianmin Zheng
2025-03-30 00:46:23 -07:00
committed by GitHub
parent 5ec5eaf760
commit b26bc86b36
16 changed files with 374 additions and 71 deletions

View File

@@ -11,7 +11,7 @@ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
from sglang.srt.layers.dp_attention import disable_dp_size
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.schedule_batch import ScheduleBatch, get_last_loc
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
@@ -67,6 +67,7 @@ class EAGLEWorker(TpModelWorker):
self.gpu_id = gpu_id
self.device = server_args.device
self.target_worker = target_worker
self.page_size = server_args.page_size
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
@@ -234,14 +235,11 @@ class EAGLEWorker(TpModelWorker):
"""
if batch.forward_mode.is_decode():
with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info, to_free_cache_loc = self.draft(batch)
spec_info = self.draft(batch)
logits_output, verify_output, model_worker_batch = self.verify(
batch, spec_info
)
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
self.token_to_kv_pool_allocator.free(to_free_cache_loc)
# If it is None, it means all requests are finished
if batch.spec_info.verified_id is not None:
with self.draft_tp_context(self.draft_model_runner.tp_group):
@@ -305,9 +303,59 @@ class EAGLEWorker(TpModelWorker):
)
# Allocate cache locations
out_cache_loc = batch.alloc_token_slots(
num_seqs * self.topk * self.speculative_num_steps
)
if self.page_size == 1:
out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
num_seqs * self.topk * self.speculative_num_steps, backup_state=True
)
else:
if self.topk == 1:
prefix_lens = batch.seq_lens
seq_lens = prefix_lens + self.speculative_num_steps
extend_num_tokens = num_seqs * self.speculative_num_steps
else:
# In this case, the last partial page needs to be duplicated.
# KV cache layout in batch.req_to_token_pool.req_to_token:
#
# | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. |
# prefix top-k = 0 tok-k = 1 top-k = 2
#
# "-" means prefix tokens
# "x" means speculative draft tokens
# "." means padded tokens
# TODO: fuse these ops
prefix_lens = batch.seq_lens
last_page_lens = prefix_lens % self.page_size
num_new_pages = (
last_page_lens + self.speculative_num_steps + self.page_size - 1
) // self.page_size
seq_lens = (
prefix_lens // self.page_size * self.page_size
+ num_new_pages * (self.page_size * self.topk)
)
extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
raise NotImplementedError(
"page_size > 1 and top_k > 1 are not supported."
)
# TODO: Support page_size > 1 and top_k > 1
# 1. Duplicate the KV cache in the last partial page for all top-k segments
# 2. Modify generate_draft_decode_kv_indices accordingly
last_loc = get_last_loc(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
prefix_lens,
)
out_cache_loc, token_to_kv_pool_state_backup = (
batch.alloc_paged_token_slots_extend(
prefix_lens,
seq_lens,
last_loc,
extend_num_tokens,
backup_state=True,
)
)
assign_draft_cache_locs[(num_seqs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
@@ -316,6 +364,7 @@ class EAGLEWorker(TpModelWorker):
batch.req_to_token_pool.req_to_token.shape[1],
self.topk,
self.speculative_num_steps,
self.page_size,
)
batch.out_cache_loc = out_cache_loc
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
@@ -343,6 +392,8 @@ class EAGLEWorker(TpModelWorker):
# Run forward steps
score_list, token_list, parents_list = self.draft_forward(forward_batch)
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
ret = EagleVerifyInput.create(
spec_info.verified_id,
score_list,
@@ -354,7 +405,7 @@ class EAGLEWorker(TpModelWorker):
self.speculative_num_steps,
self.server_args.speculative_num_draft_tokens,
)
return ret, out_cache_loc
return ret
def draft_forward(self, forward_batch: ForwardBatch):
# Parse args
@@ -411,7 +462,7 @@ class EAGLEWorker(TpModelWorker):
return score_list, token_list, parents_list
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
spec_info.prepare_for_verify(batch)
spec_info.prepare_for_verify(batch, self.page_size)
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch()
@@ -421,7 +472,10 @@ class EAGLEWorker(TpModelWorker):
self._detect_nan_if_needed(logits_output)
spec_info.hidden_states = logits_output.hidden_states
res: EagleVerifyOutput = spec_info.verify(
batch, logits_output, self.token_to_kv_pool_allocator
batch,
logits_output,
self.token_to_kv_pool_allocator,
self.page_size,
)
# Post process based on verified outputs.