[Performance] Support both xgrammar and outlines for constrained decoding (#1752)
This commit is contained in:
@@ -51,6 +51,21 @@ except ImportError:
|
|||||||
return build_regex_from_schema(schema, whitespace_pattern)
|
return build_regex_from_schema(schema, whitespace_pattern)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xgrammar import (
|
||||||
|
GrammarMatcher,
|
||||||
|
GrammarMatcherInitContext,
|
||||||
|
GrammarMatcherInitContextCache,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
|
||||||
|
class Dummy:
|
||||||
|
pass
|
||||||
|
|
||||||
|
GrammarMatcher = Dummy
|
||||||
|
GrammarMatcherInitContext = Dummy
|
||||||
|
GrammarMatcherInitContextCache = Dummy
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"RegexGuide",
|
"RegexGuide",
|
||||||
"FSMInfo",
|
"FSMInfo",
|
||||||
@@ -60,4 +75,7 @@ __all__ = [
|
|||||||
"disk_cache",
|
"disk_cache",
|
||||||
"disable_cache",
|
"disable_cache",
|
||||||
"make_byte_level_fsm",
|
"make_byte_level_fsm",
|
||||||
|
"GrammarMatcher",
|
||||||
|
"GrammarMatcherInitContext",
|
||||||
|
"GrammarMatcherInitContextCache",
|
||||||
]
|
]
|
||||||
|
|||||||
61
python/sglang/srt/constrained/bnf_cache.py
Normal file
61
python/sglang/srt/constrained/bnf_cache.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""Cache for the compressed finite state machine."""
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from sglang.srt.constrained import (
|
||||||
|
GrammarMatcher,
|
||||||
|
GrammarMatcherInitContext,
|
||||||
|
GrammarMatcherInitContextCache,
|
||||||
|
)
|
||||||
|
|
||||||
|
MAX_ROLLBACK_TOKENS = 10
|
||||||
|
|
||||||
|
|
||||||
|
class BNFCache:
|
||||||
|
grammar_cache: GrammarMatcherInitContextCache
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer_path,
|
||||||
|
tokenizer_args_dict,
|
||||||
|
skip_tokenizer_init=False,
|
||||||
|
whitespace_patterns=None,
|
||||||
|
):
|
||||||
|
# TODO(dark): how to deal with whitespace_patterns and skip_tokenizer_init
|
||||||
|
if skip_tokenizer_init:
|
||||||
|
return
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
|
||||||
|
self.grammar_cache = GrammarMatcherInitContextCache(
|
||||||
|
tokenizer_or_vocab=tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_context(self, key: Tuple[str, str]) -> GrammarMatcherInitContext:
|
||||||
|
key_type, key_string = key
|
||||||
|
if key_type == "json":
|
||||||
|
return self.grammar_cache.get_init_context_for_json_schema(key_string)
|
||||||
|
elif key_type == "regex":
|
||||||
|
raise ValueError(f"regex hasn't been supported by xgrammar yet")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid key_type: {key_type}")
|
||||||
|
|
||||||
|
def query(self, key: Tuple[str, str], vocab_size: int) -> GrammarMatcher:
|
||||||
|
ctx = self.get_context(key)
|
||||||
|
return GrammarMatcher(
|
||||||
|
ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size
|
||||||
|
)
|
||||||
190
python/sglang/srt/constrained/grammar.py
Normal file
190
python/sglang/srt/constrained/grammar.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""Cache for the compressed finite state machine."""
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.constrained import GrammarMatcher, RegexGuide
|
||||||
|
from sglang.srt.constrained.bnf_cache import BNFCache
|
||||||
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||||
|
from sglang.srt.constrained.jump_forward import JumpForwardCache, JumpForwardMap
|
||||||
|
|
||||||
|
# from sglang.srt.managers.schedule_batch import Req
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||||
|
|
||||||
|
|
||||||
|
class XGrammarJump:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class JumpHelper:
|
||||||
|
data: Union[List, str]
|
||||||
|
state: int
|
||||||
|
suffix_ids: List[int]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, data: Union[List, str] = "", state: int = -1, suffix_ids=[]
|
||||||
|
) -> None:
|
||||||
|
self.data = data
|
||||||
|
self.state = state
|
||||||
|
self.suffix_ids = suffix_ids
|
||||||
|
|
||||||
|
def can_jump(self):
|
||||||
|
return len(self.data) > 0
|
||||||
|
|
||||||
|
|
||||||
|
class Grammar:
|
||||||
|
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]]
|
||||||
|
jump_map: Union[XGrammarJump, JumpForwardMap, None]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]],
|
||||||
|
jump_map: Union[XGrammarJump, JumpForwardMap, None],
|
||||||
|
) -> None:
|
||||||
|
self.grammar = grammar
|
||||||
|
self.jump_map = jump_map
|
||||||
|
|
||||||
|
def accept_token(self, token: int):
|
||||||
|
if isinstance(self.grammar, GrammarMatcher):
|
||||||
|
assert self.grammar.accept_token(token)
|
||||||
|
else:
|
||||||
|
guide, state = self.grammar
|
||||||
|
self.grammar = guide, guide.get_next_state(state, token)
|
||||||
|
|
||||||
|
def try_jump(self, tokenizer) -> JumpHelper:
|
||||||
|
if isinstance(self.jump_map, XGrammarJump):
|
||||||
|
assert isinstance(self.grammar, GrammarMatcher)
|
||||||
|
return JumpHelper(self.grammar.find_jump_forward_string())
|
||||||
|
elif isinstance(self.jump_map, JumpForwardMap):
|
||||||
|
assert isinstance(self.grammar, Tuple)
|
||||||
|
|
||||||
|
_, state = self.grammar
|
||||||
|
jump_forward_bytes = self.jump_map.jump_forward_byte(state)
|
||||||
|
if jump_forward_bytes is None or len(jump_forward_bytes) == 0:
|
||||||
|
return JumpHelper() # can't jump
|
||||||
|
|
||||||
|
# preprocess the jump forward string
|
||||||
|
suffix_bytes = []
|
||||||
|
continuation_range = range(0x80, 0xC0)
|
||||||
|
cur_state = state
|
||||||
|
while (
|
||||||
|
len(jump_forward_bytes)
|
||||||
|
and jump_forward_bytes[0][0] in continuation_range
|
||||||
|
):
|
||||||
|
# continuation bytes
|
||||||
|
byte_edge = jump_forward_bytes.pop(0)
|
||||||
|
suffix_bytes.append(byte_edge[0])
|
||||||
|
cur_state = byte_edge[1]
|
||||||
|
|
||||||
|
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
|
||||||
|
suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)
|
||||||
|
return JumpHelper(suffix_ids, cur_state, suffix_bytes)
|
||||||
|
else:
|
||||||
|
return JumpHelper() # can't jump
|
||||||
|
|
||||||
|
def jump_forward_str_state(self, helper: JumpHelper) -> Tuple[str, int]:
|
||||||
|
if isinstance(helper.data, str):
|
||||||
|
return helper.data, -1
|
||||||
|
else:
|
||||||
|
assert isinstance(self.jump_map, JumpForwardMap)
|
||||||
|
return self.jump_map.jump_forward_symbol(helper.state)
|
||||||
|
|
||||||
|
def jump_and_retokenize(
|
||||||
|
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
|
||||||
|
):
|
||||||
|
if isinstance(self.grammar, GrammarMatcher):
|
||||||
|
k = 0
|
||||||
|
for i, old_id in enumerate(old_output_ids):
|
||||||
|
if old_id == new_output_ids[i]:
|
||||||
|
k = i + 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# rollback to the last token that is the same
|
||||||
|
if k < len(old_output_ids):
|
||||||
|
self.grammar.rollback(len(old_output_ids) - k)
|
||||||
|
|
||||||
|
for i in range(k, len(new_output_ids)):
|
||||||
|
assert self.grammar.accept_token(new_output_ids[i])
|
||||||
|
else:
|
||||||
|
self.grammar = self.grammar[0], next_state
|
||||||
|
|
||||||
|
def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int):
|
||||||
|
if isinstance(self.grammar, GrammarMatcher):
|
||||||
|
# Note that this bitmask is a bitset, not bool
|
||||||
|
bitmask = self.grammar.find_next_token_bitmask()
|
||||||
|
# Mask the tokens that are not allowed
|
||||||
|
vocab_mask[
|
||||||
|
self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size)
|
||||||
|
] = 1
|
||||||
|
else:
|
||||||
|
guide, state = self.grammar
|
||||||
|
vocab_mask.fill_(1)
|
||||||
|
vocab_mask[guide.get_next_instruction(state).tokens] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class GrammarCache:
|
||||||
|
grammar_cache: Union[BNFCache, FSMCache]
|
||||||
|
jump_cache: Union[XGrammarJump, JumpForwardCache, None]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer_path,
|
||||||
|
tokenizer_args_dict,
|
||||||
|
skip_tokenizer_init=False,
|
||||||
|
whitespace_patterns=None,
|
||||||
|
backend=None,
|
||||||
|
allow_jump=False,
|
||||||
|
):
|
||||||
|
if backend == "xgrammar":
|
||||||
|
self.grammar_cache = BNFCache(
|
||||||
|
tokenizer_path=tokenizer_path,
|
||||||
|
tokenizer_args_dict=tokenizer_args_dict,
|
||||||
|
skip_tokenizer_init=skip_tokenizer_init,
|
||||||
|
whitespace_patterns=whitespace_patterns,
|
||||||
|
)
|
||||||
|
self.jump_cache = XGrammarJump() if allow_jump else None
|
||||||
|
else:
|
||||||
|
assert backend == "outlines"
|
||||||
|
self.grammar_cache = FSMCache(
|
||||||
|
tokenizer_path=tokenizer_path,
|
||||||
|
tokenizer_args_dict=tokenizer_args_dict,
|
||||||
|
skip_tokenizer_init=skip_tokenizer_init,
|
||||||
|
constrained_json_whitespace_pattern=whitespace_patterns,
|
||||||
|
enable=True,
|
||||||
|
)
|
||||||
|
self.jump_cache = JumpForwardCache() if allow_jump else None
|
||||||
|
|
||||||
|
def query(self, key: Tuple[str, str], vocab_size: int) -> Grammar:
|
||||||
|
if isinstance(self.grammar_cache, BNFCache):
|
||||||
|
assert not isinstance(self.jump_cache, JumpForwardCache)
|
||||||
|
return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache)
|
||||||
|
else:
|
||||||
|
jump_map = None
|
||||||
|
guide, regex = self.grammar_cache.query(key)
|
||||||
|
if isinstance(self.jump_cache, JumpForwardCache):
|
||||||
|
jump_map = self.jump_cache.query(regex)
|
||||||
|
return Grammar((guide, 0), jump_map)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
if isinstance(self.grammar_cache, FSMCache):
|
||||||
|
self.grammar_cache.reset()
|
||||||
|
if isinstance(self.jump_cache, JumpForwardCache):
|
||||||
|
self.jump_cache.reset()
|
||||||
@@ -37,8 +37,7 @@ import torch
|
|||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.constrained import RegexGuide
|
from sglang.srt.constrained.grammar import Grammar
|
||||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
@@ -247,9 +246,7 @@ class Req:
|
|||||||
self.embedding = None
|
self.embedding = None
|
||||||
|
|
||||||
# Constrained decoding
|
# Constrained decoding
|
||||||
self.regex_fsm: RegexGuide = None
|
self.grammar: Optional[Grammar] = None
|
||||||
self.regex_fsm_state: int = 0
|
|
||||||
self.jump_forward_map: JumpForwardMap = None
|
|
||||||
|
|
||||||
# For Qwen2-VL
|
# For Qwen2-VL
|
||||||
self.mrope_position_delta = [] # use mutable object
|
self.mrope_position_delta = [] # use mutable object
|
||||||
@@ -359,6 +356,8 @@ class Req:
|
|||||||
return
|
return
|
||||||
|
|
||||||
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
||||||
|
assert self.grammar is not None and self.tokenizer is not None
|
||||||
|
|
||||||
if self.origin_input_text is None:
|
if self.origin_input_text is None:
|
||||||
# Recovering text can only use unpadded ids
|
# Recovering text can only use unpadded ids
|
||||||
self.origin_input_text = self.tokenizer.decode(
|
self.origin_input_text = self.tokenizer.decode(
|
||||||
@@ -398,7 +397,8 @@ class Req:
|
|||||||
self.surr_offset = self.read_offset - i
|
self.surr_offset = self.read_offset - i
|
||||||
break
|
break
|
||||||
|
|
||||||
self.regex_fsm_state = next_state
|
# update the inner state of the grammar
|
||||||
|
self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
|
||||||
|
|
||||||
if self.return_logprob:
|
if self.return_logprob:
|
||||||
# For fast-forward part's logprobs
|
# For fast-forward part's logprobs
|
||||||
@@ -468,8 +468,8 @@ class ScheduleBatch:
|
|||||||
# Stream
|
# Stream
|
||||||
has_stream: bool = False
|
has_stream: bool = False
|
||||||
|
|
||||||
# Has regex
|
# Has grammar
|
||||||
has_regex: bool = False
|
has_grammar: bool = False
|
||||||
|
|
||||||
# device
|
# device
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
@@ -477,7 +477,7 @@ class ScheduleBatch:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def init_new(
|
def init_new(
|
||||||
cls,
|
cls,
|
||||||
reqs,
|
reqs: List[Req],
|
||||||
req_to_token_pool,
|
req_to_token_pool,
|
||||||
token_to_kv_pool,
|
token_to_kv_pool,
|
||||||
tree_cache,
|
tree_cache,
|
||||||
@@ -491,7 +491,7 @@ class ScheduleBatch:
|
|||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
return_logprob=any(req.return_logprob for req in reqs),
|
return_logprob=any(req.return_logprob for req in reqs),
|
||||||
has_stream=any(req.stream for req in reqs),
|
has_stream=any(req.stream for req in reqs),
|
||||||
has_regex=any(req.regex_fsm for req in reqs),
|
has_grammar=any(req.grammar for req in reqs),
|
||||||
device=req_to_token_pool.device,
|
device=req_to_token_pool.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -803,26 +803,10 @@ class ScheduleBatch:
|
|||||||
keep_indices = set(i for i in range(len(self.reqs)))
|
keep_indices = set(i for i in range(len(self.reqs)))
|
||||||
|
|
||||||
for i, req in enumerate(self.reqs):
|
for i, req in enumerate(self.reqs):
|
||||||
if req.jump_forward_map is not None:
|
if req.grammar is not None:
|
||||||
jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
|
jump_helper = req.grammar.try_jump(req.tokenizer)
|
||||||
req.regex_fsm_state
|
if jump_helper.can_jump():
|
||||||
)
|
suffix_ids = jump_helper.suffix_ids
|
||||||
if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
|
|
||||||
suffix_bytes = []
|
|
||||||
continuation_range = range(0x80, 0xC0)
|
|
||||||
cur_state = req.regex_fsm_state
|
|
||||||
while (
|
|
||||||
len(jump_forward_bytes)
|
|
||||||
and jump_forward_bytes[0][0] in continuation_range
|
|
||||||
):
|
|
||||||
# continuation bytes
|
|
||||||
byte_edge = jump_forward_bytes.pop(0)
|
|
||||||
suffix_bytes.append(byte_edge[0])
|
|
||||||
cur_state = byte_edge[1]
|
|
||||||
|
|
||||||
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
|
|
||||||
suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)
|
|
||||||
|
|
||||||
# Current ids, for cache and revert
|
# Current ids, for cache and revert
|
||||||
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
|
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
|
||||||
cur_output_ids = req.output_ids
|
cur_output_ids = req.output_ids
|
||||||
@@ -836,10 +820,8 @@ class ScheduleBatch:
|
|||||||
(
|
(
|
||||||
jump_forward_str,
|
jump_forward_str,
|
||||||
next_state,
|
next_state,
|
||||||
) = req.jump_forward_map.jump_forward_symbol(cur_state)
|
) = req.grammar.jump_forward_str_state(jump_helper)
|
||||||
|
|
||||||
# Make the incrementally decoded text part of jump_forward_str
|
|
||||||
# so that the UTF-8 will not corrupt
|
|
||||||
jump_forward_str = new_text + jump_forward_str
|
jump_forward_str = new_text + jump_forward_str
|
||||||
if not req.jump_forward_and_retokenize(
|
if not req.jump_forward_and_retokenize(
|
||||||
jump_forward_str, next_state
|
jump_forward_str, next_state
|
||||||
@@ -946,7 +928,7 @@ class ScheduleBatch:
|
|||||||
self.top_logprobs_nums = None
|
self.top_logprobs_nums = None
|
||||||
|
|
||||||
self.has_stream = any(req.stream for req in self.reqs)
|
self.has_stream = any(req.stream for req in self.reqs)
|
||||||
self.has_regex = any(req.regex_fsm for req in self.reqs)
|
self.has_grammar = any(req.grammar for req in self.reqs)
|
||||||
|
|
||||||
self.sampling_info.filter_batch(keep_indices, new_indices)
|
self.sampling_info.filter_batch(keep_indices, new_indices)
|
||||||
|
|
||||||
@@ -979,7 +961,7 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
self.return_logprob = self.return_logprob or other.return_logprob
|
self.return_logprob = self.return_logprob or other.return_logprob
|
||||||
self.has_stream = self.has_stream or other.has_stream
|
self.has_stream = self.has_stream or other.has_stream
|
||||||
self.has_regex = self.has_regex or other.has_regex
|
self.has_grammar = self.has_grammar or other.has_grammar
|
||||||
|
|
||||||
def get_model_worker_batch(self):
|
def get_model_worker_batch(self):
|
||||||
if self.forward_mode.is_decode():
|
if self.forward_mode.is_decode():
|
||||||
@@ -989,13 +971,10 @@ class ScheduleBatch:
|
|||||||
extend_prefix_lens = self.prefix_lens
|
extend_prefix_lens = self.prefix_lens
|
||||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||||
|
|
||||||
if self.has_regex:
|
if self.has_grammar:
|
||||||
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
|
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
||||||
self.sampling_info.regex_fsm_states = [
|
|
||||||
req.regex_fsm_state for req in self.reqs
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
self.sampling_info.regex_fsms = None
|
self.sampling_info.grammars = None
|
||||||
|
|
||||||
global bid
|
global bid
|
||||||
bid += 1
|
bid += 1
|
||||||
|
|||||||
@@ -29,8 +29,7 @@ import zmq
|
|||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
from sglang.srt.constrained.grammar import GrammarCache
|
||||||
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
@@ -225,17 +224,20 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Init the FSM cache for constrained generation
|
# Init the FSM cache for constrained generation
|
||||||
|
self.grammar_cache = None
|
||||||
|
|
||||||
if not server_args.skip_tokenizer_init:
|
if not server_args.skip_tokenizer_init:
|
||||||
self.regex_fsm_cache = FSMCache(
|
self.grammar_cache = GrammarCache(
|
||||||
server_args.tokenizer_path,
|
server_args.tokenizer_path,
|
||||||
{
|
{
|
||||||
"tokenizer_mode": server_args.tokenizer_mode,
|
"tokenizer_mode": server_args.tokenizer_mode,
|
||||||
"trust_remote_code": server_args.trust_remote_code,
|
"trust_remote_code": server_args.trust_remote_code,
|
||||||
},
|
},
|
||||||
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
||||||
constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
whitespace_patterns=server_args.constrained_json_whitespace_pattern,
|
||||||
|
backend=server_args.grammar_backend,
|
||||||
|
allow_jump=not server_args.disable_regex_jump_forward,
|
||||||
)
|
)
|
||||||
self.jump_forward_cache = JumpForwardCache()
|
|
||||||
|
|
||||||
# Init new token estimation
|
# Init new token estimation
|
||||||
assert (
|
assert (
|
||||||
@@ -402,22 +404,20 @@ class Scheduler:
|
|||||||
# By default, only return the logprobs for output tokens
|
# By default, only return the logprobs for output tokens
|
||||||
req.logprob_start_len = len(recv_req.input_ids) - 1
|
req.logprob_start_len = len(recv_req.input_ids) - 1
|
||||||
|
|
||||||
# Init regex FSM
|
# Init regex FSM or BNF
|
||||||
if (
|
if (
|
||||||
req.sampling_params.json_schema is not None
|
req.sampling_params.json_schema is not None
|
||||||
or req.sampling_params.regex is not None
|
or req.sampling_params.regex is not None
|
||||||
):
|
):
|
||||||
|
assert self.grammar_cache is not None
|
||||||
if req.sampling_params.json_schema is not None:
|
if req.sampling_params.json_schema is not None:
|
||||||
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
|
req.grammar = self.grammar_cache.query(
|
||||||
("json", req.sampling_params.json_schema)
|
("json", req.sampling_params.json_schema),
|
||||||
|
self.model_config.vocab_size,
|
||||||
)
|
)
|
||||||
elif req.sampling_params.regex is not None:
|
elif req.sampling_params.regex is not None:
|
||||||
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
|
req.grammar = self.grammar_cache.query(
|
||||||
("regex", req.sampling_params.regex)
|
("regex", req.sampling_params.regex), self.model_config.vocab_size
|
||||||
)
|
|
||||||
if not self.disable_regex_jump_forward:
|
|
||||||
req.jump_forward_map = self.jump_forward_cache.query(
|
|
||||||
computed_regex_string
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Truncate prompts that are too long
|
# Truncate prompts that are too long
|
||||||
@@ -796,10 +796,8 @@ class Scheduler:
|
|||||||
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
|
|
||||||
if req.regex_fsm is not None:
|
if req.grammar is not None:
|
||||||
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
req.grammar.accept_token(next_token_ids[i])
|
||||||
req.regex_fsm_state, next_token_ids[i]
|
|
||||||
)
|
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
logprob_pt += self.add_logprob_return_values(
|
logprob_pt += self.add_logprob_return_values(
|
||||||
@@ -855,10 +853,8 @@ class Scheduler:
|
|||||||
req.output_ids.append(next_token_id)
|
req.output_ids.append(next_token_id)
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
if req.regex_fsm is not None:
|
if req.grammar is not None:
|
||||||
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
req.grammar.accept_token(next_token_id)
|
||||||
req.regex_fsm_state, next_token_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if req.finished():
|
if req.finished():
|
||||||
self.tree_cache.cache_finished_req(req)
|
self.tree_cache.cache_finished_req(req)
|
||||||
@@ -1056,7 +1052,9 @@ class Scheduler:
|
|||||||
):
|
):
|
||||||
self.tree_cache.reset()
|
self.tree_cache.reset()
|
||||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||||
self.regex_fsm_cache.reset()
|
if self.grammar_cache is not None:
|
||||||
|
self.grammar_cache.reset()
|
||||||
|
# TODO(dark): reset the bnf cache
|
||||||
self.req_to_token_pool.clear()
|
self.req_to_token_pool.clear()
|
||||||
self.token_to_kv_pool.clear()
|
self.token_to_kv_pool.clear()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, List, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import sglang.srt.sampling.penaltylib as penaltylib
|
import sglang.srt.sampling.penaltylib as penaltylib
|
||||||
from sglang.srt.constrained import RegexGuide
|
from sglang.srt.constrained.grammar import Grammar
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
@@ -29,11 +29,9 @@ class SamplingBatchInfo:
|
|||||||
# Bias Tensors
|
# Bias Tensors
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
logit_bias: torch.Tensor = None
|
logit_bias: torch.Tensor = None
|
||||||
vocab_mask: torch.Tensor = None
|
vocab_mask: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# FSM states
|
grammars: Optional[List[Optional[Grammar]]] = None
|
||||||
regex_fsms: List[RegexGuide] = None
|
|
||||||
regex_fsm_states: List[int] = None
|
|
||||||
|
|
||||||
# Penalizer
|
# Penalizer
|
||||||
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
||||||
@@ -136,8 +134,7 @@ class SamplingBatchInfo:
|
|||||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||||
|
|
||||||
def update_regex_vocab_mask(self):
|
def update_regex_vocab_mask(self):
|
||||||
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
|
if not self.grammars or not any(grammar for grammar in self.grammars):
|
||||||
if not has_regex:
|
|
||||||
self.vocab_mask = None
|
self.vocab_mask = None
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -147,12 +144,9 @@ class SamplingBatchInfo:
|
|||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
for i, regex_fsm in enumerate(self.regex_fsms):
|
for i, grammar in enumerate(self.grammars):
|
||||||
if regex_fsm is not None:
|
if grammar is not None:
|
||||||
self.vocab_mask[i].fill_(1)
|
grammar.fill_vocab_mask(self.vocab_mask[i], self.vocab_size)
|
||||||
self.vocab_mask[i][
|
|
||||||
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
|
|
||||||
] = 0
|
|
||||||
|
|
||||||
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
||||||
if self.penalizer_orchestrator:
|
if self.penalizer_orchestrator:
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ class ServerArgs:
|
|||||||
# Kernel backend
|
# Kernel backend
|
||||||
attention_backend: Optional[str] = None
|
attention_backend: Optional[str] = None
|
||||||
sampling_backend: Optional[str] = None
|
sampling_backend: Optional[str] = None
|
||||||
|
grammar_backend: Optional[str] = "outlines"
|
||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
disable_flashinfer: bool = False
|
disable_flashinfer: bool = False
|
||||||
@@ -537,6 +538,13 @@ class ServerArgs:
|
|||||||
default=ServerArgs.sampling_backend,
|
default=ServerArgs.sampling_backend,
|
||||||
help="Choose the kernels for sampling layers.",
|
help="Choose the kernels for sampling layers.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--grammar-backend",
|
||||||
|
type=str,
|
||||||
|
choices=["xgrammar", "outlines"],
|
||||||
|
default=ServerArgs.grammar_backend,
|
||||||
|
help="Choose the backend for constrained decoding.",
|
||||||
|
)
|
||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user