Json Decode && Mutl-Turns (#4)

This commit is contained in:
Liangsheng Yin
2024-01-15 16:49:29 +08:00
committed by GitHub
parent f652494df1
commit 08ab2a1655
27 changed files with 755 additions and 41 deletions

View File

@@ -0,0 +1,61 @@
## Run benchmark
### Build dataset
```
pip install wikipedia
python3 build_dataset.py
```
### Dependencies
```
llama_cpp_python 0.2.19
guidance 0.1.10
vllm 0.2.5
outlines 0.0.22
```
### Benchmark sglang
Run llama-7b
```
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
Run mixtral-8x7b
(When there is a CUDA out-of-memory error, try to reduce the `--mem-fraction-static`)
```
python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8
```
Benchmark
```
python3 bench_sglang.py --num-questions 10
```
### Benchmark vllm
Run llama-7b
```
python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
Benchmark
```
python3 bench_other.py --backend vllm --num-questions 10
```
### Benchmark guidance
Run llama-7b and benchmark
```
python3 bench_other.py --backend guidance --num-questions 10 --parallel 1
```

View File

@@ -0,0 +1,125 @@
import argparse
import json
import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_outlines,
)
from sglang.utils import dump_state_text, read_jsonl
from sglang.lang.ir import REGEX_INT, REGEX_STRING, REGEX_FLOAT
from tqdm import tqdm
REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]"
# fmt: off
def json_decode(document, generate):
s = "Please extract the information of a city from the following wikipedia page.\n"
s += "Page begin.\n" + document + "Page end.\n"
s += "Here is the name, country, and symbol of the city in JSON format.\n"
s += "{\n"
s += ' "name": '
s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n"
s += ' "country": '
s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n"
s += ' "latitude": '
s += generate(s, max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
s += ' "population": '
s += generate(s, max_tokens=8, regex=REGEX_INT + ",") + "\n"
s += ' "top 3 landmarks": '
s += generate(s, max_tokens=24, regex=REGEX_LIST) + "\n"
s += "}\n"
return s
# fmt: on
def main(args):
lines = read_jsonl(args.data_path)
arguments = []
for i in range(len(lines[: args.num_questions])):
arguments.append(
{
"document": lines[i]["document"],
}
)
states = [None] * len(arguments)
# Select backend
if args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_outlines, url=url, temperature=0)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def generate(prompt, max_tokens, stop=None, regex=None):
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=0,
stop=stop,
regex=regex,
)
)
return out["answer"]
# warmup
for _ in range(3):
generate("Hello!" * 10, max_tokens=64, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
# Run requests
def get_one_answer(i):
states[i] = json_decode(generate=generate, **arguments[i])
tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(arguments))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
rets = executor.map(get_one_answer, list(range(len(arguments))))
for _ in rets:
pass
latency = time.time() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "json_regex_decode",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="questions.jsonl")
parser.add_argument("--num-questions", type=int, default=20)
args = add_common_other_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,100 @@
import argparse
import json
import time
import sglang as sgl
from sglang.lang.ir import REGEX_INT, REGEX_STRING, REGEX_FLOAT
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]"
# fmt: off
@sgl.function
def json_warm_up(s):
s += "The information about Hogwarts is in the following JSON format.\n"
with s.var_scope("json_output"):
s += "{\n"
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
s += "}\n"
print(f'The warmp up json result is:\n{s["json_output"]}')
# fmt: on
# fmt: off
@sgl.function
def json_decode(s, document):
s += "Please extract the information of a city from the following wikipedia page.\n"
s += "Page begin.\n" + document + "Page end.\n"
s += "Here is the name, country, and symbol of the city in JSON format.\n"
with s.var_scope("json_output"):
s += "{\n"
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
s += "}\n"
# fmt: on
def main(args):
lines = read_jsonl(args.data_path)
arguments = []
for i in range(len(lines[: args.num_questions])):
arguments.append(
{
"document": lines[i]["document"],
}
)
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
# Warm up
json_warm_up.run().sync()
# Run requests
tic = time.time()
states = json_decode.run_batch(arguments, temperature=0, num_threads=args.parallel)
for state in states:
state.sync()
latency = time.time() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(f"tmp_{args.backend}_json_results.txt", "w") as fout:
for state in states:
fout.write(state["json_output"] + "\n")
with open(args.result_file, "a") as fout:
value = {
"task": "json_regex_decode",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_questions,
"other": {
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="questions.jsonl")
parser.add_argument("--num-questions", type=int, default=20)
args = add_common_sglang_args_and_parse(parser)
main(args)

View File

@@ -0,0 +1,58 @@
import json
import transformers
import wikipedia
model_path = "meta-llama/Llama-2-7b-chat-hf"
t = transformers.AutoTokenizer.from_pretrained(model_path)
city_names = [
"los angles",
"london",
"tokyo",
"beijing",
"singapore",
"paris",
"dubai",
"sydney",
"moscow",
"rome",
"toronto",
"rio de janeiro",
"istanbul",
"berlin",
"auckland",
"buenos aires",
"mexico city",
"mumbai",
"seoul",
"bangkok",
"cairo",
"athens",
"jerusalem",
]
def get_content(city_name):
content = str(wikipedia.page(city_name).content)
content = content.replace("\n\n", "\n")
tokens = t.encode(content)
expected_tokens = 3000
truncate_len = int((expected_tokens / len(tokens)) * len(content))
truncate_content = content[:truncate_len]
truncate_tokens = t.encode(truncate_content)
# Count token
print(
f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
)
return truncate_content
if __name__ == "__main__":
with open("questions.jsonl", "w") as fout:
for city_name in city_names:
truncate_content = get_content(city_name)
fout.write(json.dumps({"document": truncate_content}) + "\n")

View File

@@ -0,0 +1,66 @@
### Benchmark sglang
Run llama-7b
```
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
Run mixtral-8x7b
(When there is a CUDA out-of-memory error, try to reduce the `--mem-fraction-static`)
```
python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8
```
Benchmark(short output)
```
python3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf
```
Benchmark(long output)
```
python3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf --long
```
### Benchmark vLLM
Run llama-7b
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
Run mixtral-8x7b
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model mistralai/Mixtral-8x7B-Instruct-v0.1 --disable-log-requests --port 21000 --tensor-parallel-size 8
```
Benchmark(short output)
```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm
```
Benchmark(long output)
```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm --long
```
### Benchmark guidance
Benchmark llama-7b(short output)
```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1
```
Benchmark llama-7b(long output)
```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --long
```

View File

@@ -0,0 +1,133 @@
import json
import time
from argparse import ArgumentParser
from concurrent.futures import ThreadPoolExecutor
import requests
from sglang.test.test_utils import add_common_other_args_and_parse
from sglang.utils import dump_state_text
from tqdm import tqdm
from vllm.transformers_utils.tokenizer import get_tokenizer
from data_gen import gen_arguments
def get_generate(args):
# Select backend
if args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
def generate(prompt, max_tokens, stop=None, temperature=0, url=url, n=1):
data = {
"prompt": prompt,
"temperature": temperature,
"max_tokens": max_tokens,
"ignore_eos": True,
"stop": stop,
"stream": False,
"n": n,
}
res = requests.post(url, json=data)
assert res.status_code == 200
return res.json()["text"][0][len(prompt) :]
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def generate(prompt, max_tokens, stop=None):
out = (
model
+ prompt
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
)
return out["answer"]
# warmup
for _ in range(3):
generate("Hello!" * 10, max_tokens=64, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
return generate
def multi_turns(generate, qas):
s = ""
for qa in qas:
s += qa["prompt"]
s += generate(s, max_tokens=qa["new_tokens"])
return s
def main(args):
print(args)
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
multi_qas = gen_arguments(args, tokenizer)
states = [None] * args.num_qa
generate = get_generate(args)
def get_one_answer(i):
states[i] = multi_turns(generate=generate, **multi_qas[i])
tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(multi_qas))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
rets = executor.map(get_one_answer, list(range(len(multi_qas))))
for _ in rets:
pass
latency = time.time() - tic
# Compute accuracy
print(f"Latency: {latency:.3f}")
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "multi_turns",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_qa,
"num_turns": args.turns,
"other": {
"parallel": args.parallel,
"output_mode": "long" if args.long else "short",
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--turns", type=int, default=4)
parser.add_argument("--num-qa", type=int, default=20)
parser.add_argument("--min-len-q", type=int, default=256)
parser.add_argument("--max-len-q", type=int, default=512)
parser.add_argument("--min-len-a", type=int, default=4)
parser.add_argument("--max-len-a", type=int, default=8)
parser.add_argument("--tokenizer", type=str, required=True)
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--long", action="store_true")
args = add_common_other_args_and_parse(parser)
if args.long:
args.min_len_a = 256
args.max_len_a = 512
args.num_qa = 20
main(args)

View File

@@ -0,0 +1,77 @@
import json
import time
from argparse import ArgumentParser
import sglang as sgl
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text
from vllm.transformers_utils.tokenizer import get_tokenizer
from data_gen import gen_arguments
@sgl.function
def multi_turns(s, qas):
for qa in qas:
s += qa["prompt"]
s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)
def main(args):
print(args)
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
multi_qas = gen_arguments(args, tokenizer)
backend = select_sglang_backend(args)
tic = time.time()
states = multi_turns.run_batch(
multi_qas, temperature=0, backend=backend, num_threads=args.parallel
)
for state in states:
state.sync()
latency = time.time() - tic
print(f"Latency: {latency:.3f}")
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "multi_turns",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_qa,
"num_turns": args.turns,
"other": {
"parallel": args.parallel,
"output_mode": "long" if args.long else "short",
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--turns", type=int, default=4)
parser.add_argument("--num-qa", type=int, default=20)
parser.add_argument("--min-len-q", type=int, default=256)
parser.add_argument("--max-len-q", type=int, default=512)
parser.add_argument("--min-len-a", type=int, default=4)
parser.add_argument("--max-len-a", type=int, default=8)
parser.add_argument("--tokenizer", type=str, required=True)
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--long", action="store_true")
args = add_common_sglang_args_and_parse(parser)
if args.long:
args.min_len_a = 256
args.max_len_a = 512
args.num_qa = 20
main(args)

View File

@@ -0,0 +1,29 @@
import random
import string
random.seed(42)
def gen_prompt(tokenizer, token_num):
cha_set = string.ascii_letters + string.digits
ret = "".join(random.choices(cha_set, k=token_num))
while len(tokenizer(ret).input_ids) < token_num:
ret += random.choice(cha_set)
return ret
def gen_arguments(args, tokenizer):
multi_qas = [{"qas": []} for _ in range(args.num_qa)]
for i in range(args.num_qa):
qas = multi_qas[i]["qas"]
for _ in range(args.turns):
prompt_len = random.randint(args.min_len_q, args.max_len_q)
new_tokens = random.randint(args.min_len_a, args.max_len_a)
qas.append(
{
"prompt": gen_prompt(tokenizer, prompt_len),
"new_tokens": new_tokens,
}
)
return multi_qas