From ba069a24d3e116b37399cf3ebd295c97c49ae6fd Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 12 Nov 2024 21:17:38 -0800 Subject: [PATCH] Fix grammar backend (#2018) --- python/sglang/srt/constrained/__init__.py | 29 +-- .../sglang/srt/constrained/base_tool_cache.py | 4 +- python/sglang/srt/constrained/grammar.py | 182 ---------------- .../srt/constrained/outlines_backend.py | 203 ++++++++++++++++++ .../sglang/srt/constrained/outlines_cache.py | 96 --------- .../srt/constrained/outlines_jump_forward.py | 2 +- .../srt/constrained/xgrammar_backend.py | 127 +++++++++++ .../sglang/srt/constrained/xgrammar_cache.py | 75 ------- python/sglang/srt/managers/schedule_batch.py | 14 +- python/sglang/srt/managers/scheduler.py | 66 +++--- .../srt/sampling/sampling_batch_info.py | 5 +- python/sglang/srt/server_args.py | 10 +- test/srt/test_json_constrained.py | 22 +- 13 files changed, 401 insertions(+), 434 deletions(-) delete mode 100644 python/sglang/srt/constrained/grammar.py create mode 100644 python/sglang/srt/constrained/outlines_backend.py delete mode 100644 python/sglang/srt/constrained/outlines_cache.py create mode 100644 python/sglang/srt/constrained/xgrammar_backend.py delete mode 100644 python/sglang/srt/constrained/xgrammar_cache.py diff --git a/python/sglang/srt/constrained/__init__.py b/python/sglang/srt/constrained/__init__.py index 8ce7f05b6..1ea79f924 100644 --- a/python/sglang/srt/constrained/__init__.py +++ b/python/sglang/srt/constrained/__init__.py @@ -13,30 +13,5 @@ See the License for the specific language governing permissions and limitations under the License. """ -import json -from typing import Dict, Optional, Union - -from pydantic import BaseModel - -try: - from outlines.fsm.json_schema import build_regex_from_object -except ImportError: - # Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema, - # which only accepts string schema as input. - from outlines.fsm.json_schema import build_regex_from_schema - - def build_regex_from_object( - object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None - ): - if isinstance(object, type(BaseModel)): - schema = json.dumps(object.model_json_schema()) - elif isinstance(object, Dict): - schema = json.dumps(object) - else: - schema = object - return build_regex_from_schema(schema, whitespace_pattern) - - -__all__ = [ - "build_regex_from_object", -] +# TODO(lmzheng): make this an optional dependency +from sglang.srt.constrained.outlines_backend import build_regex_from_object diff --git a/python/sglang/srt/constrained/base_tool_cache.py b/python/sglang/srt/constrained/base_tool_cache.py index f137ad16e..1910eb730 100644 --- a/python/sglang/srt/constrained/base_tool_cache.py +++ b/python/sglang/srt/constrained/base_tool_cache.py @@ -95,9 +95,7 @@ class BaseToolCache: def get_cache_hit_rate(self): with self.lock_metrics: - if self.metrics["total"] == 0: - return 0 - return self.metrics["hit"] / self.metrics["total"] + return self.metrics["hit"] / max(self.metrics["total"], 1) def get_avg_init_time(self): with self.lock_metrics: diff --git a/python/sglang/srt/constrained/grammar.py b/python/sglang/srt/constrained/grammar.py deleted file mode 100644 index 3f9dfb8a0..000000000 --- a/python/sglang/srt/constrained/grammar.py +++ /dev/null @@ -1,182 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -"""Cache for the compressed finite state machine.""" -import logging -from concurrent.futures import Future, ThreadPoolExecutor -from typing import List, Tuple, Union - -import torch - -from sglang.srt.constrained.outlines_cache import OutlinesCache, RegexGuide -from sglang.srt.constrained.outlines_jump_forward import ( - OutlinesJumpCache, - OutlinesJumpForwardMap, -) -from sglang.srt.constrained.xgrammar_cache import ( - GrammarMatcher, - XGrammarBackend, - XGrammarJumpCache, -) - -logger = logging.getLogger(__name__) - - -class JumpHelper: - - def __init__( - self, data: Union[List, str] = "", state: int = -1, suffix_ids=[] - ) -> None: - self.data: Union[List, str] = data - self.state: int = state - self.suffix_ids: List[int] = suffix_ids - - def can_jump(self): - return len(self.data) > 0 - - -class Grammar: - - def __init__( - self, - grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]], - jump_map: Union[XGrammarJumpCache, OutlinesJumpForwardMap, None], - ) -> None: - self.grammar = grammar - self.jump_map = jump_map - - def accept_token(self, token: int): - if isinstance(self.grammar, GrammarMatcher): - assert self.grammar.accept_token(token) - else: - guide, state = self.grammar - self.grammar = guide, guide.get_next_state(state, token) - - def try_jump(self, tokenizer) -> JumpHelper: - if isinstance(self.jump_map, XGrammarJumpCache): - assert isinstance(self.grammar, GrammarMatcher) - return JumpHelper(self.grammar.find_jump_forward_string()) - elif isinstance(self.jump_map, OutlinesJumpForwardMap): - assert isinstance(self.grammar, Tuple) - - _, state = self.grammar - jump_forward_bytes = self.jump_map.jump_forward_byte(state) - if jump_forward_bytes is None or len(jump_forward_bytes) == 0: - return JumpHelper() # can't jump - - # preprocess the jump forward string - suffix_bytes = [] - continuation_range = range(0x80, 0xC0) - cur_state = state - while ( - len(jump_forward_bytes) - and jump_forward_bytes[0][0] in continuation_range - ): - # continuation bytes - byte_edge = jump_forward_bytes.pop(0) - suffix_bytes.append(byte_edge[0]) - cur_state = byte_edge[1] - - suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes] - suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens) - return JumpHelper(suffix_ids, cur_state, suffix_bytes) - else: - return JumpHelper() # can't jump - - def jump_forward_str_state(self, helper: JumpHelper) -> Tuple[str, int]: - if isinstance(helper.data, str): - return helper.data, -1 - else: - assert isinstance(self.jump_map, OutlinesJumpForwardMap) - return self.jump_map.jump_forward_symbol(helper.state) - - def jump_and_retokenize( - self, old_output_ids: List[int], new_output_ids: List[int], next_state: int - ): - if isinstance(self.grammar, GrammarMatcher): - k = 0 - for i, old_id in enumerate(old_output_ids): - if old_id == new_output_ids[i]: - k = i + 1 - else: - break - - # rollback to the last token that is the same - if k < len(old_output_ids): - self.grammar.rollback(len(old_output_ids) - k) - - for i in range(k, len(new_output_ids)): - assert self.grammar.accept_token(new_output_ids[i]) - else: - self.grammar = self.grammar[0], next_state - - def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int): - if isinstance(self.grammar, GrammarMatcher): - # Note that this bitmask is a bitset, not bool - bitmask = self.grammar.get_next_token_bitmask() - # Mask the tokens that are not allowed - vocab_mask[ - self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size) - ] = 1 - else: - guide, state = self.grammar - vocab_mask.fill_(1) - vocab_mask[guide.get_next_instruction(state).tokens] = 0 - - -class GrammarBackend: - - def __init__( - self, - tokenizer_path, - tokenizer_args_dict, - skip_tokenizer_init=False, - whitespace_patterns=None, - backend=None, - allow_jump=False, - ): - self.executor = ThreadPoolExecutor() - self.backend = backend - - if backend == "xgrammar": - self.grammar_cache = XGrammarBackend( - tokenizer_path=tokenizer_path, - tokenizer_args_dict=tokenizer_args_dict, - skip_tokenizer_init=skip_tokenizer_init, - whitespace_patterns=whitespace_patterns, - ) - self.jump_cache = XGrammarJumpCache() if allow_jump else None - else: - assert backend == "outlines" - self.grammar_cache = OutlinesCache( - tokenizer_path=tokenizer_path, - tokenizer_args_dict=tokenizer_args_dict, - skip_tokenizer_init=skip_tokenizer_init, - constrained_json_whitespace_pattern=whitespace_patterns, - ) - self.jump_cache = OutlinesJumpCache() if allow_jump else None - - def _query(self, key: Tuple[str, str], vocab_size: int) -> Grammar: - if isinstance(self.grammar_cache, XGrammarBackend): - return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache) - else: - guide, regex = self.grammar_cache.query(key) - jump_map = self.jump_cache.query(regex) - return Grammar((guide, 0), jump_map) - - def query(self, key: Tuple[str, str], vocab_size: int) -> Future: - return self.executor.submit(self._query, key, vocab_size) - - def reset(self): - self.grammar_cache.reset() - self.jump_cache.reset() diff --git a/python/sglang/srt/constrained/outlines_backend.py b/python/sglang/srt/constrained/outlines_backend.py new file mode 100644 index 000000000..434d94b57 --- /dev/null +++ b/python/sglang/srt/constrained/outlines_backend.py @@ -0,0 +1,203 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Constrained decoding with outlines backend.""" + +import json +import logging +from concurrent.futures import Future, ThreadPoolExecutor +from typing import Dict, List, Optional, Tuple, Union + +import torch +from interegular import InvalidSyntax, parse_pattern +from outlines.fsm.guide import RegexGuide +from outlines.models.transformers import TransformerTokenizer +from pydantic import BaseModel + +from sglang.srt.constrained.base_tool_cache import BaseToolCache +from sglang.srt.constrained.outlines_jump_forward import ( + OutlinesJumpForwardCache, + OutlinesJumpForwardMap, +) + +logger = logging.getLogger(__name__) + + +try: + from outlines.fsm.json_schema import build_regex_from_object +except ImportError: + # Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema, + # which only accepts string schema as input. + from outlines.fsm.json_schema import build_regex_from_schema + + def build_regex_from_object( + object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None + ): + if isinstance(object, type(BaseModel)): + schema = json.dumps(object.model_json_schema()) + elif isinstance(object, Dict): + schema = json.dumps(object) + else: + schema = object + return build_regex_from_schema(schema, whitespace_pattern) + + +class OutlinesGrammar: + def __init__( + self, + guide: RegexGuide, + state: int, + jump_forward_map: Union[OutlinesJumpForwardMap, None], + ) -> None: + self.guide = guide + self.state = state + self.jump_forward_map = jump_forward_map + + def accept_token(self, token: int): + self.state = self.guide.get_next_state(self.state, token) + + def try_jump_forward(self, tokenizer) -> Optional[Tuple]: + if not self.jump_forward_map: + return None + + jump_forward_bytes = self.jump_forward_map.jump_forward_byte(self.state) + if jump_forward_bytes is None or len(jump_forward_bytes) <= 1: + return None + + # preprocess the jump forward string + suffix_bytes = [] + continuation_range = range(0x80, 0xC0) + cur_state = self.state + while ( + len(jump_forward_bytes) and jump_forward_bytes[0][0] in continuation_range + ): + # continuation bytes + byte_edge = jump_forward_bytes.pop(0) + suffix_bytes.append(byte_edge[0]) + cur_state = byte_edge[1] + + suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes] + suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens) + return suffix_ids, cur_state + + def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: + _, cur_state = helper + return self.jump_forward_map.jump_forward_symbol(cur_state) + + def jump_and_retokenize( + self, old_output_ids: List[int], new_output_ids: List[int], next_state: int + ): + self.state = next_state + + def fill_vocab_mask(self, vocab_mask: torch.Tensor): + vocab_mask.fill_(1) + vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0 + + +class OutlinesGrammarBackend: + def __init__( + self, + tokenizer, + whitespace_patterns: bool, + allow_jump_forward: bool, + ): + self.executor = ThreadPoolExecutor() + self.grammar_cache = OutlinesCache( + tokenizer, + whitespace_pattern=whitespace_patterns, + ) + self.jump_forward_cache = ( + OutlinesJumpForwardCache() if allow_jump_forward else None + ) + + def _query(self, key: Tuple[str, str]) -> OutlinesGrammar: + guide, regex = self.grammar_cache.query(key) + jump_forward_map = ( + self.jump_forward_cache.query(regex) if self.jump_forward_cache else None + ) + return OutlinesGrammar(guide, 0, jump_forward_map) + + def query(self, key: Tuple[str, str]) -> Future: + return self.executor.submit(self._query, key) + + def reset(self): + self.grammar_cache.reset() + if self.jump_forward_cache: + self.jump_forward_cache.reset() + + +class OutlinesCache(BaseToolCache): + def __init__( + self, + tokenizer, + whitespace_pattern=None, + ): + super().__init__(enable=True) + + try: + self.outlines_tokenizer = TransformerTokenizer(tokenizer) + except AttributeError: + # FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0) + origin_pad_token_id = tokenizer.pad_token_id + + def fset(self, value): + self._value = value + + type(tokenizer).pad_token_id = property( + fget=type(tokenizer).pad_token_id.fget, fset=fset + ) + self.outlines_tokenizer = TransformerTokenizer(tokenizer) + self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id + self.outlines_tokenizer.pad_token_id = origin_pad_token_id + self.outlines_tokenizer.pad_token = ( + self.outlines_tokenizer.tokenizer.pad_token + ) + self.outlines_tokenizer.vocabulary = ( + self.outlines_tokenizer.tokenizer.get_vocab() + ) + self.whitespace_pattern = whitespace_pattern + + def init_value(self, key): + key_type, key_string = key + if key_type == "json": + try: + regex = build_regex_from_object( + key_string, + whitespace_pattern=self.whitespace_pattern, + ) + except NotImplementedError as e: + logger.warning( + f"skip invalid json schema: json_schema={key_string}, {e=}" + ) + return None, key_string + elif key_type == "regex": + regex = key_string + else: + raise ValueError(f"Invalid key_type: {key_type}") + try: + parse_pattern(regex) + except InvalidSyntax as e: + logger.warning(f"skip invalid regex guide: {regex=}, {e=}") + return None, regex + + ret = RegexGuide(regex, self.outlines_tokenizer), regex + return ret + + def _query(self, key: Tuple[str, str]): + guide, regex = self.grammar_cache.query(key) + jump_forward_map = ( + self.jump_forward_cache.query(regex) if self.jump_forward_cache else None + ) + return OutlinesGrammar(guide, 0, jump_forward_map) diff --git a/python/sglang/srt/constrained/outlines_cache.py b/python/sglang/srt/constrained/outlines_cache.py deleted file mode 100644 index 8971d5a5e..000000000 --- a/python/sglang/srt/constrained/outlines_cache.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -"""Cache for the compressed finite state machine.""" -import logging - -from interegular import InvalidSyntax, parse_pattern -from outlines.fsm.guide import RegexGuide -from outlines.models.transformers import TransformerTokenizer -from transformers import AutoTokenizer - -from sglang.srt.constrained import build_regex_from_object -from sglang.srt.constrained.base_tool_cache import BaseToolCache - -logger = logging.getLogger(__name__) - - -class OutlinesCache(BaseToolCache): - def __init__( - self, - tokenizer_path, - tokenizer_args_dict, - enable=True, - skip_tokenizer_init=False, - constrained_json_whitespace_pattern=None, - ): - super().__init__(enable=enable) - - if ( - skip_tokenizer_init - or tokenizer_path.endswith(".json") - or tokenizer_path.endswith(".model") - ): - # Do not support TiktokenTokenizer or SentencePieceTokenizer - return - - tokenizer_args_dict.setdefault("padding_side", "left") - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict) - try: - self.outlines_tokenizer = TransformerTokenizer(tokenizer) - except AttributeError: - # FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0) - origin_pad_token_id = tokenizer.pad_token_id - - def fset(self, value): - self._value = value - - type(tokenizer).pad_token_id = property( - fget=type(tokenizer).pad_token_id.fget, fset=fset - ) - self.outlines_tokenizer = TransformerTokenizer(tokenizer) - self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id - self.outlines_tokenizer.pad_token_id = origin_pad_token_id - self.outlines_tokenizer.pad_token = ( - self.outlines_tokenizer.tokenizer.pad_token - ) - self.outlines_tokenizer.vocabulary = ( - self.outlines_tokenizer.tokenizer.get_vocab() - ) - self.constrained_json_whitespace_pattern = constrained_json_whitespace_pattern - - def init_value(self, key): - key_type, key_string = key - if key_type == "json": - try: - regex = build_regex_from_object( - key_string, - whitespace_pattern=self.constrained_json_whitespace_pattern, - ) - except NotImplementedError as e: - logger.warning( - f"skip invalid json schema: json_schema={key_string}, {e=}" - ) - return None, key_string - elif key_type == "regex": - regex = key_string - else: - raise ValueError(f"Invalid key_type: {key_type}") - try: - parse_pattern(regex) - except InvalidSyntax as e: - logger.warning(f"skip invalid regex guide: {regex=}, {e=}") - return None, regex - return RegexGuide(regex, self.outlines_tokenizer), regex diff --git a/python/sglang/srt/constrained/outlines_jump_forward.py b/python/sglang/srt/constrained/outlines_jump_forward.py index 2439db276..e3dd6b166 100644 --- a/python/sglang/srt/constrained/outlines_jump_forward.py +++ b/python/sglang/srt/constrained/outlines_jump_forward.py @@ -164,7 +164,7 @@ class OutlinesJumpForwardMap: ) -class OutlinesJumpCache(BaseToolCache): +class OutlinesJumpForwardCache(BaseToolCache): def __init__(self): super().__init__() diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py new file mode 100644 index 000000000..812865da6 --- /dev/null +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -0,0 +1,127 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Constrained decoding with xgrammar backend.""" + +from concurrent.futures import Future, ThreadPoolExecutor +from typing import List, Tuple + +import torch + +try: + from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher + + import_error = None +except ImportError as e: + import_error = e + + class Dummy: + pass + + GrammarMatcher = CompiledGrammar = CachedGrammarCompiler = Dummy + + +MAX_ROLLBACK_TOKENS = 10 + + +class XGrammarGrammar: + + def __init__(self, matcher: GrammarMatcher, vocab_size: int) -> None: + self.matcher = matcher + self.vocab_size = vocab_size + + def accept_token(self, token: int): + assert self.matcher.accept_token(token) + + def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]: + return [], self.matcher.find_jump_forward_string() + + def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: + _, data = helper + return data, -1 + + def jump_and_retokenize( + self, old_output_ids: List[int], new_output_ids: List[int], next_state: int + ): + k = 0 + for i, old_id in enumerate(old_output_ids): + if old_id == new_output_ids[i]: + k = i + 1 + else: + break + + # rollback to the last token that is the same + if k < len(old_output_ids): + self.matcher.rollback(len(old_output_ids) - k) + + for i in range(k, len(new_output_ids)): + assert self.matcher.accept_token(new_output_ids[i]) + + def fill_vocab_mask(self, vocab_mask: torch.Tensor): + # Note that this bitmask is a bitset, not bool + bitmask = self.matcher.get_next_token_bitmask() + # Mask the tokens that are not allowed + vocab_mask[ + self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size) + ] = 1 + + +class XGrammarGrammarBackend: + def __init__( + self, + tokenizer, + vocab_size: int, + ): + if import_error: + raise import_error + + self.executor = ThreadPoolExecutor() + self.grammar_cache = XGrammarCache(tokenizer, vocab_size) + self.vocab_size = vocab_size + + def _query(self, key: Tuple[str, str]) -> XGrammarGrammar: + return XGrammarGrammar(self.grammar_cache.query(key), self.vocab_size) + + def query(self, key: Tuple[str, str]) -> Future: + return self.executor.submit(self._query, key) + + def reset(self): + self.grammar_cache.reset() + + +class XGrammarCache: + def __init__(self, tokenizer, vocab_size: int): + self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer) + self.vocab_size = vocab_size + + def get_context(self, key: Tuple[str, str]) -> CompiledGrammar: + key_type, key_string = key + if key_type == "json": + return self.grammar_cache.get_compiled_grammar_for_json_schema(key_string) + elif key_type == "regex": + raise ValueError("regex hasn't been supported by xgrammar yet") + else: + raise ValueError(f"Invalid key_type: {key_type}") + + def query(self, key: Tuple[str, str]) -> GrammarMatcher: + ctx = self.get_context(key) + return GrammarMatcher( + ctx, + max_rollback_tokens=MAX_ROLLBACK_TOKENS, + mask_vocab_size=self.vocab_size, + ) + + def reset(self): + self.grammar_cache.clear() diff --git a/python/sglang/srt/constrained/xgrammar_cache.py b/python/sglang/srt/constrained/xgrammar_cache.py deleted file mode 100644 index 180f67d49..000000000 --- a/python/sglang/srt/constrained/xgrammar_cache.py +++ /dev/null @@ -1,75 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -"""Cache for the compressed finite state machine.""" - -from typing import Tuple - -from transformers import AutoTokenizer - -try: - from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher -except ImportError as e: - - class Dummy: - pass - - GrammarMatcher = Dummy - CompiledGrammar = Dummy - CachedGrammarCompiler = Dummy - - -MAX_ROLLBACK_TOKENS = 10 - - -class XGrammarJumpCache: - """A dummy class.""" - - def reset(self): - pass - - -class XGrammarBackend: - def __init__( - self, - tokenizer_path, - tokenizer_args_dict, - skip_tokenizer_init=False, - whitespace_patterns=None, - ): - # TODO(dark): how to deal with whitespace_patterns and skip_tokenizer_init - if skip_tokenizer_init: - return - - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict) - self.grammar_cache: CachedGrammarCompiler = CachedGrammarCompiler( - tokenizer_or_vocab=tokenizer - ) - - def get_context(self, key: Tuple[str, str]) -> CompiledGrammar: - key_type, key_string = key - if key_type == "json": - return self.grammar_cache.get_compiled_grammar_for_json_schema(key_string) - elif key_type == "regex": - raise ValueError("regex hasn't been supported by xgrammar yet") - else: - raise ValueError(f"Invalid key_type: {key_type}") - - def query(self, key: Tuple[str, str], vocab_size: int) -> GrammarMatcher: - ctx = self.get_context(key) - return GrammarMatcher( - ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size - ) - - def reset(self): - self.grammar_cache.clear() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 96ed7c8a8..a3bcbeba5 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -37,7 +37,6 @@ import torch from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.constrained.grammar import Grammar from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool @@ -249,7 +248,7 @@ class Req: self.embedding = None # Constrained decoding - self.grammar: Optional[Grammar] = None + self.grammar = None # The number of cached tokens, that were already cached in the KV cache self.cached_tokens = 0 @@ -359,8 +358,6 @@ class Req: return def jump_forward_and_retokenize(self, jump_forward_str, next_state): - assert self.grammar is not None and self.tokenizer is not None - if self.origin_input_text is None: # Recovering text can only use unpadded ids self.origin_input_text = self.tokenizer.decode( @@ -809,9 +806,10 @@ class ScheduleBatch: for i, req in enumerate(self.reqs): if req.grammar is not None: - jump_helper = req.grammar.try_jump(req.tokenizer) - if jump_helper.can_jump(): - suffix_ids = jump_helper.suffix_ids + jump_helper = req.grammar.try_jump_forward(req.tokenizer) + if jump_helper: + suffix_ids, _ = jump_helper + # Current ids, for cache and revert cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1] cur_output_ids = req.output_ids @@ -827,6 +825,8 @@ class ScheduleBatch: next_state, ) = req.grammar.jump_forward_str_state(jump_helper) + # Make the incrementally decoded text part of jump_forward_str + # so that the UTF-8 will not corrupt jump_forward_str = new_text + jump_forward_str if not req.jump_forward_and_retokenize( jump_forward_str, next_state diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5cf96d4c3..2cdb69f1a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -21,6 +21,7 @@ import threading import time import warnings from collections import deque +from concurrent import futures from types import SimpleNamespace from typing import List, Optional @@ -29,7 +30,6 @@ import zmq from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.constrained.grammar import GrammarBackend from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( @@ -100,7 +100,7 @@ class Scheduler: self.tp_rank = tp_rank self.tp_size = server_args.tp_size self.schedule_policy = server_args.schedule_policy - self.disable_regex_jump_forward = server_args.disable_regex_jump_forward + self.disable_jump_forward = server_args.disable_jump_forward self.lora_paths = server_args.lora_paths self.max_loras_per_batch = server_args.max_loras_per_batch self.enable_overlap = server_args.enable_overlap_schedule @@ -234,22 +234,33 @@ class Scheduler: self.chunked_prefill_size is not None and server_args.enable_mixed_chunk ) - # Init the grammar cache for constrained generation - self.grammar_cache = None + # Init the grammar backend for constrained generation self.grammar_queue: List[Req] = [] - if not server_args.skip_tokenizer_init: - self.grammar_cache = GrammarBackend( - server_args.tokenizer_path, - { - "tokenizer_mode": server_args.tokenizer_mode, - "trust_remote_code": server_args.trust_remote_code, - }, - skip_tokenizer_init=server_args.skip_tokenizer_init, - whitespace_patterns=server_args.constrained_json_whitespace_pattern, - backend=server_args.grammar_backend, - allow_jump=not server_args.disable_regex_jump_forward, - ) + if server_args.grammar_backend == "outlines": + from sglang.srt.constrained.outlines_backend import ( + OutlinesGrammarBackend, + ) + + self.grammar_backend = OutlinesGrammarBackend( + self.tokenizer, + whitespace_patterns=server_args.constrained_json_whitespace_pattern, + allow_jump_forward=not server_args.disable_jump_forward, + ) + elif server_args.grammar_backend == "xgrammar": + from sglang.srt.constrained.xgrammar_backend import ( + XGrammarGrammarBackend, + ) + + self.grammar_backend = XGrammarGrammarBackend( + self.tokenizer, vocab_size=self.model_config.vocab_size + ) + else: + raise ValueError( + f"Invalid grammar backend: {server_args.grammar_backend}" + ) + else: + self.grammar_backend = None # Init new token estimation assert ( @@ -461,15 +472,14 @@ class Scheduler: req.sampling_params.json_schema is not None or req.sampling_params.regex is not None ): - assert self.grammar_cache is not None + assert self.grammar_backend is not None if req.sampling_params.json_schema is not None: - req.grammar = self.grammar_cache.query( + req.grammar = self.grammar_backend.query( ("json", req.sampling_params.json_schema), - self.model_config.vocab_size, ) elif req.sampling_params.regex is not None: - req.grammar = self.grammar_cache.query( - ("regex", req.sampling_params.regex), self.model_config.vocab_size + req.grammar = self.grammar_backend.query( + ("regex", req.sampling_params.regex) ) # Truncate prompts that are too long @@ -638,14 +648,14 @@ class Scheduler: return self.running_batch def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: - # Check if the grammar queue is ready + # Check if the grammar is ready in the grammar queue if self.grammar_queue: new_grammar_queue = [] for req in self.grammar_queue: - if req.grammar.done(): - req.grammar = req.grammar.result() + try: + req.grammar = req.grammar.result(timeout=0.05) self.waiting_queue.append(req) - else: + except futures._base.TimeoutError: new_grammar_queue.append(req) self.grammar_queue = new_grammar_queue @@ -783,7 +793,7 @@ class Scheduler: ) # Check for jump-forward - if not self.disable_regex_jump_forward: + if not self.disable_jump_forward: jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func) self.waiting_queue.extend(jump_forward_reqs) if batch.is_empty(): @@ -1142,8 +1152,8 @@ class Scheduler: ): self.tree_cache.reset() self.tree_cache_metrics = {"total": 0, "hit": 0} - if self.grammar_cache is not None: - self.grammar_cache.reset() + if self.grammar_backend is not None: + self.grammar_backend.reset() # TODO(dark): reset the bnf cache self.req_to_token_pool.clear() self.token_to_kv_pool.clear() diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 6afd48cc8..17369d31a 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, List, Optional import torch import sglang.srt.sampling.penaltylib as penaltylib -from sglang.srt.constrained.grammar import Grammar if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -31,7 +30,7 @@ class SamplingBatchInfo: logit_bias: torch.Tensor = None vocab_mask: Optional[torch.Tensor] = None - grammars: Optional[List[Optional[Grammar]]] = None + grammars: Optional[List] = None # Penalizer penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None @@ -146,7 +145,7 @@ class SamplingBatchInfo: ) for i, grammar in enumerate(self.grammars): if grammar is not None: - grammar.fill_vocab_mask(self.vocab_mask[i], self.vocab_size) + grammar.fill_vocab_mask(self.vocab_mask[i]) def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): if self.penalizer_orchestrator: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 769f435cd..7003d2c53 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -111,7 +111,7 @@ class ServerArgs: disable_flashinfer: bool = False disable_flashinfer_sampling: bool = False disable_radix_cache: bool = False - disable_regex_jump_forward: bool = False + disable_jump_forward: bool = False disable_cuda_graph: bool = False disable_cuda_graph_padding: bool = False disable_disk_cache: bool = False @@ -574,7 +574,7 @@ class ServerArgs: type=str, choices=["xgrammar", "outlines"], default=ServerArgs.grammar_backend, - help="Choose the backend for constrained decoding.", + help="Choose the backend for grammar-guided decoding.", ) # Optimization/debug options @@ -594,9 +594,9 @@ class ServerArgs: help="Disable RadixAttention for prefix caching.", ) parser.add_argument( - "--disable-regex-jump-forward", + "--disable-jump-forward", action="store_true", - help="Disable regex jump-forward.", + help="Disable jump-forward for grammar-guided decoding.", ) parser.add_argument( "--disable-cuda-graph", @@ -616,7 +616,6 @@ class ServerArgs: parser.add_argument( "--disable-custom-all-reduce", action="store_true", - default=False, help="Disable the custom all-reduce kernel and fall back to NCCL.", ) parser.add_argument( @@ -688,7 +687,6 @@ class ServerArgs: ) parser.add_argument( "--delete-ckpt-after-loading", - default=ServerArgs.delete_ckpt_after_loading, action="store_true", help="Delete the model checkpoint after loading the model.", ) diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index 23c7cc260..41d9b0c90 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -61,18 +61,27 @@ class TestJSONConstrained(unittest.TestCase): "logprob_start_len": 0, }, ) - print(json.dumps(response.json())) + ret = response.json() + print(json.dumps(ret)) print("=" * 100) if not json_schema: return + # Make sure the json output is valid try: - js_obj = json.loads(response.json()["text"]) + js_obj = json.loads(ret["text"]) except (TypeError, json.decoder.JSONDecodeError): raise - assert isinstance(js_obj["name"], str) - assert isinstance(js_obj["population"], int) + + self.assertIsInstance(js_obj["name"], str) + self.assertIsInstance(js_obj["population"], int) + + # Make sure jump forward is triggered + self.assertGreater( + ret["meta_info"]["completion_tokens"], + ret["meta_info"]["completion_tokens_wo_jump_forward"], + ) def test_json_generate(self): self.run_decode(json_schema=self.json_schema) @@ -100,8 +109,9 @@ class TestJSONConstrained(unittest.TestCase): except (TypeError, json.decoder.JSONDecodeError): print("JSONDecodeError", text) raise - assert isinstance(js_obj["name"], str), f"{js_obj=}" - assert isinstance(js_obj["population"], int), f"{js_obj=}" + + self.assertIsInstance(js_obj["name"], str) + self.assertIsInstance(js_obj["population"], int) def test_mix_json_and_other(self): json_schemas = [None, None, self.json_schema, self.json_schema] * 10