Fix grammar backend for tensor parallelism (#2020)
This commit is contained in:
72
python/sglang/srt/constrained/base_grammar_backend.py
Normal file
72
python/sglang/srt/constrained/base_grammar_backend.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""The baseclass of backends for grammar-guided constrained decoding."""
|
||||
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from threading import Event, Lock
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
value: Any
|
||||
event: Event
|
||||
|
||||
|
||||
class BaseGrammarObject:
|
||||
pass
|
||||
|
||||
|
||||
class BaseGrammarBackend:
|
||||
def __init__(self):
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.cache = {}
|
||||
self.cache_lock = Lock()
|
||||
|
||||
def init_value(self, key: Tuple[str, str]) -> BaseGrammarObject:
|
||||
with self.cache_lock:
|
||||
if key in self.cache:
|
||||
cache_hit = True
|
||||
entry = self.cache[key]
|
||||
else:
|
||||
cache_hit = False
|
||||
entry = CacheEntry(None, Event())
|
||||
self.cache[key] = entry
|
||||
|
||||
if cache_hit:
|
||||
entry.event.wait()
|
||||
else:
|
||||
entry.value = self.init_value_impl(key)
|
||||
entry.event.set()
|
||||
return entry.value.copy()
|
||||
|
||||
def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
||||
with self.cache_lock:
|
||||
entry = self.cache.get(key)
|
||||
if not entry or not entry.event.is_set():
|
||||
return None
|
||||
return self.cache[key].value.copy()
|
||||
|
||||
def get_future_value(self, key: Tuple[str, str]) -> Future:
|
||||
return self.executor.submit(self.init_value, key)
|
||||
|
||||
def reset(self):
|
||||
with self.cache_lock:
|
||||
self.cache.clear()
|
||||
@@ -1,102 +0,0 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Base cache class for constrained decoding tools."""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from threading import Event, Lock
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class MapEntry:
|
||||
event: Event
|
||||
value: Any
|
||||
|
||||
def __iter__(self):
|
||||
return iter((self.event, self.value))
|
||||
|
||||
|
||||
class BaseToolCache:
|
||||
|
||||
def __init__(self, enable=True):
|
||||
self.enable: bool = enable
|
||||
self.cache: Dict[str, MapEntry] = {}
|
||||
self.metrics: Dict[str, Any] = {}
|
||||
self.lock_cache: Lock = Lock()
|
||||
self.lock_metrics: Lock = Lock()
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
with self.lock_cache:
|
||||
self.cache = {}
|
||||
with self.lock_metrics:
|
||||
self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0}
|
||||
|
||||
def _init_with_timer(self, key) -> Tuple[Any, float]:
|
||||
start = time.monotonic()
|
||||
val = self.init_value(key)
|
||||
init_time = time.monotonic() - start
|
||||
return val, init_time
|
||||
|
||||
def update_time(self, init_time):
|
||||
with self.lock_metrics:
|
||||
curr_total = self.metrics["total"]
|
||||
new_total = curr_total + 1
|
||||
|
||||
# Update average init time without old_avg * old_total to avoid overflow.
|
||||
self.metrics["avg_init_time"] = (init_time / new_total) + (
|
||||
curr_total / new_total
|
||||
) * self.metrics["avg_init_time"]
|
||||
|
||||
def query(self, key):
|
||||
if not self.enable:
|
||||
value, init_time = self._init_with_timer(key)
|
||||
self.update_time(init_time)
|
||||
return value
|
||||
|
||||
with self.lock_cache:
|
||||
if key in self.cache:
|
||||
entry = self.cache[key]
|
||||
cache_hit = True
|
||||
else:
|
||||
entry = MapEntry(Event(), None)
|
||||
self.cache[key] = entry
|
||||
cache_hit = False
|
||||
|
||||
with self.lock_metrics:
|
||||
self.metrics["total"] += 1
|
||||
if cache_hit:
|
||||
self.metrics["hit"] += 1
|
||||
|
||||
if cache_hit:
|
||||
entry.event.wait()
|
||||
else:
|
||||
entry.value, init_time = self._init_with_timer(key)
|
||||
self.update_time(init_time)
|
||||
entry.event.set()
|
||||
return entry.value
|
||||
|
||||
def init_value(self, key):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_cache_hit_rate(self):
|
||||
with self.lock_metrics:
|
||||
return self.metrics["hit"] / max(self.metrics["total"], 1)
|
||||
|
||||
def get_avg_init_time(self):
|
||||
with self.lock_metrics:
|
||||
return self.metrics["avg_init_time"]
|
||||
@@ -17,20 +17,17 @@ limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from interegular import InvalidSyntax, parse_pattern
|
||||
from outlines.fsm.guide import RegexGuide
|
||||
from outlines.models.transformers import TransformerTokenizer
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||
from sglang.srt.constrained.outlines_jump_forward import (
|
||||
OutlinesJumpForwardCache,
|
||||
OutlinesJumpForwardMap,
|
||||
from sglang.srt.constrained.base_grammar_backend import (
|
||||
BaseGrammarBackend,
|
||||
BaseGrammarObject,
|
||||
)
|
||||
from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -41,6 +38,7 @@ except ImportError:
|
||||
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
|
||||
# which only accepts string schema as input.
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
from pydantic import BaseModel
|
||||
|
||||
def build_regex_from_object(
|
||||
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
||||
@@ -54,16 +52,15 @@ except ImportError:
|
||||
return build_regex_from_schema(schema, whitespace_pattern)
|
||||
|
||||
|
||||
class OutlinesGrammar:
|
||||
class OutlinesGrammar(BaseGrammarObject):
|
||||
def __init__(
|
||||
self,
|
||||
guide: RegexGuide,
|
||||
state: int,
|
||||
jump_forward_map: Union[OutlinesJumpForwardMap, None],
|
||||
) -> None:
|
||||
self.guide = guide
|
||||
self.state = state
|
||||
self.jump_forward_map = jump_forward_map
|
||||
self.state = 0
|
||||
|
||||
def accept_token(self, token: int):
|
||||
self.state = self.guide.get_next_state(self.state, token)
|
||||
@@ -105,46 +102,18 @@ class OutlinesGrammar:
|
||||
vocab_mask.fill_(1)
|
||||
vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0
|
||||
|
||||
def copy(self):
|
||||
return OutlinesGrammar(self.guide, self.jump_forward_map)
|
||||
|
||||
class OutlinesGrammarBackend:
|
||||
|
||||
class OutlinesGrammarBackend(BaseGrammarBackend):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
whitespace_patterns: bool,
|
||||
whitespace_pattern: bool,
|
||||
allow_jump_forward: bool,
|
||||
):
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.grammar_cache = OutlinesCache(
|
||||
tokenizer,
|
||||
whitespace_pattern=whitespace_patterns,
|
||||
)
|
||||
self.jump_forward_cache = (
|
||||
OutlinesJumpForwardCache() if allow_jump_forward else None
|
||||
)
|
||||
|
||||
def _query(self, key: Tuple[str, str]) -> OutlinesGrammar:
|
||||
guide, regex = self.grammar_cache.query(key)
|
||||
jump_forward_map = (
|
||||
self.jump_forward_cache.query(regex) if self.jump_forward_cache else None
|
||||
)
|
||||
return OutlinesGrammar(guide, 0, jump_forward_map)
|
||||
|
||||
def query(self, key: Tuple[str, str]) -> Future:
|
||||
return self.executor.submit(self._query, key)
|
||||
|
||||
def reset(self):
|
||||
self.grammar_cache.reset()
|
||||
if self.jump_forward_cache:
|
||||
self.jump_forward_cache.reset()
|
||||
|
||||
|
||||
class OutlinesCache(BaseToolCache):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
whitespace_pattern=None,
|
||||
):
|
||||
super().__init__(enable=True)
|
||||
super().__init__()
|
||||
|
||||
try:
|
||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
@@ -167,9 +136,10 @@ class OutlinesCache(BaseToolCache):
|
||||
self.outlines_tokenizer.vocabulary = (
|
||||
self.outlines_tokenizer.tokenizer.get_vocab()
|
||||
)
|
||||
self.allow_jump_forward = allow_jump_forward
|
||||
self.whitespace_pattern = whitespace_pattern
|
||||
|
||||
def init_value(self, key):
|
||||
def init_value_impl(self, key: Tuple[str, str]) -> OutlinesGrammar:
|
||||
key_type, key_string = key
|
||||
if key_type == "json":
|
||||
try:
|
||||
@@ -186,18 +156,10 @@ class OutlinesCache(BaseToolCache):
|
||||
regex = key_string
|
||||
else:
|
||||
raise ValueError(f"Invalid key_type: {key_type}")
|
||||
try:
|
||||
parse_pattern(regex)
|
||||
except InvalidSyntax as e:
|
||||
logger.warning(f"skip invalid regex guide: {regex=}, {e=}")
|
||||
return None, regex
|
||||
|
||||
ret = RegexGuide(regex, self.outlines_tokenizer), regex
|
||||
return ret
|
||||
|
||||
def _query(self, key: Tuple[str, str]):
|
||||
guide, regex = self.grammar_cache.query(key)
|
||||
jump_forward_map = (
|
||||
self.jump_forward_cache.query(regex) if self.jump_forward_cache else None
|
||||
)
|
||||
return OutlinesGrammar(guide, 0, jump_forward_map)
|
||||
guide = RegexGuide(regex, self.outlines_tokenizer)
|
||||
if self.allow_jump_forward:
|
||||
jump_forward_map = OutlinesJumpForwardMap(regex)
|
||||
else:
|
||||
jump_forward_map = None
|
||||
return OutlinesGrammar(guide, jump_forward_map)
|
||||
|
||||
@@ -27,8 +27,6 @@ from interegular import InvalidSyntax
|
||||
from outlines.caching import cache as disk_cache
|
||||
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
||||
|
||||
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||
|
||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -42,92 +40,90 @@ class JumpEdge:
|
||||
byte_next_state: int = None
|
||||
|
||||
|
||||
@disk_cache()
|
||||
def init_state_to_jump_forward(regex_string):
|
||||
try:
|
||||
regex_pattern = interegular.parse_pattern(regex_string)
|
||||
except InvalidSyntax as e:
|
||||
logger.warning(f"skip invalid regex: {regex_string}, {e=}")
|
||||
return
|
||||
|
||||
byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True)
|
||||
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
|
||||
|
||||
fsm_info: FSMInfo = regex_fsm.fsm_info
|
||||
|
||||
symbol_to_id = fsm_info.alphabet_symbol_mapping
|
||||
id_to_symbol = {}
|
||||
for symbol, id_ in symbol_to_id.items():
|
||||
id_to_symbol.setdefault(id_, []).append(symbol)
|
||||
|
||||
transitions = fsm_info.transitions
|
||||
|
||||
outgoings_ct = defaultdict(int)
|
||||
# NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
|
||||
for s in fsm_info.finals:
|
||||
outgoings_ct[s] = 1
|
||||
|
||||
state_to_jump_forward = {}
|
||||
for (state, id_), next_state in transitions.items():
|
||||
if id_ == fsm_info.alphabet_anything_value:
|
||||
# Arbitrarily symbol cannot be recognized as jump forward
|
||||
continue
|
||||
|
||||
symbols = id_to_symbol[id_]
|
||||
for c in symbols:
|
||||
if len(c) > 1:
|
||||
# Skip byte level transitions like c = "5E"
|
||||
continue
|
||||
|
||||
outgoings_ct[state] += 1
|
||||
if outgoings_ct[state] > 1:
|
||||
if state in state_to_jump_forward:
|
||||
del state_to_jump_forward[state]
|
||||
break
|
||||
|
||||
state_to_jump_forward[state] = JumpEdge(
|
||||
symbol=c,
|
||||
symbol_next_state=next_state,
|
||||
)
|
||||
|
||||
# Process the byte level jump forward
|
||||
outgoings_ct = defaultdict(int)
|
||||
for s in fsm_info.finals:
|
||||
outgoings_ct[s] = 1
|
||||
|
||||
for (state, id_), next_state in transitions.items():
|
||||
if id_ == fsm_info.alphabet_anything_value:
|
||||
continue
|
||||
symbols = id_to_symbol[id_]
|
||||
for c in symbols:
|
||||
byte_ = None
|
||||
if len(c) == 1 and ord(c) < 0x80:
|
||||
# ASCII character
|
||||
byte_ = ord(c)
|
||||
elif len(c) > 1:
|
||||
# FIXME: This logic is due to the leading \x00
|
||||
# https://github.com/outlines-dev/outlines/pull/930
|
||||
byte_ = int(symbols[0][1:], 16)
|
||||
|
||||
if byte_ is not None:
|
||||
outgoings_ct[state] += 1
|
||||
if outgoings_ct[state] > 1:
|
||||
if state in state_to_jump_forward:
|
||||
del state_to_jump_forward[state]
|
||||
break
|
||||
e = state_to_jump_forward.get(state, JumpEdge())
|
||||
e.byte = byte_
|
||||
e.byte_next_state = next_state
|
||||
state_to_jump_forward[state] = e
|
||||
|
||||
return state_to_jump_forward
|
||||
|
||||
|
||||
class OutlinesJumpForwardMap:
|
||||
def __init__(self, regex_string):
|
||||
@disk_cache()
|
||||
def _init_state_to_jump_forward(regex_string):
|
||||
try:
|
||||
regex_pattern = interegular.parse_pattern(regex_string)
|
||||
except InvalidSyntax as e:
|
||||
logger.warning(f"skip invalid regex: {regex_string}, {e=}")
|
||||
self.state_to_jump_forward = None
|
||||
return
|
||||
|
||||
byte_fsm = make_byte_level_fsm(
|
||||
regex_pattern.to_fsm().reduce(), keep_utf8=True
|
||||
)
|
||||
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
|
||||
|
||||
fsm_info: FSMInfo = regex_fsm.fsm_info
|
||||
|
||||
symbol_to_id = fsm_info.alphabet_symbol_mapping
|
||||
id_to_symbol = {}
|
||||
for symbol, id_ in symbol_to_id.items():
|
||||
id_to_symbol.setdefault(id_, []).append(symbol)
|
||||
|
||||
transitions = fsm_info.transitions
|
||||
|
||||
outgoings_ct = defaultdict(int)
|
||||
# NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
|
||||
for s in fsm_info.finals:
|
||||
outgoings_ct[s] = 1
|
||||
|
||||
state_to_jump_forward = {}
|
||||
for (state, id_), next_state in transitions.items():
|
||||
if id_ == fsm_info.alphabet_anything_value:
|
||||
# Arbitrarily symbol cannot be recognized as jump forward
|
||||
continue
|
||||
|
||||
symbols = id_to_symbol[id_]
|
||||
for c in symbols:
|
||||
if len(c) > 1:
|
||||
# Skip byte level transitions like c = "5E"
|
||||
continue
|
||||
|
||||
outgoings_ct[state] += 1
|
||||
if outgoings_ct[state] > 1:
|
||||
if state in state_to_jump_forward:
|
||||
del state_to_jump_forward[state]
|
||||
break
|
||||
|
||||
state_to_jump_forward[state] = JumpEdge(
|
||||
symbol=c,
|
||||
symbol_next_state=next_state,
|
||||
)
|
||||
|
||||
# Process the byte level jump forward
|
||||
outgoings_ct = defaultdict(int)
|
||||
for s in fsm_info.finals:
|
||||
outgoings_ct[s] = 1
|
||||
|
||||
for (state, id_), next_state in transitions.items():
|
||||
if id_ == fsm_info.alphabet_anything_value:
|
||||
continue
|
||||
symbols = id_to_symbol[id_]
|
||||
for c in symbols:
|
||||
byte_ = None
|
||||
if len(c) == 1 and ord(c) < 0x80:
|
||||
# ASCII character
|
||||
byte_ = ord(c)
|
||||
elif len(c) > 1:
|
||||
# FIXME: This logic is due to the leading \x00
|
||||
# https://github.com/outlines-dev/outlines/pull/930
|
||||
byte_ = int(symbols[0][1:], 16)
|
||||
|
||||
if byte_ is not None:
|
||||
outgoings_ct[state] += 1
|
||||
if outgoings_ct[state] > 1:
|
||||
if state in state_to_jump_forward:
|
||||
del state_to_jump_forward[state]
|
||||
break
|
||||
e = state_to_jump_forward.get(state, JumpEdge())
|
||||
e.byte = byte_
|
||||
e.byte_next_state = next_state
|
||||
state_to_jump_forward[state] = e
|
||||
|
||||
return state_to_jump_forward
|
||||
|
||||
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
|
||||
self.state_to_jump_forward = init_state_to_jump_forward(regex_string)
|
||||
|
||||
def jump_forward_symbol(self, state):
|
||||
jump_forward_str = ""
|
||||
@@ -164,18 +160,6 @@ class OutlinesJumpForwardMap:
|
||||
)
|
||||
|
||||
|
||||
class OutlinesJumpForwardCache(BaseToolCache):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def init_value(self, regex):
|
||||
forward_map = OutlinesJumpForwardMap(regex)
|
||||
if forward_map.state_to_jump_forward:
|
||||
return forward_map
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def test_main(regex_string):
|
||||
jump_forward_map = OutlinesJumpForwardMap(regex_string)
|
||||
for state, e in jump_forward_map.state_to_jump_forward.items():
|
||||
|
||||
@@ -15,38 +15,36 @@ limitations under the License.
|
||||
|
||||
"""Constrained decoding with xgrammar backend."""
|
||||
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
|
||||
|
||||
try:
|
||||
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
|
||||
|
||||
import_error = None
|
||||
except ImportError as e:
|
||||
import_error = e
|
||||
|
||||
class Dummy:
|
||||
pass
|
||||
|
||||
GrammarMatcher = CompiledGrammar = CachedGrammarCompiler = Dummy
|
||||
|
||||
from sglang.srt.constrained.base_grammar_backend import (
|
||||
BaseGrammarBackend,
|
||||
BaseGrammarObject,
|
||||
)
|
||||
|
||||
MAX_ROLLBACK_TOKENS = 10
|
||||
|
||||
|
||||
class XGrammarGrammar:
|
||||
class XGrammarGrammar(BaseGrammarObject):
|
||||
|
||||
def __init__(self, matcher: GrammarMatcher, vocab_size: int) -> None:
|
||||
def __init__(
|
||||
self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar
|
||||
) -> None:
|
||||
self.matcher = matcher
|
||||
self.vocab_size = vocab_size
|
||||
self.ctx = ctx
|
||||
|
||||
def accept_token(self, token: int):
|
||||
assert self.matcher.accept_token(token)
|
||||
|
||||
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
|
||||
return [], self.matcher.find_jump_forward_string()
|
||||
s = self.matcher.find_jump_forward_string()
|
||||
if s:
|
||||
return [], s
|
||||
return None
|
||||
|
||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||
_, data = helper
|
||||
@@ -77,51 +75,40 @@ class XGrammarGrammar:
|
||||
self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size)
|
||||
] = 1
|
||||
|
||||
def copy(self):
|
||||
matcher = GrammarMatcher(
|
||||
self.ctx,
|
||||
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
||||
mask_vocab_size=self.vocab_size,
|
||||
)
|
||||
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
|
||||
|
||||
class XGrammarGrammarBackend:
|
||||
|
||||
class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
vocab_size: int,
|
||||
):
|
||||
if import_error:
|
||||
raise import_error
|
||||
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.grammar_cache = XGrammarCache(tokenizer, vocab_size)
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
def _query(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
||||
return XGrammarGrammar(self.grammar_cache.query(key), self.vocab_size)
|
||||
|
||||
def query(self, key: Tuple[str, str]) -> Future:
|
||||
return self.executor.submit(self._query, key)
|
||||
|
||||
def reset(self):
|
||||
self.grammar_cache.reset()
|
||||
|
||||
|
||||
class XGrammarCache:
|
||||
def __init__(self, tokenizer, vocab_size: int):
|
||||
super().__init__()
|
||||
self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
def get_context(self, key: Tuple[str, str]) -> CompiledGrammar:
|
||||
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
||||
key_type, key_string = key
|
||||
if key_type == "json":
|
||||
return self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
|
||||
ctx = self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
|
||||
elif key_type == "regex":
|
||||
raise ValueError("regex hasn't been supported by xgrammar yet")
|
||||
else:
|
||||
raise ValueError(f"Invalid key_type: {key_type}")
|
||||
|
||||
def query(self, key: Tuple[str, str]) -> GrammarMatcher:
|
||||
ctx = self.get_context(key)
|
||||
return GrammarMatcher(
|
||||
matcher = GrammarMatcher(
|
||||
ctx,
|
||||
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
||||
mask_vocab_size=self.vocab_size,
|
||||
)
|
||||
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
||||
|
||||
def reset(self):
|
||||
self.grammar_cache.clear()
|
||||
|
||||
@@ -37,6 +37,7 @@ import torch
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
@@ -248,7 +249,7 @@ class Req:
|
||||
self.embedding = None
|
||||
|
||||
# Constrained decoding
|
||||
self.grammar = None
|
||||
self.grammar: Optional[BaseGrammarObject] = None
|
||||
|
||||
# The number of cached tokens, that were already cached in the KV cache
|
||||
self.cached_tokens = 0
|
||||
|
||||
@@ -244,7 +244,7 @@ class Scheduler:
|
||||
|
||||
self.grammar_backend = OutlinesGrammarBackend(
|
||||
self.tokenizer,
|
||||
whitespace_patterns=server_args.constrained_json_whitespace_pattern,
|
||||
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
||||
allow_jump_forward=not server_args.disable_jump_forward,
|
||||
)
|
||||
elif server_args.grammar_backend == "xgrammar":
|
||||
@@ -467,21 +467,6 @@ class Scheduler:
|
||||
# By default, only return the logprobs for output tokens
|
||||
req.logprob_start_len = len(recv_req.input_ids) - 1
|
||||
|
||||
# Init grammar cache for this request
|
||||
if (
|
||||
req.sampling_params.json_schema is not None
|
||||
or req.sampling_params.regex is not None
|
||||
):
|
||||
assert self.grammar_backend is not None
|
||||
if req.sampling_params.json_schema is not None:
|
||||
req.grammar = self.grammar_backend.query(
|
||||
("json", req.sampling_params.json_schema),
|
||||
)
|
||||
elif req.sampling_params.regex is not None:
|
||||
req.grammar = self.grammar_backend.query(
|
||||
("regex", req.sampling_params.regex)
|
||||
)
|
||||
|
||||
# Truncate prompts that are too long
|
||||
if len(req.origin_input_ids) > self.max_req_input_len:
|
||||
logger.warning(
|
||||
@@ -499,7 +484,24 @@ class Scheduler:
|
||||
self.max_req_len - len(req.origin_input_ids) - 1,
|
||||
)
|
||||
|
||||
if req.grammar is not None:
|
||||
# Init grammar cache for this request
|
||||
add_to_grammar_queue = False
|
||||
if (
|
||||
req.sampling_params.json_schema is not None
|
||||
or req.sampling_params.regex is not None
|
||||
):
|
||||
assert self.grammar_backend is not None
|
||||
if req.sampling_params.json_schema is not None:
|
||||
key = ("json", req.sampling_params.json_schema)
|
||||
elif req.sampling_params.regex is not None:
|
||||
key = ("regex", req.sampling_params.regex)
|
||||
|
||||
req.grammar = self.grammar_backend.get_cached_value(key)
|
||||
if not req.grammar:
|
||||
req.grammar = self.grammar_backend.get_future_value(key)
|
||||
add_to_grammar_queue = True
|
||||
|
||||
if add_to_grammar_queue:
|
||||
self.grammar_queue.append(req)
|
||||
else:
|
||||
self.waiting_queue.append(req)
|
||||
@@ -650,14 +652,7 @@ class Scheduler:
|
||||
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
||||
# Check if the grammar is ready in the grammar queue
|
||||
if self.grammar_queue:
|
||||
new_grammar_queue = []
|
||||
for req in self.grammar_queue:
|
||||
try:
|
||||
req.grammar = req.grammar.result(timeout=0.05)
|
||||
self.waiting_queue.append(req)
|
||||
except futures._base.TimeoutError:
|
||||
new_grammar_queue.append(req)
|
||||
self.grammar_queue = new_grammar_queue
|
||||
self.move_ready_grammar_requests()
|
||||
|
||||
# Handle the cases where prefill is not allowed
|
||||
if (
|
||||
@@ -1145,6 +1140,30 @@ class Scheduler:
|
||||
)
|
||||
)
|
||||
|
||||
def move_ready_grammar_requests(self):
|
||||
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
||||
num_ready_reqs = 0
|
||||
for req in self.grammar_queue:
|
||||
try:
|
||||
req.grammar = req.grammar.result(timeout=0.05)
|
||||
num_ready_reqs += 1
|
||||
except futures._base.TimeoutError:
|
||||
break
|
||||
|
||||
if self.tp_size > 1:
|
||||
# Sync across TP ranks to make sure they have the same number of ready requests
|
||||
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
|
||||
torch.distributed.all_reduce(
|
||||
tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
|
||||
)
|
||||
num_ready_reqs_max = tensor.item()
|
||||
for i in range(num_ready_reqs, num_ready_reqs_max):
|
||||
self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
|
||||
num_ready_reqs = num_ready_reqs_max
|
||||
|
||||
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
|
||||
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
||||
|
||||
def flush_cache(self):
|
||||
"""Flush the memory pool and cache."""
|
||||
if len(self.waiting_queue) == 0 and (
|
||||
@@ -1152,9 +1171,8 @@ class Scheduler:
|
||||
):
|
||||
self.tree_cache.reset()
|
||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||
if self.grammar_backend is not None:
|
||||
if self.grammar_backend:
|
||||
self.grammar_backend.reset()
|
||||
# TODO(dark): reset the bnf cache
|
||||
self.req_to_token_pool.clear()
|
||||
self.token_to_kv_pool.clear()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user