import outlines (#168)
This commit is contained in:
@@ -20,7 +20,7 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
|
||||
"zmq", "vllm>=0.2.5", "interegular", "lark", "numba",
|
||||
"pydantic", "referencing", "diskcache", "cloudpickle", "pillow"]
|
||||
"pydantic", "referencing", "diskcache", "cloudpickle", "pillow", "outlines>=0.0.27"]
|
||||
openai = ["openai>=1.0", "numpy"]
|
||||
anthropic = ["anthropic", "numpy"]
|
||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
||||
|
||||
16
python/sglang/srt/constrained/__init__.py
Normal file
16
python/sglang/srt/constrained/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from outlines.caching import cache as disk_cache
|
||||
from outlines.caching import disable_cache
|
||||
from outlines.fsm.fsm import RegexFSM
|
||||
from outlines.fsm.json_schema import build_regex_from_object
|
||||
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm
|
||||
from outlines.models.transformers import TransformerTokenizer
|
||||
|
||||
__all__ = [
|
||||
"RegexFSM",
|
||||
"FSMInfo",
|
||||
"make_deterministic_fsm",
|
||||
"build_regex_from_object",
|
||||
"TransformerTokenizer",
|
||||
"disk_cache",
|
||||
"disable_cache",
|
||||
]
|
||||
@@ -1,70 +0,0 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/6c6966cfa24e9c120494ebb317c6126aa2ae94af/outlines/caching.py
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
import cloudpickle
|
||||
from diskcache import Cache
|
||||
|
||||
home_dir = os.path.expanduser("~")
|
||||
cache_dir = os.environ.get("SGLANG_CACHE_DIR", f"{home_dir}/.cache/sglang")
|
||||
memory = Cache(cache_dir, eviction_policy="none", cull_limit=0)
|
||||
_caching_enabled = True
|
||||
|
||||
|
||||
def hash_arguments(*args, **kwargs) -> str:
|
||||
"""Create a hash out of the args and kwargs provided"""
|
||||
result = hashlib.md5()
|
||||
for item in list(args) + sorted(kwargs.items()):
|
||||
result.update(cloudpickle.dumps(item))
|
||||
return result.hexdigest()
|
||||
|
||||
|
||||
def disk_cache(key_function: Optional[Callable] = None):
|
||||
def decorator(cached_function: Callable):
|
||||
def wrapper(*args, **kwargs):
|
||||
if not _caching_enabled:
|
||||
return cached_function(*args, **kwargs)
|
||||
if key_function:
|
||||
key_args = key_function(*args, **kwargs)
|
||||
cache_key = hash_arguments(*key_args)
|
||||
else:
|
||||
cache_key = hash_arguments(*args, **kwargs)
|
||||
if cache_key in memory:
|
||||
return memory[cache_key]
|
||||
result = cached_function(*args, **kwargs)
|
||||
memory[cache_key] = result
|
||||
return result
|
||||
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
if not _caching_enabled:
|
||||
return await cached_function(*args, **kwargs)
|
||||
if key_function:
|
||||
key_args = key_function(*args, **kwargs)
|
||||
cache_key = hash_arguments(*key_args)
|
||||
else:
|
||||
cache_key = hash_arguments(*args, **kwargs)
|
||||
if cache_key in memory:
|
||||
return memory[cache_key]
|
||||
result = await cached_function(*args, **kwargs)
|
||||
memory[cache_key] = result
|
||||
return result
|
||||
|
||||
if asyncio.iscoroutinefunction(cached_function):
|
||||
return async_wrapper
|
||||
else:
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def disable_cache():
|
||||
global _caching_enabled
|
||||
_caching_enabled = False
|
||||
|
||||
|
||||
def clear_cache():
|
||||
global memory
|
||||
memory.clear()
|
||||
@@ -1,358 +0,0 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/6c6966cfa24e9c120494ebb317c6126aa2ae94af/outlines/fsm/fsm.py
|
||||
from typing import List, NewType, Protocol, Tuple
|
||||
|
||||
import interegular
|
||||
from lark import Lark
|
||||
from sglang.srt.constrained.disk_cache import disk_cache
|
||||
|
||||
# from outlines.fsm.parsing import PartialLark
|
||||
from sglang.srt.constrained.regex import (
|
||||
create_fsm_index_tokenizer,
|
||||
make_deterministic_fsm,
|
||||
)
|
||||
from sglang.srt.constrained.tokenizer import Tokenizer
|
||||
|
||||
FSMState = NewType("FSMState", int)
|
||||
|
||||
|
||||
class FSM(Protocol):
|
||||
def allowed_token_ids(self, state: FSMState) -> List[int]:
|
||||
...
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int) -> FSMState:
|
||||
...
|
||||
|
||||
def is_final_state(self, state: FSMState) -> bool:
|
||||
...
|
||||
|
||||
def copy(self) -> "FSM":
|
||||
...
|
||||
|
||||
|
||||
class StopAtTokenFSM(FSM):
|
||||
"""FSM to generate text until a specified token id is generated or
|
||||
a specified number of tokens has been generated.
|
||||
|
||||
Text is usually produced until the EOS token is generated by the
|
||||
model.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: "Tokenizer", stop_token_id: int):
|
||||
self.stop_token_id = stop_token_id
|
||||
self.vocabulary = tokenizer.vocabulary.values()
|
||||
self.final_states = {1}
|
||||
|
||||
def allowed_token_ids(self, state: FSMState) -> List[int]:
|
||||
"""Generate a list of allowed tokens for the next step.
|
||||
|
||||
When in the initial state we allow every token to be generated.
|
||||
In the final state the only allowed token is `stop_token_id`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list that contains the tokens to mask.
|
||||
|
||||
"""
|
||||
if state == 0:
|
||||
return list(self.vocabulary)
|
||||
else:
|
||||
return [self.stop_token_id]
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int) -> FSMState:
|
||||
"""Update the state of the FSM.
|
||||
|
||||
The FSM stays in the initial state `0` unless the specified stop token
|
||||
has been generated or the maximum number of tokens has been reached. In
|
||||
which case the FSM moves to the final state `1`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
token_id
|
||||
The id of the token that was just generated.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The new state of the FSM.
|
||||
|
||||
"""
|
||||
if token_id == self.stop_token_id:
|
||||
return FSMState(1)
|
||||
|
||||
return FSMState(0)
|
||||
|
||||
def is_final_state(self, state: FSMState) -> bool:
|
||||
"""Determine whether the current state of the FSM is a final state."""
|
||||
return state in self.final_states
|
||||
|
||||
def copy(self) -> "StopAtTokenFSM":
|
||||
"""Create a copy of the FSM."""
|
||||
return self
|
||||
|
||||
|
||||
class RegexFSM(FSM):
|
||||
"""FSM to generate text that is in the language of a regular expression."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
regex_string: str,
|
||||
tokenizer: "Tokenizer",
|
||||
):
|
||||
@disk_cache()
|
||||
def create_states_mapping(
|
||||
regex_string: str, cacheable_vocabulary: Tuple[Tuple[str, int]]
|
||||
) -> Tuple[dict, set, set]:
|
||||
"""Create the variables related to the mapping between states and tokens
|
||||
The parameters of the function are used for caching purpose
|
||||
"""
|
||||
regex_pattern = interegular.parse_pattern(regex_string)
|
||||
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
|
||||
(
|
||||
states_to_token_maps,
|
||||
empty_token_ids,
|
||||
) = create_fsm_index_tokenizer(regex_fsm, tokenizer)
|
||||
|
||||
# We make sure that it is possible to generate strings in the language
|
||||
# of the regular expression with the tokens present in the model's
|
||||
# vocabulary.
|
||||
if not any(
|
||||
regex_fsm.finals.intersection(v.values())
|
||||
for v in states_to_token_maps.values()
|
||||
):
|
||||
raise ValueError(
|
||||
"The vocabulary does not allow us to build a sequence that matches the input regex"
|
||||
)
|
||||
|
||||
final_states = regex_fsm.finals | {
|
||||
-1
|
||||
} # Include the EOS token in final states
|
||||
return states_to_token_maps, empty_token_ids, final_states
|
||||
|
||||
(
|
||||
self.states_to_token_maps,
|
||||
self.empty_token_ids,
|
||||
self.final_states,
|
||||
) = create_states_mapping(
|
||||
regex_string, tuple(sorted(tokenizer.vocabulary.items()))
|
||||
)
|
||||
self.num_tokens_generated = 0
|
||||
self.vocabulary = tokenizer.vocabulary.values()
|
||||
self.end_token_id = tokenizer.eos_token_id
|
||||
|
||||
def allowed_token_ids(self, state: FSMState) -> List[int]:
|
||||
"""Generate a list of allowed tokens for the next step.
|
||||
|
||||
The initialization of the FSM builds an index which maps FSM states to a
|
||||
map from authorized tokens to the state in which the FSM needs to move
|
||||
if said token is generated. Therefore the authorized tokens at the
|
||||
current state are the keys of the map returned by the value of the index
|
||||
for current state.
|
||||
|
||||
If the current state is not contained in the end this means that we are
|
||||
in a final state of the FSM. We only authorize EOS tokens in the final
|
||||
state.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list that contains the tokens to mask.
|
||||
|
||||
"""
|
||||
next_tokens_to_end_states = self.states_to_token_maps.get(state)
|
||||
|
||||
if next_tokens_to_end_states is None:
|
||||
return [self.end_token_id]
|
||||
else:
|
||||
return list(next_tokens_to_end_states.keys())
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int) -> FSMState:
|
||||
"""Update the state of the FSM.
|
||||
|
||||
We use the index to determine to which state the FSM should transition
|
||||
given the token that was just generated.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
token_id
|
||||
The id of the token that was just generated.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The new state of the FSM.
|
||||
|
||||
"""
|
||||
if token_id == self.end_token_id:
|
||||
return FSMState(-1)
|
||||
|
||||
last_token_to_end_state = self.states_to_token_maps[state]
|
||||
next_state = last_token_to_end_state.get(token_id)
|
||||
if next_state is None:
|
||||
next_state = -1
|
||||
|
||||
return FSMState(next_state)
|
||||
|
||||
def is_final_state(self, state: FSMState) -> bool:
|
||||
"""Determine whether the current state of the FSM is a final state."""
|
||||
return state in self.final_states
|
||||
|
||||
def copy(self) -> "RegexFSM":
|
||||
"""Create a copy of the FSM."""
|
||||
return self
|
||||
|
||||
|
||||
class CFGFSM(FSM):
|
||||
"""FSM to generate text that is in the language of a context-free grammar."""
|
||||
|
||||
def __init__(self, cfg_string: str, tokenizer: "Tokenizer"):
|
||||
self.cfg_string = cfg_string
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.parser = Lark(
|
||||
cfg_string,
|
||||
parser="lalr",
|
||||
lexer="contextual",
|
||||
propagate_positions=False,
|
||||
maybe_placeholders=False,
|
||||
regex=True,
|
||||
)
|
||||
self.terminal_regexps = dict()
|
||||
for terminal in self.parser.terminals:
|
||||
if terminal.pattern is not None:
|
||||
self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp()
|
||||
self.terminal_regexps["$END"] = tokenizer.eos_token
|
||||
|
||||
self.generation = ""
|
||||
self.reset_state = False
|
||||
self.allow_eos = False
|
||||
self.done = False
|
||||
self.regex_fsm: RegexFSM
|
||||
|
||||
def _set_next_regex_fsm(self) -> None:
|
||||
"""Use the CFG incremental parser to set the next regex FSM.
|
||||
|
||||
Check what the CFG incremental parser proposes next:
|
||||
- If the only proposal is the EOS token we set the state to done and
|
||||
return.
|
||||
- If there are other proposals, we set a new regex FSM and return.
|
||||
|
||||
"""
|
||||
interactive = self.parser.parse_interactive(self.generation)
|
||||
interactive.exhaust_lexer()
|
||||
options = {self.terminal_regexps[x] for x in interactive.accepts()}
|
||||
|
||||
if self.terminal_regexps["$END"] in options:
|
||||
options.remove(self.terminal_regexps["$END"])
|
||||
if len(options) == 0:
|
||||
self.done = True
|
||||
return
|
||||
self.allow_eos = True
|
||||
options.add("")
|
||||
assert len(options) > 1
|
||||
|
||||
regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")"
|
||||
self.regex_fsm = RegexFSM(regex_string, self.tokenizer)
|
||||
self.reset_state = True
|
||||
|
||||
def allowed_token_ids(self, state: FSMState) -> List[int]:
|
||||
"""Generate a list of allowed tokens for the next step.
|
||||
|
||||
Upon initialization, the CFG incremental parser is used to determine the
|
||||
first regex.
|
||||
|
||||
This regex is used for proposals until either:
|
||||
|
||||
- The regex is exhausted, and its only remaining option is the EOS
|
||||
token, in which case we always transition to the next regex
|
||||
- The regex can be exhausted, but the EOS token is not the only
|
||||
remaining option, in which case we transition to the next regex with
|
||||
probability P (TODO) or remove the possibility of generating the EOS
|
||||
token and continue with the current regex
|
||||
|
||||
The CFG incremental parser is allowed to propose the EOS token from any final state,
|
||||
and once it is generated, the FSM will continue to always generate the EOS token.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list that contains the tokens to mask.
|
||||
|
||||
"""
|
||||
if self.generation != "":
|
||||
proposal = self.regex_fsm.allowed_token_ids(state)
|
||||
if self.tokenizer.eos_token_id not in proposal:
|
||||
return proposal
|
||||
if set(proposal) != {self.tokenizer.eos_token_id}:
|
||||
if False: # TODO: THIS NEEDS TO BE SAMPLED
|
||||
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
|
||||
return proposal
|
||||
|
||||
self._set_next_regex_fsm()
|
||||
|
||||
if self.done:
|
||||
return [self.tokenizer.eos_token_id]
|
||||
|
||||
if self.reset_state:
|
||||
state = FSMState(0)
|
||||
|
||||
proposal = self.regex_fsm.allowed_token_ids(state)
|
||||
if self.allow_eos:
|
||||
self.allow_eos = False
|
||||
else:
|
||||
proposal = [x for x in proposal if x != self.tokenizer.eos_token_id]
|
||||
assert len(proposal) > 0
|
||||
return proposal
|
||||
|
||||
def next_state(self, state: FSMState, token_id: int) -> FSMState:
|
||||
"""Update the state of the FSM.
|
||||
|
||||
Transitions the underlying regex FSM to its next state.
|
||||
If at max tokens or EOS token, transition permanently to the final state.
|
||||
Update stored partial generations for subsequent incremental parsing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state
|
||||
The current state of the FSM.
|
||||
token_id
|
||||
The id of the token that was just generated.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The new state of the FSM.
|
||||
"""
|
||||
if token_id == self.tokenizer.eos_token_id:
|
||||
self.done = True
|
||||
return FSMState(-1)
|
||||
if self.reset_state:
|
||||
self.reset_state = False
|
||||
state = FSMState(0)
|
||||
|
||||
self.generation += self.tokenizer.decode([token_id])[0]
|
||||
|
||||
return self.regex_fsm.next_state(state, token_id)
|
||||
|
||||
def is_final_state(self, state: FSMState) -> bool:
|
||||
"""Return whether the current state of the FSM is a final state."""
|
||||
return self.done
|
||||
|
||||
def copy(self) -> "CFGFSM":
|
||||
"""Create a copy of the FSM."""
|
||||
return CFGFSM(self.cfg_string, self.tokenizer)
|
||||
@@ -1,6 +1,5 @@
|
||||
from sglang.srt.constrained import RegexFSM, TransformerTokenizer
|
||||
from sglang.srt.constrained.base_cache import BaseCache
|
||||
from sglang.srt.constrained.fsm import RegexFSM
|
||||
from sglang.srt.constrained.tokenizer import TransformerTokenizer
|
||||
|
||||
|
||||
class FSMCache(BaseCache):
|
||||
|
||||
@@ -1,290 +0,0 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/8a0bafc8d82937babc5d586dd4f72ae844407e0e/outlines/fsm/json_schema.py
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
from typing import Callable, Union
|
||||
|
||||
from jsonschema.protocols import Validator
|
||||
from pydantic import BaseModel, create_model
|
||||
from referencing import Registry, Resource
|
||||
from referencing._core import Resolver
|
||||
from referencing.jsonschema import DRAFT202012
|
||||
|
||||
STRING_INNER = r'(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)'
|
||||
STRING = f'"{STRING_INNER}*"'
|
||||
INTEGER = r"(0|[1-9][0-9]*)"
|
||||
NUMBER = rf"(-)?({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?"
|
||||
BOOLEAN = r"(true|false)"
|
||||
NULL = r"null"
|
||||
|
||||
type_to_regex = {
|
||||
"string": STRING,
|
||||
"integer": INTEGER,
|
||||
"number": NUMBER,
|
||||
"boolean": BOOLEAN,
|
||||
"null": NULL,
|
||||
}
|
||||
|
||||
|
||||
def build_regex_from_object(object: Union[str, Callable, BaseModel]):
|
||||
"""Turn a JSON schema into a regex that matches any JSON object that follows
|
||||
this schema.
|
||||
|
||||
JSON Schema is a declarative language that allows to annotate JSON documents
|
||||
with types and descriptions. These schemas can be generated from any Python
|
||||
datastructure that has type annotation: namedtuples, dataclasses, Pydantic
|
||||
models. And by ensuring that the generation respects the schema we ensure
|
||||
that the output can be parsed into these objects.
|
||||
This function parses the provided schema and builds a generation schedule which
|
||||
mixes deterministic generation (fixed strings), and sampling with constraints.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
schema
|
||||
A string that represents a JSON Schema.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A generation schedule. A list of strings that represent the JSON
|
||||
schema's structure and regular expression that define the structure of
|
||||
the fields.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [0] JSON Schema. https://json-schema.org/
|
||||
|
||||
"""
|
||||
|
||||
if isinstance(object, type(BaseModel)):
|
||||
schema = object.model_json_schema()
|
||||
elif callable(object):
|
||||
schema = get_schema_from_signature(object)
|
||||
else:
|
||||
schema = json.loads(object)
|
||||
|
||||
Validator.check_schema(schema)
|
||||
|
||||
# Build reference resolver
|
||||
schema = Resource(contents=schema, specification=DRAFT202012)
|
||||
uri = schema.id() if schema.id() is not None else ""
|
||||
registry = Registry().with_resource(uri=uri, resource=schema)
|
||||
resolver = registry.resolver()
|
||||
|
||||
content = schema.contents
|
||||
return to_regex(resolver, content)
|
||||
|
||||
|
||||
def to_regex(resolver: Resolver, instance: dict):
|
||||
"""Translate a JSON Schema instance into a regex that validates the schema.
|
||||
|
||||
Note
|
||||
----
|
||||
Many features of JSON schema are missing:
|
||||
- Handle `additionalProperties` keyword
|
||||
- Handle types defined as a list
|
||||
- Handle constraints on numbers
|
||||
- Handle special patterns: `date`, `uri`, etc.
|
||||
|
||||
This does not support recursive definitions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
resolver
|
||||
An object that resolves references to other instances within a schema
|
||||
instance
|
||||
The instance to translate
|
||||
"""
|
||||
whitespace = r"[\n ]*"
|
||||
|
||||
if "properties" in instance:
|
||||
regex = ""
|
||||
regex += r"\{"
|
||||
properties = instance["properties"]
|
||||
required_properties = instance.get("required", [])
|
||||
is_required = [item in required_properties for item in properties]
|
||||
# If at least one property is required, we include the one in the lastest position
|
||||
# without any comma.
|
||||
# For each property before it (optional or required), we add with a comma after the property.
|
||||
# For each property after it (optional), we add with a comma before the property.
|
||||
if any(is_required):
|
||||
last_required_pos = max([i for i, value in enumerate(is_required) if value])
|
||||
for i, (name, value) in enumerate(properties.items()):
|
||||
subregex = f'{whitespace}"{name}"{whitespace}:{whitespace}'
|
||||
subregex += to_regex(resolver, value)
|
||||
if i < last_required_pos:
|
||||
subregex = f"{subregex}{whitespace},"
|
||||
elif i > last_required_pos:
|
||||
subregex = f"{whitespace},{subregex}"
|
||||
regex += subregex if is_required[i] else f"({subregex})?"
|
||||
# If no property is required, we have to create a possible pattern for each property in which
|
||||
# it's the last one necessarilly present. Then, we add the others as optional before and after
|
||||
# following the same strategy as described above.
|
||||
# The whole block is made optional to allow the case in which no property is returned.
|
||||
else:
|
||||
property_subregexes = []
|
||||
for i, (name, value) in enumerate(properties.items()):
|
||||
subregex = f'{whitespace}"{name}"{whitespace}:{whitespace}'
|
||||
subregex += to_regex(resolver, value)
|
||||
property_subregexes.append(subregex)
|
||||
possible_patterns = []
|
||||
for i in range(len(property_subregexes)):
|
||||
pattern = ""
|
||||
for subregex in property_subregexes[:i]:
|
||||
pattern += f"({subregex}{whitespace},)?"
|
||||
pattern += property_subregexes[i]
|
||||
for subregex in property_subregexes[i + 1 :]:
|
||||
pattern += f"({whitespace},{subregex})?"
|
||||
possible_patterns.append(pattern)
|
||||
regex += f"({'|'.join(possible_patterns)})?"
|
||||
|
||||
regex += f"{whitespace}" + r"\}"
|
||||
|
||||
return regex
|
||||
|
||||
# To validate against allOf, the given data must be valid against all of the
|
||||
# given subschemas.
|
||||
elif "allOf" in instance:
|
||||
subregexes = [to_regex(resolver, t) for t in instance["allOf"]]
|
||||
subregexes_str = [f"{subregex}" for subregex in subregexes]
|
||||
return rf"({''.join(subregexes_str)})"
|
||||
|
||||
# To validate against `anyOf`, the given data must be valid against
|
||||
# any (one or more) of the given subschemas.
|
||||
elif "anyOf" in instance:
|
||||
subregexes = [to_regex(resolver, t) for t in instance["anyOf"]]
|
||||
return rf"({'|'.join(subregexes)})"
|
||||
|
||||
# To validate against oneOf, the given data must be valid against exactly
|
||||
# one of the given subschemas.
|
||||
elif "oneOf" in instance:
|
||||
subregexes = [to_regex(resolver, t) for t in instance["oneOf"]]
|
||||
|
||||
xor_patterns = []
|
||||
# json schema validation ensured there is no overlapping schemas in oneOf
|
||||
for subregex in subregexes:
|
||||
other_subregexes = filter(lambda r: r != subregex, subregexes)
|
||||
other_subregexes_str = "|".join([f"{s}" for s in other_subregexes])
|
||||
negative_lookahead = f"(?!.*({other_subregexes_str}))"
|
||||
xor_patterns.append(f"({subregex}){negative_lookahead}")
|
||||
|
||||
return rf"({'|'.join(xor_patterns)})"
|
||||
|
||||
# The enum keyword is used to restrict a value to a fixed set of values. It
|
||||
# must be an array with at least one element, where each element is unique.
|
||||
elif "enum" in instance:
|
||||
choices = []
|
||||
for choice in instance["enum"]:
|
||||
if type(choice) in [int, float, bool, None]:
|
||||
choices.append(re.escape(str(choice)))
|
||||
elif type(choice) == str:
|
||||
choices.append(f'"{re.escape(choice)}"')
|
||||
|
||||
return f"({'|'.join(choices)})"
|
||||
|
||||
elif "$ref" in instance:
|
||||
path = f"{instance['$ref']}"
|
||||
instance = resolver.lookup(path).contents
|
||||
return to_regex(resolver, instance)
|
||||
|
||||
# The type keyword may either be a string or an array:
|
||||
# - If it's a string, it is the name of one of the basic types.
|
||||
# - If it is an array, it must be an array of strings, where each string is
|
||||
# the name of one of the basic types, and each element is unique. In this
|
||||
# case, the JSON snippet is valid if it matches any of the given types.
|
||||
elif "type" in instance:
|
||||
instance_type = instance["type"]
|
||||
if instance_type == "string":
|
||||
if "maxLength" in instance or "minLength" in instance:
|
||||
max_items = instance.get("maxLength", "")
|
||||
min_items = instance.get("minLength", "")
|
||||
try:
|
||||
if int(max_items) < int(min_items):
|
||||
raise ValueError(
|
||||
"maxLength must be greater than or equal to minLength"
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
return f'"{STRING_INNER}{{{min_items},{max_items}}}"'
|
||||
elif "pattern" in instance:
|
||||
pattern = instance["pattern"]
|
||||
if pattern[0] == "^" and pattern[-1] == "$":
|
||||
return rf'(^"{pattern[1:-1]}"$)'
|
||||
else:
|
||||
return rf'("{pattern}")'
|
||||
else:
|
||||
return type_to_regex["string"]
|
||||
|
||||
elif instance_type == "number":
|
||||
return type_to_regex["number"]
|
||||
|
||||
elif instance_type == "integer":
|
||||
return type_to_regex["integer"]
|
||||
|
||||
elif instance_type == "array":
|
||||
min_items = instance.get("minItems", "0")
|
||||
max_items = instance.get("maxItems", "")
|
||||
if min_items == max_items:
|
||||
num_repeats = "{" + str(int(min_items) - 1) + "}"
|
||||
else:
|
||||
num_repeats = "*"
|
||||
|
||||
if "items" in instance:
|
||||
items_regex = to_regex(resolver, instance["items"])
|
||||
return rf"\[({items_regex})(,({items_regex})){num_repeats}\]"
|
||||
else:
|
||||
# Here we need to make the choice to exclude generating list of objects
|
||||
# if the specification of the object is not given, even though a JSON
|
||||
# object that contains an object here would be valid under the specification.
|
||||
types = [
|
||||
{"type": "boolean"},
|
||||
{"type": "null"},
|
||||
{"type": "number"},
|
||||
{"type": "integer"},
|
||||
{"type": "string"},
|
||||
]
|
||||
regexes = [to_regex(resolver, t) for t in types]
|
||||
return (
|
||||
rf"\[({'|'.join(regexes)})(,({'|'.join(regexes)})){num_repeats}\]"
|
||||
)
|
||||
|
||||
elif instance_type == "boolean":
|
||||
return type_to_regex["boolean"]
|
||||
|
||||
elif instance_type == "null":
|
||||
return type_to_regex["null"]
|
||||
|
||||
elif isinstance(instance_type, list):
|
||||
# Here we need to make the choice to exclude generating an object
|
||||
# if the specification of the object is not give, even though a JSON
|
||||
# object that contains an object here would be valid under the specification.
|
||||
regexes = [
|
||||
to_regex(resolver, {"type": t}) for t in instance_type if t != "object"
|
||||
]
|
||||
return rf"({'|'.join(regexes)})"
|
||||
|
||||
raise NotImplementedError(
|
||||
f"""Could not translate the instance {instance} to a
|
||||
regular expression. Make sure it is valid to the JSON Schema specification. If
|
||||
it is, please open an issue on the Outlines repository"""
|
||||
)
|
||||
|
||||
|
||||
def get_schema_from_signature(fn: Callable) -> str:
|
||||
"""Turn a function signature into a JSON schema.
|
||||
|
||||
Every JSON object valid to the output JSON Schema can be passed
|
||||
to `fn` using the ** unpacking syntax.
|
||||
|
||||
"""
|
||||
signature = inspect.signature(fn)
|
||||
arguments = {}
|
||||
for name, arg in signature.parameters.items():
|
||||
if arg.annotation == inspect._empty:
|
||||
raise ValueError("Each argument must have a type annotation")
|
||||
else:
|
||||
arguments[name] = (arg.annotation, ...)
|
||||
|
||||
model = create_model("Arguments", **arguments)
|
||||
|
||||
return model.model_json_schema()
|
||||
@@ -1,7 +1,6 @@
|
||||
import interegular
|
||||
from sglang.srt.constrained import FSMInfo, disk_cache, make_deterministic_fsm
|
||||
from sglang.srt.constrained.base_cache import BaseCache
|
||||
from sglang.srt.constrained.disk_cache import disk_cache
|
||||
from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm
|
||||
|
||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||
|
||||
|
||||
@@ -1,586 +0,0 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/fsm/regex.py
|
||||
from collections import namedtuple
|
||||
from functools import lru_cache
|
||||
from typing import Dict, Generator, List, Sequence, Set, Tuple
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
from interegular.fsm import FSM, Alphabet, OblivionError, anything_else
|
||||
from numba.typed.typedobjectutils import _nonoptional
|
||||
from sglang.srt.constrained.tokenizer import Tokenizer
|
||||
|
||||
|
||||
class BetterAlphabet(Alphabet):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert anything_else in self._symbol_mapping
|
||||
self.anything_value = self._symbol_mapping[anything_else]
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self._symbol_mapping.get(item, self.anything_value)
|
||||
|
||||
def copy(self):
|
||||
return BetterAlphabet(self._symbol_mapping.copy())
|
||||
|
||||
|
||||
class BetterFSM(FSM):
|
||||
flat_transition_map: Dict[Tuple[int, int], int]
|
||||
trans_key_to_states: Dict[int, List[int]]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if not isinstance(self.alphabet, BetterAlphabet):
|
||||
self.__dict__["alphabet"] = BetterAlphabet(self.alphabet._symbol_mapping)
|
||||
|
||||
flat_transition_map = {}
|
||||
trans_key_to_states = {}
|
||||
for from_state, trans_map in self.map.items():
|
||||
for trans_key, to_state in trans_map.items():
|
||||
flat_transition_map[(from_state, trans_key)] = to_state
|
||||
trans_key_to_states.setdefault(trans_key, set()).add(from_state)
|
||||
|
||||
self.__dict__["trans_key_to_states"] = trans_key_to_states
|
||||
self.__dict__["flat_transition_map"] = flat_transition_map
|
||||
self.__dict__["_fsm_info"] = None
|
||||
|
||||
def copy(self):
|
||||
return BetterFSM(
|
||||
alphabet=self.alphabet.copy(),
|
||||
states=self.states.copy(),
|
||||
initial=self.initial,
|
||||
finals=self.finals.copy(),
|
||||
map=self.map.copy(),
|
||||
__no_validation__=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def fsm_info(self):
|
||||
if self._fsm_info is None:
|
||||
flat_transition_map_items = np.fromiter(
|
||||
((a[0], a[1], b) for a, b in self.flat_transition_map.items()),
|
||||
dtype=np.dtype("i8, i8, i8"),
|
||||
)
|
||||
trans_key_to_states_items = np.fromiter(
|
||||
((k, z) for k, v in self.trans_key_to_states.items() for z in v),
|
||||
dtype=np.dtype("i8, i8"),
|
||||
)
|
||||
alphabet_symbol_mapping_items = np.fromiter(
|
||||
(
|
||||
it
|
||||
for it in self.alphabet._symbol_mapping.items()
|
||||
if it[0] != anything_else
|
||||
),
|
||||
dtype=np.dtype("U1, i8"),
|
||||
)
|
||||
nb_finals = np.fromiter(self.finals, dtype=np.dtype("i8"))
|
||||
self.__dict__["_fsm_info"] = create_fsm_info(
|
||||
self.initial,
|
||||
nb_finals,
|
||||
flat_transition_map_items,
|
||||
trans_key_to_states_items,
|
||||
self.alphabet.anything_value,
|
||||
alphabet_symbol_mapping_items,
|
||||
)
|
||||
|
||||
return self._fsm_info
|
||||
|
||||
|
||||
nb_int_list_type = numba.types.ListType(numba.int64)
|
||||
nb_int_pair_type = numba.types.UniTuple(numba.int64, 2)
|
||||
nb_unichar_1_type = numba.types.UnicodeCharSeq(1)
|
||||
|
||||
|
||||
@numba.njit(cache=True)
|
||||
def create_fsm_info(
|
||||
py_initial,
|
||||
py_finals,
|
||||
flat_transition_map_items,
|
||||
trans_key_to_states_items,
|
||||
py_anything_value,
|
||||
alphabet_symbol_mapping_items,
|
||||
):
|
||||
trans_key_to_states = numba.typed.Dict.empty(numba.int64, nb_int_list_type)
|
||||
for trans_key_and_state in trans_key_to_states_items:
|
||||
trans_key_to_states.setdefault(
|
||||
trans_key_and_state[0], numba.typed.List.empty_list(numba.int64)
|
||||
).append(trans_key_and_state[1])
|
||||
|
||||
flat_transition_map = numba.typed.Dict.empty(nb_int_pair_type, numba.int64)
|
||||
for trans_key_and_state in flat_transition_map_items:
|
||||
flat_transition_map[
|
||||
(trans_key_and_state[0], trans_key_and_state[1])
|
||||
] = trans_key_and_state[2]
|
||||
|
||||
alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_1_type, numba.int64)
|
||||
for symbol_and_trans_key in alphabet_symbol_mapping_items:
|
||||
alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1]
|
||||
|
||||
initial = numba.int64(py_initial)
|
||||
|
||||
finals = set()
|
||||
for final in py_finals:
|
||||
finals.add(final)
|
||||
|
||||
anything_value = numba.int64(py_anything_value)
|
||||
|
||||
return FSMInfo(
|
||||
initial,
|
||||
finals,
|
||||
flat_transition_map,
|
||||
trans_key_to_states,
|
||||
anything_value,
|
||||
alphabet_symbol_map,
|
||||
)
|
||||
|
||||
|
||||
FSMInfo = namedtuple(
|
||||
"FSMInfo",
|
||||
[
|
||||
"initial",
|
||||
"finals",
|
||||
"transitions",
|
||||
"trans_key_to_states",
|
||||
"alphabet_anything_value",
|
||||
"alphabet_symbol_mapping",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
|
||||
"""Construct an equivalent FSM with deterministic state labels."""
|
||||
old_to_new_trans_keys = {
|
||||
trans_key: i
|
||||
for i, (trans_key, _) in enumerate(
|
||||
sorted(fsm.alphabet.by_transition.items(), key=lambda x: sorted(x[1]))
|
||||
)
|
||||
}
|
||||
|
||||
new_symbol_mapping = {
|
||||
symbol: old_to_new_trans_keys[trans_key]
|
||||
for symbol, trans_key in fsm.alphabet._symbol_mapping.items()
|
||||
}
|
||||
|
||||
new_alphabet = BetterAlphabet(new_symbol_mapping)
|
||||
|
||||
new_map = {
|
||||
from_state: {
|
||||
old_to_new_trans_keys[trans_key]: to_state
|
||||
for trans_key, to_state in trans_map.items()
|
||||
}
|
||||
for from_state, trans_map in fsm.map.items()
|
||||
}
|
||||
|
||||
old_to_new_states = {}
|
||||
old_to_new_states[fsm.initial] = 0
|
||||
|
||||
i = 0
|
||||
seen = {fsm.initial}
|
||||
old_state_queue = [fsm.initial]
|
||||
while old_state_queue:
|
||||
old_state = old_state_queue.pop(-1)
|
||||
transitions = new_map[old_state]
|
||||
sorted_transitions = sorted(transitions.items(), key=lambda v: v[0])
|
||||
for _, old_state in sorted_transitions:
|
||||
if old_state not in seen:
|
||||
old_state_queue.append(old_state)
|
||||
seen.add(old_state)
|
||||
if old_state not in old_to_new_states:
|
||||
i += 1
|
||||
old_to_new_states[old_state] = i
|
||||
|
||||
new_map = dict(
|
||||
sorted(
|
||||
(
|
||||
(
|
||||
old_to_new_states[from_state],
|
||||
dict(
|
||||
sorted(
|
||||
(
|
||||
(trans_key, old_to_new_states[to_state])
|
||||
for trans_key, to_state in trans_map.items()
|
||||
),
|
||||
key=lambda v: v[0],
|
||||
)
|
||||
),
|
||||
)
|
||||
for from_state, trans_map in new_map.items()
|
||||
),
|
||||
key=lambda v: v[0],
|
||||
)
|
||||
)
|
||||
|
||||
new_initial = 0
|
||||
new_finals = frozenset(
|
||||
sorted(old_to_new_states[old_state] for old_state in fsm.finals)
|
||||
)
|
||||
new_states = frozenset(sorted(new_map.keys()))
|
||||
|
||||
new_fsm = BetterFSM(new_alphabet, new_states, new_initial, new_finals, new_map)
|
||||
|
||||
return new_fsm, old_to_new_states
|
||||
|
||||
|
||||
@numba.njit(nogil=True, cache=True)
|
||||
def _walk_fsm(
|
||||
fsm_transitions: Dict[Tuple[int, int], int],
|
||||
alphabet_symbol_mapping: Dict[str, int],
|
||||
alphabet_anything_value: int,
|
||||
fsm_initial: int,
|
||||
fsm_finals: Set[int],
|
||||
input_string: str,
|
||||
start_state: int,
|
||||
full_match: bool = True,
|
||||
) -> List[int]:
|
||||
state = start_state
|
||||
accepted_states: List[int] = numba.typed.List.empty_list(numba.int64)
|
||||
last_final_idx: int = numba.uint64(0)
|
||||
|
||||
for i, symbol in enumerate(input_string):
|
||||
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
|
||||
|
||||
new_state = fsm_transitions.get((state, trans_key))
|
||||
|
||||
if new_state is None:
|
||||
if not full_match and last_final_idx > 0:
|
||||
return accepted_states[:last_final_idx]
|
||||
|
||||
return numba.typed.List.empty_list(numba.int64)
|
||||
|
||||
state = new_state
|
||||
|
||||
if state in fsm_finals:
|
||||
last_final_idx = numba.uint64(i + 1)
|
||||
|
||||
accepted_states.append(_nonoptional(state))
|
||||
|
||||
if full_match and last_final_idx - 1 != i:
|
||||
return numba.typed.List.empty_list(numba.int64)
|
||||
|
||||
return accepted_states
|
||||
|
||||
|
||||
def walk_fsm(
|
||||
fsm: BetterFSM,
|
||||
input_string: str,
|
||||
start_state: int,
|
||||
full_match: bool = True,
|
||||
) -> List[int]:
|
||||
fsm_finals = fsm.finals
|
||||
|
||||
state = start_state
|
||||
accepted_states: List[int] = []
|
||||
last_final_idx: int = 0
|
||||
|
||||
alphabet_symbol_mapping = fsm.alphabet._symbol_mapping
|
||||
alphabet_anything_value = fsm.alphabet.anything_value
|
||||
fsm_transitions = fsm.flat_transition_map
|
||||
|
||||
for i, symbol in enumerate(input_string):
|
||||
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
|
||||
|
||||
new_state = fsm_transitions.get((state, trans_key))
|
||||
|
||||
if new_state is None:
|
||||
if not full_match and last_final_idx > 0:
|
||||
return accepted_states[:last_final_idx]
|
||||
|
||||
return []
|
||||
|
||||
state = new_state
|
||||
|
||||
if state in fsm_finals:
|
||||
last_final_idx = i + 1
|
||||
|
||||
accepted_states.append(state)
|
||||
|
||||
if full_match and last_final_idx - 1 != i:
|
||||
return []
|
||||
|
||||
return accepted_states
|
||||
|
||||
|
||||
def fsm_union(
|
||||
fsms: Sequence[FSM],
|
||||
) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]:
|
||||
"""Construct an FSM representing the union of the FSMs in `fsms`.
|
||||
|
||||
This is an updated version of `interegular.fsm.FSM.union` made to return an
|
||||
extra map of component FSMs to the sets of state transitions that
|
||||
correspond to them in the new FSM.
|
||||
|
||||
"""
|
||||
|
||||
alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms])
|
||||
|
||||
indexed_fsms = tuple(enumerate(fsms))
|
||||
|
||||
initial = {i: fsm.initial for (i, fsm) in indexed_fsms}
|
||||
|
||||
# Dedicated function accepting a "superset" and returning the next
|
||||
# "superset" obtained by following this transition in the new FSM
|
||||
def follow(current_state, new_transition: int):
|
||||
next = {}
|
||||
for i, f in indexed_fsms:
|
||||
old_transition = new_to_old[i][new_transition]
|
||||
if (
|
||||
i in current_state
|
||||
and current_state[i] in f.map
|
||||
and old_transition in f.map[current_state[i]]
|
||||
):
|
||||
next[i] = f.map[current_state[i]][old_transition]
|
||||
if not next:
|
||||
raise OblivionError
|
||||
return next
|
||||
|
||||
states = [initial]
|
||||
finals: Set[int] = set()
|
||||
map: Dict[int, Dict[int, int]] = {}
|
||||
|
||||
# Map component FSMs to their new state-to-state transitions, finals, and a
|
||||
# map translating component FSM states to aggregate FSM states
|
||||
fsms_to_trans_finals: Dict[
|
||||
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
|
||||
] = {}
|
||||
|
||||
i = 0
|
||||
while i < len(states):
|
||||
state = states[i]
|
||||
|
||||
# Add to the finals of the aggregate FSM whenever we hit a final in a
|
||||
# component FSM
|
||||
if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms):
|
||||
finals.add(i)
|
||||
|
||||
# Compute the map for this state
|
||||
map[i] = {}
|
||||
for transition in alphabet.by_transition:
|
||||
try:
|
||||
next = follow(state, transition)
|
||||
except OblivionError:
|
||||
# Reached an oblivion state; don't list it
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
# TODO: Seems like this could--and should--be avoided
|
||||
j = states.index(next)
|
||||
except ValueError:
|
||||
j = len(states)
|
||||
states.append(next)
|
||||
|
||||
map[i][transition] = j
|
||||
|
||||
for fsm_id, fsm_state in next.items():
|
||||
(
|
||||
fsm_transitions,
|
||||
fsm_finals,
|
||||
fsm_old_to_new,
|
||||
) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {}))
|
||||
old_from = state[fsm_id]
|
||||
old_to = fsm_state
|
||||
fsm_old_to_new.setdefault(old_from, set()).add(i)
|
||||
fsm_old_to_new.setdefault(old_to, set()).add(j)
|
||||
fsm_transitions.add((i, j))
|
||||
if fsm_state in fsms[fsm_id].finals:
|
||||
fsm_finals.add(j)
|
||||
|
||||
i += 1
|
||||
|
||||
fsm = FSM(
|
||||
alphabet=alphabet,
|
||||
states=range(len(states)),
|
||||
initial=0,
|
||||
finals=finals,
|
||||
map=map,
|
||||
__no_validation__=True,
|
||||
)
|
||||
|
||||
fsm, old_to_new_states = make_deterministic_fsm(fsm)
|
||||
_fsms_to_trans_finals = {
|
||||
fsm_id: (
|
||||
{(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions},
|
||||
{old_to_new_states[s] for s in finals},
|
||||
{
|
||||
old_state: {old_to_new_states[new_state] for new_state in new_states}
|
||||
for old_state, new_states in old_to_new.items()
|
||||
},
|
||||
)
|
||||
for fsm_id, (transitions, finals, old_to_new) in sorted(
|
||||
fsms_to_trans_finals.items(), key=lambda x: x[0]
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
fsm,
|
||||
_fsms_to_trans_finals,
|
||||
)
|
||||
|
||||
|
||||
def get_sub_fsms_from_seq(
|
||||
state_seq: Sequence[int],
|
||||
fsms_to_trans_finals: Dict[
|
||||
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
|
||||
],
|
||||
) -> Generator[Tuple[int, bool, bool], None, None]:
|
||||
"""Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state_seq
|
||||
A state sequence.
|
||||
fsms_to_trans_finals
|
||||
A map from FSM indices to tuples containing sets of their state transitions
|
||||
and sets of the final/accept states.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A generator returning tuples containing each sub-FSM index (in the order
|
||||
they were union-ed to construct `fsm`) and booleans indicating whether or
|
||||
not there is another valid transition from the last state in the sequence
|
||||
for the associated sub-FSM (i.e. if the FSM can continue
|
||||
accepting/matching) and whether or not the sequence ends in a final state
|
||||
of the sub-FSM.
|
||||
"""
|
||||
state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:]))
|
||||
last_fsm_state = state_seq[-1]
|
||||
yield from (
|
||||
(
|
||||
# The sub-FMS index
|
||||
fsm_idx,
|
||||
# Is there another possible transition in this sub-FSM?
|
||||
any(last_fsm_state == from_s for (from_s, to_s) in transitions),
|
||||
# Is this sub-FSM in a final state?
|
||||
state_seq[-1] in finals,
|
||||
)
|
||||
for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items()
|
||||
if state_seq_transitions.issubset(transitions)
|
||||
)
|
||||
|
||||
|
||||
@numba.njit(cache=True, nogil=True)
|
||||
def state_scan_tokens(
|
||||
fsm_transitions: Dict[Tuple[int, int], int],
|
||||
alphabet_symbol_mapping: Dict[str, int],
|
||||
alphabet_anything_value: int,
|
||||
fsm_initial: int,
|
||||
fsm_finals: Set[int],
|
||||
vocabulary: Dict[str, List[int]],
|
||||
start_state: int,
|
||||
) -> Set[Tuple[int, int]]:
|
||||
res = set()
|
||||
|
||||
for token, token_ids in vocabulary.items():
|
||||
state_seq = _walk_fsm(
|
||||
fsm_transitions,
|
||||
alphabet_symbol_mapping,
|
||||
alphabet_anything_value,
|
||||
fsm_initial,
|
||||
fsm_finals,
|
||||
token,
|
||||
start_state,
|
||||
False,
|
||||
)
|
||||
|
||||
if state_seq is not None and len(state_seq) < len(token):
|
||||
continue
|
||||
|
||||
for token_id in token_ids:
|
||||
res.add((token_id, state_seq[-1]))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def create_fsm_index_end_to_end(
|
||||
fsm_info: FSMInfo,
|
||||
vocabulary: Dict[str, List[int]],
|
||||
) -> Dict[int, Set[Tuple[int, int]]]:
|
||||
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""
|
||||
|
||||
# TODO: Consider using a `List` of `Set`s instead; that way we can JIT this
|
||||
# code, too.
|
||||
states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = {}
|
||||
seen: Set[int] = set()
|
||||
next_states = {fsm_info.initial}
|
||||
|
||||
while next_states:
|
||||
start_state = next_states.pop()
|
||||
|
||||
token_ids_end_states = state_scan_tokens(
|
||||
fsm_info.transitions,
|
||||
fsm_info.alphabet_symbol_mapping,
|
||||
fsm_info.alphabet_anything_value,
|
||||
fsm_info.initial,
|
||||
fsm_info.finals,
|
||||
vocabulary,
|
||||
start_state,
|
||||
)
|
||||
|
||||
for token_id_and_end_state in token_ids_end_states:
|
||||
states_to_token_subsets.setdefault(start_state, set()).add(
|
||||
token_id_and_end_state
|
||||
)
|
||||
end_state = token_id_and_end_state[1]
|
||||
if end_state not in seen:
|
||||
next_states.add(end_state)
|
||||
|
||||
seen.add(start_state)
|
||||
|
||||
return states_to_token_subsets
|
||||
|
||||
|
||||
# TODO: Cannot cache typed collections to disk, yet. See
|
||||
# https://github.com/numba/numba/issues/4698
|
||||
@lru_cache
|
||||
def reduced_vocabulary(tokenizer: "Tokenizer"):
|
||||
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
|
||||
vocabulary = numba.typed.Dict.empty(
|
||||
numba.types.string, numba.types.ListType(numba.int64)
|
||||
)
|
||||
empty_token_ids = set()
|
||||
for token, token_idx in tokenizer.vocabulary.items():
|
||||
if token in tokenizer.special_tokens:
|
||||
continue
|
||||
|
||||
token_str = tokenizer.convert_token_to_string(token)
|
||||
|
||||
if token_str:
|
||||
vocabulary.setdefault(
|
||||
token_str,
|
||||
numba.typed.List.empty_list(numba.int64),
|
||||
).append(numba.int64(token_idx))
|
||||
else:
|
||||
empty_token_ids.add(numba.int64(token_idx))
|
||||
|
||||
return vocabulary, empty_token_ids
|
||||
|
||||
|
||||
def create_fsm_index_tokenizer(
|
||||
fsm: BetterFSM,
|
||||
tokenizer: "Tokenizer",
|
||||
) -> Tuple[Dict[int, Dict[int, int]], Set[int]]:
|
||||
"""Construct an FMS index from a tokenizer.
|
||||
|
||||
This uses the end-to-end approach of `create_fsm_index_end_to_end`.
|
||||
|
||||
.. warning::
|
||||
|
||||
`fsm` needs to be deterministically ordered so that future caching makes sense.
|
||||
|
||||
"""
|
||||
vocabulary, empty_token_ids = reduced_vocabulary(tokenizer)
|
||||
|
||||
states_to_token_subsets = create_fsm_index_end_to_end(fsm.fsm_info, vocabulary)
|
||||
|
||||
# Allow transitions to EOS from all terminals FSM states that are
|
||||
# reachable
|
||||
# TODO: Do we really need this anymore?
|
||||
for state in fsm.fsm_info.finals:
|
||||
subset = states_to_token_subsets.get(state)
|
||||
if subset is not None:
|
||||
subset.add((tokenizer.eos_token_id, state))
|
||||
|
||||
# Convert to token-to-end-state maps
|
||||
states_to_token_subsets = {k: dict(v) for k, v in states_to_token_subsets.items()}
|
||||
|
||||
return states_to_token_subsets, empty_token_ids
|
||||
@@ -1,143 +0,0 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/tokenizer.py
|
||||
# https://github.com/outlines-dev/outlines/blob/0355ab4272a5d7e4d94c4a53a52593f885b81a61/outlines/models/transformers.py
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, Hashable, List, Protocol, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
class Tokenizer(Protocol, Hashable):
|
||||
eos_token: str
|
||||
eos_token_id: int
|
||||
pad_token_id: int
|
||||
vocabulary: Dict[str, int]
|
||||
special_tokens: Set[int]
|
||||
|
||||
@abstractmethod
|
||||
def encode(
|
||||
self, prompt: Union[str, List[str]]
|
||||
) -> Tuple[NDArray[np.int64], NDArray[np.int64]]:
|
||||
"""Translate the input prompts into NumPy arrays of token ids and attention mask."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, token_ids: NDArray[np.int64]) -> List[str]:
|
||||
"""Translate an array of token ids to a string or list of strings."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def convert_token_to_string(self, token: str) -> str:
|
||||
"""Convert a token to its equivalent string.
|
||||
|
||||
This is for instance useful for BPE tokenizers where whitespaces are
|
||||
represented by the special characted `Ġ`. This prevents matching a raw
|
||||
token that includes `Ġ` with a string.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def get_llama_tokenizer_types():
|
||||
"""Get all the Llama tokenizer types/classes that need work-arounds.
|
||||
|
||||
When they can't be imported, a dummy class is created.
|
||||
|
||||
"""
|
||||
try:
|
||||
from transformers.models.llama import LlamaTokenizer
|
||||
except ImportError:
|
||||
|
||||
class LlamaTokenizer: # type: ignore
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers.models.llama import LlamaTokenizerFast
|
||||
except ImportError:
|
||||
|
||||
class LlamaTokenizerFast: # type: ignore
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers.models.code_llama import CodeLlamaTokenizer
|
||||
except ImportError:
|
||||
|
||||
class CodeLlamaTokenizer: # type: ignore
|
||||
pass
|
||||
|
||||
try:
|
||||
from transformers.models.code_llama import CodeLlamaTokenizerFast
|
||||
except ImportError:
|
||||
|
||||
class CodeLlamaTokenizerFast: # type: ignore
|
||||
pass
|
||||
|
||||
return (
|
||||
LlamaTokenizer,
|
||||
LlamaTokenizerFast,
|
||||
CodeLlamaTokenizer,
|
||||
CodeLlamaTokenizerFast,
|
||||
)
|
||||
|
||||
|
||||
class TransformerTokenizer(Tokenizer):
|
||||
"""Represents a tokenizer for models in the `transformers` library."""
|
||||
|
||||
def __init__(self, model_name: str, **kwargs):
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
kwargs.setdefault("padding_side", "left")
|
||||
self.model_name = model_name
|
||||
# TODO: Do something to make this hashable?
|
||||
self.kwargs = kwargs
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **kwargs)
|
||||
self.eos_token_id = self.tokenizer.eos_token_id
|
||||
self.eos_token = self.tokenizer.eos_token
|
||||
|
||||
if not self.tokenizer.pad_token_id:
|
||||
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
||||
self.pad_token_id = self.eos_token_id
|
||||
else:
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
self.pad_token = self.tokenizer.pad_token
|
||||
|
||||
self.special_tokens = set(self.tokenizer.all_special_tokens)
|
||||
|
||||
self.vocabulary = self.tokenizer.get_vocab()
|
||||
self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types())
|
||||
|
||||
def encode(
|
||||
self, prompt: Union[str, List[str]], **kwargs
|
||||
) -> Tuple[torch.LongTensor, torch.LongTensor]:
|
||||
kwargs["padding"] = True
|
||||
kwargs["return_tensors"] = "pt"
|
||||
output = self.tokenizer(prompt, **kwargs)
|
||||
return output["input_ids"], output["attention_mask"]
|
||||
|
||||
def decode(self, token_ids: torch.LongTensor) -> List[str]:
|
||||
text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
|
||||
return text
|
||||
|
||||
def convert_token_to_string(self, token: str) -> str:
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
string = self.tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
if self.is_llama:
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, type(self)):
|
||||
return other.model_name == self.model_name and other.kwargs == self.kwargs
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self):
|
||||
from datasets.fingerprint import Hasher
|
||||
|
||||
return hash(Hasher.hash(self.tokenizer))
|
||||
@@ -21,7 +21,7 @@ from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.srt.constrained.disk_cache import disable_cache
|
||||
from sglang.srt.constrained import disable_cache
|
||||
from sglang.srt.conversation import (
|
||||
Conversation,
|
||||
SeparatorStyle,
|
||||
|
||||
Reference in New Issue
Block a user