Disk FSM cache and adjust code. (#63)
This commit is contained in:
@@ -19,7 +19,7 @@ dependencies = [
|
||||
|
||||
[project.optional-dependencies]
|
||||
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
|
||||
"interegular", "lark", "numba", "pydantic"]
|
||||
"interegular", "lark", "numba", "pydantic", "diskcache", "cloudpickle"]
|
||||
openai = ["openai>=1.0"]
|
||||
anthropic = ["anthropic"]
|
||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
||||
|
||||
70
python/sglang/srt/constrained/disk_cache.py
Normal file
70
python/sglang/srt/constrained/disk_cache.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/6c6966cfa24e9c120494ebb317c6126aa2ae94af/outlines/caching.py
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
import cloudpickle
|
||||
from diskcache import Cache
|
||||
|
||||
home_dir = os.path.expanduser("~")
|
||||
cache_dir = os.environ.get("SGLANG_CACHE_DIR", f"{home_dir}/.cache/sglang")
|
||||
memory = Cache(cache_dir, eviction_policy="none", cull_limit=0)
|
||||
_caching_enabled = True
|
||||
|
||||
|
||||
def hash_arguments(*args, **kwargs) -> str:
|
||||
"""Create a hash out of the args and kwargs provided"""
|
||||
result = hashlib.md5()
|
||||
for item in list(args) + sorted(kwargs.items()):
|
||||
result.update(cloudpickle.dumps(item))
|
||||
return result.hexdigest()
|
||||
|
||||
|
||||
def disk_cache(key_function: Optional[Callable] = None):
|
||||
def decorator(cached_function: Callable):
|
||||
def wrapper(*args, **kwargs):
|
||||
if not _caching_enabled:
|
||||
return cached_function(*args, **kwargs)
|
||||
if key_function:
|
||||
key_args = key_function(*args, **kwargs)
|
||||
cache_key = hash_arguments(*key_args)
|
||||
else:
|
||||
cache_key = hash_arguments(*args, **kwargs)
|
||||
if cache_key in memory:
|
||||
return memory[cache_key]
|
||||
result = cached_function(*args, **kwargs)
|
||||
memory[cache_key] = result
|
||||
return result
|
||||
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
if not _caching_enabled:
|
||||
return await cached_function(*args, **kwargs)
|
||||
if key_function:
|
||||
key_args = key_function(*args, **kwargs)
|
||||
cache_key = hash_arguments(*key_args)
|
||||
else:
|
||||
cache_key = hash_arguments(*args, **kwargs)
|
||||
if cache_key in memory:
|
||||
return memory[cache_key]
|
||||
result = await cached_function(*args, **kwargs)
|
||||
memory[cache_key] = result
|
||||
return result
|
||||
|
||||
if asyncio.iscoroutinefunction(cached_function):
|
||||
return async_wrapper
|
||||
else:
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def disable_cache():
|
||||
global _caching_enabled
|
||||
_caching_enabled = False
|
||||
|
||||
|
||||
def clear_cache():
|
||||
global memory
|
||||
memory.clear()
|
||||
@@ -1,9 +1,10 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/fsm.py
|
||||
from typing import List, NewType, Protocol
|
||||
# https://github.com/outlines-dev/outlines/blob/6c6966cfa24e9c120494ebb317c6126aa2ae94af/outlines/fsm/fsm.py
|
||||
from typing import List, NewType, Protocol, Tuple
|
||||
|
||||
import interegular
|
||||
from lark import Lark
|
||||
from sglang.srt.constrained.disk_cache import disk_cache
|
||||
|
||||
# from outlines.fsm.parsing import PartialLark
|
||||
from sglang.srt.constrained.regex import (
|
||||
@@ -16,16 +17,16 @@ FSMState = NewType("FSMState", int)
|
||||
|
||||
|
||||
class FSM(Protocol):
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
def allowed_token_ids(self, state: FSMState) -> List[int]:
|
||||
...
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
def next_state(self, state: FSMState, token_id: int) -> FSMState:
|
||||
...
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
def is_final_state(self, state: FSMState) -> bool:
|
||||
...
|
||||
|
||||
def reset(self) -> None:
|
||||
def copy(self) -> "FSM":
|
||||
...
|
||||
|
||||
|
||||
@@ -38,17 +39,12 @@ class StopAtTokenFSM(FSM):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: "Tokenizer",
|
||||
stop_token_id: int,
|
||||
):
|
||||
def __init__(self, tokenizer: "Tokenizer", stop_token_id: int):
|
||||
self.stop_token_id = stop_token_id
|
||||
self.num_tokens_generated = 0
|
||||
self.vocabulary = tokenizer.vocabulary.values()
|
||||
self.final_states = {1}
|
||||
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
def allowed_token_ids(self, state: FSMState) -> List[int]:
|
||||
"""Generate a list of allowed tokens for the next step.
|
||||
|
||||
When in the initial state we allow every token to be generated.
|
||||
@@ -58,8 +54,6 @@ class StopAtTokenFSM(FSM):
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -71,7 +65,7 @@ class StopAtTokenFSM(FSM):
|
||||
else:
|
||||
return [self.stop_token_id]
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
def next_state(self, state: FSMState, token_id: int) -> FSMState:
|
||||
"""Update the state of the FSM.
|
||||
|
||||
The FSM stays in the initial state `0` unless the specified stop token
|
||||
@@ -84,29 +78,24 @@ class StopAtTokenFSM(FSM):
|
||||
The current state of the FSM.
|
||||
token_id
|
||||
The id of the token that was just generated.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The new state of the FSM.
|
||||
|
||||
"""
|
||||
if idx == 0:
|
||||
self.num_tokens_generated += 1
|
||||
|
||||
if token_id == self.stop_token_id:
|
||||
return FSMState(1)
|
||||
|
||||
return FSMState(0)
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
def is_final_state(self, state: FSMState) -> bool:
|
||||
"""Determine whether the current state of the FSM is a final state."""
|
||||
return state in self.final_states
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the FSM to its initial state. Here this only resets the token counter."""
|
||||
self.num_tokens_generated = 0
|
||||
def copy(self) -> "StopAtTokenFSM":
|
||||
"""Create a copy of the FSM."""
|
||||
return self
|
||||
|
||||
|
||||
class RegexFSM(FSM):
|
||||
@@ -117,32 +106,48 @@ class RegexFSM(FSM):
|
||||
regex_string: str,
|
||||
tokenizer: "Tokenizer",
|
||||
):
|
||||
regex_pattern = interegular.parse_pattern(regex_string)
|
||||
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
|
||||
@disk_cache()
|
||||
def create_states_mapping(
|
||||
regex_string: str, cacheable_vocabulary: Tuple[Tuple[str, int]]
|
||||
) -> Tuple[dict, set, set]:
|
||||
"""Create the variables related to the mapping between states and tokens
|
||||
The parameters of the function are used for caching purpose
|
||||
"""
|
||||
regex_pattern = interegular.parse_pattern(regex_string)
|
||||
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
|
||||
(
|
||||
states_to_token_maps,
|
||||
empty_token_ids,
|
||||
) = create_fsm_index_tokenizer(regex_fsm, tokenizer)
|
||||
|
||||
# We make sure that it is possible to generate strings in the language
|
||||
# of the regular expression with the tokens present in the model's
|
||||
# vocabulary.
|
||||
if not any(
|
||||
regex_fsm.finals.intersection(v.values())
|
||||
for v in states_to_token_maps.values()
|
||||
):
|
||||
raise ValueError(
|
||||
"The vocabulary does not allow us to build a sequence that matches the input regex"
|
||||
)
|
||||
|
||||
final_states = regex_fsm.finals | {
|
||||
-1
|
||||
} # Include the EOS token in final states
|
||||
return states_to_token_maps, empty_token_ids, final_states
|
||||
|
||||
(
|
||||
self.states_to_token_maps,
|
||||
self.empty_token_ids,
|
||||
) = create_fsm_index_tokenizer(regex_fsm, tokenizer)
|
||||
|
||||
# We make sure that it is possible to generate strings in the language
|
||||
# of the regular expression with the tokens present in the model's
|
||||
# vocabulary.
|
||||
if not any(
|
||||
regex_fsm.finals.intersection(v.values())
|
||||
for v in self.states_to_token_maps.values()
|
||||
):
|
||||
raise ValueError(
|
||||
"The vocabulary does not allow us to build a sequence that matches the input regex"
|
||||
)
|
||||
|
||||
self.final_states = regex_fsm.finals | {
|
||||
-1
|
||||
} # Include the EOS token in final states
|
||||
self.final_states,
|
||||
) = create_states_mapping(
|
||||
regex_string, tuple(sorted(tokenizer.vocabulary.items()))
|
||||
)
|
||||
self.num_tokens_generated = 0
|
||||
self.vocabulary = tokenizer.vocabulary.values()
|
||||
self.end_token_id = tokenizer.eos_token_id
|
||||
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
def allowed_token_ids(self, state: FSMState) -> List[int]:
|
||||
"""Generate a list of allowed tokens for the next step.
|
||||
|
||||
The initialization of the FSM builds an index which maps FSM states to a
|
||||
@@ -159,8 +164,6 @@ class RegexFSM(FSM):
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -174,7 +177,7 @@ class RegexFSM(FSM):
|
||||
else:
|
||||
return list(next_tokens_to_end_states.keys())
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
def next_state(self, state: FSMState, token_id: int) -> FSMState:
|
||||
"""Update the state of the FSM.
|
||||
|
||||
We use the index to determine to which state the FSM should transition
|
||||
@@ -186,17 +189,12 @@ class RegexFSM(FSM):
|
||||
The current state of the FSM.
|
||||
token_id
|
||||
The id of the token that was just generated.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The new state of the FSM.
|
||||
|
||||
"""
|
||||
if idx == 0:
|
||||
self.num_tokens_generated += 1
|
||||
|
||||
if token_id == self.end_token_id:
|
||||
return FSMState(-1)
|
||||
|
||||
@@ -207,24 +205,22 @@ class RegexFSM(FSM):
|
||||
|
||||
return FSMState(next_state)
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
def is_final_state(self, state: FSMState) -> bool:
|
||||
"""Determine whether the current state of the FSM is a final state."""
|
||||
return state in self.final_states
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the FSM to its initial state. Here this only resets the token counter."""
|
||||
self.num_tokens_generated = 0
|
||||
def copy(self) -> "RegexFSM":
|
||||
"""Create a copy of the FSM."""
|
||||
return self
|
||||
|
||||
|
||||
class CFGFSM(FSM):
|
||||
"""FSM to generate text that is in the language of a context-free grammar."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg_string: str,
|
||||
tokenizer: "Tokenizer",
|
||||
):
|
||||
# self.parser = PartialLark(cfg_string, parser="lalr")
|
||||
def __init__(self, cfg_string: str, tokenizer: "Tokenizer"):
|
||||
self.cfg_string = cfg_string
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.parser = Lark(
|
||||
cfg_string,
|
||||
parser="lalr",
|
||||
@@ -239,59 +235,52 @@ class CFGFSM(FSM):
|
||||
self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp()
|
||||
self.terminal_regexps["$END"] = tokenizer.eos_token
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.num_tokens_generated = 0
|
||||
self.generations: List[str] = []
|
||||
self.regex_fsms: List[RegexFSM] = []
|
||||
self.reset_state: List[bool] = []
|
||||
self.allow_eos: List[bool] = []
|
||||
self.done: List[bool] = []
|
||||
self.generation = ""
|
||||
self.reset_state = False
|
||||
self.allow_eos = False
|
||||
self.done = False
|
||||
self.regex_fsm: RegexFSM
|
||||
|
||||
def _set_next_regex_fsm(self, idx: int = 0) -> None:
|
||||
def _set_next_regex_fsm(self) -> None:
|
||||
"""Use the CFG incremental parser to set the next regex FSM.
|
||||
|
||||
Check what the CFG incremental parser proposes next.
|
||||
If the only proposal is the EOS token,
|
||||
we set the state to done and return.
|
||||
If there are other proposals,
|
||||
we set a new regex FSM and return.
|
||||
Check what the CFG incremental parser proposes next:
|
||||
- If the only proposal is the EOS token we set the state to done and
|
||||
return.
|
||||
- If there are other proposals, we set a new regex FSM and return.
|
||||
|
||||
"""
|
||||
interactive = self.parser.parse_interactive(self.generations[idx])
|
||||
interactive = self.parser.parse_interactive(self.generation)
|
||||
interactive.exhaust_lexer()
|
||||
options = {self.terminal_regexps[x] for x in interactive.accepts()}
|
||||
|
||||
if self.terminal_regexps["$END"] in options:
|
||||
options.remove(self.terminal_regexps["$END"])
|
||||
if len(options) == 0:
|
||||
self.done[idx] = True
|
||||
self.done = True
|
||||
return
|
||||
self.allow_eos[idx] = True
|
||||
self.allow_eos = True
|
||||
options.add("")
|
||||
assert len(options) > 1
|
||||
|
||||
regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")"
|
||||
args = (
|
||||
regex_string,
|
||||
self.tokenizer,
|
||||
)
|
||||
if len(self.regex_fsms) <= idx:
|
||||
self.regex_fsms.append(RegexFSM(*args))
|
||||
else:
|
||||
self.regex_fsms[idx] = RegexFSM(*args)
|
||||
self.reset_state[idx] = True
|
||||
self.regex_fsm = RegexFSM(regex_string, self.tokenizer)
|
||||
self.reset_state = True
|
||||
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
def allowed_token_ids(self, state: FSMState) -> List[int]:
|
||||
"""Generate a list of allowed tokens for the next step.
|
||||
|
||||
Upon initialization, the CFG incremental parser is used to determine the first regex.
|
||||
Upon initialization, the CFG incremental parser is used to determine the
|
||||
first regex.
|
||||
|
||||
This regex is used for proposals until either:
|
||||
- the regex is exhausted, and its only remaining option is the EOS token,
|
||||
in which case we always transition to the next regex
|
||||
- the regex can be exhausted, but the EOS token is not the only remaining option,
|
||||
in which case we transition to the next regex with probability P (TODO)
|
||||
or remove the possibility of generating the EOS token and continue with the current regex
|
||||
|
||||
- The regex is exhausted, and its only remaining option is the EOS
|
||||
token, in which case we always transition to the next regex
|
||||
- The regex can be exhausted, but the EOS token is not the only
|
||||
remaining option, in which case we transition to the next regex with
|
||||
probability P (TODO) or remove the possibility of generating the EOS
|
||||
token and continue with the current regex
|
||||
|
||||
The CFG incremental parser is allowed to propose the EOS token from any final state,
|
||||
and once it is generated, the FSM will continue to always generate the EOS token.
|
||||
@@ -300,22 +289,14 @@ class CFGFSM(FSM):
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list that contains the tokens to mask.
|
||||
|
||||
"""
|
||||
if len(self.generations) <= idx:
|
||||
self.generations.append("")
|
||||
self.reset_state.append(False)
|
||||
self.allow_eos.append(False)
|
||||
self.done.append(False)
|
||||
|
||||
if len(self.regex_fsms) > idx:
|
||||
proposal = self.regex_fsms[idx].allowed_token_ids(state)
|
||||
if self.generation != "":
|
||||
proposal = self.regex_fsm.allowed_token_ids(state)
|
||||
if self.tokenizer.eos_token_id not in proposal:
|
||||
return proposal
|
||||
if set(proposal) != {self.tokenizer.eos_token_id}:
|
||||
@@ -323,23 +304,23 @@ class CFGFSM(FSM):
|
||||
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
|
||||
return proposal
|
||||
|
||||
self._set_next_regex_fsm(idx)
|
||||
self._set_next_regex_fsm()
|
||||
|
||||
if self.done[idx]:
|
||||
if self.done:
|
||||
return [self.tokenizer.eos_token_id]
|
||||
|
||||
if self.reset_state[idx]:
|
||||
if self.reset_state:
|
||||
state = FSMState(0)
|
||||
|
||||
proposal = self.regex_fsms[idx].allowed_token_ids(state)
|
||||
if self.allow_eos[idx]:
|
||||
self.allow_eos[idx] = False
|
||||
proposal = self.regex_fsm.allowed_token_ids(state)
|
||||
if self.allow_eos:
|
||||
self.allow_eos = False
|
||||
else:
|
||||
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
|
||||
assert len(proposal) > 0
|
||||
return proposal
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
def next_state(self, state: FSMState, token_id: int) -> FSMState:
|
||||
"""Update the state of the FSM.
|
||||
|
||||
Transitions the underlying regex FSM to its next state.
|
||||
@@ -352,34 +333,26 @@ class CFGFSM(FSM):
|
||||
The current state of the FSM.
|
||||
token_id
|
||||
The id of the token that was just generated.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The new state of the FSM.
|
||||
"""
|
||||
if idx == 0:
|
||||
self.num_tokens_generated += 1
|
||||
if token_id == self.tokenizer.eos_token_id:
|
||||
self.done[idx] = True
|
||||
self.done = True
|
||||
return FSMState(-1)
|
||||
if self.reset_state[idx]:
|
||||
self.reset_state[idx] = False
|
||||
if self.reset_state:
|
||||
self.reset_state = False
|
||||
state = FSMState(0)
|
||||
|
||||
self.generations[idx] += self.tokenizer.decode([token_id])[0]
|
||||
self.generation += self.tokenizer.decode([token_id])[0]
|
||||
|
||||
return self.regex_fsms[idx].next_state(state, token_id, idx)
|
||||
return self.regex_fsm.next_state(state, token_id)
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
def is_final_state(self, state: FSMState) -> bool:
|
||||
"""Return whether the current state of the FSM is a final state."""
|
||||
return self.done[idx]
|
||||
return self.done
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the FSM to its initial state, so it can be called on a fresh batch on inputs."""
|
||||
self.num_tokens_generated = 0
|
||||
self.generations = []
|
||||
self.regex_fsms = []
|
||||
self.reset_state = []
|
||||
self.done = []
|
||||
def copy(self) -> "CFGFSM":
|
||||
"""Create a copy of the FSM."""
|
||||
return CFGFSM(self.cfg_string, self.tokenizer)
|
||||
|
||||
@@ -1,41 +1,17 @@
|
||||
import threading
|
||||
|
||||
from sglang.srt.constrained.fsm import RegexFSM
|
||||
from sglang.srt.constrained.tokenizer import TransformerTokenizer
|
||||
|
||||
|
||||
def get_fsm(regex, tokenizer, fsm_cache_entry):
|
||||
outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
fsm = RegexFSM(regex, outlines_tokenizer)
|
||||
fsm_cache_entry.fsm = fsm
|
||||
fsm_cache_entry.event.set()
|
||||
|
||||
|
||||
class FSMCacheEntry:
|
||||
def __init__(self):
|
||||
self.fsm = None
|
||||
self.event = threading.Event()
|
||||
|
||||
|
||||
class FSMCache:
|
||||
def __init__(self, tokenizer):
|
||||
def __init__(self, tokenizer_path, tokenizer_args_dict):
|
||||
self.cache = {}
|
||||
self.tokenizer = tokenizer
|
||||
self.outlines_tokenizer = TransformerTokenizer(
|
||||
tokenizer_path, **tokenizer_args_dict
|
||||
)
|
||||
|
||||
def init_fsm_in_background(self, regex):
|
||||
def init_fsm(self, regex):
|
||||
if regex not in self.cache:
|
||||
self.cache[regex] = FSMCacheEntry()
|
||||
threading.Thread(
|
||||
target=get_fsm,
|
||||
args=(
|
||||
regex,
|
||||
self.tokenizer,
|
||||
self.cache[regex],
|
||||
),
|
||||
).start()
|
||||
fsm = RegexFSM(regex, self.outlines_tokenizer)
|
||||
self.cache[regex] = fsm
|
||||
|
||||
def get_fsm(self, regex):
|
||||
self.init_fsm_in_background(regex)
|
||||
entry = self.cache[regex]
|
||||
entry.event.wait()
|
||||
return entry.fsm
|
||||
return self.cache[regex]
|
||||
|
||||
@@ -2,17 +2,7 @@
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py
|
||||
from abc import abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
Hashable,
|
||||
List,
|
||||
Optional,
|
||||
Protocol,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from typing import Dict, Hashable, List, Protocol, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -50,15 +40,6 @@ class Tokenizer(Protocol, Hashable):
|
||||
...
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
__all__ = ["transformers"]
|
||||
|
||||
|
||||
KVCacheType = Tuple[Tuple[torch.DoubleTensor, torch.DoubleTensor], ...]
|
||||
|
||||
|
||||
def get_llama_tokenizer_types():
|
||||
"""Get all the Llama tokenizer types/classes that need work-arounds.
|
||||
|
||||
@@ -101,76 +82,17 @@ def get_llama_tokenizer_types():
|
||||
)
|
||||
|
||||
|
||||
class Transformer:
|
||||
"""Represents a `transformers` model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
):
|
||||
self.device = model.device
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@torch.inference_mode
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.LongTensor,
|
||||
past_key_values: Optional[Tuple] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[KVCacheType]]:
|
||||
"""Compute a forward pass through the transformer model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_ids
|
||||
The input token ids. Must be one or two dimensional.
|
||||
attention_mask
|
||||
The attention mask. Must be one or two dimensional.
|
||||
past_key_values
|
||||
A tuple of tuples containing the cached key and value tensors for each
|
||||
attention head.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The computed logits and the new cached key and value tensors.
|
||||
|
||||
"""
|
||||
assert 0 < input_ids.ndim < 3
|
||||
|
||||
if past_key_values:
|
||||
input_ids = input_ids[..., -1].unsqueeze(-1)
|
||||
|
||||
output = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
return output.logits, output.past_key_values
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.LongTensor,
|
||||
past_key_values: Optional[Tuple] = None,
|
||||
) -> torch.FloatTensor:
|
||||
logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values)
|
||||
next_token_logits = logits[..., -1, :]
|
||||
|
||||
return next_token_logits, kv_cache
|
||||
|
||||
|
||||
class TransformerTokenizer(Tokenizer):
|
||||
"""Represents a tokenizer for models in the `transformers` library."""
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
def __init__(self, model_name: str, **kwargs):
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
kwargs.setdefault("padding_side", "left")
|
||||
self.model_name = model_name
|
||||
# TODO: Do something to make this hashable?
|
||||
self.tokenizer = tokenizer
|
||||
self.kwargs = kwargs
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **kwargs)
|
||||
self.eos_token_id = self.tokenizer.eos_token_id
|
||||
self.eos_token = self.tokenizer.eos_token
|
||||
|
||||
@@ -212,55 +134,10 @@ class TransformerTokenizer(Tokenizer):
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, type(self)):
|
||||
return False
|
||||
# TODO(lsyin): the lru_cache for the TransoformerTokenizer is useless ?
|
||||
# return other.model_name == self.model_name and other.kwargs == self.kwargs
|
||||
return other.model_name == self.model_name and other.kwargs == self.kwargs
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self):
|
||||
from datasets.fingerprint import Hasher
|
||||
|
||||
return hash(Hasher.hash(self.tokenizer))
|
||||
|
||||
|
||||
def transformers(
|
||||
model_name: str,
|
||||
device: Optional[str] = None,
|
||||
model_kwargs: dict = {},
|
||||
tokenizer_kwargs: dict = {},
|
||||
):
|
||||
"""Instantiate a model from the `transformers` library and its tokenizer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name
|
||||
The name of the model as listed on Hugging Face's model page.
|
||||
device
|
||||
The device(s) on which the model should be loaded. This overrides
|
||||
the `device_map` entry in `model_kwargs` when provided.
|
||||
model_kwargs
|
||||
A dictionary that contains the keyword arguments to pass to the
|
||||
`from_pretrained` method when loading the model.
|
||||
tokenizer_kwargs
|
||||
A dictionary that contains the keyword arguments to pass to the
|
||||
`from_pretrained` method when loading the tokenizer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A `TransformersModel` model instance.
|
||||
|
||||
"""
|
||||
try:
|
||||
from transformers import AutoModelForCausalLM
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The `transformers` library needs to be installed in order to use `transformers` models."
|
||||
)
|
||||
|
||||
if device is not None:
|
||||
model_kwargs["device_map"] = device
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
|
||||
tokenizer = TransformerTokenizer(model_name, **tokenizer_kwargs)
|
||||
|
||||
return Transformer(model, tokenizer)
|
||||
|
||||
@@ -45,7 +45,7 @@ class Req:
|
||||
|
||||
# for constrained decoding
|
||||
self.regex_fsm = None
|
||||
self.regex_fsm_state = None
|
||||
self.regex_fsm_state = 0
|
||||
|
||||
def max_new_tokens(self):
|
||||
return self.sampling_params.max_new_tokens
|
||||
|
||||
@@ -111,7 +111,13 @@ class ModelRpcServer(rpyc.Service):
|
||||
self.stream_interval = server_args.stream_interval
|
||||
|
||||
# Init the FSM cache for constrained generation
|
||||
self.regex_fsm_cache = FSMCache(self.tokenizer)
|
||||
self.regex_fsm_cache = FSMCache(
|
||||
server_args.tokenizer_path,
|
||||
{
|
||||
"tokenizer_mode": server_args.tokenizer_mode,
|
||||
"trust_remote_code": server_args.trust_remote_code,
|
||||
},
|
||||
)
|
||||
|
||||
# Init new token estimation
|
||||
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
|
||||
@@ -213,6 +219,10 @@ class ModelRpcServer(rpyc.Service):
|
||||
req.stream = recv_req.stream
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
# Init regex fsm
|
||||
if req.sampling_params.regex is not None:
|
||||
req.regex_fsm = self.regex_fsm_cache.init_fsm(req.sampling_params.regex)
|
||||
|
||||
# Truncate long prompts
|
||||
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
|
||||
req.sampling_params.max_new_tokens = min(
|
||||
@@ -322,11 +332,10 @@ class ModelRpcServer(rpyc.Service):
|
||||
self.model_config.vocab_size, self.int_token_logit_bias
|
||||
)
|
||||
|
||||
# init the regex fsm before first sampling
|
||||
# Reset regex fsm state before first sampling due to retractions
|
||||
for req in batch.reqs:
|
||||
if req.sampling_params.regex is not None:
|
||||
req.regex_fsm_state = 0
|
||||
req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex)
|
||||
|
||||
if batch.extend_num_tokens != 0:
|
||||
# Forward
|
||||
|
||||
Reference in New Issue
Block a user