release initial code

Co-authored-by: Ying Sheng <sqy1415@gmail.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
Co-authored-by: parasol-aser <3848358+parasol-aser@users.noreply.github.com>
Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Lianmin Zheng
2024-01-08 04:37:50 +00:00
parent f6d40df0ee
commit 22085081bb
145 changed files with 17802 additions and 2 deletions

View File

@@ -0,0 +1,385 @@
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/fsm.py
from typing import List, NewType, Protocol
import interegular
from lark import Lark
# from outlines.fsm.parsing import PartialLark
from sglang.srt.constrained.regex import (
create_fsm_index_tokenizer,
make_deterministic_fsm,
)
from sglang.srt.constrained.tokenizer import Tokenizer
FSMState = NewType("FSMState", int)
class FSM(Protocol):
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
...
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
...
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
...
def reset(self) -> None:
...
class StopAtTokenFSM(FSM):
"""FSM to generate text until a specified token id is generated or
a specified number of tokens has been generated.
Text is usually produced until the EOS token is generated by the
model.
"""
def __init__(
self,
tokenizer: "Tokenizer",
stop_token_id: int,
):
self.stop_token_id = stop_token_id
self.num_tokens_generated = 0
self.vocabulary = tokenizer.vocabulary.values()
self.final_states = {1}
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
"""Generate a list of allowed tokens for the next step.
When in the initial state we allow every token to be generated.
In the final state the only allowed token is `stop_token_id`.
Parameters
----------
state
The current state of the FSM.
idx
The index of the current input in the batch.
Returns
-------
A list that contains the tokens to mask.
"""
if state == 0:
return list(self.vocabulary)
else:
return [self.stop_token_id]
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
"""Update the state of the FSM.
The FSM stays in the initial state `0` unless the specified stop token
has been generated or the maximum number of tokens has been reached. In
which case the FSM moves to the final state `1`.
Parameters
----------
state
The current state of the FSM.
token_id
The id of the token that was just generated.
idx
The index of the current input in the batch.
Returns
-------
The new state of the FSM.
"""
if idx == 0:
self.num_tokens_generated += 1
if token_id == self.stop_token_id:
return FSMState(1)
return FSMState(0)
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
"""Determine whether the current state of the FSM is a final state."""
return state in self.final_states
def reset(self) -> None:
"""Reset the FSM to its initial state. Here this only resets the token counter."""
self.num_tokens_generated = 0
class RegexFSM(FSM):
"""FSM to generate text that is in the language of a regular expression."""
def __init__(
self,
regex_string: str,
tokenizer: "Tokenizer",
):
regex_pattern = interegular.parse_pattern(regex_string)
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
(
self.states_to_token_maps,
self.empty_token_ids,
) = create_fsm_index_tokenizer(regex_fsm, tokenizer)
# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
# vocabulary.
if not any(
regex_fsm.finals.intersection(v.values())
for v in self.states_to_token_maps.values()
):
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)
self.final_states = regex_fsm.finals | {
-1
} # Include the EOS token in final states
self.num_tokens_generated = 0
self.vocabulary = tokenizer.vocabulary.values()
self.end_token_id = tokenizer.eos_token_id
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
"""Generate a list of allowed tokens for the next step.
The initialization of the FSM builds an index which maps FSM states to a
map from authorized tokens to the state in which the FSM needs to move
if said token is generated. Therefore the authorized tokens at the
current state are the keys of the map returned by the value of the index
for current state.
If the current state is not contained in the end this means that we are
in a final state of the FSM. We only authorize EOS tokens in the final
state.
Parameters
----------
state
The current state of the FSM.
idx
The index of the current input in the batch.
Returns
-------
A list that contains the tokens to mask.
"""
next_tokens_to_end_states = self.states_to_token_maps.get(state)
if next_tokens_to_end_states is None:
return [self.end_token_id]
else:
return list(next_tokens_to_end_states.keys())
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
"""Update the state of the FSM.
We use the index to determine to which state the FSM should transition
given the token that was just generated.
Parameters
----------
state
The current state of the FSM.
token_id
The id of the token that was just generated.
idx
The index of the current input in the batch.
Returns
-------
The new state of the FSM.
"""
if idx == 0:
self.num_tokens_generated += 1
if token_id == self.end_token_id:
return FSMState(-1)
last_token_to_end_state = self.states_to_token_maps[state]
next_state = last_token_to_end_state.get(token_id)
if next_state is None:
next_state = -1
return FSMState(next_state)
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
"""Determine whether the current state of the FSM is a final state."""
return state in self.final_states
def reset(self) -> None:
"""Reset the FSM to its initial state. Here this only resets the token counter."""
self.num_tokens_generated = 0
class CFGFSM(FSM):
"""FSM to generate text that is in the language of a context-free grammar."""
def __init__(
self,
cfg_string: str,
tokenizer: "Tokenizer",
):
# self.parser = PartialLark(cfg_string, parser="lalr")
self.parser = Lark(
cfg_string,
parser="lalr",
lexer="contextual",
propagate_positions=False,
maybe_placeholders=False,
regex=True,
)
self.terminal_regexps = dict()
for terminal in self.parser.terminals:
if terminal.pattern is not None:
self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp()
self.terminal_regexps["$END"] = tokenizer.eos_token
self.tokenizer = tokenizer
self.num_tokens_generated = 0
self.generations: List[str] = []
self.regex_fsms: List[RegexFSM] = []
self.reset_state: List[bool] = []
self.allow_eos: List[bool] = []
self.done: List[bool] = []
def _set_next_regex_fsm(self, idx: int = 0) -> None:
"""Use the CFG incremental parser to set the next regex FSM.
Check what the CFG incremental parser proposes next.
If the only proposal is the EOS token,
we set the state to done and return.
If there are other proposals,
we set a new regex FSM and return.
"""
interactive = self.parser.parse_interactive(self.generations[idx])
interactive.exhaust_lexer()
options = {self.terminal_regexps[x] for x in interactive.accepts()}
if self.terminal_regexps["$END"] in options:
options.remove(self.terminal_regexps["$END"])
if len(options) == 0:
self.done[idx] = True
return
self.allow_eos[idx] = True
options.add("")
assert len(options) > 1
regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")"
args = (
regex_string,
self.tokenizer,
)
if len(self.regex_fsms) <= idx:
self.regex_fsms.append(RegexFSM(*args))
else:
self.regex_fsms[idx] = RegexFSM(*args)
self.reset_state[idx] = True
def allowed_token_ids(self, state: FSMState, idx: int = 0) -> List[int]:
"""Generate a list of allowed tokens for the next step.
Upon initialization, the CFG incremental parser is used to determine the first regex.
This regex is used for proposals until either:
- the regex is exhausted, and its only remaining option is the EOS token,
in which case we always transition to the next regex
- the regex can be exhausted, but the EOS token is not the only remaining option,
in which case we transition to the next regex with probability P (TODO)
or remove the possibility of generating the EOS token and continue with the current regex
The CFG incremental parser is allowed to propose the EOS token from any final state,
and once it is generated, the FSM will continue to always generate the EOS token.
Parameters
----------
state
The current state of the FSM.
idx
The index of the current input in the batch.
Returns
-------
A list that contains the tokens to mask.
"""
if len(self.generations) <= idx:
self.generations.append("")
self.reset_state.append(False)
self.allow_eos.append(False)
self.done.append(False)
if len(self.regex_fsms) > idx:
proposal = self.regex_fsms[idx].allowed_token_ids(state)
if self.tokenizer.eos_token_id not in proposal:
return proposal
if set(proposal) != {self.tokenizer.eos_token_id}:
if False: # TODO: THIS NEEDS TO BE SAMPLED
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
return proposal
self._set_next_regex_fsm(idx)
if self.done[idx]:
return [self.tokenizer.eos_token_id]
if self.reset_state[idx]:
state = FSMState(0)
proposal = self.regex_fsms[idx].allowed_token_ids(state)
if self.allow_eos[idx]:
self.allow_eos[idx] = False
else:
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
assert len(proposal) > 0
return proposal
def next_state(self, state: FSMState, token_id: int, idx: int = 0) -> FSMState:
"""Update the state of the FSM.
Transitions the underlying regex FSM to its next state.
If at max tokens or EOS token, transition permanently to the final state.
Update stored partial generations for subsequent incremental parsing.
Parameters
----------
state
The current state of the FSM.
token_id
The id of the token that was just generated.
idx
The index of the current input in the batch.
Returns
-------
The new state of the FSM.
"""
if idx == 0:
self.num_tokens_generated += 1
if token_id == self.tokenizer.eos_token_id:
self.done[idx] = True
return FSMState(-1)
if self.reset_state[idx]:
self.reset_state[idx] = False
state = FSMState(0)
self.generations[idx] += self.tokenizer.decode([token_id])[0]
return self.regex_fsms[idx].next_state(state, token_id, idx)
def is_final_state(self, state: FSMState, idx: int = 0) -> bool:
"""Return whether the current state of the FSM is a final state."""
return self.done[idx]
def reset(self) -> None:
"""Reset the FSM to its initial state, so it can be called on a fresh batch on inputs."""
self.num_tokens_generated = 0
self.generations = []
self.regex_fsms = []
self.reset_state = []
self.done = []

View File

@@ -0,0 +1,41 @@
import threading
from sglang.srt.constrained.fsm import RegexFSM
from sglang.srt.constrained.tokenizer import TransformerTokenizer
def get_fsm(regex, tokenizer, fsm_cache_entry):
outlines_tokenizer = TransformerTokenizer(tokenizer)
fsm = RegexFSM(regex, outlines_tokenizer)
fsm_cache_entry.fsm = fsm
fsm_cache_entry.event.set()
class FSMCacheEntry:
def __init__(self):
self.fsm = None
self.event = threading.Event()
class FSMCache:
def __init__(self, tokenizer):
self.cache = {}
self.tokenizer = tokenizer
def init_fsm_in_background(self, regex):
if regex not in self.cache:
self.cache[regex] = FSMCacheEntry()
threading.Thread(
target=get_fsm,
args=(
regex,
self.tokenizer,
self.cache[regex],
),
).start()
def get_fsm(self, regex):
self.init_fsm_in_background(regex)
entry = self.cache[regex]
entry.event.wait()
return entry.fsm

View File

@@ -0,0 +1,586 @@
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/regex.py
from collections import namedtuple
from functools import lru_cache
from typing import Dict, Generator, List, Sequence, Set, Tuple
import numba
import numpy as np
from interegular.fsm import FSM, Alphabet, OblivionError, anything_else
from numba.typed.typedobjectutils import _nonoptional
from sglang.srt.constrained.tokenizer import Tokenizer
class BetterAlphabet(Alphabet):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert anything_else in self._symbol_mapping
self.anything_value = self._symbol_mapping[anything_else]
def __getitem__(self, item):
return self._symbol_mapping.get(item, self.anything_value)
def copy(self):
return BetterAlphabet(self._symbol_mapping.copy())
class BetterFSM(FSM):
flat_transition_map: Dict[Tuple[int, int], int]
trans_key_to_states: Dict[int, List[int]]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not isinstance(self.alphabet, BetterAlphabet):
self.__dict__["alphabet"] = BetterAlphabet(self.alphabet._symbol_mapping)
flat_transition_map = {}
trans_key_to_states = {}
for from_state, trans_map in self.map.items():
for trans_key, to_state in trans_map.items():
flat_transition_map[(from_state, trans_key)] = to_state
trans_key_to_states.setdefault(trans_key, set()).add(from_state)
self.__dict__["trans_key_to_states"] = trans_key_to_states
self.__dict__["flat_transition_map"] = flat_transition_map
self.__dict__["_fsm_info"] = None
def copy(self):
return BetterFSM(
alphabet=self.alphabet.copy(),
states=self.states.copy(),
initial=self.initial,
finals=self.finals.copy(),
map=self.map.copy(),
__no_validation__=True,
)
@property
def fsm_info(self):
if self._fsm_info is None:
flat_transition_map_items = np.fromiter(
((a[0], a[1], b) for a, b in self.flat_transition_map.items()),
dtype=np.dtype("i8, i8, i8"),
)
trans_key_to_states_items = np.fromiter(
((k, z) for k, v in self.trans_key_to_states.items() for z in v),
dtype=np.dtype("i8, i8"),
)
alphabet_symbol_mapping_items = np.fromiter(
(
it
for it in self.alphabet._symbol_mapping.items()
if it[0] != anything_else
),
dtype=np.dtype("U1, i8"),
)
nb_finals = np.fromiter(self.finals, dtype=np.dtype("i8"))
self.__dict__["_fsm_info"] = create_fsm_info(
self.initial,
nb_finals,
flat_transition_map_items,
trans_key_to_states_items,
self.alphabet.anything_value,
alphabet_symbol_mapping_items,
)
return self._fsm_info
nb_int_list_type = numba.types.ListType(numba.int64)
nb_int_pair_type = numba.types.UniTuple(numba.int64, 2)
nb_unichar_1_type = numba.types.UnicodeCharSeq(1)
@numba.njit(cache=True)
def create_fsm_info(
py_initial,
py_finals,
flat_transition_map_items,
trans_key_to_states_items,
py_anything_value,
alphabet_symbol_mapping_items,
):
trans_key_to_states = numba.typed.Dict.empty(numba.int64, nb_int_list_type)
for trans_key_and_state in trans_key_to_states_items:
trans_key_to_states.setdefault(
trans_key_and_state[0], numba.typed.List.empty_list(numba.int64)
).append(trans_key_and_state[1])
flat_transition_map = numba.typed.Dict.empty(nb_int_pair_type, numba.int64)
for trans_key_and_state in flat_transition_map_items:
flat_transition_map[
(trans_key_and_state[0], trans_key_and_state[1])
] = trans_key_and_state[2]
alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_1_type, numba.int64)
for symbol_and_trans_key in alphabet_symbol_mapping_items:
alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1]
initial = numba.int64(py_initial)
finals = set()
for final in py_finals:
finals.add(final)
anything_value = numba.int64(py_anything_value)
return FSMInfo(
initial,
finals,
flat_transition_map,
trans_key_to_states,
anything_value,
alphabet_symbol_map,
)
FSMInfo = namedtuple(
"FSMInfo",
[
"initial",
"finals",
"transitions",
"trans_key_to_states",
"alphabet_anything_value",
"alphabet_symbol_mapping",
],
)
def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
"""Construct an equivalent FSM with deterministic state labels."""
old_to_new_trans_keys = {
trans_key: i
for i, (trans_key, _) in enumerate(
sorted(fsm.alphabet.by_transition.items(), key=lambda x: sorted(x[1]))
)
}
new_symbol_mapping = {
symbol: old_to_new_trans_keys[trans_key]
for symbol, trans_key in fsm.alphabet._symbol_mapping.items()
}
new_alphabet = BetterAlphabet(new_symbol_mapping)
new_map = {
from_state: {
old_to_new_trans_keys[trans_key]: to_state
for trans_key, to_state in trans_map.items()
}
for from_state, trans_map in fsm.map.items()
}
old_to_new_states = {}
old_to_new_states[fsm.initial] = 0
i = 0
seen = {fsm.initial}
old_state_queue = [fsm.initial]
while old_state_queue:
old_state = old_state_queue.pop(-1)
transitions = new_map[old_state]
sorted_transitions = sorted(transitions.items(), key=lambda v: v[0])
for _, old_state in sorted_transitions:
if old_state not in seen:
old_state_queue.append(old_state)
seen.add(old_state)
if old_state not in old_to_new_states:
i += 1
old_to_new_states[old_state] = i
new_map = dict(
sorted(
(
(
old_to_new_states[from_state],
dict(
sorted(
(
(trans_key, old_to_new_states[to_state])
for trans_key, to_state in trans_map.items()
),
key=lambda v: v[0],
)
),
)
for from_state, trans_map in new_map.items()
),
key=lambda v: v[0],
)
)
new_initial = 0
new_finals = frozenset(
sorted(old_to_new_states[old_state] for old_state in fsm.finals)
)
new_states = frozenset(sorted(new_map.keys()))
new_fsm = BetterFSM(new_alphabet, new_states, new_initial, new_finals, new_map)
return new_fsm, old_to_new_states
@numba.njit(nogil=True, cache=True)
def _walk_fsm(
fsm_transitions: Dict[Tuple[int, int], int],
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
input_string: str,
start_state: int,
full_match: bool = True,
) -> List[int]:
state = start_state
accepted_states: List[int] = numba.typed.List.empty_list(numba.int64)
last_final_idx: int = numba.uint64(0)
for i, symbol in enumerate(input_string):
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
new_state = fsm_transitions.get((state, trans_key))
if new_state is None:
if not full_match and last_final_idx > 0:
return accepted_states[:last_final_idx]
return numba.typed.List.empty_list(numba.int64)
state = new_state
if state in fsm_finals:
last_final_idx = numba.uint64(i + 1)
accepted_states.append(_nonoptional(state))
if full_match and last_final_idx - 1 != i:
return numba.typed.List.empty_list(numba.int64)
return accepted_states
def walk_fsm(
fsm: BetterFSM,
input_string: str,
start_state: int,
full_match: bool = True,
) -> List[int]:
fsm_finals = fsm.finals
state = start_state
accepted_states: List[int] = []
last_final_idx: int = 0
alphabet_symbol_mapping = fsm.alphabet._symbol_mapping
alphabet_anything_value = fsm.alphabet.anything_value
fsm_transitions = fsm.flat_transition_map
for i, symbol in enumerate(input_string):
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
new_state = fsm_transitions.get((state, trans_key))
if new_state is None:
if not full_match and last_final_idx > 0:
return accepted_states[:last_final_idx]
return []
state = new_state
if state in fsm_finals:
last_final_idx = i + 1
accepted_states.append(state)
if full_match and last_final_idx - 1 != i:
return []
return accepted_states
def fsm_union(
fsms: Sequence[FSM],
) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]:
"""Construct an FSM representing the union of the FSMs in `fsms`.
This is an updated version of `interegular.fsm.FSM.union` made to return an
extra map of component FSMs to the sets of state transitions that
correspond to them in the new FSM.
"""
alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms])
indexed_fsms = tuple(enumerate(fsms))
initial = {i: fsm.initial for (i, fsm) in indexed_fsms}
# Dedicated function accepting a "superset" and returning the next
# "superset" obtained by following this transition in the new FSM
def follow(current_state, new_transition: int):
next = {}
for i, f in indexed_fsms:
old_transition = new_to_old[i][new_transition]
if (
i in current_state
and current_state[i] in f.map
and old_transition in f.map[current_state[i]]
):
next[i] = f.map[current_state[i]][old_transition]
if not next:
raise OblivionError
return next
states = [initial]
finals: Set[int] = set()
map: Dict[int, Dict[int, int]] = {}
# Map component FSMs to their new state-to-state transitions, finals, and a
# map translating component FSM states to aggregate FSM states
fsms_to_trans_finals: Dict[
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
] = {}
i = 0
while i < len(states):
state = states[i]
# Add to the finals of the aggregate FSM whenever we hit a final in a
# component FSM
if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms):
finals.add(i)
# Compute the map for this state
map[i] = {}
for transition in alphabet.by_transition:
try:
next = follow(state, transition)
except OblivionError:
# Reached an oblivion state; don't list it
continue
else:
try:
# TODO: Seems like this could--and should--be avoided
j = states.index(next)
except ValueError:
j = len(states)
states.append(next)
map[i][transition] = j
for fsm_id, fsm_state in next.items():
(
fsm_transitions,
fsm_finals,
fsm_old_to_new,
) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {}))
old_from = state[fsm_id]
old_to = fsm_state
fsm_old_to_new.setdefault(old_from, set()).add(i)
fsm_old_to_new.setdefault(old_to, set()).add(j)
fsm_transitions.add((i, j))
if fsm_state in fsms[fsm_id].finals:
fsm_finals.add(j)
i += 1
fsm = FSM(
alphabet=alphabet,
states=range(len(states)),
initial=0,
finals=finals,
map=map,
__no_validation__=True,
)
fsm, old_to_new_states = make_deterministic_fsm(fsm)
_fsms_to_trans_finals = {
fsm_id: (
{(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions},
{old_to_new_states[s] for s in finals},
{
old_state: {old_to_new_states[new_state] for new_state in new_states}
for old_state, new_states in old_to_new.items()
},
)
for fsm_id, (transitions, finals, old_to_new) in sorted(
fsms_to_trans_finals.items(), key=lambda x: x[0]
)
}
return (
fsm,
_fsms_to_trans_finals,
)
def get_sub_fsms_from_seq(
state_seq: Sequence[int],
fsms_to_trans_finals: Dict[
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
],
) -> Generator[Tuple[int, bool, bool], None, None]:
"""Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`.
Parameters
----------
state_seq
A state sequence.
fsms_to_trans_finals
A map from FSM indices to tuples containing sets of their state transitions
and sets of the final/accept states.
Returns
-------
A generator returning tuples containing each sub-FSM index (in the order
they were union-ed to construct `fsm`) and booleans indicating whether or
not there is another valid transition from the last state in the sequence
for the associated sub-FSM (i.e. if the FSM can continue
accepting/matching) and whether or not the sequence ends in a final state
of the sub-FSM.
"""
state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:]))
last_fsm_state = state_seq[-1]
yield from (
(
# The sub-FMS index
fsm_idx,
# Is there another possible transition in this sub-FSM?
any(last_fsm_state == from_s for (from_s, to_s) in transitions),
# Is this sub-FSM in a final state?
state_seq[-1] in finals,
)
for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items()
if state_seq_transitions.issubset(transitions)
)
@numba.njit(cache=True, nogil=True)
def state_scan_tokens(
fsm_transitions: Dict[Tuple[int, int], int],
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
vocabulary: Dict[str, List[int]],
start_state: int,
) -> Set[Tuple[int, int]]:
res = set()
for token, token_ids in vocabulary.items():
state_seq = _walk_fsm(
fsm_transitions,
alphabet_symbol_mapping,
alphabet_anything_value,
fsm_initial,
fsm_finals,
token,
start_state,
False,
)
if state_seq is not None and len(state_seq) < len(token):
continue
for token_id in token_ids:
res.add((token_id, state_seq[-1]))
return res
def create_fsm_index_end_to_end(
fsm_info: FSMInfo,
vocabulary: Dict[str, List[int]],
) -> Dict[int, Set[Tuple[int, int]]]:
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""
# TODO: Consider using a `List` of `Set`s instead; that way we can JIT this
# code, too.
states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = {}
seen: Set[int] = set()
next_states = {fsm_info.initial}
while next_states:
start_state = next_states.pop()
token_ids_end_states = state_scan_tokens(
fsm_info.transitions,
fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
fsm_info.initial,
fsm_info.finals,
vocabulary,
start_state,
)
for token_id_and_end_state in token_ids_end_states:
states_to_token_subsets.setdefault(start_state, set()).add(
token_id_and_end_state
)
end_state = token_id_and_end_state[1]
if end_state not in seen:
next_states.add(end_state)
seen.add(start_state)
return states_to_token_subsets
# TODO: Cannot cache typed collections to disk, yet. See
# https://github.com/numba/numba/issues/4698
@lru_cache
def reduced_vocabulary(tokenizer: "Tokenizer"):
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
vocabulary = numba.typed.Dict.empty(
numba.types.string, numba.types.ListType(numba.int64)
)
empty_token_ids = set()
for token, token_idx in tokenizer.vocabulary.items():
if token in tokenizer.special_tokens:
continue
token_str = tokenizer.convert_token_to_string(token)
if token_str:
vocabulary.setdefault(
token_str,
numba.typed.List.empty_list(numba.int64),
).append(numba.int64(token_idx))
else:
empty_token_ids.add(numba.int64(token_idx))
return vocabulary, empty_token_ids
def create_fsm_index_tokenizer(
fsm: BetterFSM,
tokenizer: "Tokenizer",
) -> Tuple[Dict[int, Dict[int, int]], Set[int]]:
"""Construct an FMS index from a tokenizer.
This uses the end-to-end approach of `create_fsm_index_end_to_end`.
.. warning::
`fsm` needs to be deterministically ordered so that future caching makes sense.
"""
vocabulary, empty_token_ids = reduced_vocabulary(tokenizer)
states_to_token_subsets = create_fsm_index_end_to_end(fsm.fsm_info, vocabulary)
# Allow transitions to EOS from all terminals FSM states that are
# reachable
# TODO: Do we really need this anymore?
for state in fsm.fsm_info.finals:
subset = states_to_token_subsets.get(state)
if subset is not None:
subset.add((tokenizer.eos_token_id, state))
# Convert to token-to-end-state maps
states_to_token_subsets = {k: dict(v) for k, v in states_to_token_subsets.items()}
return states_to_token_subsets, empty_token_ids

View File

@@ -0,0 +1,266 @@
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py
from abc import abstractmethod
from typing import (
TYPE_CHECKING,
Dict,
Hashable,
List,
Optional,
Protocol,
Set,
Tuple,
Union,
)
import numpy as np
import torch
from numpy.typing import NDArray
class Tokenizer(Protocol, Hashable):
eos_token: str
eos_token_id: int
pad_token_id: int
vocabulary: Dict[str, int]
special_tokens: Set[int]
@abstractmethod
def encode(
self, prompt: Union[str, List[str]]
) -> Tuple[NDArray[np.int64], NDArray[np.int64]]:
"""Translate the input prompts into NumPy arrays of token ids and attention mask."""
...
@abstractmethod
def decode(self, token_ids: NDArray[np.int64]) -> List[str]:
"""Translate an array of token ids to a string or list of strings."""
...
@abstractmethod
def convert_token_to_string(self, token: str) -> str:
"""Convert a token to its equivalent string.
This is for instance useful for BPE tokenizers where whitespaces are
represented by the special characted `Ġ`. This prevents matching a raw
token that includes `Ġ` with a string.
"""
...
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
__all__ = ["transformers"]
KVCacheType = Tuple[Tuple[torch.DoubleTensor, torch.DoubleTensor], ...]
def get_llama_tokenizer_types():
"""Get all the Llama tokenizer types/classes that need work-arounds.
When they can't be imported, a dummy class is created.
"""
try:
from transformers.models.llama import LlamaTokenizer
except ImportError:
class LlamaTokenizer: # type: ignore
pass
try:
from transformers.models.llama import LlamaTokenizerFast
except ImportError:
class LlamaTokenizerFast: # type: ignore
pass
try:
from transformers.models.code_llama import CodeLlamaTokenizer
except ImportError:
class CodeLlamaTokenizer: # type: ignore
pass
try:
from transformers.models.code_llama import CodeLlamaTokenizerFast
except ImportError:
class CodeLlamaTokenizerFast: # type: ignore
pass
return (
LlamaTokenizer,
LlamaTokenizerFast,
CodeLlamaTokenizer,
CodeLlamaTokenizerFast,
)
class Transformer:
"""Represents a `transformers` model."""
def __init__(
self,
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
):
self.device = model.device
self.model = model
self.tokenizer = tokenizer
@torch.inference_mode
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
past_key_values: Optional[Tuple] = None,
) -> Tuple[torch.FloatTensor, Optional[KVCacheType]]:
"""Compute a forward pass through the transformer model.
Parameters
----------
input_ids
The input token ids. Must be one or two dimensional.
attention_mask
The attention mask. Must be one or two dimensional.
past_key_values
A tuple of tuples containing the cached key and value tensors for each
attention head.
Returns
-------
The computed logits and the new cached key and value tensors.
"""
assert 0 < input_ids.ndim < 3
if past_key_values:
input_ids = input_ids[..., -1].unsqueeze(-1)
output = self.model(
input_ids,
attention_mask=attention_mask,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
past_key_values=past_key_values,
)
return output.logits, output.past_key_values
def __call__(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
past_key_values: Optional[Tuple] = None,
) -> torch.FloatTensor:
logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values)
next_token_logits = logits[..., -1, :]
return next_token_logits, kv_cache
class TransformerTokenizer(Tokenizer):
"""Represents a tokenizer for models in the `transformers` library."""
def __init__(self, tokenizer):
# TODO: Do something to make this hashable?
self.tokenizer = tokenizer
self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = self.tokenizer.eos_token
if not self.tokenizer.pad_token_id:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.pad_token_id = self.eos_token_id
else:
self.pad_token_id = self.tokenizer.pad_token_id
self.pad_token = self.tokenizer.pad_token
self.special_tokens = set(self.tokenizer.all_special_tokens)
self.vocabulary = self.tokenizer.get_vocab()
self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types())
def encode(
self, prompt: Union[str, List[str]], **kwargs
) -> Tuple[torch.LongTensor, torch.LongTensor]:
kwargs["padding"] = True
kwargs["return_tensors"] = "pt"
output = self.tokenizer(prompt, **kwargs)
return output["input_ids"], output["attention_mask"]
def decode(self, token_ids: torch.LongTensor) -> List[str]:
text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
return text
def convert_token_to_string(self, token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = self.tokenizer.convert_tokens_to_string([token])
if self.is_llama:
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
def __eq__(self, other):
if isinstance(other, type(self)):
return False
# TODO(lsyin): the lru_cache for the TransoformerTokenizer is useless ?
# return other.model_name == self.model_name and other.kwargs == self.kwargs
return NotImplemented
def __hash__(self):
from datasets.fingerprint import Hasher
return hash(Hasher.hash(self.tokenizer))
def transformers(
model_name: str,
device: Optional[str] = None,
model_kwargs: dict = {},
tokenizer_kwargs: dict = {},
):
"""Instantiate a model from the `transformers` library and its tokenizer.
Parameters
----------
model_name
The name of the model as listed on Hugging Face's model page.
device
The device(s) on which the model should be loaded. This overrides
the `device_map` entry in `model_kwargs` when provided.
model_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the model.
tokenizer_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the tokenizer.
Returns
-------
A `TransformersModel` model instance.
"""
try:
from transformers import AutoModelForCausalLM
except ImportError:
raise ImportError(
"The `transformers` library needs to be installed in order to use `transformers` models."
)
if device is not None:
model_kwargs["device_map"] = device
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
tokenizer = TransformerTokenizer(model_name, **tokenizer_kwargs)
return Transformer(model, tokenizer)