fast regex decode

Auto-detect constant str path in regex FSM, then extend instead.
This commit is contained in:
Liangsheng Yin
2024-01-25 01:16:25 +08:00
committed by GitHub
parent 711d343530
commit 01ee0fbc05
16 changed files with 968 additions and 16 deletions

View File

@@ -91,12 +91,32 @@ def run_program_batch(
if num_threads == 1:
rets = []
for arguments in batch_arguments:
rets.append(
run_program(
program, backend, (), arguments, default_sampling_para, False, True
if progress_bar:
for arguments in tqdm.tqdm(batch_arguments):
rets.append(
run_program(
program,
backend,
(),
arguments,
default_sampling_para,
False,
True,
)
)
else:
for arguments in batch_arguments:
rets.append(
run_program(
program,
backend,
(),
arguments,
default_sampling_para,
False,
True,
)
)
)
else:
if progress_bar:
pbar = tqdm.tqdm(total=len(batch_arguments))

View File

@@ -0,0 +1,78 @@
import interegular
from sglang.srt.constrained.disk_cache import disk_cache
from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
class FastForwardMap:
def __init__(self, regex_string):
@disk_cache()
def _init_state_to_fast_forward(regex_string):
regex_pattern = interegular.parse_pattern(regex_string)
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
fsm_info: FSMInfo = regex_fsm.fsm_info
symbol_to_id = fsm_info.alphabet_symbol_mapping
id_to_symbol = {}
for symbol, id_ in symbol_to_id.items():
id_to_symbol.setdefault(id_, []).append(symbol)
transitions = fsm_info.transitions
dirty_states = set()
state_to_fast_forward = {}
for (state, id_), next_state in transitions.items():
if state in dirty_states:
continue
if state in state_to_fast_forward:
dirty_states.add(state)
del state_to_fast_forward[state]
continue
if len(id_to_symbol[id_]) > 1:
dirty_states.add(state)
continue
state_to_fast_forward[state] = (id_to_symbol[id_][0], next_state)
return state_to_fast_forward
self.state_to_fast_forward = _init_state_to_fast_forward(regex_string)
def valid_states(self):
return self.state_to_fast_forward.keys()
def fast_forward(self, state):
if state not in self.state_to_fast_forward:
return None
fast_forward_str = ""
next_state = None
while state in self.state_to_fast_forward:
symbol, next_state = self.state_to_fast_forward[state]
fast_forward_str += symbol
state = next_state
return fast_forward_str, next_state
class FastForwardCache:
def __init__(self):
self.cache = {}
def init_fast_forward_map(self, regex_string):
if regex_string not in self.cache:
fast_forward_map = FastForwardMap(regex_string)
self.cache[regex_string] = fast_forward_map
return self.cache[regex_string]
def test_main():
regex_string = r"The google's DNS sever address is " + IP_REGEX
fast_forward_map = FastForwardMap(regex_string)
for state in fast_forward_map.valid_states():
print(state, f'"{fast_forward_map.fast_forward(state)}"')
if __name__ == "__main__":
test_main()

View File

@@ -1,6 +1,8 @@
from sglang.srt.constrained.fsm import RegexFSM
from sglang.srt.constrained.tokenizer import TransformerTokenizer
_enable_memory_cache = True
class FSMCache:
def __init__(self, tokenizer_path, tokenizer_args_dict):
@@ -10,8 +12,10 @@ class FSMCache:
)
def init_fsm(self, regex):
if regex not in self.cache:
fsm = RegexFSM(regex, self.outlines_tokenizer)
self.cache[regex] = fsm
if _enable_memory_cache:
if regex not in self.cache:
fsm = RegexFSM(regex, self.outlines_tokenizer)
self.cache[regex] = fsm
return self.cache[regex]
return self.cache[regex]
return RegexFSM(regex, self.outlines_tokenizer)

View File

@@ -0,0 +1,290 @@
# Adapted from:
# https://github.com/outlines-dev/outlines/blob/8a0bafc8d82937babc5d586dd4f72ae844407e0e/outlines/fsm/json_schema.py
import inspect
import json
import re
from typing import Callable, Union
from jsonschema.protocols import Validator
from pydantic import BaseModel, create_model
from referencing import Registry, Resource
from referencing._core import Resolver
from referencing.jsonschema import DRAFT202012
STRING_INNER = r'(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)'
STRING = f'"{STRING_INNER}*"'
INTEGER = r"(0|[1-9][0-9]*)"
NUMBER = rf"(-)?({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?"
BOOLEAN = r"(true|false)"
NULL = r"null"
type_to_regex = {
"string": STRING,
"integer": INTEGER,
"number": NUMBER,
"boolean": BOOLEAN,
"null": NULL,
}
def build_regex_from_object(object: Union[str, Callable, BaseModel]):
"""Turn a JSON schema into a regex that matches any JSON object that follows
this schema.
JSON Schema is a declarative language that allows to annotate JSON documents
with types and descriptions. These schemas can be generated from any Python
datastructure that has type annotation: namedtuples, dataclasses, Pydantic
models. And by ensuring that the generation respects the schema we ensure
that the output can be parsed into these objects.
This function parses the provided schema and builds a generation schedule which
mixes deterministic generation (fixed strings), and sampling with constraints.
Parameters
----------
schema
A string that represents a JSON Schema.
Returns
-------
A generation schedule. A list of strings that represent the JSON
schema's structure and regular expression that define the structure of
the fields.
References
----------
.. [0] JSON Schema. https://json-schema.org/
"""
if isinstance(object, type(BaseModel)):
schema = object.model_json_schema()
elif callable(object):
schema = get_schema_from_signature(object)
else:
schema = json.loads(object)
Validator.check_schema(schema)
# Build reference resolver
schema = Resource(contents=schema, specification=DRAFT202012)
uri = schema.id() if schema.id() is not None else ""
registry = Registry().with_resource(uri=uri, resource=schema)
resolver = registry.resolver()
content = schema.contents
return to_regex(resolver, content)
def to_regex(resolver: Resolver, instance: dict):
"""Translate a JSON Schema instance into a regex that validates the schema.
Note
----
Many features of JSON schema are missing:
- Handle `additionalProperties` keyword
- Handle types defined as a list
- Handle constraints on numbers
- Handle special patterns: `date`, `uri`, etc.
This does not support recursive definitions.
Parameters
----------
resolver
An object that resolves references to other instances within a schema
instance
The instance to translate
"""
whitespace = r"[\n ]*"
if "properties" in instance:
regex = ""
regex += r"\{"
properties = instance["properties"]
required_properties = instance.get("required", [])
is_required = [item in required_properties for item in properties]
# If at least one property is required, we include the one in the lastest position
# without any comma.
# For each property before it (optional or required), we add with a comma after the property.
# For each property after it (optional), we add with a comma before the property.
if any(is_required):
last_required_pos = max([i for i, value in enumerate(is_required) if value])
for i, (name, value) in enumerate(properties.items()):
subregex = f'{whitespace}"{name}"{whitespace}:{whitespace}'
subregex += to_regex(resolver, value)
if i < last_required_pos:
subregex = f"{subregex}{whitespace},"
elif i > last_required_pos:
subregex = f"{whitespace},{subregex}"
regex += subregex if is_required[i] else f"({subregex})?"
# If no property is required, we have to create a possible pattern for each property in which
# it's the last one necessarilly present. Then, we add the others as optional before and after
# following the same strategy as described above.
# The whole block is made optional to allow the case in which no property is returned.
else:
property_subregexes = []
for i, (name, value) in enumerate(properties.items()):
subregex = f'{whitespace}"{name}"{whitespace}:{whitespace}'
subregex += to_regex(resolver, value)
property_subregexes.append(subregex)
possible_patterns = []
for i in range(len(property_subregexes)):
pattern = ""
for subregex in property_subregexes[:i]:
pattern += f"({subregex}{whitespace},)?"
pattern += property_subregexes[i]
for subregex in property_subregexes[i + 1 :]:
pattern += f"({whitespace},{subregex})?"
possible_patterns.append(pattern)
regex += f"({'|'.join(possible_patterns)})?"
regex += f"{whitespace}" + r"\}"
return regex
# To validate against allOf, the given data must be valid against all of the
# given subschemas.
elif "allOf" in instance:
subregexes = [to_regex(resolver, t) for t in instance["allOf"]]
subregexes_str = [f"{subregex}" for subregex in subregexes]
return rf"({''.join(subregexes_str)})"
# To validate against `anyOf`, the given data must be valid against
# any (one or more) of the given subschemas.
elif "anyOf" in instance:
subregexes = [to_regex(resolver, t) for t in instance["anyOf"]]
return rf"({'|'.join(subregexes)})"
# To validate against oneOf, the given data must be valid against exactly
# one of the given subschemas.
elif "oneOf" in instance:
subregexes = [to_regex(resolver, t) for t in instance["oneOf"]]
xor_patterns = []
# json schema validation ensured there is no overlapping schemas in oneOf
for subregex in subregexes:
other_subregexes = filter(lambda r: r != subregex, subregexes)
other_subregexes_str = "|".join([f"{s}" for s in other_subregexes])
negative_lookahead = f"(?!.*({other_subregexes_str}))"
xor_patterns.append(f"({subregex}){negative_lookahead}")
return rf"({'|'.join(xor_patterns)})"
# The enum keyword is used to restrict a value to a fixed set of values. It
# must be an array with at least one element, where each element is unique.
elif "enum" in instance:
choices = []
for choice in instance["enum"]:
if type(choice) in [int, float, bool, None]:
choices.append(re.escape(str(choice)))
elif type(choice) == str:
choices.append(f'"{re.escape(choice)}"')
return f"({'|'.join(choices)})"
elif "$ref" in instance:
path = f"{instance['$ref']}"
instance = resolver.lookup(path).contents
return to_regex(resolver, instance)
# The type keyword may either be a string or an array:
# - If it's a string, it is the name of one of the basic types.
# - If it is an array, it must be an array of strings, where each string is
# the name of one of the basic types, and each element is unique. In this
# case, the JSON snippet is valid if it matches any of the given types.
elif "type" in instance:
instance_type = instance["type"]
if instance_type == "string":
if "maxLength" in instance or "minLength" in instance:
max_items = instance.get("maxLength", "")
min_items = instance.get("minLength", "")
try:
if int(max_items) < int(min_items):
raise ValueError(
"maxLength must be greater than or equal to minLength"
)
except ValueError:
pass
return f'"{STRING_INNER}{{{min_items},{max_items}}}"'
elif "pattern" in instance:
pattern = instance["pattern"]
if pattern[0] == "^" and pattern[-1] == "$":
return rf'(^"{pattern[1:-1]}"$)'
else:
return rf'("{pattern}")'
else:
return type_to_regex["string"]
elif instance_type == "number":
return type_to_regex["number"]
elif instance_type == "integer":
return type_to_regex["integer"]
elif instance_type == "array":
min_items = instance.get("minItems", "0")
max_items = instance.get("maxItems", "")
if min_items == max_items:
num_repeats = "{" + str(int(min_items) - 1) + "}"
else:
num_repeats = "*"
if "items" in instance:
items_regex = to_regex(resolver, instance["items"])
return rf"\[({items_regex})(,({items_regex})){num_repeats}\]"
else:
# Here we need to make the choice to exclude generating list of objects
# if the specification of the object is not given, even though a JSON
# object that contains an object here would be valid under the specification.
types = [
{"type": "boolean"},
{"type": "null"},
{"type": "number"},
{"type": "integer"},
{"type": "string"},
]
regexes = [to_regex(resolver, t) for t in types]
return (
rf"\[({'|'.join(regexes)})(,({'|'.join(regexes)})){num_repeats}\]"
)
elif instance_type == "boolean":
return type_to_regex["boolean"]
elif instance_type == "null":
return type_to_regex["null"]
elif isinstance(instance_type, list):
# Here we need to make the choice to exclude generating an object
# if the specification of the object is not give, even though a JSON
# object that contains an object here would be valid under the specification.
regexes = [
to_regex(resolver, {"type": t}) for t in instance_type if t != "object"
]
return rf"({'|'.join(regexes)})"
raise NotImplementedError(
f"""Could not translate the instance {instance} to a
regular expression. Make sure it is valid to the JSON Schema specification. If
it is, please open an issue on the Outlines repository"""
)
def get_schema_from_signature(fn: Callable) -> str:
"""Turn a function signature into a JSON schema.
Every JSON object valid to the output JSON Schema can be passed
to `fn` using the ** unpacking syntax.
"""
signature = inspect.signature(fn)
arguments = {}
for name, arg in signature.parameters.items():
if arg.annotation == inspect._empty:
raise ValueError("Each argument must have a type annotation")
else:
arguments[name] = (arg.annotation, ...)
model = create_model("Arguments", **arguments)
return model.model_json_schema()

View File

@@ -60,6 +60,8 @@ class DetokenizerManager:
if first_token.startswith(""):
output_strs[i] = " " + output_strs[i]
output_strs[i] = recv_obj.output_and_fast_forward_strs[i] + output_strs[i]
self.send_to_tokenizer.send_pyobj(
BatchStrOut(
recv_obj.rids,

View File

@@ -59,6 +59,7 @@ class GenerateReqInput:
@dataclass
class TokenizedGenerateReqInput:
rid: str
input_text: str
input_ids: List[int]
pixel_values: List[float]
image_hash: int
@@ -73,6 +74,7 @@ class TokenizedGenerateReqInput:
class BatchTokenIDOut:
rids: List[str]
output_tokens: List[List[int]]
output_and_fast_forward_strs: List[str]
hit_stop_str: List[Optional[str]]
skip_special_tokens: List[bool]
meta_info: List[Dict]

View File

@@ -23,6 +23,7 @@ class FinishReason(Enum):
class Req:
def __init__(self, rid):
self.rid = rid
self.input_text = None
self.input_ids = []
self.output_ids = []
self.pixel_values = None
@@ -48,10 +49,44 @@ class Req:
# for constrained decoding
self.regex_fsm = None
self.regex_fsm_state = 0
self.fast_forward_map = None
self.output_and_fast_forward_str = ""
def max_new_tokens(self):
return self.sampling_params.max_new_tokens
def tokenize_fast_forward(self, fast_forward_str, next_state):
old_output_str = self.tokenizer.decode(self.output_ids)
if self.tokenizer.convert_ids_to_tokens(self.output_ids[0]).startswith(""):
old_output_str = " " + old_output_str
new_input_string = (
self.input_text
+ self.output_and_fast_forward_str
+ old_output_str
+ fast_forward_str
)
new_input_ids = self.tokenizer.encode(new_input_string)
fast_forward_tokens_len = (
len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
)
# print("=" * 100)
# print(f"Catch fast forward:\n{fast_forward_str}")
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
self.input_ids = new_input_ids
self.output_ids = []
self.sampling_params.max_new_tokens = max(
self.sampling_params.max_new_tokens - fast_forward_tokens_len, 0
)
self.regex_fsm_state = next_state
self.output_and_fast_forward_str = (
self.output_and_fast_forward_str + old_output_str + fast_forward_str
)
# print(f"Output and fast forward str:\n{self.output_and_fast_forward_str}")
# print("*" * 100)
def check_finished(self):
if self.finished:
return
@@ -263,6 +298,8 @@ class Batch:
req.last_node = None
req.extend_input_len = 0
req.output_ids = []
req.regex_fsm_state = 0
# TODO: apply more fine-grained retraction
token_indices = self.req_to_token_pool.req_to_token[
@@ -274,6 +311,46 @@ class Batch:
return retracted_reqs
def check_for_fast_forward(self):
fast_forward_reqs = []
filter_indices = [i for i in range(len(self.reqs))]
req_pool_indices_cpu = None
for i, req in enumerate(self.reqs):
if req.fast_forward_map is not None:
res = req.fast_forward_map.fast_forward(req.regex_fsm_state)
if res is not None:
fast_forward_str, next_state = res
if len(fast_forward_str) <= 1:
continue
# insert the old request into tree_cache
token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
if req_pool_indices_cpu is None:
req_pool_indices_cpu = self.req_pool_indices.cpu().tolist()
req_pool_idx = req_pool_indices_cpu[i]
indices = self.req_to_token_pool.req_to_token[
req_pool_idx, : len(token_ids_in_memory)
]
prefix_len = self.tree_cache.insert(
token_ids_in_memory, indices.clone()
)
self.token_to_kv_pool.free(indices[:prefix_len])
self.req_to_token_pool.free(req_pool_idx)
self.tree_cache.dec_ref_counter(req.last_node)
# fast forward
req.tokenize_fast_forward(fast_forward_str, next_state)
fast_forward_reqs.append(req)
filter_indices.remove(i)
if len(filter_indices) < len(self.reqs):
self.filter_batch(filter_indices)
return fast_forward_reqs
def prepare_for_decode(self, input_ids=None):
if input_ids is None:
input_ids = [

View File

@@ -21,6 +21,7 @@ from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.constrained.fast_forward import FastForwardCache
from sglang.srt.utils import (
get_exception_traceback,
get_int_token_logit_bias,
@@ -45,6 +46,7 @@ class ModelRpcServer(rpyc.Service):
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic
self.no_regex_fast_forward = server_args.no_regex_fast_forward
# Init model and tokenizer
self.model_config = ModelConfig(
@@ -118,6 +120,7 @@ class ModelRpcServer(rpyc.Service):
"trust_remote_code": server_args.trust_remote_code,
},
)
self.fast_forward_cache = FastForwardCache()
# Init new token estimation
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
@@ -201,6 +204,7 @@ class ModelRpcServer(rpyc.Service):
recv_req: TokenizedGenerateReqInput,
):
req = Req(recv_req.rid)
req.input_text = recv_req.input_text
req.input_ids = recv_req.input_ids
req.pixel_values = recv_req.pixel_values
req.image_size = recv_req.image_size
@@ -223,6 +227,10 @@ class ModelRpcServer(rpyc.Service):
# Init regex fsm
if req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.init_fsm(req.sampling_params.regex)
if not self.no_regex_fast_forward:
req.fast_forward_map = self.fast_forward_cache.init_fast_forward_map(
req.sampling_params.regex
)
# Truncate long prompts
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
@@ -334,11 +342,6 @@ class ModelRpcServer(rpyc.Service):
self.model_config.vocab_size, self.int_token_logit_bias
)
# Reset regex fsm state before first sampling due to retractions
for req in batch.reqs:
if req.sampling_params.regex is not None:
req.regex_fsm_state = 0
if batch.extend_num_tokens != 0:
# Forward
logits, (logprobs, normalized_logprobs) = self.model_runner.forward(
@@ -388,6 +391,13 @@ class ModelRpcServer(rpyc.Service):
self.min_new_token_ratio,
)
if not self.no_regex_fast_forward:
# check for fast forward
fast_forward_reqs = batch.check_for_fast_forward()
self.forward_queue.extend(fast_forward_reqs)
if batch.is_empty():
return
# Update batch tensors
self.decode_forward_ct += 1
batch.prepare_for_decode()
@@ -408,6 +418,7 @@ class ModelRpcServer(rpyc.Service):
def handle_finished_requests(self, batch: Batch):
output_rids = []
output_tokens = []
output_and_fast_forward_strs = []
output_hit_stop_str = []
output_skip_special_tokens = []
output_meta_info = []
@@ -425,6 +436,7 @@ class ModelRpcServer(rpyc.Service):
):
output_rids.append(req.rid)
output_tokens.append(req.output_ids)
output_and_fast_forward_strs.append(req.output_and_fast_forward_str)
output_hit_stop_str.append(req.hit_stop_str)
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
@@ -445,6 +457,7 @@ class ModelRpcServer(rpyc.Service):
BatchTokenIDOut(
output_rids,
output_tokens,
output_and_fast_forward_strs,
output_hit_stop_str,
output_skip_special_tokens,
output_meta_info,

View File

@@ -157,6 +157,7 @@ class TokenizerManager:
)
tokenized_obj = TokenizedGenerateReqInput(
rid=rid,
input_text=obj.text,
input_ids=input_ids,
pixel_values=pixel_values,
image_hash=image_hash,

View File

@@ -23,6 +23,7 @@ class ServerArgs:
disable_log_stats: bool = False
log_stats_interval: int = 10
log_level: str = "info"
no_regex_fast_forward: bool = False
def __post_init__(self):
if self.tokenizer_path is None:
@@ -150,6 +151,11 @@ class ServerArgs:
default=ServerArgs.log_stats_interval,
help="Log stats interval in second.",
)
parser.add_argument(
"--no-regex-fast-forward",
action="store_true",
help="Disable regex fast forward",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):