[EAGLE] many fixes for eagle (#4195)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: Sehoon Kim <sehoon@x.ai>
This commit is contained in:
Lianmin Zheng
2025-03-07 22:12:13 -08:00
parent d052f4c8a9
commit d4017a6b63
15 changed files with 202 additions and 135 deletions

View File

@@ -18,12 +18,15 @@ dependencies = ["requests", "tqdm", "numpy", "IPython", "setproctitle"]
[project.optional-dependencies]
runtime_common = [
"aiohttp",
"datasets",
"decord",
"fastapi",
"hf_transfer",
"huggingface_hub",
"interegular",
"llguidance>=0.6.15",
"modelscope",
"ninja",
"orjson",
"packaging",
"pillow",
@@ -33,13 +36,10 @@ runtime_common = [
"python-multipart",
"pyzmq>=25.1.2",
"torchao>=0.7.0",
"transformers==4.48.3",
"uvicorn",
"uvloop",
"xgrammar==0.1.14",
"ninja",
"transformers==4.48.3",
"llguidance>=0.6.15",
"datasets"
]
srt = [

View File

@@ -81,7 +81,7 @@ class ModelConfig:
if context_length is not None:
if context_length > derived_context_len:
if get_bool_env_var(
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="False"
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="True"
):
logger.warning(
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "

View File

@@ -106,6 +106,8 @@ class Engine:
tokenizer_manager, scheduler_info = _launch_subprocesses(
server_args=server_args
)
self.server_args = server_args
self.tokenizer_manager = tokenizer_manager
self.scheduler_info = scheduler_info

View File

@@ -42,7 +42,6 @@ class Sampler(nn.Module):
return_logprob: bool,
top_logprobs_nums: List[int],
token_ids_logprobs: List[List[int]],
batch_next_token_ids: Optional[torch.Tensor] = None,
):
"""Run a sampler & compute logprobs and update logits_output accordingly.
@@ -72,8 +71,7 @@ class Sampler(nn.Module):
if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling
if batch_next_token_ids is None:
batch_next_token_ids = torch.argmax(logits, -1)
batch_next_token_ids = torch.argmax(logits, -1)
if return_logprob:
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
else:
@@ -94,43 +92,39 @@ class Sampler(nn.Module):
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
).clamp(min=torch.finfo(probs.dtype).min)
if batch_next_token_ids is None:
max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
)
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids = min_p_sampling_from_probs(
probs, uniform_samples, sampling_info.min_ps
)
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids = min_p_sampling_from_probs(
probs, uniform_samples, sampling_info.min_ps
)
else:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
probs,
uniform_samples,
sampling_info.top_ks,
sampling_info.top_ps,
filter_apply_order="joint",
)
if self.use_nan_detection and not torch.all(success):
logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(
batch_next_token_ids
)
elif global_server_args_dict["sampling_backend"] == "pytorch":
if batch_next_token_ids is None:
# A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
else:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
probs,
uniform_samples,
sampling_info.top_ks,
sampling_info.top_ps,
sampling_info.min_ps,
sampling_info.need_min_p_sampling,
filter_apply_order="joint",
)
if self.use_nan_detection and not torch.all(success):
logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
elif global_server_args_dict["sampling_backend"] == "pytorch":
# A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
probs,
sampling_info.top_ks,
sampling_info.top_ps,
sampling_info.min_ps,
sampling_info.need_min_p_sampling,
)
if return_logprob:
# clamp to avoid -inf
logprobs = torch.log(

View File

@@ -957,11 +957,13 @@ class Scheduler:
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
self.batch_is_full = False
# Filter batch
last_bs = self.last_batch.batch_size()
self.last_batch.filter_batch()
if self.last_batch.batch_size() < last_bs:
self.batch_is_full = False
# Merge the new batch into the running batch
if not self.last_batch.is_empty():
if self.running_batch is None:
self.running_batch = self.last_batch

View File

@@ -300,10 +300,11 @@ class CudaGraphRunner:
def capture(self):
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
# Reverse the order to enable better memory sharing across cuda graphs.
capture_range = (
tqdm.tqdm(self.capture_bs)
tqdm.tqdm(reversed(self.capture_bs))
if get_tensor_model_parallel_rank() == 0
else self.capture_bs
else reversed(self.capture_bs)
)
for bs in capture_range:
with patch_model(

View File

@@ -928,45 +928,6 @@ class ModelRunner:
sampling_info.update_regex_vocab_mask()
sampling_info.apply_logits_bias(logits_output.next_token_logits)
def update_output_logprobs(
self,
logits_output: LogitsProcessorOutput,
sampling_info: SamplingBatchInfo,
top_logprobs_nums: List[int],
token_ids_logprobs: List[int],
next_token_ids: torch.Tensor,
*,
num_tokens_per_req: List[int],
):
"""Update the logits_output's output logprob based on next_token_ids
Args:
logits_output: The logits output from the model forward
sampling_info: Sampling info for logprob calculation
top_logprobs_nums: Number of logprobs per request.
next_token_ids: Next token ids.
num_tokens_per_req: The number of tokens per request.
Returns:
A list of next_token_ids
"""
self._preprocess_logits(logits_output, sampling_info)
# We should repeat top_logprobs_nums to match num_tokens_per_req.
top_logprobs_nums_repeat_interleaved = []
token_ids_logprobs_repeat_interleaved = []
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
self.sampler(
logits_output,
sampling_info,
True,
top_logprobs_nums_repeat_interleaved,
token_ids_logprobs_repeat_interleaved,
batch_next_token_ids=next_token_ids,
)
def sample(
self,
logits_output: LogitsProcessorOutput,

View File

@@ -56,7 +56,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
]
def _merge(self, their: "BatchedFrequencyPenalizer"):
print(f"{self.frequency_penalties.shape=}, {their.frequency_penalties.shape=}")
self.frequency_penalties = torch.cat(
[self.frequency_penalties, their.frequency_penalties], dim=0
)

View File

@@ -56,7 +56,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
]
def _merge(self, their: "BatchedPresencePenalizer"):
print(f"{self.presence_penalties.shape=}, {their.presence_penalties.shape=}")
self.presence_penalties = torch.cat(
[self.presence_penalties, their.presence_penalties], dim=0
)

View File

@@ -7,6 +7,7 @@ import torch
from huggingface_hub import snapshot_download
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import (
@@ -302,13 +303,10 @@ class EAGLEWorker(TpModelWorker):
# Set inputs
forward_batch.input_ids = input_ids
out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
forward_batch.out_cache_loc = out_cache_loc[
forward_batch.batch_size
* self.topk
* i : forward_batch.batch_size
* self.topk
* (i + 1)
]
:, self.topk * i : self.topk * (i + 1)
].flatten()
forward_batch.positions.add_(1)
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
spec_info.hidden_states = hidden_states
@@ -353,42 +351,70 @@ class EAGLEWorker(TpModelWorker):
batch.spec_info = res.draft_input
if batch.return_logprob:
# Compute output logprobs using the sampler.
num_tokens_per_req = [
accept + 1 for accept in res.accept_length_per_req_cpu
]
self.target_worker.model_runner.update_output_logprobs(
logits_output,
batch.sampling_info,
batch.top_logprobs_nums,
batch.token_ids_logprobs,
res.verified_id,
# +1 for bonus token.
num_tokens_per_req=num_tokens_per_req,
)
# Add output logprobs to the request.
pt = 0
# NOTE: tolist() of these values are skipped when output is processed
next_token_logprobs = res.logits_output.next_token_logprobs.tolist()
verified_ids = res.verified_id.tolist()
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
for _ in range(num_tokens):
if req.return_logprob:
token_id = verified_ids[pt]
req.output_token_logprobs_val.append(next_token_logprobs[pt])
req.output_token_logprobs_idx.append(token_id)
if req.top_logprobs_num > 0:
req.output_top_logprobs_val.append(
res.logits_output.next_token_top_logprobs_val[pt]
)
req.output_top_logprobs_idx.append(
res.logits_output.next_token_top_logprobs_idx[pt]
)
pt += 1
self.add_logprob_values(batch, res, logits_output)
return logits_output, res, model_worker_batch
def add_logprob_values(
self,
batch: ScheduleBatch,
res: EagleVerifyOutput,
logits_output: LogitsProcessorOutput,
):
# Extract args
logits_output = res.logits_output
top_logprobs_nums = batch.top_logprobs_nums
token_ids_logprobs = batch.token_ids_logprobs
logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits, dim=-1
)
batch_next_token_ids = res.verified_id
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
# We should repeat top_logprobs_nums to match num_tokens_per_req.
top_logprobs_nums_repeat_interleaved = []
token_ids_logprobs_repeat_interleaved = []
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
# Extract logprobs
if any(x > 0 for x in top_logprobs_nums):
(
logits_output.next_token_top_logprobs_val,
logits_output.next_token_top_logprobs_idx,
) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved)
if any(x is not None for x in token_ids_logprobs):
(
logits_output.next_token_token_ids_logprobs_val,
logits_output.next_token_token_ids_logprobs_idx,
) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved)
logits_output.next_token_logprobs = logprobs[
torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
batch_next_token_ids,
]
# Add output logprobs to the request.
pt = 0
next_token_logprobs = logits_output.next_token_logprobs.tolist()
verified_ids = batch_next_token_ids.tolist()
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
for _ in range(num_tokens):
if req.return_logprob:
req.output_token_logprobs_val.append(next_token_logprobs[pt])
req.output_token_logprobs_idx.append(verified_ids[pt])
if req.top_logprobs_num > 0:
req.output_top_logprobs_val.append(
res.logits_output.next_token_top_logprobs_val[pt]
)
req.output_top_logprobs_idx.append(
res.logits_output.next_token_top_logprobs_idx[pt]
)
pt += 1
def forward_draft_extend(
self,
batch: ScheduleBatch,