Simplify prepare_extend_after_decode (#6987)
This commit is contained in:
@@ -23,6 +23,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
||||
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
||||
EAGLEDraftCudaGraphRunner,
|
||||
)
|
||||
@@ -69,7 +70,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.server_args = server_args
|
||||
self.topk = server_args.speculative_eagle_topk
|
||||
self.speculative_num_steps = server_args.speculative_num_steps
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.enable_nan_detection = server_args.enable_nan_detection
|
||||
self.gpu_id = gpu_id
|
||||
self.device = server_args.device
|
||||
@@ -78,6 +78,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
self.padded_static_len = -1
|
||||
|
||||
# Override context length with target model's context length
|
||||
server_args.context_length = target_worker.model_runner.model_config.context_len
|
||||
@@ -184,7 +185,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.draft_model_runner,
|
||||
skip_prefill=False,
|
||||
)
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = True
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
from sglang.srt.layers.attention.triton_backend import (
|
||||
@@ -201,7 +201,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.draft_model_runner,
|
||||
skip_prefill=False,
|
||||
)
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = False
|
||||
elif self.server_args.attention_backend == "fa3":
|
||||
from sglang.srt.layers.attention.flashattention_backend import (
|
||||
@@ -218,7 +217,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.draft_model_runner,
|
||||
skip_prefill=False,
|
||||
)
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = False
|
||||
elif self.server_args.attention_backend == "flashmla":
|
||||
from sglang.srt.layers.attention.flashmla_backend import (
|
||||
@@ -231,7 +229,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = None
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = False
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -319,10 +316,12 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
|
||||
else:
|
||||
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
|
||||
logits_output, next_token_ids, bid, seq_lens_cpu = (
|
||||
self.forward_target_extend(batch)
|
||||
)
|
||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||
self.forward_draft_extend(
|
||||
batch, logits_output.hidden_states, next_token_ids
|
||||
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
|
||||
)
|
||||
return logits_output, next_token_ids, bid, 0, False
|
||||
|
||||
@@ -346,7 +345,12 @@ class EAGLEWorker(TpModelWorker):
|
||||
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
return logits_output, next_token_ids, model_worker_batch.bid
|
||||
return (
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
model_worker_batch.bid,
|
||||
model_worker_batch.seq_lens_cpu,
|
||||
)
|
||||
|
||||
def draft(self, batch: ScheduleBatch):
|
||||
# Parse args
|
||||
@@ -452,7 +456,14 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
|
||||
|
||||
ret = EagleVerifyInput.create(
|
||||
(
|
||||
tree_mask,
|
||||
position,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
draft_tokens,
|
||||
) = build_tree_kernel_efficient(
|
||||
spec_info.verified_id,
|
||||
score_list,
|
||||
token_list,
|
||||
@@ -463,7 +474,22 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.speculative_num_steps,
|
||||
self.server_args.speculative_num_draft_tokens,
|
||||
)
|
||||
return ret
|
||||
|
||||
return EagleVerifyInput(
|
||||
draft_token=draft_tokens,
|
||||
custom_mask=tree_mask,
|
||||
positions=position,
|
||||
retrive_index=retrive_index,
|
||||
retrive_next_token=retrive_next_token,
|
||||
retrive_next_sibling=retrive_next_sibling,
|
||||
retrive_cum_len=None,
|
||||
spec_steps=self.speculative_num_steps,
|
||||
topk=self.topk,
|
||||
draft_token_num=self.server_args.speculative_num_draft_tokens,
|
||||
capture_hidden_mode=CaptureHiddenMode.FULL,
|
||||
seq_lens_sum=forward_batch.seq_lens_sum,
|
||||
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
||||
)
|
||||
|
||||
def draft_forward(self, forward_batch: ForwardBatch):
|
||||
# Parse args
|
||||
@@ -523,7 +549,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
spec_info.prepare_for_verify(batch, self.page_size)
|
||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||
batch.spec_info = spec_info
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch = batch.get_model_worker_batch(
|
||||
seq_lens_cpu_cache=spec_info.seq_lens_cpu
|
||||
)
|
||||
|
||||
if batch.has_grammar:
|
||||
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
|
||||
@@ -650,6 +678,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch: ScheduleBatch,
|
||||
hidden_states: torch.Tensor,
|
||||
next_token_ids: List[int],
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
):
|
||||
"""Run draft model extend. This API modifies the states of the batch.
|
||||
|
||||
@@ -664,7 +693,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
batch.spec_info.prepare_for_extend(batch)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch = batch.get_model_worker_batch(
|
||||
seq_lens_cpu_cache=seq_lens_cpu
|
||||
)
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
@@ -683,19 +714,18 @@ class EAGLEWorker(TpModelWorker):
|
||||
return_logprob_backup = batch.return_logprob
|
||||
|
||||
# Prepare metadata
|
||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
||||
batch.spec_info.prepare_extend_after_decode(
|
||||
batch,
|
||||
self.speculative_num_steps,
|
||||
self.server_args.context_length,
|
||||
pad_input=self.cuda_graph_runner_for_draft_extend is not None,
|
||||
)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
batch.return_logprob = False
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
if forward_batch.seq_lens_cpu is not None:
|
||||
forward_batch.seq_lens_sum = forward_batch.seq_lens_cpu.sum().item()
|
||||
else:
|
||||
forward_batch.seq_lens_sum = batch.seq_lens.sum().item()
|
||||
|
||||
# Run
|
||||
can_cuda_graph = (
|
||||
@@ -706,14 +736,19 @@ class EAGLEWorker(TpModelWorker):
|
||||
logits_output = self.cuda_graph_runner_for_draft_extend.replay(
|
||||
forward_batch
|
||||
)
|
||||
forward_batch.spec_info.topk_p, forward_batch.spec_info.topk_index = (
|
||||
logits_output.topk_p,
|
||||
logits_output.topk_index,
|
||||
)
|
||||
forward_batch.spec_info.hidden_states = logits_output.hidden_states
|
||||
else:
|
||||
self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
|
||||
logits_output = self.draft_model_runner.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
|
||||
# Restore backup.
|
||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||
|
||||
Reference in New Issue
Block a user