From 01ee0fbc051f4e177ad917ef90ab26904c7d6cab Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Thu, 25 Jan 2024 01:16:25 +0800 Subject: [PATCH] fast regex decode Auto-detect constant str path in regex FSM, then extend instead. --- benchmark/json_fast_forward/README.md | 46 +++ benchmark/json_fast_forward/bench_other.py | 135 ++++++++ benchmark/json_fast_forward/bench_sglang.py | 92 ++++++ benchmark/json_fast_forward/dataset.txt | 50 +++ python/sglang/lang/interpreter.py | 30 +- python/sglang/srt/constrained/fast_forward.py | 78 +++++ python/sglang/srt/constrained/fsm_cache.py | 12 +- python/sglang/srt/constrained/json_schema.py | 290 ++++++++++++++++++ .../srt/managers/detokenizer_manager.py | 2 + python/sglang/srt/managers/io_struct.py | 2 + .../sglang/srt/managers/router/infer_batch.py | 77 +++++ .../sglang/srt/managers/router/model_rpc.py | 23 +- .../sglang/srt/managers/tokenizer_manager.py | 1 + python/sglang/srt/server_args.py | 6 + test/srt/test_fast_forward.py | 137 +++++++++ test/srt/test_robust.py | 3 +- 16 files changed, 968 insertions(+), 16 deletions(-) create mode 100644 benchmark/json_fast_forward/README.md create mode 100644 benchmark/json_fast_forward/bench_other.py create mode 100644 benchmark/json_fast_forward/bench_sglang.py create mode 100644 benchmark/json_fast_forward/dataset.txt create mode 100644 python/sglang/srt/constrained/fast_forward.py create mode 100644 python/sglang/srt/constrained/json_schema.py create mode 100644 test/srt/test_fast_forward.py diff --git a/benchmark/json_fast_forward/README.md b/benchmark/json_fast_forward/README.md new file mode 100644 index 000000000..30643d745 --- /dev/null +++ b/benchmark/json_fast_forward/README.md @@ -0,0 +1,46 @@ +## Run benchmark + +### Dependencies + +``` +llama_cpp_python 0.2.32 +guidance 0.1.10 +vllm 0.2.7 +outlines 0.0.24 +``` + +### Benchmark sglang + +Run Llama-7B + +``` +python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +Benchmark + +``` +python3 bench_sglang.py +``` + +### Benchmark vllm + +Run Llama-7B + +``` +python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000 +``` + +Benchmark + +``` +python3 bench_other.py --backend vllm +``` + +### Benchmark guidance (seems not supported) + +Run Llama-7B and benchmark + +``` +python3 bench_other.py --backend guidance --parallel 1 +``` diff --git a/benchmark/json_fast_forward/bench_other.py b/benchmark/json_fast_forward/bench_other.py new file mode 100644 index 000000000..7db0e2d21 --- /dev/null +++ b/benchmark/json_fast_forward/bench_other.py @@ -0,0 +1,135 @@ +import argparse +import json +import time +from concurrent.futures import ThreadPoolExecutor +from functools import partial + +import guidance +from sglang.test.test_utils import ( + add_common_other_args_and_parse, + call_generate_outlines, +) +from sglang.utils import dump_state_text +from tqdm import tqdm + +# there are some FSM bugs with json regex converted from pydantic model +# here use a string regex instead +# regex_string = build_regex_from_object(HarryPoterRole) +character_regex = ( + r"""\{\n""" + + r""" "name": "[\w\d\s]{1,16}",\n""" + + r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n""" + + r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n""" + + r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n""" + + r""" "wand": \{\n""" + + r""" "wood": "[\w\d\s]{1,16}",\n""" + + r""" "core": "[\w\d\s]{1,16}",\n""" + + r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n""" + + r""" \},\n""" + + r""" "alive": "(Alive|Deceased)",\n""" + + r""" "patronus": "[\w\d\s]{1,16}",\n""" + + r""" "bogart": "[\w\d\s]{1,16}"\n""" + + r"""\}""" +) + +# fmt: off +def character_gen(name, generate): + s = name+ " is a character in Harry Potter. Please fill in the following information about him/her.\n" + s += generate(s, max_tokens=256, regex=character_regex) + return s +# fmt: on + + +@guidance +def character_maker(lm, name): + regex_str_no_quote = r"[\w\d\s]+" + regex_float = r"[0-9]+\.[0-9]+" + lm += f"""\ + {name} is a character in Harry Potter. Please fill in the following information about him/her. + {{ + "name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}", + "house": "{guidance.select(options=['Gryffindor', 'Slytherin', 'Ravenclaw', 'Hufflepuff'], name='house')}", + "blood status": "{guidance.select(options=['Pure-blood', 'Half-blood', 'Muggle-born'], name='blood status')}", + "occupation": "{guidance.select(options=['student', 'teacher', 'auror', 'ministry of magic', 'death eater', 'order of the phoenix'], name='occupation')}", + "wand": {{ + "wood": "{guidance.gen("wood", max_tokens=16, regex=regex_str_no_quote)}", + "core": "{guidance.gen('core', max_tokens=16, regex=regex_str_no_quote)}", + "length": {guidance.gen('length', max_tokens=10, regex=regex_float)} + }}, + "alive": "{guidance.select(options=['Alive', 'Deceased'], name='alive')}", + "patronus": "{guidance.gen('patronus', max_tokens=16, regex=regex_str_no_quote)}", + "bogart": "{guidance.gen('bogart', max_tokens=16, regex=regex_str_no_quote)}" + }} + """ + + return lm + + +def main(args): + arguments = [] + with open(args.data_path, "r") as f: + for line in f: + arguments.append({"name": line.strip()}) + arguments = arguments[: args.num_jsons] + + states = [None] * len(arguments) + + # Select backend + if args.backend == "vllm": + url = f"{args.host}:{args.port}/generate" + generate = partial(call_generate_outlines, url=url, temperature=0) + + def func(i): + states[i] = character_gen(**arguments[i], generate=generate) + + get_one_answer = func + elif args.backend == "guidance": + model = guidance.models.LlamaCpp( + "/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf", + n_gpu_layers=-1, + n_ctx=4096, + ) + + def func(i): + lm = model + character_maker(**arguments[i]) + states[i] = lm + + get_one_answer = func + else: + raise ValueError(f"Invalid backend: {args.backend}") + + tic = time.time() + if args.parallel == 1: + for i in tqdm(range(len(arguments))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + rets = executor.map(get_one_answer, list(range(len(arguments)))) + for _ in rets: + pass + + latency = time.time() - tic + + # Compute accuracy + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "json_fast_forward", + "backend": args.backend, + "latency": round(latency, 3), + "num_jsons": args.num_jsons, + "parallel": args.parallel, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="dataset.txt") + parser.add_argument("--num-jsons", type=int, default=50) + args = add_common_other_args_and_parse(parser) + main(args) diff --git a/benchmark/json_fast_forward/bench_sglang.py b/benchmark/json_fast_forward/bench_sglang.py new file mode 100644 index 000000000..6f8c94f17 --- /dev/null +++ b/benchmark/json_fast_forward/bench_sglang.py @@ -0,0 +1,92 @@ +import argparse +import json +import time + +import sglang as sgl +from sglang.test.test_utils import ( + add_common_sglang_args_and_parse, + select_sglang_backend, +) +from sglang.utils import dump_state_text + +# there are some FSM bugs with json regex converted from pydantic model +# here use a string regex instead +# regex_string = build_regex_from_object(HarryPoterRole) +character_regex = ( + r"""\{\n""" + + r""" "name": "[\w\d\s]{1,16}",\n""" + + r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n""" + + r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n""" + + r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n""" + + r""" "wand": \{\n""" + + r""" "wood": "[\w\d\s]{1,16}",\n""" + + r""" "core": "[\w\d\s]{1,16}",\n""" + + r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n""" + + r""" \},\n""" + + r""" "alive": "(Alive|Deceased)",\n""" + + r""" "patronus": "[\w\d\s]{1,16}",\n""" + + r""" "bogart": "[\w\d\s]{1,16}"\n""" + + r"""\}""" +) + +# fmt: off +@sgl.function +def character_gen(s, name): + s += name+ " is a character in Harry Potter. Please fill in the following information about him/her.\n" + s += sgl.gen("json_output", max_tokens=256, regex=character_regex) +# fmt: on + + +def bench_character(args): + arguments = [] + with open(args.data_path, "r") as f: + for line in f: + arguments.append({"name": line.strip()}) + arguments = arguments[: args.num_jsons] + + # Select backend + backend = select_sglang_backend(args) + sgl.set_default_backend(backend) + + # Run requests + tic = time.time() + states = character_gen.run_batch( + arguments, + temperature=0, + num_threads=args.parallel, + progress_bar=(args.parallel == 1), + ) + latency = time.time() - tic + + return states, latency + + +def main(args): + states, latency = bench_character(args) + + # Compute accuracy + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + with open(f"{args.backend}.json", "w") as fout: + for state in states: + fout.write(state["json_output"] + "\n") + + with open(args.result_file, "a") as fout: + value = { + "task": "json_fast_forward", + "backend": args.backend, + "latency": round(latency, 3), + "num_jsons": args.num_jsons, + "parallel": args.parallel, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="dataset.txt") + parser.add_argument("--num-jsons", type=int, default=50) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/benchmark/json_fast_forward/dataset.txt b/benchmark/json_fast_forward/dataset.txt new file mode 100644 index 000000000..f3c8b5f60 --- /dev/null +++ b/benchmark/json_fast_forward/dataset.txt @@ -0,0 +1,50 @@ +Harry Potter +Hermione Granger +Ron Weasley +Albus Dumbledore +Severus Snape +Rubeus Hagrid +Draco Malfoy +Ginny Weasley +Fred Weasley +George Weasley +Percy Weasley +Sirius Black +Remus Lupin +Neville Longbottom +Luna Lovegood +Cedric Diggory +Cho Chang +Lord Voldemort +Minerva McGonagall +Filius Flitwick +Dolores Umbridge +Bellatrix Lestrange +Lucius Malfoy +Molly Weasley +Arthur Weasley +Nymphadora Tonks +Dobby +Moaning Myrtle +Peter Pettigrew +Alastor 'Mad-Eye' Moody +Horace Slughorn +Vernon Dursley +Petunia Dursley +Dudley Dursley +Argus Filch +Sybill Trelawney +Gilderoy Lockhart +Fleur Delacour +Viktor Krum +Bill Weasley +Oliver Wood +Cornelius Fudge +Barty Crouch Sr. +Barty Crouch Jr. +Kingsley Shacklebolt +Quirinus Quirrell +Nearly Headless Nick +Aunt Marge +Griphook +Ludo Bagman \ No newline at end of file diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 5949134a6..144a4e79f 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -91,12 +91,32 @@ def run_program_batch( if num_threads == 1: rets = [] - for arguments in batch_arguments: - rets.append( - run_program( - program, backend, (), arguments, default_sampling_para, False, True + if progress_bar: + for arguments in tqdm.tqdm(batch_arguments): + rets.append( + run_program( + program, + backend, + (), + arguments, + default_sampling_para, + False, + True, + ) + ) + else: + for arguments in batch_arguments: + rets.append( + run_program( + program, + backend, + (), + arguments, + default_sampling_para, + False, + True, + ) ) - ) else: if progress_bar: pbar = tqdm.tqdm(total=len(batch_arguments)) diff --git a/python/sglang/srt/constrained/fast_forward.py b/python/sglang/srt/constrained/fast_forward.py new file mode 100644 index 000000000..49ac33ea5 --- /dev/null +++ b/python/sglang/srt/constrained/fast_forward.py @@ -0,0 +1,78 @@ +import interegular +from sglang.srt.constrained.disk_cache import disk_cache +from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm + +IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" + + +class FastForwardMap: + def __init__(self, regex_string): + @disk_cache() + def _init_state_to_fast_forward(regex_string): + regex_pattern = interegular.parse_pattern(regex_string) + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + + fsm_info: FSMInfo = regex_fsm.fsm_info + + symbol_to_id = fsm_info.alphabet_symbol_mapping + id_to_symbol = {} + for symbol, id_ in symbol_to_id.items(): + id_to_symbol.setdefault(id_, []).append(symbol) + + transitions = fsm_info.transitions + dirty_states = set() + state_to_fast_forward = {} + + for (state, id_), next_state in transitions.items(): + if state in dirty_states: + continue + if state in state_to_fast_forward: + dirty_states.add(state) + del state_to_fast_forward[state] + continue + if len(id_to_symbol[id_]) > 1: + dirty_states.add(state) + continue + + state_to_fast_forward[state] = (id_to_symbol[id_][0], next_state) + + return state_to_fast_forward + + self.state_to_fast_forward = _init_state_to_fast_forward(regex_string) + + def valid_states(self): + return self.state_to_fast_forward.keys() + + def fast_forward(self, state): + if state not in self.state_to_fast_forward: + return None + + fast_forward_str = "" + next_state = None + while state in self.state_to_fast_forward: + symbol, next_state = self.state_to_fast_forward[state] + fast_forward_str += symbol + state = next_state + return fast_forward_str, next_state + + +class FastForwardCache: + def __init__(self): + self.cache = {} + + def init_fast_forward_map(self, regex_string): + if regex_string not in self.cache: + fast_forward_map = FastForwardMap(regex_string) + self.cache[regex_string] = fast_forward_map + return self.cache[regex_string] + + +def test_main(): + regex_string = r"The google's DNS sever address is " + IP_REGEX + fast_forward_map = FastForwardMap(regex_string) + for state in fast_forward_map.valid_states(): + print(state, f'"{fast_forward_map.fast_forward(state)}"') + + +if __name__ == "__main__": + test_main() diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py index 00be13c8f..7bd815eb1 100644 --- a/python/sglang/srt/constrained/fsm_cache.py +++ b/python/sglang/srt/constrained/fsm_cache.py @@ -1,6 +1,8 @@ from sglang.srt.constrained.fsm import RegexFSM from sglang.srt.constrained.tokenizer import TransformerTokenizer +_enable_memory_cache = True + class FSMCache: def __init__(self, tokenizer_path, tokenizer_args_dict): @@ -10,8 +12,10 @@ class FSMCache: ) def init_fsm(self, regex): - if regex not in self.cache: - fsm = RegexFSM(regex, self.outlines_tokenizer) - self.cache[regex] = fsm + if _enable_memory_cache: + if regex not in self.cache: + fsm = RegexFSM(regex, self.outlines_tokenizer) + self.cache[regex] = fsm + return self.cache[regex] - return self.cache[regex] + return RegexFSM(regex, self.outlines_tokenizer) diff --git a/python/sglang/srt/constrained/json_schema.py b/python/sglang/srt/constrained/json_schema.py new file mode 100644 index 000000000..9bb09c1eb --- /dev/null +++ b/python/sglang/srt/constrained/json_schema.py @@ -0,0 +1,290 @@ +# Adapted from: +# https://github.com/outlines-dev/outlines/blob/8a0bafc8d82937babc5d586dd4f72ae844407e0e/outlines/fsm/json_schema.py +import inspect +import json +import re +from typing import Callable, Union + +from jsonschema.protocols import Validator +from pydantic import BaseModel, create_model +from referencing import Registry, Resource +from referencing._core import Resolver +from referencing.jsonschema import DRAFT202012 + +STRING_INNER = r'(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)' +STRING = f'"{STRING_INNER}*"' +INTEGER = r"(0|[1-9][0-9]*)" +NUMBER = rf"(-)?({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?" +BOOLEAN = r"(true|false)" +NULL = r"null" + +type_to_regex = { + "string": STRING, + "integer": INTEGER, + "number": NUMBER, + "boolean": BOOLEAN, + "null": NULL, +} + + +def build_regex_from_object(object: Union[str, Callable, BaseModel]): + """Turn a JSON schema into a regex that matches any JSON object that follows + this schema. + + JSON Schema is a declarative language that allows to annotate JSON documents + with types and descriptions. These schemas can be generated from any Python + datastructure that has type annotation: namedtuples, dataclasses, Pydantic + models. And by ensuring that the generation respects the schema we ensure + that the output can be parsed into these objects. + This function parses the provided schema and builds a generation schedule which + mixes deterministic generation (fixed strings), and sampling with constraints. + + Parameters + ---------- + schema + A string that represents a JSON Schema. + + Returns + ------- + A generation schedule. A list of strings that represent the JSON + schema's structure and regular expression that define the structure of + the fields. + + References + ---------- + .. [0] JSON Schema. https://json-schema.org/ + + """ + + if isinstance(object, type(BaseModel)): + schema = object.model_json_schema() + elif callable(object): + schema = get_schema_from_signature(object) + else: + schema = json.loads(object) + + Validator.check_schema(schema) + + # Build reference resolver + schema = Resource(contents=schema, specification=DRAFT202012) + uri = schema.id() if schema.id() is not None else "" + registry = Registry().with_resource(uri=uri, resource=schema) + resolver = registry.resolver() + + content = schema.contents + return to_regex(resolver, content) + + +def to_regex(resolver: Resolver, instance: dict): + """Translate a JSON Schema instance into a regex that validates the schema. + + Note + ---- + Many features of JSON schema are missing: + - Handle `additionalProperties` keyword + - Handle types defined as a list + - Handle constraints on numbers + - Handle special patterns: `date`, `uri`, etc. + + This does not support recursive definitions. + + Parameters + ---------- + resolver + An object that resolves references to other instances within a schema + instance + The instance to translate + """ + whitespace = r"[\n ]*" + + if "properties" in instance: + regex = "" + regex += r"\{" + properties = instance["properties"] + required_properties = instance.get("required", []) + is_required = [item in required_properties for item in properties] + # If at least one property is required, we include the one in the lastest position + # without any comma. + # For each property before it (optional or required), we add with a comma after the property. + # For each property after it (optional), we add with a comma before the property. + if any(is_required): + last_required_pos = max([i for i, value in enumerate(is_required) if value]) + for i, (name, value) in enumerate(properties.items()): + subregex = f'{whitespace}"{name}"{whitespace}:{whitespace}' + subregex += to_regex(resolver, value) + if i < last_required_pos: + subregex = f"{subregex}{whitespace}," + elif i > last_required_pos: + subregex = f"{whitespace},{subregex}" + regex += subregex if is_required[i] else f"({subregex})?" + # If no property is required, we have to create a possible pattern for each property in which + # it's the last one necessarilly present. Then, we add the others as optional before and after + # following the same strategy as described above. + # The whole block is made optional to allow the case in which no property is returned. + else: + property_subregexes = [] + for i, (name, value) in enumerate(properties.items()): + subregex = f'{whitespace}"{name}"{whitespace}:{whitespace}' + subregex += to_regex(resolver, value) + property_subregexes.append(subregex) + possible_patterns = [] + for i in range(len(property_subregexes)): + pattern = "" + for subregex in property_subregexes[:i]: + pattern += f"({subregex}{whitespace},)?" + pattern += property_subregexes[i] + for subregex in property_subregexes[i + 1 :]: + pattern += f"({whitespace},{subregex})?" + possible_patterns.append(pattern) + regex += f"({'|'.join(possible_patterns)})?" + + regex += f"{whitespace}" + r"\}" + + return regex + + # To validate against allOf, the given data must be valid against all of the + # given subschemas. + elif "allOf" in instance: + subregexes = [to_regex(resolver, t) for t in instance["allOf"]] + subregexes_str = [f"{subregex}" for subregex in subregexes] + return rf"({''.join(subregexes_str)})" + + # To validate against `anyOf`, the given data must be valid against + # any (one or more) of the given subschemas. + elif "anyOf" in instance: + subregexes = [to_regex(resolver, t) for t in instance["anyOf"]] + return rf"({'|'.join(subregexes)})" + + # To validate against oneOf, the given data must be valid against exactly + # one of the given subschemas. + elif "oneOf" in instance: + subregexes = [to_regex(resolver, t) for t in instance["oneOf"]] + + xor_patterns = [] + # json schema validation ensured there is no overlapping schemas in oneOf + for subregex in subregexes: + other_subregexes = filter(lambda r: r != subregex, subregexes) + other_subregexes_str = "|".join([f"{s}" for s in other_subregexes]) + negative_lookahead = f"(?!.*({other_subregexes_str}))" + xor_patterns.append(f"({subregex}){negative_lookahead}") + + return rf"({'|'.join(xor_patterns)})" + + # The enum keyword is used to restrict a value to a fixed set of values. It + # must be an array with at least one element, where each element is unique. + elif "enum" in instance: + choices = [] + for choice in instance["enum"]: + if type(choice) in [int, float, bool, None]: + choices.append(re.escape(str(choice))) + elif type(choice) == str: + choices.append(f'"{re.escape(choice)}"') + + return f"({'|'.join(choices)})" + + elif "$ref" in instance: + path = f"{instance['$ref']}" + instance = resolver.lookup(path).contents + return to_regex(resolver, instance) + + # The type keyword may either be a string or an array: + # - If it's a string, it is the name of one of the basic types. + # - If it is an array, it must be an array of strings, where each string is + # the name of one of the basic types, and each element is unique. In this + # case, the JSON snippet is valid if it matches any of the given types. + elif "type" in instance: + instance_type = instance["type"] + if instance_type == "string": + if "maxLength" in instance or "minLength" in instance: + max_items = instance.get("maxLength", "") + min_items = instance.get("minLength", "") + try: + if int(max_items) < int(min_items): + raise ValueError( + "maxLength must be greater than or equal to minLength" + ) + except ValueError: + pass + return f'"{STRING_INNER}{{{min_items},{max_items}}}"' + elif "pattern" in instance: + pattern = instance["pattern"] + if pattern[0] == "^" and pattern[-1] == "$": + return rf'(^"{pattern[1:-1]}"$)' + else: + return rf'("{pattern}")' + else: + return type_to_regex["string"] + + elif instance_type == "number": + return type_to_regex["number"] + + elif instance_type == "integer": + return type_to_regex["integer"] + + elif instance_type == "array": + min_items = instance.get("minItems", "0") + max_items = instance.get("maxItems", "") + if min_items == max_items: + num_repeats = "{" + str(int(min_items) - 1) + "}" + else: + num_repeats = "*" + + if "items" in instance: + items_regex = to_regex(resolver, instance["items"]) + return rf"\[({items_regex})(,({items_regex})){num_repeats}\]" + else: + # Here we need to make the choice to exclude generating list of objects + # if the specification of the object is not given, even though a JSON + # object that contains an object here would be valid under the specification. + types = [ + {"type": "boolean"}, + {"type": "null"}, + {"type": "number"}, + {"type": "integer"}, + {"type": "string"}, + ] + regexes = [to_regex(resolver, t) for t in types] + return ( + rf"\[({'|'.join(regexes)})(,({'|'.join(regexes)})){num_repeats}\]" + ) + + elif instance_type == "boolean": + return type_to_regex["boolean"] + + elif instance_type == "null": + return type_to_regex["null"] + + elif isinstance(instance_type, list): + # Here we need to make the choice to exclude generating an object + # if the specification of the object is not give, even though a JSON + # object that contains an object here would be valid under the specification. + regexes = [ + to_regex(resolver, {"type": t}) for t in instance_type if t != "object" + ] + return rf"({'|'.join(regexes)})" + + raise NotImplementedError( + f"""Could not translate the instance {instance} to a + regular expression. Make sure it is valid to the JSON Schema specification. If + it is, please open an issue on the Outlines repository""" + ) + + +def get_schema_from_signature(fn: Callable) -> str: + """Turn a function signature into a JSON schema. + + Every JSON object valid to the output JSON Schema can be passed + to `fn` using the ** unpacking syntax. + + """ + signature = inspect.signature(fn) + arguments = {} + for name, arg in signature.parameters.items(): + if arg.annotation == inspect._empty: + raise ValueError("Each argument must have a type annotation") + else: + arguments[name] = (arg.annotation, ...) + + model = create_model("Arguments", **arguments) + + return model.model_json_schema() diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 6dd3d79a7..9670d7dda 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -60,6 +60,8 @@ class DetokenizerManager: if first_token.startswith("▁"): output_strs[i] = " " + output_strs[i] + output_strs[i] = recv_obj.output_and_fast_forward_strs[i] + output_strs[i] + self.send_to_tokenizer.send_pyobj( BatchStrOut( recv_obj.rids, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 7d2cbf3a2..c4380c49a 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -59,6 +59,7 @@ class GenerateReqInput: @dataclass class TokenizedGenerateReqInput: rid: str + input_text: str input_ids: List[int] pixel_values: List[float] image_hash: int @@ -73,6 +74,7 @@ class TokenizedGenerateReqInput: class BatchTokenIDOut: rids: List[str] output_tokens: List[List[int]] + output_and_fast_forward_strs: List[str] hit_stop_str: List[Optional[str]] skip_special_tokens: List[bool] meta_info: List[Dict] diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index dd98801df..00ada2955 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -23,6 +23,7 @@ class FinishReason(Enum): class Req: def __init__(self, rid): self.rid = rid + self.input_text = None self.input_ids = [] self.output_ids = [] self.pixel_values = None @@ -48,10 +49,44 @@ class Req: # for constrained decoding self.regex_fsm = None self.regex_fsm_state = 0 + self.fast_forward_map = None + self.output_and_fast_forward_str = "" def max_new_tokens(self): return self.sampling_params.max_new_tokens + def tokenize_fast_forward(self, fast_forward_str, next_state): + old_output_str = self.tokenizer.decode(self.output_ids) + if self.tokenizer.convert_ids_to_tokens(self.output_ids[0]).startswith("▁"): + old_output_str = " " + old_output_str + new_input_string = ( + self.input_text + + self.output_and_fast_forward_str + + old_output_str + + fast_forward_str + ) + new_input_ids = self.tokenizer.encode(new_input_string) + fast_forward_tokens_len = ( + len(new_input_ids) - len(self.input_ids) - len(self.output_ids) + ) + # print("=" * 100) + # print(f"Catch fast forward:\n{fast_forward_str}") + # print(self.tokenizer.convert_ids_to_tokens(self.input_ids)) + # print(self.tokenizer.convert_ids_to_tokens(new_input_ids)) + + self.input_ids = new_input_ids + self.output_ids = [] + self.sampling_params.max_new_tokens = max( + self.sampling_params.max_new_tokens - fast_forward_tokens_len, 0 + ) + self.regex_fsm_state = next_state + self.output_and_fast_forward_str = ( + self.output_and_fast_forward_str + old_output_str + fast_forward_str + ) + + # print(f"Output and fast forward str:\n{self.output_and_fast_forward_str}") + # print("*" * 100) + def check_finished(self): if self.finished: return @@ -263,6 +298,8 @@ class Batch: req.last_node = None req.extend_input_len = 0 req.output_ids = [] + req.regex_fsm_state = 0 + # TODO: apply more fine-grained retraction token_indices = self.req_to_token_pool.req_to_token[ @@ -274,6 +311,46 @@ class Batch: return retracted_reqs + def check_for_fast_forward(self): + fast_forward_reqs = [] + filter_indices = [i for i in range(len(self.reqs))] + + req_pool_indices_cpu = None + + for i, req in enumerate(self.reqs): + if req.fast_forward_map is not None: + res = req.fast_forward_map.fast_forward(req.regex_fsm_state) + if res is not None: + fast_forward_str, next_state = res + if len(fast_forward_str) <= 1: + continue + + # insert the old request into tree_cache + token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1] + if req_pool_indices_cpu is None: + req_pool_indices_cpu = self.req_pool_indices.cpu().tolist() + req_pool_idx = req_pool_indices_cpu[i] + indices = self.req_to_token_pool.req_to_token[ + req_pool_idx, : len(token_ids_in_memory) + ] + prefix_len = self.tree_cache.insert( + token_ids_in_memory, indices.clone() + ) + self.token_to_kv_pool.free(indices[:prefix_len]) + self.req_to_token_pool.free(req_pool_idx) + self.tree_cache.dec_ref_counter(req.last_node) + + # fast forward + req.tokenize_fast_forward(fast_forward_str, next_state) + + fast_forward_reqs.append(req) + filter_indices.remove(i) + + if len(filter_indices) < len(self.reqs): + self.filter_batch(filter_indices) + + return fast_forward_reqs + def prepare_for_decode(self, input_ids=None): if input_ids is None: input_ids = [ diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 4d77eed03..88f7291d1 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -21,6 +21,7 @@ from sglang.srt.managers.router.radix_cache import RadixCache from sglang.srt.managers.router.scheduler import Scheduler from sglang.srt.model_config import ModelConfig from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.constrained.fast_forward import FastForwardCache from sglang.srt.utils import ( get_exception_traceback, get_int_token_logit_bias, @@ -45,6 +46,7 @@ class ModelRpcServer(rpyc.Service): self.tp_rank = tp_rank self.tp_size = server_args.tp_size self.schedule_heuristic = server_args.schedule_heuristic + self.no_regex_fast_forward = server_args.no_regex_fast_forward # Init model and tokenizer self.model_config = ModelConfig( @@ -118,6 +120,7 @@ class ModelRpcServer(rpyc.Service): "trust_remote_code": server_args.trust_remote_code, }, ) + self.fast_forward_cache = FastForwardCache() # Init new token estimation self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0) @@ -201,6 +204,7 @@ class ModelRpcServer(rpyc.Service): recv_req: TokenizedGenerateReqInput, ): req = Req(recv_req.rid) + req.input_text = recv_req.input_text req.input_ids = recv_req.input_ids req.pixel_values = recv_req.pixel_values req.image_size = recv_req.image_size @@ -223,6 +227,10 @@ class ModelRpcServer(rpyc.Service): # Init regex fsm if req.sampling_params.regex is not None: req.regex_fsm = self.regex_fsm_cache.init_fsm(req.sampling_params.regex) + if not self.no_regex_fast_forward: + req.fast_forward_map = self.fast_forward_cache.init_fast_forward_map( + req.sampling_params.regex + ) # Truncate long prompts req.input_ids = req.input_ids[: self.model_config.context_len - 1] @@ -334,11 +342,6 @@ class ModelRpcServer(rpyc.Service): self.model_config.vocab_size, self.int_token_logit_bias ) - # Reset regex fsm state before first sampling due to retractions - for req in batch.reqs: - if req.sampling_params.regex is not None: - req.regex_fsm_state = 0 - if batch.extend_num_tokens != 0: # Forward logits, (logprobs, normalized_logprobs) = self.model_runner.forward( @@ -388,6 +391,13 @@ class ModelRpcServer(rpyc.Service): self.min_new_token_ratio, ) + if not self.no_regex_fast_forward: + # check for fast forward + fast_forward_reqs = batch.check_for_fast_forward() + self.forward_queue.extend(fast_forward_reqs) + if batch.is_empty(): + return + # Update batch tensors self.decode_forward_ct += 1 batch.prepare_for_decode() @@ -408,6 +418,7 @@ class ModelRpcServer(rpyc.Service): def handle_finished_requests(self, batch: Batch): output_rids = [] output_tokens = [] + output_and_fast_forward_strs = [] output_hit_stop_str = [] output_skip_special_tokens = [] output_meta_info = [] @@ -425,6 +436,7 @@ class ModelRpcServer(rpyc.Service): ): output_rids.append(req.rid) output_tokens.append(req.output_ids) + output_and_fast_forward_strs.append(req.output_and_fast_forward_str) output_hit_stop_str.append(req.hit_stop_str) output_skip_special_tokens.append( req.sampling_params.skip_special_tokens @@ -445,6 +457,7 @@ class ModelRpcServer(rpyc.Service): BatchTokenIDOut( output_rids, output_tokens, + output_and_fast_forward_strs, output_hit_stop_str, output_skip_special_tokens, output_meta_info, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index acc35c7d9..a6c49c45d 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -157,6 +157,7 @@ class TokenizerManager: ) tokenized_obj = TokenizedGenerateReqInput( rid=rid, + input_text=obj.text, input_ids=input_ids, pixel_values=pixel_values, image_hash=image_hash, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index dfcae6a59..5fcb6f5c2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -23,6 +23,7 @@ class ServerArgs: disable_log_stats: bool = False log_stats_interval: int = 10 log_level: str = "info" + no_regex_fast_forward: bool = False def __post_init__(self): if self.tokenizer_path is None: @@ -150,6 +151,11 @@ class ServerArgs: default=ServerArgs.log_stats_interval, help="Log stats interval in second.", ) + parser.add_argument( + "--no-regex-fast-forward", + action="store_true", + help="Disable regex fast forward", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): diff --git a/test/srt/test_fast_forward.py b/test/srt/test_fast_forward.py new file mode 100644 index 000000000..a94d15ae5 --- /dev/null +++ b/test/srt/test_fast_forward.py @@ -0,0 +1,137 @@ +import argparse +from enum import Enum + +import sglang as sgl +from pydantic import BaseModel, constr +from sglang.srt.constrained.json_schema import build_regex_from_object +from sglang.test.test_utils import ( + add_common_sglang_args_and_parse, + select_sglang_backend, +) + +IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" + +ip_fast_forward = ( + r"The google's DNS sever address is " + + IP_REGEX + + r" and " + + IP_REGEX + + r". " + + r"The google's website domain name is " + + r"www\.(\w)+\.(\w)+" + + r"." +) + + +# fmt: off +@sgl.function +def regex_gen(s): + s += "Q: What is the IP address of the Google DNS servers?\n" + s += "A: " + sgl.gen( + "answer", + max_tokens=128, + temperature=0, + regex=ip_fast_forward, + ) +# fmt: on + +json_fast_forward = ( + r"""The information about Hogwarts is in the following JSON format\.\n""" + + r"""\n\{\n""" + + r""" "name": "[\w\d\s]*",\n""" + + r""" "country": "[\w\d\s]*",\n""" + + r""" "latitude": [-+]?[0-9]*\.?[0-9]+,\n""" + + r""" "population": [-+]?[0-9]+,\n""" + + r""" "top 3 landmarks": \["[\w\d\s]*", "[\w\d\s]*", "[\w\d\s]*"\],\n""" + + r"""\}\n""" +) + +# fmt: off +@sgl.function +def json_gen(s): + s += sgl.gen( + "json", + max_tokens=128, + temperature=0, + regex=json_fast_forward, + ) +# fmt: on + + +class Weapon(str, Enum): + sword = "sword" + axe = "axe" + mace = "mace" + spear = "spear" + bow = "bow" + crossbow = "crossbow" + + +class Armor(str, Enum): + leather = "leather" + chainmail = "chainmail" + plate = "plate" + + +class Character(BaseModel): + name: constr(max_length=10) + age: int + armor: Armor + weapon: Weapon + strength: int + + +@sgl.function +def character_gen(s): + s += "Give me a character description who is a wizard.\n" + s += sgl.gen( + "character", + max_tokens=128, + temperature=0, + regex=build_regex_from_object(Character), + ) + + +def main(args): + # Select backend + backend = select_sglang_backend(args) + sgl.set_default_backend(backend) + + state = regex_gen.run(temperature=0) + + print("=" * 20, "IP TEST", "=" * 20) + print(state.text()) + + state = json_gen.run(temperature=0) + + print("=" * 20, "JSON TEST", "=" * 20) + print(state.text()) + + state = character_gen.run(temperature=0) + + print("=" * 20, "CHARACTER TEST", "=" * 20) + print(state.text()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + args = add_common_sglang_args_and_parse(parser) + main(args) + +# ==================== IP TEST ==================== +# Q: What is the IP address of the Google DNS servers? +# A: The google's DNS sever address is 8.8.8.8 and 8.8.4.4. The google's website domain name is www.google.com. +# ==================== JSON TEST ==================== +# The information about Hogwarts is in the following JSON format. + +# { +# "name": "Hogwarts School of Witchcraft and Wizardry", +# "country": "Scotland", +# "latitude": 55.566667, +# "population": 1000, +# "top 3 landmarks": ["Hogwarts Castle", "The Great Hall", "The Forbidden Forest"], +# } + +# ==================== CHARACTER TEST ==================== +# Give me a character description who is a wizard. +# { "name" : "Merlin", "age" : 500, "armor" : "chainmail" , "weapon" : "sword" , "strength" : 10 } diff --git a/test/srt/test_robust.py b/test/srt/test_robust.py index 5b479318f..9b4ceaf5f 100644 --- a/test/srt/test_robust.py +++ b/test/srt/test_robust.py @@ -2,14 +2,13 @@ import argparse import random import string +import sglang as sgl from sglang.test.test_utils import ( add_common_sglang_args_and_parse, select_sglang_backend, ) from vllm.transformers_utils.tokenizer import get_tokenizer -import sglang as sgl - TOKENIZER = None RANDOM_PREFILL_LEN = None RANDOM_DECODE_LEN = None