Disk FSM cache and adjust code. (#63)
This commit is contained in:
@@ -1,13 +1,16 @@
|
|||||||
from sglang import function, gen, set_default_backend, Runtime
|
from sglang import function, gen, set_default_backend, Runtime
|
||||||
|
|
||||||
|
|
||||||
|
IP_ADDR_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||||
|
|
||||||
|
|
||||||
@function
|
@function
|
||||||
def regex_gen(s):
|
def regex_gen(s):
|
||||||
s += "Q: What is the IP address of the Google DNS servers?\n"
|
s += "Q: What is the IP address of the Google DNS servers?\n"
|
||||||
s += "A: " + gen(
|
s += "A: " + gen(
|
||||||
"answer",
|
"answer",
|
||||||
temperature=0,
|
temperature=0,
|
||||||
regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
|
regex=IP_ADDR_REGEX,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ dependencies = [
|
|||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
|
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
|
||||||
"interegular", "lark", "numba", "pydantic"]
|
"interegular", "lark", "numba", "pydantic", "diskcache", "cloudpickle"]
|
||||||
openai = ["openai>=1.0"]
|
openai = ["openai>=1.0"]
|
||||||
anthropic = ["anthropic"]
|
anthropic = ["anthropic"]
|
||||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
||||||
|
|||||||
70
python/sglang/srt/constrained/disk_cache.py
Normal file
70
python/sglang/srt/constrained/disk_cache.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
# Adapted from:
|
||||||
|
# https://github.com/outlines-dev/outlines/blob/6c6966cfa24e9c120494ebb317c6126aa2ae94af/outlines/caching.py
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import cloudpickle
|
||||||
|
from diskcache import Cache
|
||||||
|
|
||||||
|
home_dir = os.path.expanduser("~")
|
||||||
|
cache_dir = os.environ.get("SGLANG_CACHE_DIR", f"{home_dir}/.cache/sglang")
|
||||||
|
memory = Cache(cache_dir, eviction_policy="none", cull_limit=0)
|
||||||
|
_caching_enabled = True
|
||||||
|
|
||||||
|
|
||||||
|
def hash_arguments(*args, **kwargs) -> str:
|
||||||
|
"""Create a hash out of the args and kwargs provided"""
|
||||||
|
result = hashlib.md5()
|
||||||
|
for item in list(args) + sorted(kwargs.items()):
|
||||||
|
result.update(cloudpickle.dumps(item))
|
||||||
|
return result.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def disk_cache(key_function: Optional[Callable] = None):
|
||||||
|
def decorator(cached_function: Callable):
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if not _caching_enabled:
|
||||||
|
return cached_function(*args, **kwargs)
|
||||||
|
if key_function:
|
||||||
|
key_args = key_function(*args, **kwargs)
|
||||||
|
cache_key = hash_arguments(*key_args)
|
||||||
|
else:
|
||||||
|
cache_key = hash_arguments(*args, **kwargs)
|
||||||
|
if cache_key in memory:
|
||||||
|
return memory[cache_key]
|
||||||
|
result = cached_function(*args, **kwargs)
|
||||||
|
memory[cache_key] = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def async_wrapper(*args, **kwargs):
|
||||||
|
if not _caching_enabled:
|
||||||
|
return await cached_function(*args, **kwargs)
|
||||||
|
if key_function:
|
||||||
|
key_args = key_function(*args, **kwargs)
|
||||||
|
cache_key = hash_arguments(*key_args)
|
||||||
|
else:
|
||||||
|
cache_key = hash_arguments(*args, **kwargs)
|
||||||
|
if cache_key in memory:
|
||||||
|
return memory[cache_key]
|
||||||
|
result = await cached_function(*args, **kwargs)
|
||||||
|
memory[cache_key] = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
if asyncio.iscoroutinefunction(cached_function):
|
||||||
|
return async_wrapper
|
||||||
|
else:
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def disable_cache():
|
||||||
|
global _caching_enabled
|
||||||
|
_caching_enabled = False
|
||||||
|
|
||||||
|
|
||||||
|
def clear_cache():
|
||||||
|
global memory
|
||||||
|
memory.clear()
|
||||||
@@ -1,9 +1,10 @@
|
|||||||
# Adapted from:
|
# Adapted from:
|
||||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/fsm.py
|
# https://github.com/outlines-dev/outlines/blob/6c6966cfa24e9c120494ebb317c6126aa2ae94af/outlines/fsm/fsm.py
|
||||||
from typing import List, NewType, Protocol
|
from typing import List, NewType, Protocol, Tuple
|
||||||
|
|
||||||
import interegular
|
import interegular
|
||||||
from lark import Lark
|
from lark import Lark
|
||||||
|
from sglang.srt.constrained.disk_cache import disk_cache
|
||||||
|
|
||||||
# from outlines.fsm.parsing import PartialLark
|
# from outlines.fsm.parsing import PartialLark
|
||||||
from sglang.srt.constrained.regex import (
|
from sglang.srt.constrained.regex import (
|
||||||
@@ -16,16 +17,16 @@ FSMState = NewType("FSMState", int)
|
|||||||
|
|
||||||
|
|
||||||
class FSM(Protocol):
|
class FSM(Protocol):
|
||||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
def allowed_token_ids(self, state: FSMState) -> List[int]:
|
||||||
...
|
...
|
||||||
|
|
||||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
def next_state(self, state: FSMState, token_id: int) -> FSMState:
|
||||||
...
|
...
|
||||||
|
|
||||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
def is_final_state(self, state: FSMState) -> bool:
|
||||||
...
|
...
|
||||||
|
|
||||||
def reset(self) -> None:
|
def copy(self) -> "FSM":
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
@@ -38,17 +39,12 @@ class StopAtTokenFSM(FSM):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, tokenizer: "Tokenizer", stop_token_id: int):
|
||||||
self,
|
|
||||||
tokenizer: "Tokenizer",
|
|
||||||
stop_token_id: int,
|
|
||||||
):
|
|
||||||
self.stop_token_id = stop_token_id
|
self.stop_token_id = stop_token_id
|
||||||
self.num_tokens_generated = 0
|
|
||||||
self.vocabulary = tokenizer.vocabulary.values()
|
self.vocabulary = tokenizer.vocabulary.values()
|
||||||
self.final_states = {1}
|
self.final_states = {1}
|
||||||
|
|
||||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
def allowed_token_ids(self, state: FSMState) -> List[int]:
|
||||||
"""Generate a list of allowed tokens for the next step.
|
"""Generate a list of allowed tokens for the next step.
|
||||||
|
|
||||||
When in the initial state we allow every token to be generated.
|
When in the initial state we allow every token to be generated.
|
||||||
@@ -58,8 +54,6 @@ class StopAtTokenFSM(FSM):
|
|||||||
----------
|
----------
|
||||||
state
|
state
|
||||||
The current state of the FSM.
|
The current state of the FSM.
|
||||||
idx
|
|
||||||
The index of the current input in the batch.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -71,7 +65,7 @@ class StopAtTokenFSM(FSM):
|
|||||||
else:
|
else:
|
||||||
return [self.stop_token_id]
|
return [self.stop_token_id]
|
||||||
|
|
||||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
def next_state(self, state: FSMState, token_id: int) -> FSMState:
|
||||||
"""Update the state of the FSM.
|
"""Update the state of the FSM.
|
||||||
|
|
||||||
The FSM stays in the initial state `0` unless the specified stop token
|
The FSM stays in the initial state `0` unless the specified stop token
|
||||||
@@ -84,29 +78,24 @@ class StopAtTokenFSM(FSM):
|
|||||||
The current state of the FSM.
|
The current state of the FSM.
|
||||||
token_id
|
token_id
|
||||||
The id of the token that was just generated.
|
The id of the token that was just generated.
|
||||||
idx
|
|
||||||
The index of the current input in the batch.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
The new state of the FSM.
|
The new state of the FSM.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if idx == 0:
|
|
||||||
self.num_tokens_generated += 1
|
|
||||||
|
|
||||||
if token_id == self.stop_token_id:
|
if token_id == self.stop_token_id:
|
||||||
return FSMState(1)
|
return FSMState(1)
|
||||||
|
|
||||||
return FSMState(0)
|
return FSMState(0)
|
||||||
|
|
||||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
def is_final_state(self, state: FSMState) -> bool:
|
||||||
"""Determine whether the current state of the FSM is a final state."""
|
"""Determine whether the current state of the FSM is a final state."""
|
||||||
return state in self.final_states
|
return state in self.final_states
|
||||||
|
|
||||||
def reset(self) -> None:
|
def copy(self) -> "StopAtTokenFSM":
|
||||||
"""Reset the FSM to its initial state. Here this only resets the token counter."""
|
"""Create a copy of the FSM."""
|
||||||
self.num_tokens_generated = 0
|
return self
|
||||||
|
|
||||||
|
|
||||||
class RegexFSM(FSM):
|
class RegexFSM(FSM):
|
||||||
@@ -117,32 +106,48 @@ class RegexFSM(FSM):
|
|||||||
regex_string: str,
|
regex_string: str,
|
||||||
tokenizer: "Tokenizer",
|
tokenizer: "Tokenizer",
|
||||||
):
|
):
|
||||||
regex_pattern = interegular.parse_pattern(regex_string)
|
@disk_cache()
|
||||||
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
|
def create_states_mapping(
|
||||||
|
regex_string: str, cacheable_vocabulary: Tuple[Tuple[str, int]]
|
||||||
|
) -> Tuple[dict, set, set]:
|
||||||
|
"""Create the variables related to the mapping between states and tokens
|
||||||
|
The parameters of the function are used for caching purpose
|
||||||
|
"""
|
||||||
|
regex_pattern = interegular.parse_pattern(regex_string)
|
||||||
|
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
|
||||||
|
(
|
||||||
|
states_to_token_maps,
|
||||||
|
empty_token_ids,
|
||||||
|
) = create_fsm_index_tokenizer(regex_fsm, tokenizer)
|
||||||
|
|
||||||
|
# We make sure that it is possible to generate strings in the language
|
||||||
|
# of the regular expression with the tokens present in the model's
|
||||||
|
# vocabulary.
|
||||||
|
if not any(
|
||||||
|
regex_fsm.finals.intersection(v.values())
|
||||||
|
for v in states_to_token_maps.values()
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"The vocabulary does not allow us to build a sequence that matches the input regex"
|
||||||
|
)
|
||||||
|
|
||||||
|
final_states = regex_fsm.finals | {
|
||||||
|
-1
|
||||||
|
} # Include the EOS token in final states
|
||||||
|
return states_to_token_maps, empty_token_ids, final_states
|
||||||
|
|
||||||
(
|
(
|
||||||
self.states_to_token_maps,
|
self.states_to_token_maps,
|
||||||
self.empty_token_ids,
|
self.empty_token_ids,
|
||||||
) = create_fsm_index_tokenizer(regex_fsm, tokenizer)
|
self.final_states,
|
||||||
|
) = create_states_mapping(
|
||||||
# We make sure that it is possible to generate strings in the language
|
regex_string, tuple(sorted(tokenizer.vocabulary.items()))
|
||||||
# of the regular expression with the tokens present in the model's
|
)
|
||||||
# vocabulary.
|
|
||||||
if not any(
|
|
||||||
regex_fsm.finals.intersection(v.values())
|
|
||||||
for v in self.states_to_token_maps.values()
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"The vocabulary does not allow us to build a sequence that matches the input regex"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.final_states = regex_fsm.finals | {
|
|
||||||
-1
|
|
||||||
} # Include the EOS token in final states
|
|
||||||
self.num_tokens_generated = 0
|
self.num_tokens_generated = 0
|
||||||
self.vocabulary = tokenizer.vocabulary.values()
|
self.vocabulary = tokenizer.vocabulary.values()
|
||||||
self.end_token_id = tokenizer.eos_token_id
|
self.end_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
def allowed_token_ids(self, state: FSMState) -> List[int]:
|
||||||
"""Generate a list of allowed tokens for the next step.
|
"""Generate a list of allowed tokens for the next step.
|
||||||
|
|
||||||
The initialization of the FSM builds an index which maps FSM states to a
|
The initialization of the FSM builds an index which maps FSM states to a
|
||||||
@@ -159,8 +164,6 @@ class RegexFSM(FSM):
|
|||||||
----------
|
----------
|
||||||
state
|
state
|
||||||
The current state of the FSM.
|
The current state of the FSM.
|
||||||
idx
|
|
||||||
The index of the current input in the batch.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -174,7 +177,7 @@ class RegexFSM(FSM):
|
|||||||
else:
|
else:
|
||||||
return list(next_tokens_to_end_states.keys())
|
return list(next_tokens_to_end_states.keys())
|
||||||
|
|
||||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
def next_state(self, state: FSMState, token_id: int) -> FSMState:
|
||||||
"""Update the state of the FSM.
|
"""Update the state of the FSM.
|
||||||
|
|
||||||
We use the index to determine to which state the FSM should transition
|
We use the index to determine to which state the FSM should transition
|
||||||
@@ -186,17 +189,12 @@ class RegexFSM(FSM):
|
|||||||
The current state of the FSM.
|
The current state of the FSM.
|
||||||
token_id
|
token_id
|
||||||
The id of the token that was just generated.
|
The id of the token that was just generated.
|
||||||
idx
|
|
||||||
The index of the current input in the batch.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
The new state of the FSM.
|
The new state of the FSM.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if idx == 0:
|
|
||||||
self.num_tokens_generated += 1
|
|
||||||
|
|
||||||
if token_id == self.end_token_id:
|
if token_id == self.end_token_id:
|
||||||
return FSMState(-1)
|
return FSMState(-1)
|
||||||
|
|
||||||
@@ -207,24 +205,22 @@ class RegexFSM(FSM):
|
|||||||
|
|
||||||
return FSMState(next_state)
|
return FSMState(next_state)
|
||||||
|
|
||||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
def is_final_state(self, state: FSMState) -> bool:
|
||||||
"""Determine whether the current state of the FSM is a final state."""
|
"""Determine whether the current state of the FSM is a final state."""
|
||||||
return state in self.final_states
|
return state in self.final_states
|
||||||
|
|
||||||
def reset(self) -> None:
|
def copy(self) -> "RegexFSM":
|
||||||
"""Reset the FSM to its initial state. Here this only resets the token counter."""
|
"""Create a copy of the FSM."""
|
||||||
self.num_tokens_generated = 0
|
return self
|
||||||
|
|
||||||
|
|
||||||
class CFGFSM(FSM):
|
class CFGFSM(FSM):
|
||||||
"""FSM to generate text that is in the language of a context-free grammar."""
|
"""FSM to generate text that is in the language of a context-free grammar."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, cfg_string: str, tokenizer: "Tokenizer"):
|
||||||
self,
|
self.cfg_string = cfg_string
|
||||||
cfg_string: str,
|
self.tokenizer = tokenizer
|
||||||
tokenizer: "Tokenizer",
|
|
||||||
):
|
|
||||||
# self.parser = PartialLark(cfg_string, parser="lalr")
|
|
||||||
self.parser = Lark(
|
self.parser = Lark(
|
||||||
cfg_string,
|
cfg_string,
|
||||||
parser="lalr",
|
parser="lalr",
|
||||||
@@ -239,59 +235,52 @@ class CFGFSM(FSM):
|
|||||||
self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp()
|
self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp()
|
||||||
self.terminal_regexps["$END"] = tokenizer.eos_token
|
self.terminal_regexps["$END"] = tokenizer.eos_token
|
||||||
|
|
||||||
self.tokenizer = tokenizer
|
self.generation = ""
|
||||||
self.num_tokens_generated = 0
|
self.reset_state = False
|
||||||
self.generations: List[str] = []
|
self.allow_eos = False
|
||||||
self.regex_fsms: List[RegexFSM] = []
|
self.done = False
|
||||||
self.reset_state: List[bool] = []
|
self.regex_fsm: RegexFSM
|
||||||
self.allow_eos: List[bool] = []
|
|
||||||
self.done: List[bool] = []
|
|
||||||
|
|
||||||
def _set_next_regex_fsm(self, idx: int = 0) -> None:
|
def _set_next_regex_fsm(self) -> None:
|
||||||
"""Use the CFG incremental parser to set the next regex FSM.
|
"""Use the CFG incremental parser to set the next regex FSM.
|
||||||
|
|
||||||
Check what the CFG incremental parser proposes next.
|
Check what the CFG incremental parser proposes next:
|
||||||
If the only proposal is the EOS token,
|
- If the only proposal is the EOS token we set the state to done and
|
||||||
we set the state to done and return.
|
return.
|
||||||
If there are other proposals,
|
- If there are other proposals, we set a new regex FSM and return.
|
||||||
we set a new regex FSM and return.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
interactive = self.parser.parse_interactive(self.generations[idx])
|
interactive = self.parser.parse_interactive(self.generation)
|
||||||
interactive.exhaust_lexer()
|
interactive.exhaust_lexer()
|
||||||
options = {self.terminal_regexps[x] for x in interactive.accepts()}
|
options = {self.terminal_regexps[x] for x in interactive.accepts()}
|
||||||
|
|
||||||
if self.terminal_regexps["$END"] in options:
|
if self.terminal_regexps["$END"] in options:
|
||||||
options.remove(self.terminal_regexps["$END"])
|
options.remove(self.terminal_regexps["$END"])
|
||||||
if len(options) == 0:
|
if len(options) == 0:
|
||||||
self.done[idx] = True
|
self.done = True
|
||||||
return
|
return
|
||||||
self.allow_eos[idx] = True
|
self.allow_eos = True
|
||||||
options.add("")
|
options.add("")
|
||||||
assert len(options) > 1
|
assert len(options) > 1
|
||||||
|
|
||||||
regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")"
|
regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")"
|
||||||
args = (
|
self.regex_fsm = RegexFSM(regex_string, self.tokenizer)
|
||||||
regex_string,
|
self.reset_state = True
|
||||||
self.tokenizer,
|
|
||||||
)
|
|
||||||
if len(self.regex_fsms) <= idx:
|
|
||||||
self.regex_fsms.append(RegexFSM(*args))
|
|
||||||
else:
|
|
||||||
self.regex_fsms[idx] = RegexFSM(*args)
|
|
||||||
self.reset_state[idx] = True
|
|
||||||
|
|
||||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
def allowed_token_ids(self, state: FSMState) -> List[int]:
|
||||||
"""Generate a list of allowed tokens for the next step.
|
"""Generate a list of allowed tokens for the next step.
|
||||||
|
|
||||||
Upon initialization, the CFG incremental parser is used to determine the first regex.
|
Upon initialization, the CFG incremental parser is used to determine the
|
||||||
|
first regex.
|
||||||
|
|
||||||
This regex is used for proposals until either:
|
This regex is used for proposals until either:
|
||||||
- the regex is exhausted, and its only remaining option is the EOS token,
|
|
||||||
in which case we always transition to the next regex
|
- The regex is exhausted, and its only remaining option is the EOS
|
||||||
- the regex can be exhausted, but the EOS token is not the only remaining option,
|
token, in which case we always transition to the next regex
|
||||||
in which case we transition to the next regex with probability P (TODO)
|
- The regex can be exhausted, but the EOS token is not the only
|
||||||
or remove the possibility of generating the EOS token and continue with the current regex
|
remaining option, in which case we transition to the next regex with
|
||||||
|
probability P (TODO) or remove the possibility of generating the EOS
|
||||||
|
token and continue with the current regex
|
||||||
|
|
||||||
The CFG incremental parser is allowed to propose the EOS token from any final state,
|
The CFG incremental parser is allowed to propose the EOS token from any final state,
|
||||||
and once it is generated, the FSM will continue to always generate the EOS token.
|
and once it is generated, the FSM will continue to always generate the EOS token.
|
||||||
@@ -300,22 +289,14 @@ class CFGFSM(FSM):
|
|||||||
----------
|
----------
|
||||||
state
|
state
|
||||||
The current state of the FSM.
|
The current state of the FSM.
|
||||||
idx
|
|
||||||
The index of the current input in the batch.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
A list that contains the tokens to mask.
|
A list that contains the tokens to mask.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if len(self.generations) <= idx:
|
if self.generation != "":
|
||||||
self.generations.append("")
|
proposal = self.regex_fsm.allowed_token_ids(state)
|
||||||
self.reset_state.append(False)
|
|
||||||
self.allow_eos.append(False)
|
|
||||||
self.done.append(False)
|
|
||||||
|
|
||||||
if len(self.regex_fsms) > idx:
|
|
||||||
proposal = self.regex_fsms[idx].allowed_token_ids(state)
|
|
||||||
if self.tokenizer.eos_token_id not in proposal:
|
if self.tokenizer.eos_token_id not in proposal:
|
||||||
return proposal
|
return proposal
|
||||||
if set(proposal) != {self.tokenizer.eos_token_id}:
|
if set(proposal) != {self.tokenizer.eos_token_id}:
|
||||||
@@ -323,23 +304,23 @@ class CFGFSM(FSM):
|
|||||||
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
|
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
|
||||||
return proposal
|
return proposal
|
||||||
|
|
||||||
self._set_next_regex_fsm(idx)
|
self._set_next_regex_fsm()
|
||||||
|
|
||||||
if self.done[idx]:
|
if self.done:
|
||||||
return [self.tokenizer.eos_token_id]
|
return [self.tokenizer.eos_token_id]
|
||||||
|
|
||||||
if self.reset_state[idx]:
|
if self.reset_state:
|
||||||
state = FSMState(0)
|
state = FSMState(0)
|
||||||
|
|
||||||
proposal = self.regex_fsms[idx].allowed_token_ids(state)
|
proposal = self.regex_fsm.allowed_token_ids(state)
|
||||||
if self.allow_eos[idx]:
|
if self.allow_eos:
|
||||||
self.allow_eos[idx] = False
|
self.allow_eos = False
|
||||||
else:
|
else:
|
||||||
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
|
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
|
||||||
assert len(proposal) > 0
|
assert len(proposal) > 0
|
||||||
return proposal
|
return proposal
|
||||||
|
|
||||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
def next_state(self, state: FSMState, token_id: int) -> FSMState:
|
||||||
"""Update the state of the FSM.
|
"""Update the state of the FSM.
|
||||||
|
|
||||||
Transitions the underlying regex FSM to its next state.
|
Transitions the underlying regex FSM to its next state.
|
||||||
@@ -352,34 +333,26 @@ class CFGFSM(FSM):
|
|||||||
The current state of the FSM.
|
The current state of the FSM.
|
||||||
token_id
|
token_id
|
||||||
The id of the token that was just generated.
|
The id of the token that was just generated.
|
||||||
idx
|
|
||||||
The index of the current input in the batch.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
The new state of the FSM.
|
The new state of the FSM.
|
||||||
"""
|
"""
|
||||||
if idx == 0:
|
|
||||||
self.num_tokens_generated += 1
|
|
||||||
if token_id == self.tokenizer.eos_token_id:
|
if token_id == self.tokenizer.eos_token_id:
|
||||||
self.done[idx] = True
|
self.done = True
|
||||||
return FSMState(-1)
|
return FSMState(-1)
|
||||||
if self.reset_state[idx]:
|
if self.reset_state:
|
||||||
self.reset_state[idx] = False
|
self.reset_state = False
|
||||||
state = FSMState(0)
|
state = FSMState(0)
|
||||||
|
|
||||||
self.generations[idx] += self.tokenizer.decode([token_id])[0]
|
self.generation += self.tokenizer.decode([token_id])[0]
|
||||||
|
|
||||||
return self.regex_fsms[idx].next_state(state, token_id, idx)
|
return self.regex_fsm.next_state(state, token_id)
|
||||||
|
|
||||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
def is_final_state(self, state: FSMState) -> bool:
|
||||||
"""Return whether the current state of the FSM is a final state."""
|
"""Return whether the current state of the FSM is a final state."""
|
||||||
return self.done[idx]
|
return self.done
|
||||||
|
|
||||||
def reset(self) -> None:
|
def copy(self) -> "CFGFSM":
|
||||||
"""Reset the FSM to its initial state, so it can be called on a fresh batch on inputs."""
|
"""Create a copy of the FSM."""
|
||||||
self.num_tokens_generated = 0
|
return CFGFSM(self.cfg_string, self.tokenizer)
|
||||||
self.generations = []
|
|
||||||
self.regex_fsms = []
|
|
||||||
self.reset_state = []
|
|
||||||
self.done = []
|
|
||||||
|
|||||||
@@ -1,41 +1,17 @@
|
|||||||
import threading
|
|
||||||
|
|
||||||
from sglang.srt.constrained.fsm import RegexFSM
|
from sglang.srt.constrained.fsm import RegexFSM
|
||||||
from sglang.srt.constrained.tokenizer import TransformerTokenizer
|
from sglang.srt.constrained.tokenizer import TransformerTokenizer
|
||||||
|
|
||||||
|
|
||||||
def get_fsm(regex, tokenizer, fsm_cache_entry):
|
|
||||||
outlines_tokenizer = TransformerTokenizer(tokenizer)
|
|
||||||
fsm = RegexFSM(regex, outlines_tokenizer)
|
|
||||||
fsm_cache_entry.fsm = fsm
|
|
||||||
fsm_cache_entry.event.set()
|
|
||||||
|
|
||||||
|
|
||||||
class FSMCacheEntry:
|
|
||||||
def __init__(self):
|
|
||||||
self.fsm = None
|
|
||||||
self.event = threading.Event()
|
|
||||||
|
|
||||||
|
|
||||||
class FSMCache:
|
class FSMCache:
|
||||||
def __init__(self, tokenizer):
|
def __init__(self, tokenizer_path, tokenizer_args_dict):
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.tokenizer = tokenizer
|
self.outlines_tokenizer = TransformerTokenizer(
|
||||||
|
tokenizer_path, **tokenizer_args_dict
|
||||||
|
)
|
||||||
|
|
||||||
def init_fsm_in_background(self, regex):
|
def init_fsm(self, regex):
|
||||||
if regex not in self.cache:
|
if regex not in self.cache:
|
||||||
self.cache[regex] = FSMCacheEntry()
|
fsm = RegexFSM(regex, self.outlines_tokenizer)
|
||||||
threading.Thread(
|
self.cache[regex] = fsm
|
||||||
target=get_fsm,
|
|
||||||
args=(
|
|
||||||
regex,
|
|
||||||
self.tokenizer,
|
|
||||||
self.cache[regex],
|
|
||||||
),
|
|
||||||
).start()
|
|
||||||
|
|
||||||
def get_fsm(self, regex):
|
return self.cache[regex]
|
||||||
self.init_fsm_in_background(regex)
|
|
||||||
entry = self.cache[regex]
|
|
||||||
entry.event.wait()
|
|
||||||
return entry.fsm
|
|
||||||
|
|||||||
@@ -2,17 +2,7 @@
|
|||||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py
|
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py
|
||||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py
|
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import (
|
from typing import Dict, Hashable, List, Protocol, Set, Tuple, Union
|
||||||
TYPE_CHECKING,
|
|
||||||
Dict,
|
|
||||||
Hashable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Protocol,
|
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -50,15 +40,6 @@ class Tokenizer(Protocol, Hashable):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
|
||||||
|
|
||||||
__all__ = ["transformers"]
|
|
||||||
|
|
||||||
|
|
||||||
KVCacheType = Tuple[Tuple[torch.DoubleTensor, torch.DoubleTensor], ...]
|
|
||||||
|
|
||||||
|
|
||||||
def get_llama_tokenizer_types():
|
def get_llama_tokenizer_types():
|
||||||
"""Get all the Llama tokenizer types/classes that need work-arounds.
|
"""Get all the Llama tokenizer types/classes that need work-arounds.
|
||||||
|
|
||||||
@@ -101,76 +82,17 @@ def get_llama_tokenizer_types():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Transformer:
|
|
||||||
"""Represents a `transformers` model."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: "PreTrainedModel",
|
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
):
|
|
||||||
self.device = model.device
|
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
|
|
||||||
@torch.inference_mode
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
attention_mask: torch.LongTensor,
|
|
||||||
past_key_values: Optional[Tuple] = None,
|
|
||||||
) -> Tuple[torch.FloatTensor, Optional[KVCacheType]]:
|
|
||||||
"""Compute a forward pass through the transformer model.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
input_ids
|
|
||||||
The input token ids. Must be one or two dimensional.
|
|
||||||
attention_mask
|
|
||||||
The attention mask. Must be one or two dimensional.
|
|
||||||
past_key_values
|
|
||||||
A tuple of tuples containing the cached key and value tensors for each
|
|
||||||
attention head.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
The computed logits and the new cached key and value tensors.
|
|
||||||
|
|
||||||
"""
|
|
||||||
assert 0 < input_ids.ndim < 3
|
|
||||||
|
|
||||||
if past_key_values:
|
|
||||||
input_ids = input_ids[..., -1].unsqueeze(-1)
|
|
||||||
|
|
||||||
output = self.model(
|
|
||||||
input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
return_dict=True,
|
|
||||||
output_attentions=False,
|
|
||||||
output_hidden_states=False,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
return output.logits, output.past_key_values
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
attention_mask: torch.LongTensor,
|
|
||||||
past_key_values: Optional[Tuple] = None,
|
|
||||||
) -> torch.FloatTensor:
|
|
||||||
logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values)
|
|
||||||
next_token_logits = logits[..., -1, :]
|
|
||||||
|
|
||||||
return next_token_logits, kv_cache
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerTokenizer(Tokenizer):
|
class TransformerTokenizer(Tokenizer):
|
||||||
"""Represents a tokenizer for models in the `transformers` library."""
|
"""Represents a tokenizer for models in the `transformers` library."""
|
||||||
|
|
||||||
def __init__(self, tokenizer):
|
def __init__(self, model_name: str, **kwargs):
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
kwargs.setdefault("padding_side", "left")
|
||||||
|
self.model_name = model_name
|
||||||
# TODO: Do something to make this hashable?
|
# TODO: Do something to make this hashable?
|
||||||
self.tokenizer = tokenizer
|
self.kwargs = kwargs
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **kwargs)
|
||||||
self.eos_token_id = self.tokenizer.eos_token_id
|
self.eos_token_id = self.tokenizer.eos_token_id
|
||||||
self.eos_token = self.tokenizer.eos_token
|
self.eos_token = self.tokenizer.eos_token
|
||||||
|
|
||||||
@@ -212,55 +134,10 @@ class TransformerTokenizer(Tokenizer):
|
|||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if isinstance(other, type(self)):
|
if isinstance(other, type(self)):
|
||||||
return False
|
return other.model_name == self.model_name and other.kwargs == self.kwargs
|
||||||
# TODO(lsyin): the lru_cache for the TransoformerTokenizer is useless ?
|
|
||||||
# return other.model_name == self.model_name and other.kwargs == self.kwargs
|
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
from datasets.fingerprint import Hasher
|
from datasets.fingerprint import Hasher
|
||||||
|
|
||||||
return hash(Hasher.hash(self.tokenizer))
|
return hash(Hasher.hash(self.tokenizer))
|
||||||
|
|
||||||
|
|
||||||
def transformers(
|
|
||||||
model_name: str,
|
|
||||||
device: Optional[str] = None,
|
|
||||||
model_kwargs: dict = {},
|
|
||||||
tokenizer_kwargs: dict = {},
|
|
||||||
):
|
|
||||||
"""Instantiate a model from the `transformers` library and its tokenizer.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
model_name
|
|
||||||
The name of the model as listed on Hugging Face's model page.
|
|
||||||
device
|
|
||||||
The device(s) on which the model should be loaded. This overrides
|
|
||||||
the `device_map` entry in `model_kwargs` when provided.
|
|
||||||
model_kwargs
|
|
||||||
A dictionary that contains the keyword arguments to pass to the
|
|
||||||
`from_pretrained` method when loading the model.
|
|
||||||
tokenizer_kwargs
|
|
||||||
A dictionary that contains the keyword arguments to pass to the
|
|
||||||
`from_pretrained` method when loading the tokenizer.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
A `TransformersModel` model instance.
|
|
||||||
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from transformers import AutoModelForCausalLM
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"The `transformers` library needs to be installed in order to use `transformers` models."
|
|
||||||
)
|
|
||||||
|
|
||||||
if device is not None:
|
|
||||||
model_kwargs["device_map"] = device
|
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
|
|
||||||
tokenizer = TransformerTokenizer(model_name, **tokenizer_kwargs)
|
|
||||||
|
|
||||||
return Transformer(model, tokenizer)
|
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class Req:
|
|||||||
|
|
||||||
# for constrained decoding
|
# for constrained decoding
|
||||||
self.regex_fsm = None
|
self.regex_fsm = None
|
||||||
self.regex_fsm_state = None
|
self.regex_fsm_state = 0
|
||||||
|
|
||||||
def max_new_tokens(self):
|
def max_new_tokens(self):
|
||||||
return self.sampling_params.max_new_tokens
|
return self.sampling_params.max_new_tokens
|
||||||
|
|||||||
@@ -111,7 +111,13 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
self.stream_interval = server_args.stream_interval
|
self.stream_interval = server_args.stream_interval
|
||||||
|
|
||||||
# Init the FSM cache for constrained generation
|
# Init the FSM cache for constrained generation
|
||||||
self.regex_fsm_cache = FSMCache(self.tokenizer)
|
self.regex_fsm_cache = FSMCache(
|
||||||
|
server_args.tokenizer_path,
|
||||||
|
{
|
||||||
|
"tokenizer_mode": server_args.tokenizer_mode,
|
||||||
|
"trust_remote_code": server_args.trust_remote_code,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Init new token estimation
|
# Init new token estimation
|
||||||
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
|
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
|
||||||
@@ -213,6 +219,10 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
req.stream = recv_req.stream
|
req.stream = recv_req.stream
|
||||||
req.tokenizer = self.tokenizer
|
req.tokenizer = self.tokenizer
|
||||||
|
|
||||||
|
# Init regex fsm
|
||||||
|
if req.sampling_params.regex is not None:
|
||||||
|
req.regex_fsm = self.regex_fsm_cache.init_fsm(req.sampling_params.regex)
|
||||||
|
|
||||||
# Truncate long prompts
|
# Truncate long prompts
|
||||||
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
|
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
|
||||||
req.sampling_params.max_new_tokens = min(
|
req.sampling_params.max_new_tokens = min(
|
||||||
@@ -322,11 +332,10 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
self.model_config.vocab_size, self.int_token_logit_bias
|
self.model_config.vocab_size, self.int_token_logit_bias
|
||||||
)
|
)
|
||||||
|
|
||||||
# init the regex fsm before first sampling
|
# Reset regex fsm state before first sampling due to retractions
|
||||||
for req in batch.reqs:
|
for req in batch.reqs:
|
||||||
if req.sampling_params.regex is not None:
|
if req.sampling_params.regex is not None:
|
||||||
req.regex_fsm_state = 0
|
req.regex_fsm_state = 0
|
||||||
req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex)
|
|
||||||
|
|
||||||
if batch.extend_num_tokens != 0:
|
if batch.extend_num_tokens != 0:
|
||||||
# Forward
|
# Forward
|
||||||
|
|||||||
Reference in New Issue
Block a user