Fix dependency and error message for xgrammar (#2024)
This commit is contained in:
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""The baseclass of backends for grammar-guided constrained decoding."""
|
"""The baseclass of a backend for grammar-guided constrained decoding."""
|
||||||
|
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|||||||
@@ -22,7 +22,9 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||||||
import interegular
|
import interegular
|
||||||
import torch
|
import torch
|
||||||
from outlines.fsm.guide import RegexGuide
|
from outlines.fsm.guide import RegexGuide
|
||||||
|
from outlines.fsm.json_schema import build_regex_from_schema
|
||||||
from outlines.models.transformers import TransformerTokenizer
|
from outlines.models.transformers import TransformerTokenizer
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from sglang.srt.constrained.base_grammar_backend import (
|
from sglang.srt.constrained.base_grammar_backend import (
|
||||||
BaseGrammarBackend,
|
BaseGrammarBackend,
|
||||||
@@ -33,26 +35,6 @@ from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
from outlines.fsm.json_schema import build_regex_from_object
|
|
||||||
except ImportError:
|
|
||||||
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
|
|
||||||
# which only accepts string schema as input.
|
|
||||||
from outlines.fsm.json_schema import build_regex_from_schema
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
def build_regex_from_object(
|
|
||||||
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
|
||||||
):
|
|
||||||
if isinstance(object, type(BaseModel)):
|
|
||||||
schema = json.dumps(object.model_json_schema())
|
|
||||||
elif isinstance(object, Dict):
|
|
||||||
schema = json.dumps(object)
|
|
||||||
else:
|
|
||||||
schema = object
|
|
||||||
return build_regex_from_schema(schema, whitespace_pattern)
|
|
||||||
|
|
||||||
|
|
||||||
class OutlinesGrammar(BaseGrammarObject):
|
class OutlinesGrammar(BaseGrammarObject):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -169,3 +151,15 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
|
|||||||
else:
|
else:
|
||||||
jump_forward_map = None
|
jump_forward_map = None
|
||||||
return OutlinesGrammar(guide, jump_forward_map)
|
return OutlinesGrammar(guide, jump_forward_map)
|
||||||
|
|
||||||
|
|
||||||
|
def build_regex_from_object(
|
||||||
|
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
||||||
|
):
|
||||||
|
if isinstance(object, type(BaseModel)):
|
||||||
|
schema = json.dumps(object.model_json_schema())
|
||||||
|
elif isinstance(object, Dict):
|
||||||
|
schema = json.dumps(object)
|
||||||
|
else:
|
||||||
|
schema = object
|
||||||
|
return build_regex_from_schema(schema, whitespace_pattern)
|
||||||
|
|||||||
@@ -19,7 +19,16 @@ import logging
|
|||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
|
|
||||||
|
try:
|
||||||
|
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
|
||||||
|
|
||||||
|
import_error = None
|
||||||
|
except ImportError as e:
|
||||||
|
CachedGrammarCompiler = CompiledGrammar = GrammarMatcher = TokenizerInfo = (
|
||||||
|
ImportError
|
||||||
|
)
|
||||||
|
import_error = e
|
||||||
|
|
||||||
from sglang.srt.constrained.base_grammar_backend import (
|
from sglang.srt.constrained.base_grammar_backend import (
|
||||||
BaseGrammarBackend,
|
BaseGrammarBackend,
|
||||||
@@ -95,10 +104,21 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if import_error:
|
||||||
|
logger.warning(
|
||||||
|
f"Ignore import error for the grammar backend: {import_error}"
|
||||||
|
)
|
||||||
|
self.grammar_cache = None
|
||||||
|
return
|
||||||
|
|
||||||
self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
|
self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
||||||
|
if import_error:
|
||||||
|
raise import_error
|
||||||
|
|
||||||
key_type, key_string = key
|
key_type, key_string = key
|
||||||
if key_type == "json":
|
if key_type == "json":
|
||||||
try:
|
try:
|
||||||
@@ -126,4 +146,5 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.grammar_cache.clear()
|
if self.grammar_cache:
|
||||||
|
self.grammar_cache.clear()
|
||||||
|
|||||||
Reference in New Issue
Block a user