Fix grammar backend for tensor parallelism (#2020)
This commit is contained in:
72
python/sglang/srt/constrained/base_grammar_backend.py
Normal file
72
python/sglang/srt/constrained/base_grammar_backend.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""The baseclass of backends for grammar-guided constrained decoding."""
|
||||||
|
|
||||||
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from threading import Event, Lock
|
||||||
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheEntry:
|
||||||
|
value: Any
|
||||||
|
event: Event
|
||||||
|
|
||||||
|
|
||||||
|
class BaseGrammarObject:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseGrammarBackend:
|
||||||
|
def __init__(self):
|
||||||
|
self.executor = ThreadPoolExecutor()
|
||||||
|
self.cache = {}
|
||||||
|
self.cache_lock = Lock()
|
||||||
|
|
||||||
|
def init_value(self, key: Tuple[str, str]) -> BaseGrammarObject:
|
||||||
|
with self.cache_lock:
|
||||||
|
if key in self.cache:
|
||||||
|
cache_hit = True
|
||||||
|
entry = self.cache[key]
|
||||||
|
else:
|
||||||
|
cache_hit = False
|
||||||
|
entry = CacheEntry(None, Event())
|
||||||
|
self.cache[key] = entry
|
||||||
|
|
||||||
|
if cache_hit:
|
||||||
|
entry.event.wait()
|
||||||
|
else:
|
||||||
|
entry.value = self.init_value_impl(key)
|
||||||
|
entry.event.set()
|
||||||
|
return entry.value.copy()
|
||||||
|
|
||||||
|
def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
|
||||||
|
with self.cache_lock:
|
||||||
|
entry = self.cache.get(key)
|
||||||
|
if not entry or not entry.event.is_set():
|
||||||
|
return None
|
||||||
|
return self.cache[key].value.copy()
|
||||||
|
|
||||||
|
def get_future_value(self, key: Tuple[str, str]) -> Future:
|
||||||
|
return self.executor.submit(self.init_value, key)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
with self.cache_lock:
|
||||||
|
self.cache.clear()
|
||||||
@@ -1,102 +0,0 @@
|
|||||||
"""
|
|
||||||
Copyright 2023-2024 SGLang Team
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""Base cache class for constrained decoding tools."""
|
|
||||||
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from threading import Event, Lock
|
|
||||||
from typing import Any, Dict, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MapEntry:
|
|
||||||
event: Event
|
|
||||||
value: Any
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return iter((self.event, self.value))
|
|
||||||
|
|
||||||
|
|
||||||
class BaseToolCache:
|
|
||||||
|
|
||||||
def __init__(self, enable=True):
|
|
||||||
self.enable: bool = enable
|
|
||||||
self.cache: Dict[str, MapEntry] = {}
|
|
||||||
self.metrics: Dict[str, Any] = {}
|
|
||||||
self.lock_cache: Lock = Lock()
|
|
||||||
self.lock_metrics: Lock = Lock()
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
with self.lock_cache:
|
|
||||||
self.cache = {}
|
|
||||||
with self.lock_metrics:
|
|
||||||
self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0}
|
|
||||||
|
|
||||||
def _init_with_timer(self, key) -> Tuple[Any, float]:
|
|
||||||
start = time.monotonic()
|
|
||||||
val = self.init_value(key)
|
|
||||||
init_time = time.monotonic() - start
|
|
||||||
return val, init_time
|
|
||||||
|
|
||||||
def update_time(self, init_time):
|
|
||||||
with self.lock_metrics:
|
|
||||||
curr_total = self.metrics["total"]
|
|
||||||
new_total = curr_total + 1
|
|
||||||
|
|
||||||
# Update average init time without old_avg * old_total to avoid overflow.
|
|
||||||
self.metrics["avg_init_time"] = (init_time / new_total) + (
|
|
||||||
curr_total / new_total
|
|
||||||
) * self.metrics["avg_init_time"]
|
|
||||||
|
|
||||||
def query(self, key):
|
|
||||||
if not self.enable:
|
|
||||||
value, init_time = self._init_with_timer(key)
|
|
||||||
self.update_time(init_time)
|
|
||||||
return value
|
|
||||||
|
|
||||||
with self.lock_cache:
|
|
||||||
if key in self.cache:
|
|
||||||
entry = self.cache[key]
|
|
||||||
cache_hit = True
|
|
||||||
else:
|
|
||||||
entry = MapEntry(Event(), None)
|
|
||||||
self.cache[key] = entry
|
|
||||||
cache_hit = False
|
|
||||||
|
|
||||||
with self.lock_metrics:
|
|
||||||
self.metrics["total"] += 1
|
|
||||||
if cache_hit:
|
|
||||||
self.metrics["hit"] += 1
|
|
||||||
|
|
||||||
if cache_hit:
|
|
||||||
entry.event.wait()
|
|
||||||
else:
|
|
||||||
entry.value, init_time = self._init_with_timer(key)
|
|
||||||
self.update_time(init_time)
|
|
||||||
entry.event.set()
|
|
||||||
return entry.value
|
|
||||||
|
|
||||||
def init_value(self, key):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_cache_hit_rate(self):
|
|
||||||
with self.lock_metrics:
|
|
||||||
return self.metrics["hit"] / max(self.metrics["total"], 1)
|
|
||||||
|
|
||||||
def get_avg_init_time(self):
|
|
||||||
with self.lock_metrics:
|
|
||||||
return self.metrics["avg_init_time"]
|
|
||||||
@@ -17,20 +17,17 @@ limitations under the License.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from interegular import InvalidSyntax, parse_pattern
|
|
||||||
from outlines.fsm.guide import RegexGuide
|
from outlines.fsm.guide import RegexGuide
|
||||||
from outlines.models.transformers import TransformerTokenizer
|
from outlines.models.transformers import TransformerTokenizer
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
from sglang.srt.constrained.base_grammar_backend import (
|
||||||
from sglang.srt.constrained.outlines_jump_forward import (
|
BaseGrammarBackend,
|
||||||
OutlinesJumpForwardCache,
|
BaseGrammarObject,
|
||||||
OutlinesJumpForwardMap,
|
|
||||||
)
|
)
|
||||||
|
from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -41,6 +38,7 @@ except ImportError:
|
|||||||
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
|
# Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
|
||||||
# which only accepts string schema as input.
|
# which only accepts string schema as input.
|
||||||
from outlines.fsm.json_schema import build_regex_from_schema
|
from outlines.fsm.json_schema import build_regex_from_schema
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
def build_regex_from_object(
|
def build_regex_from_object(
|
||||||
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
|
||||||
@@ -54,16 +52,15 @@ except ImportError:
|
|||||||
return build_regex_from_schema(schema, whitespace_pattern)
|
return build_regex_from_schema(schema, whitespace_pattern)
|
||||||
|
|
||||||
|
|
||||||
class OutlinesGrammar:
|
class OutlinesGrammar(BaseGrammarObject):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
guide: RegexGuide,
|
guide: RegexGuide,
|
||||||
state: int,
|
|
||||||
jump_forward_map: Union[OutlinesJumpForwardMap, None],
|
jump_forward_map: Union[OutlinesJumpForwardMap, None],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.guide = guide
|
self.guide = guide
|
||||||
self.state = state
|
|
||||||
self.jump_forward_map = jump_forward_map
|
self.jump_forward_map = jump_forward_map
|
||||||
|
self.state = 0
|
||||||
|
|
||||||
def accept_token(self, token: int):
|
def accept_token(self, token: int):
|
||||||
self.state = self.guide.get_next_state(self.state, token)
|
self.state = self.guide.get_next_state(self.state, token)
|
||||||
@@ -105,46 +102,18 @@ class OutlinesGrammar:
|
|||||||
vocab_mask.fill_(1)
|
vocab_mask.fill_(1)
|
||||||
vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0
|
vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return OutlinesGrammar(self.guide, self.jump_forward_map)
|
||||||
|
|
||||||
class OutlinesGrammarBackend:
|
|
||||||
|
class OutlinesGrammarBackend(BaseGrammarBackend):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
whitespace_patterns: bool,
|
whitespace_pattern: bool,
|
||||||
allow_jump_forward: bool,
|
allow_jump_forward: bool,
|
||||||
):
|
):
|
||||||
self.executor = ThreadPoolExecutor()
|
super().__init__()
|
||||||
self.grammar_cache = OutlinesCache(
|
|
||||||
tokenizer,
|
|
||||||
whitespace_pattern=whitespace_patterns,
|
|
||||||
)
|
|
||||||
self.jump_forward_cache = (
|
|
||||||
OutlinesJumpForwardCache() if allow_jump_forward else None
|
|
||||||
)
|
|
||||||
|
|
||||||
def _query(self, key: Tuple[str, str]) -> OutlinesGrammar:
|
|
||||||
guide, regex = self.grammar_cache.query(key)
|
|
||||||
jump_forward_map = (
|
|
||||||
self.jump_forward_cache.query(regex) if self.jump_forward_cache else None
|
|
||||||
)
|
|
||||||
return OutlinesGrammar(guide, 0, jump_forward_map)
|
|
||||||
|
|
||||||
def query(self, key: Tuple[str, str]) -> Future:
|
|
||||||
return self.executor.submit(self._query, key)
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.grammar_cache.reset()
|
|
||||||
if self.jump_forward_cache:
|
|
||||||
self.jump_forward_cache.reset()
|
|
||||||
|
|
||||||
|
|
||||||
class OutlinesCache(BaseToolCache):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tokenizer,
|
|
||||||
whitespace_pattern=None,
|
|
||||||
):
|
|
||||||
super().__init__(enable=True)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||||
@@ -167,9 +136,10 @@ class OutlinesCache(BaseToolCache):
|
|||||||
self.outlines_tokenizer.vocabulary = (
|
self.outlines_tokenizer.vocabulary = (
|
||||||
self.outlines_tokenizer.tokenizer.get_vocab()
|
self.outlines_tokenizer.tokenizer.get_vocab()
|
||||||
)
|
)
|
||||||
|
self.allow_jump_forward = allow_jump_forward
|
||||||
self.whitespace_pattern = whitespace_pattern
|
self.whitespace_pattern = whitespace_pattern
|
||||||
|
|
||||||
def init_value(self, key):
|
def init_value_impl(self, key: Tuple[str, str]) -> OutlinesGrammar:
|
||||||
key_type, key_string = key
|
key_type, key_string = key
|
||||||
if key_type == "json":
|
if key_type == "json":
|
||||||
try:
|
try:
|
||||||
@@ -186,18 +156,10 @@ class OutlinesCache(BaseToolCache):
|
|||||||
regex = key_string
|
regex = key_string
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid key_type: {key_type}")
|
raise ValueError(f"Invalid key_type: {key_type}")
|
||||||
try:
|
|
||||||
parse_pattern(regex)
|
|
||||||
except InvalidSyntax as e:
|
|
||||||
logger.warning(f"skip invalid regex guide: {regex=}, {e=}")
|
|
||||||
return None, regex
|
|
||||||
|
|
||||||
ret = RegexGuide(regex, self.outlines_tokenizer), regex
|
guide = RegexGuide(regex, self.outlines_tokenizer)
|
||||||
return ret
|
if self.allow_jump_forward:
|
||||||
|
jump_forward_map = OutlinesJumpForwardMap(regex)
|
||||||
def _query(self, key: Tuple[str, str]):
|
else:
|
||||||
guide, regex = self.grammar_cache.query(key)
|
jump_forward_map = None
|
||||||
jump_forward_map = (
|
return OutlinesGrammar(guide, jump_forward_map)
|
||||||
self.jump_forward_cache.query(regex) if self.jump_forward_cache else None
|
|
||||||
)
|
|
||||||
return OutlinesGrammar(guide, 0, jump_forward_map)
|
|
||||||
|
|||||||
@@ -27,8 +27,6 @@ from interegular import InvalidSyntax
|
|||||||
from outlines.caching import cache as disk_cache
|
from outlines.caching import cache as disk_cache
|
||||||
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
||||||
|
|
||||||
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
|
||||||
|
|
||||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -42,92 +40,90 @@ class JumpEdge:
|
|||||||
byte_next_state: int = None
|
byte_next_state: int = None
|
||||||
|
|
||||||
|
|
||||||
|
@disk_cache()
|
||||||
|
def init_state_to_jump_forward(regex_string):
|
||||||
|
try:
|
||||||
|
regex_pattern = interegular.parse_pattern(regex_string)
|
||||||
|
except InvalidSyntax as e:
|
||||||
|
logger.warning(f"skip invalid regex: {regex_string}, {e=}")
|
||||||
|
return
|
||||||
|
|
||||||
|
byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True)
|
||||||
|
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
|
||||||
|
|
||||||
|
fsm_info: FSMInfo = regex_fsm.fsm_info
|
||||||
|
|
||||||
|
symbol_to_id = fsm_info.alphabet_symbol_mapping
|
||||||
|
id_to_symbol = {}
|
||||||
|
for symbol, id_ in symbol_to_id.items():
|
||||||
|
id_to_symbol.setdefault(id_, []).append(symbol)
|
||||||
|
|
||||||
|
transitions = fsm_info.transitions
|
||||||
|
|
||||||
|
outgoings_ct = defaultdict(int)
|
||||||
|
# NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
|
||||||
|
for s in fsm_info.finals:
|
||||||
|
outgoings_ct[s] = 1
|
||||||
|
|
||||||
|
state_to_jump_forward = {}
|
||||||
|
for (state, id_), next_state in transitions.items():
|
||||||
|
if id_ == fsm_info.alphabet_anything_value:
|
||||||
|
# Arbitrarily symbol cannot be recognized as jump forward
|
||||||
|
continue
|
||||||
|
|
||||||
|
symbols = id_to_symbol[id_]
|
||||||
|
for c in symbols:
|
||||||
|
if len(c) > 1:
|
||||||
|
# Skip byte level transitions like c = "5E"
|
||||||
|
continue
|
||||||
|
|
||||||
|
outgoings_ct[state] += 1
|
||||||
|
if outgoings_ct[state] > 1:
|
||||||
|
if state in state_to_jump_forward:
|
||||||
|
del state_to_jump_forward[state]
|
||||||
|
break
|
||||||
|
|
||||||
|
state_to_jump_forward[state] = JumpEdge(
|
||||||
|
symbol=c,
|
||||||
|
symbol_next_state=next_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the byte level jump forward
|
||||||
|
outgoings_ct = defaultdict(int)
|
||||||
|
for s in fsm_info.finals:
|
||||||
|
outgoings_ct[s] = 1
|
||||||
|
|
||||||
|
for (state, id_), next_state in transitions.items():
|
||||||
|
if id_ == fsm_info.alphabet_anything_value:
|
||||||
|
continue
|
||||||
|
symbols = id_to_symbol[id_]
|
||||||
|
for c in symbols:
|
||||||
|
byte_ = None
|
||||||
|
if len(c) == 1 and ord(c) < 0x80:
|
||||||
|
# ASCII character
|
||||||
|
byte_ = ord(c)
|
||||||
|
elif len(c) > 1:
|
||||||
|
# FIXME: This logic is due to the leading \x00
|
||||||
|
# https://github.com/outlines-dev/outlines/pull/930
|
||||||
|
byte_ = int(symbols[0][1:], 16)
|
||||||
|
|
||||||
|
if byte_ is not None:
|
||||||
|
outgoings_ct[state] += 1
|
||||||
|
if outgoings_ct[state] > 1:
|
||||||
|
if state in state_to_jump_forward:
|
||||||
|
del state_to_jump_forward[state]
|
||||||
|
break
|
||||||
|
e = state_to_jump_forward.get(state, JumpEdge())
|
||||||
|
e.byte = byte_
|
||||||
|
e.byte_next_state = next_state
|
||||||
|
state_to_jump_forward[state] = e
|
||||||
|
|
||||||
|
return state_to_jump_forward
|
||||||
|
|
||||||
|
|
||||||
class OutlinesJumpForwardMap:
|
class OutlinesJumpForwardMap:
|
||||||
def __init__(self, regex_string):
|
def __init__(self, regex_string):
|
||||||
@disk_cache()
|
self.state_to_jump_forward = init_state_to_jump_forward(regex_string)
|
||||||
def _init_state_to_jump_forward(regex_string):
|
|
||||||
try:
|
|
||||||
regex_pattern = interegular.parse_pattern(regex_string)
|
|
||||||
except InvalidSyntax as e:
|
|
||||||
logger.warning(f"skip invalid regex: {regex_string}, {e=}")
|
|
||||||
self.state_to_jump_forward = None
|
|
||||||
return
|
|
||||||
|
|
||||||
byte_fsm = make_byte_level_fsm(
|
|
||||||
regex_pattern.to_fsm().reduce(), keep_utf8=True
|
|
||||||
)
|
|
||||||
regex_fsm, _ = make_deterministic_fsm(byte_fsm)
|
|
||||||
|
|
||||||
fsm_info: FSMInfo = regex_fsm.fsm_info
|
|
||||||
|
|
||||||
symbol_to_id = fsm_info.alphabet_symbol_mapping
|
|
||||||
id_to_symbol = {}
|
|
||||||
for symbol, id_ in symbol_to_id.items():
|
|
||||||
id_to_symbol.setdefault(id_, []).append(symbol)
|
|
||||||
|
|
||||||
transitions = fsm_info.transitions
|
|
||||||
|
|
||||||
outgoings_ct = defaultdict(int)
|
|
||||||
# NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
|
|
||||||
for s in fsm_info.finals:
|
|
||||||
outgoings_ct[s] = 1
|
|
||||||
|
|
||||||
state_to_jump_forward = {}
|
|
||||||
for (state, id_), next_state in transitions.items():
|
|
||||||
if id_ == fsm_info.alphabet_anything_value:
|
|
||||||
# Arbitrarily symbol cannot be recognized as jump forward
|
|
||||||
continue
|
|
||||||
|
|
||||||
symbols = id_to_symbol[id_]
|
|
||||||
for c in symbols:
|
|
||||||
if len(c) > 1:
|
|
||||||
# Skip byte level transitions like c = "5E"
|
|
||||||
continue
|
|
||||||
|
|
||||||
outgoings_ct[state] += 1
|
|
||||||
if outgoings_ct[state] > 1:
|
|
||||||
if state in state_to_jump_forward:
|
|
||||||
del state_to_jump_forward[state]
|
|
||||||
break
|
|
||||||
|
|
||||||
state_to_jump_forward[state] = JumpEdge(
|
|
||||||
symbol=c,
|
|
||||||
symbol_next_state=next_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process the byte level jump forward
|
|
||||||
outgoings_ct = defaultdict(int)
|
|
||||||
for s in fsm_info.finals:
|
|
||||||
outgoings_ct[s] = 1
|
|
||||||
|
|
||||||
for (state, id_), next_state in transitions.items():
|
|
||||||
if id_ == fsm_info.alphabet_anything_value:
|
|
||||||
continue
|
|
||||||
symbols = id_to_symbol[id_]
|
|
||||||
for c in symbols:
|
|
||||||
byte_ = None
|
|
||||||
if len(c) == 1 and ord(c) < 0x80:
|
|
||||||
# ASCII character
|
|
||||||
byte_ = ord(c)
|
|
||||||
elif len(c) > 1:
|
|
||||||
# FIXME: This logic is due to the leading \x00
|
|
||||||
# https://github.com/outlines-dev/outlines/pull/930
|
|
||||||
byte_ = int(symbols[0][1:], 16)
|
|
||||||
|
|
||||||
if byte_ is not None:
|
|
||||||
outgoings_ct[state] += 1
|
|
||||||
if outgoings_ct[state] > 1:
|
|
||||||
if state in state_to_jump_forward:
|
|
||||||
del state_to_jump_forward[state]
|
|
||||||
break
|
|
||||||
e = state_to_jump_forward.get(state, JumpEdge())
|
|
||||||
e.byte = byte_
|
|
||||||
e.byte_next_state = next_state
|
|
||||||
state_to_jump_forward[state] = e
|
|
||||||
|
|
||||||
return state_to_jump_forward
|
|
||||||
|
|
||||||
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
|
|
||||||
|
|
||||||
def jump_forward_symbol(self, state):
|
def jump_forward_symbol(self, state):
|
||||||
jump_forward_str = ""
|
jump_forward_str = ""
|
||||||
@@ -164,18 +160,6 @@ class OutlinesJumpForwardMap:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class OutlinesJumpForwardCache(BaseToolCache):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def init_value(self, regex):
|
|
||||||
forward_map = OutlinesJumpForwardMap(regex)
|
|
||||||
if forward_map.state_to_jump_forward:
|
|
||||||
return forward_map
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def test_main(regex_string):
|
def test_main(regex_string):
|
||||||
jump_forward_map = OutlinesJumpForwardMap(regex_string)
|
jump_forward_map = OutlinesJumpForwardMap(regex_string)
|
||||||
for state, e in jump_forward_map.state_to_jump_forward.items():
|
for state, e in jump_forward_map.state_to_jump_forward.items():
|
||||||
|
|||||||
@@ -15,38 +15,36 @@ limitations under the License.
|
|||||||
|
|
||||||
"""Constrained decoding with xgrammar backend."""
|
"""Constrained decoding with xgrammar backend."""
|
||||||
|
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
|
||||||
|
|
||||||
try:
|
from sglang.srt.constrained.base_grammar_backend import (
|
||||||
from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
|
BaseGrammarBackend,
|
||||||
|
BaseGrammarObject,
|
||||||
import_error = None
|
)
|
||||||
except ImportError as e:
|
|
||||||
import_error = e
|
|
||||||
|
|
||||||
class Dummy:
|
|
||||||
pass
|
|
||||||
|
|
||||||
GrammarMatcher = CompiledGrammar = CachedGrammarCompiler = Dummy
|
|
||||||
|
|
||||||
|
|
||||||
MAX_ROLLBACK_TOKENS = 10
|
MAX_ROLLBACK_TOKENS = 10
|
||||||
|
|
||||||
|
|
||||||
class XGrammarGrammar:
|
class XGrammarGrammar(BaseGrammarObject):
|
||||||
|
|
||||||
def __init__(self, matcher: GrammarMatcher, vocab_size: int) -> None:
|
def __init__(
|
||||||
|
self, matcher: GrammarMatcher, vocab_size: int, ctx: CompiledGrammar
|
||||||
|
) -> None:
|
||||||
self.matcher = matcher
|
self.matcher = matcher
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
self.ctx = ctx
|
||||||
|
|
||||||
def accept_token(self, token: int):
|
def accept_token(self, token: int):
|
||||||
assert self.matcher.accept_token(token)
|
assert self.matcher.accept_token(token)
|
||||||
|
|
||||||
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
|
def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]:
|
||||||
return [], self.matcher.find_jump_forward_string()
|
s = self.matcher.find_jump_forward_string()
|
||||||
|
if s:
|
||||||
|
return [], s
|
||||||
|
return None
|
||||||
|
|
||||||
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
|
||||||
_, data = helper
|
_, data = helper
|
||||||
@@ -77,51 +75,40 @@ class XGrammarGrammar:
|
|||||||
self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size)
|
self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size)
|
||||||
] = 1
|
] = 1
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
matcher = GrammarMatcher(
|
||||||
|
self.ctx,
|
||||||
|
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
||||||
|
mask_vocab_size=self.vocab_size,
|
||||||
|
)
|
||||||
|
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
|
||||||
|
|
||||||
class XGrammarGrammarBackend:
|
|
||||||
|
class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
):
|
):
|
||||||
if import_error:
|
super().__init__()
|
||||||
raise import_error
|
|
||||||
|
|
||||||
self.executor = ThreadPoolExecutor()
|
|
||||||
self.grammar_cache = XGrammarCache(tokenizer, vocab_size)
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
|
|
||||||
def _query(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
|
||||||
return XGrammarGrammar(self.grammar_cache.query(key), self.vocab_size)
|
|
||||||
|
|
||||||
def query(self, key: Tuple[str, str]) -> Future:
|
|
||||||
return self.executor.submit(self._query, key)
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.grammar_cache.reset()
|
|
||||||
|
|
||||||
|
|
||||||
class XGrammarCache:
|
|
||||||
def __init__(self, tokenizer, vocab_size: int):
|
|
||||||
self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
|
self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
def get_context(self, key: Tuple[str, str]) -> CompiledGrammar:
|
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
||||||
key_type, key_string = key
|
key_type, key_string = key
|
||||||
if key_type == "json":
|
if key_type == "json":
|
||||||
return self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
|
ctx = self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
|
||||||
elif key_type == "regex":
|
elif key_type == "regex":
|
||||||
raise ValueError("regex hasn't been supported by xgrammar yet")
|
raise ValueError("regex hasn't been supported by xgrammar yet")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid key_type: {key_type}")
|
raise ValueError(f"Invalid key_type: {key_type}")
|
||||||
|
|
||||||
def query(self, key: Tuple[str, str]) -> GrammarMatcher:
|
matcher = GrammarMatcher(
|
||||||
ctx = self.get_context(key)
|
|
||||||
return GrammarMatcher(
|
|
||||||
ctx,
|
ctx,
|
||||||
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
||||||
mask_vocab_size=self.vocab_size,
|
mask_vocab_size=self.vocab_size,
|
||||||
)
|
)
|
||||||
|
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.grammar_cache.clear()
|
self.grammar_cache.clear()
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ import torch
|
|||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
@@ -248,7 +249,7 @@ class Req:
|
|||||||
self.embedding = None
|
self.embedding = None
|
||||||
|
|
||||||
# Constrained decoding
|
# Constrained decoding
|
||||||
self.grammar = None
|
self.grammar: Optional[BaseGrammarObject] = None
|
||||||
|
|
||||||
# The number of cached tokens, that were already cached in the KV cache
|
# The number of cached tokens, that were already cached in the KV cache
|
||||||
self.cached_tokens = 0
|
self.cached_tokens = 0
|
||||||
|
|||||||
@@ -244,7 +244,7 @@ class Scheduler:
|
|||||||
|
|
||||||
self.grammar_backend = OutlinesGrammarBackend(
|
self.grammar_backend = OutlinesGrammarBackend(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
whitespace_patterns=server_args.constrained_json_whitespace_pattern,
|
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
||||||
allow_jump_forward=not server_args.disable_jump_forward,
|
allow_jump_forward=not server_args.disable_jump_forward,
|
||||||
)
|
)
|
||||||
elif server_args.grammar_backend == "xgrammar":
|
elif server_args.grammar_backend == "xgrammar":
|
||||||
@@ -467,21 +467,6 @@ class Scheduler:
|
|||||||
# By default, only return the logprobs for output tokens
|
# By default, only return the logprobs for output tokens
|
||||||
req.logprob_start_len = len(recv_req.input_ids) - 1
|
req.logprob_start_len = len(recv_req.input_ids) - 1
|
||||||
|
|
||||||
# Init grammar cache for this request
|
|
||||||
if (
|
|
||||||
req.sampling_params.json_schema is not None
|
|
||||||
or req.sampling_params.regex is not None
|
|
||||||
):
|
|
||||||
assert self.grammar_backend is not None
|
|
||||||
if req.sampling_params.json_schema is not None:
|
|
||||||
req.grammar = self.grammar_backend.query(
|
|
||||||
("json", req.sampling_params.json_schema),
|
|
||||||
)
|
|
||||||
elif req.sampling_params.regex is not None:
|
|
||||||
req.grammar = self.grammar_backend.query(
|
|
||||||
("regex", req.sampling_params.regex)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Truncate prompts that are too long
|
# Truncate prompts that are too long
|
||||||
if len(req.origin_input_ids) > self.max_req_input_len:
|
if len(req.origin_input_ids) > self.max_req_input_len:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -499,7 +484,24 @@ class Scheduler:
|
|||||||
self.max_req_len - len(req.origin_input_ids) - 1,
|
self.max_req_len - len(req.origin_input_ids) - 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
if req.grammar is not None:
|
# Init grammar cache for this request
|
||||||
|
add_to_grammar_queue = False
|
||||||
|
if (
|
||||||
|
req.sampling_params.json_schema is not None
|
||||||
|
or req.sampling_params.regex is not None
|
||||||
|
):
|
||||||
|
assert self.grammar_backend is not None
|
||||||
|
if req.sampling_params.json_schema is not None:
|
||||||
|
key = ("json", req.sampling_params.json_schema)
|
||||||
|
elif req.sampling_params.regex is not None:
|
||||||
|
key = ("regex", req.sampling_params.regex)
|
||||||
|
|
||||||
|
req.grammar = self.grammar_backend.get_cached_value(key)
|
||||||
|
if not req.grammar:
|
||||||
|
req.grammar = self.grammar_backend.get_future_value(key)
|
||||||
|
add_to_grammar_queue = True
|
||||||
|
|
||||||
|
if add_to_grammar_queue:
|
||||||
self.grammar_queue.append(req)
|
self.grammar_queue.append(req)
|
||||||
else:
|
else:
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
@@ -650,14 +652,7 @@ class Scheduler:
|
|||||||
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
||||||
# Check if the grammar is ready in the grammar queue
|
# Check if the grammar is ready in the grammar queue
|
||||||
if self.grammar_queue:
|
if self.grammar_queue:
|
||||||
new_grammar_queue = []
|
self.move_ready_grammar_requests()
|
||||||
for req in self.grammar_queue:
|
|
||||||
try:
|
|
||||||
req.grammar = req.grammar.result(timeout=0.05)
|
|
||||||
self.waiting_queue.append(req)
|
|
||||||
except futures._base.TimeoutError:
|
|
||||||
new_grammar_queue.append(req)
|
|
||||||
self.grammar_queue = new_grammar_queue
|
|
||||||
|
|
||||||
# Handle the cases where prefill is not allowed
|
# Handle the cases where prefill is not allowed
|
||||||
if (
|
if (
|
||||||
@@ -1145,6 +1140,30 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def move_ready_grammar_requests(self):
|
||||||
|
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
||||||
|
num_ready_reqs = 0
|
||||||
|
for req in self.grammar_queue:
|
||||||
|
try:
|
||||||
|
req.grammar = req.grammar.result(timeout=0.05)
|
||||||
|
num_ready_reqs += 1
|
||||||
|
except futures._base.TimeoutError:
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.tp_size > 1:
|
||||||
|
# Sync across TP ranks to make sure they have the same number of ready requests
|
||||||
|
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
|
||||||
|
)
|
||||||
|
num_ready_reqs_max = tensor.item()
|
||||||
|
for i in range(num_ready_reqs, num_ready_reqs_max):
|
||||||
|
self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
|
||||||
|
num_ready_reqs = num_ready_reqs_max
|
||||||
|
|
||||||
|
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
|
||||||
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
"""Flush the memory pool and cache."""
|
"""Flush the memory pool and cache."""
|
||||||
if len(self.waiting_queue) == 0 and (
|
if len(self.waiting_queue) == 0 and (
|
||||||
@@ -1152,9 +1171,8 @@ class Scheduler:
|
|||||||
):
|
):
|
||||||
self.tree_cache.reset()
|
self.tree_cache.reset()
|
||||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||||
if self.grammar_backend is not None:
|
if self.grammar_backend:
|
||||||
self.grammar_backend.reset()
|
self.grammar_backend.reset()
|
||||||
# TODO(dark): reset the bnf cache
|
|
||||||
self.req_to_token_pool.clear()
|
self.req_to_token_pool.clear()
|
||||||
self.token_to_kv_pool.clear()
|
self.token_to_kv_pool.clear()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|||||||
Reference in New Issue
Block a user