Apply constraint grammar to EAGLE (#6499)
Co-authored-by: merrymercy <lianminzheng@gmail.com>
This commit is contained in:
@@ -9,15 +9,18 @@ import torch.nn.functional as F
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import (
|
||||||
|
Req,
|
||||||
ScheduleBatch,
|
ScheduleBatch,
|
||||||
get_last_loc,
|
get_last_loc,
|
||||||
global_server_args_dict,
|
global_server_args_dict,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||||
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
||||||
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
|
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
|
||||||
|
|
||||||
@@ -187,6 +190,7 @@ class EagleVerifyInput:
|
|||||||
draft_token_num: int
|
draft_token_num: int
|
||||||
spec_steps: int
|
spec_steps: int
|
||||||
capture_hidden_mode: CaptureHiddenMode
|
capture_hidden_mode: CaptureHiddenMode
|
||||||
|
grammar: BaseGrammarObject = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
@@ -307,6 +311,7 @@ class EagleVerifyInput:
|
|||||||
logits_output: torch.Tensor,
|
logits_output: torch.Tensor,
|
||||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
|
vocab_mask: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Verify and find accepted tokens based on logits output and batch
|
Verify and find accepted tokens based on logits output and batch
|
||||||
@@ -343,6 +348,13 @@ class EagleVerifyInput:
|
|||||||
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
|
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply grammar mask
|
||||||
|
if vocab_mask is not None:
|
||||||
|
assert self.grammar is not None
|
||||||
|
self.grammar.apply_vocab_mask(
|
||||||
|
logits=logits_output.next_token_logits, vocab_mask=vocab_mask
|
||||||
|
)
|
||||||
|
|
||||||
# Sample tokens
|
# Sample tokens
|
||||||
if batch.sampling_info.is_all_greedy:
|
if batch.sampling_info.is_all_greedy:
|
||||||
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
||||||
@@ -440,6 +452,15 @@ class EagleVerifyInput:
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
new_accept_index_.append(idx)
|
new_accept_index_.append(idx)
|
||||||
|
# update grammar state
|
||||||
|
if req.grammar is not None:
|
||||||
|
try:
|
||||||
|
req.grammar.accept_token(id)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.info(
|
||||||
|
f"{i=}, {req=}\n" f"{accept_index=}\n" f"{predict=}\n"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
if not req.finished():
|
if not req.finished():
|
||||||
new_accept_index.extend(new_accept_index_)
|
new_accept_index.extend(new_accept_index_)
|
||||||
unfinished_index.append(i)
|
unfinished_index.append(i)
|
||||||
@@ -801,3 +822,113 @@ def _generate_simulated_accept_index(
|
|||||||
accept_length.fill_(simulate_acc_len - 1)
|
accept_length.fill_(simulate_acc_len - 1)
|
||||||
predict.fill_(100) # some legit token id
|
predict.fill_(100) # some legit token id
|
||||||
return sim_accept_index
|
return sim_accept_index
|
||||||
|
|
||||||
|
|
||||||
|
def traverse_tree(
|
||||||
|
retrieve_next_token: torch.Tensor,
|
||||||
|
retrieve_next_sibling: torch.Tensor,
|
||||||
|
draft_tokens: torch.Tensor,
|
||||||
|
grammar: BaseGrammarObject,
|
||||||
|
allocate_token_bitmask: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Traverse the tree constructed by the draft model to generate the logits mask.
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
|
||||||
|
)
|
||||||
|
|
||||||
|
allocate_token_bitmask.fill_(0)
|
||||||
|
|
||||||
|
def dfs(
|
||||||
|
curr: int,
|
||||||
|
retrieve_next_token: torch.Tensor,
|
||||||
|
retrieve_next_sibling: torch.Tensor,
|
||||||
|
parent_pos: int,
|
||||||
|
):
|
||||||
|
if curr == 0:
|
||||||
|
# the first token generated by the target model, and thus it is always
|
||||||
|
# accepted from the previous iteration
|
||||||
|
accepted = True
|
||||||
|
else:
|
||||||
|
parent_bitmask = allocate_token_bitmask[parent_pos]
|
||||||
|
curr_token_id = draft_tokens[curr]
|
||||||
|
# 32 boolean bitmask values are packed into 32-bit integers
|
||||||
|
accepted = (
|
||||||
|
parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
|
||||||
|
) != 0
|
||||||
|
|
||||||
|
if accepted:
|
||||||
|
if curr != 0:
|
||||||
|
# Accept the current token
|
||||||
|
grammar.accept_token(draft_tokens[curr])
|
||||||
|
if not grammar.is_terminated():
|
||||||
|
# Generate the bitmask for the current token
|
||||||
|
grammar.fill_vocab_mask(allocate_token_bitmask, curr)
|
||||||
|
if retrieve_next_token[curr] != -1:
|
||||||
|
# Visit the child node
|
||||||
|
dfs(
|
||||||
|
retrieve_next_token[curr],
|
||||||
|
retrieve_next_token,
|
||||||
|
retrieve_next_sibling,
|
||||||
|
curr,
|
||||||
|
)
|
||||||
|
|
||||||
|
if curr != 0:
|
||||||
|
# Rollback the current token
|
||||||
|
grammar.rollback(1)
|
||||||
|
|
||||||
|
if retrieve_next_sibling[curr] != -1:
|
||||||
|
# Visit the sibling node
|
||||||
|
dfs(
|
||||||
|
retrieve_next_sibling[curr],
|
||||||
|
retrieve_next_token,
|
||||||
|
retrieve_next_sibling,
|
||||||
|
parent_pos,
|
||||||
|
)
|
||||||
|
|
||||||
|
dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_token_bitmask(
|
||||||
|
reqs: List[Req],
|
||||||
|
verify_input: EagleVerifyInput,
|
||||||
|
retrieve_next_token_cpu: torch.Tensor,
|
||||||
|
retrieve_next_sibling_cpu: torch.Tensor,
|
||||||
|
draft_tokens_cpu: torch.Tensor,
|
||||||
|
vocab_size: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate the logit mask for structured output.
|
||||||
|
Draft model's token can be either valid or invalid with respect to the grammar.
|
||||||
|
We need to perform DFS to figure out:
|
||||||
|
1. which tokens are accepted by the grammar
|
||||||
|
2. what is the corresponding logit mask.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_draft_tokens = draft_tokens_cpu.shape[-1]
|
||||||
|
|
||||||
|
allocate_token_bitmask = None
|
||||||
|
assert len(reqs) == retrieve_next_token_cpu.shape[0]
|
||||||
|
grammar = None
|
||||||
|
for i, req in enumerate(reqs):
|
||||||
|
if req.grammar is not None:
|
||||||
|
if allocate_token_bitmask is None:
|
||||||
|
allocate_token_bitmask = req.grammar.allocate_vocab_mask(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
batch_size=draft_tokens_cpu.numel(),
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
grammar = req.grammar
|
||||||
|
traverse_tree(
|
||||||
|
retrieve_next_token_cpu[i],
|
||||||
|
retrieve_next_sibling_cpu[i],
|
||||||
|
draft_tokens_cpu[i],
|
||||||
|
req.grammar,
|
||||||
|
allocate_token_bitmask[
|
||||||
|
i * num_draft_tokens : (i + 1) * num_draft_tokens
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
verify_input.grammar = grammar
|
||||||
|
return allocate_token_bitmask
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from sglang.srt.speculative.eagle_utils import (
|
|||||||
EagleVerifyInput,
|
EagleVerifyInput,
|
||||||
EagleVerifyOutput,
|
EagleVerifyOutput,
|
||||||
assign_draft_cache_locs,
|
assign_draft_cache_locs,
|
||||||
|
generate_token_bitmask,
|
||||||
select_top_k_tokens,
|
select_top_k_tokens,
|
||||||
)
|
)
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
@@ -492,11 +493,41 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||||
batch.spec_info = spec_info
|
batch.spec_info = spec_info
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
|
||||||
|
if batch.has_grammar:
|
||||||
|
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
|
||||||
|
retrieve_next_sibling_cpu = spec_info.retrive_next_sibling.cpu()
|
||||||
|
draft_tokens_cpu = spec_info.draft_token.view(
|
||||||
|
spec_info.retrive_next_token.shape
|
||||||
|
).cpu()
|
||||||
|
|
||||||
|
# Forward
|
||||||
logits_output, _, can_run_cuda_graph = (
|
logits_output, _, can_run_cuda_graph = (
|
||||||
self.target_worker.forward_batch_generation(
|
self.target_worker.forward_batch_generation(
|
||||||
model_worker_batch, skip_sample=True
|
model_worker_batch, skip_sample=True
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
vocab_mask = None
|
||||||
|
if batch.has_grammar:
|
||||||
|
# Generate the logit mask for structured output.
|
||||||
|
# Overlap the CPU operations for bitmask generation with the forward pass.
|
||||||
|
vocab_mask = generate_token_bitmask(
|
||||||
|
batch.reqs,
|
||||||
|
spec_info,
|
||||||
|
retrieve_next_token_cpu,
|
||||||
|
retrieve_next_sibling_cpu,
|
||||||
|
draft_tokens_cpu,
|
||||||
|
batch.sampling_info.vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if vocab_mask is not None:
|
||||||
|
assert spec_info.grammar is not None
|
||||||
|
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
|
||||||
|
# otherwise, this vocab mask will be the one from the previous extend stage
|
||||||
|
# and will be applied to produce wrong results
|
||||||
|
batch.sampling_info.vocab_mask = None
|
||||||
|
|
||||||
self._detect_nan_if_needed(logits_output)
|
self._detect_nan_if_needed(logits_output)
|
||||||
spec_info.hidden_states = logits_output.hidden_states
|
spec_info.hidden_states = logits_output.hidden_states
|
||||||
res: EagleVerifyOutput = spec_info.verify(
|
res: EagleVerifyOutput = spec_info.verify(
|
||||||
@@ -504,6 +535,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
logits_output,
|
logits_output,
|
||||||
self.token_to_kv_pool_allocator,
|
self.token_to_kv_pool_allocator,
|
||||||
self.page_size,
|
self.page_size,
|
||||||
|
vocab_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Post process based on verified outputs.
|
# Post process based on verified outputs.
|
||||||
|
|||||||
@@ -481,6 +481,41 @@ class TestEAGLEServer(CustomTestCase):
|
|||||||
with ThreadPoolExecutor(8) as executor:
|
with ThreadPoolExecutor(8) as executor:
|
||||||
list(executor.map(self.run_decode, args))
|
list(executor.map(self.run_decode, args))
|
||||||
|
|
||||||
|
def test_constrained_decoding(self):
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Give me a json"},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": 0,
|
||||||
|
"response_format": {"type": "json_object"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
res = response.json()
|
||||||
|
|
||||||
|
# Validate response structure
|
||||||
|
self.assertIn("choices", res)
|
||||||
|
self.assertEqual(len(res["choices"]), 1)
|
||||||
|
self.assertIn("message", res["choices"][0])
|
||||||
|
self.assertIn("content", res["choices"][0]["message"])
|
||||||
|
|
||||||
|
# Validate JSON content
|
||||||
|
content_json = res["choices"][0]["message"]["content"]
|
||||||
|
is_valid_json = True
|
||||||
|
try:
|
||||||
|
content = json.loads(content_json)
|
||||||
|
self.assertIsInstance(content, dict)
|
||||||
|
except Exception:
|
||||||
|
print(f"parse JSON failed: {content_json}")
|
||||||
|
is_valid_json = False
|
||||||
|
self.assertTrue(is_valid_json)
|
||||||
|
|
||||||
|
|
||||||
class TestEAGLERetract(TestEAGLEServer):
|
class TestEAGLERetract(TestEAGLEServer):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user