fast regex decode
Auto-detect constant str path in regex FSM, then extend instead.
This commit is contained in:
@@ -91,12 +91,32 @@ def run_program_batch(
|
||||
|
||||
if num_threads == 1:
|
||||
rets = []
|
||||
for arguments in batch_arguments:
|
||||
rets.append(
|
||||
run_program(
|
||||
program, backend, (), arguments, default_sampling_para, False, True
|
||||
if progress_bar:
|
||||
for arguments in tqdm.tqdm(batch_arguments):
|
||||
rets.append(
|
||||
run_program(
|
||||
program,
|
||||
backend,
|
||||
(),
|
||||
arguments,
|
||||
default_sampling_para,
|
||||
False,
|
||||
True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
for arguments in batch_arguments:
|
||||
rets.append(
|
||||
run_program(
|
||||
program,
|
||||
backend,
|
||||
(),
|
||||
arguments,
|
||||
default_sampling_para,
|
||||
False,
|
||||
True,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if progress_bar:
|
||||
pbar = tqdm.tqdm(total=len(batch_arguments))
|
||||
|
||||
78
python/sglang/srt/constrained/fast_forward.py
Normal file
78
python/sglang/srt/constrained/fast_forward.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import interegular
|
||||
from sglang.srt.constrained.disk_cache import disk_cache
|
||||
from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm
|
||||
|
||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||
|
||||
|
||||
class FastForwardMap:
|
||||
def __init__(self, regex_string):
|
||||
@disk_cache()
|
||||
def _init_state_to_fast_forward(regex_string):
|
||||
regex_pattern = interegular.parse_pattern(regex_string)
|
||||
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
|
||||
|
||||
fsm_info: FSMInfo = regex_fsm.fsm_info
|
||||
|
||||
symbol_to_id = fsm_info.alphabet_symbol_mapping
|
||||
id_to_symbol = {}
|
||||
for symbol, id_ in symbol_to_id.items():
|
||||
id_to_symbol.setdefault(id_, []).append(symbol)
|
||||
|
||||
transitions = fsm_info.transitions
|
||||
dirty_states = set()
|
||||
state_to_fast_forward = {}
|
||||
|
||||
for (state, id_), next_state in transitions.items():
|
||||
if state in dirty_states:
|
||||
continue
|
||||
if state in state_to_fast_forward:
|
||||
dirty_states.add(state)
|
||||
del state_to_fast_forward[state]
|
||||
continue
|
||||
if len(id_to_symbol[id_]) > 1:
|
||||
dirty_states.add(state)
|
||||
continue
|
||||
|
||||
state_to_fast_forward[state] = (id_to_symbol[id_][0], next_state)
|
||||
|
||||
return state_to_fast_forward
|
||||
|
||||
self.state_to_fast_forward = _init_state_to_fast_forward(regex_string)
|
||||
|
||||
def valid_states(self):
|
||||
return self.state_to_fast_forward.keys()
|
||||
|
||||
def fast_forward(self, state):
|
||||
if state not in self.state_to_fast_forward:
|
||||
return None
|
||||
|
||||
fast_forward_str = ""
|
||||
next_state = None
|
||||
while state in self.state_to_fast_forward:
|
||||
symbol, next_state = self.state_to_fast_forward[state]
|
||||
fast_forward_str += symbol
|
||||
state = next_state
|
||||
return fast_forward_str, next_state
|
||||
|
||||
|
||||
class FastForwardCache:
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def init_fast_forward_map(self, regex_string):
|
||||
if regex_string not in self.cache:
|
||||
fast_forward_map = FastForwardMap(regex_string)
|
||||
self.cache[regex_string] = fast_forward_map
|
||||
return self.cache[regex_string]
|
||||
|
||||
|
||||
def test_main():
|
||||
regex_string = r"The google's DNS sever address is " + IP_REGEX
|
||||
fast_forward_map = FastForwardMap(regex_string)
|
||||
for state in fast_forward_map.valid_states():
|
||||
print(state, f'"{fast_forward_map.fast_forward(state)}"')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_main()
|
||||
@@ -1,6 +1,8 @@
|
||||
from sglang.srt.constrained.fsm import RegexFSM
|
||||
from sglang.srt.constrained.tokenizer import TransformerTokenizer
|
||||
|
||||
_enable_memory_cache = True
|
||||
|
||||
|
||||
class FSMCache:
|
||||
def __init__(self, tokenizer_path, tokenizer_args_dict):
|
||||
@@ -10,8 +12,10 @@ class FSMCache:
|
||||
)
|
||||
|
||||
def init_fsm(self, regex):
|
||||
if regex not in self.cache:
|
||||
fsm = RegexFSM(regex, self.outlines_tokenizer)
|
||||
self.cache[regex] = fsm
|
||||
if _enable_memory_cache:
|
||||
if regex not in self.cache:
|
||||
fsm = RegexFSM(regex, self.outlines_tokenizer)
|
||||
self.cache[regex] = fsm
|
||||
return self.cache[regex]
|
||||
|
||||
return self.cache[regex]
|
||||
return RegexFSM(regex, self.outlines_tokenizer)
|
||||
|
||||
290
python/sglang/srt/constrained/json_schema.py
Normal file
290
python/sglang/srt/constrained/json_schema.py
Normal file
@@ -0,0 +1,290 @@
|
||||
# Adapted from:
|
||||
# https://github.com/outlines-dev/outlines/blob/8a0bafc8d82937babc5d586dd4f72ae844407e0e/outlines/fsm/json_schema.py
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
from typing import Callable, Union
|
||||
|
||||
from jsonschema.protocols import Validator
|
||||
from pydantic import BaseModel, create_model
|
||||
from referencing import Registry, Resource
|
||||
from referencing._core import Resolver
|
||||
from referencing.jsonschema import DRAFT202012
|
||||
|
||||
STRING_INNER = r'(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)'
|
||||
STRING = f'"{STRING_INNER}*"'
|
||||
INTEGER = r"(0|[1-9][0-9]*)"
|
||||
NUMBER = rf"(-)?({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?"
|
||||
BOOLEAN = r"(true|false)"
|
||||
NULL = r"null"
|
||||
|
||||
type_to_regex = {
|
||||
"string": STRING,
|
||||
"integer": INTEGER,
|
||||
"number": NUMBER,
|
||||
"boolean": BOOLEAN,
|
||||
"null": NULL,
|
||||
}
|
||||
|
||||
|
||||
def build_regex_from_object(object: Union[str, Callable, BaseModel]):
|
||||
"""Turn a JSON schema into a regex that matches any JSON object that follows
|
||||
this schema.
|
||||
|
||||
JSON Schema is a declarative language that allows to annotate JSON documents
|
||||
with types and descriptions. These schemas can be generated from any Python
|
||||
datastructure that has type annotation: namedtuples, dataclasses, Pydantic
|
||||
models. And by ensuring that the generation respects the schema we ensure
|
||||
that the output can be parsed into these objects.
|
||||
This function parses the provided schema and builds a generation schedule which
|
||||
mixes deterministic generation (fixed strings), and sampling with constraints.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
schema
|
||||
A string that represents a JSON Schema.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A generation schedule. A list of strings that represent the JSON
|
||||
schema's structure and regular expression that define the structure of
|
||||
the fields.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [0] JSON Schema. https://json-schema.org/
|
||||
|
||||
"""
|
||||
|
||||
if isinstance(object, type(BaseModel)):
|
||||
schema = object.model_json_schema()
|
||||
elif callable(object):
|
||||
schema = get_schema_from_signature(object)
|
||||
else:
|
||||
schema = json.loads(object)
|
||||
|
||||
Validator.check_schema(schema)
|
||||
|
||||
# Build reference resolver
|
||||
schema = Resource(contents=schema, specification=DRAFT202012)
|
||||
uri = schema.id() if schema.id() is not None else ""
|
||||
registry = Registry().with_resource(uri=uri, resource=schema)
|
||||
resolver = registry.resolver()
|
||||
|
||||
content = schema.contents
|
||||
return to_regex(resolver, content)
|
||||
|
||||
|
||||
def to_regex(resolver: Resolver, instance: dict):
|
||||
"""Translate a JSON Schema instance into a regex that validates the schema.
|
||||
|
||||
Note
|
||||
----
|
||||
Many features of JSON schema are missing:
|
||||
- Handle `additionalProperties` keyword
|
||||
- Handle types defined as a list
|
||||
- Handle constraints on numbers
|
||||
- Handle special patterns: `date`, `uri`, etc.
|
||||
|
||||
This does not support recursive definitions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
resolver
|
||||
An object that resolves references to other instances within a schema
|
||||
instance
|
||||
The instance to translate
|
||||
"""
|
||||
whitespace = r"[\n ]*"
|
||||
|
||||
if "properties" in instance:
|
||||
regex = ""
|
||||
regex += r"\{"
|
||||
properties = instance["properties"]
|
||||
required_properties = instance.get("required", [])
|
||||
is_required = [item in required_properties for item in properties]
|
||||
# If at least one property is required, we include the one in the lastest position
|
||||
# without any comma.
|
||||
# For each property before it (optional or required), we add with a comma after the property.
|
||||
# For each property after it (optional), we add with a comma before the property.
|
||||
if any(is_required):
|
||||
last_required_pos = max([i for i, value in enumerate(is_required) if value])
|
||||
for i, (name, value) in enumerate(properties.items()):
|
||||
subregex = f'{whitespace}"{name}"{whitespace}:{whitespace}'
|
||||
subregex += to_regex(resolver, value)
|
||||
if i < last_required_pos:
|
||||
subregex = f"{subregex}{whitespace},"
|
||||
elif i > last_required_pos:
|
||||
subregex = f"{whitespace},{subregex}"
|
||||
regex += subregex if is_required[i] else f"({subregex})?"
|
||||
# If no property is required, we have to create a possible pattern for each property in which
|
||||
# it's the last one necessarilly present. Then, we add the others as optional before and after
|
||||
# following the same strategy as described above.
|
||||
# The whole block is made optional to allow the case in which no property is returned.
|
||||
else:
|
||||
property_subregexes = []
|
||||
for i, (name, value) in enumerate(properties.items()):
|
||||
subregex = f'{whitespace}"{name}"{whitespace}:{whitespace}'
|
||||
subregex += to_regex(resolver, value)
|
||||
property_subregexes.append(subregex)
|
||||
possible_patterns = []
|
||||
for i in range(len(property_subregexes)):
|
||||
pattern = ""
|
||||
for subregex in property_subregexes[:i]:
|
||||
pattern += f"({subregex}{whitespace},)?"
|
||||
pattern += property_subregexes[i]
|
||||
for subregex in property_subregexes[i + 1 :]:
|
||||
pattern += f"({whitespace},{subregex})?"
|
||||
possible_patterns.append(pattern)
|
||||
regex += f"({'|'.join(possible_patterns)})?"
|
||||
|
||||
regex += f"{whitespace}" + r"\}"
|
||||
|
||||
return regex
|
||||
|
||||
# To validate against allOf, the given data must be valid against all of the
|
||||
# given subschemas.
|
||||
elif "allOf" in instance:
|
||||
subregexes = [to_regex(resolver, t) for t in instance["allOf"]]
|
||||
subregexes_str = [f"{subregex}" for subregex in subregexes]
|
||||
return rf"({''.join(subregexes_str)})"
|
||||
|
||||
# To validate against `anyOf`, the given data must be valid against
|
||||
# any (one or more) of the given subschemas.
|
||||
elif "anyOf" in instance:
|
||||
subregexes = [to_regex(resolver, t) for t in instance["anyOf"]]
|
||||
return rf"({'|'.join(subregexes)})"
|
||||
|
||||
# To validate against oneOf, the given data must be valid against exactly
|
||||
# one of the given subschemas.
|
||||
elif "oneOf" in instance:
|
||||
subregexes = [to_regex(resolver, t) for t in instance["oneOf"]]
|
||||
|
||||
xor_patterns = []
|
||||
# json schema validation ensured there is no overlapping schemas in oneOf
|
||||
for subregex in subregexes:
|
||||
other_subregexes = filter(lambda r: r != subregex, subregexes)
|
||||
other_subregexes_str = "|".join([f"{s}" for s in other_subregexes])
|
||||
negative_lookahead = f"(?!.*({other_subregexes_str}))"
|
||||
xor_patterns.append(f"({subregex}){negative_lookahead}")
|
||||
|
||||
return rf"({'|'.join(xor_patterns)})"
|
||||
|
||||
# The enum keyword is used to restrict a value to a fixed set of values. It
|
||||
# must be an array with at least one element, where each element is unique.
|
||||
elif "enum" in instance:
|
||||
choices = []
|
||||
for choice in instance["enum"]:
|
||||
if type(choice) in [int, float, bool, None]:
|
||||
choices.append(re.escape(str(choice)))
|
||||
elif type(choice) == str:
|
||||
choices.append(f'"{re.escape(choice)}"')
|
||||
|
||||
return f"({'|'.join(choices)})"
|
||||
|
||||
elif "$ref" in instance:
|
||||
path = f"{instance['$ref']}"
|
||||
instance = resolver.lookup(path).contents
|
||||
return to_regex(resolver, instance)
|
||||
|
||||
# The type keyword may either be a string or an array:
|
||||
# - If it's a string, it is the name of one of the basic types.
|
||||
# - If it is an array, it must be an array of strings, where each string is
|
||||
# the name of one of the basic types, and each element is unique. In this
|
||||
# case, the JSON snippet is valid if it matches any of the given types.
|
||||
elif "type" in instance:
|
||||
instance_type = instance["type"]
|
||||
if instance_type == "string":
|
||||
if "maxLength" in instance or "minLength" in instance:
|
||||
max_items = instance.get("maxLength", "")
|
||||
min_items = instance.get("minLength", "")
|
||||
try:
|
||||
if int(max_items) < int(min_items):
|
||||
raise ValueError(
|
||||
"maxLength must be greater than or equal to minLength"
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
return f'"{STRING_INNER}{{{min_items},{max_items}}}"'
|
||||
elif "pattern" in instance:
|
||||
pattern = instance["pattern"]
|
||||
if pattern[0] == "^" and pattern[-1] == "$":
|
||||
return rf'(^"{pattern[1:-1]}"$)'
|
||||
else:
|
||||
return rf'("{pattern}")'
|
||||
else:
|
||||
return type_to_regex["string"]
|
||||
|
||||
elif instance_type == "number":
|
||||
return type_to_regex["number"]
|
||||
|
||||
elif instance_type == "integer":
|
||||
return type_to_regex["integer"]
|
||||
|
||||
elif instance_type == "array":
|
||||
min_items = instance.get("minItems", "0")
|
||||
max_items = instance.get("maxItems", "")
|
||||
if min_items == max_items:
|
||||
num_repeats = "{" + str(int(min_items) - 1) + "}"
|
||||
else:
|
||||
num_repeats = "*"
|
||||
|
||||
if "items" in instance:
|
||||
items_regex = to_regex(resolver, instance["items"])
|
||||
return rf"\[({items_regex})(,({items_regex})){num_repeats}\]"
|
||||
else:
|
||||
# Here we need to make the choice to exclude generating list of objects
|
||||
# if the specification of the object is not given, even though a JSON
|
||||
# object that contains an object here would be valid under the specification.
|
||||
types = [
|
||||
{"type": "boolean"},
|
||||
{"type": "null"},
|
||||
{"type": "number"},
|
||||
{"type": "integer"},
|
||||
{"type": "string"},
|
||||
]
|
||||
regexes = [to_regex(resolver, t) for t in types]
|
||||
return (
|
||||
rf"\[({'|'.join(regexes)})(,({'|'.join(regexes)})){num_repeats}\]"
|
||||
)
|
||||
|
||||
elif instance_type == "boolean":
|
||||
return type_to_regex["boolean"]
|
||||
|
||||
elif instance_type == "null":
|
||||
return type_to_regex["null"]
|
||||
|
||||
elif isinstance(instance_type, list):
|
||||
# Here we need to make the choice to exclude generating an object
|
||||
# if the specification of the object is not give, even though a JSON
|
||||
# object that contains an object here would be valid under the specification.
|
||||
regexes = [
|
||||
to_regex(resolver, {"type": t}) for t in instance_type if t != "object"
|
||||
]
|
||||
return rf"({'|'.join(regexes)})"
|
||||
|
||||
raise NotImplementedError(
|
||||
f"""Could not translate the instance {instance} to a
|
||||
regular expression. Make sure it is valid to the JSON Schema specification. If
|
||||
it is, please open an issue on the Outlines repository"""
|
||||
)
|
||||
|
||||
|
||||
def get_schema_from_signature(fn: Callable) -> str:
|
||||
"""Turn a function signature into a JSON schema.
|
||||
|
||||
Every JSON object valid to the output JSON Schema can be passed
|
||||
to `fn` using the ** unpacking syntax.
|
||||
|
||||
"""
|
||||
signature = inspect.signature(fn)
|
||||
arguments = {}
|
||||
for name, arg in signature.parameters.items():
|
||||
if arg.annotation == inspect._empty:
|
||||
raise ValueError("Each argument must have a type annotation")
|
||||
else:
|
||||
arguments[name] = (arg.annotation, ...)
|
||||
|
||||
model = create_model("Arguments", **arguments)
|
||||
|
||||
return model.model_json_schema()
|
||||
@@ -60,6 +60,8 @@ class DetokenizerManager:
|
||||
if first_token.startswith("▁"):
|
||||
output_strs[i] = " " + output_strs[i]
|
||||
|
||||
output_strs[i] = recv_obj.output_and_fast_forward_strs[i] + output_strs[i]
|
||||
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
BatchStrOut(
|
||||
recv_obj.rids,
|
||||
|
||||
@@ -59,6 +59,7 @@ class GenerateReqInput:
|
||||
@dataclass
|
||||
class TokenizedGenerateReqInput:
|
||||
rid: str
|
||||
input_text: str
|
||||
input_ids: List[int]
|
||||
pixel_values: List[float]
|
||||
image_hash: int
|
||||
@@ -73,6 +74,7 @@ class TokenizedGenerateReqInput:
|
||||
class BatchTokenIDOut:
|
||||
rids: List[str]
|
||||
output_tokens: List[List[int]]
|
||||
output_and_fast_forward_strs: List[str]
|
||||
hit_stop_str: List[Optional[str]]
|
||||
skip_special_tokens: List[bool]
|
||||
meta_info: List[Dict]
|
||||
|
||||
@@ -23,6 +23,7 @@ class FinishReason(Enum):
|
||||
class Req:
|
||||
def __init__(self, rid):
|
||||
self.rid = rid
|
||||
self.input_text = None
|
||||
self.input_ids = []
|
||||
self.output_ids = []
|
||||
self.pixel_values = None
|
||||
@@ -48,10 +49,44 @@ class Req:
|
||||
# for constrained decoding
|
||||
self.regex_fsm = None
|
||||
self.regex_fsm_state = 0
|
||||
self.fast_forward_map = None
|
||||
self.output_and_fast_forward_str = ""
|
||||
|
||||
def max_new_tokens(self):
|
||||
return self.sampling_params.max_new_tokens
|
||||
|
||||
def tokenize_fast_forward(self, fast_forward_str, next_state):
|
||||
old_output_str = self.tokenizer.decode(self.output_ids)
|
||||
if self.tokenizer.convert_ids_to_tokens(self.output_ids[0]).startswith("▁"):
|
||||
old_output_str = " " + old_output_str
|
||||
new_input_string = (
|
||||
self.input_text
|
||||
+ self.output_and_fast_forward_str
|
||||
+ old_output_str
|
||||
+ fast_forward_str
|
||||
)
|
||||
new_input_ids = self.tokenizer.encode(new_input_string)
|
||||
fast_forward_tokens_len = (
|
||||
len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
|
||||
)
|
||||
# print("=" * 100)
|
||||
# print(f"Catch fast forward:\n{fast_forward_str}")
|
||||
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
|
||||
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
|
||||
|
||||
self.input_ids = new_input_ids
|
||||
self.output_ids = []
|
||||
self.sampling_params.max_new_tokens = max(
|
||||
self.sampling_params.max_new_tokens - fast_forward_tokens_len, 0
|
||||
)
|
||||
self.regex_fsm_state = next_state
|
||||
self.output_and_fast_forward_str = (
|
||||
self.output_and_fast_forward_str + old_output_str + fast_forward_str
|
||||
)
|
||||
|
||||
# print(f"Output and fast forward str:\n{self.output_and_fast_forward_str}")
|
||||
# print("*" * 100)
|
||||
|
||||
def check_finished(self):
|
||||
if self.finished:
|
||||
return
|
||||
@@ -263,6 +298,8 @@ class Batch:
|
||||
req.last_node = None
|
||||
req.extend_input_len = 0
|
||||
req.output_ids = []
|
||||
req.regex_fsm_state = 0
|
||||
|
||||
# TODO: apply more fine-grained retraction
|
||||
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
@@ -274,6 +311,46 @@ class Batch:
|
||||
|
||||
return retracted_reqs
|
||||
|
||||
def check_for_fast_forward(self):
|
||||
fast_forward_reqs = []
|
||||
filter_indices = [i for i in range(len(self.reqs))]
|
||||
|
||||
req_pool_indices_cpu = None
|
||||
|
||||
for i, req in enumerate(self.reqs):
|
||||
if req.fast_forward_map is not None:
|
||||
res = req.fast_forward_map.fast_forward(req.regex_fsm_state)
|
||||
if res is not None:
|
||||
fast_forward_str, next_state = res
|
||||
if len(fast_forward_str) <= 1:
|
||||
continue
|
||||
|
||||
# insert the old request into tree_cache
|
||||
token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
|
||||
if req_pool_indices_cpu is None:
|
||||
req_pool_indices_cpu = self.req_pool_indices.cpu().tolist()
|
||||
req_pool_idx = req_pool_indices_cpu[i]
|
||||
indices = self.req_to_token_pool.req_to_token[
|
||||
req_pool_idx, : len(token_ids_in_memory)
|
||||
]
|
||||
prefix_len = self.tree_cache.insert(
|
||||
token_ids_in_memory, indices.clone()
|
||||
)
|
||||
self.token_to_kv_pool.free(indices[:prefix_len])
|
||||
self.req_to_token_pool.free(req_pool_idx)
|
||||
self.tree_cache.dec_ref_counter(req.last_node)
|
||||
|
||||
# fast forward
|
||||
req.tokenize_fast_forward(fast_forward_str, next_state)
|
||||
|
||||
fast_forward_reqs.append(req)
|
||||
filter_indices.remove(i)
|
||||
|
||||
if len(filter_indices) < len(self.reqs):
|
||||
self.filter_batch(filter_indices)
|
||||
|
||||
return fast_forward_reqs
|
||||
|
||||
def prepare_for_decode(self, input_ids=None):
|
||||
if input_ids is None:
|
||||
input_ids = [
|
||||
|
||||
@@ -21,6 +21,7 @@ from sglang.srt.managers.router.radix_cache import RadixCache
|
||||
from sglang.srt.managers.router.scheduler import Scheduler
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.constrained.fast_forward import FastForwardCache
|
||||
from sglang.srt.utils import (
|
||||
get_exception_traceback,
|
||||
get_int_token_logit_bias,
|
||||
@@ -45,6 +46,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = server_args.tp_size
|
||||
self.schedule_heuristic = server_args.schedule_heuristic
|
||||
self.no_regex_fast_forward = server_args.no_regex_fast_forward
|
||||
|
||||
# Init model and tokenizer
|
||||
self.model_config = ModelConfig(
|
||||
@@ -118,6 +120,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
"trust_remote_code": server_args.trust_remote_code,
|
||||
},
|
||||
)
|
||||
self.fast_forward_cache = FastForwardCache()
|
||||
|
||||
# Init new token estimation
|
||||
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
|
||||
@@ -201,6 +204,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
recv_req: TokenizedGenerateReqInput,
|
||||
):
|
||||
req = Req(recv_req.rid)
|
||||
req.input_text = recv_req.input_text
|
||||
req.input_ids = recv_req.input_ids
|
||||
req.pixel_values = recv_req.pixel_values
|
||||
req.image_size = recv_req.image_size
|
||||
@@ -223,6 +227,10 @@ class ModelRpcServer(rpyc.Service):
|
||||
# Init regex fsm
|
||||
if req.sampling_params.regex is not None:
|
||||
req.regex_fsm = self.regex_fsm_cache.init_fsm(req.sampling_params.regex)
|
||||
if not self.no_regex_fast_forward:
|
||||
req.fast_forward_map = self.fast_forward_cache.init_fast_forward_map(
|
||||
req.sampling_params.regex
|
||||
)
|
||||
|
||||
# Truncate long prompts
|
||||
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
|
||||
@@ -334,11 +342,6 @@ class ModelRpcServer(rpyc.Service):
|
||||
self.model_config.vocab_size, self.int_token_logit_bias
|
||||
)
|
||||
|
||||
# Reset regex fsm state before first sampling due to retractions
|
||||
for req in batch.reqs:
|
||||
if req.sampling_params.regex is not None:
|
||||
req.regex_fsm_state = 0
|
||||
|
||||
if batch.extend_num_tokens != 0:
|
||||
# Forward
|
||||
logits, (logprobs, normalized_logprobs) = self.model_runner.forward(
|
||||
@@ -388,6 +391,13 @@ class ModelRpcServer(rpyc.Service):
|
||||
self.min_new_token_ratio,
|
||||
)
|
||||
|
||||
if not self.no_regex_fast_forward:
|
||||
# check for fast forward
|
||||
fast_forward_reqs = batch.check_for_fast_forward()
|
||||
self.forward_queue.extend(fast_forward_reqs)
|
||||
if batch.is_empty():
|
||||
return
|
||||
|
||||
# Update batch tensors
|
||||
self.decode_forward_ct += 1
|
||||
batch.prepare_for_decode()
|
||||
@@ -408,6 +418,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
def handle_finished_requests(self, batch: Batch):
|
||||
output_rids = []
|
||||
output_tokens = []
|
||||
output_and_fast_forward_strs = []
|
||||
output_hit_stop_str = []
|
||||
output_skip_special_tokens = []
|
||||
output_meta_info = []
|
||||
@@ -425,6 +436,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
):
|
||||
output_rids.append(req.rid)
|
||||
output_tokens.append(req.output_ids)
|
||||
output_and_fast_forward_strs.append(req.output_and_fast_forward_str)
|
||||
output_hit_stop_str.append(req.hit_stop_str)
|
||||
output_skip_special_tokens.append(
|
||||
req.sampling_params.skip_special_tokens
|
||||
@@ -445,6 +457,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
BatchTokenIDOut(
|
||||
output_rids,
|
||||
output_tokens,
|
||||
output_and_fast_forward_strs,
|
||||
output_hit_stop_str,
|
||||
output_skip_special_tokens,
|
||||
output_meta_info,
|
||||
|
||||
@@ -157,6 +157,7 @@ class TokenizerManager:
|
||||
)
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid=rid,
|
||||
input_text=obj.text,
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
image_hash=image_hash,
|
||||
|
||||
@@ -23,6 +23,7 @@ class ServerArgs:
|
||||
disable_log_stats: bool = False
|
||||
log_stats_interval: int = 10
|
||||
log_level: str = "info"
|
||||
no_regex_fast_forward: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_path is None:
|
||||
@@ -150,6 +151,11 @@ class ServerArgs:
|
||||
default=ServerArgs.log_stats_interval,
|
||||
help="Log stats interval in second.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-regex-fast-forward",
|
||||
action="store_true",
|
||||
help="Disable regex fast forward",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
|
||||
Reference in New Issue
Block a user