From 7f076c2ce6d2de2625233b98c4b6990d24d09b66 Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Mon, 25 Nov 2024 18:58:30 -0500 Subject: [PATCH] Update XGrammar to the latest API (#2176) Co-authored-by: Ben Gitter --- python/pyproject.toml | 2 +- .../srt/constrained/xgrammar_backend.py | 67 +++++++------------ test/srt/test_json_constrained.py | 37 +++++++++- 3 files changed, 61 insertions(+), 45 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index a47426427..4b6da31b3 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -22,7 +22,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2", "torchao", "uvicorn", "uvloop", - "modelscope", "xgrammar"] + "modelscope", "xgrammar==0.1.4"] srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1"] # HIP (Heterogeneous-computing Interface for Portability) for AMD diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index 9195aa30d..1bcc51c64 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -17,21 +17,14 @@ import logging from typing import List, Tuple import torch - -try: - from xgrammar import ( - CachedGrammarCompiler, - CompiledGrammar, - GrammarMatcher, - TokenizerInfo, - ) - - import_error = None -except ImportError as e: - CachedGrammarCompiler = CompiledGrammar = GrammarMatcher = TokenizerInfo = ( - ImportError - ) - import_error = e +from xgrammar import ( + CompiledGrammar, + GrammarCompiler, + GrammarMatcher, + TokenizerInfo, + allocate_token_bitmask, + apply_token_bitmask_inplace, +) from sglang.srt.constrained.base_grammar_backend import ( BaseGrammarBackend, @@ -41,7 +34,7 @@ from sglang.srt.constrained.base_grammar_backend import ( logger = logging.getLogger(__name__) -MAX_ROLLBACK_TOKENS = 10 +MAX_ROLLBACK_TOKENS = 200 class XGrammarGrammar(BaseGrammarObject): @@ -86,21 +79,22 @@ class XGrammarGrammar(BaseGrammarObject): def allocate_vocab_mask( self, vocab_size: int, batch_size: int, device ) -> torch.Tensor: - return self.matcher.allocate_token_bitmask(vocab_size, batch_size) + return allocate_token_bitmask(batch_size, vocab_size) def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: self.matcher.fill_next_token_bitmask(vocab_mask, idx) @staticmethod def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: - GrammarMatcher.apply_token_bitmask_inplace(logits, vocab_mask) + if vocab_mask.device.type != logits.device.type: + # vocab_mask must then be on the same device as logits + # when applying the token bitmask, so we check and move if needed + vocab_mask = vocab_mask.to(logits.device) + + apply_token_bitmask_inplace(logits, vocab_mask) def copy(self): - matcher = GrammarMatcher( - self.ctx, - max_rollback_tokens=MAX_ROLLBACK_TOKENS, - vocab_size=self.vocab_size, - ) + matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS) return XGrammarGrammar(matcher, self.vocab_size, self.ctx) @@ -112,25 +106,18 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ): super().__init__() - if import_error: - logger.warning( - f"Ignore import error for the grammar backend: {import_error}" - ) - self.grammar_cache = None - return - - tokenizer_info = TokenizerInfo.from_huggingface(tokenizer) - self.grammar_cache = CachedGrammarCompiler(tokenizer_info=tokenizer_info) + tokenizer_info = TokenizerInfo.from_huggingface( + tokenizer, vocab_size=vocab_size + ) + self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info) self.vocab_size = vocab_size def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar: - if import_error: - raise import_error key_type, key_string = key if key_type == "json": try: - ctx = self.grammar_cache.compile_json_schema_grammar(schema=key_string) + ctx = self.grammar_compiler.compile_json_schema(schema=key_string) except RuntimeError as e: logging.warning( f"Skip invalid json_schema: json_schema={key_string}, {e=}" @@ -144,13 +131,9 @@ class XGrammarGrammarBackend(BaseGrammarBackend): else: raise ValueError(f"Invalid key_type: {key_type}") - matcher = GrammarMatcher( - ctx, - max_rollback_tokens=MAX_ROLLBACK_TOKENS, - vocab_size=self.vocab_size, - ) + matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS) return XGrammarGrammar(matcher, self.vocab_size, ctx) def reset(self): - if self.grammar_cache: - self.grammar_cache.clear() + if self.grammar_compiler: + self.grammar_compiler.clear_cache() diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index 2d08d6684..ae27b036f 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -17,7 +17,7 @@ from sglang.test.test_utils import ( ) -class TestJSONConstrained(unittest.TestCase): +class TestJSONConstrainedOutlinesBackend(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST @@ -36,7 +36,12 @@ class TestJSONConstrained(unittest.TestCase): cls.model, cls.base_url, timeout=300, - other_args=["--max-running-requests", "10"], + other_args=[ + "--max-running-requests", + "10", + "--grammar-backend", + "outlines", + ], ) @classmethod @@ -121,5 +126,33 @@ class TestJSONConstrained(unittest.TestCase): list(executor.map(self.run_decode, json_schemas)) +class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.json_schema = json.dumps( + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "population": {"type": "integer"}, + }, + "required": ["name", "population"], + } + ) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=300, + other_args=[ + "--max-running-requests", + "10", + "--grammar-backend", + "xgrammar", + ], + ) + + if __name__ == "__main__": unittest.main()