Fix grammar backend (#2018)
This commit is contained in:
@@ -13,30 +13,5 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from outlines.fsm.json_schema import build_regex_from_object
|
||||
except ImportError:
|
||||
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
|
||||
# which only accepts string schema as input.
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
|
||||
def build_regex_from_object(
|
||||
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
||||
):
|
||||
if isinstance(object, type(BaseModel)):
|
||||
schema = json.dumps(object.model_json_schema())
|
||||
elif isinstance(object, Dict):
|
||||
schema = json.dumps(object)
|
||||
else:
|
||||
schema = object
|
||||
return build_regex_from_schema(schema, whitespace_pattern)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"build_regex_from_object",
|
||||
]
|
||||
# TODO(lmzheng): make this an optional dependency
|
||||
from sglang.srt.constrained.outlines_backend import build_regex_from_object
|
||||
|
||||
@@ -95,9 +95,7 @@ class BaseToolCache:
|
||||
|
||||
def get_cache_hit_rate(self):
|
||||
with self.lock_metrics:
|
||||
if self.metrics["total"] == 0:
|
||||
return 0
|
||||
return self.metrics["hit"] / self.metrics["total"]
|
||||
return self.metrics["hit"] / max(self.metrics["total"], 1)
|
||||
|
||||
def get_avg_init_time(self):
|
||||
with self.lock_metrics:
|
||||
|
||||
@@ -1,182 +0,0 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Cache for the compressed finite state machine."""
|
||||
import logging
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.constrained.outlines_cache import OutlinesCache, RegexGuide
|
||||
from sglang.srt.constrained.outlines_jump_forward import (
|
||||
OutlinesJumpCache,
|
||||
OutlinesJumpForwardMap,
|
||||
)
|
||||
from sglang.srt.constrained.xgrammar_cache import (
|
||||
GrammarMatcher,
|
||||
XGrammarBackend,
|
||||
XGrammarJumpCache,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JumpHelper:
|
||||
|
||||
def __init__(
|
||||
self, data: Union[List, str] = "", state: int = -1, suffix_ids=[]
|
||||
) -> None:
|
||||
self.data: Union[List, str] = data
|
||||
self.state: int = state
|
||||
self.suffix_ids: List[int] = suffix_ids
|
||||
|
||||
def can_jump(self):
|
||||
return len(self.data) > 0
|
||||
|
||||
|
||||
class Grammar:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]],
|
||||
jump_map: Union[XGrammarJumpCache, OutlinesJumpForwardMap, None],
|
||||
) -> None:
|
||||
self.grammar = grammar
|
||||
self.jump_map = jump_map
|
||||
|
||||
def accept_token(self, token: int):
|
||||
if isinstance(self.grammar, GrammarMatcher):
|
||||
assert self.grammar.accept_token(token)
|
||||
else:
|
||||
guide, state = self.grammar
|
||||
self.grammar = guide, guide.get_next_state(state, token)
|
||||
|
||||
def try_jump(self, tokenizer) -> JumpHelper:
|
||||
if isinstance(self.jump_map, XGrammarJumpCache):
|
||||
assert isinstance(self.grammar, GrammarMatcher)
|
||||
return JumpHelper(self.grammar.find_jump_forward_string())
|
||||
elif isinstance(self.jump_map, OutlinesJumpForwardMap):
|
||||
assert isinstance(self.grammar, Tuple)
|
||||
|
||||
_, state = self.grammar
|
||||
jump_forward_bytes = self.jump_map.jump_forward_byte(state)
|
||||
if jump_forward_bytes is None or len(jump_forward_bytes) == 0:
|
||||
return JumpHelper() # can't jump
|
||||
|
||||
# preprocess the jump forward string
|
||||
suffix_bytes = []
|
||||
continuation_range = range(0x80, 0xC0)
|
||||
cur_state = state
|
||||
while (
|
||||
len(jump_forward_bytes)
|
||||
and jump_forward_bytes[0][0] in continuation_range
|
||||
):
|
||||
# continuation bytes
|
||||
byte_edge = jump_forward_bytes.pop(0)
|
||||
suffix_bytes.append(byte_edge[0])
|
||||
cur_state = byte_edge[1]
|
||||
|
||||
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
|
||||
suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)
|
||||
return JumpHelper(suffix_ids, cur_state, suffix_bytes)
|
||||
else:
|
||||
return JumpHelper() # can't jump
|
||||
|
||||
def jump_forward_str_state(self, helper: JumpHelper) -> Tuple[str, int]:
|
||||
if isinstance(helper.data, str):
|
||||
return helper.data, -1
|
||||
else:
|
||||
assert isinstance(self.jump_map, OutlinesJumpForwardMap)
|
||||
return self.jump_map.jump_forward_symbol(helper.state)
|
||||
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
):
|
||||
if isinstance(self.grammar, GrammarMatcher):
|
||||
k = 0
|
||||
for i, old_id in enumerate(old_output_ids):
|
||||
if old_id == new_output_ids[i]:
|
||||
k = i + 1
|
||||
else:
|
||||
break
|
||||
|
||||
# rollback to the last token that is the same
|
||||
if k < len(old_output_ids):
|
||||
self.grammar.rollback(len(old_output_ids) - k)
|
||||
|
||||
for i in range(k, len(new_output_ids)):
|
||||
assert self.grammar.accept_token(new_output_ids[i])
|
||||
else:
|
||||
self.grammar = self.grammar[0], next_state
|
||||
|
||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int):
|
||||
if isinstance(self.grammar, GrammarMatcher):
|
||||
# Note that this bitmask is a bitset, not bool
|
||||
bitmask = self.grammar.get_next_token_bitmask()
|
||||
# Mask the tokens that are not allowed
|
||||
vocab_mask[
|
||||
self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size)
|
||||
] = 1
|
||||
else:
|
||||
guide, state = self.grammar
|
||||
vocab_mask.fill_(1)
|
||||
vocab_mask[guide.get_next_instruction(state).tokens] = 0
|
||||
|
||||
|
||||
class GrammarBackend:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path,
|
||||
tokenizer_args_dict,
|
||||
skip_tokenizer_init=False,
|
||||
whitespace_patterns=None,
|
||||
backend=None,
|
||||
allow_jump=False,
|
||||
):
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.backend = backend
|
||||
|
||||
if backend == "xgrammar":
|
||||
self.grammar_cache = XGrammarBackend(
|
||||
tokenizer_path=tokenizer_path,
|
||||
tokenizer_args_dict=tokenizer_args_dict,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
whitespace_patterns=whitespace_patterns,
|
||||
)
|
||||
self.jump_cache = XGrammarJumpCache() if allow_jump else None
|
||||
else:
|
||||
assert backend == "outlines"
|
||||
self.grammar_cache = OutlinesCache(
|
||||
tokenizer_path=tokenizer_path,
|
||||
tokenizer_args_dict=tokenizer_args_dict,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
constrained_json_whitespace_pattern=whitespace_patterns,
|
||||
)
|
||||
self.jump_cache = OutlinesJumpCache() if allow_jump else None
|
||||
|
||||
def _query(self, key: Tuple[str, str], vocab_size: int) -> Grammar:
|
||||
if isinstance(self.grammar_cache, XGrammarBackend):
|
||||
return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache)
|
||||
else:
|
||||
guide, regex = self.grammar_cache.query(key)
|
||||
jump_map = self.jump_cache.query(regex)
|
||||
return Grammar((guide, 0), jump_map)
|
||||
|
||||
def query(self, key: Tuple[str, str], vocab_size: int) -> Future:
|
||||
return self.executor.submit(self._query, key, vocab_size)
|
||||
|
||||
def reset(self):
|
||||
self.grammar_cache.reset()
|
||||
self.jump_cache.reset()
|
||||
203
python/sglang/srt/constrained/outlines_backend.py
Normal file
203
python/sglang/srt/constrained/outlines_backend.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Constrained decoding with outlines backend."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from interegular import InvalidSyntax, parse_pattern
|
||||
from outlines.fsm.guide import RegexGuide
|
||||
from outlines.models.transformers import TransformerTokenizer
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||
from sglang.srt.constrained.outlines_jump_forward import (
|
||||
OutlinesJumpForwardCache,
|
||||
OutlinesJumpForwardMap,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
from outlines.fsm.json_schema import build_regex_from_object
|
||||
except ImportError:
|
||||
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
|
||||
# which only accepts string schema as input.
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
|
||||
def build_regex_from_object(
|
||||
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
||||
):
|
||||
if isinstance(object, type(BaseModel)):
|
||||
schema = json.dumps(object.model_json_schema())
|
||||
elif isinstance(object, Dict):
|
||||
schema = json.dumps(object)
|
||||
else:
|
||||
schema = object
|
||||
return build_regex_from_schema(schema, whitespace_pattern)
|
||||
|
||||
|
||||
class OutlinesGrammar:
|
||||
def __init__(
|
||||
self,
|
||||
guide: RegexGuide,
|
||||
state: int,
|
||||
jump_forward_map: Union[OutlinesJumpForwardMap, None],
|
||||
) -> None:
|
||||
self.guide = guide
|
||||
self.state = state
|
||||
self.jump_forward_map = jump_forward_map
|
||||
|
||||
def accept_token(self, token: int):
|
||||
self.state = self.guide.get_next_state(self.state, token)
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Optional[Tuple]:
|
||||
if not self.jump_forward_map:
|
||||
return None
|
||||
|
||||
jump_forward_bytes = self.jump_forward_map.jump_forward_byte(self.state)
|
||||
if jump_forward_bytes is None or len(jump_forward_bytes) <= 1:
|
||||
return None
|
||||
|
||||
# preprocess the jump forward string
|
||||
suffix_bytes = []
|
||||
continuation_range = range(0x80, 0xC0)
|
||||
cur_state = self.state
|
||||
while (
|
||||
len(jump_forward_bytes) and jump_forward_bytes[0][0] in continuation_range
|
||||
):
|
||||
# continuation bytes
|
||||
byte_edge = jump_forward_bytes.pop(0)
|
||||
suffix_bytes.append(byte_edge[0])
|
||||
cur_state = byte_edge[1]
|
||||
|
||||
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
|
||||
suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)
|
||||
return suffix_ids, cur_state
|
||||
|
||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||
_, cur_state = helper
|
||||
return self.jump_forward_map.jump_forward_symbol(cur_state)
|
||||
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
):
|
||||
self.state = next_state
|
||||
|
||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor):
|
||||
vocab_mask.fill_(1)
|
||||
vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0
|
||||
|
||||
|
||||
class OutlinesGrammarBackend:
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
whitespace_patterns: bool,
|
||||
allow_jump_forward: bool,
|
||||
):
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.grammar_cache = OutlinesCache(
|
||||
tokenizer,
|
||||
whitespace_pattern=whitespace_patterns,
|
||||
)
|
||||
self.jump_forward_cache = (
|
||||
OutlinesJumpForwardCache() if allow_jump_forward else None
|
||||
)
|
||||
|
||||
def _query(self, key: Tuple[str, str]) -> OutlinesGrammar:
|
||||
guide, regex = self.grammar_cache.query(key)
|
||||
jump_forward_map = (
|
||||
self.jump_forward_cache.query(regex) if self.jump_forward_cache else None
|
||||
)
|
||||
return OutlinesGrammar(guide, 0, jump_forward_map)
|
||||
|
||||
def query(self, key: Tuple[str, str]) -> Future:
|
||||
return self.executor.submit(self._query, key)
|
||||
|
||||
def reset(self):
|
||||
self.grammar_cache.reset()
|
||||
if self.jump_forward_cache:
|
||||
self.jump_forward_cache.reset()
|
||||
|
||||
|
||||
class OutlinesCache(BaseToolCache):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
whitespace_pattern=None,
|
||||
):
|
||||
super().__init__(enable=True)
|
||||
|
||||
try:
|
||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
except AttributeError:
|
||||
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
|
||||
origin_pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
def fset(self, value):
|
||||
self._value = value
|
||||
|
||||
type(tokenizer).pad_token_id = property(
|
||||
fget=type(tokenizer).pad_token_id.fget, fset=fset
|
||||
)
|
||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
|
||||
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
|
||||
self.outlines_tokenizer.pad_token = (
|
||||
self.outlines_tokenizer.tokenizer.pad_token
|
||||
)
|
||||
self.outlines_tokenizer.vocabulary = (
|
||||
self.outlines_tokenizer.tokenizer.get_vocab()
|
||||
)
|
||||
self.whitespace_pattern = whitespace_pattern
|
||||
|
||||
def init_value(self, key):
|
||||
key_type, key_string = key
|
||||
if key_type == "json":
|
||||
try:
|
||||
regex = build_regex_from_object(
|
||||
key_string,
|
||||
whitespace_pattern=self.whitespace_pattern,
|
||||
)
|
||||
except NotImplementedError as e:
|
||||
logger.warning(
|
||||
f"skip invalid json schema: json_schema={key_string}, {e=}"
|
||||
)
|
||||
return None, key_string
|
||||
elif key_type == "regex":
|
||||
regex = key_string
|
||||
else:
|
||||
raise ValueError(f"Invalid key_type: {key_type}")
|
||||
try:
|
||||
parse_pattern(regex)
|
||||
except InvalidSyntax as e:
|
||||
logger.warning(f"skip invalid regex guide: {regex=}, {e=}")
|
||||
return None, regex
|
||||
|
||||
ret = RegexGuide(regex, self.outlines_tokenizer), regex
|
||||
return ret
|
||||
|
||||
def _query(self, key: Tuple[str, str]):
|
||||
guide, regex = self.grammar_cache.query(key)
|
||||
jump_forward_map = (
|
||||
self.jump_forward_cache.query(regex) if self.jump_forward_cache else None
|
||||
)
|
||||
return OutlinesGrammar(guide, 0, jump_forward_map)
|
||||
@@ -1,96 +0,0 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Cache for the compressed finite state machine."""
|
||||
import logging
|
||||
|
||||
from interegular import InvalidSyntax, parse_pattern
|
||||
from outlines.fsm.guide import RegexGuide
|
||||
from outlines.models.transformers import TransformerTokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from sglang.srt.constrained import build_regex_from_object
|
||||
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutlinesCache(BaseToolCache):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path,
|
||||
tokenizer_args_dict,
|
||||
enable=True,
|
||||
skip_tokenizer_init=False,
|
||||
constrained_json_whitespace_pattern=None,
|
||||
):
|
||||
super().__init__(enable=enable)
|
||||
|
||||
if (
|
||||
skip_tokenizer_init
|
||||
or tokenizer_path.endswith(".json")
|
||||
or tokenizer_path.endswith(".model")
|
||||
):
|
||||
# Do not support TiktokenTokenizer or SentencePieceTokenizer
|
||||
return
|
||||
|
||||
tokenizer_args_dict.setdefault("padding_side", "left")
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
|
||||
try:
|
||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
except AttributeError:
|
||||
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
|
||||
origin_pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
def fset(self, value):
|
||||
self._value = value
|
||||
|
||||
type(tokenizer).pad_token_id = property(
|
||||
fget=type(tokenizer).pad_token_id.fget, fset=fset
|
||||
)
|
||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
|
||||
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
|
||||
self.outlines_tokenizer.pad_token = (
|
||||
self.outlines_tokenizer.tokenizer.pad_token
|
||||
)
|
||||
self.outlines_tokenizer.vocabulary = (
|
||||
self.outlines_tokenizer.tokenizer.get_vocab()
|
||||
)
|
||||
self.constrained_json_whitespace_pattern = constrained_json_whitespace_pattern
|
||||
|
||||
def init_value(self, key):
|
||||
key_type, key_string = key
|
||||
if key_type == "json":
|
||||
try:
|
||||
regex = build_regex_from_object(
|
||||
key_string,
|
||||
whitespace_pattern=self.constrained_json_whitespace_pattern,
|
||||
)
|
||||
except NotImplementedError as e:
|
||||
logger.warning(
|
||||
f"skip invalid json schema: json_schema={key_string}, {e=}"
|
||||
)
|
||||
return None, key_string
|
||||
elif key_type == "regex":
|
||||
regex = key_string
|
||||
else:
|
||||
raise ValueError(f"Invalid key_type: {key_type}")
|
||||
try:
|
||||
parse_pattern(regex)
|
||||
except InvalidSyntax as e:
|
||||
logger.warning(f"skip invalid regex guide: {regex=}, {e=}")
|
||||
return None, regex
|
||||
return RegexGuide(regex, self.outlines_tokenizer), regex
|
||||
@@ -164,7 +164,7 @@ class OutlinesJumpForwardMap:
|
||||
)
|
||||
|
||||
|
||||
class OutlinesJumpCache(BaseToolCache):
|
||||
class OutlinesJumpForwardCache(BaseToolCache):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
127
python/sglang/srt/constrained/xgrammar_backend.py
Normal file
127
python/sglang/srt/constrained/xgrammar_backend.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Constrained decoding with xgrammar backend."""
|
||||
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
|
||||
|
||||
import_error = None
|
||||
except ImportError as e:
|
||||
import_error = e
|
||||
|
||||
class Dummy:
|
||||
pass
|
||||
|
||||
GrammarMatcher = CompiledGrammar = CachedGrammarCompiler = Dummy
|
||||
|
||||
|
||||
MAX_ROLLBACK_TOKENS = 10
|
||||
|
||||
|
||||
class XGrammarGrammar:
|
||||
|
||||
def __init__(self, matcher: GrammarMatcher, vocab_size: int) -> None:
|
||||
self.matcher = matcher
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
def accept_token(self, token: int):
|
||||
assert self.matcher.accept_token(token)
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
|
||||
return [], self.matcher.find_jump_forward_string()
|
||||
|
||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||
_, data = helper
|
||||
return data, -1
|
||||
|
||||
def jump_and_retokenize(
|
||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||
):
|
||||
k = 0
|
||||
for i, old_id in enumerate(old_output_ids):
|
||||
if old_id == new_output_ids[i]:
|
||||
k = i + 1
|
||||
else:
|
||||
break
|
||||
|
||||
# rollback to the last token that is the same
|
||||
if k < len(old_output_ids):
|
||||
self.matcher.rollback(len(old_output_ids) - k)
|
||||
|
||||
for i in range(k, len(new_output_ids)):
|
||||
assert self.matcher.accept_token(new_output_ids[i])
|
||||
|
||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor):
|
||||
# Note that this bitmask is a bitset, not bool
|
||||
bitmask = self.matcher.get_next_token_bitmask()
|
||||
# Mask the tokens that are not allowed
|
||||
vocab_mask[
|
||||
self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size)
|
||||
] = 1
|
||||
|
||||
|
||||
class XGrammarGrammarBackend:
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
vocab_size: int,
|
||||
):
|
||||
if import_error:
|
||||
raise import_error
|
||||
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.grammar_cache = XGrammarCache(tokenizer, vocab_size)
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
def _query(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
||||
return XGrammarGrammar(self.grammar_cache.query(key), self.vocab_size)
|
||||
|
||||
def query(self, key: Tuple[str, str]) -> Future:
|
||||
return self.executor.submit(self._query, key)
|
||||
|
||||
def reset(self):
|
||||
self.grammar_cache.reset()
|
||||
|
||||
|
||||
class XGrammarCache:
|
||||
def __init__(self, tokenizer, vocab_size: int):
|
||||
self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
def get_context(self, key: Tuple[str, str]) -> CompiledGrammar:
|
||||
key_type, key_string = key
|
||||
if key_type == "json":
|
||||
return self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
|
||||
elif key_type == "regex":
|
||||
raise ValueError("regex hasn't been supported by xgrammar yet")
|
||||
else:
|
||||
raise ValueError(f"Invalid key_type: {key_type}")
|
||||
|
||||
def query(self, key: Tuple[str, str]) -> GrammarMatcher:
|
||||
ctx = self.get_context(key)
|
||||
return GrammarMatcher(
|
||||
ctx,
|
||||
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
||||
mask_vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
self.grammar_cache.clear()
|
||||
@@ -1,75 +0,0 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Cache for the compressed finite state machine."""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
try:
|
||||
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
|
||||
except ImportError as e:
|
||||
|
||||
class Dummy:
|
||||
pass
|
||||
|
||||
GrammarMatcher = Dummy
|
||||
CompiledGrammar = Dummy
|
||||
CachedGrammarCompiler = Dummy
|
||||
|
||||
|
||||
MAX_ROLLBACK_TOKENS = 10
|
||||
|
||||
|
||||
class XGrammarJumpCache:
|
||||
"""A dummy class."""
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
|
||||
class XGrammarBackend:
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path,
|
||||
tokenizer_args_dict,
|
||||
skip_tokenizer_init=False,
|
||||
whitespace_patterns=None,
|
||||
):
|
||||
# TODO(dark): how to deal with whitespace_patterns and skip_tokenizer_init
|
||||
if skip_tokenizer_init:
|
||||
return
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
|
||||
self.grammar_cache: CachedGrammarCompiler = CachedGrammarCompiler(
|
||||
tokenizer_or_vocab=tokenizer
|
||||
)
|
||||
|
||||
def get_context(self, key: Tuple[str, str]) -> CompiledGrammar:
|
||||
key_type, key_string = key
|
||||
if key_type == "json":
|
||||
return self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
|
||||
elif key_type == "regex":
|
||||
raise ValueError("regex hasn't been supported by xgrammar yet")
|
||||
else:
|
||||
raise ValueError(f"Invalid key_type: {key_type}")
|
||||
|
||||
def query(self, key: Tuple[str, str], vocab_size: int) -> GrammarMatcher:
|
||||
ctx = self.get_context(key)
|
||||
return GrammarMatcher(
|
||||
ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
self.grammar_cache.clear()
|
||||
Reference in New Issue
Block a user