Apply constraint grammar to EAGLE (#6499)

Co-authored-by: merrymercy <lianminzheng@gmail.com>
This commit is contained in:
Ke Bao
2025-05-22 08:18:41 +08:00
committed by GitHub
parent 969660c762
commit 6ce0ed073b
3 changed files with 198 additions and 0 deletions

View File

@@ -9,15 +9,18 @@ import torch.nn.functional as F
import triton
import triton.language as tl
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import (
Req,
ScheduleBatch,
get_last_loc,
global_server_args_dict,
)
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
@@ -187,6 +190,7 @@ class EagleVerifyInput:
draft_token_num: int
spec_steps: int
capture_hidden_mode: CaptureHiddenMode
grammar: BaseGrammarObject = None
@classmethod
def create(
@@ -307,6 +311,7 @@ class EagleVerifyInput:
logits_output: torch.Tensor,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
page_size: int,
vocab_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Verify and find accepted tokens based on logits output and batch
@@ -343,6 +348,13 @@ class EagleVerifyInput:
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
)
# Apply grammar mask
if vocab_mask is not None:
assert self.grammar is not None
self.grammar.apply_vocab_mask(
logits=logits_output.next_token_logits, vocab_mask=vocab_mask
)
# Sample tokens
if batch.sampling_info.is_all_greedy:
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
@@ -440,6 +452,15 @@ class EagleVerifyInput:
break
else:
new_accept_index_.append(idx)
# update grammar state
if req.grammar is not None:
try:
req.grammar.accept_token(id)
except ValueError as e:
logger.info(
f"{i=}, {req=}\n" f"{accept_index=}\n" f"{predict=}\n"
)
raise e
if not req.finished():
new_accept_index.extend(new_accept_index_)
unfinished_index.append(i)
@@ -801,3 +822,113 @@ def _generate_simulated_accept_index(
accept_length.fill_(simulate_acc_len - 1)
predict.fill_(100) # some legit token id
return sim_accept_index
def traverse_tree(
retrieve_next_token: torch.Tensor,
retrieve_next_sibling: torch.Tensor,
draft_tokens: torch.Tensor,
grammar: BaseGrammarObject,
allocate_token_bitmask: torch.Tensor,
):
"""
Traverse the tree constructed by the draft model to generate the logits mask.
"""
assert (
retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
)
allocate_token_bitmask.fill_(0)
def dfs(
curr: int,
retrieve_next_token: torch.Tensor,
retrieve_next_sibling: torch.Tensor,
parent_pos: int,
):
if curr == 0:
# the first token generated by the target model, and thus it is always
# accepted from the previous iteration
accepted = True
else:
parent_bitmask = allocate_token_bitmask[parent_pos]
curr_token_id = draft_tokens[curr]
# 32 boolean bitmask values are packed into 32-bit integers
accepted = (
parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
) != 0
if accepted:
if curr != 0:
# Accept the current token
grammar.accept_token(draft_tokens[curr])
if not grammar.is_terminated():
# Generate the bitmask for the current token
grammar.fill_vocab_mask(allocate_token_bitmask, curr)
if retrieve_next_token[curr] != -1:
# Visit the child node
dfs(
retrieve_next_token[curr],
retrieve_next_token,
retrieve_next_sibling,
curr,
)
if curr != 0:
# Rollback the current token
grammar.rollback(1)
if retrieve_next_sibling[curr] != -1:
# Visit the sibling node
dfs(
retrieve_next_sibling[curr],
retrieve_next_token,
retrieve_next_sibling,
parent_pos,
)
dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
def generate_token_bitmask(
reqs: List[Req],
verify_input: EagleVerifyInput,
retrieve_next_token_cpu: torch.Tensor,
retrieve_next_sibling_cpu: torch.Tensor,
draft_tokens_cpu: torch.Tensor,
vocab_size: int,
):
"""
Generate the logit mask for structured output.
Draft model's token can be either valid or invalid with respect to the grammar.
We need to perform DFS to figure out:
1. which tokens are accepted by the grammar
2. what is the corresponding logit mask.
"""
num_draft_tokens = draft_tokens_cpu.shape[-1]
allocate_token_bitmask = None
assert len(reqs) == retrieve_next_token_cpu.shape[0]
grammar = None
for i, req in enumerate(reqs):
if req.grammar is not None:
if allocate_token_bitmask is None:
allocate_token_bitmask = req.grammar.allocate_vocab_mask(
vocab_size=vocab_size,
batch_size=draft_tokens_cpu.numel(),
device="cpu",
)
grammar = req.grammar
traverse_tree(
retrieve_next_token_cpu[i],
retrieve_next_sibling_cpu[i],
draft_tokens_cpu[i],
req.grammar,
allocate_token_bitmask[
i * num_draft_tokens : (i + 1) * num_draft_tokens
],
)
verify_input.grammar = grammar
return allocate_token_bitmask

View File

@@ -31,6 +31,7 @@ from sglang.srt.speculative.eagle_utils import (
EagleVerifyInput,
EagleVerifyOutput,
assign_draft_cache_locs,
generate_token_bitmask,
select_top_k_tokens,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -492,11 +493,41 @@ class EAGLEWorker(TpModelWorker):
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch()
if batch.has_grammar:
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
retrieve_next_sibling_cpu = spec_info.retrive_next_sibling.cpu()
draft_tokens_cpu = spec_info.draft_token.view(
spec_info.retrive_next_token.shape
).cpu()
# Forward
logits_output, _, can_run_cuda_graph = (
self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True
)
)
vocab_mask = None
if batch.has_grammar:
# Generate the logit mask for structured output.
# Overlap the CPU operations for bitmask generation with the forward pass.
vocab_mask = generate_token_bitmask(
batch.reqs,
spec_info,
retrieve_next_token_cpu,
retrieve_next_sibling_cpu,
draft_tokens_cpu,
batch.sampling_info.vocab_size,
)
if vocab_mask is not None:
assert spec_info.grammar is not None
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
# otherwise, this vocab mask will be the one from the previous extend stage
# and will be applied to produce wrong results
batch.sampling_info.vocab_mask = None
self._detect_nan_if_needed(logits_output)
spec_info.hidden_states = logits_output.hidden_states
res: EagleVerifyOutput = spec_info.verify(
@@ -504,6 +535,7 @@ class EAGLEWorker(TpModelWorker):
logits_output,
self.token_to_kv_pool_allocator,
self.page_size,
vocab_mask,
)
# Post process based on verified outputs.