support parallel grammar preprocessing (#1996)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
@@ -13,25 +13,11 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""For constrained decoding."""
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
try:
|
|
||||||
from outlines.caching import cache as disk_cache
|
|
||||||
from outlines.caching import disable_cache
|
|
||||||
from outlines.fsm.guide import RegexGuide
|
|
||||||
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
|
||||||
from outlines.models.transformers import TransformerTokenizer
|
|
||||||
except ImportError as e:
|
|
||||||
print(
|
|
||||||
f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n'
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from outlines.fsm.json_schema import build_regex_from_object
|
from outlines.fsm.json_schema import build_regex_from_object
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -51,31 +37,6 @@ except ImportError:
|
|||||||
return build_regex_from_schema(schema, whitespace_pattern)
|
return build_regex_from_schema(schema, whitespace_pattern)
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
from xgrammar import (
|
|
||||||
GrammarMatcher,
|
|
||||||
GrammarMatcherInitContext,
|
|
||||||
GrammarMatcherInitContextCache,
|
|
||||||
)
|
|
||||||
except ImportError as e:
|
|
||||||
|
|
||||||
class Dummy:
|
|
||||||
pass
|
|
||||||
|
|
||||||
GrammarMatcher = Dummy
|
|
||||||
GrammarMatcherInitContext = Dummy
|
|
||||||
GrammarMatcherInitContextCache = Dummy
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"RegexGuide",
|
|
||||||
"FSMInfo",
|
|
||||||
"make_deterministic_fsm",
|
|
||||||
"build_regex_from_object",
|
"build_regex_from_object",
|
||||||
"TransformerTokenizer",
|
|
||||||
"disk_cache",
|
|
||||||
"disable_cache",
|
|
||||||
"make_byte_level_fsm",
|
|
||||||
"GrammarMatcher",
|
|
||||||
"GrammarMatcherInitContext",
|
|
||||||
"GrammarMatcherInitContextCache",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -13,25 +13,47 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""Base tool cache for constrained decoding tools."""
|
"""Base cache class for constrained decoding tools."""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from threading import Event, Lock
|
||||||
|
from typing import Any, Dict, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MapEntry:
|
||||||
|
event: Event
|
||||||
|
value: Any
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter((self.event, self.value))
|
||||||
|
|
||||||
|
|
||||||
class BaseToolCache:
|
class BaseToolCache:
|
||||||
|
|
||||||
def __init__(self, enable=True):
|
def __init__(self, enable=True):
|
||||||
self.enable = enable
|
self.enable: bool = enable
|
||||||
|
self.cache: Dict[str, MapEntry] = {}
|
||||||
|
self.metrics: Dict[str, Any] = {}
|
||||||
|
self.lock_cache: Lock = Lock()
|
||||||
|
self.lock_metrics: Lock = Lock()
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.cache = {}
|
with self.lock_cache:
|
||||||
self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0}
|
self.cache = {}
|
||||||
|
with self.lock_metrics:
|
||||||
|
self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0}
|
||||||
|
|
||||||
def query(self, key):
|
def _init_with_timer(self, key) -> Tuple[Any, float]:
|
||||||
def _init_with_timer(key):
|
start = time.monotonic()
|
||||||
start = time.monotonic()
|
val = self.init_value(key)
|
||||||
val = self.init_value(key)
|
init_time = time.monotonic() - start
|
||||||
init_time = time.monotonic() - start
|
return val, init_time
|
||||||
|
|
||||||
|
def update_time(self, init_time):
|
||||||
|
with self.lock_metrics:
|
||||||
curr_total = self.metrics["total"]
|
curr_total = self.metrics["total"]
|
||||||
new_total = curr_total + 1
|
new_total = curr_total + 1
|
||||||
|
|
||||||
@@ -39,27 +61,44 @@ class BaseToolCache:
|
|||||||
self.metrics["avg_init_time"] = (init_time / new_total) + (
|
self.metrics["avg_init_time"] = (init_time / new_total) + (
|
||||||
curr_total / new_total
|
curr_total / new_total
|
||||||
) * self.metrics["avg_init_time"]
|
) * self.metrics["avg_init_time"]
|
||||||
return val
|
|
||||||
|
|
||||||
if key in self.cache:
|
def query(self, key):
|
||||||
self.metrics["hit"] += 1
|
if not self.enable:
|
||||||
val = self.cache[key]
|
value, init_time = self._init_with_timer(key)
|
||||||
else:
|
self.update_time(init_time)
|
||||||
# Cache miss or disabled.
|
return value
|
||||||
val = _init_with_timer(key)
|
|
||||||
|
|
||||||
if self.enable:
|
with self.lock_cache:
|
||||||
|
if key in self.cache:
|
||||||
|
entry = self.cache[key]
|
||||||
|
cache_hit = True
|
||||||
|
else:
|
||||||
|
entry = MapEntry(Event(), None)
|
||||||
|
self.cache[key] = entry
|
||||||
|
cache_hit = False
|
||||||
|
|
||||||
|
with self.lock_metrics:
|
||||||
self.metrics["total"] += 1
|
self.metrics["total"] += 1
|
||||||
self.cache[key] = val
|
if cache_hit:
|
||||||
return val
|
self.metrics["hit"] += 1
|
||||||
|
|
||||||
|
if cache_hit:
|
||||||
|
entry.event.wait()
|
||||||
|
else:
|
||||||
|
entry.value, init_time = self._init_with_timer(key)
|
||||||
|
self.update_time(init_time)
|
||||||
|
entry.event.set()
|
||||||
|
return entry.value
|
||||||
|
|
||||||
def init_value(self, key):
|
def init_value(self, key):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_cache_hit_rate(self):
|
def get_cache_hit_rate(self):
|
||||||
if self.metrics["total"] == 0:
|
with self.lock_metrics:
|
||||||
return 0
|
if self.metrics["total"] == 0:
|
||||||
return self.metrics["hit"] / self.metrics["total"]
|
return 0
|
||||||
|
return self.metrics["hit"] / self.metrics["total"]
|
||||||
|
|
||||||
def get_avg_init_time(self):
|
def get_avg_init_time(self):
|
||||||
return self.metrics["avg_init_time"]
|
with self.lock_metrics:
|
||||||
|
return self.metrics["avg_init_time"]
|
||||||
|
|||||||
@@ -13,50 +13,44 @@ limitations under the License.
|
|||||||
|
|
||||||
"""Cache for the compressed finite state machine."""
|
"""Cache for the compressed finite state machine."""
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional, Tuple, Union
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.constrained import GrammarMatcher, RegexGuide
|
from sglang.srt.constrained.outlines_cache import OutlinesCache, RegexGuide
|
||||||
from sglang.srt.constrained.bnf_cache import BNFCache
|
from sglang.srt.constrained.outlines_jump_forward import (
|
||||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
OutlinesJumpCache,
|
||||||
from sglang.srt.constrained.jump_forward import JumpForwardCache, JumpForwardMap
|
OutlinesJumpForwardMap,
|
||||||
|
)
|
||||||
# from sglang.srt.managers.schedule_batch import Req
|
from sglang.srt.constrained.xgrammar_cache import (
|
||||||
|
GrammarMatcher,
|
||||||
|
XGrammarBackend,
|
||||||
|
XGrammarJumpCache,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
|
||||||
|
|
||||||
|
|
||||||
class XGrammarJump:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class JumpHelper:
|
class JumpHelper:
|
||||||
data: Union[List, str]
|
|
||||||
state: int
|
|
||||||
suffix_ids: List[int]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, data: Union[List, str] = "", state: int = -1, suffix_ids=[]
|
self, data: Union[List, str] = "", state: int = -1, suffix_ids=[]
|
||||||
) -> None:
|
) -> None:
|
||||||
self.data = data
|
self.data: Union[List, str] = data
|
||||||
self.state = state
|
self.state: int = state
|
||||||
self.suffix_ids = suffix_ids
|
self.suffix_ids: List[int] = suffix_ids
|
||||||
|
|
||||||
def can_jump(self):
|
def can_jump(self):
|
||||||
return len(self.data) > 0
|
return len(self.data) > 0
|
||||||
|
|
||||||
|
|
||||||
class Grammar:
|
class Grammar:
|
||||||
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]]
|
|
||||||
jump_map: Union[XGrammarJump, JumpForwardMap, None]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]],
|
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]],
|
||||||
jump_map: Union[XGrammarJump, JumpForwardMap, None],
|
jump_map: Union[XGrammarJumpCache, OutlinesJumpForwardMap, None],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.grammar = grammar
|
self.grammar = grammar
|
||||||
self.jump_map = jump_map
|
self.jump_map = jump_map
|
||||||
@@ -69,10 +63,10 @@ class Grammar:
|
|||||||
self.grammar = guide, guide.get_next_state(state, token)
|
self.grammar = guide, guide.get_next_state(state, token)
|
||||||
|
|
||||||
def try_jump(self, tokenizer) -> JumpHelper:
|
def try_jump(self, tokenizer) -> JumpHelper:
|
||||||
if isinstance(self.jump_map, XGrammarJump):
|
if isinstance(self.jump_map, XGrammarJumpCache):
|
||||||
assert isinstance(self.grammar, GrammarMatcher)
|
assert isinstance(self.grammar, GrammarMatcher)
|
||||||
return JumpHelper(self.grammar.find_jump_forward_string())
|
return JumpHelper(self.grammar.find_jump_forward_string())
|
||||||
elif isinstance(self.jump_map, JumpForwardMap):
|
elif isinstance(self.jump_map, OutlinesJumpForwardMap):
|
||||||
assert isinstance(self.grammar, Tuple)
|
assert isinstance(self.grammar, Tuple)
|
||||||
|
|
||||||
_, state = self.grammar
|
_, state = self.grammar
|
||||||
@@ -103,7 +97,7 @@ class Grammar:
|
|||||||
if isinstance(helper.data, str):
|
if isinstance(helper.data, str):
|
||||||
return helper.data, -1
|
return helper.data, -1
|
||||||
else:
|
else:
|
||||||
assert isinstance(self.jump_map, JumpForwardMap)
|
assert isinstance(self.jump_map, OutlinesJumpForwardMap)
|
||||||
return self.jump_map.jump_forward_symbol(helper.state)
|
return self.jump_map.jump_forward_symbol(helper.state)
|
||||||
|
|
||||||
def jump_and_retokenize(
|
def jump_and_retokenize(
|
||||||
@@ -129,7 +123,7 @@ class Grammar:
|
|||||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int):
|
def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int):
|
||||||
if isinstance(self.grammar, GrammarMatcher):
|
if isinstance(self.grammar, GrammarMatcher):
|
||||||
# Note that this bitmask is a bitset, not bool
|
# Note that this bitmask is a bitset, not bool
|
||||||
bitmask = self.grammar.find_next_token_bitmask()
|
bitmask = self.grammar.get_next_token_bitmask()
|
||||||
# Mask the tokens that are not allowed
|
# Mask the tokens that are not allowed
|
||||||
vocab_mask[
|
vocab_mask[
|
||||||
self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size)
|
self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size)
|
||||||
@@ -140,9 +134,7 @@ class Grammar:
|
|||||||
vocab_mask[guide.get_next_instruction(state).tokens] = 0
|
vocab_mask[guide.get_next_instruction(state).tokens] = 0
|
||||||
|
|
||||||
|
|
||||||
class GrammarCache:
|
class GrammarBackend:
|
||||||
grammar_cache: Union[BNFCache, FSMCache]
|
|
||||||
jump_cache: Union[XGrammarJump, JumpForwardCache, None]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -153,38 +145,38 @@ class GrammarCache:
|
|||||||
backend=None,
|
backend=None,
|
||||||
allow_jump=False,
|
allow_jump=False,
|
||||||
):
|
):
|
||||||
|
self.executor = ThreadPoolExecutor()
|
||||||
|
self.backend = backend
|
||||||
|
|
||||||
if backend == "xgrammar":
|
if backend == "xgrammar":
|
||||||
self.grammar_cache = BNFCache(
|
self.grammar_cache = XGrammarBackend(
|
||||||
tokenizer_path=tokenizer_path,
|
tokenizer_path=tokenizer_path,
|
||||||
tokenizer_args_dict=tokenizer_args_dict,
|
tokenizer_args_dict=tokenizer_args_dict,
|
||||||
skip_tokenizer_init=skip_tokenizer_init,
|
skip_tokenizer_init=skip_tokenizer_init,
|
||||||
whitespace_patterns=whitespace_patterns,
|
whitespace_patterns=whitespace_patterns,
|
||||||
)
|
)
|
||||||
self.jump_cache = XGrammarJump() if allow_jump else None
|
self.jump_cache = XGrammarJumpCache() if allow_jump else None
|
||||||
else:
|
else:
|
||||||
assert backend == "outlines"
|
assert backend == "outlines"
|
||||||
self.grammar_cache = FSMCache(
|
self.grammar_cache = OutlinesCache(
|
||||||
tokenizer_path=tokenizer_path,
|
tokenizer_path=tokenizer_path,
|
||||||
tokenizer_args_dict=tokenizer_args_dict,
|
tokenizer_args_dict=tokenizer_args_dict,
|
||||||
skip_tokenizer_init=skip_tokenizer_init,
|
skip_tokenizer_init=skip_tokenizer_init,
|
||||||
constrained_json_whitespace_pattern=whitespace_patterns,
|
constrained_json_whitespace_pattern=whitespace_patterns,
|
||||||
enable=True,
|
|
||||||
)
|
)
|
||||||
self.jump_cache = JumpForwardCache() if allow_jump else None
|
self.jump_cache = OutlinesJumpCache() if allow_jump else None
|
||||||
|
|
||||||
def query(self, key: Tuple[str, str], vocab_size: int) -> Grammar:
|
def _query(self, key: Tuple[str, str], vocab_size: int) -> Grammar:
|
||||||
if isinstance(self.grammar_cache, BNFCache):
|
if isinstance(self.grammar_cache, XGrammarBackend):
|
||||||
assert not isinstance(self.jump_cache, JumpForwardCache)
|
|
||||||
return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache)
|
return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache)
|
||||||
else:
|
else:
|
||||||
jump_map = None
|
|
||||||
guide, regex = self.grammar_cache.query(key)
|
guide, regex = self.grammar_cache.query(key)
|
||||||
if isinstance(self.jump_cache, JumpForwardCache):
|
jump_map = self.jump_cache.query(regex)
|
||||||
jump_map = self.jump_cache.query(regex)
|
|
||||||
return Grammar((guide, 0), jump_map)
|
return Grammar((guide, 0), jump_map)
|
||||||
|
|
||||||
|
def query(self, key: Tuple[str, str], vocab_size: int) -> Future:
|
||||||
|
return self.executor.submit(self._query, key, vocab_size)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
if isinstance(self.grammar_cache, FSMCache):
|
self.grammar_cache.reset()
|
||||||
self.grammar_cache.reset()
|
self.jump_cache.reset()
|
||||||
if isinstance(self.jump_cache, JumpForwardCache):
|
|
||||||
self.jump_cache.reset()
|
|
||||||
|
|||||||
@@ -17,16 +17,17 @@ limitations under the License.
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from interegular import InvalidSyntax, parse_pattern
|
from interegular import InvalidSyntax, parse_pattern
|
||||||
from outlines.fsm.json_schema import build_regex_from_schema
|
from outlines.fsm.guide import RegexGuide
|
||||||
|
from outlines.models.transformers import TransformerTokenizer
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
from sglang.srt.constrained import build_regex_from_object
|
||||||
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FSMCache(BaseToolCache):
|
class OutlinesCache(BaseToolCache):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer_path,
|
tokenizer_path,
|
||||||
@@ -74,7 +75,7 @@ class FSMCache(BaseToolCache):
|
|||||||
key_type, key_string = key
|
key_type, key_string = key
|
||||||
if key_type == "json":
|
if key_type == "json":
|
||||||
try:
|
try:
|
||||||
regex = build_regex_from_schema(
|
regex = build_regex_from_object(
|
||||||
key_string,
|
key_string,
|
||||||
whitespace_pattern=self.constrained_json_whitespace_pattern,
|
whitespace_pattern=self.constrained_json_whitespace_pattern,
|
||||||
)
|
)
|
||||||
@@ -14,7 +14,7 @@ limitations under the License.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Faster constrained decoding.
|
Faster constrained decoding with jump forward decoding / compressed finite state machine.
|
||||||
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -23,15 +23,10 @@ import logging
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import interegular
|
import interegular
|
||||||
import outlines.caching
|
|
||||||
from interegular import InvalidSyntax
|
from interegular import InvalidSyntax
|
||||||
|
from outlines.caching import cache as disk_cache
|
||||||
|
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
||||||
|
|
||||||
from sglang.srt.constrained import (
|
|
||||||
FSMInfo,
|
|
||||||
disk_cache,
|
|
||||||
make_byte_level_fsm,
|
|
||||||
make_deterministic_fsm,
|
|
||||||
)
|
|
||||||
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||||
|
|
||||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||||
@@ -47,7 +42,7 @@ class JumpEdge:
|
|||||||
byte_next_state: int = None
|
byte_next_state: int = None
|
||||||
|
|
||||||
|
|
||||||
class JumpForwardMap:
|
class OutlinesJumpForwardMap:
|
||||||
def __init__(self, regex_string):
|
def __init__(self, regex_string):
|
||||||
@disk_cache()
|
@disk_cache()
|
||||||
def _init_state_to_jump_forward(regex_string):
|
def _init_state_to_jump_forward(regex_string):
|
||||||
@@ -169,12 +164,12 @@ class JumpForwardMap:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class JumpForwardCache(BaseToolCache):
|
class OutlinesJumpCache(BaseToolCache):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def init_value(self, regex):
|
def init_value(self, regex):
|
||||||
forward_map = JumpForwardMap(regex)
|
forward_map = OutlinesJumpForwardMap(regex)
|
||||||
if forward_map.state_to_jump_forward:
|
if forward_map.state_to_jump_forward:
|
||||||
return forward_map
|
return forward_map
|
||||||
else:
|
else:
|
||||||
@@ -182,7 +177,7 @@ class JumpForwardCache(BaseToolCache):
|
|||||||
|
|
||||||
|
|
||||||
def test_main(regex_string):
|
def test_main(regex_string):
|
||||||
jump_forward_map = JumpForwardMap(regex_string)
|
jump_forward_map = OutlinesJumpForwardMap(regex_string)
|
||||||
for state, e in jump_forward_map.state_to_jump_forward.items():
|
for state, e in jump_forward_map.state_to_jump_forward.items():
|
||||||
if e.symbol is not None:
|
if e.symbol is not None:
|
||||||
jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
|
jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
|
||||||
@@ -17,18 +17,29 @@ from typing import Tuple
|
|||||||
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from sglang.srt.constrained import (
|
try:
|
||||||
GrammarMatcher,
|
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
|
||||||
GrammarMatcherInitContext,
|
except ImportError as e:
|
||||||
GrammarMatcherInitContextCache,
|
|
||||||
)
|
class Dummy:
|
||||||
|
pass
|
||||||
|
|
||||||
|
GrammarMatcher = Dummy
|
||||||
|
CompiledGrammar = Dummy
|
||||||
|
CachedGrammarCompiler = Dummy
|
||||||
|
|
||||||
|
|
||||||
MAX_ROLLBACK_TOKENS = 10
|
MAX_ROLLBACK_TOKENS = 10
|
||||||
|
|
||||||
|
|
||||||
class BNFCache:
|
class XGrammarJumpCache:
|
||||||
grammar_cache: GrammarMatcherInitContextCache
|
"""A dummy class."""
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class XGrammarBackend:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer_path,
|
tokenizer_path,
|
||||||
@@ -41,16 +52,16 @@ class BNFCache:
|
|||||||
return
|
return
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
|
||||||
self.grammar_cache = GrammarMatcherInitContextCache(
|
self.grammar_cache: CachedGrammarCompiler = CachedGrammarCompiler(
|
||||||
tokenizer_or_vocab=tokenizer
|
tokenizer_or_vocab=tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_context(self, key: Tuple[str, str]) -> GrammarMatcherInitContext:
|
def get_context(self, key: Tuple[str, str]) -> CompiledGrammar:
|
||||||
key_type, key_string = key
|
key_type, key_string = key
|
||||||
if key_type == "json":
|
if key_type == "json":
|
||||||
return self.grammar_cache.get_init_context_for_json_schema(key_string)
|
return self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
|
||||||
elif key_type == "regex":
|
elif key_type == "regex":
|
||||||
raise ValueError(f"regex hasn't been supported by xgrammar yet")
|
raise ValueError("regex hasn't been supported by xgrammar yet")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid key_type: {key_type}")
|
raise ValueError(f"Invalid key_type: {key_type}")
|
||||||
|
|
||||||
@@ -59,3 +70,6 @@ class BNFCache:
|
|||||||
return GrammarMatcher(
|
return GrammarMatcher(
|
||||||
ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size
|
ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.grammar_cache.clear()
|
||||||
@@ -29,7 +29,7 @@ import zmq
|
|||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.constrained.grammar import GrammarCache
|
from sglang.srt.constrained.grammar import GrammarBackend
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
@@ -234,11 +234,12 @@ class Scheduler:
|
|||||||
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init the FSM cache for constrained generation
|
# Init the grammar cache for constrained generation
|
||||||
self.grammar_cache = None
|
self.grammar_cache = None
|
||||||
|
self.grammar_queue: List[Req] = []
|
||||||
|
|
||||||
if not server_args.skip_tokenizer_init:
|
if not server_args.skip_tokenizer_init:
|
||||||
self.grammar_cache = GrammarCache(
|
self.grammar_cache = GrammarBackend(
|
||||||
server_args.tokenizer_path,
|
server_args.tokenizer_path,
|
||||||
{
|
{
|
||||||
"tokenizer_mode": server_args.tokenizer_mode,
|
"tokenizer_mode": server_args.tokenizer_mode,
|
||||||
@@ -455,7 +456,7 @@ class Scheduler:
|
|||||||
# By default, only return the logprobs for output tokens
|
# By default, only return the logprobs for output tokens
|
||||||
req.logprob_start_len = len(recv_req.input_ids) - 1
|
req.logprob_start_len = len(recv_req.input_ids) - 1
|
||||||
|
|
||||||
# Init regex FSM or BNF
|
# Init grammar cache for this request
|
||||||
if (
|
if (
|
||||||
req.sampling_params.json_schema is not None
|
req.sampling_params.json_schema is not None
|
||||||
or req.sampling_params.regex is not None
|
or req.sampling_params.regex is not None
|
||||||
@@ -488,7 +489,10 @@ class Scheduler:
|
|||||||
self.max_req_len - len(req.origin_input_ids) - 1,
|
self.max_req_len - len(req.origin_input_ids) - 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.waiting_queue.append(req)
|
if req.grammar is not None:
|
||||||
|
self.grammar_queue.append(req)
|
||||||
|
else:
|
||||||
|
self.waiting_queue.append(req)
|
||||||
|
|
||||||
def handle_embedding_request(
|
def handle_embedding_request(
|
||||||
self,
|
self,
|
||||||
@@ -634,6 +638,17 @@ class Scheduler:
|
|||||||
return self.running_batch
|
return self.running_batch
|
||||||
|
|
||||||
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
||||||
|
# Check if the grammar queue is ready
|
||||||
|
if self.grammar_queue:
|
||||||
|
new_grammar_queue = []
|
||||||
|
for req in self.grammar_queue:
|
||||||
|
if req.grammar.done():
|
||||||
|
req.grammar = req.grammar.result()
|
||||||
|
self.waiting_queue.append(req)
|
||||||
|
else:
|
||||||
|
new_grammar_queue.append(req)
|
||||||
|
self.grammar_queue = new_grammar_queue
|
||||||
|
|
||||||
# Handle the cases where prefill is not allowed
|
# Handle the cases where prefill is not allowed
|
||||||
if (
|
if (
|
||||||
self.batch_is_full or len(self.waiting_queue) == 0
|
self.batch_is_full or len(self.waiting_queue) == 0
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ from vllm.model_executor.model_loader import get_model
|
|||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
|
||||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||||
from sglang.srt.constrained import disable_cache
|
|
||||||
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
||||||
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
||||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||||
@@ -129,6 +128,8 @@ class ModelRunner:
|
|||||||
if server_args.show_time_cost:
|
if server_args.show_time_cost:
|
||||||
enable_show_time_cost()
|
enable_show_time_cost()
|
||||||
if server_args.disable_disk_cache:
|
if server_args.disable_disk_cache:
|
||||||
|
from outlines.caching import disable_cache
|
||||||
|
|
||||||
disable_cache()
|
disable_cache()
|
||||||
|
|
||||||
global_server_args_dict.update(
|
global_server_args_dict.update(
|
||||||
|
|||||||
@@ -100,8 +100,8 @@ class TestJSONConstrained(unittest.TestCase):
|
|||||||
except (TypeError, json.decoder.JSONDecodeError):
|
except (TypeError, json.decoder.JSONDecodeError):
|
||||||
print("JSONDecodeError", text)
|
print("JSONDecodeError", text)
|
||||||
raise
|
raise
|
||||||
assert isinstance(js_obj["name"], str)
|
assert isinstance(js_obj["name"], str), f"{js_obj=}"
|
||||||
assert isinstance(js_obj["population"], int)
|
assert isinstance(js_obj["population"], int), f"{js_obj=}"
|
||||||
|
|
||||||
def test_mix_json_and_other(self):
|
def test_mix_json_and_other(self):
|
||||||
json_schemas = [None, None, self.json_schema, self.json_schema] * 10
|
json_schemas = [None, None, self.json_schema, self.json_schema] * 10
|
||||||
|
|||||||
Reference in New Issue
Block a user