Apply constraint grammar to EAGLE (#6499)
Co-authored-by: merrymercy <lianminzheng@gmail.com>
This commit is contained in:
@@ -9,15 +9,18 @@ import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Req,
|
||||
ScheduleBatch,
|
||||
get_last_loc,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
||||
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
|
||||
|
||||
@@ -187,6 +190,7 @@ class EagleVerifyInput:
|
||||
draft_token_num: int
|
||||
spec_steps: int
|
||||
capture_hidden_mode: CaptureHiddenMode
|
||||
grammar: BaseGrammarObject = None
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -307,6 +311,7 @@ class EagleVerifyInput:
|
||||
logits_output: torch.Tensor,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
page_size: int,
|
||||
vocab_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Verify and find accepted tokens based on logits output and batch
|
||||
@@ -343,6 +348,13 @@ class EagleVerifyInput:
|
||||
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
|
||||
)
|
||||
|
||||
# Apply grammar mask
|
||||
if vocab_mask is not None:
|
||||
assert self.grammar is not None
|
||||
self.grammar.apply_vocab_mask(
|
||||
logits=logits_output.next_token_logits, vocab_mask=vocab_mask
|
||||
)
|
||||
|
||||
# Sample tokens
|
||||
if batch.sampling_info.is_all_greedy:
|
||||
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
||||
@@ -440,6 +452,15 @@ class EagleVerifyInput:
|
||||
break
|
||||
else:
|
||||
new_accept_index_.append(idx)
|
||||
# update grammar state
|
||||
if req.grammar is not None:
|
||||
try:
|
||||
req.grammar.accept_token(id)
|
||||
except ValueError as e:
|
||||
logger.info(
|
||||
f"{i=}, {req=}\n" f"{accept_index=}\n" f"{predict=}\n"
|
||||
)
|
||||
raise e
|
||||
if not req.finished():
|
||||
new_accept_index.extend(new_accept_index_)
|
||||
unfinished_index.append(i)
|
||||
@@ -801,3 +822,113 @@ def _generate_simulated_accept_index(
|
||||
accept_length.fill_(simulate_acc_len - 1)
|
||||
predict.fill_(100) # some legit token id
|
||||
return sim_accept_index
|
||||
|
||||
|
||||
def traverse_tree(
|
||||
retrieve_next_token: torch.Tensor,
|
||||
retrieve_next_sibling: torch.Tensor,
|
||||
draft_tokens: torch.Tensor,
|
||||
grammar: BaseGrammarObject,
|
||||
allocate_token_bitmask: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Traverse the tree constructed by the draft model to generate the logits mask.
|
||||
"""
|
||||
assert (
|
||||
retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
|
||||
)
|
||||
|
||||
allocate_token_bitmask.fill_(0)
|
||||
|
||||
def dfs(
|
||||
curr: int,
|
||||
retrieve_next_token: torch.Tensor,
|
||||
retrieve_next_sibling: torch.Tensor,
|
||||
parent_pos: int,
|
||||
):
|
||||
if curr == 0:
|
||||
# the first token generated by the target model, and thus it is always
|
||||
# accepted from the previous iteration
|
||||
accepted = True
|
||||
else:
|
||||
parent_bitmask = allocate_token_bitmask[parent_pos]
|
||||
curr_token_id = draft_tokens[curr]
|
||||
# 32 boolean bitmask values are packed into 32-bit integers
|
||||
accepted = (
|
||||
parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
|
||||
) != 0
|
||||
|
||||
if accepted:
|
||||
if curr != 0:
|
||||
# Accept the current token
|
||||
grammar.accept_token(draft_tokens[curr])
|
||||
if not grammar.is_terminated():
|
||||
# Generate the bitmask for the current token
|
||||
grammar.fill_vocab_mask(allocate_token_bitmask, curr)
|
||||
if retrieve_next_token[curr] != -1:
|
||||
# Visit the child node
|
||||
dfs(
|
||||
retrieve_next_token[curr],
|
||||
retrieve_next_token,
|
||||
retrieve_next_sibling,
|
||||
curr,
|
||||
)
|
||||
|
||||
if curr != 0:
|
||||
# Rollback the current token
|
||||
grammar.rollback(1)
|
||||
|
||||
if retrieve_next_sibling[curr] != -1:
|
||||
# Visit the sibling node
|
||||
dfs(
|
||||
retrieve_next_sibling[curr],
|
||||
retrieve_next_token,
|
||||
retrieve_next_sibling,
|
||||
parent_pos,
|
||||
)
|
||||
|
||||
dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
|
||||
|
||||
|
||||
def generate_token_bitmask(
|
||||
reqs: List[Req],
|
||||
verify_input: EagleVerifyInput,
|
||||
retrieve_next_token_cpu: torch.Tensor,
|
||||
retrieve_next_sibling_cpu: torch.Tensor,
|
||||
draft_tokens_cpu: torch.Tensor,
|
||||
vocab_size: int,
|
||||
):
|
||||
"""
|
||||
Generate the logit mask for structured output.
|
||||
Draft model's token can be either valid or invalid with respect to the grammar.
|
||||
We need to perform DFS to figure out:
|
||||
1. which tokens are accepted by the grammar
|
||||
2. what is the corresponding logit mask.
|
||||
"""
|
||||
|
||||
num_draft_tokens = draft_tokens_cpu.shape[-1]
|
||||
|
||||
allocate_token_bitmask = None
|
||||
assert len(reqs) == retrieve_next_token_cpu.shape[0]
|
||||
grammar = None
|
||||
for i, req in enumerate(reqs):
|
||||
if req.grammar is not None:
|
||||
if allocate_token_bitmask is None:
|
||||
allocate_token_bitmask = req.grammar.allocate_vocab_mask(
|
||||
vocab_size=vocab_size,
|
||||
batch_size=draft_tokens_cpu.numel(),
|
||||
device="cpu",
|
||||
)
|
||||
grammar = req.grammar
|
||||
traverse_tree(
|
||||
retrieve_next_token_cpu[i],
|
||||
retrieve_next_sibling_cpu[i],
|
||||
draft_tokens_cpu[i],
|
||||
req.grammar,
|
||||
allocate_token_bitmask[
|
||||
i * num_draft_tokens : (i + 1) * num_draft_tokens
|
||||
],
|
||||
)
|
||||
|
||||
verify_input.grammar = grammar
|
||||
return allocate_token_bitmask
|
||||
|
||||
@@ -31,6 +31,7 @@ from sglang.srt.speculative.eagle_utils import (
|
||||
EagleVerifyInput,
|
||||
EagleVerifyOutput,
|
||||
assign_draft_cache_locs,
|
||||
generate_token_bitmask,
|
||||
select_top_k_tokens,
|
||||
)
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
@@ -492,11 +493,41 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||
batch.spec_info = spec_info
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
|
||||
if batch.has_grammar:
|
||||
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
|
||||
retrieve_next_sibling_cpu = spec_info.retrive_next_sibling.cpu()
|
||||
draft_tokens_cpu = spec_info.draft_token.view(
|
||||
spec_info.retrive_next_token.shape
|
||||
).cpu()
|
||||
|
||||
# Forward
|
||||
logits_output, _, can_run_cuda_graph = (
|
||||
self.target_worker.forward_batch_generation(
|
||||
model_worker_batch, skip_sample=True
|
||||
)
|
||||
)
|
||||
|
||||
vocab_mask = None
|
||||
if batch.has_grammar:
|
||||
# Generate the logit mask for structured output.
|
||||
# Overlap the CPU operations for bitmask generation with the forward pass.
|
||||
vocab_mask = generate_token_bitmask(
|
||||
batch.reqs,
|
||||
spec_info,
|
||||
retrieve_next_token_cpu,
|
||||
retrieve_next_sibling_cpu,
|
||||
draft_tokens_cpu,
|
||||
batch.sampling_info.vocab_size,
|
||||
)
|
||||
|
||||
if vocab_mask is not None:
|
||||
assert spec_info.grammar is not None
|
||||
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
|
||||
# otherwise, this vocab mask will be the one from the previous extend stage
|
||||
# and will be applied to produce wrong results
|
||||
batch.sampling_info.vocab_mask = None
|
||||
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
spec_info.hidden_states = logits_output.hidden_states
|
||||
res: EagleVerifyOutput = spec_info.verify(
|
||||
@@ -504,6 +535,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
logits_output,
|
||||
self.token_to_kv_pool_allocator,
|
||||
self.page_size,
|
||||
vocab_mask,
|
||||
)
|
||||
|
||||
# Post process based on verified outputs.
|
||||
|
||||
Reference in New Issue
Block a user