Fix grammar backend (#2018)
This commit is contained in:
@@ -13,30 +13,5 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
# TODO(lmzheng): make this an optional dependency
|
||||||
from typing import Dict, Optional, Union
|
from sglang.srt.constrained.outlines_backend import build_regex_from_object
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
try:
|
|
||||||
from outlines.fsm.json_schema import build_regex_from_object
|
|
||||||
except ImportError:
|
|
||||||
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
|
|
||||||
# which only accepts string schema as input.
|
|
||||||
from outlines.fsm.json_schema import build_regex_from_schema
|
|
||||||
|
|
||||||
def build_regex_from_object(
|
|
||||||
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
|
||||||
):
|
|
||||||
if isinstance(object, type(BaseModel)):
|
|
||||||
schema = json.dumps(object.model_json_schema())
|
|
||||||
elif isinstance(object, Dict):
|
|
||||||
schema = json.dumps(object)
|
|
||||||
else:
|
|
||||||
schema = object
|
|
||||||
return build_regex_from_schema(schema, whitespace_pattern)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"build_regex_from_object",
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -95,9 +95,7 @@ class BaseToolCache:
|
|||||||
|
|
||||||
def get_cache_hit_rate(self):
|
def get_cache_hit_rate(self):
|
||||||
with self.lock_metrics:
|
with self.lock_metrics:
|
||||||
if self.metrics["total"] == 0:
|
return self.metrics["hit"] / max(self.metrics["total"], 1)
|
||||||
return 0
|
|
||||||
return self.metrics["hit"] / self.metrics["total"]
|
|
||||||
|
|
||||||
def get_avg_init_time(self):
|
def get_avg_init_time(self):
|
||||||
with self.lock_metrics:
|
with self.lock_metrics:
|
||||||
|
|||||||
@@ -1,182 +0,0 @@
|
|||||||
"""
|
|
||||||
Copyright 2023-2024 SGLang Team
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""Cache for the compressed finite state machine."""
|
|
||||||
import logging
|
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
|
||||||
from typing import List, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from sglang.srt.constrained.outlines_cache import OutlinesCache, RegexGuide
|
|
||||||
from sglang.srt.constrained.outlines_jump_forward import (
|
|
||||||
OutlinesJumpCache,
|
|
||||||
OutlinesJumpForwardMap,
|
|
||||||
)
|
|
||||||
from sglang.srt.constrained.xgrammar_cache import (
|
|
||||||
GrammarMatcher,
|
|
||||||
XGrammarBackend,
|
|
||||||
XGrammarJumpCache,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class JumpHelper:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, data: Union[List, str] = "", state: int = -1, suffix_ids=[]
|
|
||||||
) -> None:
|
|
||||||
self.data: Union[List, str] = data
|
|
||||||
self.state: int = state
|
|
||||||
self.suffix_ids: List[int] = suffix_ids
|
|
||||||
|
|
||||||
def can_jump(self):
|
|
||||||
return len(self.data) > 0
|
|
||||||
|
|
||||||
|
|
||||||
class Grammar:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]],
|
|
||||||
jump_map: Union[XGrammarJumpCache, OutlinesJumpForwardMap, None],
|
|
||||||
) -> None:
|
|
||||||
self.grammar = grammar
|
|
||||||
self.jump_map = jump_map
|
|
||||||
|
|
||||||
def accept_token(self, token: int):
|
|
||||||
if isinstance(self.grammar, GrammarMatcher):
|
|
||||||
assert self.grammar.accept_token(token)
|
|
||||||
else:
|
|
||||||
guide, state = self.grammar
|
|
||||||
self.grammar = guide, guide.get_next_state(state, token)
|
|
||||||
|
|
||||||
def try_jump(self, tokenizer) -> JumpHelper:
|
|
||||||
if isinstance(self.jump_map, XGrammarJumpCache):
|
|
||||||
assert isinstance(self.grammar, GrammarMatcher)
|
|
||||||
return JumpHelper(self.grammar.find_jump_forward_string())
|
|
||||||
elif isinstance(self.jump_map, OutlinesJumpForwardMap):
|
|
||||||
assert isinstance(self.grammar, Tuple)
|
|
||||||
|
|
||||||
_, state = self.grammar
|
|
||||||
jump_forward_bytes = self.jump_map.jump_forward_byte(state)
|
|
||||||
if jump_forward_bytes is None or len(jump_forward_bytes) == 0:
|
|
||||||
return JumpHelper() # can't jump
|
|
||||||
|
|
||||||
# preprocess the jump forward string
|
|
||||||
suffix_bytes = []
|
|
||||||
continuation_range = range(0x80, 0xC0)
|
|
||||||
cur_state = state
|
|
||||||
while (
|
|
||||||
len(jump_forward_bytes)
|
|
||||||
and jump_forward_bytes[0][0] in continuation_range
|
|
||||||
):
|
|
||||||
# continuation bytes
|
|
||||||
byte_edge = jump_forward_bytes.pop(0)
|
|
||||||
suffix_bytes.append(byte_edge[0])
|
|
||||||
cur_state = byte_edge[1]
|
|
||||||
|
|
||||||
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
|
|
||||||
suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)
|
|
||||||
return JumpHelper(suffix_ids, cur_state, suffix_bytes)
|
|
||||||
else:
|
|
||||||
return JumpHelper() # can't jump
|
|
||||||
|
|
||||||
def jump_forward_str_state(self, helper: JumpHelper) -> Tuple[str, int]:
|
|
||||||
if isinstance(helper.data, str):
|
|
||||||
return helper.data, -1
|
|
||||||
else:
|
|
||||||
assert isinstance(self.jump_map, OutlinesJumpForwardMap)
|
|
||||||
return self.jump_map.jump_forward_symbol(helper.state)
|
|
||||||
|
|
||||||
def jump_and_retokenize(
|
|
||||||
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
|
||||||
):
|
|
||||||
if isinstance(self.grammar, GrammarMatcher):
|
|
||||||
k = 0
|
|
||||||
for i, old_id in enumerate(old_output_ids):
|
|
||||||
if old_id == new_output_ids[i]:
|
|
||||||
k = i + 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
# rollback to the last token that is the same
|
|
||||||
if k < len(old_output_ids):
|
|
||||||
self.grammar.rollback(len(old_output_ids) - k)
|
|
||||||
|
|
||||||
for i in range(k, len(new_output_ids)):
|
|
||||||
assert self.grammar.accept_token(new_output_ids[i])
|
|
||||||
else:
|
|
||||||
self.grammar = self.grammar[0], next_state
|
|
||||||
|
|
||||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int):
|
|
||||||
if isinstance(self.grammar, GrammarMatcher):
|
|
||||||
# Note that this bitmask is a bitset, not bool
|
|
||||||
bitmask = self.grammar.get_next_token_bitmask()
|
|
||||||
# Mask the tokens that are not allowed
|
|
||||||
vocab_mask[
|
|
||||||
self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size)
|
|
||||||
] = 1
|
|
||||||
else:
|
|
||||||
guide, state = self.grammar
|
|
||||||
vocab_mask.fill_(1)
|
|
||||||
vocab_mask[guide.get_next_instruction(state).tokens] = 0
|
|
||||||
|
|
||||||
|
|
||||||
class GrammarBackend:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tokenizer_path,
|
|
||||||
tokenizer_args_dict,
|
|
||||||
skip_tokenizer_init=False,
|
|
||||||
whitespace_patterns=None,
|
|
||||||
backend=None,
|
|
||||||
allow_jump=False,
|
|
||||||
):
|
|
||||||
self.executor = ThreadPoolExecutor()
|
|
||||||
self.backend = backend
|
|
||||||
|
|
||||||
if backend == "xgrammar":
|
|
||||||
self.grammar_cache = XGrammarBackend(
|
|
||||||
tokenizer_path=tokenizer_path,
|
|
||||||
tokenizer_args_dict=tokenizer_args_dict,
|
|
||||||
skip_tokenizer_init=skip_tokenizer_init,
|
|
||||||
whitespace_patterns=whitespace_patterns,
|
|
||||||
)
|
|
||||||
self.jump_cache = XGrammarJumpCache() if allow_jump else None
|
|
||||||
else:
|
|
||||||
assert backend == "outlines"
|
|
||||||
self.grammar_cache = OutlinesCache(
|
|
||||||
tokenizer_path=tokenizer_path,
|
|
||||||
tokenizer_args_dict=tokenizer_args_dict,
|
|
||||||
skip_tokenizer_init=skip_tokenizer_init,
|
|
||||||
constrained_json_whitespace_pattern=whitespace_patterns,
|
|
||||||
)
|
|
||||||
self.jump_cache = OutlinesJumpCache() if allow_jump else None
|
|
||||||
|
|
||||||
def _query(self, key: Tuple[str, str], vocab_size: int) -> Grammar:
|
|
||||||
if isinstance(self.grammar_cache, XGrammarBackend):
|
|
||||||
return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache)
|
|
||||||
else:
|
|
||||||
guide, regex = self.grammar_cache.query(key)
|
|
||||||
jump_map = self.jump_cache.query(regex)
|
|
||||||
return Grammar((guide, 0), jump_map)
|
|
||||||
|
|
||||||
def query(self, key: Tuple[str, str], vocab_size: int) -> Future:
|
|
||||||
return self.executor.submit(self._query, key, vocab_size)
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.grammar_cache.reset()
|
|
||||||
self.jump_cache.reset()
|
|
||||||
203
python/sglang/srt/constrained/outlines_backend.py
Normal file
203
python/sglang/srt/constrained/outlines_backend.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""Constrained decoding with outlines backend."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from interegular import InvalidSyntax, parse_pattern
|
||||||
|
from outlines.fsm.guide import RegexGuide
|
||||||
|
from outlines.models.transformers import TransformerTokenizer
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||||
|
from sglang.srt.constrained.outlines_jump_forward import (
|
||||||
|
OutlinesJumpForwardCache,
|
||||||
|
OutlinesJumpForwardMap,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from outlines.fsm.json_schema import build_regex_from_object
|
||||||
|
except ImportError:
|
||||||
|
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
|
||||||
|
# which only accepts string schema as input.
|
||||||
|
from outlines.fsm.json_schema import build_regex_from_schema
|
||||||
|
|
||||||
|
def build_regex_from_object(
|
||||||
|
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
||||||
|
):
|
||||||
|
if isinstance(object, type(BaseModel)):
|
||||||
|
schema = json.dumps(object.model_json_schema())
|
||||||
|
elif isinstance(object, Dict):
|
||||||
|
schema = json.dumps(object)
|
||||||
|
else:
|
||||||
|
schema = object
|
||||||
|
return build_regex_from_schema(schema, whitespace_pattern)
|
||||||
|
|
||||||
|
|
||||||
|
class OutlinesGrammar:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
guide: RegexGuide,
|
||||||
|
state: int,
|
||||||
|
jump_forward_map: Union[OutlinesJumpForwardMap, None],
|
||||||
|
) -> None:
|
||||||
|
self.guide = guide
|
||||||
|
self.state = state
|
||||||
|
self.jump_forward_map = jump_forward_map
|
||||||
|
|
||||||
|
def accept_token(self, token: int):
|
||||||
|
self.state = self.guide.get_next_state(self.state, token)
|
||||||
|
|
||||||
|
def try_jump_forward(self, tokenizer) -> Optional[Tuple]:
|
||||||
|
if not self.jump_forward_map:
|
||||||
|
return None
|
||||||
|
|
||||||
|
jump_forward_bytes = self.jump_forward_map.jump_forward_byte(self.state)
|
||||||
|
if jump_forward_bytes is None or len(jump_forward_bytes) <= 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# preprocess the jump forward string
|
||||||
|
suffix_bytes = []
|
||||||
|
continuation_range = range(0x80, 0xC0)
|
||||||
|
cur_state = self.state
|
||||||
|
while (
|
||||||
|
len(jump_forward_bytes) and jump_forward_bytes[0][0] in continuation_range
|
||||||
|
):
|
||||||
|
# continuation bytes
|
||||||
|
byte_edge = jump_forward_bytes.pop(0)
|
||||||
|
suffix_bytes.append(byte_edge[0])
|
||||||
|
cur_state = byte_edge[1]
|
||||||
|
|
||||||
|
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
|
||||||
|
suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)
|
||||||
|
return suffix_ids, cur_state
|
||||||
|
|
||||||
|
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||||
|
_, cur_state = helper
|
||||||
|
return self.jump_forward_map.jump_forward_symbol(cur_state)
|
||||||
|
|
||||||
|
def jump_and_retokenize(
|
||||||
|
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||||
|
):
|
||||||
|
self.state = next_state
|
||||||
|
|
||||||
|
def fill_vocab_mask(self, vocab_mask: torch.Tensor):
|
||||||
|
vocab_mask.fill_(1)
|
||||||
|
vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class OutlinesGrammarBackend:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
whitespace_patterns: bool,
|
||||||
|
allow_jump_forward: bool,
|
||||||
|
):
|
||||||
|
self.executor = ThreadPoolExecutor()
|
||||||
|
self.grammar_cache = OutlinesCache(
|
||||||
|
tokenizer,
|
||||||
|
whitespace_pattern=whitespace_patterns,
|
||||||
|
)
|
||||||
|
self.jump_forward_cache = (
|
||||||
|
OutlinesJumpForwardCache() if allow_jump_forward else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def _query(self, key: Tuple[str, str]) -> OutlinesGrammar:
|
||||||
|
guide, regex = self.grammar_cache.query(key)
|
||||||
|
jump_forward_map = (
|
||||||
|
self.jump_forward_cache.query(regex) if self.jump_forward_cache else None
|
||||||
|
)
|
||||||
|
return OutlinesGrammar(guide, 0, jump_forward_map)
|
||||||
|
|
||||||
|
def query(self, key: Tuple[str, str]) -> Future:
|
||||||
|
return self.executor.submit(self._query, key)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.grammar_cache.reset()
|
||||||
|
if self.jump_forward_cache:
|
||||||
|
self.jump_forward_cache.reset()
|
||||||
|
|
||||||
|
|
||||||
|
class OutlinesCache(BaseToolCache):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
whitespace_pattern=None,
|
||||||
|
):
|
||||||
|
super().__init__(enable=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||||
|
except AttributeError:
|
||||||
|
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
|
||||||
|
origin_pad_token_id = tokenizer.pad_token_id
|
||||||
|
|
||||||
|
def fset(self, value):
|
||||||
|
self._value = value
|
||||||
|
|
||||||
|
type(tokenizer).pad_token_id = property(
|
||||||
|
fget=type(tokenizer).pad_token_id.fget, fset=fset
|
||||||
|
)
|
||||||
|
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||||
|
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
|
||||||
|
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
|
||||||
|
self.outlines_tokenizer.pad_token = (
|
||||||
|
self.outlines_tokenizer.tokenizer.pad_token
|
||||||
|
)
|
||||||
|
self.outlines_tokenizer.vocabulary = (
|
||||||
|
self.outlines_tokenizer.tokenizer.get_vocab()
|
||||||
|
)
|
||||||
|
self.whitespace_pattern = whitespace_pattern
|
||||||
|
|
||||||
|
def init_value(self, key):
|
||||||
|
key_type, key_string = key
|
||||||
|
if key_type == "json":
|
||||||
|
try:
|
||||||
|
regex = build_regex_from_object(
|
||||||
|
key_string,
|
||||||
|
whitespace_pattern=self.whitespace_pattern,
|
||||||
|
)
|
||||||
|
except NotImplementedError as e:
|
||||||
|
logger.warning(
|
||||||
|
f"skip invalid json schema: json_schema={key_string}, {e=}"
|
||||||
|
)
|
||||||
|
return None, key_string
|
||||||
|
elif key_type == "regex":
|
||||||
|
regex = key_string
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid key_type: {key_type}")
|
||||||
|
try:
|
||||||
|
parse_pattern(regex)
|
||||||
|
except InvalidSyntax as e:
|
||||||
|
logger.warning(f"skip invalid regex guide: {regex=}, {e=}")
|
||||||
|
return None, regex
|
||||||
|
|
||||||
|
ret = RegexGuide(regex, self.outlines_tokenizer), regex
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _query(self, key: Tuple[str, str]):
|
||||||
|
guide, regex = self.grammar_cache.query(key)
|
||||||
|
jump_forward_map = (
|
||||||
|
self.jump_forward_cache.query(regex) if self.jump_forward_cache else None
|
||||||
|
)
|
||||||
|
return OutlinesGrammar(guide, 0, jump_forward_map)
|
||||||
@@ -1,96 +0,0 @@
|
|||||||
"""
|
|
||||||
Copyright 2023-2024 SGLang Team
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""Cache for the compressed finite state machine."""
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from interegular import InvalidSyntax, parse_pattern
|
|
||||||
from outlines.fsm.guide import RegexGuide
|
|
||||||
from outlines.models.transformers import TransformerTokenizer
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from sglang.srt.constrained import build_regex_from_object
|
|
||||||
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class OutlinesCache(BaseToolCache):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tokenizer_path,
|
|
||||||
tokenizer_args_dict,
|
|
||||||
enable=True,
|
|
||||||
skip_tokenizer_init=False,
|
|
||||||
constrained_json_whitespace_pattern=None,
|
|
||||||
):
|
|
||||||
super().__init__(enable=enable)
|
|
||||||
|
|
||||||
if (
|
|
||||||
skip_tokenizer_init
|
|
||||||
or tokenizer_path.endswith(".json")
|
|
||||||
or tokenizer_path.endswith(".model")
|
|
||||||
):
|
|
||||||
# Do not support TiktokenTokenizer or SentencePieceTokenizer
|
|
||||||
return
|
|
||||||
|
|
||||||
tokenizer_args_dict.setdefault("padding_side", "left")
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
|
|
||||||
try:
|
|
||||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
|
||||||
except AttributeError:
|
|
||||||
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
|
|
||||||
origin_pad_token_id = tokenizer.pad_token_id
|
|
||||||
|
|
||||||
def fset(self, value):
|
|
||||||
self._value = value
|
|
||||||
|
|
||||||
type(tokenizer).pad_token_id = property(
|
|
||||||
fget=type(tokenizer).pad_token_id.fget, fset=fset
|
|
||||||
)
|
|
||||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
|
||||||
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
|
|
||||||
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
|
|
||||||
self.outlines_tokenizer.pad_token = (
|
|
||||||
self.outlines_tokenizer.tokenizer.pad_token
|
|
||||||
)
|
|
||||||
self.outlines_tokenizer.vocabulary = (
|
|
||||||
self.outlines_tokenizer.tokenizer.get_vocab()
|
|
||||||
)
|
|
||||||
self.constrained_json_whitespace_pattern = constrained_json_whitespace_pattern
|
|
||||||
|
|
||||||
def init_value(self, key):
|
|
||||||
key_type, key_string = key
|
|
||||||
if key_type == "json":
|
|
||||||
try:
|
|
||||||
regex = build_regex_from_object(
|
|
||||||
key_string,
|
|
||||||
whitespace_pattern=self.constrained_json_whitespace_pattern,
|
|
||||||
)
|
|
||||||
except NotImplementedError as e:
|
|
||||||
logger.warning(
|
|
||||||
f"skip invalid json schema: json_schema={key_string}, {e=}"
|
|
||||||
)
|
|
||||||
return None, key_string
|
|
||||||
elif key_type == "regex":
|
|
||||||
regex = key_string
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid key_type: {key_type}")
|
|
||||||
try:
|
|
||||||
parse_pattern(regex)
|
|
||||||
except InvalidSyntax as e:
|
|
||||||
logger.warning(f"skip invalid regex guide: {regex=}, {e=}")
|
|
||||||
return None, regex
|
|
||||||
return RegexGuide(regex, self.outlines_tokenizer), regex
|
|
||||||
@@ -164,7 +164,7 @@ class OutlinesJumpForwardMap:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class OutlinesJumpCache(BaseToolCache):
|
class OutlinesJumpForwardCache(BaseToolCache):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|||||||
127
python/sglang/srt/constrained/xgrammar_backend.py
Normal file
127
python/sglang/srt/constrained/xgrammar_backend.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""Constrained decoding with xgrammar backend."""
|
||||||
|
|
||||||
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
|
||||||
|
|
||||||
|
import_error = None
|
||||||
|
except ImportError as e:
|
||||||
|
import_error = e
|
||||||
|
|
||||||
|
class Dummy:
|
||||||
|
pass
|
||||||
|
|
||||||
|
GrammarMatcher = CompiledGrammar = CachedGrammarCompiler = Dummy
|
||||||
|
|
||||||
|
|
||||||
|
MAX_ROLLBACK_TOKENS = 10
|
||||||
|
|
||||||
|
|
||||||
|
class XGrammarGrammar:
|
||||||
|
|
||||||
|
def __init__(self, matcher: GrammarMatcher, vocab_size: int) -> None:
|
||||||
|
self.matcher = matcher
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
|
def accept_token(self, token: int):
|
||||||
|
assert self.matcher.accept_token(token)
|
||||||
|
|
||||||
|
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
|
||||||
|
return [], self.matcher.find_jump_forward_string()
|
||||||
|
|
||||||
|
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||||
|
_, data = helper
|
||||||
|
return data, -1
|
||||||
|
|
||||||
|
def jump_and_retokenize(
|
||||||
|
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||||
|
):
|
||||||
|
k = 0
|
||||||
|
for i, old_id in enumerate(old_output_ids):
|
||||||
|
if old_id == new_output_ids[i]:
|
||||||
|
k = i + 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# rollback to the last token that is the same
|
||||||
|
if k < len(old_output_ids):
|
||||||
|
self.matcher.rollback(len(old_output_ids) - k)
|
||||||
|
|
||||||
|
for i in range(k, len(new_output_ids)):
|
||||||
|
assert self.matcher.accept_token(new_output_ids[i])
|
||||||
|
|
||||||
|
def fill_vocab_mask(self, vocab_mask: torch.Tensor):
|
||||||
|
# Note that this bitmask is a bitset, not bool
|
||||||
|
bitmask = self.matcher.get_next_token_bitmask()
|
||||||
|
# Mask the tokens that are not allowed
|
||||||
|
vocab_mask[
|
||||||
|
self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size)
|
||||||
|
] = 1
|
||||||
|
|
||||||
|
|
||||||
|
class XGrammarGrammarBackend:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
vocab_size: int,
|
||||||
|
):
|
||||||
|
if import_error:
|
||||||
|
raise import_error
|
||||||
|
|
||||||
|
self.executor = ThreadPoolExecutor()
|
||||||
|
self.grammar_cache = XGrammarCache(tokenizer, vocab_size)
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
|
def _query(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
||||||
|
return XGrammarGrammar(self.grammar_cache.query(key), self.vocab_size)
|
||||||
|
|
||||||
|
def query(self, key: Tuple[str, str]) -> Future:
|
||||||
|
return self.executor.submit(self._query, key)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.grammar_cache.reset()
|
||||||
|
|
||||||
|
|
||||||
|
class XGrammarCache:
|
||||||
|
def __init__(self, tokenizer, vocab_size: int):
|
||||||
|
self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
|
def get_context(self, key: Tuple[str, str]) -> CompiledGrammar:
|
||||||
|
key_type, key_string = key
|
||||||
|
if key_type == "json":
|
||||||
|
return self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
|
||||||
|
elif key_type == "regex":
|
||||||
|
raise ValueError("regex hasn't been supported by xgrammar yet")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid key_type: {key_type}")
|
||||||
|
|
||||||
|
def query(self, key: Tuple[str, str]) -> GrammarMatcher:
|
||||||
|
ctx = self.get_context(key)
|
||||||
|
return GrammarMatcher(
|
||||||
|
ctx,
|
||||||
|
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
||||||
|
mask_vocab_size=self.vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.grammar_cache.clear()
|
||||||
@@ -1,75 +0,0 @@
|
|||||||
"""
|
|
||||||
Copyright 2023-2024 SGLang Team
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""Cache for the compressed finite state machine."""
|
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
try:
|
|
||||||
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
|
|
||||||
except ImportError as e:
|
|
||||||
|
|
||||||
class Dummy:
|
|
||||||
pass
|
|
||||||
|
|
||||||
GrammarMatcher = Dummy
|
|
||||||
CompiledGrammar = Dummy
|
|
||||||
CachedGrammarCompiler = Dummy
|
|
||||||
|
|
||||||
|
|
||||||
MAX_ROLLBACK_TOKENS = 10
|
|
||||||
|
|
||||||
|
|
||||||
class XGrammarJumpCache:
|
|
||||||
"""A dummy class."""
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class XGrammarBackend:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tokenizer_path,
|
|
||||||
tokenizer_args_dict,
|
|
||||||
skip_tokenizer_init=False,
|
|
||||||
whitespace_patterns=None,
|
|
||||||
):
|
|
||||||
# TODO(dark): how to deal with whitespace_patterns and skip_tokenizer_init
|
|
||||||
if skip_tokenizer_init:
|
|
||||||
return
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
|
|
||||||
self.grammar_cache: CachedGrammarCompiler = CachedGrammarCompiler(
|
|
||||||
tokenizer_or_vocab=tokenizer
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_context(self, key: Tuple[str, str]) -> CompiledGrammar:
|
|
||||||
key_type, key_string = key
|
|
||||||
if key_type == "json":
|
|
||||||
return self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
|
|
||||||
elif key_type == "regex":
|
|
||||||
raise ValueError("regex hasn't been supported by xgrammar yet")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid key_type: {key_type}")
|
|
||||||
|
|
||||||
def query(self, key: Tuple[str, str], vocab_size: int) -> GrammarMatcher:
|
|
||||||
ctx = self.get_context(key)
|
|
||||||
return GrammarMatcher(
|
|
||||||
ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.grammar_cache.clear()
|
|
||||||
@@ -37,7 +37,6 @@ import torch
|
|||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.constrained.grammar import Grammar
|
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
@@ -249,7 +248,7 @@ class Req:
|
|||||||
self.embedding = None
|
self.embedding = None
|
||||||
|
|
||||||
# Constrained decoding
|
# Constrained decoding
|
||||||
self.grammar: Optional[Grammar] = None
|
self.grammar = None
|
||||||
|
|
||||||
# The number of cached tokens, that were already cached in the KV cache
|
# The number of cached tokens, that were already cached in the KV cache
|
||||||
self.cached_tokens = 0
|
self.cached_tokens = 0
|
||||||
@@ -359,8 +358,6 @@ class Req:
|
|||||||
return
|
return
|
||||||
|
|
||||||
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
||||||
assert self.grammar is not None and self.tokenizer is not None
|
|
||||||
|
|
||||||
if self.origin_input_text is None:
|
if self.origin_input_text is None:
|
||||||
# Recovering text can only use unpadded ids
|
# Recovering text can only use unpadded ids
|
||||||
self.origin_input_text = self.tokenizer.decode(
|
self.origin_input_text = self.tokenizer.decode(
|
||||||
@@ -809,9 +806,10 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
for i, req in enumerate(self.reqs):
|
for i, req in enumerate(self.reqs):
|
||||||
if req.grammar is not None:
|
if req.grammar is not None:
|
||||||
jump_helper = req.grammar.try_jump(req.tokenizer)
|
jump_helper = req.grammar.try_jump_forward(req.tokenizer)
|
||||||
if jump_helper.can_jump():
|
if jump_helper:
|
||||||
suffix_ids = jump_helper.suffix_ids
|
suffix_ids, _ = jump_helper
|
||||||
|
|
||||||
# Current ids, for cache and revert
|
# Current ids, for cache and revert
|
||||||
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
|
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
|
||||||
cur_output_ids = req.output_ids
|
cur_output_ids = req.output_ids
|
||||||
@@ -827,6 +825,8 @@ class ScheduleBatch:
|
|||||||
next_state,
|
next_state,
|
||||||
) = req.grammar.jump_forward_str_state(jump_helper)
|
) = req.grammar.jump_forward_str_state(jump_helper)
|
||||||
|
|
||||||
|
# Make the incrementally decoded text part of jump_forward_str
|
||||||
|
# so that the UTF-8 will not corrupt
|
||||||
jump_forward_str = new_text + jump_forward_str
|
jump_forward_str = new_text + jump_forward_str
|
||||||
if not req.jump_forward_and_retokenize(
|
if not req.jump_forward_and_retokenize(
|
||||||
jump_forward_str, next_state
|
jump_forward_str, next_state
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from concurrent import futures
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
@@ -29,7 +30,6 @@ import zmq
|
|||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.constrained.grammar import GrammarBackend
|
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
@@ -100,7 +100,7 @@ class Scheduler:
|
|||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.tp_size = server_args.tp_size
|
self.tp_size = server_args.tp_size
|
||||||
self.schedule_policy = server_args.schedule_policy
|
self.schedule_policy = server_args.schedule_policy
|
||||||
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
self.disable_jump_forward = server_args.disable_jump_forward
|
||||||
self.lora_paths = server_args.lora_paths
|
self.lora_paths = server_args.lora_paths
|
||||||
self.max_loras_per_batch = server_args.max_loras_per_batch
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
||||||
self.enable_overlap = server_args.enable_overlap_schedule
|
self.enable_overlap = server_args.enable_overlap_schedule
|
||||||
@@ -234,22 +234,33 @@ class Scheduler:
|
|||||||
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init the grammar cache for constrained generation
|
# Init the grammar backend for constrained generation
|
||||||
self.grammar_cache = None
|
|
||||||
self.grammar_queue: List[Req] = []
|
self.grammar_queue: List[Req] = []
|
||||||
|
|
||||||
if not server_args.skip_tokenizer_init:
|
if not server_args.skip_tokenizer_init:
|
||||||
self.grammar_cache = GrammarBackend(
|
if server_args.grammar_backend == "outlines":
|
||||||
server_args.tokenizer_path,
|
from sglang.srt.constrained.outlines_backend import (
|
||||||
{
|
OutlinesGrammarBackend,
|
||||||
"tokenizer_mode": server_args.tokenizer_mode,
|
)
|
||||||
"trust_remote_code": server_args.trust_remote_code,
|
|
||||||
},
|
self.grammar_backend = OutlinesGrammarBackend(
|
||||||
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
self.tokenizer,
|
||||||
whitespace_patterns=server_args.constrained_json_whitespace_pattern,
|
whitespace_patterns=server_args.constrained_json_whitespace_pattern,
|
||||||
backend=server_args.grammar_backend,
|
allow_jump_forward=not server_args.disable_jump_forward,
|
||||||
allow_jump=not server_args.disable_regex_jump_forward,
|
)
|
||||||
)
|
elif server_args.grammar_backend == "xgrammar":
|
||||||
|
from sglang.srt.constrained.xgrammar_backend import (
|
||||||
|
XGrammarGrammarBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.grammar_backend = XGrammarGrammarBackend(
|
||||||
|
self.tokenizer, vocab_size=self.model_config.vocab_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid grammar backend: {server_args.grammar_backend}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.grammar_backend = None
|
||||||
|
|
||||||
# Init new token estimation
|
# Init new token estimation
|
||||||
assert (
|
assert (
|
||||||
@@ -461,15 +472,14 @@ class Scheduler:
|
|||||||
req.sampling_params.json_schema is not None
|
req.sampling_params.json_schema is not None
|
||||||
or req.sampling_params.regex is not None
|
or req.sampling_params.regex is not None
|
||||||
):
|
):
|
||||||
assert self.grammar_cache is not None
|
assert self.grammar_backend is not None
|
||||||
if req.sampling_params.json_schema is not None:
|
if req.sampling_params.json_schema is not None:
|
||||||
req.grammar = self.grammar_cache.query(
|
req.grammar = self.grammar_backend.query(
|
||||||
("json", req.sampling_params.json_schema),
|
("json", req.sampling_params.json_schema),
|
||||||
self.model_config.vocab_size,
|
|
||||||
)
|
)
|
||||||
elif req.sampling_params.regex is not None:
|
elif req.sampling_params.regex is not None:
|
||||||
req.grammar = self.grammar_cache.query(
|
req.grammar = self.grammar_backend.query(
|
||||||
("regex", req.sampling_params.regex), self.model_config.vocab_size
|
("regex", req.sampling_params.regex)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Truncate prompts that are too long
|
# Truncate prompts that are too long
|
||||||
@@ -638,14 +648,14 @@ class Scheduler:
|
|||||||
return self.running_batch
|
return self.running_batch
|
||||||
|
|
||||||
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
||||||
# Check if the grammar queue is ready
|
# Check if the grammar is ready in the grammar queue
|
||||||
if self.grammar_queue:
|
if self.grammar_queue:
|
||||||
new_grammar_queue = []
|
new_grammar_queue = []
|
||||||
for req in self.grammar_queue:
|
for req in self.grammar_queue:
|
||||||
if req.grammar.done():
|
try:
|
||||||
req.grammar = req.grammar.result()
|
req.grammar = req.grammar.result(timeout=0.05)
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
else:
|
except futures._base.TimeoutError:
|
||||||
new_grammar_queue.append(req)
|
new_grammar_queue.append(req)
|
||||||
self.grammar_queue = new_grammar_queue
|
self.grammar_queue = new_grammar_queue
|
||||||
|
|
||||||
@@ -783,7 +793,7 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check for jump-forward
|
# Check for jump-forward
|
||||||
if not self.disable_regex_jump_forward:
|
if not self.disable_jump_forward:
|
||||||
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
|
||||||
self.waiting_queue.extend(jump_forward_reqs)
|
self.waiting_queue.extend(jump_forward_reqs)
|
||||||
if batch.is_empty():
|
if batch.is_empty():
|
||||||
@@ -1142,8 +1152,8 @@ class Scheduler:
|
|||||||
):
|
):
|
||||||
self.tree_cache.reset()
|
self.tree_cache.reset()
|
||||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||||
if self.grammar_cache is not None:
|
if self.grammar_backend is not None:
|
||||||
self.grammar_cache.reset()
|
self.grammar_backend.reset()
|
||||||
# TODO(dark): reset the bnf cache
|
# TODO(dark): reset the bnf cache
|
||||||
self.req_to_token_pool.clear()
|
self.req_to_token_pool.clear()
|
||||||
self.token_to_kv_pool.clear()
|
self.token_to_kv_pool.clear()
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, List, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import sglang.srt.sampling.penaltylib as penaltylib
|
import sglang.srt.sampling.penaltylib as penaltylib
|
||||||
from sglang.srt.constrained.grammar import Grammar
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
@@ -31,7 +30,7 @@ class SamplingBatchInfo:
|
|||||||
logit_bias: torch.Tensor = None
|
logit_bias: torch.Tensor = None
|
||||||
vocab_mask: Optional[torch.Tensor] = None
|
vocab_mask: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
grammars: Optional[List[Optional[Grammar]]] = None
|
grammars: Optional[List] = None
|
||||||
|
|
||||||
# Penalizer
|
# Penalizer
|
||||||
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
||||||
@@ -146,7 +145,7 @@ class SamplingBatchInfo:
|
|||||||
)
|
)
|
||||||
for i, grammar in enumerate(self.grammars):
|
for i, grammar in enumerate(self.grammars):
|
||||||
if grammar is not None:
|
if grammar is not None:
|
||||||
grammar.fill_vocab_mask(self.vocab_mask[i], self.vocab_size)
|
grammar.fill_vocab_mask(self.vocab_mask[i])
|
||||||
|
|
||||||
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
||||||
if self.penalizer_orchestrator:
|
if self.penalizer_orchestrator:
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ class ServerArgs:
|
|||||||
disable_flashinfer: bool = False
|
disable_flashinfer: bool = False
|
||||||
disable_flashinfer_sampling: bool = False
|
disable_flashinfer_sampling: bool = False
|
||||||
disable_radix_cache: bool = False
|
disable_radix_cache: bool = False
|
||||||
disable_regex_jump_forward: bool = False
|
disable_jump_forward: bool = False
|
||||||
disable_cuda_graph: bool = False
|
disable_cuda_graph: bool = False
|
||||||
disable_cuda_graph_padding: bool = False
|
disable_cuda_graph_padding: bool = False
|
||||||
disable_disk_cache: bool = False
|
disable_disk_cache: bool = False
|
||||||
@@ -574,7 +574,7 @@ class ServerArgs:
|
|||||||
type=str,
|
type=str,
|
||||||
choices=["xgrammar", "outlines"],
|
choices=["xgrammar", "outlines"],
|
||||||
default=ServerArgs.grammar_backend,
|
default=ServerArgs.grammar_backend,
|
||||||
help="Choose the backend for constrained decoding.",
|
help="Choose the backend for grammar-guided decoding.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
@@ -594,9 +594,9 @@ class ServerArgs:
|
|||||||
help="Disable RadixAttention for prefix caching.",
|
help="Disable RadixAttention for prefix caching.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-regex-jump-forward",
|
"--disable-jump-forward",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable regex jump-forward.",
|
help="Disable jump-forward for grammar-guided decoding.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-cuda-graph",
|
"--disable-cuda-graph",
|
||||||
@@ -616,7 +616,6 @@ class ServerArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-custom-all-reduce",
|
"--disable-custom-all-reduce",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=False,
|
|
||||||
help="Disable the custom all-reduce kernel and fall back to NCCL.",
|
help="Disable the custom all-reduce kernel and fall back to NCCL.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -688,7 +687,6 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--delete-ckpt-after-loading",
|
"--delete-ckpt-after-loading",
|
||||||
default=ServerArgs.delete_ckpt_after_loading,
|
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Delete the model checkpoint after loading the model.",
|
help="Delete the model checkpoint after loading the model.",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -61,18 +61,27 @@ class TestJSONConstrained(unittest.TestCase):
|
|||||||
"logprob_start_len": 0,
|
"logprob_start_len": 0,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
print(json.dumps(response.json()))
|
ret = response.json()
|
||||||
|
print(json.dumps(ret))
|
||||||
print("=" * 100)
|
print("=" * 100)
|
||||||
|
|
||||||
if not json_schema:
|
if not json_schema:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Make sure the json output is valid
|
||||||
try:
|
try:
|
||||||
js_obj = json.loads(response.json()["text"])
|
js_obj = json.loads(ret["text"])
|
||||||
except (TypeError, json.decoder.JSONDecodeError):
|
except (TypeError, json.decoder.JSONDecodeError):
|
||||||
raise
|
raise
|
||||||
assert isinstance(js_obj["name"], str)
|
|
||||||
assert isinstance(js_obj["population"], int)
|
self.assertIsInstance(js_obj["name"], str)
|
||||||
|
self.assertIsInstance(js_obj["population"], int)
|
||||||
|
|
||||||
|
# Make sure jump forward is triggered
|
||||||
|
self.assertGreater(
|
||||||
|
ret["meta_info"]["completion_tokens"],
|
||||||
|
ret["meta_info"]["completion_tokens_wo_jump_forward"],
|
||||||
|
)
|
||||||
|
|
||||||
def test_json_generate(self):
|
def test_json_generate(self):
|
||||||
self.run_decode(json_schema=self.json_schema)
|
self.run_decode(json_schema=self.json_schema)
|
||||||
@@ -100,8 +109,9 @@ class TestJSONConstrained(unittest.TestCase):
|
|||||||
except (TypeError, json.decoder.JSONDecodeError):
|
except (TypeError, json.decoder.JSONDecodeError):
|
||||||
print("JSONDecodeError", text)
|
print("JSONDecodeError", text)
|
||||||
raise
|
raise
|
||||||
assert isinstance(js_obj["name"], str), f"{js_obj=}"
|
|
||||||
assert isinstance(js_obj["population"], int), f"{js_obj=}"
|
self.assertIsInstance(js_obj["name"], str)
|
||||||
|
self.assertIsInstance(js_obj["population"], int)
|
||||||
|
|
||||||
def test_mix_json_and_other(self):
|
def test_mix_json_and_other(self):
|
||||||
json_schemas = [None, None, self.json_schema, self.json_schema] * 10
|
json_schemas = [None, None, self.json_schema, self.json_schema] * 10
|
||||||
|
|||||||
Reference in New Issue
Block a user