Apply constraint grammar to EAGLE (#6499)
Co-authored-by: merrymercy <lianminzheng@gmail.com>
This commit is contained in:
@@ -31,6 +31,7 @@ from sglang.srt.speculative.eagle_utils import (
|
||||
EagleVerifyInput,
|
||||
EagleVerifyOutput,
|
||||
assign_draft_cache_locs,
|
||||
generate_token_bitmask,
|
||||
select_top_k_tokens,
|
||||
)
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
@@ -492,11 +493,41 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||
batch.spec_info = spec_info
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
|
||||
if batch.has_grammar:
|
||||
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
|
||||
retrieve_next_sibling_cpu = spec_info.retrive_next_sibling.cpu()
|
||||
draft_tokens_cpu = spec_info.draft_token.view(
|
||||
spec_info.retrive_next_token.shape
|
||||
).cpu()
|
||||
|
||||
# Forward
|
||||
logits_output, _, can_run_cuda_graph = (
|
||||
self.target_worker.forward_batch_generation(
|
||||
model_worker_batch, skip_sample=True
|
||||
)
|
||||
)
|
||||
|
||||
vocab_mask = None
|
||||
if batch.has_grammar:
|
||||
# Generate the logit mask for structured output.
|
||||
# Overlap the CPU operations for bitmask generation with the forward pass.
|
||||
vocab_mask = generate_token_bitmask(
|
||||
batch.reqs,
|
||||
spec_info,
|
||||
retrieve_next_token_cpu,
|
||||
retrieve_next_sibling_cpu,
|
||||
draft_tokens_cpu,
|
||||
batch.sampling_info.vocab_size,
|
||||
)
|
||||
|
||||
if vocab_mask is not None:
|
||||
assert spec_info.grammar is not None
|
||||
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
|
||||
# otherwise, this vocab mask will be the one from the previous extend stage
|
||||
# and will be applied to produce wrong results
|
||||
batch.sampling_info.vocab_mask = None
|
||||
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
spec_info.hidden_states = logits_output.hidden_states
|
||||
res: EagleVerifyOutput = spec_info.verify(
|
||||
@@ -504,6 +535,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
logits_output,
|
||||
self.token_to_kv_pool_allocator,
|
||||
self.page_size,
|
||||
vocab_mask,
|
||||
)
|
||||
|
||||
# Post process based on verified outputs.
|
||||
|
||||
Reference in New Issue
Block a user