diff --git a/python/sglang/srt/constrained/__init__.py b/python/sglang/srt/constrained/__init__.py index a8708dfea..8ce7f05b6 100644 --- a/python/sglang/srt/constrained/__init__.py +++ b/python/sglang/srt/constrained/__init__.py @@ -13,25 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. """ -"""For constrained decoding.""" - import json from typing import Dict, Optional, Union from pydantic import BaseModel -try: - from outlines.caching import cache as disk_cache - from outlines.caching import disable_cache - from outlines.fsm.guide import RegexGuide - from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm - from outlines.models.transformers import TransformerTokenizer -except ImportError as e: - print( - f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n' - ) - raise - try: from outlines.fsm.json_schema import build_regex_from_object except ImportError: @@ -51,31 +37,6 @@ except ImportError: return build_regex_from_schema(schema, whitespace_pattern) -try: - from xgrammar import ( - GrammarMatcher, - GrammarMatcherInitContext, - GrammarMatcherInitContextCache, - ) -except ImportError as e: - - class Dummy: - pass - - GrammarMatcher = Dummy - GrammarMatcherInitContext = Dummy - GrammarMatcherInitContextCache = Dummy - __all__ = [ - "RegexGuide", - "FSMInfo", - "make_deterministic_fsm", "build_regex_from_object", - "TransformerTokenizer", - "disk_cache", - "disable_cache", - "make_byte_level_fsm", - "GrammarMatcher", - "GrammarMatcherInitContext", - "GrammarMatcherInitContextCache", ] diff --git a/python/sglang/srt/constrained/base_tool_cache.py b/python/sglang/srt/constrained/base_tool_cache.py index fa1aff5ea..f137ad16e 100644 --- a/python/sglang/srt/constrained/base_tool_cache.py +++ b/python/sglang/srt/constrained/base_tool_cache.py @@ -13,25 +13,47 @@ See the License for the specific language governing permissions and limitations under the License. """ -"""Base tool cache for constrained decoding tools.""" +"""Base cache class for constrained decoding tools.""" import time +from dataclasses import dataclass +from threading import Event, Lock +from typing import Any, Dict, Tuple + + +@dataclass +class MapEntry: + event: Event + value: Any + + def __iter__(self): + return iter((self.event, self.value)) class BaseToolCache: + def __init__(self, enable=True): - self.enable = enable + self.enable: bool = enable + self.cache: Dict[str, MapEntry] = {} + self.metrics: Dict[str, Any] = {} + self.lock_cache: Lock = Lock() + self.lock_metrics: Lock = Lock() self.reset() def reset(self): - self.cache = {} - self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0} + with self.lock_cache: + self.cache = {} + with self.lock_metrics: + self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0} - def query(self, key): - def _init_with_timer(key): - start = time.monotonic() - val = self.init_value(key) - init_time = time.monotonic() - start + def _init_with_timer(self, key) -> Tuple[Any, float]: + start = time.monotonic() + val = self.init_value(key) + init_time = time.monotonic() - start + return val, init_time + + def update_time(self, init_time): + with self.lock_metrics: curr_total = self.metrics["total"] new_total = curr_total + 1 @@ -39,27 +61,44 @@ class BaseToolCache: self.metrics["avg_init_time"] = (init_time / new_total) + ( curr_total / new_total ) * self.metrics["avg_init_time"] - return val - if key in self.cache: - self.metrics["hit"] += 1 - val = self.cache[key] - else: - # Cache miss or disabled. - val = _init_with_timer(key) + def query(self, key): + if not self.enable: + value, init_time = self._init_with_timer(key) + self.update_time(init_time) + return value - if self.enable: + with self.lock_cache: + if key in self.cache: + entry = self.cache[key] + cache_hit = True + else: + entry = MapEntry(Event(), None) + self.cache[key] = entry + cache_hit = False + + with self.lock_metrics: self.metrics["total"] += 1 - self.cache[key] = val - return val + if cache_hit: + self.metrics["hit"] += 1 + + if cache_hit: + entry.event.wait() + else: + entry.value, init_time = self._init_with_timer(key) + self.update_time(init_time) + entry.event.set() + return entry.value def init_value(self, key): raise NotImplementedError() def get_cache_hit_rate(self): - if self.metrics["total"] == 0: - return 0 - return self.metrics["hit"] / self.metrics["total"] + with self.lock_metrics: + if self.metrics["total"] == 0: + return 0 + return self.metrics["hit"] / self.metrics["total"] def get_avg_init_time(self): - return self.metrics["avg_init_time"] + with self.lock_metrics: + return self.metrics["avg_init_time"] diff --git a/python/sglang/srt/constrained/grammar.py b/python/sglang/srt/constrained/grammar.py index 0281539b8..3f9dfb8a0 100644 --- a/python/sglang/srt/constrained/grammar.py +++ b/python/sglang/srt/constrained/grammar.py @@ -13,50 +13,44 @@ limitations under the License. """Cache for the compressed finite state machine.""" import logging -from typing import List, Optional, Tuple, Union +from concurrent.futures import Future, ThreadPoolExecutor +from typing import List, Tuple, Union import torch -from sglang.srt.constrained import GrammarMatcher, RegexGuide -from sglang.srt.constrained.bnf_cache import BNFCache -from sglang.srt.constrained.fsm_cache import FSMCache -from sglang.srt.constrained.jump_forward import JumpForwardCache, JumpForwardMap - -# from sglang.srt.managers.schedule_batch import Req +from sglang.srt.constrained.outlines_cache import OutlinesCache, RegexGuide +from sglang.srt.constrained.outlines_jump_forward import ( + OutlinesJumpCache, + OutlinesJumpForwardMap, +) +from sglang.srt.constrained.xgrammar_cache import ( + GrammarMatcher, + XGrammarBackend, + XGrammarJumpCache, +) logger = logging.getLogger(__name__) -INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 - - -class XGrammarJump: - pass - class JumpHelper: - data: Union[List, str] - state: int - suffix_ids: List[int] def __init__( self, data: Union[List, str] = "", state: int = -1, suffix_ids=[] ) -> None: - self.data = data - self.state = state - self.suffix_ids = suffix_ids + self.data: Union[List, str] = data + self.state: int = state + self.suffix_ids: List[int] = suffix_ids def can_jump(self): return len(self.data) > 0 class Grammar: - grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]] - jump_map: Union[XGrammarJump, JumpForwardMap, None] def __init__( self, grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]], - jump_map: Union[XGrammarJump, JumpForwardMap, None], + jump_map: Union[XGrammarJumpCache, OutlinesJumpForwardMap, None], ) -> None: self.grammar = grammar self.jump_map = jump_map @@ -69,10 +63,10 @@ class Grammar: self.grammar = guide, guide.get_next_state(state, token) def try_jump(self, tokenizer) -> JumpHelper: - if isinstance(self.jump_map, XGrammarJump): + if isinstance(self.jump_map, XGrammarJumpCache): assert isinstance(self.grammar, GrammarMatcher) return JumpHelper(self.grammar.find_jump_forward_string()) - elif isinstance(self.jump_map, JumpForwardMap): + elif isinstance(self.jump_map, OutlinesJumpForwardMap): assert isinstance(self.grammar, Tuple) _, state = self.grammar @@ -103,7 +97,7 @@ class Grammar: if isinstance(helper.data, str): return helper.data, -1 else: - assert isinstance(self.jump_map, JumpForwardMap) + assert isinstance(self.jump_map, OutlinesJumpForwardMap) return self.jump_map.jump_forward_symbol(helper.state) def jump_and_retokenize( @@ -129,7 +123,7 @@ class Grammar: def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int): if isinstance(self.grammar, GrammarMatcher): # Note that this bitmask is a bitset, not bool - bitmask = self.grammar.find_next_token_bitmask() + bitmask = self.grammar.get_next_token_bitmask() # Mask the tokens that are not allowed vocab_mask[ self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size) @@ -140,9 +134,7 @@ class Grammar: vocab_mask[guide.get_next_instruction(state).tokens] = 0 -class GrammarCache: - grammar_cache: Union[BNFCache, FSMCache] - jump_cache: Union[XGrammarJump, JumpForwardCache, None] +class GrammarBackend: def __init__( self, @@ -153,38 +145,38 @@ class GrammarCache: backend=None, allow_jump=False, ): + self.executor = ThreadPoolExecutor() + self.backend = backend + if backend == "xgrammar": - self.grammar_cache = BNFCache( + self.grammar_cache = XGrammarBackend( tokenizer_path=tokenizer_path, tokenizer_args_dict=tokenizer_args_dict, skip_tokenizer_init=skip_tokenizer_init, whitespace_patterns=whitespace_patterns, ) - self.jump_cache = XGrammarJump() if allow_jump else None + self.jump_cache = XGrammarJumpCache() if allow_jump else None else: assert backend == "outlines" - self.grammar_cache = FSMCache( + self.grammar_cache = OutlinesCache( tokenizer_path=tokenizer_path, tokenizer_args_dict=tokenizer_args_dict, skip_tokenizer_init=skip_tokenizer_init, constrained_json_whitespace_pattern=whitespace_patterns, - enable=True, ) - self.jump_cache = JumpForwardCache() if allow_jump else None + self.jump_cache = OutlinesJumpCache() if allow_jump else None - def query(self, key: Tuple[str, str], vocab_size: int) -> Grammar: - if isinstance(self.grammar_cache, BNFCache): - assert not isinstance(self.jump_cache, JumpForwardCache) + def _query(self, key: Tuple[str, str], vocab_size: int) -> Grammar: + if isinstance(self.grammar_cache, XGrammarBackend): return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache) else: - jump_map = None guide, regex = self.grammar_cache.query(key) - if isinstance(self.jump_cache, JumpForwardCache): - jump_map = self.jump_cache.query(regex) + jump_map = self.jump_cache.query(regex) return Grammar((guide, 0), jump_map) + def query(self, key: Tuple[str, str], vocab_size: int) -> Future: + return self.executor.submit(self._query, key, vocab_size) + def reset(self): - if isinstance(self.grammar_cache, FSMCache): - self.grammar_cache.reset() - if isinstance(self.jump_cache, JumpForwardCache): - self.jump_cache.reset() + self.grammar_cache.reset() + self.jump_cache.reset() diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/outlines_cache.py similarity index 93% rename from python/sglang/srt/constrained/fsm_cache.py rename to python/sglang/srt/constrained/outlines_cache.py index 192431fda..8971d5a5e 100644 --- a/python/sglang/srt/constrained/fsm_cache.py +++ b/python/sglang/srt/constrained/outlines_cache.py @@ -17,16 +17,17 @@ limitations under the License. import logging from interegular import InvalidSyntax, parse_pattern -from outlines.fsm.json_schema import build_regex_from_schema +from outlines.fsm.guide import RegexGuide +from outlines.models.transformers import TransformerTokenizer from transformers import AutoTokenizer -from sglang.srt.constrained import RegexGuide, TransformerTokenizer +from sglang.srt.constrained import build_regex_from_object from sglang.srt.constrained.base_tool_cache import BaseToolCache logger = logging.getLogger(__name__) -class FSMCache(BaseToolCache): +class OutlinesCache(BaseToolCache): def __init__( self, tokenizer_path, @@ -74,7 +75,7 @@ class FSMCache(BaseToolCache): key_type, key_string = key if key_type == "json": try: - regex = build_regex_from_schema( + regex = build_regex_from_object( key_string, whitespace_pattern=self.constrained_json_whitespace_pattern, ) diff --git a/python/sglang/srt/constrained/jump_forward.py b/python/sglang/srt/constrained/outlines_jump_forward.py similarity index 94% rename from python/sglang/srt/constrained/jump_forward.py rename to python/sglang/srt/constrained/outlines_jump_forward.py index 1ebc8b217..2439db276 100644 --- a/python/sglang/srt/constrained/jump_forward.py +++ b/python/sglang/srt/constrained/outlines_jump_forward.py @@ -14,7 +14,7 @@ limitations under the License. """ """ -Faster constrained decoding. +Faster constrained decoding with jump forward decoding / compressed finite state machine. Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/ """ @@ -23,15 +23,10 @@ import logging from collections import defaultdict import interegular -import outlines.caching from interegular import InvalidSyntax +from outlines.caching import cache as disk_cache +from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm -from sglang.srt.constrained import ( - FSMInfo, - disk_cache, - make_byte_level_fsm, - make_deterministic_fsm, -) from sglang.srt.constrained.base_tool_cache import BaseToolCache IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" @@ -47,7 +42,7 @@ class JumpEdge: byte_next_state: int = None -class JumpForwardMap: +class OutlinesJumpForwardMap: def __init__(self, regex_string): @disk_cache() def _init_state_to_jump_forward(regex_string): @@ -169,12 +164,12 @@ class JumpForwardMap: ) -class JumpForwardCache(BaseToolCache): +class OutlinesJumpCache(BaseToolCache): def __init__(self): super().__init__() def init_value(self, regex): - forward_map = JumpForwardMap(regex) + forward_map = OutlinesJumpForwardMap(regex) if forward_map.state_to_jump_forward: return forward_map else: @@ -182,7 +177,7 @@ class JumpForwardCache(BaseToolCache): def test_main(regex_string): - jump_forward_map = JumpForwardMap(regex_string) + jump_forward_map = OutlinesJumpForwardMap(regex_string) for state, e in jump_forward_map.state_to_jump_forward.items(): if e.symbol is not None: jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state) diff --git a/python/sglang/srt/constrained/bnf_cache.py b/python/sglang/srt/constrained/xgrammar_cache.py similarity index 68% rename from python/sglang/srt/constrained/bnf_cache.py rename to python/sglang/srt/constrained/xgrammar_cache.py index 19765731b..180f67d49 100644 --- a/python/sglang/srt/constrained/bnf_cache.py +++ b/python/sglang/srt/constrained/xgrammar_cache.py @@ -17,18 +17,29 @@ from typing import Tuple from transformers import AutoTokenizer -from sglang.srt.constrained import ( - GrammarMatcher, - GrammarMatcherInitContext, - GrammarMatcherInitContextCache, -) +try: + from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher +except ImportError as e: + + class Dummy: + pass + + GrammarMatcher = Dummy + CompiledGrammar = Dummy + CachedGrammarCompiler = Dummy + MAX_ROLLBACK_TOKENS = 10 -class BNFCache: - grammar_cache: GrammarMatcherInitContextCache +class XGrammarJumpCache: + """A dummy class.""" + def reset(self): + pass + + +class XGrammarBackend: def __init__( self, tokenizer_path, @@ -41,16 +52,16 @@ class BNFCache: return tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict) - self.grammar_cache = GrammarMatcherInitContextCache( + self.grammar_cache: CachedGrammarCompiler = CachedGrammarCompiler( tokenizer_or_vocab=tokenizer ) - def get_context(self, key: Tuple[str, str]) -> GrammarMatcherInitContext: + def get_context(self, key: Tuple[str, str]) -> CompiledGrammar: key_type, key_string = key if key_type == "json": - return self.grammar_cache.get_init_context_for_json_schema(key_string) + return self.grammar_cache.get_compiled_grammar_for_json_schema(key_string) elif key_type == "regex": - raise ValueError(f"regex hasn't been supported by xgrammar yet") + raise ValueError("regex hasn't been supported by xgrammar yet") else: raise ValueError(f"Invalid key_type: {key_type}") @@ -59,3 +70,6 @@ class BNFCache: return GrammarMatcher( ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size ) + + def reset(self): + self.grammar_cache.clear() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2dc1944d5..5cf96d4c3 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -29,7 +29,7 @@ import zmq from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.constrained.grammar import GrammarCache +from sglang.srt.constrained.grammar import GrammarBackend from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( @@ -234,11 +234,12 @@ class Scheduler: self.chunked_prefill_size is not None and server_args.enable_mixed_chunk ) - # Init the FSM cache for constrained generation + # Init the grammar cache for constrained generation self.grammar_cache = None + self.grammar_queue: List[Req] = [] if not server_args.skip_tokenizer_init: - self.grammar_cache = GrammarCache( + self.grammar_cache = GrammarBackend( server_args.tokenizer_path, { "tokenizer_mode": server_args.tokenizer_mode, @@ -455,7 +456,7 @@ class Scheduler: # By default, only return the logprobs for output tokens req.logprob_start_len = len(recv_req.input_ids) - 1 - # Init regex FSM or BNF + # Init grammar cache for this request if ( req.sampling_params.json_schema is not None or req.sampling_params.regex is not None @@ -488,7 +489,10 @@ class Scheduler: self.max_req_len - len(req.origin_input_ids) - 1, ) - self.waiting_queue.append(req) + if req.grammar is not None: + self.grammar_queue.append(req) + else: + self.waiting_queue.append(req) def handle_embedding_request( self, @@ -634,6 +638,17 @@ class Scheduler: return self.running_batch def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: + # Check if the grammar queue is ready + if self.grammar_queue: + new_grammar_queue = [] + for req in self.grammar_queue: + if req.grammar.done(): + req.grammar = req.grammar.result() + self.waiting_queue.append(req) + else: + new_grammar_queue.append(req) + self.grammar_queue = new_grammar_queue + # Handle the cases where prefill is not allowed if ( self.batch_is_full or len(self.waiting_queue) == 0 diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 1dde62943..8b06d2cea 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -39,7 +39,6 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import ModelRegistry from sglang.srt.configs.model_config import AttentionArch, ModelConfig -from sglang.srt.constrained import disable_cache from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend from sglang.srt.layers.attention.triton_backend import TritonAttnBackend @@ -129,6 +128,8 @@ class ModelRunner: if server_args.show_time_cost: enable_show_time_cost() if server_args.disable_disk_cache: + from outlines.caching import disable_cache + disable_cache() global_server_args_dict.update( diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index 88368fba8..23c7cc260 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -100,8 +100,8 @@ class TestJSONConstrained(unittest.TestCase): except (TypeError, json.decoder.JSONDecodeError): print("JSONDecodeError", text) raise - assert isinstance(js_obj["name"], str) - assert isinstance(js_obj["population"], int) + assert isinstance(js_obj["name"], str), f"{js_obj=}" + assert isinstance(js_obj["population"], int), f"{js_obj=}" def test_mix_json_and_other(self): json_schemas = [None, None, self.json_schema, self.json_schema] * 10