Fix grammar backend (#2018)

This commit is contained in:
Lianmin Zheng
2024-11-12 21:17:38 -08:00
committed by GitHub
parent 125b1199c5
commit ba069a24d3
13 changed files with 401 additions and 434 deletions

View File

@@ -13,30 +13,5 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import json
from typing import Dict, Optional, Union
from pydantic import BaseModel
try:
from outlines.fsm.json_schema import build_regex_from_object
except ImportError:
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
# which only accepts string schema as input.
from outlines.fsm.json_schema import build_regex_from_schema
def build_regex_from_object(
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
):
if isinstance(object, type(BaseModel)):
schema = json.dumps(object.model_json_schema())
elif isinstance(object, Dict):
schema = json.dumps(object)
else:
schema = object
return build_regex_from_schema(schema, whitespace_pattern)
__all__ = [
"build_regex_from_object",
]
# TODO(lmzheng): make this an optional dependency
from sglang.srt.constrained.outlines_backend import build_regex_from_object

View File

@@ -95,9 +95,7 @@ class BaseToolCache:
def get_cache_hit_rate(self):
with self.lock_metrics:
if self.metrics["total"] == 0:
return 0
return self.metrics["hit"] / self.metrics["total"]
return self.metrics["hit"] / max(self.metrics["total"], 1)
def get_avg_init_time(self):
with self.lock_metrics:

View File

@@ -1,182 +0,0 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Cache for the compressed finite state machine."""
import logging
from concurrent.futures import Future, ThreadPoolExecutor
from typing import List, Tuple, Union
import torch
from sglang.srt.constrained.outlines_cache import OutlinesCache, RegexGuide
from sglang.srt.constrained.outlines_jump_forward import (
OutlinesJumpCache,
OutlinesJumpForwardMap,
)
from sglang.srt.constrained.xgrammar_cache import (
GrammarMatcher,
XGrammarBackend,
XGrammarJumpCache,
)
logger = logging.getLogger(__name__)
class JumpHelper:
def __init__(
self, data: Union[List, str] = "", state: int = -1, suffix_ids=[]
) -> None:
self.data: Union[List, str] = data
self.state: int = state
self.suffix_ids: List[int] = suffix_ids
def can_jump(self):
return len(self.data) > 0
class Grammar:
def __init__(
self,
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]],
jump_map: Union[XGrammarJumpCache, OutlinesJumpForwardMap, None],
) -> None:
self.grammar = grammar
self.jump_map = jump_map
def accept_token(self, token: int):
if isinstance(self.grammar, GrammarMatcher):
assert self.grammar.accept_token(token)
else:
guide, state = self.grammar
self.grammar = guide, guide.get_next_state(state, token)
def try_jump(self, tokenizer) -> JumpHelper:
if isinstance(self.jump_map, XGrammarJumpCache):
assert isinstance(self.grammar, GrammarMatcher)
return JumpHelper(self.grammar.find_jump_forward_string())
elif isinstance(self.jump_map, OutlinesJumpForwardMap):
assert isinstance(self.grammar, Tuple)
_, state = self.grammar
jump_forward_bytes = self.jump_map.jump_forward_byte(state)
if jump_forward_bytes is None or len(jump_forward_bytes) == 0:
return JumpHelper() # can't jump
# preprocess the jump forward string
suffix_bytes = []
continuation_range = range(0x80, 0xC0)
cur_state = state
while (
len(jump_forward_bytes)
and jump_forward_bytes[0][0] in continuation_range
):
# continuation bytes
byte_edge = jump_forward_bytes.pop(0)
suffix_bytes.append(byte_edge[0])
cur_state = byte_edge[1]
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)
return JumpHelper(suffix_ids, cur_state, suffix_bytes)
else:
return JumpHelper() # can't jump
def jump_forward_str_state(self, helper: JumpHelper) -> Tuple[str, int]:
if isinstance(helper.data, str):
return helper.data, -1
else:
assert isinstance(self.jump_map, OutlinesJumpForwardMap)
return self.jump_map.jump_forward_symbol(helper.state)
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
if isinstance(self.grammar, GrammarMatcher):
k = 0
for i, old_id in enumerate(old_output_ids):
if old_id == new_output_ids[i]:
k = i + 1
else:
break
# rollback to the last token that is the same
if k < len(old_output_ids):
self.grammar.rollback(len(old_output_ids) - k)
for i in range(k, len(new_output_ids)):
assert self.grammar.accept_token(new_output_ids[i])
else:
self.grammar = self.grammar[0], next_state
def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int):
if isinstance(self.grammar, GrammarMatcher):
# Note that this bitmask is a bitset, not bool
bitmask = self.grammar.get_next_token_bitmask()
# Mask the tokens that are not allowed
vocab_mask[
self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size)
] = 1
else:
guide, state = self.grammar
vocab_mask.fill_(1)
vocab_mask[guide.get_next_instruction(state).tokens] = 0
class GrammarBackend:
def __init__(
self,
tokenizer_path,
tokenizer_args_dict,
skip_tokenizer_init=False,
whitespace_patterns=None,
backend=None,
allow_jump=False,
):
self.executor = ThreadPoolExecutor()
self.backend = backend
if backend == "xgrammar":
self.grammar_cache = XGrammarBackend(
tokenizer_path=tokenizer_path,
tokenizer_args_dict=tokenizer_args_dict,
skip_tokenizer_init=skip_tokenizer_init,
whitespace_patterns=whitespace_patterns,
)
self.jump_cache = XGrammarJumpCache() if allow_jump else None
else:
assert backend == "outlines"
self.grammar_cache = OutlinesCache(
tokenizer_path=tokenizer_path,
tokenizer_args_dict=tokenizer_args_dict,
skip_tokenizer_init=skip_tokenizer_init,
constrained_json_whitespace_pattern=whitespace_patterns,
)
self.jump_cache = OutlinesJumpCache() if allow_jump else None
def _query(self, key: Tuple[str, str], vocab_size: int) -> Grammar:
if isinstance(self.grammar_cache, XGrammarBackend):
return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache)
else:
guide, regex = self.grammar_cache.query(key)
jump_map = self.jump_cache.query(regex)
return Grammar((guide, 0), jump_map)
def query(self, key: Tuple[str, str], vocab_size: int) -> Future:
return self.executor.submit(self._query, key, vocab_size)
def reset(self):
self.grammar_cache.reset()
self.jump_cache.reset()

View File

@@ -0,0 +1,203 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Constrained decoding with outlines backend."""
import json
import logging
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Dict, List, Optional, Tuple, Union
import torch
from interegular import InvalidSyntax, parse_pattern
from outlines.fsm.guide import RegexGuide
from outlines.models.transformers import TransformerTokenizer
from pydantic import BaseModel
from sglang.srt.constrained.base_tool_cache import BaseToolCache
from sglang.srt.constrained.outlines_jump_forward import (
OutlinesJumpForwardCache,
OutlinesJumpForwardMap,
)
logger = logging.getLogger(__name__)
try:
from outlines.fsm.json_schema import build_regex_from_object
except ImportError:
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
# which only accepts string schema as input.
from outlines.fsm.json_schema import build_regex_from_schema
def build_regex_from_object(
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
):
if isinstance(object, type(BaseModel)):
schema = json.dumps(object.model_json_schema())
elif isinstance(object, Dict):
schema = json.dumps(object)
else:
schema = object
return build_regex_from_schema(schema, whitespace_pattern)
class OutlinesGrammar:
def __init__(
self,
guide: RegexGuide,
state: int,
jump_forward_map: Union[OutlinesJumpForwardMap, None],
) -> None:
self.guide = guide
self.state = state
self.jump_forward_map = jump_forward_map
def accept_token(self, token: int):
self.state = self.guide.get_next_state(self.state, token)
def try_jump_forward(self, tokenizer) -> Optional[Tuple]:
if not self.jump_forward_map:
return None
jump_forward_bytes = self.jump_forward_map.jump_forward_byte(self.state)
if jump_forward_bytes is None or len(jump_forward_bytes) <= 1:
return None
# preprocess the jump forward string
suffix_bytes = []
continuation_range = range(0x80, 0xC0)
cur_state = self.state
while (
len(jump_forward_bytes) and jump_forward_bytes[0][0] in continuation_range
):
# continuation bytes
byte_edge = jump_forward_bytes.pop(0)
suffix_bytes.append(byte_edge[0])
cur_state = byte_edge[1]
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)
return suffix_ids, cur_state
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
_, cur_state = helper
return self.jump_forward_map.jump_forward_symbol(cur_state)
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
self.state = next_state
def fill_vocab_mask(self, vocab_mask: torch.Tensor):
vocab_mask.fill_(1)
vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0
class OutlinesGrammarBackend:
def __init__(
self,
tokenizer,
whitespace_patterns: bool,
allow_jump_forward: bool,
):
self.executor = ThreadPoolExecutor()
self.grammar_cache = OutlinesCache(
tokenizer,
whitespace_pattern=whitespace_patterns,
)
self.jump_forward_cache = (
OutlinesJumpForwardCache() if allow_jump_forward else None
)
def _query(self, key: Tuple[str, str]) -> OutlinesGrammar:
guide, regex = self.grammar_cache.query(key)
jump_forward_map = (
self.jump_forward_cache.query(regex) if self.jump_forward_cache else None
)
return OutlinesGrammar(guide, 0, jump_forward_map)
def query(self, key: Tuple[str, str]) -> Future:
return self.executor.submit(self._query, key)
def reset(self):
self.grammar_cache.reset()
if self.jump_forward_cache:
self.jump_forward_cache.reset()
class OutlinesCache(BaseToolCache):
def __init__(
self,
tokenizer,
whitespace_pattern=None,
):
super().__init__(enable=True)
try:
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
except AttributeError:
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
origin_pad_token_id = tokenizer.pad_token_id
def fset(self, value):
self._value = value
type(tokenizer).pad_token_id = property(
fget=type(tokenizer).pad_token_id.fget, fset=fset
)
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
self.outlines_tokenizer.pad_token = (
self.outlines_tokenizer.tokenizer.pad_token
)
self.outlines_tokenizer.vocabulary = (
self.outlines_tokenizer.tokenizer.get_vocab()
)
self.whitespace_pattern = whitespace_pattern
def init_value(self, key):
key_type, key_string = key
if key_type == "json":
try:
regex = build_regex_from_object(
key_string,
whitespace_pattern=self.whitespace_pattern,
)
except NotImplementedError as e:
logger.warning(
f"skip invalid json schema: json_schema={key_string}, {e=}"
)
return None, key_string
elif key_type == "regex":
regex = key_string
else:
raise ValueError(f"Invalid key_type: {key_type}")
try:
parse_pattern(regex)
except InvalidSyntax as e:
logger.warning(f"skip invalid regex guide: {regex=}, {e=}")
return None, regex
ret = RegexGuide(regex, self.outlines_tokenizer), regex
return ret
def _query(self, key: Tuple[str, str]):
guide, regex = self.grammar_cache.query(key)
jump_forward_map = (
self.jump_forward_cache.query(regex) if self.jump_forward_cache else None
)
return OutlinesGrammar(guide, 0, jump_forward_map)

View File

@@ -1,96 +0,0 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Cache for the compressed finite state machine."""
import logging
from interegular import InvalidSyntax, parse_pattern
from outlines.fsm.guide import RegexGuide
from outlines.models.transformers import TransformerTokenizer
from transformers import AutoTokenizer
from sglang.srt.constrained import build_regex_from_object
from sglang.srt.constrained.base_tool_cache import BaseToolCache
logger = logging.getLogger(__name__)
class OutlinesCache(BaseToolCache):
def __init__(
self,
tokenizer_path,
tokenizer_args_dict,
enable=True,
skip_tokenizer_init=False,
constrained_json_whitespace_pattern=None,
):
super().__init__(enable=enable)
if (
skip_tokenizer_init
or tokenizer_path.endswith(".json")
or tokenizer_path.endswith(".model")
):
# Do not support TiktokenTokenizer or SentencePieceTokenizer
return
tokenizer_args_dict.setdefault("padding_side", "left")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
try:
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
except AttributeError:
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
origin_pad_token_id = tokenizer.pad_token_id
def fset(self, value):
self._value = value
type(tokenizer).pad_token_id = property(
fget=type(tokenizer).pad_token_id.fget, fset=fset
)
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
self.outlines_tokenizer.pad_token = (
self.outlines_tokenizer.tokenizer.pad_token
)
self.outlines_tokenizer.vocabulary = (
self.outlines_tokenizer.tokenizer.get_vocab()
)
self.constrained_json_whitespace_pattern = constrained_json_whitespace_pattern
def init_value(self, key):
key_type, key_string = key
if key_type == "json":
try:
regex = build_regex_from_object(
key_string,
whitespace_pattern=self.constrained_json_whitespace_pattern,
)
except NotImplementedError as e:
logger.warning(
f"skip invalid json schema: json_schema={key_string}, {e=}"
)
return None, key_string
elif key_type == "regex":
regex = key_string
else:
raise ValueError(f"Invalid key_type: {key_type}")
try:
parse_pattern(regex)
except InvalidSyntax as e:
logger.warning(f"skip invalid regex guide: {regex=}, {e=}")
return None, regex
return RegexGuide(regex, self.outlines_tokenizer), regex

View File

@@ -164,7 +164,7 @@ class OutlinesJumpForwardMap:
)
class OutlinesJumpCache(BaseToolCache):
class OutlinesJumpForwardCache(BaseToolCache):
def __init__(self):
super().__init__()

View File

@@ -0,0 +1,127 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Constrained decoding with xgrammar backend."""
from concurrent.futures import Future, ThreadPoolExecutor
from typing import List, Tuple
import torch
try:
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
import_error = None
except ImportError as e:
import_error = e
class Dummy:
pass
GrammarMatcher = CompiledGrammar = CachedGrammarCompiler = Dummy
MAX_ROLLBACK_TOKENS = 10
class XGrammarGrammar:
def __init__(self, matcher: GrammarMatcher, vocab_size: int) -> None:
self.matcher = matcher
self.vocab_size = vocab_size
def accept_token(self, token: int):
assert self.matcher.accept_token(token)
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
return [], self.matcher.find_jump_forward_string()
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
_, data = helper
return data, -1
def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
k = 0
for i, old_id in enumerate(old_output_ids):
if old_id == new_output_ids[i]:
k = i + 1
else:
break
# rollback to the last token that is the same
if k < len(old_output_ids):
self.matcher.rollback(len(old_output_ids) - k)
for i in range(k, len(new_output_ids)):
assert self.matcher.accept_token(new_output_ids[i])
def fill_vocab_mask(self, vocab_mask: torch.Tensor):
# Note that this bitmask is a bitset, not bool
bitmask = self.matcher.get_next_token_bitmask()
# Mask the tokens that are not allowed
vocab_mask[
self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size)
] = 1
class XGrammarGrammarBackend:
def __init__(
self,
tokenizer,
vocab_size: int,
):
if import_error:
raise import_error
self.executor = ThreadPoolExecutor()
self.grammar_cache = XGrammarCache(tokenizer, vocab_size)
self.vocab_size = vocab_size
def _query(self, key: Tuple[str, str]) -> XGrammarGrammar:
return XGrammarGrammar(self.grammar_cache.query(key), self.vocab_size)
def query(self, key: Tuple[str, str]) -> Future:
return self.executor.submit(self._query, key)
def reset(self):
self.grammar_cache.reset()
class XGrammarCache:
def __init__(self, tokenizer, vocab_size: int):
self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
self.vocab_size = vocab_size
def get_context(self, key: Tuple[str, str]) -> CompiledGrammar:
key_type, key_string = key
if key_type == "json":
return self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
elif key_type == "regex":
raise ValueError("regex hasn't been supported by xgrammar yet")
else:
raise ValueError(f"Invalid key_type: {key_type}")
def query(self, key: Tuple[str, str]) -> GrammarMatcher:
ctx = self.get_context(key)
return GrammarMatcher(
ctx,
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
mask_vocab_size=self.vocab_size,
)
def reset(self):
self.grammar_cache.clear()

View File

@@ -1,75 +0,0 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Cache for the compressed finite state machine."""
from typing import Tuple
from transformers import AutoTokenizer
try:
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
except ImportError as e:
class Dummy:
pass
GrammarMatcher = Dummy
CompiledGrammar = Dummy
CachedGrammarCompiler = Dummy
MAX_ROLLBACK_TOKENS = 10
class XGrammarJumpCache:
"""A dummy class."""
def reset(self):
pass
class XGrammarBackend:
def __init__(
self,
tokenizer_path,
tokenizer_args_dict,
skip_tokenizer_init=False,
whitespace_patterns=None,
):
# TODO(dark): how to deal with whitespace_patterns and skip_tokenizer_init
if skip_tokenizer_init:
return
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
self.grammar_cache: CachedGrammarCompiler = CachedGrammarCompiler(
tokenizer_or_vocab=tokenizer
)
def get_context(self, key: Tuple[str, str]) -> CompiledGrammar:
key_type, key_string = key
if key_type == "json":
return self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
elif key_type == "regex":
raise ValueError("regex hasn't been supported by xgrammar yet")
else:
raise ValueError(f"Invalid key_type: {key_type}")
def query(self, key: Tuple[str, str], vocab_size: int) -> GrammarMatcher:
ctx = self.get_context(key)
return GrammarMatcher(
ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size
)
def reset(self):
self.grammar_cache.clear()