Simplify prepare_extend_after_decode (#6987)

This commit is contained in:
Lianmin Zheng
2025-06-09 16:39:21 -07:00
committed by GitHub
parent a968c888c0
commit dc0705a504
9 changed files with 140 additions and 176 deletions

View File

@@ -23,6 +23,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode,
)
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
EAGLEDraftCudaGraphRunner,
)
@@ -69,7 +70,6 @@ class EAGLEWorker(TpModelWorker):
self.server_args = server_args
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.padded_static_len = self.speculative_num_steps + 1
self.enable_nan_detection = server_args.enable_nan_detection
self.gpu_id = gpu_id
self.device = server_args.device
@@ -78,6 +78,7 @@ class EAGLEWorker(TpModelWorker):
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.padded_static_len = -1
# Override context length with target model's context length
server_args.context_length = target_worker.model_runner.model_config.context_len
@@ -184,7 +185,6 @@ class EAGLEWorker(TpModelWorker):
self.draft_model_runner,
skip_prefill=False,
)
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = True
elif self.server_args.attention_backend == "triton":
from sglang.srt.layers.attention.triton_backend import (
@@ -201,7 +201,6 @@ class EAGLEWorker(TpModelWorker):
self.draft_model_runner,
skip_prefill=False,
)
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "fa3":
from sglang.srt.layers.attention.flashattention_backend import (
@@ -218,7 +217,6 @@ class EAGLEWorker(TpModelWorker):
self.draft_model_runner,
skip_prefill=False,
)
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "flashmla":
from sglang.srt.layers.attention.flashmla_backend import (
@@ -231,7 +229,6 @@ class EAGLEWorker(TpModelWorker):
self.speculative_num_steps,
)
self.draft_extend_attn_backend = None
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = False
else:
raise ValueError(
@@ -319,10 +316,12 @@ class EAGLEWorker(TpModelWorker):
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
else:
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
logits_output, next_token_ids, bid, seq_lens_cpu = (
self.forward_target_extend(batch)
)
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend(
batch, logits_output.hidden_states, next_token_ids
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
)
return logits_output, next_token_ids, bid, 0, False
@@ -346,7 +345,12 @@ class EAGLEWorker(TpModelWorker):
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
model_worker_batch
)
return logits_output, next_token_ids, model_worker_batch.bid
return (
logits_output,
next_token_ids,
model_worker_batch.bid,
model_worker_batch.seq_lens_cpu,
)
def draft(self, batch: ScheduleBatch):
# Parse args
@@ -452,7 +456,14 @@ class EAGLEWorker(TpModelWorker):
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
ret = EagleVerifyInput.create(
(
tree_mask,
position,
retrive_index,
retrive_next_token,
retrive_next_sibling,
draft_tokens,
) = build_tree_kernel_efficient(
spec_info.verified_id,
score_list,
token_list,
@@ -463,7 +474,22 @@ class EAGLEWorker(TpModelWorker):
self.speculative_num_steps,
self.server_args.speculative_num_draft_tokens,
)
return ret
return EagleVerifyInput(
draft_token=draft_tokens,
custom_mask=tree_mask,
positions=position,
retrive_index=retrive_index,
retrive_next_token=retrive_next_token,
retrive_next_sibling=retrive_next_sibling,
retrive_cum_len=None,
spec_steps=self.speculative_num_steps,
topk=self.topk,
draft_token_num=self.server_args.speculative_num_draft_tokens,
capture_hidden_mode=CaptureHiddenMode.FULL,
seq_lens_sum=forward_batch.seq_lens_sum,
seq_lens_cpu=forward_batch.seq_lens_cpu,
)
def draft_forward(self, forward_batch: ForwardBatch):
# Parse args
@@ -523,7 +549,9 @@ class EAGLEWorker(TpModelWorker):
spec_info.prepare_for_verify(batch, self.page_size)
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=spec_info.seq_lens_cpu
)
if batch.has_grammar:
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
@@ -650,6 +678,7 @@ class EAGLEWorker(TpModelWorker):
batch: ScheduleBatch,
hidden_states: torch.Tensor,
next_token_ids: List[int],
seq_lens_cpu: torch.Tensor,
):
"""Run draft model extend. This API modifies the states of the batch.
@@ -664,7 +693,9 @@ class EAGLEWorker(TpModelWorker):
)
batch.spec_info.prepare_for_extend(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu
)
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
@@ -683,19 +714,18 @@ class EAGLEWorker(TpModelWorker):
return_logprob_backup = batch.return_logprob
# Prepare metadata
batch.forward_mode = ForwardMode.DRAFT_EXTEND
batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
self.server_args.context_length,
pad_input=self.cuda_graph_runner_for_draft_extend is not None,
)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
batch.return_logprob = False
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
if forward_batch.seq_lens_cpu is not None:
forward_batch.seq_lens_sum = forward_batch.seq_lens_cpu.sum().item()
else:
forward_batch.seq_lens_sum = batch.seq_lens.sum().item()
# Run
can_cuda_graph = (
@@ -706,14 +736,19 @@ class EAGLEWorker(TpModelWorker):
logits_output = self.cuda_graph_runner_for_draft_extend.replay(
forward_batch
)
forward_batch.spec_info.topk_p, forward_batch.spec_info.topk_index = (
logits_output.topk_p,
logits_output.topk_index,
)
forward_batch.spec_info.hidden_states = logits_output.hidden_states
else:
self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
self.capture_for_decode(logits_output, forward_batch.spec_info)
self._detect_nan_if_needed(logits_output)
self.capture_for_decode(logits_output, forward_batch.spec_info)
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`