From c722d9bdc30e9730f82f6d646c171c43a4837e12 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 13 Nov 2024 14:04:25 -0800 Subject: [PATCH] Fix dependency and error message for xgrammar (#2024) --- .../srt/constrained/base_grammar_backend.py | 2 +- .../srt/constrained/outlines_backend.py | 34 ++++++++----------- .../srt/constrained/xgrammar_backend.py | 25 ++++++++++++-- 3 files changed, 38 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/constrained/base_grammar_backend.py b/python/sglang/srt/constrained/base_grammar_backend.py index d1192685e..5534667e4 100644 --- a/python/sglang/srt/constrained/base_grammar_backend.py +++ b/python/sglang/srt/constrained/base_grammar_backend.py @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ -"""The baseclass of backends for grammar-guided constrained decoding.""" +"""The baseclass of a backend for grammar-guided constrained decoding.""" from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass diff --git a/python/sglang/srt/constrained/outlines_backend.py b/python/sglang/srt/constrained/outlines_backend.py index 4999e8dbd..cc68b97f8 100644 --- a/python/sglang/srt/constrained/outlines_backend.py +++ b/python/sglang/srt/constrained/outlines_backend.py @@ -22,7 +22,9 @@ from typing import Dict, List, Optional, Tuple, Union import interegular import torch from outlines.fsm.guide import RegexGuide +from outlines.fsm.json_schema import build_regex_from_schema from outlines.models.transformers import TransformerTokenizer +from pydantic import BaseModel from sglang.srt.constrained.base_grammar_backend import ( BaseGrammarBackend, @@ -33,26 +35,6 @@ from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap logger = logging.getLogger(__name__) -try: - from outlines.fsm.json_schema import build_regex_from_object -except ImportError: - # Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema, - # which only accepts string schema as input. - from outlines.fsm.json_schema import build_regex_from_schema - from pydantic import BaseModel - - def build_regex_from_object( - object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None - ): - if isinstance(object, type(BaseModel)): - schema = json.dumps(object.model_json_schema()) - elif isinstance(object, Dict): - schema = json.dumps(object) - else: - schema = object - return build_regex_from_schema(schema, whitespace_pattern) - - class OutlinesGrammar(BaseGrammarObject): def __init__( self, @@ -169,3 +151,15 @@ class OutlinesGrammarBackend(BaseGrammarBackend): else: jump_forward_map = None return OutlinesGrammar(guide, jump_forward_map) + + +def build_regex_from_object( + object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None +): + if isinstance(object, type(BaseModel)): + schema = json.dumps(object.model_json_schema()) + elif isinstance(object, Dict): + schema = json.dumps(object) + else: + schema = object + return build_regex_from_schema(schema, whitespace_pattern) diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index c36ae00b4..ab4df5c98 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -19,7 +19,16 @@ import logging from typing import List, Tuple import torch -from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher + +try: + from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher + + import_error = None +except ImportError as e: + CachedGrammarCompiler = CompiledGrammar = GrammarMatcher = TokenizerInfo = ( + ImportError + ) + import_error = e from sglang.srt.constrained.base_grammar_backend import ( BaseGrammarBackend, @@ -95,10 +104,21 @@ class XGrammarGrammarBackend(BaseGrammarBackend): vocab_size: int, ): super().__init__() + + if import_error: + logger.warning( + f"Ignore import error for the grammar backend: {import_error}" + ) + self.grammar_cache = None + return + self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer) self.vocab_size = vocab_size def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar: + if import_error: + raise import_error + key_type, key_string = key if key_type == "json": try: @@ -126,4 +146,5 @@ class XGrammarGrammarBackend(BaseGrammarBackend): return XGrammarGrammar(matcher, self.vocab_size, ctx) def reset(self): - self.grammar_cache.clear() + if self.grammar_cache: + self.grammar_cache.clear()