Json Decode && Mutl-Turns (#4)
This commit is contained in:
2
3rdparty/flashinfer
vendored
2
3rdparty/flashinfer
vendored
Submodule 3rdparty/flashinfer updated: 00cf5f46fd...88b9496e1a
61
benchmark/json_regex_decode/README.md
Normal file
61
benchmark/json_regex_decode/README.md
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
## Run benchmark
|
||||||
|
|
||||||
|
### Build dataset
|
||||||
|
```
|
||||||
|
pip install wikipedia
|
||||||
|
python3 build_dataset.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Dependencies
|
||||||
|
|
||||||
|
```
|
||||||
|
llama_cpp_python 0.2.19
|
||||||
|
guidance 0.1.10
|
||||||
|
vllm 0.2.5
|
||||||
|
outlines 0.0.22
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark sglang
|
||||||
|
|
||||||
|
Run llama-7b
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
Run mixtral-8x7b
|
||||||
|
(When there is a CUDA out-of-memory error, try to reduce the `--mem-fraction-static`)
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8
|
||||||
|
```
|
||||||
|
|
||||||
|
Benchmark
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_sglang.py --num-questions 10
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark vllm
|
||||||
|
|
||||||
|
Run llama-7b
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||||
|
```
|
||||||
|
|
||||||
|
Benchmark
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --backend vllm --num-questions 10
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Benchmark guidance
|
||||||
|
|
||||||
|
Run llama-7b and benchmark
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --backend guidance --num-questions 10 --parallel 1
|
||||||
|
```
|
||||||
125
benchmark/json_regex_decode/bench_other.py
Normal file
125
benchmark/json_regex_decode/bench_other.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_other_args_and_parse,
|
||||||
|
call_generate_outlines,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
from sglang.lang.ir import REGEX_INT, REGEX_STRING, REGEX_FLOAT
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]"
|
||||||
|
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
def json_decode(document, generate):
|
||||||
|
s = "Please extract the information of a city from the following wikipedia page.\n"
|
||||||
|
s += "Page begin.\n" + document + "Page end.\n"
|
||||||
|
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||||
|
s += "{\n"
|
||||||
|
s += ' "name": '
|
||||||
|
s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n"
|
||||||
|
s += ' "country": '
|
||||||
|
s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n"
|
||||||
|
s += ' "latitude": '
|
||||||
|
s += generate(s, max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
|
||||||
|
s += ' "population": '
|
||||||
|
s += generate(s, max_tokens=8, regex=REGEX_INT + ",") + "\n"
|
||||||
|
s += ' "top 3 landmarks": '
|
||||||
|
s += generate(s, max_tokens=24, regex=REGEX_LIST) + "\n"
|
||||||
|
s += "}\n"
|
||||||
|
|
||||||
|
return s
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
lines = read_jsonl(args.data_path)
|
||||||
|
arguments = []
|
||||||
|
for i in range(len(lines[: args.num_questions])):
|
||||||
|
arguments.append(
|
||||||
|
{
|
||||||
|
"document": lines[i]["document"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
states = [None] * len(arguments)
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
if args.backend == "vllm":
|
||||||
|
url = f"{args.host}:{args.port}/generate"
|
||||||
|
generate = partial(call_generate_outlines, url=url, temperature=0)
|
||||||
|
elif args.backend == "guidance":
|
||||||
|
from guidance import gen, models
|
||||||
|
|
||||||
|
model = models.LlamaCpp(
|
||||||
|
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf",
|
||||||
|
n_gpu_layers=-1,
|
||||||
|
n_ctx=4096,
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate(prompt, max_tokens, stop=None, regex=None):
|
||||||
|
out = (
|
||||||
|
model
|
||||||
|
+ prompt
|
||||||
|
+ gen(
|
||||||
|
name="answer",
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=0,
|
||||||
|
stop=stop,
|
||||||
|
regex=regex,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return out["answer"]
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
for _ in range(3):
|
||||||
|
generate("Hello!" * 10, max_tokens=64, stop=None)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid backend: {args.backend}")
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
def get_one_answer(i):
|
||||||
|
states[i] = json_decode(generate=generate, **arguments[i])
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
if args.parallel == 1:
|
||||||
|
for i in tqdm(range(len(arguments))):
|
||||||
|
get_one_answer(i)
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(args.parallel) as executor:
|
||||||
|
rets = executor.map(get_one_answer, list(range(len(arguments))))
|
||||||
|
for _ in rets:
|
||||||
|
pass
|
||||||
|
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "json_regex_decode",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"num_requests": args.num_questions,
|
||||||
|
"other": {
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
||||||
|
parser.add_argument("--num-questions", type=int, default=20)
|
||||||
|
args = add_common_other_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
100
benchmark/json_regex_decode/bench_sglang.py
Normal file
100
benchmark/json_regex_decode/bench_sglang.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.lang.ir import REGEX_INT, REGEX_STRING, REGEX_FLOAT
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]"
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
@sgl.function
|
||||||
|
def json_warm_up(s):
|
||||||
|
s += "The information about Hogwarts is in the following JSON format.\n"
|
||||||
|
with s.var_scope("json_output"):
|
||||||
|
s += "{\n"
|
||||||
|
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
|
||||||
|
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
|
||||||
|
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
|
||||||
|
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
|
||||||
|
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
|
||||||
|
s += "}\n"
|
||||||
|
print(f'The warmp up json result is:\n{s["json_output"]}')
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
@sgl.function
|
||||||
|
def json_decode(s, document):
|
||||||
|
s += "Please extract the information of a city from the following wikipedia page.\n"
|
||||||
|
s += "Page begin.\n" + document + "Page end.\n"
|
||||||
|
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||||
|
with s.var_scope("json_output"):
|
||||||
|
s += "{\n"
|
||||||
|
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
|
||||||
|
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
|
||||||
|
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
|
||||||
|
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
|
||||||
|
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
|
||||||
|
s += "}\n"
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
lines = read_jsonl(args.data_path)
|
||||||
|
arguments = []
|
||||||
|
for i in range(len(lines[: args.num_questions])):
|
||||||
|
arguments.append(
|
||||||
|
{
|
||||||
|
"document": lines[i]["document"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select backend
|
||||||
|
backend = select_sglang_backend(args)
|
||||||
|
sgl.set_default_backend(backend)
|
||||||
|
|
||||||
|
# Warm up
|
||||||
|
json_warm_up.run().sync()
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
tic = time.time()
|
||||||
|
states = json_decode.run_batch(arguments, temperature=0, num_threads=args.parallel)
|
||||||
|
for state in states:
|
||||||
|
state.sync()
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
|
with open(f"tmp_{args.backend}_json_results.txt", "w") as fout:
|
||||||
|
for state in states:
|
||||||
|
fout.write(state["json_output"] + "\n")
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "json_regex_decode",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"num_requests": args.num_questions,
|
||||||
|
"other": {
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-path", type=str, default="questions.jsonl")
|
||||||
|
parser.add_argument("--num-questions", type=int, default=20)
|
||||||
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
|
main(args)
|
||||||
58
benchmark/json_regex_decode/build_dataset.py
Normal file
58
benchmark/json_regex_decode/build_dataset.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
import wikipedia
|
||||||
|
|
||||||
|
model_path = "meta-llama/Llama-2-7b-chat-hf"
|
||||||
|
t = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||||
|
city_names = [
|
||||||
|
"los angles",
|
||||||
|
"london",
|
||||||
|
"tokyo",
|
||||||
|
"beijing",
|
||||||
|
"singapore",
|
||||||
|
"paris",
|
||||||
|
"dubai",
|
||||||
|
"sydney",
|
||||||
|
"moscow",
|
||||||
|
"rome",
|
||||||
|
"toronto",
|
||||||
|
"rio de janeiro",
|
||||||
|
"istanbul",
|
||||||
|
"berlin",
|
||||||
|
"auckland",
|
||||||
|
"buenos aires",
|
||||||
|
"mexico city",
|
||||||
|
"mumbai",
|
||||||
|
"seoul",
|
||||||
|
"bangkok",
|
||||||
|
"cairo",
|
||||||
|
"athens",
|
||||||
|
"jerusalem",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_content(city_name):
|
||||||
|
content = str(wikipedia.page(city_name).content)
|
||||||
|
content = content.replace("\n\n", "\n")
|
||||||
|
|
||||||
|
tokens = t.encode(content)
|
||||||
|
|
||||||
|
expected_tokens = 3000
|
||||||
|
truncate_len = int((expected_tokens / len(tokens)) * len(content))
|
||||||
|
truncate_content = content[:truncate_len]
|
||||||
|
truncate_tokens = t.encode(truncate_content)
|
||||||
|
|
||||||
|
# Count token
|
||||||
|
print(
|
||||||
|
f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return truncate_content
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
with open("questions.jsonl", "w") as fout:
|
||||||
|
for city_name in city_names:
|
||||||
|
truncate_content = get_content(city_name)
|
||||||
|
fout.write(json.dumps({"document": truncate_content}) + "\n")
|
||||||
66
benchmark/multi_turns/README.md
Normal file
66
benchmark/multi_turns/README.md
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
### Benchmark sglang
|
||||||
|
|
||||||
|
Run llama-7b
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
Run mixtral-8x7b
|
||||||
|
(When there is a CUDA out-of-memory error, try to reduce the `--mem-fraction-static`)
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8
|
||||||
|
```
|
||||||
|
|
||||||
|
Benchmark(short output)
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf
|
||||||
|
```
|
||||||
|
|
||||||
|
Benchmark(long output)
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf --long
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark vLLM
|
||||||
|
|
||||||
|
Run llama-7b
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
|
||||||
|
```
|
||||||
|
|
||||||
|
Run mixtral-8x7b
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model mistralai/Mixtral-8x7B-Instruct-v0.1 --disable-log-requests --port 21000 --tensor-parallel-size 8
|
||||||
|
```
|
||||||
|
|
||||||
|
Benchmark(short output)
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm
|
||||||
|
```
|
||||||
|
|
||||||
|
Benchmark(long output)
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm --long
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark guidance
|
||||||
|
|
||||||
|
Benchmark llama-7b(short output)
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1
|
||||||
|
```
|
||||||
|
|
||||||
|
Benchmark llama-7b(long output)
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --long
|
||||||
|
```
|
||||||
133
benchmark/multi_turns/bench_other.py
Normal file
133
benchmark/multi_turns/bench_other.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from sglang.test.test_utils import add_common_other_args_and_parse
|
||||||
|
from sglang.utils import dump_state_text
|
||||||
|
from tqdm import tqdm
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
from data_gen import gen_arguments
|
||||||
|
|
||||||
|
|
||||||
|
def get_generate(args):
|
||||||
|
# Select backend
|
||||||
|
if args.backend == "vllm":
|
||||||
|
url = f"{args.host}:{args.port}/generate"
|
||||||
|
|
||||||
|
def generate(prompt, max_tokens, stop=None, temperature=0, url=url, n=1):
|
||||||
|
data = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"ignore_eos": True,
|
||||||
|
"stop": stop,
|
||||||
|
"stream": False,
|
||||||
|
"n": n,
|
||||||
|
}
|
||||||
|
res = requests.post(url, json=data)
|
||||||
|
assert res.status_code == 200
|
||||||
|
return res.json()["text"][0][len(prompt) :]
|
||||||
|
|
||||||
|
elif args.backend == "guidance":
|
||||||
|
from guidance import gen, models
|
||||||
|
|
||||||
|
model = models.LlamaCpp(
|
||||||
|
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf",
|
||||||
|
n_gpu_layers=-1,
|
||||||
|
n_ctx=4096,
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate(prompt, max_tokens, stop=None):
|
||||||
|
out = (
|
||||||
|
model
|
||||||
|
+ prompt
|
||||||
|
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
|
||||||
|
)
|
||||||
|
return out["answer"]
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
for _ in range(3):
|
||||||
|
generate("Hello!" * 10, max_tokens=64, stop=None)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid backend: {args.backend}")
|
||||||
|
|
||||||
|
return generate
|
||||||
|
|
||||||
|
|
||||||
|
def multi_turns(generate, qas):
|
||||||
|
s = ""
|
||||||
|
for qa in qas:
|
||||||
|
s += qa["prompt"]
|
||||||
|
s += generate(s, max_tokens=qa["new_tokens"])
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||||
|
|
||||||
|
multi_qas = gen_arguments(args, tokenizer)
|
||||||
|
|
||||||
|
states = [None] * args.num_qa
|
||||||
|
|
||||||
|
generate = get_generate(args)
|
||||||
|
|
||||||
|
def get_one_answer(i):
|
||||||
|
states[i] = multi_turns(generate=generate, **multi_qas[i])
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
if args.parallel == 1:
|
||||||
|
for i in tqdm(range(len(multi_qas))):
|
||||||
|
get_one_answer(i)
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(args.parallel) as executor:
|
||||||
|
rets = executor.map(get_one_answer, list(range(len(multi_qas))))
|
||||||
|
for _ in rets:
|
||||||
|
pass
|
||||||
|
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "multi_turns",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"num_requests": args.num_qa,
|
||||||
|
"num_turns": args.turns,
|
||||||
|
"other": {
|
||||||
|
"parallel": args.parallel,
|
||||||
|
"output_mode": "long" if args.long else "short",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--turns", type=int, default=4)
|
||||||
|
parser.add_argument("--num-qa", type=int, default=20)
|
||||||
|
parser.add_argument("--min-len-q", type=int, default=256)
|
||||||
|
parser.add_argument("--max-len-q", type=int, default=512)
|
||||||
|
parser.add_argument("--min-len-a", type=int, default=4)
|
||||||
|
parser.add_argument("--max-len-a", type=int, default=8)
|
||||||
|
parser.add_argument("--tokenizer", type=str, required=True)
|
||||||
|
parser.add_argument("--trust-remote-code", action="store_true")
|
||||||
|
parser.add_argument("--long", action="store_true")
|
||||||
|
args = add_common_other_args_and_parse(parser)
|
||||||
|
|
||||||
|
if args.long:
|
||||||
|
args.min_len_a = 256
|
||||||
|
args.max_len_a = 512
|
||||||
|
args.num_qa = 20
|
||||||
|
main(args)
|
||||||
77
benchmark/multi_turns/bench_sglang.py
Normal file
77
benchmark/multi_turns/bench_sglang.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
from data_gen import gen_arguments
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def multi_turns(s, qas):
|
||||||
|
for qa in qas:
|
||||||
|
s += qa["prompt"]
|
||||||
|
s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||||
|
|
||||||
|
multi_qas = gen_arguments(args, tokenizer)
|
||||||
|
|
||||||
|
backend = select_sglang_backend(args)
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
states = multi_turns.run_batch(
|
||||||
|
multi_qas, temperature=0, backend=backend, num_threads=args.parallel
|
||||||
|
)
|
||||||
|
for state in states:
|
||||||
|
state.sync()
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
print(f"Latency: {latency:.3f}")
|
||||||
|
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "multi_turns",
|
||||||
|
"backend": args.backend,
|
||||||
|
"num_gpus": 1,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"num_requests": args.num_qa,
|
||||||
|
"num_turns": args.turns,
|
||||||
|
"other": {
|
||||||
|
"parallel": args.parallel,
|
||||||
|
"output_mode": "long" if args.long else "short",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--turns", type=int, default=4)
|
||||||
|
parser.add_argument("--num-qa", type=int, default=20)
|
||||||
|
parser.add_argument("--min-len-q", type=int, default=256)
|
||||||
|
parser.add_argument("--max-len-q", type=int, default=512)
|
||||||
|
parser.add_argument("--min-len-a", type=int, default=4)
|
||||||
|
parser.add_argument("--max-len-a", type=int, default=8)
|
||||||
|
parser.add_argument("--tokenizer", type=str, required=True)
|
||||||
|
parser.add_argument("--trust-remote-code", action="store_true")
|
||||||
|
parser.add_argument("--long", action="store_true")
|
||||||
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
|
|
||||||
|
if args.long:
|
||||||
|
args.min_len_a = 256
|
||||||
|
args.max_len_a = 512
|
||||||
|
args.num_qa = 20
|
||||||
|
main(args)
|
||||||
29
benchmark/multi_turns/data_gen.py
Normal file
29
benchmark/multi_turns/data_gen.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
import random
|
||||||
|
import string
|
||||||
|
|
||||||
|
random.seed(42)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_prompt(tokenizer, token_num):
|
||||||
|
cha_set = string.ascii_letters + string.digits
|
||||||
|
ret = "".join(random.choices(cha_set, k=token_num))
|
||||||
|
while len(tokenizer(ret).input_ids) < token_num:
|
||||||
|
ret += random.choice(cha_set)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def gen_arguments(args, tokenizer):
|
||||||
|
multi_qas = [{"qas": []} for _ in range(args.num_qa)]
|
||||||
|
for i in range(args.num_qa):
|
||||||
|
qas = multi_qas[i]["qas"]
|
||||||
|
for _ in range(args.turns):
|
||||||
|
prompt_len = random.randint(args.min_len_q, args.max_len_q)
|
||||||
|
new_tokens = random.randint(args.min_len_a, args.max_len_a)
|
||||||
|
qas.append(
|
||||||
|
{
|
||||||
|
"prompt": gen_prompt(tokenizer, prompt_len),
|
||||||
|
"new_tokens": new_tokens,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return multi_qas
|
||||||
@@ -37,6 +37,7 @@ def gen(
|
|||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
|
ignore_eos: Optional[bool] = None,
|
||||||
dtype: Optional[type] = None,
|
dtype: Optional[type] = None,
|
||||||
choices: Optional[List[str]] = None,
|
choices: Optional[List[str]] = None,
|
||||||
regex: Optional[str] = None,
|
regex: Optional[str] = None,
|
||||||
@@ -60,6 +61,7 @@ def gen(
|
|||||||
top_k,
|
top_k,
|
||||||
frequency_penalty,
|
frequency_penalty,
|
||||||
presence_penalty,
|
presence_penalty,
|
||||||
|
ignore_eos,
|
||||||
dtype,
|
dtype,
|
||||||
regex,
|
regex,
|
||||||
)
|
)
|
||||||
@@ -74,6 +76,7 @@ def gen_int(
|
|||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
|
ignore_eos: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
return SglGen(
|
return SglGen(
|
||||||
name,
|
name,
|
||||||
@@ -84,6 +87,7 @@ def gen_int(
|
|||||||
top_k,
|
top_k,
|
||||||
frequency_penalty,
|
frequency_penalty,
|
||||||
presence_penalty,
|
presence_penalty,
|
||||||
|
ignore_eos,
|
||||||
int,
|
int,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -98,6 +102,7 @@ def gen_string(
|
|||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
|
ignore_eos: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
return SglGen(
|
return SglGen(
|
||||||
name,
|
name,
|
||||||
@@ -108,6 +113,7 @@ def gen_string(
|
|||||||
top_k,
|
top_k,
|
||||||
frequency_penalty,
|
frequency_penalty,
|
||||||
presence_penalty,
|
presence_penalty,
|
||||||
|
ignore_eos,
|
||||||
str,
|
str,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import numpy as np
|
|||||||
from sglang.backend.base_backend import BaseBackend
|
from sglang.backend.base_backend import BaseBackend
|
||||||
from sglang.lang.chat_template import get_chat_template
|
from sglang.lang.chat_template import get_chat_template
|
||||||
from sglang.lang.interpreter import StreamExecutor
|
from sglang.lang.interpreter import StreamExecutor
|
||||||
from sglang.lang.ir import SamplingParams
|
from sglang.lang.ir import SglSamplingParams
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import anthropic
|
import anthropic
|
||||||
@@ -28,7 +28,7 @@ class Anthropic(BaseBackend):
|
|||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
s: StreamExecutor,
|
s: StreamExecutor,
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SglSamplingParams,
|
||||||
):
|
):
|
||||||
prompt = s.text_
|
prompt = s.text_
|
||||||
ret = anthropic.Anthropic().completions.create(
|
ret = anthropic.Anthropic().completions.create(
|
||||||
@@ -43,7 +43,7 @@ class Anthropic(BaseBackend):
|
|||||||
def generate_stream(
|
def generate_stream(
|
||||||
self,
|
self,
|
||||||
s: StreamExecutor,
|
s: StreamExecutor,
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SglSamplingParams,
|
||||||
):
|
):
|
||||||
prompt = s.text_
|
prompt = s.text_
|
||||||
generator = anthropic.Anthropic().completions.create(
|
generator = anthropic.Anthropic().completions.create(
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from typing import Callable, List, Optional, Union
|
|||||||
|
|
||||||
from sglang.lang.chat_template import get_chat_template
|
from sglang.lang.chat_template import get_chat_template
|
||||||
from sglang.lang.interpreter import StreamExecutor
|
from sglang.lang.interpreter import StreamExecutor
|
||||||
from sglang.lang.ir import SamplingParams
|
from sglang.lang.ir import SglSamplingParams
|
||||||
|
|
||||||
|
|
||||||
class BaseBackend:
|
class BaseBackend:
|
||||||
@@ -48,14 +48,14 @@ class BaseBackend:
|
|||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
s: StreamExecutor,
|
s: StreamExecutor,
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SglSamplingParams,
|
||||||
):
|
):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def generate_stream(
|
def generate_stream(
|
||||||
self,
|
self,
|
||||||
s: StreamExecutor,
|
s: StreamExecutor,
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SglSamplingParams,
|
||||||
):
|
):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import numpy as np
|
|||||||
from sglang.backend.base_backend import BaseBackend
|
from sglang.backend.base_backend import BaseBackend
|
||||||
from sglang.lang.chat_template import get_chat_template
|
from sglang.lang.chat_template import get_chat_template
|
||||||
from sglang.lang.interpreter import StreamExecutor
|
from sglang.lang.interpreter import StreamExecutor
|
||||||
from sglang.lang.ir import SamplingParams
|
from sglang.lang.ir import SglSamplingParams
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
@@ -73,7 +73,7 @@ class OpenAI(BaseBackend):
|
|||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
s: StreamExecutor,
|
s: StreamExecutor,
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SglSamplingParams,
|
||||||
):
|
):
|
||||||
if sampling_params.dtype is None:
|
if sampling_params.dtype is None:
|
||||||
if self.is_chat_model:
|
if self.is_chat_model:
|
||||||
@@ -122,7 +122,7 @@ class OpenAI(BaseBackend):
|
|||||||
def generate_stream(
|
def generate_stream(
|
||||||
self,
|
self,
|
||||||
s: StreamExecutor,
|
s: StreamExecutor,
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SglSamplingParams,
|
||||||
):
|
):
|
||||||
if sampling_params.dtype is None:
|
if sampling_params.dtype is None:
|
||||||
if self.is_chat_model:
|
if self.is_chat_model:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from sglang.backend.base_backend import BaseBackend
|
|||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||||
from sglang.lang.interpreter import StreamExecutor
|
from sglang.lang.interpreter import StreamExecutor
|
||||||
from sglang.lang.ir import SamplingParams, SglArgument
|
from sglang.lang.ir import SglSamplingParams, SglArgument
|
||||||
from sglang.utils import encode_image_base64, find_printable_text, http_request
|
from sglang.utils import encode_image_base64, find_printable_text, http_request
|
||||||
|
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
s: StreamExecutor,
|
s: StreamExecutor,
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SglSamplingParams,
|
||||||
):
|
):
|
||||||
if sampling_params.dtype is None:
|
if sampling_params.dtype is None:
|
||||||
data = {
|
data = {
|
||||||
@@ -87,7 +87,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
def generate_stream(
|
def generate_stream(
|
||||||
self,
|
self,
|
||||||
s: StreamExecutor,
|
s: StreamExecutor,
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SglSamplingParams,
|
||||||
):
|
):
|
||||||
if sampling_params.dtype is None:
|
if sampling_params.dtype is None:
|
||||||
data = {
|
data = {
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import List, Optional, Union
|
|||||||
from sglang.backend.base_backend import BaseBackend
|
from sglang.backend.base_backend import BaseBackend
|
||||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||||
from sglang.lang.interpreter import StreamExecutor
|
from sglang.lang.interpreter import StreamExecutor
|
||||||
from sglang.lang.ir import SamplingParams
|
from sglang.lang.ir import SglSamplingParams
|
||||||
from sglang.utils import http_request
|
from sglang.utils import http_request
|
||||||
|
|
||||||
|
|
||||||
@@ -138,7 +138,7 @@ class TGI(BaseBackend):
|
|||||||
self,
|
self,
|
||||||
s: StreamExecutor,
|
s: StreamExecutor,
|
||||||
choices: List[str],
|
choices: List[str],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SglSamplingParams,
|
||||||
):
|
):
|
||||||
decision = self.retry_for_expected(
|
decision = self.retry_for_expected(
|
||||||
s.text_,
|
s.text_,
|
||||||
@@ -152,7 +152,7 @@ class TGI(BaseBackend):
|
|||||||
s: StreamExecutor,
|
s: StreamExecutor,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
stop: Union[str, List[str]],
|
stop: Union[str, List[str]],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SglSamplingParams,
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
):
|
):
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import List, Union
|
|||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
|
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
|
||||||
from sglang.lang.ir import (
|
from sglang.lang.ir import (
|
||||||
SamplingParams,
|
SglSamplingParams,
|
||||||
SglArgument,
|
SglArgument,
|
||||||
SglConstantText,
|
SglConstantText,
|
||||||
SglExpr,
|
SglExpr,
|
||||||
@@ -140,7 +140,7 @@ class CompiledFunction:
|
|||||||
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
|
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
|
||||||
kwargs.update(self.function.bind_arguments)
|
kwargs.update(self.function.bind_arguments)
|
||||||
|
|
||||||
default_sampling_para = SamplingParams(
|
default_sampling_para = SglSamplingParams(
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@@ -173,7 +173,7 @@ class CompiledFunction:
|
|||||||
|
|
||||||
backend = backend or global_config.default_backend
|
backend = backend or global_config.default_backend
|
||||||
|
|
||||||
default_sampling_para = SamplingParams(
|
default_sampling_para = SglSamplingParams(
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
|||||||
@@ -292,7 +292,7 @@ class StreamExecutor:
|
|||||||
|
|
||||||
assert isinstance(other, SglExpr), f"{other}"
|
assert isinstance(other, SglExpr), f"{other}"
|
||||||
|
|
||||||
if isinstance(other, (SglConstantText, SglArgument)):
|
if isinstance(other, SglConstantText):
|
||||||
self._execute_fill(other.value)
|
self._execute_fill(other.value)
|
||||||
elif isinstance(other, SglGen):
|
elif isinstance(other, SglGen):
|
||||||
self._execute_gen(other)
|
self._execute_gen(other)
|
||||||
@@ -332,8 +332,6 @@ class StreamExecutor:
|
|||||||
|
|
||||||
def _execute_image(self, expr: SglImage):
|
def _execute_image(self, expr: SglImage):
|
||||||
path = expr.path
|
path = expr.path
|
||||||
if isinstance(path, SglArgument):
|
|
||||||
path = path.value
|
|
||||||
|
|
||||||
base64_data = encode_image_base64(path)
|
base64_data = encode_image_base64(path)
|
||||||
|
|
||||||
@@ -419,7 +417,7 @@ class StreamExecutor:
|
|||||||
"role": expr.role,
|
"role": expr.role,
|
||||||
"content": [{"type": "text", "text": new_text}],
|
"content": [{"type": "text", "text": new_text}],
|
||||||
}
|
}
|
||||||
for (image_path, image_base64_data) in self.cur_images:
|
for image_path, image_base64_data in self.cur_images:
|
||||||
last_msg["content"].append(
|
last_msg["content"].append(
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
@@ -480,6 +478,7 @@ class StreamExecutor:
|
|||||||
"top_k",
|
"top_k",
|
||||||
"frequency_penalty",
|
"frequency_penalty",
|
||||||
"presence_penalty",
|
"presence_penalty",
|
||||||
|
"ignore_eos",
|
||||||
"dtype",
|
"dtype",
|
||||||
"regex",
|
"regex",
|
||||||
]:
|
]:
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
|||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class SamplingParams:
|
class SglSamplingParams:
|
||||||
max_new_tokens: int = 16
|
max_new_tokens: int = 16
|
||||||
stop: Union[str, List[str]] = ()
|
stop: Union[str, List[str]] = ()
|
||||||
temperature: float = 1.0
|
temperature: float = 1.0
|
||||||
@@ -21,13 +21,14 @@ class SamplingParams:
|
|||||||
top_k: int = -1 # -1 means disable
|
top_k: int = -1 # -1 means disable
|
||||||
frequency_penalty: float = 0.0
|
frequency_penalty: float = 0.0
|
||||||
presence_penalty: float = 0.0
|
presence_penalty: float = 0.0
|
||||||
|
ignore_eos: bool = False
|
||||||
|
|
||||||
# for constrained generation, not included in to_xxx_kwargs
|
# for constrained generation, not included in to_xxx_kwargs
|
||||||
dtype: Optional[str] = None
|
dtype: Optional[str] = None
|
||||||
regex: Optional[str] = None
|
regex: Optional[str] = None
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
return SamplingParams(
|
return SglSamplingParams(
|
||||||
self.max_new_tokens,
|
self.max_new_tokens,
|
||||||
self.stop,
|
self.stop,
|
||||||
self.temperature,
|
self.temperature,
|
||||||
@@ -67,6 +68,7 @@ class SamplingParams:
|
|||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
"frequency_penalty": self.frequency_penalty,
|
"frequency_penalty": self.frequency_penalty,
|
||||||
"presence_penalty": self.presence_penalty,
|
"presence_penalty": self.presence_penalty,
|
||||||
|
"ignore_eos": self.ignore_eos,
|
||||||
"regex": self.regex,
|
"regex": self.regex,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,13 +100,14 @@ class SglFunction:
|
|||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
|
ignore_eos: bool = False,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
backend=None,
|
backend=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
from sglang.lang.interpreter import run_program
|
from sglang.lang.interpreter import run_program
|
||||||
|
|
||||||
default_sampling_para = SamplingParams(
|
default_sampling_para = SglSamplingParams(
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@@ -112,9 +115,9 @@ class SglFunction:
|
|||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
|
ignore_eos=ignore_eos,
|
||||||
)
|
)
|
||||||
backend = backend or global_config.default_backend
|
backend = backend or global_config.default_backend
|
||||||
kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
|
|
||||||
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
|
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
|
||||||
|
|
||||||
def run_batch(
|
def run_batch(
|
||||||
@@ -128,6 +131,7 @@ class SglFunction:
|
|||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
|
ignore_eos: bool = False,
|
||||||
backend=None,
|
backend=None,
|
||||||
num_threads: Union[str, int] = "auto",
|
num_threads: Union[str, int] = "auto",
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
@@ -139,7 +143,7 @@ class SglFunction:
|
|||||||
return []
|
return []
|
||||||
assert isinstance(batch_kwargs[0], dict)
|
assert isinstance(batch_kwargs[0], dict)
|
||||||
|
|
||||||
default_sampling_para = SamplingParams(
|
default_sampling_para = SglSamplingParams(
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@@ -147,11 +151,9 @@ class SglFunction:
|
|||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
|
ignore_eos=ignore_eos,
|
||||||
)
|
)
|
||||||
backend = backend or global_config.default_backend
|
backend = backend or global_config.default_backend
|
||||||
batch_kwargs = [
|
|
||||||
{k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
|
|
||||||
]
|
|
||||||
return run_program_batch(
|
return run_program_batch(
|
||||||
self,
|
self,
|
||||||
backend,
|
backend,
|
||||||
@@ -321,12 +323,13 @@ class SglGen(SglExpr):
|
|||||||
top_k,
|
top_k,
|
||||||
frequency_penalty,
|
frequency_penalty,
|
||||||
presence_penalty,
|
presence_penalty,
|
||||||
|
ignore_eos,
|
||||||
dtype,
|
dtype,
|
||||||
regex,
|
regex,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.name = name
|
self.name = name
|
||||||
self.sampling_params = SamplingParams(
|
self.sampling_params = SglSamplingParams(
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@@ -334,6 +337,7 @@ class SglGen(SglExpr):
|
|||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
|
ignore_eos=ignore_eos,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
regex=regex,
|
regex=regex,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -40,7 +40,8 @@ def extract_prefix_by_tracing(program, backend):
|
|||||||
try:
|
try:
|
||||||
with TracingScope(tracer):
|
with TracingScope(tracer):
|
||||||
tracer.ret_value = program.func(tracer, **arguments)
|
tracer.ret_value = program.func(tracer, **arguments)
|
||||||
except StopTracing:
|
except (StopTracing, TypeError):
|
||||||
|
# Some exceptions may not be catched
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Run and cache prefix
|
# Run and cache prefix
|
||||||
|
|||||||
12
python/sglang/srt/backend_config.py
Normal file
12
python/sglang/srt/backend_config.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
Backend configurations, may vary with different serving platforms.
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BackendConfig:
|
||||||
|
extend_dependency_time: float = 0.03
|
||||||
|
|
||||||
|
|
||||||
|
GLOBAL_BACKEND_CONFIG = BackendConfig()
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import uvloop
|
import uvloop
|
||||||
import zmq
|
import zmq
|
||||||
@@ -8,6 +7,7 @@ import zmq.asyncio
|
|||||||
from sglang.srt.managers.router.model_rpc import ModelRpcClient
|
from sglang.srt.managers.router.model_rpc import ModelRpcClient
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import get_exception_traceback
|
from sglang.srt.utils import get_exception_traceback
|
||||||
|
from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
@@ -28,6 +28,9 @@ class RouterManager:
|
|||||||
self.model_client = model_client
|
self.model_client = model_client
|
||||||
self.recv_reqs = []
|
self.recv_reqs = []
|
||||||
|
|
||||||
|
# Init Some Configs
|
||||||
|
self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
|
||||||
|
|
||||||
async def loop_for_forward(self):
|
async def loop_for_forward(self):
|
||||||
while True:
|
while True:
|
||||||
next_step_input = list(self.recv_reqs)
|
next_step_input = list(self.recv_reqs)
|
||||||
@@ -37,7 +40,12 @@ class RouterManager:
|
|||||||
for obj in out_pyobjs:
|
for obj in out_pyobjs:
|
||||||
self.send_to_detokenizer.send_pyobj(obj)
|
self.send_to_detokenizer.send_pyobj(obj)
|
||||||
|
|
||||||
# await for a while to accept input requests
|
# async sleep for recving the subsequent request, and avoiding cache miss
|
||||||
|
if len(out_pyobjs) != 0:
|
||||||
|
has_finished = any([obj.finished for obj in out_pyobjs])
|
||||||
|
if has_finished:
|
||||||
|
await asyncio.sleep(self.extend_dependency_time)
|
||||||
|
|
||||||
await asyncio.sleep(0.001)
|
await asyncio.sleep(0.001)
|
||||||
|
|
||||||
async def loop_for_recv_requests(self):
|
async def loop_for_recv_requests(self):
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from sglang.srt.managers.router.model_runner import ModelRunner
|
|||||||
from sglang.srt.managers.router.radix_cache import RadixCache
|
from sglang.srt.managers.router.radix_cache import RadixCache
|
||||||
from sglang.srt.managers.router.scheduler import Scheduler
|
from sglang.srt.managers.router.scheduler import Scheduler
|
||||||
from sglang.srt.model_config import ModelConfig
|
from sglang.srt.model_config import ModelConfig
|
||||||
from sglang.srt.sampling_params import SamplingParams
|
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_exception_traceback,
|
get_exception_traceback,
|
||||||
@@ -158,6 +157,18 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
if self.running_batch.is_empty():
|
if self.running_batch.is_empty():
|
||||||
self.running_batch = None
|
self.running_batch = None
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
# check the available size
|
||||||
|
available_size = (
|
||||||
|
self.token_to_kv_pool.available_size()
|
||||||
|
+ self.tree_cache.evictable_size()
|
||||||
|
)
|
||||||
|
if available_size != self.max_total_num_token:
|
||||||
|
logger.warning(
|
||||||
|
"Warning: "
|
||||||
|
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
|
||||||
|
"KV cache pool leak detected!"
|
||||||
|
)
|
||||||
|
|
||||||
if self.running_batch is not None and self.tp_rank == 0:
|
if self.running_batch is not None and self.tp_rank == 0:
|
||||||
if self.decode_forward_ct >= 20:
|
if self.decode_forward_ct >= 20:
|
||||||
@@ -408,7 +419,9 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
token_ids = tuple(req.input_ids + req.output_ids)
|
token_ids = tuple(req.input_ids + req.output_ids)
|
||||||
seq_len = len(token_ids) - 1
|
seq_len = len(token_ids) - 1
|
||||||
indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
|
indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
|
||||||
prefix_len = self.tree_cache.insert(token_ids, indices.clone())
|
prefix_len = self.tree_cache.insert(
|
||||||
|
token_ids[:seq_len], indices.clone()
|
||||||
|
)
|
||||||
|
|
||||||
self.token_to_kv_pool.free(indices[:prefix_len])
|
self.token_to_kv_pool.free(indices[:prefix_len])
|
||||||
self.req_to_token_pool.free(req_pool_idx)
|
self.req_to_token_pool.free(req_pool_idx)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class Scheduler:
|
|||||||
self.tree_cache = tree_cache
|
self.tree_cache = tree_cache
|
||||||
|
|
||||||
def new_token_estimation_ratio(self):
|
def new_token_estimation_ratio(self):
|
||||||
return 0.4 if self.schedule_heuristic != "fcfs" else 0.5
|
return 0.5 if self.schedule_heuristic != "fcfs" else 0.6
|
||||||
|
|
||||||
def get_priority_queue(self, forward_queue):
|
def get_priority_queue(self, forward_queue):
|
||||||
if self.schedule_heuristic == "lpm":
|
if self.schedule_heuristic == "lpm":
|
||||||
|
|||||||
@@ -7,13 +7,13 @@ _SAMPLING_EPS = 1e-6
|
|||||||
class SamplingParams:
|
class SamplingParams:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
max_new_tokens: int = 16,
|
||||||
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
|
||||||
max_new_tokens: int = 16,
|
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
dtype: Optional[str] = None,
|
dtype: Optional[str] = None,
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ class ServerArgs:
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tokenizer_path is None:
|
if self.tokenizer_path is None:
|
||||||
self.tokenizer_path = self.model_path
|
self.tokenizer_path = self.model_path
|
||||||
|
if self.tp_size > 1:
|
||||||
|
self.mem_fraction_static = 0.8
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: argparse.ArgumentParser):
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
|
|||||||
@@ -38,6 +38,26 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
|
|||||||
return pred
|
return pred
|
||||||
|
|
||||||
|
|
||||||
|
def call_generate_outlines(
|
||||||
|
prompt, temperature, max_tokens, url, stop=[], regex=None, n=1
|
||||||
|
):
|
||||||
|
data = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"stop": stop,
|
||||||
|
"regex": regex,
|
||||||
|
"n": n,
|
||||||
|
}
|
||||||
|
res = requests.post(url, json=data)
|
||||||
|
assert res.status_code == 200
|
||||||
|
if n == 1:
|
||||||
|
pred = res.json()["text"][0][len(prompt) :]
|
||||||
|
else:
|
||||||
|
pred = [x[len(prompt) :] for x in res.json()["text"]]
|
||||||
|
return pred
|
||||||
|
|
||||||
|
|
||||||
def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
|
def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
|
||||||
data = {
|
data = {
|
||||||
"text": prompt,
|
"text": prompt,
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ def dump_state_text(filename, states, mode="w"):
|
|||||||
if isinstance(s, str):
|
if isinstance(s, str):
|
||||||
pass
|
pass
|
||||||
elif isinstance(s, ProgramState):
|
elif isinstance(s, ProgramState):
|
||||||
s = s.text().strip()
|
s = s.text()
|
||||||
else:
|
else:
|
||||||
s = str(s)
|
s = str(s)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user