release initial code
Co-authored-by: Ying Sheng <sqy1415@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu> Co-authored-by: parasol-aser <3848358+parasol-aser@users.noreply.github.com> Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
385
python/sglang/srt/constrained/fsm.py
Normal file
385
python/sglang/srt/constrained/fsm.py
Normal file
@@ -0,0 +1,385 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/fsm.py
|
||||
from typing import List, NewType, Protocol
|
||||
|
||||
import interegular
|
||||
from lark import Lark
|
||||
|
||||
# from outlines.fsm.parsing import PartialLark
|
||||
from sglang.srt.constrained.regex import (
|
||||
create_fsm_index_tokenizer,
|
||||
make_deterministic_fsm,
|
||||
)
|
||||
from sglang.srt.constrained.tokenizer import Tokenizer
|
||||
|
||||
FSMState = NewType("FSMState", int)
|
||||
|
||||
|
||||
class FSM(Protocol):
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
...
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
...
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
...
|
||||
|
||||
def reset(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
class StopAtTokenFSM(FSM):
|
||||
"""FSM to generate text until a specified token id is generated or
|
||||
a specified number of tokens has been generated.
|
||||
|
||||
Text is usually produced until the EOS token is generated by the
|
||||
model.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: "Tokenizer",
|
||||
stop_token_id: int,
|
||||
):
|
||||
self.stop_token_id = stop_token_id
|
||||
self.num_tokens_generated = 0
|
||||
self.vocabulary = tokenizer.vocabulary.values()
|
||||
self.final_states = {1}
|
||||
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
"""Generate a list of allowed tokens for the next step.
|
||||
|
||||
When in the initial state we allow every token to be generated.
|
||||
In the final state the only allowed token is `stop_token_id`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list that contains the tokens to mask.
|
||||
|
||||
"""
|
||||
if state == 0:
|
||||
return list(self.vocabulary)
|
||||
else:
|
||||
return [self.stop_token_id]
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
"""Update the state of the FSM.
|
||||
|
||||
The FSM stays in the initial state `0` unless the specified stop token
|
||||
has been generated or the maximum number of tokens has been reached. In
|
||||
which case the FSM moves to the final state `1`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
token_id
|
||||
The id of the token that was just generated.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The new state of the FSM.
|
||||
|
||||
"""
|
||||
if idx == 0:
|
||||
self.num_tokens_generated += 1
|
||||
|
||||
if token_id == self.stop_token_id:
|
||||
return FSMState(1)
|
||||
|
||||
return FSMState(0)
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
"""Determine whether the current state of the FSM is a final state."""
|
||||
return state in self.final_states
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the FSM to its initial state. Here this only resets the token counter."""
|
||||
self.num_tokens_generated = 0
|
||||
|
||||
|
||||
class RegexFSM(FSM):
|
||||
"""FSM to generate text that is in the language of a regular expression."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
regex_string: str,
|
||||
tokenizer: "Tokenizer",
|
||||
):
|
||||
regex_pattern = interegular.parse_pattern(regex_string)
|
||||
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
|
||||
(
|
||||
self.states_to_token_maps,
|
||||
self.empty_token_ids,
|
||||
) = create_fsm_index_tokenizer(regex_fsm, tokenizer)
|
||||
|
||||
# We make sure that it is possible to generate strings in the language
|
||||
# of the regular expression with the tokens present in the model's
|
||||
# vocabulary.
|
||||
if not any(
|
||||
regex_fsm.finals.intersection(v.values())
|
||||
for v in self.states_to_token_maps.values()
|
||||
):
|
||||
raise ValueError(
|
||||
"The vocabulary does not allow us to build a sequence that matches the input regex"
|
||||
)
|
||||
|
||||
self.final_states = regex_fsm.finals | {
|
||||
-1
|
||||
} # Include the EOS token in final states
|
||||
self.num_tokens_generated = 0
|
||||
self.vocabulary = tokenizer.vocabulary.values()
|
||||
self.end_token_id = tokenizer.eos_token_id
|
||||
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
"""Generate a list of allowed tokens for the next step.
|
||||
|
||||
The initialization of the FSM builds an index which maps FSM states to a
|
||||
map from authorized tokens to the state in which the FSM needs to move
|
||||
if said token is generated. Therefore the authorized tokens at the
|
||||
current state are the keys of the map returned by the value of the index
|
||||
for current state.
|
||||
|
||||
If the current state is not contained in the end this means that we are
|
||||
in a final state of the FSM. We only authorize EOS tokens in the final
|
||||
state.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list that contains the tokens to mask.
|
||||
|
||||
"""
|
||||
next_tokens_to_end_states = self.states_to_token_maps.get(state)
|
||||
|
||||
if next_tokens_to_end_states is None:
|
||||
return [self.end_token_id]
|
||||
else:
|
||||
return list(next_tokens_to_end_states.keys())
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
"""Update the state of the FSM.
|
||||
|
||||
We use the index to determine to which state the FSM should transition
|
||||
given the token that was just generated.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
token_id
|
||||
The id of the token that was just generated.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The new state of the FSM.
|
||||
|
||||
"""
|
||||
if idx == 0:
|
||||
self.num_tokens_generated += 1
|
||||
|
||||
if token_id == self.end_token_id:
|
||||
return FSMState(-1)
|
||||
|
||||
last_token_to_end_state = self.states_to_token_maps[state]
|
||||
next_state = last_token_to_end_state.get(token_id)
|
||||
if next_state is None:
|
||||
next_state = -1
|
||||
|
||||
return FSMState(next_state)
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
"""Determine whether the current state of the FSM is a final state."""
|
||||
return state in self.final_states
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the FSM to its initial state. Here this only resets the token counter."""
|
||||
self.num_tokens_generated = 0
|
||||
|
||||
|
||||
class CFGFSM(FSM):
|
||||
"""FSM to generate text that is in the language of a context-free grammar."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg_string: str,
|
||||
tokenizer: "Tokenizer",
|
||||
):
|
||||
# self.parser = PartialLark(cfg_string, parser="lalr")
|
||||
self.parser = Lark(
|
||||
cfg_string,
|
||||
parser="lalr",
|
||||
lexer="contextual",
|
||||
propagate_positions=False,
|
||||
maybe_placeholders=False,
|
||||
regex=True,
|
||||
)
|
||||
self.terminal_regexps = dict()
|
||||
for terminal in self.parser.terminals:
|
||||
if terminal.pattern is not None:
|
||||
self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp()
|
||||
self.terminal_regexps["$END"] = tokenizer.eos_token
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.num_tokens_generated = 0
|
||||
self.generations: List[str] = []
|
||||
self.regex_fsms: List[RegexFSM] = []
|
||||
self.reset_state: List[bool] = []
|
||||
self.allow_eos: List[bool] = []
|
||||
self.done: List[bool] = []
|
||||
|
||||
def _set_next_regex_fsm(self, idx: int = 0) -> None:
|
||||
"""Use the CFG incremental parser to set the next regex FSM.
|
||||
|
||||
Check what the CFG incremental parser proposes next.
|
||||
If the only proposal is the EOS token,
|
||||
we set the state to done and return.
|
||||
If there are other proposals,
|
||||
we set a new regex FSM and return.
|
||||
|
||||
"""
|
||||
interactive = self.parser.parse_interactive(self.generations[idx])
|
||||
interactive.exhaust_lexer()
|
||||
options = {self.terminal_regexps[x] for x in interactive.accepts()}
|
||||
|
||||
if self.terminal_regexps["$END"] in options:
|
||||
options.remove(self.terminal_regexps["$END"])
|
||||
if len(options) == 0:
|
||||
self.done[idx] = True
|
||||
return
|
||||
self.allow_eos[idx] = True
|
||||
options.add("")
|
||||
assert len(options) > 1
|
||||
|
||||
regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")"
|
||||
args = (
|
||||
regex_string,
|
||||
self.tokenizer,
|
||||
)
|
||||
if len(self.regex_fsms) <= idx:
|
||||
self.regex_fsms.append(RegexFSM(*args))
|
||||
else:
|
||||
self.regex_fsms[idx] = RegexFSM(*args)
|
||||
self.reset_state[idx] = True
|
||||
|
||||
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
|
||||
"""Generate a list of allowed tokens for the next step.
|
||||
|
||||
Upon initialization, the CFG incremental parser is used to determine the first regex.
|
||||
|
||||
This regex is used for proposals until either:
|
||||
- the regex is exhausted, and its only remaining option is the EOS token,
|
||||
in which case we always transition to the next regex
|
||||
- the regex can be exhausted, but the EOS token is not the only remaining option,
|
||||
in which case we transition to the next regex with probability P (TODO)
|
||||
or remove the possibility of generating the EOS token and continue with the current regex
|
||||
|
||||
The CFG incremental parser is allowed to propose the EOS token from any final state,
|
||||
and once it is generated, the FSM will continue to always generate the EOS token.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list that contains the tokens to mask.
|
||||
|
||||
"""
|
||||
if len(self.generations) <= idx:
|
||||
self.generations.append("")
|
||||
self.reset_state.append(False)
|
||||
self.allow_eos.append(False)
|
||||
self.done.append(False)
|
||||
|
||||
if len(self.regex_fsms) > idx:
|
||||
proposal = self.regex_fsms[idx].allowed_token_ids(state)
|
||||
if self.tokenizer.eos_token_id not in proposal:
|
||||
return proposal
|
||||
if set(proposal) != {self.tokenizer.eos_token_id}:
|
||||
if False: # TODO: THIS NEEDS TO BE SAMPLED
|
||||
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
|
||||
return proposal
|
||||
|
||||
self._set_next_regex_fsm(idx)
|
||||
|
||||
if self.done[idx]:
|
||||
return [self.tokenizer.eos_token_id]
|
||||
|
||||
if self.reset_state[idx]:
|
||||
state = FSMState(0)
|
||||
|
||||
proposal = self.regex_fsms[idx].allowed_token_ids(state)
|
||||
if self.allow_eos[idx]:
|
||||
self.allow_eos[idx] = False
|
||||
else:
|
||||
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
|
||||
assert len(proposal) > 0
|
||||
return proposal
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
|
||||
"""Update the state of the FSM.
|
||||
|
||||
Transitions the underlying regex FSM to its next state.
|
||||
If at max tokens or EOS token, transition permanently to the final state.
|
||||
Update stored partial generations for subsequent incremental parsing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
token_id
|
||||
The id of the token that was just generated.
|
||||
idx
|
||||
The index of the current input in the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The new state of the FSM.
|
||||
"""
|
||||
if idx == 0:
|
||||
self.num_tokens_generated += 1
|
||||
if token_id == self.tokenizer.eos_token_id:
|
||||
self.done[idx] = True
|
||||
return FSMState(-1)
|
||||
if self.reset_state[idx]:
|
||||
self.reset_state[idx] = False
|
||||
state = FSMState(0)
|
||||
|
||||
self.generations[idx] += self.tokenizer.decode([token_id])[0]
|
||||
|
||||
return self.regex_fsms[idx].next_state(state, token_id, idx)
|
||||
|
||||
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
|
||||
"""Return whether the current state of the FSM is a final state."""
|
||||
return self.done[idx]
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the FSM to its initial state, so it can be called on a fresh batch on inputs."""
|
||||
self.num_tokens_generated = 0
|
||||
self.generations = []
|
||||
self.regex_fsms = []
|
||||
self.reset_state = []
|
||||
self.done = []
|
||||
41
python/sglang/srt/constrained/fsm_cache.py
Normal file
41
python/sglang/srt/constrained/fsm_cache.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import threading
|
||||
|
||||
from sglang.srt.constrained.fsm import RegexFSM
|
||||
from sglang.srt.constrained.tokenizer import TransformerTokenizer
|
||||
|
||||
|
||||
def get_fsm(regex, tokenizer, fsm_cache_entry):
|
||||
outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||
fsm = RegexFSM(regex, outlines_tokenizer)
|
||||
fsm_cache_entry.fsm = fsm
|
||||
fsm_cache_entry.event.set()
|
||||
|
||||
|
||||
class FSMCacheEntry:
|
||||
def __init__(self):
|
||||
self.fsm = None
|
||||
self.event = threading.Event()
|
||||
|
||||
|
||||
class FSMCache:
|
||||
def __init__(self, tokenizer):
|
||||
self.cache = {}
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def init_fsm_in_background(self, regex):
|
||||
if regex not in self.cache:
|
||||
self.cache[regex] = FSMCacheEntry()
|
||||
threading.Thread(
|
||||
target=get_fsm,
|
||||
args=(
|
||||
regex,
|
||||
self.tokenizer,
|
||||
self.cache[regex],
|
||||
),
|
||||
).start()
|
||||
|
||||
def get_fsm(self, regex):
|
||||
self.init_fsm_in_background(regex)
|
||||
entry = self.cache[regex]
|
||||
entry.event.wait()
|
||||
return entry.fsm
|
||||
586
python/sglang/srt/constrained/regex.py
Normal file
586
python/sglang/srt/constrained/regex.py
Normal file
@@ -0,0 +1,586 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/regex.py
|
||||
from collections import namedtuple
|
||||
from functools import lru_cache
|
||||
from typing import Dict, Generator, List, Sequence, Set, Tuple
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
from interegular.fsm import FSM, Alphabet, OblivionError, anything_else
|
||||
from numba.typed.typedobjectutils import _nonoptional
|
||||
from sglang.srt.constrained.tokenizer import Tokenizer
|
||||
|
||||
|
||||
class BetterAlphabet(Alphabet):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert anything_else in self._symbol_mapping
|
||||
self.anything_value = self._symbol_mapping[anything_else]
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._symbol_mapping.get(item, self.anything_value)
|
||||
|
||||
def copy(self):
|
||||
return BetterAlphabet(self._symbol_mapping.copy())
|
||||
|
||||
|
||||
class BetterFSM(FSM):
|
||||
flat_transition_map: Dict[Tuple[int, int], int]
|
||||
trans_key_to_states: Dict[int, List[int]]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if not isinstance(self.alphabet, BetterAlphabet):
|
||||
self.__dict__["alphabet"] = BetterAlphabet(self.alphabet._symbol_mapping)
|
||||
|
||||
flat_transition_map = {}
|
||||
trans_key_to_states = {}
|
||||
for from_state, trans_map in self.map.items():
|
||||
for trans_key, to_state in trans_map.items():
|
||||
flat_transition_map[(from_state, trans_key)] = to_state
|
||||
trans_key_to_states.setdefault(trans_key, set()).add(from_state)
|
||||
|
||||
self.__dict__["trans_key_to_states"] = trans_key_to_states
|
||||
self.__dict__["flat_transition_map"] = flat_transition_map
|
||||
self.__dict__["_fsm_info"] = None
|
||||
|
||||
def copy(self):
|
||||
return BetterFSM(
|
||||
alphabet=self.alphabet.copy(),
|
||||
states=self.states.copy(),
|
||||
initial=self.initial,
|
||||
finals=self.finals.copy(),
|
||||
map=self.map.copy(),
|
||||
__no_validation__=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def fsm_info(self):
|
||||
if self._fsm_info is None:
|
||||
flat_transition_map_items = np.fromiter(
|
||||
((a[0], a[1], b) for a, b in self.flat_transition_map.items()),
|
||||
dtype=np.dtype("i8, i8, i8"),
|
||||
)
|
||||
trans_key_to_states_items = np.fromiter(
|
||||
((k, z) for k, v in self.trans_key_to_states.items() for z in v),
|
||||
dtype=np.dtype("i8, i8"),
|
||||
)
|
||||
alphabet_symbol_mapping_items = np.fromiter(
|
||||
(
|
||||
it
|
||||
for it in self.alphabet._symbol_mapping.items()
|
||||
if it[0] != anything_else
|
||||
),
|
||||
dtype=np.dtype("U1, i8"),
|
||||
)
|
||||
nb_finals = np.fromiter(self.finals, dtype=np.dtype("i8"))
|
||||
self.__dict__["_fsm_info"] = create_fsm_info(
|
||||
self.initial,
|
||||
nb_finals,
|
||||
flat_transition_map_items,
|
||||
trans_key_to_states_items,
|
||||
self.alphabet.anything_value,
|
||||
alphabet_symbol_mapping_items,
|
||||
)
|
||||
|
||||
return self._fsm_info
|
||||
|
||||
|
||||
nb_int_list_type = numba.types.ListType(numba.int64)
|
||||
nb_int_pair_type = numba.types.UniTuple(numba.int64, 2)
|
||||
nb_unichar_1_type = numba.types.UnicodeCharSeq(1)
|
||||
|
||||
|
||||
@numba.njit(cache=True)
|
||||
def create_fsm_info(
|
||||
py_initial,
|
||||
py_finals,
|
||||
flat_transition_map_items,
|
||||
trans_key_to_states_items,
|
||||
py_anything_value,
|
||||
alphabet_symbol_mapping_items,
|
||||
):
|
||||
trans_key_to_states = numba.typed.Dict.empty(numba.int64, nb_int_list_type)
|
||||
for trans_key_and_state in trans_key_to_states_items:
|
||||
trans_key_to_states.setdefault(
|
||||
trans_key_and_state[0], numba.typed.List.empty_list(numba.int64)
|
||||
).append(trans_key_and_state[1])
|
||||
|
||||
flat_transition_map = numba.typed.Dict.empty(nb_int_pair_type, numba.int64)
|
||||
for trans_key_and_state in flat_transition_map_items:
|
||||
flat_transition_map[
|
||||
(trans_key_and_state[0], trans_key_and_state[1])
|
||||
] = trans_key_and_state[2]
|
||||
|
||||
alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_1_type, numba.int64)
|
||||
for symbol_and_trans_key in alphabet_symbol_mapping_items:
|
||||
alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1]
|
||||
|
||||
initial = numba.int64(py_initial)
|
||||
|
||||
finals = set()
|
||||
for final in py_finals:
|
||||
finals.add(final)
|
||||
|
||||
anything_value = numba.int64(py_anything_value)
|
||||
|
||||
return FSMInfo(
|
||||
initial,
|
||||
finals,
|
||||
flat_transition_map,
|
||||
trans_key_to_states,
|
||||
anything_value,
|
||||
alphabet_symbol_map,
|
||||
)
|
||||
|
||||
|
||||
FSMInfo = namedtuple(
|
||||
"FSMInfo",
|
||||
[
|
||||
"initial",
|
||||
"finals",
|
||||
"transitions",
|
||||
"trans_key_to_states",
|
||||
"alphabet_anything_value",
|
||||
"alphabet_symbol_mapping",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
|
||||
"""Construct an equivalent FSM with deterministic state labels."""
|
||||
old_to_new_trans_keys = {
|
||||
trans_key: i
|
||||
for i, (trans_key, _) in enumerate(
|
||||
sorted(fsm.alphabet.by_transition.items(), key=lambda x: sorted(x[1]))
|
||||
)
|
||||
}
|
||||
|
||||
new_symbol_mapping = {
|
||||
symbol: old_to_new_trans_keys[trans_key]
|
||||
for symbol, trans_key in fsm.alphabet._symbol_mapping.items()
|
||||
}
|
||||
|
||||
new_alphabet = BetterAlphabet(new_symbol_mapping)
|
||||
|
||||
new_map = {
|
||||
from_state: {
|
||||
old_to_new_trans_keys[trans_key]: to_state
|
||||
for trans_key, to_state in trans_map.items()
|
||||
}
|
||||
for from_state, trans_map in fsm.map.items()
|
||||
}
|
||||
|
||||
old_to_new_states = {}
|
||||
old_to_new_states[fsm.initial] = 0
|
||||
|
||||
i = 0
|
||||
seen = {fsm.initial}
|
||||
old_state_queue = [fsm.initial]
|
||||
while old_state_queue:
|
||||
old_state = old_state_queue.pop(-1)
|
||||
transitions = new_map[old_state]
|
||||
sorted_transitions = sorted(transitions.items(), key=lambda v: v[0])
|
||||
for _, old_state in sorted_transitions:
|
||||
if old_state not in seen:
|
||||
old_state_queue.append(old_state)
|
||||
seen.add(old_state)
|
||||
if old_state not in old_to_new_states:
|
||||
i += 1
|
||||
old_to_new_states[old_state] = i
|
||||
|
||||
new_map = dict(
|
||||
sorted(
|
||||
(
|
||||
(
|
||||
old_to_new_states[from_state],
|
||||
dict(
|
||||
sorted(
|
||||
(
|
||||
(trans_key, old_to_new_states[to_state])
|
||||
for trans_key, to_state in trans_map.items()
|
||||
),
|
||||
key=lambda v: v[0],
|
||||
)
|
||||
),
|
||||
)
|
||||
for from_state, trans_map in new_map.items()
|
||||
),
|
||||
key=lambda v: v[0],
|
||||
)
|
||||
)
|
||||
|
||||
new_initial = 0
|
||||
new_finals = frozenset(
|
||||
sorted(old_to_new_states[old_state] for old_state in fsm.finals)
|
||||
)
|
||||
new_states = frozenset(sorted(new_map.keys()))
|
||||
|
||||
new_fsm = BetterFSM(new_alphabet, new_states, new_initial, new_finals, new_map)
|
||||
|
||||
return new_fsm, old_to_new_states
|
||||
|
||||
|
||||
@numba.njit(nogil=True, cache=True)
|
||||
def _walk_fsm(
|
||||
fsm_transitions: Dict[Tuple[int, int], int],
|
||||
alphabet_symbol_mapping: Dict[str, int],
|
||||
alphabet_anything_value: int,
|
||||
fsm_initial: int,
|
||||
fsm_finals: Set[int],
|
||||
input_string: str,
|
||||
start_state: int,
|
||||
full_match: bool = True,
|
||||
) -> List[int]:
|
||||
state = start_state
|
||||
accepted_states: List[int] = numba.typed.List.empty_list(numba.int64)
|
||||
last_final_idx: int = numba.uint64(0)
|
||||
|
||||
for i, symbol in enumerate(input_string):
|
||||
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
|
||||
|
||||
new_state = fsm_transitions.get((state, trans_key))
|
||||
|
||||
if new_state is None:
|
||||
if not full_match and last_final_idx > 0:
|
||||
return accepted_states[:last_final_idx]
|
||||
|
||||
return numba.typed.List.empty_list(numba.int64)
|
||||
|
||||
state = new_state
|
||||
|
||||
if state in fsm_finals:
|
||||
last_final_idx = numba.uint64(i + 1)
|
||||
|
||||
accepted_states.append(_nonoptional(state))
|
||||
|
||||
if full_match and last_final_idx - 1 != i:
|
||||
return numba.typed.List.empty_list(numba.int64)
|
||||
|
||||
return accepted_states
|
||||
|
||||
|
||||
def walk_fsm(
|
||||
fsm: BetterFSM,
|
||||
input_string: str,
|
||||
start_state: int,
|
||||
full_match: bool = True,
|
||||
) -> List[int]:
|
||||
fsm_finals = fsm.finals
|
||||
|
||||
state = start_state
|
||||
accepted_states: List[int] = []
|
||||
last_final_idx: int = 0
|
||||
|
||||
alphabet_symbol_mapping = fsm.alphabet._symbol_mapping
|
||||
alphabet_anything_value = fsm.alphabet.anything_value
|
||||
fsm_transitions = fsm.flat_transition_map
|
||||
|
||||
for i, symbol in enumerate(input_string):
|
||||
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
|
||||
|
||||
new_state = fsm_transitions.get((state, trans_key))
|
||||
|
||||
if new_state is None:
|
||||
if not full_match and last_final_idx > 0:
|
||||
return accepted_states[:last_final_idx]
|
||||
|
||||
return []
|
||||
|
||||
state = new_state
|
||||
|
||||
if state in fsm_finals:
|
||||
last_final_idx = i + 1
|
||||
|
||||
accepted_states.append(state)
|
||||
|
||||
if full_match and last_final_idx - 1 != i:
|
||||
return []
|
||||
|
||||
return accepted_states
|
||||
|
||||
|
||||
def fsm_union(
|
||||
fsms: Sequence[FSM],
|
||||
) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]:
|
||||
"""Construct an FSM representing the union of the FSMs in `fsms`.
|
||||
|
||||
This is an updated version of `interegular.fsm.FSM.union` made to return an
|
||||
extra map of component FSMs to the sets of state transitions that
|
||||
correspond to them in the new FSM.
|
||||
|
||||
"""
|
||||
|
||||
alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms])
|
||||
|
||||
indexed_fsms = tuple(enumerate(fsms))
|
||||
|
||||
initial = {i: fsm.initial for (i, fsm) in indexed_fsms}
|
||||
|
||||
# Dedicated function accepting a "superset" and returning the next
|
||||
# "superset" obtained by following this transition in the new FSM
|
||||
def follow(current_state, new_transition: int):
|
||||
next = {}
|
||||
for i, f in indexed_fsms:
|
||||
old_transition = new_to_old[i][new_transition]
|
||||
if (
|
||||
i in current_state
|
||||
and current_state[i] in f.map
|
||||
and old_transition in f.map[current_state[i]]
|
||||
):
|
||||
next[i] = f.map[current_state[i]][old_transition]
|
||||
if not next:
|
||||
raise OblivionError
|
||||
return next
|
||||
|
||||
states = [initial]
|
||||
finals: Set[int] = set()
|
||||
map: Dict[int, Dict[int, int]] = {}
|
||||
|
||||
# Map component FSMs to their new state-to-state transitions, finals, and a
|
||||
# map translating component FSM states to aggregate FSM states
|
||||
fsms_to_trans_finals: Dict[
|
||||
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
|
||||
] = {}
|
||||
|
||||
i = 0
|
||||
while i < len(states):
|
||||
state = states[i]
|
||||
|
||||
# Add to the finals of the aggregate FSM whenever we hit a final in a
|
||||
# component FSM
|
||||
if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms):
|
||||
finals.add(i)
|
||||
|
||||
# Compute the map for this state
|
||||
map[i] = {}
|
||||
for transition in alphabet.by_transition:
|
||||
try:
|
||||
next = follow(state, transition)
|
||||
except OblivionError:
|
||||
# Reached an oblivion state; don't list it
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
# TODO: Seems like this could--and should--be avoided
|
||||
j = states.index(next)
|
||||
except ValueError:
|
||||
j = len(states)
|
||||
states.append(next)
|
||||
|
||||
map[i][transition] = j
|
||||
|
||||
for fsm_id, fsm_state in next.items():
|
||||
(
|
||||
fsm_transitions,
|
||||
fsm_finals,
|
||||
fsm_old_to_new,
|
||||
) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {}))
|
||||
old_from = state[fsm_id]
|
||||
old_to = fsm_state
|
||||
fsm_old_to_new.setdefault(old_from, set()).add(i)
|
||||
fsm_old_to_new.setdefault(old_to, set()).add(j)
|
||||
fsm_transitions.add((i, j))
|
||||
if fsm_state in fsms[fsm_id].finals:
|
||||
fsm_finals.add(j)
|
||||
|
||||
i += 1
|
||||
|
||||
fsm = FSM(
|
||||
alphabet=alphabet,
|
||||
states=range(len(states)),
|
||||
initial=0,
|
||||
finals=finals,
|
||||
map=map,
|
||||
__no_validation__=True,
|
||||
)
|
||||
|
||||
fsm, old_to_new_states = make_deterministic_fsm(fsm)
|
||||
_fsms_to_trans_finals = {
|
||||
fsm_id: (
|
||||
{(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions},
|
||||
{old_to_new_states[s] for s in finals},
|
||||
{
|
||||
old_state: {old_to_new_states[new_state] for new_state in new_states}
|
||||
for old_state, new_states in old_to_new.items()
|
||||
},
|
||||
)
|
||||
for fsm_id, (transitions, finals, old_to_new) in sorted(
|
||||
fsms_to_trans_finals.items(), key=lambda x: x[0]
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
fsm,
|
||||
_fsms_to_trans_finals,
|
||||
)
|
||||
|
||||
|
||||
def get_sub_fsms_from_seq(
|
||||
state_seq: Sequence[int],
|
||||
fsms_to_trans_finals: Dict[
|
||||
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
|
||||
],
|
||||
) -> Generator[Tuple[int, bool, bool], None, None]:
|
||||
"""Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state_seq
|
||||
A state sequence.
|
||||
fsms_to_trans_finals
|
||||
A map from FSM indices to tuples containing sets of their state transitions
|
||||
and sets of the final/accept states.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A generator returning tuples containing each sub-FSM index (in the order
|
||||
they were union-ed to construct `fsm`) and booleans indicating whether or
|
||||
not there is another valid transition from the last state in the sequence
|
||||
for the associated sub-FSM (i.e. if the FSM can continue
|
||||
accepting/matching) and whether or not the sequence ends in a final state
|
||||
of the sub-FSM.
|
||||
"""
|
||||
state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:]))
|
||||
last_fsm_state = state_seq[-1]
|
||||
yield from (
|
||||
(
|
||||
# The sub-FMS index
|
||||
fsm_idx,
|
||||
# Is there another possible transition in this sub-FSM?
|
||||
any(last_fsm_state == from_s for (from_s, to_s) in transitions),
|
||||
# Is this sub-FSM in a final state?
|
||||
state_seq[-1] in finals,
|
||||
)
|
||||
for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items()
|
||||
if state_seq_transitions.issubset(transitions)
|
||||
)
|
||||
|
||||
|
||||
@numba.njit(cache=True, nogil=True)
|
||||
def state_scan_tokens(
|
||||
fsm_transitions: Dict[Tuple[int, int], int],
|
||||
alphabet_symbol_mapping: Dict[str, int],
|
||||
alphabet_anything_value: int,
|
||||
fsm_initial: int,
|
||||
fsm_finals: Set[int],
|
||||
vocabulary: Dict[str, List[int]],
|
||||
start_state: int,
|
||||
) -> Set[Tuple[int, int]]:
|
||||
res = set()
|
||||
|
||||
for token, token_ids in vocabulary.items():
|
||||
state_seq = _walk_fsm(
|
||||
fsm_transitions,
|
||||
alphabet_symbol_mapping,
|
||||
alphabet_anything_value,
|
||||
fsm_initial,
|
||||
fsm_finals,
|
||||
token,
|
||||
start_state,
|
||||
False,
|
||||
)
|
||||
|
||||
if state_seq is not None and len(state_seq) < len(token):
|
||||
continue
|
||||
|
||||
for token_id in token_ids:
|
||||
res.add((token_id, state_seq[-1]))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def create_fsm_index_end_to_end(
|
||||
fsm_info: FSMInfo,
|
||||
vocabulary: Dict[str, List[int]],
|
||||
) -> Dict[int, Set[Tuple[int, int]]]:
|
||||
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""
|
||||
|
||||
# TODO: Consider using a `List` of `Set`s instead; that way we can JIT this
|
||||
# code, too.
|
||||
states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = {}
|
||||
seen: Set[int] = set()
|
||||
next_states = {fsm_info.initial}
|
||||
|
||||
while next_states:
|
||||
start_state = next_states.pop()
|
||||
|
||||
token_ids_end_states = state_scan_tokens(
|
||||
fsm_info.transitions,
|
||||
fsm_info.alphabet_symbol_mapping,
|
||||
fsm_info.alphabet_anything_value,
|
||||
fsm_info.initial,
|
||||
fsm_info.finals,
|
||||
vocabulary,
|
||||
start_state,
|
||||
)
|
||||
|
||||
for token_id_and_end_state in token_ids_end_states:
|
||||
states_to_token_subsets.setdefault(start_state, set()).add(
|
||||
token_id_and_end_state
|
||||
)
|
||||
end_state = token_id_and_end_state[1]
|
||||
if end_state not in seen:
|
||||
next_states.add(end_state)
|
||||
|
||||
seen.add(start_state)
|
||||
|
||||
return states_to_token_subsets
|
||||
|
||||
|
||||
# TODO: Cannot cache typed collections to disk, yet. See
|
||||
# https://github.com/numba/numba/issues/4698
|
||||
@lru_cache
|
||||
def reduced_vocabulary(tokenizer: "Tokenizer"):
|
||||
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
|
||||
vocabulary = numba.typed.Dict.empty(
|
||||
numba.types.string, numba.types.ListType(numba.int64)
|
||||
)
|
||||
empty_token_ids = set()
|
||||
for token, token_idx in tokenizer.vocabulary.items():
|
||||
if token in tokenizer.special_tokens:
|
||||
continue
|
||||
|
||||
token_str = tokenizer.convert_token_to_string(token)
|
||||
|
||||
if token_str:
|
||||
vocabulary.setdefault(
|
||||
token_str,
|
||||
numba.typed.List.empty_list(numba.int64),
|
||||
).append(numba.int64(token_idx))
|
||||
else:
|
||||
empty_token_ids.add(numba.int64(token_idx))
|
||||
|
||||
return vocabulary, empty_token_ids
|
||||
|
||||
|
||||
def create_fsm_index_tokenizer(
|
||||
fsm: BetterFSM,
|
||||
tokenizer: "Tokenizer",
|
||||
) -> Tuple[Dict[int, Dict[int, int]], Set[int]]:
|
||||
"""Construct an FMS index from a tokenizer.
|
||||
|
||||
This uses the end-to-end approach of `create_fsm_index_end_to_end`.
|
||||
|
||||
.. warning::
|
||||
|
||||
`fsm` needs to be deterministically ordered so that future caching makes sense.
|
||||
|
||||
"""
|
||||
vocabulary, empty_token_ids = reduced_vocabulary(tokenizer)
|
||||
|
||||
states_to_token_subsets = create_fsm_index_end_to_end(fsm.fsm_info, vocabulary)
|
||||
|
||||
# Allow transitions to EOS from all terminals FSM states that are
|
||||
# reachable
|
||||
# TODO: Do we really need this anymore?
|
||||
for state in fsm.fsm_info.finals:
|
||||
subset = states_to_token_subsets.get(state)
|
||||
if subset is not None:
|
||||
subset.add((tokenizer.eos_token_id, state))
|
||||
|
||||
# Convert to token-to-end-state maps
|
||||
states_to_token_subsets = {k: dict(v) for k, v in states_to_token_subsets.items()}
|
||||
|
||||
return states_to_token_subsets, empty_token_ids
|
||||
266
python/sglang/srt/constrained/tokenizer.py
Normal file
266
python/sglang/srt/constrained/tokenizer.py
Normal file
@@ -0,0 +1,266 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py
|
||||
from abc import abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
Hashable,
|
||||
List,
|
||||
Optional,
|
||||
Protocol,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
class Tokenizer(Protocol, Hashable):
|
||||
eos_token: str
|
||||
eos_token_id: int
|
||||
pad_token_id: int
|
||||
vocabulary: Dict[str, int]
|
||||
special_tokens: Set[int]
|
||||
|
||||
@abstractmethod
|
||||
def encode(
|
||||
self, prompt: Union[str, List[str]]
|
||||
) -> Tuple[NDArray[np.int64], NDArray[np.int64]]:
|
||||
"""Translate the input prompts into NumPy arrays of token ids and attention mask."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, token_ids: NDArray[np.int64]) -> List[str]:
|
||||
"""Translate an array of token ids to a string or list of strings."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def convert_token_to_string(self, token: str) -> str:
|
||||
"""Convert a token to its equivalent string.
|
||||
|
||||
This is for instance useful for BPE tokenizers where whitespaces are
|
||||
represented by the special characted `Ġ`. This prevents matching a raw
|
||||
token that includes `Ġ` with a string.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
__all__ = ["transformers"]
|
||||
|
||||
|
||||
KVCacheType = Tuple[Tuple[torch.DoubleTensor, torch.DoubleTensor], ...]
|
||||
|
||||
|
||||
def get_llama_tokenizer_types():
|
||||
"""Get all the Llama tokenizer types/classes that need work-arounds.
|
||||
|
||||
When they can't be imported, a dummy class is created.
|
||||
|
||||
"""
|
||||
try:
|
||||
from transformers.models.llama import LlamaTokenizer
|
||||
except ImportError:
|
||||
|
||||
class LlamaTokenizer: # type: ignore
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers.models.llama import LlamaTokenizerFast
|
||||
except ImportError:
|
||||
|
||||
class LlamaTokenizerFast: # type: ignore
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers.models.code_llama import CodeLlamaTokenizer
|
||||
except ImportError:
|
||||
|
||||
class CodeLlamaTokenizer: # type: ignore
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers.models.code_llama import CodeLlamaTokenizerFast
|
||||
except ImportError:
|
||||
|
||||
class CodeLlamaTokenizerFast: # type: ignore
|
||||
pass
|
||||
|
||||
return (
|
||||
LlamaTokenizer,
|
||||
LlamaTokenizerFast,
|
||||
CodeLlamaTokenizer,
|
||||
CodeLlamaTokenizerFast,
|
||||
)
|
||||
|
||||
|
||||
class Transformer:
|
||||
"""Represents a `transformers` model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
):
|
||||
self.device = model.device
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@torch.inference_mode
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.LongTensor,
|
||||
past_key_values: Optional[Tuple] = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[KVCacheType]]:
|
||||
"""Compute a forward pass through the transformer model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_ids
|
||||
The input token ids. Must be one or two dimensional.
|
||||
attention_mask
|
||||
The attention mask. Must be one or two dimensional.
|
||||
past_key_values
|
||||
A tuple of tuples containing the cached key and value tensors for each
|
||||
attention head.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The computed logits and the new cached key and value tensors.
|
||||
|
||||
"""
|
||||
assert 0 < input_ids.ndim < 3
|
||||
|
||||
if past_key_values:
|
||||
input_ids = input_ids[..., -1].unsqueeze(-1)
|
||||
|
||||
output = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
return output.logits, output.past_key_values
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.LongTensor,
|
||||
past_key_values: Optional[Tuple] = None,
|
||||
) -> torch.FloatTensor:
|
||||
logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values)
|
||||
next_token_logits = logits[..., -1, :]
|
||||
|
||||
return next_token_logits, kv_cache
|
||||
|
||||
|
||||
class TransformerTokenizer(Tokenizer):
|
||||
"""Represents a tokenizer for models in the `transformers` library."""
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
# TODO: Do something to make this hashable?
|
||||
self.tokenizer = tokenizer
|
||||
self.eos_token_id = self.tokenizer.eos_token_id
|
||||
self.eos_token = self.tokenizer.eos_token
|
||||
|
||||
if not self.tokenizer.pad_token_id:
|
||||
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
||||
self.pad_token_id = self.eos_token_id
|
||||
else:
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
self.pad_token = self.tokenizer.pad_token
|
||||
|
||||
self.special_tokens = set(self.tokenizer.all_special_tokens)
|
||||
|
||||
self.vocabulary = self.tokenizer.get_vocab()
|
||||
self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types())
|
||||
|
||||
def encode(
|
||||
self, prompt: Union[str, List[str]], **kwargs
|
||||
) -> Tuple[torch.LongTensor, torch.LongTensor]:
|
||||
kwargs["padding"] = True
|
||||
kwargs["return_tensors"] = "pt"
|
||||
output = self.tokenizer(prompt, **kwargs)
|
||||
return output["input_ids"], output["attention_mask"]
|
||||
|
||||
def decode(self, token_ids: torch.LongTensor) -> List[str]:
|
||||
text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
|
||||
return text
|
||||
|
||||
def convert_token_to_string(self, token: str) -> str:
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
string = self.tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
if self.is_llama:
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, type(self)):
|
||||
return False
|
||||
# TODO(lsyin): the lru_cache for the TransoformerTokenizer is useless ?
|
||||
# return other.model_name == self.model_name and other.kwargs == self.kwargs
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self):
|
||||
from datasets.fingerprint import Hasher
|
||||
|
||||
return hash(Hasher.hash(self.tokenizer))
|
||||
|
||||
|
||||
def transformers(
|
||||
model_name: str,
|
||||
device: Optional[str] = None,
|
||||
model_kwargs: dict = {},
|
||||
tokenizer_kwargs: dict = {},
|
||||
):
|
||||
"""Instantiate a model from the `transformers` library and its tokenizer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name
|
||||
The name of the model as listed on Hugging Face's model page.
|
||||
device
|
||||
The device(s) on which the model should be loaded. This overrides
|
||||
the `device_map` entry in `model_kwargs` when provided.
|
||||
model_kwargs
|
||||
A dictionary that contains the keyword arguments to pass to the
|
||||
`from_pretrained` method when loading the model.
|
||||
tokenizer_kwargs
|
||||
A dictionary that contains the keyword arguments to pass to the
|
||||
`from_pretrained` method when loading the tokenizer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A `TransformersModel` model instance.
|
||||
|
||||
"""
|
||||
try:
|
||||
from transformers import AutoModelForCausalLM
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The `transformers` library needs to be installed in order to use `transformers` models."
|
||||
)
|
||||
|
||||
if device is not None:
|
||||
model_kwargs["device_map"] = device
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
|
||||
tokenizer = TransformerTokenizer(model_name, **tokenizer_kwargs)
|
||||
|
||||
return Transformer(model, tokenizer)
|
||||
Reference in New Issue
Block a user