Support page size > 1 + eagle (#4908)
This commit is contained in:
@@ -11,7 +11,7 @@ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
|
||||
from sglang.srt.layers.dp_attention import disable_dp_size
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, get_last_loc
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
@@ -67,6 +67,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.gpu_id = gpu_id
|
||||
self.device = server_args.device
|
||||
self.target_worker = target_worker
|
||||
self.page_size = server_args.page_size
|
||||
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
@@ -234,14 +235,11 @@ class EAGLEWorker(TpModelWorker):
|
||||
"""
|
||||
if batch.forward_mode.is_decode():
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
spec_info, to_free_cache_loc = self.draft(batch)
|
||||
spec_info = self.draft(batch)
|
||||
logits_output, verify_output, model_worker_batch = self.verify(
|
||||
batch, spec_info
|
||||
)
|
||||
|
||||
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
|
||||
self.token_to_kv_pool_allocator.free(to_free_cache_loc)
|
||||
|
||||
# If it is None, it means all requests are finished
|
||||
if batch.spec_info.verified_id is not None:
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
@@ -305,9 +303,59 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
|
||||
# Allocate cache locations
|
||||
out_cache_loc = batch.alloc_token_slots(
|
||||
num_seqs * self.topk * self.speculative_num_steps
|
||||
)
|
||||
if self.page_size == 1:
|
||||
out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
|
||||
num_seqs * self.topk * self.speculative_num_steps, backup_state=True
|
||||
)
|
||||
else:
|
||||
if self.topk == 1:
|
||||
prefix_lens = batch.seq_lens
|
||||
seq_lens = prefix_lens + self.speculative_num_steps
|
||||
extend_num_tokens = num_seqs * self.speculative_num_steps
|
||||
else:
|
||||
# In this case, the last partial page needs to be duplicated.
|
||||
# KV cache layout in batch.req_to_token_pool.req_to_token:
|
||||
#
|
||||
# | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. |
|
||||
# prefix top-k = 0 tok-k = 1 top-k = 2
|
||||
#
|
||||
# "-" means prefix tokens
|
||||
# "x" means speculative draft tokens
|
||||
# "." means padded tokens
|
||||
|
||||
# TODO: fuse these ops
|
||||
prefix_lens = batch.seq_lens
|
||||
last_page_lens = prefix_lens % self.page_size
|
||||
num_new_pages = (
|
||||
last_page_lens + self.speculative_num_steps + self.page_size - 1
|
||||
) // self.page_size
|
||||
seq_lens = (
|
||||
prefix_lens // self.page_size * self.page_size
|
||||
+ num_new_pages * (self.page_size * self.topk)
|
||||
)
|
||||
extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
|
||||
raise NotImplementedError(
|
||||
"page_size > 1 and top_k > 1 are not supported."
|
||||
)
|
||||
# TODO: Support page_size > 1 and top_k > 1
|
||||
# 1. Duplicate the KV cache in the last partial page for all top-k segments
|
||||
# 2. Modify generate_draft_decode_kv_indices accordingly
|
||||
|
||||
last_loc = get_last_loc(
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.req_pool_indices,
|
||||
prefix_lens,
|
||||
)
|
||||
out_cache_loc, token_to_kv_pool_state_backup = (
|
||||
batch.alloc_paged_token_slots_extend(
|
||||
prefix_lens,
|
||||
seq_lens,
|
||||
last_loc,
|
||||
extend_num_tokens,
|
||||
backup_state=True,
|
||||
)
|
||||
)
|
||||
|
||||
assign_draft_cache_locs[(num_seqs,)](
|
||||
batch.req_pool_indices,
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
@@ -316,6 +364,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.req_to_token_pool.req_to_token.shape[1],
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
self.page_size,
|
||||
)
|
||||
batch.out_cache_loc = out_cache_loc
|
||||
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
||||
@@ -343,6 +392,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
# Run forward steps
|
||||
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
||||
|
||||
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
|
||||
|
||||
ret = EagleVerifyInput.create(
|
||||
spec_info.verified_id,
|
||||
score_list,
|
||||
@@ -354,7 +405,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.speculative_num_steps,
|
||||
self.server_args.speculative_num_draft_tokens,
|
||||
)
|
||||
return ret, out_cache_loc
|
||||
return ret
|
||||
|
||||
def draft_forward(self, forward_batch: ForwardBatch):
|
||||
# Parse args
|
||||
@@ -411,7 +462,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
return score_list, token_list, parents_list
|
||||
|
||||
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
||||
spec_info.prepare_for_verify(batch)
|
||||
spec_info.prepare_for_verify(batch, self.page_size)
|
||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||
batch.spec_info = spec_info
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
@@ -421,7 +472,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
spec_info.hidden_states = logits_output.hidden_states
|
||||
res: EagleVerifyOutput = spec_info.verify(
|
||||
batch, logits_output, self.token_to_kv_pool_allocator
|
||||
batch,
|
||||
logits_output,
|
||||
self.token_to_kv_pool_allocator,
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
# Post process based on verified outputs.
|
||||
|
||||
Reference in New Issue
Block a user