Fix a draft model accuracy bug in eagle; support step=1; return logprob in eagle (#4134)
Co-authored-by: Sehoon Kim <kssteven418@gmail.com> Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: Sehoon Kim <sehoon@x.ai>
This commit is contained in:
@@ -1,20 +1,19 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
||||
EAGLEDraftCudaGraphRunner,
|
||||
@@ -27,7 +26,6 @@ from sglang.srt.speculative.eagle_utils import (
|
||||
fast_topk,
|
||||
select_top_k_tokens,
|
||||
)
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import get_available_gpu_memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,16 +42,30 @@ class EAGLEWorker(TpModelWorker):
|
||||
nccl_port: int,
|
||||
target_worker: TpModelWorker,
|
||||
):
|
||||
# Parse arguments
|
||||
self.server_args = server_args
|
||||
self.topk = server_args.speculative_eagle_topk
|
||||
self.speculative_num_steps = server_args.speculative_num_steps
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.enable_nan_detection = server_args.enable_nan_detection
|
||||
self.gpu_id = gpu_id
|
||||
self.device = server_args.device
|
||||
self.target_worker = target_worker
|
||||
|
||||
# Override context length with target model's context length
|
||||
server_args.context_length = target_worker.model_runner.model_config.context_len
|
||||
os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1"
|
||||
|
||||
# Do not capture cuda graph in `super().__init__()`
|
||||
# We will capture it later
|
||||
# It will be captured later.
|
||||
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
||||
server_args.disable_cuda_graph = True
|
||||
# Share the allocator with a target worker.
|
||||
# Draft and target worker own their own KV cache pools.
|
||||
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
||||
target_worker.get_memory_pool()
|
||||
)
|
||||
|
||||
# Lossy optimization by using hot tokens
|
||||
# Load hot token ids
|
||||
if server_args.speculative_token_map is not None:
|
||||
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
||||
server_args.json_model_override_args = (
|
||||
@@ -62,13 +74,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
else:
|
||||
self.hot_token_id = None
|
||||
|
||||
# We share the allocator with a target worker. Draft/target worker
|
||||
# owns its own KV cache.
|
||||
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
||||
target_worker.get_memory_pool()
|
||||
)
|
||||
|
||||
# Init target worker
|
||||
# Init draft worker
|
||||
super().__init__(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
@@ -79,18 +85,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
)
|
||||
self.target_worker = target_worker
|
||||
|
||||
# Parse arguments
|
||||
self.topk = server_args.speculative_eagle_topk
|
||||
self.speculative_num_steps = server_args.speculative_num_steps
|
||||
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
self.server_args = server_args
|
||||
self.use_nan_detection = self.server_args.enable_nan_detection
|
||||
self.device = self.model_runner.device
|
||||
self.gpu_id = self.model_runner.gpu_id
|
||||
|
||||
# Share the embedding and lm_head
|
||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||
@@ -103,8 +97,12 @@ class EAGLEWorker(TpModelWorker):
|
||||
backup_disable_cuda_graph
|
||||
)
|
||||
|
||||
self.init_attention_backend()
|
||||
self.init_cuda_graphs()
|
||||
|
||||
def init_attention_backend(self):
|
||||
# Create multi-step attn backends and cuda graph runners
|
||||
if server_args.attention_backend == "flashinfer":
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
FlashInferMultiStepDraftBackend,
|
||||
)
|
||||
@@ -114,7 +112,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
elif server_args.attention_backend == "triton":
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
from sglang.srt.layers.attention.triton_backend import (
|
||||
TritonMultiStepDraftBackend,
|
||||
)
|
||||
@@ -126,11 +124,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
|
||||
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
|
||||
)
|
||||
|
||||
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
||||
self.init_cuda_graphs()
|
||||
|
||||
def init_cuda_graphs(self):
|
||||
"""Capture cuda graphs."""
|
||||
@@ -356,6 +352,41 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.spec_info = res.draft_input
|
||||
|
||||
if batch.return_logprob:
|
||||
# Compute output logprobs using the sampler.
|
||||
num_tokens_per_req = [
|
||||
accept + 1 for accept in res.accept_length_per_req_cpu
|
||||
]
|
||||
self.target_worker.model_runner.update_output_logprobs(
|
||||
logits_output,
|
||||
batch.sampling_info,
|
||||
batch.top_logprobs_nums,
|
||||
batch.token_ids_logprobs,
|
||||
res.verified_id,
|
||||
# +1 for bonus token.
|
||||
num_tokens_per_req=num_tokens_per_req,
|
||||
)
|
||||
|
||||
# Add output logprobs to the request.
|
||||
pt = 0
|
||||
# NOTE: tolist() of these values are skipped when output is processed
|
||||
next_token_logprobs = res.logits_output.next_token_logprobs.tolist()
|
||||
verified_ids = res.verified_id.tolist()
|
||||
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
|
||||
for _ in range(num_tokens):
|
||||
if req.return_logprob:
|
||||
token_id = verified_ids[pt]
|
||||
req.output_token_logprobs_val.append(next_token_logprobs[pt])
|
||||
req.output_token_logprobs_idx.append(token_id)
|
||||
if req.top_logprobs_num > 0:
|
||||
req.output_top_logprobs_val.append(
|
||||
res.logits_output.next_token_top_logprobs_val[pt]
|
||||
)
|
||||
req.output_top_logprobs_idx.append(
|
||||
res.logits_output.next_token_top_logprobs_idx[pt]
|
||||
)
|
||||
pt += 1
|
||||
|
||||
return logits_output, res, model_worker_batch
|
||||
|
||||
def forward_draft_extend(
|
||||
@@ -381,6 +412,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
forward_batch.return_logprob = False
|
||||
logits_output = self.draft_model_runner.forward(forward_batch)
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
@@ -393,6 +425,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
# We don't need logprob for this extend.
|
||||
original_return_logprob = batch.return_logprob
|
||||
batch.return_logprob = False
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
@@ -404,6 +438,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
# Restore backup.
|
||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||
batch.return_logprob = original_return_logprob
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.seq_lens = seq_lens_backup
|
||||
|
||||
@@ -415,7 +450,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
draft_input.hidden_states = logits_output.hidden_states
|
||||
|
||||
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
|
||||
if self.use_nan_detection:
|
||||
if self.enable_nan_detection:
|
||||
logits = logits_output.next_token_logits
|
||||
if torch.any(torch.isnan(logits)):
|
||||
logger.warning("Detected errors during sampling! NaN in the logits.")
|
||||
|
||||
Reference in New Issue
Block a user