From 6ce0ed073bde704ef182f6ecada477725da9a314 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Thu, 22 May 2025 08:18:41 +0800 Subject: [PATCH] Apply constraint grammar to EAGLE (#6499) Co-authored-by: merrymercy --- python/sglang/srt/speculative/eagle_utils.py | 131 ++++++++++++++++++ python/sglang/srt/speculative/eagle_worker.py | 32 +++++ test/srt/test_eagle_infer.py | 35 +++++ 3 files changed, 198 insertions(+) diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index f7d598de9..eb1b3b44f 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -9,15 +9,18 @@ import torch.nn.functional as F import triton import triton.language as tl +from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import ( + Req, ScheduleBatch, get_last_loc, global_server_args_dict, ) from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode +from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2 @@ -187,6 +190,7 @@ class EagleVerifyInput: draft_token_num: int spec_steps: int capture_hidden_mode: CaptureHiddenMode + grammar: BaseGrammarObject = None @classmethod def create( @@ -307,6 +311,7 @@ class EagleVerifyInput: logits_output: torch.Tensor, token_to_kv_pool_allocator: TokenToKVPoolAllocator, page_size: int, + vocab_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Verify and find accepted tokens based on logits output and batch @@ -343,6 +348,13 @@ class EagleVerifyInput: torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) ) + # Apply grammar mask + if vocab_mask is not None: + assert self.grammar is not None + self.grammar.apply_vocab_mask( + logits=logits_output.next_token_logits, vocab_mask=vocab_mask + ) + # Sample tokens if batch.sampling_info.is_all_greedy: target_predict = torch.argmax(logits_output.next_token_logits, dim=-1) @@ -440,6 +452,15 @@ class EagleVerifyInput: break else: new_accept_index_.append(idx) + # update grammar state + if req.grammar is not None: + try: + req.grammar.accept_token(id) + except ValueError as e: + logger.info( + f"{i=}, {req=}\n" f"{accept_index=}\n" f"{predict=}\n" + ) + raise e if not req.finished(): new_accept_index.extend(new_accept_index_) unfinished_index.append(i) @@ -801,3 +822,113 @@ def _generate_simulated_accept_index( accept_length.fill_(simulate_acc_len - 1) predict.fill_(100) # some legit token id return sim_accept_index + + +def traverse_tree( + retrieve_next_token: torch.Tensor, + retrieve_next_sibling: torch.Tensor, + draft_tokens: torch.Tensor, + grammar: BaseGrammarObject, + allocate_token_bitmask: torch.Tensor, +): + """ + Traverse the tree constructed by the draft model to generate the logits mask. + """ + assert ( + retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape + ) + + allocate_token_bitmask.fill_(0) + + def dfs( + curr: int, + retrieve_next_token: torch.Tensor, + retrieve_next_sibling: torch.Tensor, + parent_pos: int, + ): + if curr == 0: + # the first token generated by the target model, and thus it is always + # accepted from the previous iteration + accepted = True + else: + parent_bitmask = allocate_token_bitmask[parent_pos] + curr_token_id = draft_tokens[curr] + # 32 boolean bitmask values are packed into 32-bit integers + accepted = ( + parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32)) + ) != 0 + + if accepted: + if curr != 0: + # Accept the current token + grammar.accept_token(draft_tokens[curr]) + if not grammar.is_terminated(): + # Generate the bitmask for the current token + grammar.fill_vocab_mask(allocate_token_bitmask, curr) + if retrieve_next_token[curr] != -1: + # Visit the child node + dfs( + retrieve_next_token[curr], + retrieve_next_token, + retrieve_next_sibling, + curr, + ) + + if curr != 0: + # Rollback the current token + grammar.rollback(1) + + if retrieve_next_sibling[curr] != -1: + # Visit the sibling node + dfs( + retrieve_next_sibling[curr], + retrieve_next_token, + retrieve_next_sibling, + parent_pos, + ) + + dfs(0, retrieve_next_token, retrieve_next_sibling, -1) + + +def generate_token_bitmask( + reqs: List[Req], + verify_input: EagleVerifyInput, + retrieve_next_token_cpu: torch.Tensor, + retrieve_next_sibling_cpu: torch.Tensor, + draft_tokens_cpu: torch.Tensor, + vocab_size: int, +): + """ + Generate the logit mask for structured output. + Draft model's token can be either valid or invalid with respect to the grammar. + We need to perform DFS to figure out: + 1. which tokens are accepted by the grammar + 2. what is the corresponding logit mask. + """ + + num_draft_tokens = draft_tokens_cpu.shape[-1] + + allocate_token_bitmask = None + assert len(reqs) == retrieve_next_token_cpu.shape[0] + grammar = None + for i, req in enumerate(reqs): + if req.grammar is not None: + if allocate_token_bitmask is None: + allocate_token_bitmask = req.grammar.allocate_vocab_mask( + vocab_size=vocab_size, + batch_size=draft_tokens_cpu.numel(), + device="cpu", + ) + grammar = req.grammar + traverse_tree( + retrieve_next_token_cpu[i], + retrieve_next_sibling_cpu[i], + draft_tokens_cpu[i], + req.grammar, + allocate_token_bitmask[ + i * num_draft_tokens : (i + 1) * num_draft_tokens + ], + ) + + verify_input.grammar = grammar + return allocate_token_bitmask diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index ebbff0e8f..647fafaad 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -31,6 +31,7 @@ from sglang.srt.speculative.eagle_utils import ( EagleVerifyInput, EagleVerifyOutput, assign_draft_cache_locs, + generate_token_bitmask, select_top_k_tokens, ) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm @@ -492,11 +493,41 @@ class EAGLEWorker(TpModelWorker): batch.forward_mode = ForwardMode.TARGET_VERIFY batch.spec_info = spec_info model_worker_batch = batch.get_model_worker_batch() + + if batch.has_grammar: + retrieve_next_token_cpu = spec_info.retrive_next_token.cpu() + retrieve_next_sibling_cpu = spec_info.retrive_next_sibling.cpu() + draft_tokens_cpu = spec_info.draft_token.view( + spec_info.retrive_next_token.shape + ).cpu() + + # Forward logits_output, _, can_run_cuda_graph = ( self.target_worker.forward_batch_generation( model_worker_batch, skip_sample=True ) ) + + vocab_mask = None + if batch.has_grammar: + # Generate the logit mask for structured output. + # Overlap the CPU operations for bitmask generation with the forward pass. + vocab_mask = generate_token_bitmask( + batch.reqs, + spec_info, + retrieve_next_token_cpu, + retrieve_next_sibling_cpu, + draft_tokens_cpu, + batch.sampling_info.vocab_size, + ) + + if vocab_mask is not None: + assert spec_info.grammar is not None + vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device) + # otherwise, this vocab mask will be the one from the previous extend stage + # and will be applied to produce wrong results + batch.sampling_info.vocab_mask = None + self._detect_nan_if_needed(logits_output) spec_info.hidden_states = logits_output.hidden_states res: EagleVerifyOutput = spec_info.verify( @@ -504,6 +535,7 @@ class EAGLEWorker(TpModelWorker): logits_output, self.token_to_kv_pool_allocator, self.page_size, + vocab_mask, ) # Post process based on verified outputs. diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index 7f653777a..4384d2c64 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -481,6 +481,41 @@ class TestEAGLEServer(CustomTestCase): with ThreadPoolExecutor(8) as executor: list(executor.map(self.run_decode, args)) + def test_constrained_decoding(self): + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Give me a json"}, + ] + + response = requests.post( + self.base_url + "/v1/chat/completions", + json={ + "model": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + "messages": messages, + "temperature": 0, + "response_format": {"type": "json_object"}, + }, + ) + self.assertEqual(response.status_code, 200) + res = response.json() + + # Validate response structure + self.assertIn("choices", res) + self.assertEqual(len(res["choices"]), 1) + self.assertIn("message", res["choices"][0]) + self.assertIn("content", res["choices"][0]["message"]) + + # Validate JSON content + content_json = res["choices"][0]["message"]["content"] + is_valid_json = True + try: + content = json.loads(content_json) + self.assertIsInstance(content, dict) + except Exception: + print(f"parse JSON failed: {content_json}") + is_valid_json = False + self.assertTrue(is_valid_json) + class TestEAGLERetract(TestEAGLEServer): @classmethod