Fix a draft model accuracy bug in eagle; support step=1; return logprob in eagle (#4134)

Co-authored-by: Sehoon Kim <kssteven418@gmail.com>
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: Sehoon Kim <sehoon@x.ai>
This commit is contained in:
Lianmin Zheng
2025-03-06 06:13:59 -08:00
committed by GitHub
parent 3a3918121f
commit bc1534ff32
11 changed files with 304 additions and 106 deletions

View File

@@ -7,16 +7,14 @@ FlashInfer is faster and Triton is easier to customize.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
import math
import os
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Callable, List, Optional, Union
import torch
import triton
import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -37,7 +35,7 @@ if is_flashinfer_available():
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
from flashinfer.decode import PosEncodingMode
from flashinfer.decode import _get_range_buf, get_seq_lens
class WrapperDispatch(Enum):
@@ -73,8 +71,6 @@ class FlashInferAttnBackend(AttentionBackend):
):
super().__init__()
self.is_multimodal = model_runner.model_config.is_multimodal
# Parse constants
self.decode_use_tensor_cores = should_use_tensor_core(
kv_cache_dtype=model_runner.kv_cache_dtype,
@@ -86,6 +82,7 @@ class FlashInferAttnBackend(AttentionBackend):
)
self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill
self.is_multimodal = model_runner.model_config.is_multimodal
assert not (
model_runner.sliding_window_size is not None
@@ -115,7 +112,6 @@ class FlashInferAttnBackend(AttentionBackend):
device=model_runner.device,
)
self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None:
self.kv_indptr = [
@@ -163,9 +159,11 @@ class FlashInferAttnBackend(AttentionBackend):
)
)
self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
)
)
self.decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer,
@@ -178,13 +176,14 @@ class FlashInferAttnBackend(AttentionBackend):
if not skip_prefill:
self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill(
model_runner, self
)
) # for verify
self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self)
# Other metadata
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
self.decode_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {} # For verify
self.draft_extend_cuda_graph_metadata = {} # For draft extend
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
@@ -300,7 +299,6 @@ class FlashInferAttnBackend(AttentionBackend):
],
)
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_decode.update(
req_pool_indices,
@@ -312,6 +310,10 @@ class FlashInferAttnBackend(AttentionBackend):
)
self.decode_cuda_graph_metadata[bs] = decode_wrappers
self.forward_metadata = DecodeMetadata(decode_wrappers)
for i in range(self.num_wrappers):
decode_wrappers[i].begin_forward = partial(
fast_decode_plan, decode_wrappers[i]
)
elif forward_mode.is_target_verify():
prefill_wrappers = []
for i in range(self.num_wrappers):
@@ -437,7 +439,7 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=False,
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
logits_soft_cap=logits_soft_cap,
)
o, _ = merge_state(o1, s1, o2, s2)
@@ -636,9 +638,15 @@ class FlashInferIndicesUpdaterDecode:
bs = len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
if wrapper.is_cuda_graph_enabled:
# Directly write to the cuda graph input buffer
kv_indices = wrapper._paged_kv_indices_buf
else:
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
@@ -649,9 +657,9 @@ class FlashInferIndicesUpdaterDecode:
self.req_to_token.shape[1],
)
else:
assert isinstance(spec_info, EagleDraftInput)
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1
wrapper.begin_forward(
kv_indptr,
kv_indices,
@@ -699,7 +707,7 @@ class FlashInferIndicesUpdaterPrefill:
def update(
self,
req_pool_indices: torch.Tnesor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
@@ -713,7 +721,7 @@ class FlashInferIndicesUpdaterPrefill:
def update_single_wrapper(
self,
req_pool_indices: torch.Tnesor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
prefix_lens: torch.Tensor,
@@ -858,7 +866,6 @@ class FlashInferIndicesUpdaterPrefill:
kv_indices,
self.req_to_token.shape[1],
)
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
@@ -954,7 +961,10 @@ class FlashInferMultiStepDraftBackend:
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
def common_template(
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
self,
forward_batch: ForwardBatch,
kv_indices_buffer: torch.Tensor,
call_fn: Callable,
):
num_seqs = forward_batch.batch_size
bs = self.topk * num_seqs
@@ -1042,17 +1052,15 @@ class FlashInferMultiStepDraftBackend:
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[
forward_batch.batch_size
][0]
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
forward_batch.batch_size,
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
@@ -1113,6 +1121,11 @@ def should_use_tensor_core(
return False
# Use as a fast path to override the indptr in flashinfer's plan function
# This is used to remove some host-to-device copy overhead.
global_override_indptr_cpu = None
def fast_decode_plan(
self,
indptr: torch.Tensor,
@@ -1142,6 +1155,9 @@ def fast_decode_plan(
if logits_soft_cap is None:
logits_soft_cap = 0.0
if self.use_tensor_cores:
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
if self.is_cuda_graph_enabled:
if batch_size != self._fixed_batch_size:
raise ValueError(
@@ -1154,7 +1170,7 @@ def fast_decode_plan(
raise ValueError(
"The size of indices should be less than or equal to the allocated buffer"
)
# Skip these copies
# Skip these copies because we directly write to them during prepartion
# self._paged_kv_indptr_buf.copy_(indptr)
# self._paged_kv_indices_buf[: len(indices)] = indices
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
@@ -1162,6 +1178,7 @@ def fast_decode_plan(
self._paged_kv_indptr_buf = indptr
self._paged_kv_indices_buf = indices
self._paged_kv_last_page_len_buf = last_page_len
self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking)
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
if not q_data_type:
@@ -1184,27 +1201,55 @@ def fast_decode_plan(
)
self.last_page_len = torch.ones(32768, dtype=torch.int32)
empty_q_data = self.empty_q_data
empty_kv_cache = self.empty_kv_cache
stream = torch.cuda.current_stream()
self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
indptr.to("cpu"),
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
window_left,
logits_soft_cap,
head_dim,
head_dim,
empty_q_data,
empty_kv_cache,
stream.cuda_stream,
indptr_host = (
global_override_indptr_cpu
if global_override_indptr_cpu is not None
else indptr.cpu()
)
if self.use_tensor_cores:
kv_lens_arr_host = get_seq_lens(
indptr_host, self.last_page_len[:batch_size], page_size
)
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_host,
indptr_host,
kv_lens_arr_host,
batch_size, # total_num_rows
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
head_dim,
head_dim,
False, # causal
torch.cuda.current_stream().cuda_stream,
)
else:
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
indptr_host,
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
window_left,
logits_soft_cap,
head_dim,
head_dim,
self.empty_q_data,
self.empty_kv_cache,
torch.cuda.current_stream().cuda_stream,
)
self._pos_encoding_mode = pos_encoding_mode
self._window_left = window_left
self._logits_soft_cap = logits_soft_cap

View File

@@ -578,10 +578,12 @@ class TritonMultiStepDraftBackend:
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
forward_batch.batch_size,
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,

View File

@@ -396,16 +396,10 @@ class CudaGraphRunner:
run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
global global_graph_memory_pool
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
out = run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
global_graph_memory_pool = graph.pool()
return graph, out

View File

@@ -26,7 +26,12 @@ def build_tree_kernel_efficient_preprocess(
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
parent_list = torch.cat(parents_list[:-1], dim=1)
if len(parents_list) > 1:
parent_list = torch.cat(parents_list[:-1], dim=1)
else:
batch_size = parents_list[0].shape[0]
parent_list = torch.empty(batch_size, 0, device=parents_list[0].device)
return parent_list, top_scores_index, draft_tokens

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import bisect
import time
from typing import TYPE_CHECKING, Callable
import torch
@@ -162,20 +161,11 @@ class EAGLEDraftCudaGraphRunner:
run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
with torch.cuda.graph(
graph, pool=get_global_graph_memory_pool(), stream=stream
):
out = run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
set_global_graph_memory_pool(graph.pool())
return graph, out
@@ -204,7 +194,7 @@ class EAGLEDraftCudaGraphRunner:
# Attention backend
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
forward_batch
forward_batch, forward_batch.batch_size
)
# Replay

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List
from typing import TYPE_CHECKING, List
import torch
import torch.nn.functional as F
@@ -62,6 +62,7 @@ class EagleDraftInput:
batch.input_ids[pt : pt + extend_len] = torch.concat(
(input_ids[1:], self.verified_id[i].reshape(1))
)
pt += extend_len
def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
assert self.verified_id.numel() == batch.out_cache_loc.shape[0]

View File

@@ -1,20 +1,19 @@
import logging
import os
import time
from typing import Dict, List, Optional, Tuple, Union
from typing import List, Optional, Tuple
import torch
from huggingface_hub import snapshot_download
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
EAGLEDraftCudaGraphRunner,
@@ -27,7 +26,6 @@ from sglang.srt.speculative.eagle_utils import (
fast_topk,
select_top_k_tokens,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import get_available_gpu_memory
logger = logging.getLogger(__name__)
@@ -44,16 +42,30 @@ class EAGLEWorker(TpModelWorker):
nccl_port: int,
target_worker: TpModelWorker,
):
# Parse arguments
self.server_args = server_args
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.padded_static_len = self.speculative_num_steps + 1
self.enable_nan_detection = server_args.enable_nan_detection
self.gpu_id = gpu_id
self.device = server_args.device
self.target_worker = target_worker
# Override context length with target model's context length
server_args.context_length = target_worker.model_runner.model_config.context_len
os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1"
# Do not capture cuda graph in `super().__init__()`
# We will capture it later
# It will be captured later.
backup_disable_cuda_graph = server_args.disable_cuda_graph
server_args.disable_cuda_graph = True
# Share the allocator with a target worker.
# Draft and target worker own their own KV cache pools.
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
target_worker.get_memory_pool()
)
# Lossy optimization by using hot tokens
# Load hot token ids
if server_args.speculative_token_map is not None:
self.hot_token_id = load_token_map(server_args.speculative_token_map)
server_args.json_model_override_args = (
@@ -62,13 +74,7 @@ class EAGLEWorker(TpModelWorker):
else:
self.hot_token_id = None
# We share the allocator with a target worker. Draft/target worker
# owns its own KV cache.
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
target_worker.get_memory_pool()
)
# Init target worker
# Init draft worker
super().__init__(
gpu_id=gpu_id,
tp_rank=tp_rank,
@@ -79,18 +85,6 @@ class EAGLEWorker(TpModelWorker):
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
self.target_worker = target_worker
# Parse arguments
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.server_args = server_args
self.use_nan_detection = self.server_args.enable_nan_detection
self.device = self.model_runner.device
self.gpu_id = self.model_runner.gpu_id
# Share the embedding and lm_head
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
@@ -103,8 +97,12 @@ class EAGLEWorker(TpModelWorker):
backup_disable_cuda_graph
)
self.init_attention_backend()
self.init_cuda_graphs()
def init_attention_backend(self):
# Create multi-step attn backends and cuda graph runners
if server_args.attention_backend == "flashinfer":
if self.server_args.attention_backend == "flashinfer":
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferMultiStepDraftBackend,
)
@@ -114,7 +112,7 @@ class EAGLEWorker(TpModelWorker):
self.topk,
self.speculative_num_steps,
)
elif server_args.attention_backend == "triton":
elif self.server_args.attention_backend == "triton":
from sglang.srt.layers.attention.triton_backend import (
TritonMultiStepDraftBackend,
)
@@ -126,11 +124,9 @@ class EAGLEWorker(TpModelWorker):
)
else:
raise ValueError(
f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
)
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
self.init_cuda_graphs()
def init_cuda_graphs(self):
"""Capture cuda graphs."""
@@ -356,6 +352,41 @@ class EAGLEWorker(TpModelWorker):
batch.forward_mode = ForwardMode.DECODE
batch.spec_info = res.draft_input
if batch.return_logprob:
# Compute output logprobs using the sampler.
num_tokens_per_req = [
accept + 1 for accept in res.accept_length_per_req_cpu
]
self.target_worker.model_runner.update_output_logprobs(
logits_output,
batch.sampling_info,
batch.top_logprobs_nums,
batch.token_ids_logprobs,
res.verified_id,
# +1 for bonus token.
num_tokens_per_req=num_tokens_per_req,
)
# Add output logprobs to the request.
pt = 0
# NOTE: tolist() of these values are skipped when output is processed
next_token_logprobs = res.logits_output.next_token_logprobs.tolist()
verified_ids = res.verified_id.tolist()
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
for _ in range(num_tokens):
if req.return_logprob:
token_id = verified_ids[pt]
req.output_token_logprobs_val.append(next_token_logprobs[pt])
req.output_token_logprobs_idx.append(token_id)
if req.top_logprobs_num > 0:
req.output_top_logprobs_val.append(
res.logits_output.next_token_top_logprobs_val[pt]
)
req.output_top_logprobs_idx.append(
res.logits_output.next_token_top_logprobs_idx[pt]
)
pt += 1
return logits_output, res, model_worker_batch
def forward_draft_extend(
@@ -381,6 +412,7 @@ class EAGLEWorker(TpModelWorker):
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
forward_batch.return_logprob = False
logits_output = self.draft_model_runner.forward(forward_batch)
self._detect_nan_if_needed(logits_output)
assert isinstance(forward_batch.spec_info, EagleDraftInput)
@@ -393,6 +425,8 @@ class EAGLEWorker(TpModelWorker):
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
# We don't need logprob for this extend.
original_return_logprob = batch.return_logprob
batch.return_logprob = False
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
@@ -404,6 +438,7 @@ class EAGLEWorker(TpModelWorker):
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch.return_logprob = original_return_logprob
batch.forward_mode = ForwardMode.DECODE
batch.seq_lens = seq_lens_backup
@@ -415,7 +450,7 @@ class EAGLEWorker(TpModelWorker):
draft_input.hidden_states = logits_output.hidden_states
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
if self.use_nan_detection:
if self.enable_nan_detection:
logits = logits_output.next_token_logits
if torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.")