Fix dependency and error message for xgrammar (#2024)
This commit is contained in:
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""The baseclass of backends for grammar-guided constrained decoding."""
|
||||
"""The baseclass of a backend for grammar-guided constrained decoding."""
|
||||
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -22,7 +22,9 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
import interegular
|
||||
import torch
|
||||
from outlines.fsm.guide import RegexGuide
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
from outlines.models.transformers import TransformerTokenizer
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sglang.srt.constrained.base_grammar_backend import (
|
||||
BaseGrammarBackend,
|
||||
@@ -33,26 +35,6 @@ from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
from outlines.fsm.json_schema import build_regex_from_object
|
||||
except ImportError:
|
||||
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
|
||||
# which only accepts string schema as input.
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
from pydantic import BaseModel
|
||||
|
||||
def build_regex_from_object(
|
||||
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
||||
):
|
||||
if isinstance(object, type(BaseModel)):
|
||||
schema = json.dumps(object.model_json_schema())
|
||||
elif isinstance(object, Dict):
|
||||
schema = json.dumps(object)
|
||||
else:
|
||||
schema = object
|
||||
return build_regex_from_schema(schema, whitespace_pattern)
|
||||
|
||||
|
||||
class OutlinesGrammar(BaseGrammarObject):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -169,3 +151,15 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
||||
else:
|
||||
jump_forward_map = None
|
||||
return OutlinesGrammar(guide, jump_forward_map)
|
||||
|
||||
|
||||
def build_regex_from_object(
|
||||
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
||||
):
|
||||
if isinstance(object, type(BaseModel)):
|
||||
schema = json.dumps(object.model_json_schema())
|
||||
elif isinstance(object, Dict):
|
||||
schema = json.dumps(object)
|
||||
else:
|
||||
schema = object
|
||||
return build_regex_from_schema(schema, whitespace_pattern)
|
||||
|
||||
@@ -19,7 +19,16 @@ import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
|
||||
|
||||
try:
|
||||
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
|
||||
|
||||
import_error = None
|
||||
except ImportError as e:
|
||||
CachedGrammarCompiler = CompiledGrammar = GrammarMatcher = TokenizerInfo = (
|
||||
ImportError
|
||||
)
|
||||
import_error = e
|
||||
|
||||
from sglang.srt.constrained.base_grammar_backend import (
|
||||
BaseGrammarBackend,
|
||||
@@ -95,10 +104,21 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
vocab_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if import_error:
|
||||
logger.warning(
|
||||
f"Ignore import error for the grammar backend: {import_error}"
|
||||
)
|
||||
self.grammar_cache = None
|
||||
return
|
||||
|
||||
self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
||||
if import_error:
|
||||
raise import_error
|
||||
|
||||
key_type, key_string = key
|
||||
if key_type == "json":
|
||||
try:
|
||||
@@ -126,4 +146,5 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
||||
|
||||
def reset(self):
|
||||
self.grammar_cache.clear()
|
||||
if self.grammar_cache:
|
||||
self.grammar_cache.clear()
|
||||
|
||||
Reference in New Issue
Block a user