Revert "[EAGLE] Refactor code for page size > 1 & more simplifications" (#7210)

This commit is contained in:
Lianmin Zheng
2025-06-15 02:48:00 -07:00
committed by GitHub
parent 5f1ab32717
commit fff10809bf
7 changed files with 150 additions and 647 deletions

View File

@@ -35,17 +35,11 @@ from sglang.srt.speculative.eagle_utils import (
EagleVerifyInput,
EagleVerifyOutput,
assign_draft_cache_locs,
fast_topk,
generate_token_bitmask,
select_top_k_tokens,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
empty_context,
get_available_gpu_memory,
is_cuda,
next_power_of_2,
)
from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda
if is_cuda():
from sgl_kernel import segment_packbits
@@ -158,12 +152,6 @@ class EAGLEWorker(TpModelWorker):
self.init_attention_backend()
self.init_cuda_graphs()
# Some dummy tensors
self.num_new_pages_per_topk = torch.empty(
(), dtype=torch.int64, device=self.device
)
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
def init_attention_backend(self):
# Create multi-step attn backends and cuda graph runners
if self.server_args.attention_backend == "flashinfer":
@@ -266,7 +254,7 @@ class EAGLEWorker(TpModelWorker):
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
)
# Capture extend
@@ -281,7 +269,7 @@ class EAGLEWorker(TpModelWorker):
)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
)
@property
@@ -302,6 +290,7 @@ class EAGLEWorker(TpModelWorker):
A tuple of the final logit output of the target model, next tokens accepted,
the batch id (used for overlap schedule), and number of accepted tokens.
"""
if batch.forward_mode.is_decode():
with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info = self.draft(batch)
@@ -377,21 +366,14 @@ class EAGLEWorker(TpModelWorker):
)
# Allocate cache locations
# Layout of the out_cache_loc
# [ topk 0 ] [ topk 1 ]
# [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
if self.page_size == 1:
out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
num_seqs * self.speculative_num_steps * self.topk, backup_state=True
num_seqs * self.topk * self.speculative_num_steps, backup_state=True
)
else:
if self.topk == 1:
prefix_lens, seq_lens, last_loc = get_last_loc_large_page_size_top_k_1(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
batch.seq_lens,
self.speculative_num_steps,
)
prefix_lens = batch.seq_lens
seq_lens = prefix_lens + self.speculative_num_steps
extend_num_tokens = num_seqs * self.speculative_num_steps
else:
# In this case, the last partial page needs to be duplicated.
@@ -404,33 +386,29 @@ class EAGLEWorker(TpModelWorker):
# "x" means speculative draft tokens
# "." means padded tokens
# TODO(lmzheng): The current implementation is still a fake support
# for page size > 1. In the `assign_draft_cache_locs` below,
# we directly move the indices instead of the real kv cache.
# This only works when the kernel backend runs with page size = 1.
# If the kernel backend runs with page size > 1, we need to
# duplicate the real KV cache. The overhead of duplicating KV
# cache seems okay because the draft KV cache only has one layer.
# see a related copy operation in MHATokenToKVPool::move_kv_cache.
(
prefix_lens,
seq_lens,
last_loc,
self.num_new_pages_per_topk,
self.extend_lens,
) = get_last_loc_large_page_size_large_top_k(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
batch.seq_lens,
self.speculative_num_steps,
self.topk,
self.page_size,
# TODO: fuse these ops
prefix_lens = batch.seq_lens
last_page_lens = prefix_lens % self.page_size
num_new_pages = (
last_page_lens + self.speculative_num_steps + self.page_size - 1
) // self.page_size
seq_lens = (
prefix_lens // self.page_size * self.page_size
+ num_new_pages * (self.page_size * self.topk)
)
extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
raise NotImplementedError(
"page_size > 1 and top_k > 1 are not supported."
)
# TODO: Support page_size > 1 and top_k > 1
# 1. Duplicate the KV cache in the last partial page for all top-k segments
# 2. Modify generate_draft_decode_kv_indices accordingly
# TODO(lmzheng): remove this device sync
extend_num_tokens = torch.sum(self.extend_lens).item()
last_loc = get_last_loc(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
prefix_lens,
)
out_cache_loc, token_to_kv_pool_state_backup = (
batch.alloc_paged_token_slots_extend(
prefix_lens,
@@ -445,31 +423,19 @@ class EAGLEWorker(TpModelWorker):
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
batch.seq_lens,
self.extend_lens,
self.num_new_pages_per_topk,
out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
self.topk,
self.speculative_num_steps,
self.page_size,
next_power_of_2(num_seqs),
next_power_of_2(self.speculative_num_steps),
)
if self.page_size > 1 and self.topk > 1:
# Remove padded slots
out_cache_loc = out_cache_loc[
: num_seqs * self.topk * self.speculative_num_steps
]
batch.out_cache_loc = out_cache_loc
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
batch.return_hidden_states = False
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
# Get forward batch
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch()
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
@@ -538,13 +504,6 @@ class EAGLEWorker(TpModelWorker):
if self.hot_token_id is not None:
topk_index = self.hot_token_id[topk_index]
out_cache_loc = out_cache_loc.reshape(
forward_batch.batch_size, self.topk, self.speculative_num_steps
)
out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape(
self.speculative_num_steps, -1
)
# Return values
score_list: List[torch.Tensor] = []
token_list: List[torch.Tensor] = []
@@ -566,7 +525,10 @@ class EAGLEWorker(TpModelWorker):
# Set inputs
forward_batch.input_ids = input_ids
forward_batch.out_cache_loc = out_cache_loc[i:]
out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
forward_batch.out_cache_loc = out_cache_loc[
:, self.topk * i : self.topk * (i + 1)
].flatten()
forward_batch.positions.add_(1)
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
spec_info.hidden_states = hidden_states
@@ -624,7 +586,7 @@ class EAGLEWorker(TpModelWorker):
if vocab_mask is not None:
assert spec_info.grammar is not None
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
# NOTE (sk): otherwise, this vocab mask will be the one from the previous extend stage
# otherwise, this vocab mask will be the one from the previous extend stage
# and will be applied to produce wrong results
batch.sampling_info.vocab_mask = None
@@ -645,13 +607,13 @@ class EAGLEWorker(TpModelWorker):
]
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
if batch.return_logprob:
self.add_logprob_values(batch, res, logits_output)
# Prepare the batch for the next draft forwards.
batch.forward_mode = ForwardMode.DECODE
batch.spec_info = res.draft_input
if batch.return_logprob:
self.add_logprob_values(batch, res, logits_output)
return logits_output, res, model_worker_batch, can_run_cuda_graph
def add_logprob_values(
@@ -664,16 +626,8 @@ class EAGLEWorker(TpModelWorker):
logits_output = res.logits_output
top_logprobs_nums = batch.top_logprobs_nums
token_ids_logprobs = batch.token_ids_logprobs
accepted_indices = res.accepted_indices
assert len(accepted_indices) == len(logits_output.next_token_logits)
temperatures = batch.sampling_info.temperatures
num_draft_tokens = batch.spec_info.draft_token_num
# acceptance indices are the indices in a "flattened" batch.
# dividing it to num_draft_tokens will yield the actual batch index.
temperatures = temperatures[accepted_indices // num_draft_tokens]
logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits / temperatures, dim=-1
logits_output.next_token_logits, dim=-1
)
batch_next_token_ids = res.verified_id
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
@@ -708,7 +662,7 @@ class EAGLEWorker(TpModelWorker):
pt = 0
next_token_logprobs = logits_output.next_token_logprobs.tolist()
verified_ids = batch_next_token_ids.tolist()
for req, num_tokens in zip(batch.reqs, num_tokens_per_req, strict=True):
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
for _ in range(num_tokens):
if req.return_logprob:
req.output_token_logprobs_val.append(next_token_logprobs[pt])
@@ -736,6 +690,7 @@ class EAGLEWorker(TpModelWorker):
hidden_states: Hidden states from the target model forward
next_token_ids: Next token ids generated from the target forward.
"""
# Sometimes we get hidden states produced by CaptureHiddenMode.FULL, so we have to select just the last
batch.spec_info = EagleDraftInput(
hidden_states=hidden_states,
verified_id=next_token_ids,
@@ -746,6 +701,7 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu
)
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
@@ -768,7 +724,9 @@ class EAGLEWorker(TpModelWorker):
batch,
self.speculative_num_steps,
)
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch()
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
@@ -832,47 +790,3 @@ def load_token_map(token_map_path: str) -> List[int]:
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
hot_token_id = torch.load(token_map_path, weights_only=True)
return torch.tensor(hot_token_id, dtype=torch.int32)
@torch.compile(dynamic=True)
def get_last_loc_large_page_size_top_k_1(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens,
speculative_num_steps: int,
):
prefix_lens = seq_lens
seq_lens = prefix_lens + speculative_num_steps
last_loc = get_last_loc(
req_to_token,
req_pool_indices,
prefix_lens,
)
return prefix_lens, seq_lens, last_loc
@torch.compile(dynamic=True)
def get_last_loc_large_page_size_large_top_k(
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
speculative_num_steps: int,
topk: int,
page_size: int,
):
prefix_lens = seq_lens
last_page_lens = prefix_lens % page_size
num_new_pages_per_topk = (
last_page_lens + speculative_num_steps + page_size - 1
) // page_size
seq_lens = prefix_lens // page_size * page_size + num_new_pages_per_topk * (
page_size * topk
)
extend_lens = seq_lens - prefix_lens
last_loc = get_last_loc(
req_to_token,
req_pool_indices,
prefix_lens,
)
return prefix_lens, seq_lens, last_loc, num_new_pages_per_topk, extend_lens