fast regex decode
Auto-detect constant str path in regex FSM, then extend instead.
This commit is contained in:
46
benchmark/json_fast_forward/README.md
Normal file
46
benchmark/json_fast_forward/README.md
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
## Run benchmark
|
||||||
|
|
||||||
|
### Dependencies
|
||||||
|
|
||||||
|
```
|
||||||
|
llama_cpp_python 0.2.32
|
||||||
|
guidance 0.1.10
|
||||||
|
vllm 0.2.7
|
||||||
|
outlines 0.0.24
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark sglang
|
||||||
|
|
||||||
|
Run Llama-7B
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
Benchmark
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_sglang.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark vllm
|
||||||
|
|
||||||
|
Run Llama-7B
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||||
|
```
|
||||||
|
|
||||||
|
Benchmark
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --backend vllm
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark guidance (seems not supported)
|
||||||
|
|
||||||
|
Run Llama-7B and benchmark
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --backend guidance --parallel 1
|
||||||
|
```
|
||||||
135
benchmark/json_fast_forward/bench_other.py
Normal file
135
benchmark/json_fast_forward/bench_other.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import guidance
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_other_args_and_parse,
|
||||||
|
call_generate_outlines,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# there are some FSM bugs with json regex converted from pydantic model
|
||||||
|
# here use a string regex instead
|
||||||
|
# regex_string = build_regex_from_object(HarryPoterRole)
|
||||||
|
character_regex = (
|
||||||
|
r"""\{\n"""
|
||||||
|
+ r""" "name": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
|
||||||
|
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
|
||||||
|
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
|
||||||
|
+ r""" "wand": \{\n"""
|
||||||
|
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "core": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
|
||||||
|
+ r""" \},\n"""
|
||||||
|
+ r""" "alive": "(Alive|Deceased)",\n"""
|
||||||
|
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
|
||||||
|
+ r"""\}"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
def character_gen(name, generate):
|
||||||
|
s = name+ " is a character in Harry Potter. Please fill in the following information about him/her.\n"
|
||||||
|
s += generate(s, max_tokens=256, regex=character_regex)
|
||||||
|
return s
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
@guidance
|
||||||
|
def character_maker(lm, name):
|
||||||
|
regex_str_no_quote = r"[\w\d\s]+"
|
||||||
|
regex_float = r"[0-9]+\.[0-9]+"
|
||||||
|
lm += f"""\
|
||||||
|
{name} is a character in Harry Potter. Please fill in the following information about him/her.
|
||||||
|
{{
|
||||||
|
"name": "{guidance.gen("name", max_tokens=16, regex=regex_str_no_quote)}",
|
||||||
|
"house": "{guidance.select(options=['Gryffindor', 'Slytherin', 'Ravenclaw', 'Hufflepuff'], name='house')}",
|
||||||
|
"blood status": "{guidance.select(options=['Pure-blood', 'Half-blood', 'Muggle-born'], name='blood status')}",
|
||||||
|
"occupation": "{guidance.select(options=['student', 'teacher', 'auror', 'ministry of magic', 'death eater', 'order of the phoenix'], name='occupation')}",
|
||||||
|
"wand": {{
|
||||||
|
"wood": "{guidance.gen("wood", max_tokens=16, regex=regex_str_no_quote)}",
|
||||||
|
"core": "{guidance.gen('core', max_tokens=16, regex=regex_str_no_quote)}",
|
||||||
|
"length": {guidance.gen('length', max_tokens=10, regex=regex_float)}
|
||||||
|
}},
|
||||||
|
"alive": "{guidance.select(options=['Alive', 'Deceased'], name='alive')}",
|
||||||
|
"patronus": "{guidance.gen('patronus', max_tokens=16, regex=regex_str_no_quote)}",
|
||||||
|
"bogart": "{guidance.gen('bogart', max_tokens=16, regex=regex_str_no_quote)}"
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
return lm
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
arguments = []
|
||||||
|
with open(args.data_path, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
arguments.append({"name": line.strip()})
|
||||||
|
arguments = arguments[: args.num_jsons]
|
||||||
|
|
||||||
|
states = [None] * len(arguments)
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
if args.backend == "vllm":
|
||||||
|
url = f"{args.host}:{args.port}/generate"
|
||||||
|
generate = partial(call_generate_outlines, url=url, temperature=0)
|
||||||
|
|
||||||
|
def func(i):
|
||||||
|
states[i] = character_gen(**arguments[i], generate=generate)
|
||||||
|
|
||||||
|
get_one_answer = func
|
||||||
|
elif args.backend == "guidance":
|
||||||
|
model = guidance.models.LlamaCpp(
|
||||||
|
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf",
|
||||||
|
n_gpu_layers=-1,
|
||||||
|
n_ctx=4096,
|
||||||
|
)
|
||||||
|
|
||||||
|
def func(i):
|
||||||
|
lm = model + character_maker(**arguments[i])
|
||||||
|
states[i] = lm
|
||||||
|
|
||||||
|
get_one_answer = func
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid backend: {args.backend}")
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
if args.parallel == 1:
|
||||||
|
for i in tqdm(range(len(arguments))):
|
||||||
|
get_one_answer(i)
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(args.parallel) as executor:
|
||||||
|
rets = executor.map(get_one_answer, list(range(len(arguments))))
|
||||||
|
for _ in rets:
|
||||||
|
pass
|
||||||
|
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "json_fast_forward",
|
||||||
|
"backend": args.backend,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"num_jsons": args.num_jsons,
|
||||||
|
"parallel": args.parallel,
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-path", type=str, default="dataset.txt")
|
||||||
|
parser.add_argument("--num-jsons", type=int, default=50)
|
||||||
|
args = add_common_other_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
92
benchmark/json_fast_forward/bench_sglang.py
Normal file
92
benchmark/json_fast_forward/bench_sglang.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text
|
||||||
|
|
||||||
|
# there are some FSM bugs with json regex converted from pydantic model
|
||||||
|
# here use a string regex instead
|
||||||
|
# regex_string = build_regex_from_object(HarryPoterRole)
|
||||||
|
character_regex = (
|
||||||
|
r"""\{\n"""
|
||||||
|
+ r""" "name": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
|
||||||
|
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
|
||||||
|
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
|
||||||
|
+ r""" "wand": \{\n"""
|
||||||
|
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "core": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
|
||||||
|
+ r""" \},\n"""
|
||||||
|
+ r""" "alive": "(Alive|Deceased)",\n"""
|
||||||
|
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
|
||||||
|
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
|
||||||
|
+ r"""\}"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
@sgl.function
|
||||||
|
def character_gen(s, name):
|
||||||
|
s += name+ " is a character in Harry Potter. Please fill in the following information about him/her.\n"
|
||||||
|
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
def bench_character(args):
|
||||||
|
arguments = []
|
||||||
|
with open(args.data_path, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
arguments.append({"name": line.strip()})
|
||||||
|
arguments = arguments[: args.num_jsons]
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
backend = select_sglang_backend(args)
|
||||||
|
sgl.set_default_backend(backend)
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
tic = time.time()
|
||||||
|
states = character_gen.run_batch(
|
||||||
|
arguments,
|
||||||
|
temperature=0,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=(args.parallel == 1),
|
||||||
|
)
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
return states, latency
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
states, latency = bench_character(args)
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
with open(f"{args.backend}.json", "w") as fout:
|
||||||
|
for state in states:
|
||||||
|
fout.write(state["json_output"] + "\n")
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "json_fast_forward",
|
||||||
|
"backend": args.backend,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"num_jsons": args.num_jsons,
|
||||||
|
"parallel": args.parallel,
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-path", type=str, default="dataset.txt")
|
||||||
|
parser.add_argument("--num-jsons", type=int, default=50)
|
||||||
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
50
benchmark/json_fast_forward/dataset.txt
Normal file
50
benchmark/json_fast_forward/dataset.txt
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
Harry Potter
|
||||||
|
Hermione Granger
|
||||||
|
Ron Weasley
|
||||||
|
Albus Dumbledore
|
||||||
|
Severus Snape
|
||||||
|
Rubeus Hagrid
|
||||||
|
Draco Malfoy
|
||||||
|
Ginny Weasley
|
||||||
|
Fred Weasley
|
||||||
|
George Weasley
|
||||||
|
Percy Weasley
|
||||||
|
Sirius Black
|
||||||
|
Remus Lupin
|
||||||
|
Neville Longbottom
|
||||||
|
Luna Lovegood
|
||||||
|
Cedric Diggory
|
||||||
|
Cho Chang
|
||||||
|
Lord Voldemort
|
||||||
|
Minerva McGonagall
|
||||||
|
Filius Flitwick
|
||||||
|
Dolores Umbridge
|
||||||
|
Bellatrix Lestrange
|
||||||
|
Lucius Malfoy
|
||||||
|
Molly Weasley
|
||||||
|
Arthur Weasley
|
||||||
|
Nymphadora Tonks
|
||||||
|
Dobby
|
||||||
|
Moaning Myrtle
|
||||||
|
Peter Pettigrew
|
||||||
|
Alastor 'Mad-Eye' Moody
|
||||||
|
Horace Slughorn
|
||||||
|
Vernon Dursley
|
||||||
|
Petunia Dursley
|
||||||
|
Dudley Dursley
|
||||||
|
Argus Filch
|
||||||
|
Sybill Trelawney
|
||||||
|
Gilderoy Lockhart
|
||||||
|
Fleur Delacour
|
||||||
|
Viktor Krum
|
||||||
|
Bill Weasley
|
||||||
|
Oliver Wood
|
||||||
|
Cornelius Fudge
|
||||||
|
Barty Crouch Sr.
|
||||||
|
Barty Crouch Jr.
|
||||||
|
Kingsley Shacklebolt
|
||||||
|
Quirinus Quirrell
|
||||||
|
Nearly Headless Nick
|
||||||
|
Aunt Marge
|
||||||
|
Griphook
|
||||||
|
Ludo Bagman
|
||||||
@@ -91,12 +91,32 @@ def run_program_batch(
|
|||||||
|
|
||||||
if num_threads == 1:
|
if num_threads == 1:
|
||||||
rets = []
|
rets = []
|
||||||
for arguments in batch_arguments:
|
if progress_bar:
|
||||||
rets.append(
|
for arguments in tqdm.tqdm(batch_arguments):
|
||||||
run_program(
|
rets.append(
|
||||||
program, backend, (), arguments, default_sampling_para, False, True
|
run_program(
|
||||||
|
program,
|
||||||
|
backend,
|
||||||
|
(),
|
||||||
|
arguments,
|
||||||
|
default_sampling_para,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for arguments in batch_arguments:
|
||||||
|
rets.append(
|
||||||
|
run_program(
|
||||||
|
program,
|
||||||
|
backend,
|
||||||
|
(),
|
||||||
|
arguments,
|
||||||
|
default_sampling_para,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if progress_bar:
|
if progress_bar:
|
||||||
pbar = tqdm.tqdm(total=len(batch_arguments))
|
pbar = tqdm.tqdm(total=len(batch_arguments))
|
||||||
|
|||||||
78
python/sglang/srt/constrained/fast_forward.py
Normal file
78
python/sglang/srt/constrained/fast_forward.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import interegular
|
||||||
|
from sglang.srt.constrained.disk_cache import disk_cache
|
||||||
|
from sglang.srt.constrained.regex import FSMInfo, make_deterministic_fsm
|
||||||
|
|
||||||
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||||
|
|
||||||
|
|
||||||
|
class FastForwardMap:
|
||||||
|
def __init__(self, regex_string):
|
||||||
|
@disk_cache()
|
||||||
|
def _init_state_to_fast_forward(regex_string):
|
||||||
|
regex_pattern = interegular.parse_pattern(regex_string)
|
||||||
|
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
|
||||||
|
|
||||||
|
fsm_info: FSMInfo = regex_fsm.fsm_info
|
||||||
|
|
||||||
|
symbol_to_id = fsm_info.alphabet_symbol_mapping
|
||||||
|
id_to_symbol = {}
|
||||||
|
for symbol, id_ in symbol_to_id.items():
|
||||||
|
id_to_symbol.setdefault(id_, []).append(symbol)
|
||||||
|
|
||||||
|
transitions = fsm_info.transitions
|
||||||
|
dirty_states = set()
|
||||||
|
state_to_fast_forward = {}
|
||||||
|
|
||||||
|
for (state, id_), next_state in transitions.items():
|
||||||
|
if state in dirty_states:
|
||||||
|
continue
|
||||||
|
if state in state_to_fast_forward:
|
||||||
|
dirty_states.add(state)
|
||||||
|
del state_to_fast_forward[state]
|
||||||
|
continue
|
||||||
|
if len(id_to_symbol[id_]) > 1:
|
||||||
|
dirty_states.add(state)
|
||||||
|
continue
|
||||||
|
|
||||||
|
state_to_fast_forward[state] = (id_to_symbol[id_][0], next_state)
|
||||||
|
|
||||||
|
return state_to_fast_forward
|
||||||
|
|
||||||
|
self.state_to_fast_forward = _init_state_to_fast_forward(regex_string)
|
||||||
|
|
||||||
|
def valid_states(self):
|
||||||
|
return self.state_to_fast_forward.keys()
|
||||||
|
|
||||||
|
def fast_forward(self, state):
|
||||||
|
if state not in self.state_to_fast_forward:
|
||||||
|
return None
|
||||||
|
|
||||||
|
fast_forward_str = ""
|
||||||
|
next_state = None
|
||||||
|
while state in self.state_to_fast_forward:
|
||||||
|
symbol, next_state = self.state_to_fast_forward[state]
|
||||||
|
fast_forward_str += symbol
|
||||||
|
state = next_state
|
||||||
|
return fast_forward_str, next_state
|
||||||
|
|
||||||
|
|
||||||
|
class FastForwardCache:
|
||||||
|
def __init__(self):
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
def init_fast_forward_map(self, regex_string):
|
||||||
|
if regex_string not in self.cache:
|
||||||
|
fast_forward_map = FastForwardMap(regex_string)
|
||||||
|
self.cache[regex_string] = fast_forward_map
|
||||||
|
return self.cache[regex_string]
|
||||||
|
|
||||||
|
|
||||||
|
def test_main():
|
||||||
|
regex_string = r"The google's DNS sever address is " + IP_REGEX
|
||||||
|
fast_forward_map = FastForwardMap(regex_string)
|
||||||
|
for state in fast_forward_map.valid_states():
|
||||||
|
print(state, f'"{fast_forward_map.fast_forward(state)}"')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_main()
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
from sglang.srt.constrained.fsm import RegexFSM
|
from sglang.srt.constrained.fsm import RegexFSM
|
||||||
from sglang.srt.constrained.tokenizer import TransformerTokenizer
|
from sglang.srt.constrained.tokenizer import TransformerTokenizer
|
||||||
|
|
||||||
|
_enable_memory_cache = True
|
||||||
|
|
||||||
|
|
||||||
class FSMCache:
|
class FSMCache:
|
||||||
def __init__(self, tokenizer_path, tokenizer_args_dict):
|
def __init__(self, tokenizer_path, tokenizer_args_dict):
|
||||||
@@ -10,8 +12,10 @@ class FSMCache:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_fsm(self, regex):
|
def init_fsm(self, regex):
|
||||||
if regex not in self.cache:
|
if _enable_memory_cache:
|
||||||
fsm = RegexFSM(regex, self.outlines_tokenizer)
|
if regex not in self.cache:
|
||||||
self.cache[regex] = fsm
|
fsm = RegexFSM(regex, self.outlines_tokenizer)
|
||||||
|
self.cache[regex] = fsm
|
||||||
|
return self.cache[regex]
|
||||||
|
|
||||||
return self.cache[regex]
|
return RegexFSM(regex, self.outlines_tokenizer)
|
||||||
|
|||||||
290
python/sglang/srt/constrained/json_schema.py
Normal file
290
python/sglang/srt/constrained/json_schema.py
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
# Adapted from:
|
||||||
|
# https://github.com/outlines-dev/outlines/blob/8a0bafc8d82937babc5d586dd4f72ae844407e0e/outlines/fsm/json_schema.py
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Callable, Union
|
||||||
|
|
||||||
|
from jsonschema.protocols import Validator
|
||||||
|
from pydantic import BaseModel, create_model
|
||||||
|
from referencing import Registry, Resource
|
||||||
|
from referencing._core import Resolver
|
||||||
|
from referencing.jsonschema import DRAFT202012
|
||||||
|
|
||||||
|
STRING_INNER = r'(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)'
|
||||||
|
STRING = f'"{STRING_INNER}*"'
|
||||||
|
INTEGER = r"(0|[1-9][0-9]*)"
|
||||||
|
NUMBER = rf"(-)?({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?"
|
||||||
|
BOOLEAN = r"(true|false)"
|
||||||
|
NULL = r"null"
|
||||||
|
|
||||||
|
type_to_regex = {
|
||||||
|
"string": STRING,
|
||||||
|
"integer": INTEGER,
|
||||||
|
"number": NUMBER,
|
||||||
|
"boolean": BOOLEAN,
|
||||||
|
"null": NULL,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_regex_from_object(object: Union[str, Callable, BaseModel]):
|
||||||
|
"""Turn a JSON schema into a regex that matches any JSON object that follows
|
||||||
|
this schema.
|
||||||
|
|
||||||
|
JSON Schema is a declarative language that allows to annotate JSON documents
|
||||||
|
with types and descriptions. These schemas can be generated from any Python
|
||||||
|
datastructure that has type annotation: namedtuples, dataclasses, Pydantic
|
||||||
|
models. And by ensuring that the generation respects the schema we ensure
|
||||||
|
that the output can be parsed into these objects.
|
||||||
|
This function parses the provided schema and builds a generation schedule which
|
||||||
|
mixes deterministic generation (fixed strings), and sampling with constraints.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
schema
|
||||||
|
A string that represents a JSON Schema.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A generation schedule. A list of strings that represent the JSON
|
||||||
|
schema's structure and regular expression that define the structure of
|
||||||
|
the fields.
|
||||||
|
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [0] JSON Schema. https://json-schema.org/
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(object, type(BaseModel)):
|
||||||
|
schema = object.model_json_schema()
|
||||||
|
elif callable(object):
|
||||||
|
schema = get_schema_from_signature(object)
|
||||||
|
else:
|
||||||
|
schema = json.loads(object)
|
||||||
|
|
||||||
|
Validator.check_schema(schema)
|
||||||
|
|
||||||
|
# Build reference resolver
|
||||||
|
schema = Resource(contents=schema, specification=DRAFT202012)
|
||||||
|
uri = schema.id() if schema.id() is not None else ""
|
||||||
|
registry = Registry().with_resource(uri=uri, resource=schema)
|
||||||
|
resolver = registry.resolver()
|
||||||
|
|
||||||
|
content = schema.contents
|
||||||
|
return to_regex(resolver, content)
|
||||||
|
|
||||||
|
|
||||||
|
def to_regex(resolver: Resolver, instance: dict):
|
||||||
|
"""Translate a JSON Schema instance into a regex that validates the schema.
|
||||||
|
|
||||||
|
Note
|
||||||
|
----
|
||||||
|
Many features of JSON schema are missing:
|
||||||
|
- Handle `additionalProperties` keyword
|
||||||
|
- Handle types defined as a list
|
||||||
|
- Handle constraints on numbers
|
||||||
|
- Handle special patterns: `date`, `uri`, etc.
|
||||||
|
|
||||||
|
This does not support recursive definitions.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
resolver
|
||||||
|
An object that resolves references to other instances within a schema
|
||||||
|
instance
|
||||||
|
The instance to translate
|
||||||
|
"""
|
||||||
|
whitespace = r"[\n ]*"
|
||||||
|
|
||||||
|
if "properties" in instance:
|
||||||
|
regex = ""
|
||||||
|
regex += r"\{"
|
||||||
|
properties = instance["properties"]
|
||||||
|
required_properties = instance.get("required", [])
|
||||||
|
is_required = [item in required_properties for item in properties]
|
||||||
|
# If at least one property is required, we include the one in the lastest position
|
||||||
|
# without any comma.
|
||||||
|
# For each property before it (optional or required), we add with a comma after the property.
|
||||||
|
# For each property after it (optional), we add with a comma before the property.
|
||||||
|
if any(is_required):
|
||||||
|
last_required_pos = max([i for i, value in enumerate(is_required) if value])
|
||||||
|
for i, (name, value) in enumerate(properties.items()):
|
||||||
|
subregex = f'{whitespace}"{name}"{whitespace}:{whitespace}'
|
||||||
|
subregex += to_regex(resolver, value)
|
||||||
|
if i < last_required_pos:
|
||||||
|
subregex = f"{subregex}{whitespace},"
|
||||||
|
elif i > last_required_pos:
|
||||||
|
subregex = f"{whitespace},{subregex}"
|
||||||
|
regex += subregex if is_required[i] else f"({subregex})?"
|
||||||
|
# If no property is required, we have to create a possible pattern for each property in which
|
||||||
|
# it's the last one necessarilly present. Then, we add the others as optional before and after
|
||||||
|
# following the same strategy as described above.
|
||||||
|
# The whole block is made optional to allow the case in which no property is returned.
|
||||||
|
else:
|
||||||
|
property_subregexes = []
|
||||||
|
for i, (name, value) in enumerate(properties.items()):
|
||||||
|
subregex = f'{whitespace}"{name}"{whitespace}:{whitespace}'
|
||||||
|
subregex += to_regex(resolver, value)
|
||||||
|
property_subregexes.append(subregex)
|
||||||
|
possible_patterns = []
|
||||||
|
for i in range(len(property_subregexes)):
|
||||||
|
pattern = ""
|
||||||
|
for subregex in property_subregexes[:i]:
|
||||||
|
pattern += f"({subregex}{whitespace},)?"
|
||||||
|
pattern += property_subregexes[i]
|
||||||
|
for subregex in property_subregexes[i + 1 :]:
|
||||||
|
pattern += f"({whitespace},{subregex})?"
|
||||||
|
possible_patterns.append(pattern)
|
||||||
|
regex += f"({'|'.join(possible_patterns)})?"
|
||||||
|
|
||||||
|
regex += f"{whitespace}" + r"\}"
|
||||||
|
|
||||||
|
return regex
|
||||||
|
|
||||||
|
# To validate against allOf, the given data must be valid against all of the
|
||||||
|
# given subschemas.
|
||||||
|
elif "allOf" in instance:
|
||||||
|
subregexes = [to_regex(resolver, t) for t in instance["allOf"]]
|
||||||
|
subregexes_str = [f"{subregex}" for subregex in subregexes]
|
||||||
|
return rf"({''.join(subregexes_str)})"
|
||||||
|
|
||||||
|
# To validate against `anyOf`, the given data must be valid against
|
||||||
|
# any (one or more) of the given subschemas.
|
||||||
|
elif "anyOf" in instance:
|
||||||
|
subregexes = [to_regex(resolver, t) for t in instance["anyOf"]]
|
||||||
|
return rf"({'|'.join(subregexes)})"
|
||||||
|
|
||||||
|
# To validate against oneOf, the given data must be valid against exactly
|
||||||
|
# one of the given subschemas.
|
||||||
|
elif "oneOf" in instance:
|
||||||
|
subregexes = [to_regex(resolver, t) for t in instance["oneOf"]]
|
||||||
|
|
||||||
|
xor_patterns = []
|
||||||
|
# json schema validation ensured there is no overlapping schemas in oneOf
|
||||||
|
for subregex in subregexes:
|
||||||
|
other_subregexes = filter(lambda r: r != subregex, subregexes)
|
||||||
|
other_subregexes_str = "|".join([f"{s}" for s in other_subregexes])
|
||||||
|
negative_lookahead = f"(?!.*({other_subregexes_str}))"
|
||||||
|
xor_patterns.append(f"({subregex}){negative_lookahead}")
|
||||||
|
|
||||||
|
return rf"({'|'.join(xor_patterns)})"
|
||||||
|
|
||||||
|
# The enum keyword is used to restrict a value to a fixed set of values. It
|
||||||
|
# must be an array with at least one element, where each element is unique.
|
||||||
|
elif "enum" in instance:
|
||||||
|
choices = []
|
||||||
|
for choice in instance["enum"]:
|
||||||
|
if type(choice) in [int, float, bool, None]:
|
||||||
|
choices.append(re.escape(str(choice)))
|
||||||
|
elif type(choice) == str:
|
||||||
|
choices.append(f'"{re.escape(choice)}"')
|
||||||
|
|
||||||
|
return f"({'|'.join(choices)})"
|
||||||
|
|
||||||
|
elif "$ref" in instance:
|
||||||
|
path = f"{instance['$ref']}"
|
||||||
|
instance = resolver.lookup(path).contents
|
||||||
|
return to_regex(resolver, instance)
|
||||||
|
|
||||||
|
# The type keyword may either be a string or an array:
|
||||||
|
# - If it's a string, it is the name of one of the basic types.
|
||||||
|
# - If it is an array, it must be an array of strings, where each string is
|
||||||
|
# the name of one of the basic types, and each element is unique. In this
|
||||||
|
# case, the JSON snippet is valid if it matches any of the given types.
|
||||||
|
elif "type" in instance:
|
||||||
|
instance_type = instance["type"]
|
||||||
|
if instance_type == "string":
|
||||||
|
if "maxLength" in instance or "minLength" in instance:
|
||||||
|
max_items = instance.get("maxLength", "")
|
||||||
|
min_items = instance.get("minLength", "")
|
||||||
|
try:
|
||||||
|
if int(max_items) < int(min_items):
|
||||||
|
raise ValueError(
|
||||||
|
"maxLength must be greater than or equal to minLength"
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return f'"{STRING_INNER}{{{min_items},{max_items}}}"'
|
||||||
|
elif "pattern" in instance:
|
||||||
|
pattern = instance["pattern"]
|
||||||
|
if pattern[0] == "^" and pattern[-1] == "$":
|
||||||
|
return rf'(^"{pattern[1:-1]}"$)'
|
||||||
|
else:
|
||||||
|
return rf'("{pattern}")'
|
||||||
|
else:
|
||||||
|
return type_to_regex["string"]
|
||||||
|
|
||||||
|
elif instance_type == "number":
|
||||||
|
return type_to_regex["number"]
|
||||||
|
|
||||||
|
elif instance_type == "integer":
|
||||||
|
return type_to_regex["integer"]
|
||||||
|
|
||||||
|
elif instance_type == "array":
|
||||||
|
min_items = instance.get("minItems", "0")
|
||||||
|
max_items = instance.get("maxItems", "")
|
||||||
|
if min_items == max_items:
|
||||||
|
num_repeats = "{" + str(int(min_items) - 1) + "}"
|
||||||
|
else:
|
||||||
|
num_repeats = "*"
|
||||||
|
|
||||||
|
if "items" in instance:
|
||||||
|
items_regex = to_regex(resolver, instance["items"])
|
||||||
|
return rf"\[({items_regex})(,({items_regex})){num_repeats}\]"
|
||||||
|
else:
|
||||||
|
# Here we need to make the choice to exclude generating list of objects
|
||||||
|
# if the specification of the object is not given, even though a JSON
|
||||||
|
# object that contains an object here would be valid under the specification.
|
||||||
|
types = [
|
||||||
|
{"type": "boolean"},
|
||||||
|
{"type": "null"},
|
||||||
|
{"type": "number"},
|
||||||
|
{"type": "integer"},
|
||||||
|
{"type": "string"},
|
||||||
|
]
|
||||||
|
regexes = [to_regex(resolver, t) for t in types]
|
||||||
|
return (
|
||||||
|
rf"\[({'|'.join(regexes)})(,({'|'.join(regexes)})){num_repeats}\]"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif instance_type == "boolean":
|
||||||
|
return type_to_regex["boolean"]
|
||||||
|
|
||||||
|
elif instance_type == "null":
|
||||||
|
return type_to_regex["null"]
|
||||||
|
|
||||||
|
elif isinstance(instance_type, list):
|
||||||
|
# Here we need to make the choice to exclude generating an object
|
||||||
|
# if the specification of the object is not give, even though a JSON
|
||||||
|
# object that contains an object here would be valid under the specification.
|
||||||
|
regexes = [
|
||||||
|
to_regex(resolver, {"type": t}) for t in instance_type if t != "object"
|
||||||
|
]
|
||||||
|
return rf"({'|'.join(regexes)})"
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"""Could not translate the instance {instance} to a
|
||||||
|
regular expression. Make sure it is valid to the JSON Schema specification. If
|
||||||
|
it is, please open an issue on the Outlines repository"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_schema_from_signature(fn: Callable) -> str:
|
||||||
|
"""Turn a function signature into a JSON schema.
|
||||||
|
|
||||||
|
Every JSON object valid to the output JSON Schema can be passed
|
||||||
|
to `fn` using the ** unpacking syntax.
|
||||||
|
|
||||||
|
"""
|
||||||
|
signature = inspect.signature(fn)
|
||||||
|
arguments = {}
|
||||||
|
for name, arg in signature.parameters.items():
|
||||||
|
if arg.annotation == inspect._empty:
|
||||||
|
raise ValueError("Each argument must have a type annotation")
|
||||||
|
else:
|
||||||
|
arguments[name] = (arg.annotation, ...)
|
||||||
|
|
||||||
|
model = create_model("Arguments", **arguments)
|
||||||
|
|
||||||
|
return model.model_json_schema()
|
||||||
@@ -60,6 +60,8 @@ class DetokenizerManager:
|
|||||||
if first_token.startswith("▁"):
|
if first_token.startswith("▁"):
|
||||||
output_strs[i] = " " + output_strs[i]
|
output_strs[i] = " " + output_strs[i]
|
||||||
|
|
||||||
|
output_strs[i] = recv_obj.output_and_fast_forward_strs[i] + output_strs[i]
|
||||||
|
|
||||||
self.send_to_tokenizer.send_pyobj(
|
self.send_to_tokenizer.send_pyobj(
|
||||||
BatchStrOut(
|
BatchStrOut(
|
||||||
recv_obj.rids,
|
recv_obj.rids,
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ class GenerateReqInput:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class TokenizedGenerateReqInput:
|
class TokenizedGenerateReqInput:
|
||||||
rid: str
|
rid: str
|
||||||
|
input_text: str
|
||||||
input_ids: List[int]
|
input_ids: List[int]
|
||||||
pixel_values: List[float]
|
pixel_values: List[float]
|
||||||
image_hash: int
|
image_hash: int
|
||||||
@@ -73,6 +74,7 @@ class TokenizedGenerateReqInput:
|
|||||||
class BatchTokenIDOut:
|
class BatchTokenIDOut:
|
||||||
rids: List[str]
|
rids: List[str]
|
||||||
output_tokens: List[List[int]]
|
output_tokens: List[List[int]]
|
||||||
|
output_and_fast_forward_strs: List[str]
|
||||||
hit_stop_str: List[Optional[str]]
|
hit_stop_str: List[Optional[str]]
|
||||||
skip_special_tokens: List[bool]
|
skip_special_tokens: List[bool]
|
||||||
meta_info: List[Dict]
|
meta_info: List[Dict]
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ class FinishReason(Enum):
|
|||||||
class Req:
|
class Req:
|
||||||
def __init__(self, rid):
|
def __init__(self, rid):
|
||||||
self.rid = rid
|
self.rid = rid
|
||||||
|
self.input_text = None
|
||||||
self.input_ids = []
|
self.input_ids = []
|
||||||
self.output_ids = []
|
self.output_ids = []
|
||||||
self.pixel_values = None
|
self.pixel_values = None
|
||||||
@@ -48,10 +49,44 @@ class Req:
|
|||||||
# for constrained decoding
|
# for constrained decoding
|
||||||
self.regex_fsm = None
|
self.regex_fsm = None
|
||||||
self.regex_fsm_state = 0
|
self.regex_fsm_state = 0
|
||||||
|
self.fast_forward_map = None
|
||||||
|
self.output_and_fast_forward_str = ""
|
||||||
|
|
||||||
def max_new_tokens(self):
|
def max_new_tokens(self):
|
||||||
return self.sampling_params.max_new_tokens
|
return self.sampling_params.max_new_tokens
|
||||||
|
|
||||||
|
def tokenize_fast_forward(self, fast_forward_str, next_state):
|
||||||
|
old_output_str = self.tokenizer.decode(self.output_ids)
|
||||||
|
if self.tokenizer.convert_ids_to_tokens(self.output_ids[0]).startswith("▁"):
|
||||||
|
old_output_str = " " + old_output_str
|
||||||
|
new_input_string = (
|
||||||
|
self.input_text
|
||||||
|
+ self.output_and_fast_forward_str
|
||||||
|
+ old_output_str
|
||||||
|
+ fast_forward_str
|
||||||
|
)
|
||||||
|
new_input_ids = self.tokenizer.encode(new_input_string)
|
||||||
|
fast_forward_tokens_len = (
|
||||||
|
len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
|
||||||
|
)
|
||||||
|
# print("=" * 100)
|
||||||
|
# print(f"Catch fast forward:\n{fast_forward_str}")
|
||||||
|
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
|
||||||
|
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
|
||||||
|
|
||||||
|
self.input_ids = new_input_ids
|
||||||
|
self.output_ids = []
|
||||||
|
self.sampling_params.max_new_tokens = max(
|
||||||
|
self.sampling_params.max_new_tokens - fast_forward_tokens_len, 0
|
||||||
|
)
|
||||||
|
self.regex_fsm_state = next_state
|
||||||
|
self.output_and_fast_forward_str = (
|
||||||
|
self.output_and_fast_forward_str + old_output_str + fast_forward_str
|
||||||
|
)
|
||||||
|
|
||||||
|
# print(f"Output and fast forward str:\n{self.output_and_fast_forward_str}")
|
||||||
|
# print("*" * 100)
|
||||||
|
|
||||||
def check_finished(self):
|
def check_finished(self):
|
||||||
if self.finished:
|
if self.finished:
|
||||||
return
|
return
|
||||||
@@ -263,6 +298,8 @@ class Batch:
|
|||||||
req.last_node = None
|
req.last_node = None
|
||||||
req.extend_input_len = 0
|
req.extend_input_len = 0
|
||||||
req.output_ids = []
|
req.output_ids = []
|
||||||
|
req.regex_fsm_state = 0
|
||||||
|
|
||||||
# TODO: apply more fine-grained retraction
|
# TODO: apply more fine-grained retraction
|
||||||
|
|
||||||
token_indices = self.req_to_token_pool.req_to_token[
|
token_indices = self.req_to_token_pool.req_to_token[
|
||||||
@@ -274,6 +311,46 @@ class Batch:
|
|||||||
|
|
||||||
return retracted_reqs
|
return retracted_reqs
|
||||||
|
|
||||||
|
def check_for_fast_forward(self):
|
||||||
|
fast_forward_reqs = []
|
||||||
|
filter_indices = [i for i in range(len(self.reqs))]
|
||||||
|
|
||||||
|
req_pool_indices_cpu = None
|
||||||
|
|
||||||
|
for i, req in enumerate(self.reqs):
|
||||||
|
if req.fast_forward_map is not None:
|
||||||
|
res = req.fast_forward_map.fast_forward(req.regex_fsm_state)
|
||||||
|
if res is not None:
|
||||||
|
fast_forward_str, next_state = res
|
||||||
|
if len(fast_forward_str) <= 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# insert the old request into tree_cache
|
||||||
|
token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
|
||||||
|
if req_pool_indices_cpu is None:
|
||||||
|
req_pool_indices_cpu = self.req_pool_indices.cpu().tolist()
|
||||||
|
req_pool_idx = req_pool_indices_cpu[i]
|
||||||
|
indices = self.req_to_token_pool.req_to_token[
|
||||||
|
req_pool_idx, : len(token_ids_in_memory)
|
||||||
|
]
|
||||||
|
prefix_len = self.tree_cache.insert(
|
||||||
|
token_ids_in_memory, indices.clone()
|
||||||
|
)
|
||||||
|
self.token_to_kv_pool.free(indices[:prefix_len])
|
||||||
|
self.req_to_token_pool.free(req_pool_idx)
|
||||||
|
self.tree_cache.dec_ref_counter(req.last_node)
|
||||||
|
|
||||||
|
# fast forward
|
||||||
|
req.tokenize_fast_forward(fast_forward_str, next_state)
|
||||||
|
|
||||||
|
fast_forward_reqs.append(req)
|
||||||
|
filter_indices.remove(i)
|
||||||
|
|
||||||
|
if len(filter_indices) < len(self.reqs):
|
||||||
|
self.filter_batch(filter_indices)
|
||||||
|
|
||||||
|
return fast_forward_reqs
|
||||||
|
|
||||||
def prepare_for_decode(self, input_ids=None):
|
def prepare_for_decode(self, input_ids=None):
|
||||||
if input_ids is None:
|
if input_ids is None:
|
||||||
input_ids = [
|
input_ids = [
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from sglang.srt.managers.router.radix_cache import RadixCache
|
|||||||
from sglang.srt.managers.router.scheduler import Scheduler
|
from sglang.srt.managers.router.scheduler import Scheduler
|
||||||
from sglang.srt.model_config import ModelConfig
|
from sglang.srt.model_config import ModelConfig
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
|
from sglang.srt.constrained.fast_forward import FastForwardCache
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_exception_traceback,
|
get_exception_traceback,
|
||||||
get_int_token_logit_bias,
|
get_int_token_logit_bias,
|
||||||
@@ -45,6 +46,7 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.tp_size = server_args.tp_size
|
self.tp_size = server_args.tp_size
|
||||||
self.schedule_heuristic = server_args.schedule_heuristic
|
self.schedule_heuristic = server_args.schedule_heuristic
|
||||||
|
self.no_regex_fast_forward = server_args.no_regex_fast_forward
|
||||||
|
|
||||||
# Init model and tokenizer
|
# Init model and tokenizer
|
||||||
self.model_config = ModelConfig(
|
self.model_config = ModelConfig(
|
||||||
@@ -118,6 +120,7 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
"trust_remote_code": server_args.trust_remote_code,
|
"trust_remote_code": server_args.trust_remote_code,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.fast_forward_cache = FastForwardCache()
|
||||||
|
|
||||||
# Init new token estimation
|
# Init new token estimation
|
||||||
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
|
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
|
||||||
@@ -201,6 +204,7 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
recv_req: TokenizedGenerateReqInput,
|
recv_req: TokenizedGenerateReqInput,
|
||||||
):
|
):
|
||||||
req = Req(recv_req.rid)
|
req = Req(recv_req.rid)
|
||||||
|
req.input_text = recv_req.input_text
|
||||||
req.input_ids = recv_req.input_ids
|
req.input_ids = recv_req.input_ids
|
||||||
req.pixel_values = recv_req.pixel_values
|
req.pixel_values = recv_req.pixel_values
|
||||||
req.image_size = recv_req.image_size
|
req.image_size = recv_req.image_size
|
||||||
@@ -223,6 +227,10 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
# Init regex fsm
|
# Init regex fsm
|
||||||
if req.sampling_params.regex is not None:
|
if req.sampling_params.regex is not None:
|
||||||
req.regex_fsm = self.regex_fsm_cache.init_fsm(req.sampling_params.regex)
|
req.regex_fsm = self.regex_fsm_cache.init_fsm(req.sampling_params.regex)
|
||||||
|
if not self.no_regex_fast_forward:
|
||||||
|
req.fast_forward_map = self.fast_forward_cache.init_fast_forward_map(
|
||||||
|
req.sampling_params.regex
|
||||||
|
)
|
||||||
|
|
||||||
# Truncate long prompts
|
# Truncate long prompts
|
||||||
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
|
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
|
||||||
@@ -334,11 +342,6 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
self.model_config.vocab_size, self.int_token_logit_bias
|
self.model_config.vocab_size, self.int_token_logit_bias
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reset regex fsm state before first sampling due to retractions
|
|
||||||
for req in batch.reqs:
|
|
||||||
if req.sampling_params.regex is not None:
|
|
||||||
req.regex_fsm_state = 0
|
|
||||||
|
|
||||||
if batch.extend_num_tokens != 0:
|
if batch.extend_num_tokens != 0:
|
||||||
# Forward
|
# Forward
|
||||||
logits, (logprobs, normalized_logprobs) = self.model_runner.forward(
|
logits, (logprobs, normalized_logprobs) = self.model_runner.forward(
|
||||||
@@ -388,6 +391,13 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
self.min_new_token_ratio,
|
self.min_new_token_ratio,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.no_regex_fast_forward:
|
||||||
|
# check for fast forward
|
||||||
|
fast_forward_reqs = batch.check_for_fast_forward()
|
||||||
|
self.forward_queue.extend(fast_forward_reqs)
|
||||||
|
if batch.is_empty():
|
||||||
|
return
|
||||||
|
|
||||||
# Update batch tensors
|
# Update batch tensors
|
||||||
self.decode_forward_ct += 1
|
self.decode_forward_ct += 1
|
||||||
batch.prepare_for_decode()
|
batch.prepare_for_decode()
|
||||||
@@ -408,6 +418,7 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
def handle_finished_requests(self, batch: Batch):
|
def handle_finished_requests(self, batch: Batch):
|
||||||
output_rids = []
|
output_rids = []
|
||||||
output_tokens = []
|
output_tokens = []
|
||||||
|
output_and_fast_forward_strs = []
|
||||||
output_hit_stop_str = []
|
output_hit_stop_str = []
|
||||||
output_skip_special_tokens = []
|
output_skip_special_tokens = []
|
||||||
output_meta_info = []
|
output_meta_info = []
|
||||||
@@ -425,6 +436,7 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
):
|
):
|
||||||
output_rids.append(req.rid)
|
output_rids.append(req.rid)
|
||||||
output_tokens.append(req.output_ids)
|
output_tokens.append(req.output_ids)
|
||||||
|
output_and_fast_forward_strs.append(req.output_and_fast_forward_str)
|
||||||
output_hit_stop_str.append(req.hit_stop_str)
|
output_hit_stop_str.append(req.hit_stop_str)
|
||||||
output_skip_special_tokens.append(
|
output_skip_special_tokens.append(
|
||||||
req.sampling_params.skip_special_tokens
|
req.sampling_params.skip_special_tokens
|
||||||
@@ -445,6 +457,7 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
BatchTokenIDOut(
|
BatchTokenIDOut(
|
||||||
output_rids,
|
output_rids,
|
||||||
output_tokens,
|
output_tokens,
|
||||||
|
output_and_fast_forward_strs,
|
||||||
output_hit_stop_str,
|
output_hit_stop_str,
|
||||||
output_skip_special_tokens,
|
output_skip_special_tokens,
|
||||||
output_meta_info,
|
output_meta_info,
|
||||||
|
|||||||
@@ -157,6 +157,7 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
tokenized_obj = TokenizedGenerateReqInput(
|
tokenized_obj = TokenizedGenerateReqInput(
|
||||||
rid=rid,
|
rid=rid,
|
||||||
|
input_text=obj.text,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
image_hash=image_hash,
|
image_hash=image_hash,
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ class ServerArgs:
|
|||||||
disable_log_stats: bool = False
|
disable_log_stats: bool = False
|
||||||
log_stats_interval: int = 10
|
log_stats_interval: int = 10
|
||||||
log_level: str = "info"
|
log_level: str = "info"
|
||||||
|
no_regex_fast_forward: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tokenizer_path is None:
|
if self.tokenizer_path is None:
|
||||||
@@ -150,6 +151,11 @@ class ServerArgs:
|
|||||||
default=ServerArgs.log_stats_interval,
|
default=ServerArgs.log_stats_interval,
|
||||||
help="Log stats interval in second.",
|
help="Log stats interval in second.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-regex-fast-forward",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable regex fast forward",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
|
|||||||
137
test/srt/test_fast_forward.py
Normal file
137
test/srt/test_fast_forward.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
import argparse
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from pydantic import BaseModel, constr
|
||||||
|
from sglang.srt.constrained.json_schema import build_regex_from_object
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||||
|
|
||||||
|
ip_fast_forward = (
|
||||||
|
r"The google's DNS sever address is "
|
||||||
|
+ IP_REGEX
|
||||||
|
+ r" and "
|
||||||
|
+ IP_REGEX
|
||||||
|
+ r". "
|
||||||
|
+ r"The google's website domain name is "
|
||||||
|
+ r"www\.(\w)+\.(\w)+"
|
||||||
|
+ r"."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
@sgl.function
|
||||||
|
def regex_gen(s):
|
||||||
|
s += "Q: What is the IP address of the Google DNS servers?\n"
|
||||||
|
s += "A: " + sgl.gen(
|
||||||
|
"answer",
|
||||||
|
max_tokens=128,
|
||||||
|
temperature=0,
|
||||||
|
regex=ip_fast_forward,
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
json_fast_forward = (
|
||||||
|
r"""The information about Hogwarts is in the following JSON format\.\n"""
|
||||||
|
+ r"""\n\{\n"""
|
||||||
|
+ r""" "name": "[\w\d\s]*",\n"""
|
||||||
|
+ r""" "country": "[\w\d\s]*",\n"""
|
||||||
|
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]+,\n"""
|
||||||
|
+ r""" "population": [-+]?[0-9]+,\n"""
|
||||||
|
+ r""" "top 3 landmarks": \["[\w\d\s]*", "[\w\d\s]*", "[\w\d\s]*"\],\n"""
|
||||||
|
+ r"""\}\n"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
@sgl.function
|
||||||
|
def json_gen(s):
|
||||||
|
s += sgl.gen(
|
||||||
|
"json",
|
||||||
|
max_tokens=128,
|
||||||
|
temperature=0,
|
||||||
|
regex=json_fast_forward,
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class Weapon(str, Enum):
|
||||||
|
sword = "sword"
|
||||||
|
axe = "axe"
|
||||||
|
mace = "mace"
|
||||||
|
spear = "spear"
|
||||||
|
bow = "bow"
|
||||||
|
crossbow = "crossbow"
|
||||||
|
|
||||||
|
|
||||||
|
class Armor(str, Enum):
|
||||||
|
leather = "leather"
|
||||||
|
chainmail = "chainmail"
|
||||||
|
plate = "plate"
|
||||||
|
|
||||||
|
|
||||||
|
class Character(BaseModel):
|
||||||
|
name: constr(max_length=10)
|
||||||
|
age: int
|
||||||
|
armor: Armor
|
||||||
|
weapon: Weapon
|
||||||
|
strength: int
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def character_gen(s):
|
||||||
|
s += "Give me a character description who is a wizard.\n"
|
||||||
|
s += sgl.gen(
|
||||||
|
"character",
|
||||||
|
max_tokens=128,
|
||||||
|
temperature=0,
|
||||||
|
regex=build_regex_from_object(Character),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# Select backend
|
||||||
|
backend = select_sglang_backend(args)
|
||||||
|
sgl.set_default_backend(backend)
|
||||||
|
|
||||||
|
state = regex_gen.run(temperature=0)
|
||||||
|
|
||||||
|
print("=" * 20, "IP TEST", "=" * 20)
|
||||||
|
print(state.text())
|
||||||
|
|
||||||
|
state = json_gen.run(temperature=0)
|
||||||
|
|
||||||
|
print("=" * 20, "JSON TEST", "=" * 20)
|
||||||
|
print(state.text())
|
||||||
|
|
||||||
|
state = character_gen.run(temperature=0)
|
||||||
|
|
||||||
|
print("=" * 20, "CHARACTER TEST", "=" * 20)
|
||||||
|
print(state.text())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
|
|
||||||
|
# ==================== IP TEST ====================
|
||||||
|
# Q: What is the IP address of the Google DNS servers?
|
||||||
|
# A: The google's DNS sever address is 8.8.8.8 and 8.8.4.4. The google's website domain name is www.google.com.
|
||||||
|
# ==================== JSON TEST ====================
|
||||||
|
# The information about Hogwarts is in the following JSON format.
|
||||||
|
|
||||||
|
# {
|
||||||
|
# "name": "Hogwarts School of Witchcraft and Wizardry",
|
||||||
|
# "country": "Scotland",
|
||||||
|
# "latitude": 55.566667,
|
||||||
|
# "population": 1000,
|
||||||
|
# "top 3 landmarks": ["Hogwarts Castle", "The Great Hall", "The Forbidden Forest"],
|
||||||
|
# }
|
||||||
|
|
||||||
|
# ==================== CHARACTER TEST ====================
|
||||||
|
# Give me a character description who is a wizard.
|
||||||
|
# { "name" : "Merlin", "age" : 500, "armor" : "chainmail" , "weapon" : "sword" , "strength" : 10 }
|
||||||
@@ -2,14 +2,13 @@ import argparse
|
|||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
add_common_sglang_args_and_parse,
|
add_common_sglang_args_and_parse,
|
||||||
select_sglang_backend,
|
select_sglang_backend,
|
||||||
)
|
)
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
import sglang as sgl
|
|
||||||
|
|
||||||
TOKENIZER = None
|
TOKENIZER = None
|
||||||
RANDOM_PREFILL_LEN = None
|
RANDOM_PREFILL_LEN = None
|
||||||
RANDOM_DECODE_LEN = None
|
RANDOM_DECODE_LEN = None
|
||||||
|
|||||||
Reference in New Issue
Block a user