Apply constraint grammar to EAGLE (#6499)

Co-authored-by: merrymercy <lianminzheng@gmail.com>
This commit is contained in:
Ke Bao
2025-05-22 08:18:41 +08:00
committed by GitHub
parent 969660c762
commit 6ce0ed073b
3 changed files with 198 additions and 0 deletions

View File

@@ -31,6 +31,7 @@ from sglang.srt.speculative.eagle_utils import (
EagleVerifyInput,
EagleVerifyOutput,
assign_draft_cache_locs,
generate_token_bitmask,
select_top_k_tokens,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -492,11 +493,41 @@ class EAGLEWorker(TpModelWorker):
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch()
if batch.has_grammar:
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
retrieve_next_sibling_cpu = spec_info.retrive_next_sibling.cpu()
draft_tokens_cpu = spec_info.draft_token.view(
spec_info.retrive_next_token.shape
).cpu()
# Forward
logits_output, _, can_run_cuda_graph = (
self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True
)
)
vocab_mask = None
if batch.has_grammar:
# Generate the logit mask for structured output.
# Overlap the CPU operations for bitmask generation with the forward pass.
vocab_mask = generate_token_bitmask(
batch.reqs,
spec_info,
retrieve_next_token_cpu,
retrieve_next_sibling_cpu,
draft_tokens_cpu,
batch.sampling_info.vocab_size,
)
if vocab_mask is not None:
assert spec_info.grammar is not None
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
# otherwise, this vocab mask will be the one from the previous extend stage
# and will be applied to produce wrong results
batch.sampling_info.vocab_mask = None
self._detect_nan_if_needed(logits_output)
spec_info.hidden_states = logits_output.hidden_states
res: EagleVerifyOutput = spec_info.verify(
@@ -504,6 +535,7 @@ class EAGLEWorker(TpModelWorker):
logits_output,
self.token_to_kv_pool_allocator,
self.page_size,
vocab_mask,
)
# Post process based on verified outputs.