diff --git a/python/sglang/srt/constrained/outlines_backend.py b/python/sglang/srt/constrained/outlines_backend.py index cc68b97f8..831c1d1a9 100644 --- a/python/sglang/srt/constrained/outlines_backend.py +++ b/python/sglang/srt/constrained/outlines_backend.py @@ -81,10 +81,20 @@ class OutlinesGrammar(BaseGrammarObject): ): self.state = next_state - def fill_vocab_mask(self, vocab_mask: torch.Tensor): + def allocate_vocab_mask( + self, vocab_size: int, batch_size: int, device + ) -> torch.Tensor: + return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device) + + def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: + vocab_mask = vocab_mask[idx] vocab_mask.fill_(1) vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0 + @staticmethod + def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor): + logits.masked_fill_(vocab_mask, float("-inf")) + def copy(self): return OutlinesGrammar(self.guide, self.jump_forward_map) diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index ab4df5c98..acaae10c0 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -21,7 +21,12 @@ from typing import List, Tuple import torch try: - from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher + from xgrammar import ( + CachedGrammarCompiler, + CompiledGrammar, + GrammarMatcher, + TokenizerInfo, + ) import_error = None except ImportError as e: @@ -80,19 +85,23 @@ class XGrammarGrammar(BaseGrammarObject): for i in range(k, len(new_output_ids)): assert self.matcher.accept_token(new_output_ids[i]) - def fill_vocab_mask(self, vocab_mask: torch.Tensor): - # Note that this bitmask is a bitset, not bool - bitmask = self.matcher.get_next_token_bitmask() - # Mask the tokens that are not allowed - vocab_mask[ - self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size) - ] = 1 + def allocate_vocab_mask( + self, vocab_size: int, batch_size: int, device + ) -> torch.Tensor: + return self.matcher.allocate_token_bitmask(vocab_size, batch_size) + + def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: + self.matcher.fill_next_token_bitmask(vocab_mask, idx) + + @staticmethod + def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: + GrammarMatcher.apply_token_bitmask_inplace(logits, vocab_mask) def copy(self): matcher = GrammarMatcher( self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, - mask_vocab_size=self.vocab_size, + vocab_size=self.vocab_size, ) return XGrammarGrammar(matcher, self.vocab_size, self.ctx) @@ -112,7 +121,8 @@ class XGrammarGrammarBackend(BaseGrammarBackend): self.grammar_cache = None return - self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer) + tokenizer_info = TokenizerInfo.from_huggingface(tokenizer) + self.grammar_cache = CachedGrammarCompiler(tokenizer_info=tokenizer_info) self.vocab_size = vocab_size def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar: @@ -122,9 +132,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): key_type, key_string = key if key_type == "json": try: - ctx = self.grammar_cache.get_compiled_grammar_for_json_schema( - key_string - ) + ctx = self.grammar_cache.compile_json_schema_grammar(schema=key_string) except RuntimeError as e: logging.warning( f"Skip invalid json_schema: json_schema={key_string}, {e=}" @@ -141,7 +149,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend): matcher = GrammarMatcher( ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, - mask_vocab_size=self.vocab_size, + vocab_size=self.vocab_size, ) return XGrammarGrammar(matcher, self.vocab_size, ctx) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 02750d5df..8096fec5a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -645,7 +645,7 @@ class ModelRunner: # Apply regex vocab_mask if sampling_info.vocab_mask is not None: - logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf")) + sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask) return logits diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index a341c2b17..61aa341fd 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -1,7 +1,7 @@ from __future__ import annotations import dataclasses -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional import torch @@ -29,7 +29,7 @@ class SamplingBatchInfo: vocab_size: int logit_bias: torch.Tensor = None vocab_mask: Optional[torch.Tensor] = None - + apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None grammars: Optional[List] = None # Penalizer @@ -135,17 +135,23 @@ class SamplingBatchInfo: def update_regex_vocab_mask(self): if not self.grammars or not any(grammar for grammar in self.grammars): self.vocab_mask = None + self.apply_mask = None return - self.vocab_mask = torch.zeros( - len(self.temperatures), - self.vocab_size, - dtype=torch.bool, + # find a grammar from the list + grammar = next(grammar for grammar in self.grammars if grammar is not None) + + # maybe we can reuse the existing mask? + self.vocab_mask = grammar.allocate_vocab_mask( + vocab_size=self.vocab_size, + batch_size=len(self.temperatures), device=self.device, ) + self.apply_mask = type(grammar).apply_vocab_mask # force to use static method + for i, grammar in enumerate(self.grammars): if grammar is not None: - grammar.fill_vocab_mask(self.vocab_mask[i]) + grammar.fill_vocab_mask(self.vocab_mask, i) def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): if self.penalizer_orchestrator: