From ca13f3b8c58e419c04e706bb5a6711073f466aa0 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sun, 21 Jan 2024 13:26:11 +0800 Subject: [PATCH] Disk FSM cache and adjust code. (#63) --- examples/quick_start/srt_example_regex.py | 5 +- python/pyproject.toml | 2 +- python/sglang/srt/constrained/disk_cache.py | 70 ++++++ python/sglang/srt/constrained/fsm.py | 231 ++++++++---------- python/sglang/srt/constrained/fsm_cache.py | 40 +-- python/sglang/srt/constrained/tokenizer.py | 141 +---------- .../sglang/srt/managers/router/infer_batch.py | 2 +- .../sglang/srt/managers/router/model_rpc.py | 15 +- 8 files changed, 207 insertions(+), 299 deletions(-) create mode 100644 python/sglang/srt/constrained/disk_cache.py diff --git a/examples/quick_start/srt_example_regex.py b/examples/quick_start/srt_example_regex.py index 8f85aec5e..0dcae15ea 100644 --- a/examples/quick_start/srt_example_regex.py +++ b/examples/quick_start/srt_example_regex.py @@ -1,13 +1,16 @@ from sglang import function, gen, set_default_backend, Runtime +IP_ADDR_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" + + @function def regex_gen(s): s += "Q: What is the IP address of the Google DNS servers?\n" s += "A: " + gen( "answer", temperature=0, - regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)", + regex=IP_ADDR_REGEX, ) diff --git a/python/pyproject.toml b/python/pyproject.toml index 0df941460..6479ef48b 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ [project.optional-dependencies] srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5", - "interegular", "lark", "numba", "pydantic"] + "interegular", "lark", "numba", "pydantic", "diskcache", "cloudpickle"] openai = ["openai>=1.0"] anthropic = ["anthropic"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"] diff --git a/python/sglang/srt/constrained/disk_cache.py b/python/sglang/srt/constrained/disk_cache.py new file mode 100644 index 000000000..1855895e6 --- /dev/null +++ b/python/sglang/srt/constrained/disk_cache.py @@ -0,0 +1,70 @@ +# Adapted from: +# https://github.com/outlines-dev/outlines/blob/6c6966cfa24e9c120494ebb317c6126aa2ae94af/outlines/caching.py +import asyncio +import hashlib +import os +from typing import Callable, Optional + +import cloudpickle +from diskcache import Cache + +home_dir = os.path.expanduser("~") +cache_dir = os.environ.get("SGLANG_CACHE_DIR", f"{home_dir}/.cache/sglang") +memory = Cache(cache_dir, eviction_policy="none", cull_limit=0) +_caching_enabled = True + + +def hash_arguments(*args, **kwargs) -> str: + """Create a hash out of the args and kwargs provided""" + result = hashlib.md5() + for item in list(args) + sorted(kwargs.items()): + result.update(cloudpickle.dumps(item)) + return result.hexdigest() + + +def disk_cache(key_function: Optional[Callable] = None): + def decorator(cached_function: Callable): + def wrapper(*args, **kwargs): + if not _caching_enabled: + return cached_function(*args, **kwargs) + if key_function: + key_args = key_function(*args, **kwargs) + cache_key = hash_arguments(*key_args) + else: + cache_key = hash_arguments(*args, **kwargs) + if cache_key in memory: + return memory[cache_key] + result = cached_function(*args, **kwargs) + memory[cache_key] = result + return result + + async def async_wrapper(*args, **kwargs): + if not _caching_enabled: + return await cached_function(*args, **kwargs) + if key_function: + key_args = key_function(*args, **kwargs) + cache_key = hash_arguments(*key_args) + else: + cache_key = hash_arguments(*args, **kwargs) + if cache_key in memory: + return memory[cache_key] + result = await cached_function(*args, **kwargs) + memory[cache_key] = result + return result + + if asyncio.iscoroutinefunction(cached_function): + return async_wrapper + else: + return wrapper + + return decorator + + +def disable_cache(): + global _caching_enabled + _caching_enabled = False + + +def clear_cache(): + global memory + memory.clear() diff --git a/python/sglang/srt/constrained/fsm.py b/python/sglang/srt/constrained/fsm.py index ceec5d3e5..8da6366ac 100644 --- a/python/sglang/srt/constrained/fsm.py +++ b/python/sglang/srt/constrained/fsm.py @@ -1,9 +1,10 @@ # Adapted from: -# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/fsm.py -from typing import List, NewType, Protocol +# https://github.com/outlines-dev/outlines/blob/6c6966cfa24e9c120494ebb317c6126aa2ae94af/outlines/fsm/fsm.py +from typing import List, NewType, Protocol, Tuple import interegular from lark import Lark +from sglang.srt.constrained.disk_cache import disk_cache # from outlines.fsm.parsing import PartialLark from sglang.srt.constrained.regex import ( @@ -16,16 +17,16 @@ FSMState = NewType("FSMState", int) class FSM(Protocol): - def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: + def allowed_token_ids(self, state: FSMState) -> List[int]: ... - def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: + def next_state(self, state: FSMState, token_id: int) -> FSMState: ... - def is_final_state(self, state: FSMState, idx: int = 0) -> bool: + def is_final_state(self, state: FSMState) -> bool: ... - def reset(self) -> None: + def copy(self) -> "FSM": ... @@ -38,17 +39,12 @@ class StopAtTokenFSM(FSM): """ - def __init__( - self, - tokenizer: "Tokenizer", - stop_token_id: int, - ): + def __init__(self, tokenizer: "Tokenizer", stop_token_id: int): self.stop_token_id = stop_token_id - self.num_tokens_generated = 0 self.vocabulary = tokenizer.vocabulary.values() self.final_states = {1} - def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: + def allowed_token_ids(self, state: FSMState) -> List[int]: """Generate a list of allowed tokens for the next step. When in the initial state we allow every token to be generated. @@ -58,8 +54,6 @@ class StopAtTokenFSM(FSM): ---------- state The current state of the FSM. - idx - The index of the current input in the batch. Returns ------- @@ -71,7 +65,7 @@ class StopAtTokenFSM(FSM): else: return [self.stop_token_id] - def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: + def next_state(self, state: FSMState, token_id: int) -> FSMState: """Update the state of the FSM. The FSM stays in the initial state `0` unless the specified stop token @@ -84,29 +78,24 @@ class StopAtTokenFSM(FSM): The current state of the FSM. token_id The id of the token that was just generated. - idx - The index of the current input in the batch. Returns ------- The new state of the FSM. """ - if idx == 0: - self.num_tokens_generated += 1 - if token_id == self.stop_token_id: return FSMState(1) return FSMState(0) - def is_final_state(self, state: FSMState, idx: int = 0) -> bool: + def is_final_state(self, state: FSMState) -> bool: """Determine whether the current state of the FSM is a final state.""" return state in self.final_states - def reset(self) -> None: - """Reset the FSM to its initial state. Here this only resets the token counter.""" - self.num_tokens_generated = 0 + def copy(self) -> "StopAtTokenFSM": + """Create a copy of the FSM.""" + return self class RegexFSM(FSM): @@ -117,32 +106,48 @@ class RegexFSM(FSM): regex_string: str, tokenizer: "Tokenizer", ): - regex_pattern = interegular.parse_pattern(regex_string) - regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + @disk_cache() + def create_states_mapping( + regex_string: str, cacheable_vocabulary: Tuple[Tuple[str, int]] + ) -> Tuple[dict, set, set]: + """Create the variables related to the mapping between states and tokens + The parameters of the function are used for caching purpose + """ + regex_pattern = interegular.parse_pattern(regex_string) + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + ( + states_to_token_maps, + empty_token_ids, + ) = create_fsm_index_tokenizer(regex_fsm, tokenizer) + + # We make sure that it is possible to generate strings in the language + # of the regular expression with the tokens present in the model's + # vocabulary. + if not any( + regex_fsm.finals.intersection(v.values()) + for v in states_to_token_maps.values() + ): + raise ValueError( + "The vocabulary does not allow us to build a sequence that matches the input regex" + ) + + final_states = regex_fsm.finals | { + -1 + } # Include the EOS token in final states + return states_to_token_maps, empty_token_ids, final_states + ( self.states_to_token_maps, self.empty_token_ids, - ) = create_fsm_index_tokenizer(regex_fsm, tokenizer) - - # We make sure that it is possible to generate strings in the language - # of the regular expression with the tokens present in the model's - # vocabulary. - if not any( - regex_fsm.finals.intersection(v.values()) - for v in self.states_to_token_maps.values() - ): - raise ValueError( - "The vocabulary does not allow us to build a sequence that matches the input regex" - ) - - self.final_states = regex_fsm.finals | { - -1 - } # Include the EOS token in final states + self.final_states, + ) = create_states_mapping( + regex_string, tuple(sorted(tokenizer.vocabulary.items())) + ) self.num_tokens_generated = 0 self.vocabulary = tokenizer.vocabulary.values() self.end_token_id = tokenizer.eos_token_id - def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: + def allowed_token_ids(self, state: FSMState) -> List[int]: """Generate a list of allowed tokens for the next step. The initialization of the FSM builds an index which maps FSM states to a @@ -159,8 +164,6 @@ class RegexFSM(FSM): ---------- state The current state of the FSM. - idx - The index of the current input in the batch. Returns ------- @@ -174,7 +177,7 @@ class RegexFSM(FSM): else: return list(next_tokens_to_end_states.keys()) - def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: + def next_state(self, state: FSMState, token_id: int) -> FSMState: """Update the state of the FSM. We use the index to determine to which state the FSM should transition @@ -186,17 +189,12 @@ class RegexFSM(FSM): The current state of the FSM. token_id The id of the token that was just generated. - idx - The index of the current input in the batch. Returns ------- The new state of the FSM. """ - if idx == 0: - self.num_tokens_generated += 1 - if token_id == self.end_token_id: return FSMState(-1) @@ -207,24 +205,22 @@ class RegexFSM(FSM): return FSMState(next_state) - def is_final_state(self, state: FSMState, idx: int = 0) -> bool: + def is_final_state(self, state: FSMState) -> bool: """Determine whether the current state of the FSM is a final state.""" return state in self.final_states - def reset(self) -> None: - """Reset the FSM to its initial state. Here this only resets the token counter.""" - self.num_tokens_generated = 0 + def copy(self) -> "RegexFSM": + """Create a copy of the FSM.""" + return self class CFGFSM(FSM): """FSM to generate text that is in the language of a context-free grammar.""" - def __init__( - self, - cfg_string: str, - tokenizer: "Tokenizer", - ): - # self.parser = PartialLark(cfg_string, parser="lalr") + def __init__(self, cfg_string: str, tokenizer: "Tokenizer"): + self.cfg_string = cfg_string + self.tokenizer = tokenizer + self.parser = Lark( cfg_string, parser="lalr", @@ -239,59 +235,52 @@ class CFGFSM(FSM): self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp() self.terminal_regexps["$END"] = tokenizer.eos_token - self.tokenizer = tokenizer - self.num_tokens_generated = 0 - self.generations: List[str] = [] - self.regex_fsms: List[RegexFSM] = [] - self.reset_state: List[bool] = [] - self.allow_eos: List[bool] = [] - self.done: List[bool] = [] + self.generation = "" + self.reset_state = False + self.allow_eos = False + self.done = False + self.regex_fsm: RegexFSM - def _set_next_regex_fsm(self, idx: int = 0) -> None: + def _set_next_regex_fsm(self) -> None: """Use the CFG incremental parser to set the next regex FSM. - Check what the CFG incremental parser proposes next. - If the only proposal is the EOS token, - we set the state to done and return. - If there are other proposals, - we set a new regex FSM and return. + Check what the CFG incremental parser proposes next: + - If the only proposal is the EOS token we set the state to done and + return. + - If there are other proposals, we set a new regex FSM and return. """ - interactive = self.parser.parse_interactive(self.generations[idx]) + interactive = self.parser.parse_interactive(self.generation) interactive.exhaust_lexer() options = {self.terminal_regexps[x] for x in interactive.accepts()} if self.terminal_regexps["$END"] in options: options.remove(self.terminal_regexps["$END"]) if len(options) == 0: - self.done[idx] = True + self.done = True return - self.allow_eos[idx] = True + self.allow_eos = True options.add("") assert len(options) > 1 regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")" - args = ( - regex_string, - self.tokenizer, - ) - if len(self.regex_fsms) <= idx: - self.regex_fsms.append(RegexFSM(*args)) - else: - self.regex_fsms[idx] = RegexFSM(*args) - self.reset_state[idx] = True + self.regex_fsm = RegexFSM(regex_string, self.tokenizer) + self.reset_state = True - def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]: + def allowed_token_ids(self, state: FSMState) -> List[int]: """Generate a list of allowed tokens for the next step. - Upon initialization, the CFG incremental parser is used to determine the first regex. + Upon initialization, the CFG incremental parser is used to determine the + first regex. This regex is used for proposals until either: - - the regex is exhausted, and its only remaining option is the EOS token, - in which case we always transition to the next regex - - the regex can be exhausted, but the EOS token is not the only remaining option, - in which case we transition to the next regex with probability P (TODO) - or remove the possibility of generating the EOS token and continue with the current regex + + - The regex is exhausted, and its only remaining option is the EOS + token, in which case we always transition to the next regex + - The regex can be exhausted, but the EOS token is not the only + remaining option, in which case we transition to the next regex with + probability P (TODO) or remove the possibility of generating the EOS + token and continue with the current regex The CFG incremental parser is allowed to propose the EOS token from any final state, and once it is generated, the FSM will continue to always generate the EOS token. @@ -300,22 +289,14 @@ class CFGFSM(FSM): ---------- state The current state of the FSM. - idx - The index of the current input in the batch. Returns ------- A list that contains the tokens to mask. """ - if len(self.generations) <= idx: - self.generations.append("") - self.reset_state.append(False) - self.allow_eos.append(False) - self.done.append(False) - - if len(self.regex_fsms) > idx: - proposal = self.regex_fsms[idx].allowed_token_ids(state) + if self.generation != "": + proposal = self.regex_fsm.allowed_token_ids(state) if self.tokenizer.eos_token_id not in proposal: return proposal if set(proposal) != {self.tokenizer.eos_token_id}: @@ -323,23 +304,23 @@ class CFGFSM(FSM): proposal = [x for x in proposal if x != self.tokenizer.eos_token_id] return proposal - self._set_next_regex_fsm(idx) + self._set_next_regex_fsm() - if self.done[idx]: + if self.done: return [self.tokenizer.eos_token_id] - if self.reset_state[idx]: + if self.reset_state: state = FSMState(0) - proposal = self.regex_fsms[idx].allowed_token_ids(state) - if self.allow_eos[idx]: - self.allow_eos[idx] = False + proposal = self.regex_fsm.allowed_token_ids(state) + if self.allow_eos: + self.allow_eos = False else: proposal = [x for x in proposal if x != self.tokenizer.eos_token_id] assert len(proposal) > 0 return proposal - def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState: + def next_state(self, state: FSMState, token_id: int) -> FSMState: """Update the state of the FSM. Transitions the underlying regex FSM to its next state. @@ -352,34 +333,26 @@ class CFGFSM(FSM): The current state of the FSM. token_id The id of the token that was just generated. - idx - The index of the current input in the batch. Returns ------- The new state of the FSM. """ - if idx == 0: - self.num_tokens_generated += 1 if token_id == self.tokenizer.eos_token_id: - self.done[idx] = True + self.done = True return FSMState(-1) - if self.reset_state[idx]: - self.reset_state[idx] = False + if self.reset_state: + self.reset_state = False state = FSMState(0) - self.generations[idx] += self.tokenizer.decode([token_id])[0] + self.generation += self.tokenizer.decode([token_id])[0] - return self.regex_fsms[idx].next_state(state, token_id, idx) + return self.regex_fsm.next_state(state, token_id) - def is_final_state(self, state: FSMState, idx: int = 0) -> bool: + def is_final_state(self, state: FSMState) -> bool: """Return whether the current state of the FSM is a final state.""" - return self.done[idx] + return self.done - def reset(self) -> None: - """Reset the FSM to its initial state, so it can be called on a fresh batch on inputs.""" - self.num_tokens_generated = 0 - self.generations = [] - self.regex_fsms = [] - self.reset_state = [] - self.done = [] + def copy(self) -> "CFGFSM": + """Create a copy of the FSM.""" + return CFGFSM(self.cfg_string, self.tokenizer) diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py index bd6c6a073..00be13c8f 100644 --- a/python/sglang/srt/constrained/fsm_cache.py +++ b/python/sglang/srt/constrained/fsm_cache.py @@ -1,41 +1,17 @@ -import threading - from sglang.srt.constrained.fsm import RegexFSM from sglang.srt.constrained.tokenizer import TransformerTokenizer -def get_fsm(regex, tokenizer, fsm_cache_entry): - outlines_tokenizer = TransformerTokenizer(tokenizer) - fsm = RegexFSM(regex, outlines_tokenizer) - fsm_cache_entry.fsm = fsm - fsm_cache_entry.event.set() - - -class FSMCacheEntry: - def __init__(self): - self.fsm = None - self.event = threading.Event() - - class FSMCache: - def __init__(self, tokenizer): + def __init__(self, tokenizer_path, tokenizer_args_dict): self.cache = {} - self.tokenizer = tokenizer + self.outlines_tokenizer = TransformerTokenizer( + tokenizer_path, **tokenizer_args_dict + ) - def init_fsm_in_background(self, regex): + def init_fsm(self, regex): if regex not in self.cache: - self.cache[regex] = FSMCacheEntry() - threading.Thread( - target=get_fsm, - args=( - regex, - self.tokenizer, - self.cache[regex], - ), - ).start() + fsm = RegexFSM(regex, self.outlines_tokenizer) + self.cache[regex] = fsm - def get_fsm(self, regex): - self.init_fsm_in_background(regex) - entry = self.cache[regex] - entry.event.wait() - return entry.fsm + return self.cache[regex] diff --git a/python/sglang/srt/constrained/tokenizer.py b/python/sglang/srt/constrained/tokenizer.py index ac1c8ebed..6853dd9ee 100644 --- a/python/sglang/srt/constrained/tokenizer.py +++ b/python/sglang/srt/constrained/tokenizer.py @@ -2,17 +2,7 @@ # https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py # https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py from abc import abstractmethod -from typing import ( - TYPE_CHECKING, - Dict, - Hashable, - List, - Optional, - Protocol, - Set, - Tuple, - Union, -) +from typing import Dict, Hashable, List, Protocol, Set, Tuple, Union import numpy as np import torch @@ -50,15 +40,6 @@ class Tokenizer(Protocol, Hashable): ... -if TYPE_CHECKING: - from transformers import PreTrainedModel, PreTrainedTokenizer - -__all__ = ["transformers"] - - -KVCacheType = Tuple[Tuple[torch.DoubleTensor, torch.DoubleTensor], ...] - - def get_llama_tokenizer_types(): """Get all the Llama tokenizer types/classes that need work-arounds. @@ -101,76 +82,17 @@ def get_llama_tokenizer_types(): ) -class Transformer: - """Represents a `transformers` model.""" - - def __init__( - self, - model: "PreTrainedModel", - tokenizer: "PreTrainedTokenizer", - ): - self.device = model.device - self.model = model - self.tokenizer = tokenizer - - @torch.inference_mode - def forward( - self, - input_ids: torch.LongTensor, - attention_mask: torch.LongTensor, - past_key_values: Optional[Tuple] = None, - ) -> Tuple[torch.FloatTensor, Optional[KVCacheType]]: - """Compute a forward pass through the transformer model. - - Parameters - ---------- - input_ids - The input token ids. Must be one or two dimensional. - attention_mask - The attention mask. Must be one or two dimensional. - past_key_values - A tuple of tuples containing the cached key and value tensors for each - attention head. - - Returns - ------- - The computed logits and the new cached key and value tensors. - - """ - assert 0 < input_ids.ndim < 3 - - if past_key_values: - input_ids = input_ids[..., -1].unsqueeze(-1) - - output = self.model( - input_ids, - attention_mask=attention_mask, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - past_key_values=past_key_values, - ) - - return output.logits, output.past_key_values - - def __call__( - self, - input_ids: torch.LongTensor, - attention_mask: torch.LongTensor, - past_key_values: Optional[Tuple] = None, - ) -> torch.FloatTensor: - logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values) - next_token_logits = logits[..., -1, :] - - return next_token_logits, kv_cache - - class TransformerTokenizer(Tokenizer): """Represents a tokenizer for models in the `transformers` library.""" - def __init__(self, tokenizer): + def __init__(self, model_name: str, **kwargs): + from transformers import AutoTokenizer + + kwargs.setdefault("padding_side", "left") + self.model_name = model_name # TODO: Do something to make this hashable? - self.tokenizer = tokenizer + self.kwargs = kwargs + self.tokenizer = AutoTokenizer.from_pretrained(model_name, **kwargs) self.eos_token_id = self.tokenizer.eos_token_id self.eos_token = self.tokenizer.eos_token @@ -212,55 +134,10 @@ class TransformerTokenizer(Tokenizer): def __eq__(self, other): if isinstance(other, type(self)): - return False - # TODO(lsyin): the lru_cache for the TransoformerTokenizer is useless ? - # return other.model_name == self.model_name and other.kwargs == self.kwargs + return other.model_name == self.model_name and other.kwargs == self.kwargs return NotImplemented def __hash__(self): from datasets.fingerprint import Hasher return hash(Hasher.hash(self.tokenizer)) - - -def transformers( - model_name: str, - device: Optional[str] = None, - model_kwargs: dict = {}, - tokenizer_kwargs: dict = {}, -): - """Instantiate a model from the `transformers` library and its tokenizer. - - Parameters - ---------- - model_name - The name of the model as listed on Hugging Face's model page. - device - The device(s) on which the model should be loaded. This overrides - the `device_map` entry in `model_kwargs` when provided. - model_kwargs - A dictionary that contains the keyword arguments to pass to the - `from_pretrained` method when loading the model. - tokenizer_kwargs - A dictionary that contains the keyword arguments to pass to the - `from_pretrained` method when loading the tokenizer. - - Returns - ------- - A `TransformersModel` model instance. - - """ - try: - from transformers import AutoModelForCausalLM - except ImportError: - raise ImportError( - "The `transformers` library needs to be installed in order to use `transformers` models." - ) - - if device is not None: - model_kwargs["device_map"] = device - - model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) - tokenizer = TransformerTokenizer(model_name, **tokenizer_kwargs) - - return Transformer(model, tokenizer) diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index 3cc61dd08..1e1a93c9b 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -45,7 +45,7 @@ class Req: # for constrained decoding self.regex_fsm = None - self.regex_fsm_state = None + self.regex_fsm_state = 0 def max_new_tokens(self): return self.sampling_params.max_new_tokens diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 9de71b60b..b4425cf00 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -111,7 +111,13 @@ class ModelRpcServer(rpyc.Service): self.stream_interval = server_args.stream_interval # Init the FSM cache for constrained generation - self.regex_fsm_cache = FSMCache(self.tokenizer) + self.regex_fsm_cache = FSMCache( + server_args.tokenizer_path, + { + "tokenizer_mode": server_args.tokenizer_mode, + "trust_remote_code": server_args.trust_remote_code, + }, + ) # Init new token estimation self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0) @@ -213,6 +219,10 @@ class ModelRpcServer(rpyc.Service): req.stream = recv_req.stream req.tokenizer = self.tokenizer + # Init regex fsm + if req.sampling_params.regex is not None: + req.regex_fsm = self.regex_fsm_cache.init_fsm(req.sampling_params.regex) + # Truncate long prompts req.input_ids = req.input_ids[: self.model_config.context_len - 1] req.sampling_params.max_new_tokens = min( @@ -322,11 +332,10 @@ class ModelRpcServer(rpyc.Service): self.model_config.vocab_size, self.int_token_logit_bias ) - # init the regex fsm before first sampling + # Reset regex fsm state before first sampling due to retractions for req in batch.reqs: if req.sampling_params.regex is not None: req.regex_fsm_state = 0 - req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex) if batch.extend_num_tokens != 0: # Forward