Organize Benchmark (#381)

This commit is contained in:
Liangsheng Yin
2024-05-05 16:14:17 +08:00
committed by GitHub
parent 183df47282
commit 14522e6a26
36 changed files with 829 additions and 809 deletions

View File

@@ -9,6 +9,12 @@ Turn off cache at https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa2
cache_turn_on = False cache_turn_on = False
``` ```
or set the environment variable
```
export DSP_CACHEBOOL=false
```
## Benchmark SGLang ## Benchmark SGLang
``` ```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000

View File

@@ -28,5 +28,11 @@ python3 bench_other.py --num-events 1000 --backend vllm --parallel 1
### Benchmark guidance ### Benchmark guidance
``` ```
python3 bench_other.py --num-events 1000 --backend guidance --parallel 1 python3 bench_other.py --num-events 1000 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
### Benchmark lmql
```
python3 bench_other.py --num-events 1000 --backend lmql --parallel 1
``` ```

View File

@@ -1,8 +1,6 @@
import argparse import argparse
import json import json
import time import time
from functools import partial
from pathlib import Path
from agent_functions import ( from agent_functions import (
action_location_object_prompt, action_location_object_prompt,
@@ -13,12 +11,7 @@ from agent_functions import (
) )
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import ( from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.utils import dump_state_text, read_jsonl from sglang.utils import dump_state_text, read_jsonl
@@ -36,48 +29,27 @@ def main(args):
states = [] states = []
# Select backend # Select backend
if args.backend == "lightllm": call_generate = get_call_generate(args)
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
str(Path.home()) + "/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def call_generate(prompt, temperature, max_tokens, stop):
out = (
model
+ prompt
+ gen(
name="result",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
return out["result"]
else:
raise ValueError(f"Invalid backend: {args.backend}")
def get_one_answer(arg): def get_one_answer(arg):
answer = call_generate(**arg, temperature=0) answer = call_generate(**arg, temperature=0)
states.append(answer) states.append(answer)
async def get_one_answer_async(arg):
answer = await call_generate(**arg, temperature=0)
states.append(answer)
tic = time.time() tic = time.time()
# we always sequentially execute agent calls to maintain its dependency # we always sequentially execute agent calls to maintain its dependency
for arg in tqdm(arguments): if args.backend != "lmql":
get_one_answer(arg) for arg in tqdm(arguments):
get_one_answer(arg)
else:
import asyncio
loop = asyncio.get_event_loop()
for arg in tqdm(arguments):
loop.run_until_complete(get_one_answer_async(arg))
latency = time.time() - tic latency = time.time() - tic
print(f"Latency: {latency:.3f}") print(f"Latency: {latency:.3f}")

View File

@@ -38,7 +38,7 @@ python3 bench_other.py --num-questions 200 --backend lightllm
### Benchmark guidance ### Benchmark guidance
``` ```
python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
``` ```

View File

@@ -5,17 +5,11 @@ import json
import re import re
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import ( from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.utils import dump_state_text, read_jsonl from sglang.utils import dump_state_text, read_jsonl
INVALID = -9999999 INVALID = -9999999
@@ -63,54 +57,7 @@ def main(args):
states = [None] * len(labels) states = [None] * len(labels)
# Select backend # Select backend
if args.backend == "lightllm": call_generate = get_call_generate(args)
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def call_generate(prompt, temperature, max_tokens, stop):
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
return out["answer"]
elif args.backend == "lmql":
import lmql
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
@lmql.query(model=model)
async def program(question):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < 257 and STOPS_AT(ANSWER, "Question")
return ANSWER
'''
async def call_generate(prompt, temperature, max_tokens, stop):
return await program(question=prompt, temperature=0)
else:
raise ValueError(f"Invalid backend: {args.backend}")
# Run requests # Run requests
if args.backend != "lmql": if args.backend != "lmql":
@@ -130,7 +77,13 @@ def main(args):
get_one_answer(i) get_one_answer(i)
else: else:
with ThreadPoolExecutor(args.parallel) as executor: with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(questions)))) list(
tqdm(
executor.map(get_one_answer, list(range(len(questions)))),
total=len(questions),
)
)
else: else:
# Use asyncio # Use asyncio
async def batched_call(batch_size): async def batched_call(batch_size):

View File

@@ -38,7 +38,7 @@ python3 bench_other.py --num-questions 200 --backend lightllm
### Benchmark guidance ### Benchmark guidance
``` ```
CUDA_VISIBLE_DEVICES=0,1 python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 CUDA_VISIBLE_DEVICES=0,1 python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
``` ```

View File

@@ -3,15 +3,11 @@ import asyncio
import json import json
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial
import numpy as np import numpy as np
from tqdm import tqdm
from sglang.test.test_utils import ( from sglang.test.test_utils import add_common_other_args_and_parse, get_call_select
add_common_other_args_and_parse,
call_select_lightllm,
call_select_vllm,
)
from sglang.utils import read_jsonl from sglang.utils import read_jsonl
@@ -47,47 +43,7 @@ def main(args):
preds = [None] * len(labels) preds = [None] * len(labels)
# Select backend # Select backend
if args.backend == "lightllm": call_select = get_call_select(args)
url = f"{args.host}:{args.port}/generate"
call_select = partial(call_select_lightllm, url=url)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_select = partial(call_select_vllm, url=url)
elif args.backend == "guidance":
from guidance import models, select
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def call_select(context, choices):
out = model + context + select(choices, name="answer")
return choices.index(out["answer"])
call_select("Hello,", ["world", "earth"])
elif args.backend == "lmql":
import lmql
model = lmql.model(
"meta-llama/Llama-2-7b-chat-hf", endpoint=f"{args.host}:{args.port}"
)
@lmql.query(model=model)
async def program(ctx, choices):
'''lmql
"""{ctx}[ANSWER]""" where ANSWER in set(choices)
return ANSWER
'''
async def call_select(context, choices):
answer = await program(ctx=context, choices=choices, temperature=0)
return choices.index(answer)
else:
raise ValueError(f"Invalid backend: {args.backend}")
# Run requests # Run requests
if args.backend != "lmql": if args.backend != "lmql":
@@ -99,11 +55,17 @@ def main(args):
tic = time.time() tic = time.time()
if args.parallel == 1: if args.parallel == 1:
for i in range(len(questions)): for i in tqdm(range(len(questions))):
get_one_answer(i) get_one_answer(i)
else: else:
with ThreadPoolExecutor(args.parallel) as executor: with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(questions)))) list(
tqdm(
executor.map(get_one_answer, list(range(len(questions)))),
total=len(questions),
)
)
else: else:
# Use asyncio # Use asyncio
async def batched_call(batch_size): async def batched_call(batch_size):

View File

@@ -36,7 +36,7 @@ python3 bench_sglang.py --num-questions 10
``` ```
### Benchmark vllm ### Benchmark Outlines + vLLM
Run Llama-7B Run Llama-7B
@@ -47,7 +47,7 @@ python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2
Benchmark Benchmark
``` ```
python3 bench_other.py --backend vllm --num-questions 10 python3 bench_other.py --backend outlines --num-questions 10
``` ```
@@ -56,5 +56,5 @@ python3 bench_other.py --backend vllm --num-questions 10
Run Llama-7B and benchmark Run Llama-7B and benchmark
``` ```
python3 bench_other.py --backend guidance --num-questions 10 --parallel 1 python3 bench_other.py --backend guidance --num-questions 10 --parallel 1 --n-ctx 4096 --model-path path/to/gguf
``` ```

View File

@@ -7,10 +7,7 @@ from functools import partial
from tqdm import tqdm from tqdm import tqdm
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
from sglang.test.test_utils import ( from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
add_common_other_args_and_parse,
call_generate_outlines,
)
from sglang.utils import dump_state_text, read_jsonl from sglang.utils import dump_state_text, read_jsonl
REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]" REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]"
@@ -50,41 +47,11 @@ def main(args):
states = [None] * len(arguments) states = [None] * len(arguments)
# Select backend # Select backend
if args.backend == "vllm": call_generate = partial(get_call_generate(args), temperature=0)
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_outlines, url=url, temperature=0)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def generate(prompt, max_tokens, stop=None, regex=None):
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=0,
stop=stop,
regex=regex,
)
)
return out["answer"]
# warmup
for _ in range(3):
generate("Hello!" * 10, max_tokens=64, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
# Run requests # Run requests
def get_one_answer(i): def get_one_answer(i):
states[i] = json_decode(generate=generate, **arguments[i]) states[i] = json_decode(generate=call_generate, **arguments[i])
tic = time.time() tic = time.time()
if args.parallel == 1: if args.parallel == 1:
@@ -92,7 +59,12 @@ def main(args):
get_one_answer(i) get_one_answer(i)
else: else:
with ThreadPoolExecutor(args.parallel) as executor: with ThreadPoolExecutor(args.parallel) as executor:
rets = executor.map(get_one_answer, list(range(len(arguments)))) rets = list(
tqdm(
executor.map(get_one_answer, list(range(len(arguments)))),
total=len(arguments),
)
)
for _ in rets: for _ in rets:
pass pass

View File

@@ -39,7 +39,7 @@ python3 bench_sglang.py --mode city
``` ```
### Benchmark vllm ### Benchmark Outlines + vLLM
Run Llama-7B Run Llama-7B
@@ -50,13 +50,13 @@ python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2
Benchmark Character Generation Benchmark Character Generation
```bash ```bash
python3 bench_other.py --mode character --backend vllm python3 bench_other.py --mode character --backend outlines
``` ```
Benchmark City Information Retrieval Benchmark City Information Retrieval
```bash ```bash
python3 bench_other.py --mode city --backend vllm python3 bench_other.py --mode city --backend outlines
``` ```
### Benchmark guidance ### Benchmark guidance
@@ -64,11 +64,25 @@ python3 bench_other.py --mode city --backend vllm
Run Llama-7B and benchmark character generation Run Llama-7B and benchmark character generation
```bash ```bash
python3 bench_other.py --mode character --backend guidance --parallel 1 python3 bench_other.py --mode character --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
``` ```
Run Llama-7B and benchmark city information retrieval Run Llama-7B and benchmark city information retrieval
```bash ```bash
python3 bench_other.py --mode city --backend guidance --parallel 1 python3 bench_other.py --mode city --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
### Benchmark lmql
Run Llama-7B and benchmark character generation
```
python3 bench_other.py --mode character --backend lmql --parallel 1
```
Run Llama-7B and benchmark city information retrieval
```
python3 bench_other.py --mode city --backend lmql --parallel 1
``` ```

View File

@@ -7,10 +7,7 @@ from functools import partial
import guidance import guidance
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import ( from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
add_common_other_args_and_parse,
call_generate_outlines,
)
from sglang.utils import dump_state_text, read_jsonl from sglang.utils import dump_state_text, read_jsonl
# there are some FSM bugs with json regex converted from pydantic model # there are some FSM bugs with json regex converted from pydantic model
@@ -85,6 +82,29 @@ def character_maker(lm, name):
return lm return lm
async def call_generate_lmql(
prompt, temperature, max_tokens, regex, max_len=4096, model=None, **kwargs
):
assert model is not None
import lmql
@lmql.query(model=model)
async def program(question, max_tokens, regex):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and REGEX(ANSWER, regex)
return ANSWER
'''
return await program(
question=prompt,
temperature=temperature,
max_tokens=max_tokens,
max_len=max_len,
regex=regex,
**kwargs,
)
@guidance @guidance
def city_maker(lm, document): def city_maker(lm, document):
regex_str_no_quote = r"[\w\d\s]+" regex_str_no_quote = r"[\w\d\s]+"
@@ -119,38 +139,68 @@ def bench_character(args):
states = [None] * len(arguments) states = [None] * len(arguments)
# Select backend # Select backend
if args.backend == "vllm": if args.backend == "outlines":
url = f"{args.host}:{args.port}/generate" call_generate = partial(get_call_generate(args), temperature=0)
generate = partial(call_generate_outlines, url=url, temperature=0)
def func(i): def get_one_answer(i):
states[i] = character_gen(**arguments[i], generate=generate) states[i] = character_gen(**arguments[i], generate=call_generate)
get_one_answer = func
elif args.backend == "guidance": elif args.backend == "guidance":
model = guidance.models.LlamaCpp( model = guidance.models.LlamaCpp(
args.llama_cpp_model_path, args.model_path,
n_gpu_layers=-1, n_gpu_layers=-1,
n_ctx=4096, n_ctx=args.n_ctx,
) )
def func(i): def get_one_answer(i):
lm = model + character_maker(**arguments[i]) lm = model + character_maker(**arguments[i])
states[i] = lm states[i] = lm
get_one_answer = func elif args.backend == "lmql":
import asyncio
import lmql
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
call_generate = partial(
call_generate_lmql,
model=model,
max_tokens=256,
regex=character_regex,
)
async def get_one_answer_async(i):
states[i] = await call_generate(prompt=arguments[i]["name"], temperature=0)
else: else:
raise ValueError(f"Invalid backend: {args.backend}") raise ValueError(f"Invalid backend: {args.backend}")
tic = time.time() tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(arguments))): if args.backend != "lmql":
get_one_answer(i) if args.parallel == 1:
for i in tqdm(range(len(arguments))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
rets = list(
tqdm(
executor.map(get_one_answer, list(range(len(arguments)))),
total=len(arguments),
)
)
for _ in rets:
pass
else: else:
with ThreadPoolExecutor(args.parallel) as executor: batches = []
rets = executor.map(get_one_answer, list(range(len(arguments)))) for i in range(0, len(arguments), args.parallel):
for _ in rets: batches.append(list(range(i, min(i + args.parallel, len(arguments)))))
pass loop = asyncio.get_event_loop()
for bt in tqdm(batches):
loop.run_until_complete(
asyncio.gather(*[get_one_answer_async(i) for i in bt])
)
latency = time.time() - tic latency = time.time() - tic
@@ -166,26 +216,23 @@ def bench_city_doc(args):
states = [None] * len(arguments) states = [None] * len(arguments)
# Select backend # Select backend
if args.backend == "vllm": if args.backend == "outlines":
url = f"{args.host}:{args.port}/generate" call_generate = partial(get_call_generate(args), temperature=0)
generate = partial(call_generate_outlines, url=url, temperature=0)
def func(i): def get_one_answer(i):
states[i] = city_gen(**arguments[i], generate=generate) states[i] = city_gen(**arguments[i], generate=call_generate)
get_one_answer = func
elif args.backend == "guidance": elif args.backend == "guidance":
model = guidance.models.LlamaCpp( model = guidance.models.LlamaCpp(
args.llama_cpp_model_path, args.model_path,
n_gpu_layers=-1, n_gpu_layers=-1,
n_ctx=4096, n_ctx=args.n_ctx,
) )
def func(i): def get_one_answer(i):
lm = model + city_maker(**arguments[i]) lm = model + city_maker(**arguments[i])
states[i] = lm states[i] = lm
get_one_answer = func
else: else:
raise ValueError(f"Invalid backend: {args.backend}") raise ValueError(f"Invalid backend: {args.backend}")
@@ -237,10 +284,5 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--mode", type=str, default="character", choices=["character", "city"] "--mode", type=str, default="character", choices=["character", "city"]
) )
parser.add_argument(
"--llama-cpp-model-path",
type=str,
default="/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf",
)
args = add_common_other_args_and_parse(parser) args = add_common_other_args_and_parse(parser)
main(args) main(args)

View File

@@ -23,5 +23,11 @@ python3 bench_other.py --backend vllm --num-questions 25
### Benchmark guidance ### Benchmark guidance
``` ```
python3 bench_other.py --backend guidance --num-questions 25 --parallel 1 python3 bench_other.py --backend guidance --num-questions 25 --parallel 1 --n-ctx 4096 --model-path path/to/gguf
``` ```
### Benchmark lmql
```
python3 bench_other.py --backend lmql --num-questions 25 --parallel 1
```

View File

@@ -6,12 +6,7 @@ from functools import partial
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import ( from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.utils import dump_state_text, read_jsonl from sglang.utils import dump_state_text, read_jsonl
system_prompt = "Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency." system_prompt = "Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
@@ -54,53 +49,77 @@ def multi_dimension_judge(article, generate):
return s return s
async def multi_dimension_judge_async(article, generate):
s = system_prompt
s += "\n```\n" + article + "\n```\n\n"
judges = []
for i in range(len(dimension_prompts)):
comp = await generate(
s
+ "USER: Please judge the quality based on the following metric. "
+ dimension_prompts[i]
+ " Please provide a single-paragraph judgement. "
+ "Focus on the provided metric and do not say other things. "
'End your judgement paragraph with the word "END"\nJUDGE:',
max_tokens=256,
stop="END",
)
judges.append(comp)
s += "I will judge the quality based on the following metrics.\n"
for i in range(len(dimension_prompts)):
s += dimension_prompts[i].split(":")[0] + ": " + judges[i].strip() + "\n"
s += "In summary, on a scale of 1 to 10, I would give the article a score of"
s += await generate(s, max_tokens=2, stop=None)
return s
def main(args): def main(args):
lines = read_jsonl(args.data_path)[: args.num_questions] lines = read_jsonl(args.data_path)[: args.num_questions]
states = [None] * len(lines) states = [None] * len(lines)
# Select backend # Select backend
if args.backend == "lightllm": call_generate = partial(get_call_generate(args), temperature=0)
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_lightllm, url=url, temperature=0)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_vllm, url=url, temperature=0)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_srt_raw, url=url, temperature=0)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def generate(prompt, max_tokens, stop):
out = (
model
+ prompt
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
)
return out["answer"]
# warmup
generate("Hello!", max_tokens=8, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
# Run requests # Run requests
def get_one_answer(i):
states[i] = multi_dimension_judge(lines[i], generate)
tic = time.time() tic = time.time()
if args.parallel == 1:
for i in tqdm(range(len(lines))): if args.backend != "lmql":
get_one_answer(i)
def get_one_answer(i):
states[i] = multi_dimension_judge(lines[i], call_generate)
if args.parallel == 1:
for i in tqdm(range(len(lines))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(get_one_answer, list(range(len(lines)))),
total=len(lines),
)
)
else: else:
with ThreadPoolExecutor(args.parallel) as executor: import asyncio
executor.map(get_one_answer, list(range(len(lines))))
async def get_one_answer_async(i):
states[i] = await multi_dimension_judge_async(lines[i], call_generate)
batches = []
for i in range(0, len(lines), args.parallel):
batches.append(list(range(i, min(i + args.parallel, len(lines)))))
loop = asyncio.get_event_loop()
for bt in tqdm(batches):
loop.run_until_complete(
asyncio.gather(*[get_one_answer_async(i) for i in bt])
)
latency = time.time() - tic latency = time.time() - tic
# Compute accuracy # Compute accuracy

View File

@@ -22,7 +22,7 @@ python3 bench_other.py --backend vllm --num-questions 5
### Benchmark guidance ### Benchmark guidance
``` ```
python3 bench_other.py --backend guidance --num-questions 5 --parallel 1 python3 bench_other.py --backend guidance --num-questions 5 --parallel 1 --n-ctx 11000 --model-path path/to/code-llama/gguf
``` ```

View File

@@ -6,12 +6,7 @@ from functools import partial
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import ( from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.utils import dump_state_text, read_jsonl from sglang.utils import dump_state_text, read_jsonl
@@ -44,40 +39,11 @@ def main(args):
states = [None] * len(arguments) states = [None] * len(arguments)
# Select backend # Select backend
if args.backend == "lightllm": call_generate = partial(get_call_generate(args), temperature=0)
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_lightllm, url=url, temperature=0)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_vllm, url=url, temperature=0)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_srt_raw, url=url, temperature=0)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/CodeLlama-7b-instruct-hf.gguf",
n_gpu_layers=-1,
n_ctx=11000,
)
def generate(prompt, max_tokens, stop):
out = (
model
+ prompt
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
)
return out["answer"]
# warmup
generate("Hello!", max_tokens=8, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
# Run requests # Run requests
def get_one_answer(i): def get_one_answer(i):
states[i] = json_decode(generate=generate, **arguments[i]) states[i] = json_decode(generate=call_generate, **arguments[i])
tic = time.time() tic = time.time()
if args.parallel == 1: if args.parallel == 1:
@@ -85,7 +51,13 @@ def main(args):
get_one_answer(i) get_one_answer(i)
else: else:
with ThreadPoolExecutor(args.parallel) as executor: with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(arguments)))) list(
tqdm(
executor.map(get_one_answer, list(range(len(arguments)))),
total=len(arguments),
)
)
latency = time.time() - tic latency = time.time() - tic
# Compute accuracy # Compute accuracy

View File

@@ -46,7 +46,7 @@ python3 bench_other.py --nsub 10 --backend lightllm
### Benchmark guidance ### Benchmark guidance
``` ```
python3 bench_other.py --nsub 10 --backend guidance --parallel 1 python3 bench_other.py --nsub 10 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
``` ```

View File

@@ -4,19 +4,13 @@ import json
import os import os
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import tiktoken import tiktoken
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import ( from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
choices = ["A", "B", "C", "D"] choices = ["A", "B", "C", "D"]
@@ -53,10 +47,7 @@ def gen_prompt(train_df, subject, k=-1):
return prompt return prompt
model_initialized = None def evaluate(args, subject, dev_df, test_df, call_generate):
def evaluate(args, subject, dev_df, test_df):
prompts = [] prompts = []
labels = [] labels = []
@@ -78,62 +69,6 @@ def evaluate(args, subject, dev_df, test_df):
preds = [None] * len(prompts) preds = [None] * len(prompts)
max_tokens = 1 max_tokens = 1
# Select backend
global model_initialized
if args.backend == "lightllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url, stop=None)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url, stop=None)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url, stop=None)
elif args.backend == "guidance":
from guidance import gen, models
if model_initialized is None:
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
model_initialized = model
else:
model = model_initialized
def call_generate(prompt, temperature, max_tokens):
out = (
model
+ prompt
+ gen(name="answer", max_tokens=max_tokens, temperature=0)
)
return out["answer"]
# warmup
call_generate("Hello,", temperature=1.0, max_tokens=8)
elif args.backend == "lmql":
import lmql
model = lmql.model(
"meta-llama/Llama-2-7b-chat-hf", endpoint=f"{args.host}:{args.port}"
)
@lmql.query(model=model)
async def program(question):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < 2
return ANSWER
'''
async def call_generate(prompt, temperature, max_tokens):
return await program(question=prompt, temperature=temperature)
else:
raise ValueError(f"Invalid backend: {args.backend}")
# Run requests # Run requests
if args.backend != "lmql": if args.backend != "lmql":
# Use thread pool # Use thread pool
@@ -190,6 +125,9 @@ def main(args):
all_latencies = [] all_latencies = []
num_requests = 0 num_requests = 0
# Select backend
call_generate = get_call_generate(args)
for subject in tqdm(subjects[: args.nsub]): for subject in tqdm(subjects[: args.nsub]):
dev_df = pd.read_csv( dev_df = pd.read_csv(
os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None
@@ -198,7 +136,7 @@ def main(args):
os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None
) )
cors, acc, latency = evaluate(args, subject, dev_df, test_df) cors, acc, latency = evaluate(args, subject, dev_df, test_df, call_generate)
all_cors.append(cors) all_cors.append(cors)
all_latencies.append(latency) all_latencies.append(latency)
num_requests += len(test_df) num_requests += len(test_df)

View File

@@ -1,3 +1,9 @@
## Download Dataset
```sh
wget -O question.jsonl https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl
```
## Run benchmark ## Run benchmark
### Benchmark sglang ### Benchmark sglang

View File

@@ -4,16 +4,11 @@ import os
import time import time
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial
from fastchat.model import get_conversation_template from fastchat.model import get_conversation_template
from tqdm import tqdm
from sglang.test.test_utils import ( from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt,
call_generate_vllm,
)
def load_questions(filename): def load_questions(filename):
@@ -50,17 +45,7 @@ def main(args):
conv_main = get_conversation_template(model_id) conv_main = get_conversation_template(model_id)
# Select backend # Select backend
if args.backend == "lightllm": call_generate = get_call_generate(args)
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url, stop=None)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url, stop=None)
elif args.backend == "srt":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt, url=url, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
answers = [None] * len(questions) answers = [None] * len(questions)
@@ -83,11 +68,17 @@ def main(args):
# Run requests # Run requests
tic = time.time() tic = time.time()
if args.parallel == 1: if args.parallel == 1:
for i in range(len(questions)): for i in tqdm(range(len(questions))):
get_answer(i) get_answer(i)
else: else:
with ThreadPoolExecutor(args.parallel) as executor: with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_answer, list(range(len(questions)))) list(
tqdm(
executor.map(get_answer, list(range(len(questions)))),
total=len(questions),
)
)
latency = time.time() - tic latency = time.time() - tic
print(f"#questions: {len(questions)}, Latency: {latency:.2f}") print(f"#questions: {len(questions)}, Latency: {latency:.2f}")

View File

@@ -39,5 +39,11 @@ python3 bench_other.py --num-questions 64 --backend lightllm
### Benchmark guidance ### Benchmark guidance
``` ```
python3 bench_other.py --num-questions 8 --backend guidance --parallel 1 python3 bench_other.py --num-questions 8 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
### Benchmark lmql
```
python3 bench_other.py --num-questions 64 --backend lmql --parallel 1
``` ```

View File

@@ -5,16 +5,11 @@ import json
import re import re
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial
import numpy as np import numpy as np
from tqdm import tqdm
from sglang.test.test_utils import ( from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.utils import dump_state_text, read_jsonl from sglang.utils import dump_state_text, read_jsonl
INVALID = -9999999 INVALID = -9999999
@@ -67,6 +62,32 @@ def multi_chain_gsm8k(question, num_chains, call_generate):
return s return s
async def multi_chain_gsm8k_async(question, num_chains, call_generate):
s = "Question: " + question + "\n"
# s += call_generate(s + "Answer: " + prompt_lib[0], max_tokens=256,
# stop="Question", temperature=0)
# return s
comps = []
for i in range(num_chains):
comps.append(
await call_generate(
s + "Answer: " + prompt_lib[i % num_chains],
max_tokens=256,
temperature=0.3,
stop="Question",
)
)
s += "Answer: To answer this question, here are some possible solutions. "
s += "After considering all of them, I will do a majority vote.\n\n"
for i in range(num_chains):
s += f"Solution {i+1}: " + comps[i].strip() + "\n\n"
s += "\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
s += await call_generate(s, max_tokens=16, temperature=0, stop=None)
return s
def main(args): def main(args):
lines = read_jsonl(args.data_path) lines = read_jsonl(args.data_path)
@@ -83,71 +104,7 @@ def main(args):
states = [None] * len(labels) states = [None] * len(labels)
# Select backend # Select backend
if args.backend == "lightllm": call_generate = get_call_generate(args)
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def call_generate(prompt, temperature, max_tokens, stop):
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
return out["answer"]
# def multi_chain_gsm8k(question, num_chains, call_generate):
# s = model + "Question: " + question + "\n"
# comps = []
# for i in range(num_chains):
# comps.append(call_generate(s + "Answer: " + prompt_lib[i % num_chains],
# max_tokens=256, temperature=0.3, stop="Question"))
# s += "Answer: To answer this question, here are some possible solutions. "
# s += "After considering all of them, I will do a majority vote.\n\n"
# for i in range(num_chains):
# s += f"Solution {i+1}: " + comps[i].strip() + "\n\n"
# s += f"\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
# return call_generate(s, max_tokens=16, temperature=0, stop=None)
elif args.backend == "lmql":
import lmql
model = lmql.model(
"meta-llama/Llama-2-7b-chat-hf", endpoint=f"{args.host}:{args.port}"
)
@lmql.query(model=model)
async def program(question):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < 257 and STOPS_AT(ANSWER, "Question")
return ANSWER
'''
async def call_generate(prompt, temperature, max_tokens, stop):
return await program(question=prompt, temperature=0)
else:
raise ValueError(f"Invalid backend: {args.backend}")
# Run requests # Run requests
if args.backend != "lmql": if args.backend != "lmql":
@@ -158,31 +115,35 @@ def main(args):
tic = time.time() tic = time.time()
if args.parallel == 1: if args.parallel == 1:
for i in range(len(questions)): for i in tqdm(range(len(questions))):
get_one_answer(i) get_one_answer(i)
else: else:
with ThreadPoolExecutor(args.parallel) as executor: with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(questions)))) list(
tqdm(
executor.map(get_one_answer, list(range(len(questions)))),
total=len(questions),
)
)
else: else:
# Use asyncio # Use asyncio
async def batched_call(batch_size): async def get_one_answer_asyncio(i):
for i in range(0, len(questions), batch_size): answer = await multi_chain_gsm8k_async(
tasks = [] questions[i], args.num_chains, call_generate
for q in questions[i : i + batch_size]: )
tasks.append( states[i] = answer
call_generate(
few_shot_examples + q,
temperature=0,
max_tokens=256,
stop="Question",
)
)
rets = await asyncio.gather(*tasks)
for j in range(len(rets)):
states[i + j] = get_answer_value(rets[j])
tic = time.time() tic = time.time()
asyncio.run(batched_call(batch_size=args.parallel)) loop = asyncio.get_event_loop()
batches = [
list(range(i, min(i + args.parallel, len(questions))))
for i in range(0, len(questions), args.parallel)
]
for bt in tqdm(batches):
tasks = [get_one_answer_asyncio(k) for k in bt]
loop.run_until_complete(asyncio.gather(*tasks))
latency = time.time() - tic latency = time.time() - tic
preds = [] preds = []

View File

@@ -22,7 +22,7 @@ python3 bench_other.py --backend vllm --num-questions 64
### Benchmark guidance ### Benchmark guidance
``` ```
python3 bench_other.py --backend guidance --num-questions 32 --parallel 1 python3 bench_other.py --backend guidance --num-questions 32 --parallel 1 --n-ctx 11000 --model-path path/to/code-llama/gguf
``` ```

View File

@@ -6,12 +6,7 @@ from functools import partial
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import ( from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.utils import dump_state_text, read_jsonl from sglang.utils import dump_state_text, read_jsonl
USER_PREFIX = "[INST] " USER_PREFIX = "[INST] "
@@ -60,40 +55,11 @@ def main(args):
states = [None] * len(arguments) states = [None] * len(arguments)
# Select backend # Select backend
if args.backend == "lightllm": call_generate = partial(get_call_generate(args), temperature=0)
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_lightllm, url=url, temperature=0)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_vllm, url=url, temperature=0)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_srt_raw, url=url, temperature=0)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/CodeLlama-7b-instruct-hf.gguf",
n_gpu_layers=-1,
n_ctx=11000,
)
def generate(prompt, max_tokens, stop):
out = (
model
+ prompt
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
)
return out["answer"]
# warmup
generate("Hello!", max_tokens=8, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
# Run requests # Run requests
def get_one_answer(i): def get_one_answer(i):
states[i] = multi_document_qa(generate=generate, **arguments[i]) states[i] = multi_document_qa(generate=call_generate, **arguments[i])
tic = time.time() tic = time.time()
if args.parallel == 1: if args.parallel == 1:
@@ -101,7 +67,13 @@ def main(args):
get_one_answer(i) get_one_answer(i)
else: else:
with ThreadPoolExecutor(args.parallel) as executor: with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(labels)))) list(
tqdm(
executor.map(get_one_answer, list(range(len(labels)))),
total=len(labels),
)
)
latency = time.time() - tic latency = time.time() - tic
# Compute accuracy # Compute accuracy

View File

@@ -56,11 +56,11 @@ python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm
Benchmark Llama-7B (short output) Benchmark Llama-7B (short output)
``` ```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
``` ```
Benchmark Llama-7B (long output) Benchmark Llama-7B (long output)
``` ```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --long python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf --long
``` ```

View File

@@ -2,61 +2,16 @@ import json
import time import time
from argparse import ArgumentParser from argparse import ArgumentParser
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial
import requests
from data_gen import gen_arguments from data_gen import gen_arguments
from tqdm import tqdm from tqdm import tqdm
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from sglang.test.test_utils import add_common_other_args_and_parse from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text from sglang.utils import dump_state_text
def get_generate(args):
# Select backend
if args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
def generate(prompt, max_tokens, stop=None, temperature=0, url=url, n=1):
data = {
"prompt": prompt,
"temperature": temperature,
"max_tokens": max_tokens,
"ignore_eos": True,
"stop": stop,
"stream": False,
"n": n,
}
res = requests.post(url, json=data)
assert res.status_code == 200
return res.json()["text"][0][len(prompt) :]
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def generate(prompt, max_tokens, stop=None):
out = (
model
+ prompt
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
)
return out["answer"]
# warmup
for _ in range(3):
generate("Hello!" * 10, max_tokens=64, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
return generate
def multi_turns(generate, qas): def multi_turns(generate, qas):
s = "" s = ""
for qa in qas: for qa in qas:
@@ -75,10 +30,10 @@ def main(args):
states = [None] * args.num_qa states = [None] * args.num_qa
generate = get_generate(args) call_generate = partial(get_call_generate(args), temperature=0)
def get_one_answer(i): def get_one_answer(i):
states[i] = multi_turns(generate=generate, **multi_qas[i]) states[i] = multi_turns(generate=call_generate, **multi_qas[i])
tic = time.time() tic = time.time()
if args.parallel == 1: if args.parallel == 1:
@@ -86,7 +41,12 @@ def main(args):
get_one_answer(i) get_one_answer(i)
else: else:
with ThreadPoolExecutor(args.parallel) as executor: with ThreadPoolExecutor(args.parallel) as executor:
rets = executor.map(get_one_answer, list(range(len(multi_qas)))) rets = list(
tqdm(
executor.map(get_one_answer, list(range(len(multi_qas)))),
total=len(multi_qas),
)
)
for _ in rets: for _ in rets:
pass pass

View File

@@ -24,5 +24,11 @@ python3 bench_other.py --num-questions 100 --backend vllm
### Benchmark guidance ### Benchmark guidance
``` ```
python3 bench_other.py --num-questions 100 --backend guidance --parallel 1 python3 bench_other.py --num-questions 100 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
``` ```
### Benchmark lmql
```
python3 bench_other.py --num-questions 100 --backend lmql --parallel 1
```

View File

@@ -2,17 +2,10 @@ import argparse
import json import json
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial
from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import ( from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.utils import dump_state_text, read_jsonl from sglang.utils import dump_state_text, read_jsonl
@@ -97,42 +90,7 @@ def main(args):
states = [] states = []
# Select backend # Select backend
if args.backend == "lightllm": call_generate = get_call_generate(args)
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
str(Path.home()) + "/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def call_generate(prompt, temperature, max_tokens, stop):
out = (
model
+ prompt
+ gen(
name="result",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
return out["result"]
# warmup
call_generate("Hello,", 1.0, 8, ".")
else:
raise ValueError(f"Invalid backend: {args.backend}")
def run_single_agent(argument): def run_single_agent(argument):
question = argument["question"] question = argument["question"]
@@ -161,13 +119,60 @@ def main(args):
states.append(answer) states.append(answer)
async def run_single_agent_async(argument):
question = argument["question"]
triplets = argument["triplets"]
prompt = get_prompt(question)
for i in range(1, len(triplets) + 2):
prompt += "Thought " + str(i) + ":"
states.append(prompt)
answer = await call_generate(
prompt, max_tokens=200, temperature=0, stop="Observation", max_len=4096
)
if i > len(triplets):
break
prompt += (
triplets[i - 1]["thought"]
+ "\nAction "
+ str(i)
+ ":"
+ triplets[i - 1]["action"]
+ "\nObservation "
+ str(i)
+ ":"
+ triplets[i - 1]["observation"]
+ "\n"
)
states.append(answer)
tic = time.time() tic = time.time()
if args.parallel == 1:
for arg in tqdm(arguments): if args.backend != "lmql":
run_single_agent(arg) if args.parallel == 1:
for arg in tqdm(arguments):
run_single_agent(arg)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(run_single_agent, arguments), total=len(arguments)
)
)
else: else:
with ThreadPoolExecutor(args.parallel) as executor: import asyncio
executor.map(run_single_agent, arguments)
loop = asyncio.get_event_loop()
batches = [
[] for _ in range((len(arguments) + args.parallel - 1) // args.parallel)
]
for i, arg in enumerate(arguments):
batches[i // args.parallel].append(arg)
for bt in tqdm(batches):
tasks = [run_single_agent_async(arg) for arg in bt]
loop.run_until_complete(asyncio.gather(*tasks))
latency = time.time() - tic latency = time.time() - tic
print(f"Latency: {latency:.3f}") print(f"Latency: {latency:.3f}")

View File

@@ -23,5 +23,11 @@ python3 bench_other.py --backend vllm --num-questions 64
### Benchmark guidance ### Benchmark guidance
``` ```
python3 bench_other.py --backend guidance --num-questions 32 --parallel 1 python3 bench_other.py --backend guidance --num-questions 32 --parallel 1 --n-ctx 4096 --model-path path/to/gguf
``` ```
### Benchmark lmql
```
python3 bench_other.py --backend lmql --num-questions 32 --parallel 1
```

View File

@@ -6,12 +6,7 @@ from functools import partial
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import ( from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.utils import dump_state_text, read_jsonl from sglang.utils import dump_state_text, read_jsonl
number = 5 number = 5
@@ -70,48 +65,43 @@ def main(args):
states = [None] * len(lines) states = [None] * len(lines)
# Select backend # Select backend
if args.backend == "lightllm": call_generate = partial(get_call_generate(args), temperature=0)
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_lightllm, url=url, temperature=0)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_vllm, url=url, temperature=0)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_srt_raw, url=url, temperature=0)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def generate(prompt, max_tokens, stop):
out = (
model
+ prompt
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
)
return out["answer"]
# warmup
generate("Hello!", max_tokens=8, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
# Run requests # Run requests
def get_one_answer(i):
states[i] = suggest_tips(lines[i]["topic"], generate)
tic = time.time() tic = time.time()
if args.parallel == 1: if args.backend != "lmql":
for i in tqdm(range(len(lines))):
get_one_answer(i) def get_one_answer(i):
states[i] = suggest_tips(lines[i]["topic"], call_generate)
if args.parallel == 1:
for i in tqdm(range(len(lines))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(get_one_answer, list(range(len(lines)))),
total=len(lines),
)
)
else: else:
with ThreadPoolExecutor(args.parallel) as executor: import asyncio
executor.map(get_one_answer, list(range(len(lines))))
from lmql_funcs import suggest_tips_async
async def get_one_answer_async(i):
states[i] = await suggest_tips_async(lines[i]["topic"], call_generate)
batches = []
for i in range(0, len(lines), args.parallel):
batches.append(list(range(i, min(i + args.parallel, len(lines)))))
loop = asyncio.get_event_loop()
for batch in tqdm(batches):
loop.run_until_complete(
asyncio.gather(*[get_one_answer_async(i) for i in batch])
)
latency = time.time() - tic latency = time.time() - tic
# Compute accuracy # Compute accuracy

View File

@@ -0,0 +1,50 @@
number = 5
async def expand_tip_async(topic, tip, generate):
s = (
"""Please expand a tip for a topic into a detailed paragraph.
Topic: staying healthy
Tip: Regular Exercise
Paragraph: Incorporate physical activity into your daily routine. This doesn't necessarily mean intense gym workouts; it can be as simple as walking, cycling, or yoga. Regular exercise helps in maintaining a healthy weight, improves cardiovascular health, boosts mental health, and can enhance cognitive function, which is crucial for fields that require intense intellectual engagement.
Topic: building a campfire
Tip: Choose the Right Location
Paragraph: Always build your campfire in a safe spot. This means selecting a location that's away from trees, bushes, and other flammable materials. Ideally, use a fire ring if available. If you're building a fire pit, it should be on bare soil or on a bed of stones, not on grass or near roots which can catch fire underground. Make sure the area above is clear of low-hanging branches.
Topic: writing a blog post
Tip: structure your content effectively
Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement.
Topic: """
+ topic
+ "\nTip: "
+ tip
+ "\nParagraph:"
)
return await generate(s, max_tokens=128, stop="\n\n")
async def suggest_tips_async(topic, generate):
s = "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n"
s += "USER: Give some tips for " + topic + ".\n"
s += (
"ASSISTANT: Okay. Here are "
+ str(number)
+ " concise tips, each under 8 words:\n"
)
tips = []
for i in range(1, 1 + number):
s += f"{i}."
# NOTE: stop is different due to lmql does not support a list of stop tokens
tip = await generate(s, max_tokens=24, stop=".\n")
s += tip + ".\n"
tips.append(tip)
paragraphs = [await expand_tip_async(topic, tip, generate=generate) for tip in tips]
for i in range(1, 1 + number):
s += f"Tip {i}:" + paragraphs[i - 1] + "\n"
return s

View File

@@ -41,5 +41,11 @@ python3 bench_other.py --num-questions 32 --backend lightllm
### Benchmark guidance ### Benchmark guidance
``` ```
python3 bench_other.py --num-questions 8 --backend guidance --parallel 1 python3 bench_other.py --num-questions 8 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
```
### Benchmark lmql
```
python3 bench_other.py --num-questions 8 --backend lmql --parallel 1
``` ```

View File

@@ -5,17 +5,11 @@ import re
import time import time
from collections import Counter from collections import Counter
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import ( from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.utils import dump_state_text, read_jsonl from sglang.utils import dump_state_text, read_jsonl
INVALID = -9999999 INVALID = -9999999
@@ -139,69 +133,50 @@ def main(args):
arguments = [{"question": q, "num_branches": num_branches} for q in questions] arguments = [{"question": q, "num_branches": num_branches} for q in questions]
# Select backend # Select backend
if args.backend == "lightllm": call_generate = get_call_generate(args)
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def call_generate(prompt, temperature, max_tokens, stop, n):
if n == 1:
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
return out["answer"]
else:
rets = []
for i in range(n):
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
rets.append(out["answer"])
return rets
# warmup
call_generate("Hello,", 1.0, 8, ".", 1)
# Run requests # Run requests
states = [None] * len(questions) states = [None] * len(questions)
def get_one_answer(i):
states[i] = tree_search(**arguments[i], call_generate=call_generate)
tic = time.time() tic = time.time()
if args.parallel == 1: if args.backend != "lmql":
for i in tqdm(range(len(questions))):
get_one_answer(i) def get_one_answer(i):
states[i] = tree_search(**arguments[i], call_generate=call_generate)
if args.parallel == 1:
for i in tqdm(range(len(questions))):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
list(
tqdm(
executor.map(get_one_answer, list(range(len(questions)))),
total=len(questions),
)
)
else: else:
with ThreadPoolExecutor(args.parallel) as executor: import asyncio
executor.map(get_one_answer, list(range(len(questions))))
from lmql_funcs import tree_search_async
async def get_one_answer_async(i):
states[i] = await tree_search_async(
**arguments[i], call_generate=call_generate
)
batches = [
[] for _ in range((len(questions) + args.parallel - 1) // args.parallel)
]
for i in range(len(questions)):
batches[i // args.parallel].append(i)
loop = asyncio.get_event_loop()
for bt in tqdm(batches):
tasks = [get_one_answer_async(k) for k in bt]
loop.run_until_complete(asyncio.gather(*tasks))
latency = time.time() - tic latency = time.time() - tic
answers_text = [] answers_text = []

View File

@@ -0,0 +1,82 @@
from bench_other import (
ASSISTANT_PREFIX,
ASSISTANT_SUFFIX,
USER_PREFIX,
USER_SUFFIX,
temp,
)
async def propose_plan_async(s, question, num_branches, call_generate):
s += (
USER_PREFIX
+ """Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """
+ question
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX
comps = await call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
async def execute_plan_async(s, num_branches, call_generate):
s += (
USER_PREFIX
+ """The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short."""
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX
comps = await call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
async def reflect_solution_async(s, num_branches, call_generate):
s += (
USER_PREFIX
+ """Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness."""
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX
comps = await call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
async def get_final_answer_async(s, num_branches, call_generate):
s += (
USER_PREFIX
+ """Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration."""
+ USER_SUFFIX
)
s += ASSISTANT_PREFIX
comps = await call_generate(
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
)
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
async def tree_search_async(question, num_branches, call_generate):
plan_forks = await propose_plan_async("", question, num_branches, call_generate)
sol_states = []
for plan in plan_forks:
forks = await execute_plan_async(plan, num_branches, call_generate)
sol_states.extend(forks)
ref_states = []
for sol in sol_states:
forks = await reflect_solution_async(sol, num_branches, call_generate)
ref_states.extend(forks)
solutions = []
for sol in ref_states:
ans = await get_final_answer_async(sol, num_branches, call_generate)
solutions.append(ans)
return solutions

View File

@@ -39,5 +39,5 @@ python3 bench_other.py --num-questions 32 --backend lightllm
### Benchmark guidance ### Benchmark guidance
``` ```
python3 bench_other.py --num-questions 32 --backend guidance --parallel 1 python3 bench_other.py --num-questions 32 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf
``` ```

View File

@@ -5,17 +5,11 @@ import re
import time import time
from collections import Counter from collections import Counter
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import ( from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
add_common_other_args_and_parse,
call_generate_lightllm,
call_generate_srt_raw,
call_generate_vllm,
)
from sglang.utils import dump_state_text, read_jsonl from sglang.utils import dump_state_text, read_jsonl
INVALID = -9999999 INVALID = -9999999
@@ -119,52 +113,7 @@ def main(args):
arguments = [{"question": q, "num_branches": num_branches} for q in questions] arguments = [{"question": q, "num_branches": num_branches} for q in questions]
# Select backend # Select backend
if args.backend == "lightllm": call_generate = get_call_generate(args)
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_lightllm, url=url)
elif args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_vllm, url=url)
elif args.backend == "srt-raw":
url = f"{args.host}:{args.port}/generate"
call_generate = partial(call_generate_srt_raw, url=url)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def call_generate(prompt, temperature, max_tokens, stop, n):
if n == 1:
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
return out["answer"]
else:
rets = []
for i in range(n):
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
)
rets.append(out["answer"])
return rets
# Run requests # Run requests
states = [None] * len(questions) states = [None] * len(questions)
@@ -178,7 +127,13 @@ def main(args):
get_one_answer(i) get_one_answer(i)
else: else:
with ThreadPoolExecutor(args.parallel) as executor: with ThreadPoolExecutor(args.parallel) as executor:
executor.map(get_one_answer, list(range(len(questions)))) list(
tqdm(
executor.map(get_one_answer, list(range(len(questions)))),
total=len(questions),
)
)
latency = time.time() - tic latency = time.time() - tic
answers_text = [] answers_text = []

View File

@@ -1,14 +1,20 @@
"""Common utilities for testing and benchmarking""" """Common utilities for testing and benchmarking"""
import asyncio
from functools import partial
import numpy as np import numpy as np
import requests import requests
from sglang.backend.openai import OpenAI from sglang.backend.openai import OpenAI
from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.utils import get_exception_traceback
def call_generate_lightllm(prompt, temperature, max_tokens, stop, url): def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
assert url is not None
data = { data = {
"inputs": prompt, "inputs": prompt,
"parameters": { "parameters": {
@@ -23,7 +29,9 @@ def call_generate_lightllm(prompt, temperature, max_tokens, stop, url):
return pred return pred
def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1): def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None):
assert url is not None
data = { data = {
"prompt": prompt, "prompt": prompt,
"temperature": temperature, "temperature": temperature,
@@ -41,8 +49,10 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
def call_generate_outlines( def call_generate_outlines(
prompt, temperature, max_tokens, url, stop=[], regex=None, n=1 prompt, temperature, max_tokens, stop=[], regex=None, n=1, url=None
): ):
assert url is not None
data = { data = {
"prompt": prompt, "prompt": prompt,
"temperature": temperature, "temperature": temperature,
@@ -60,7 +70,9 @@ def call_generate_outlines(
return pred return pred
def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url): def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
assert url is not None
data = { data = {
"text": prompt, "text": prompt,
"sampling_params": { "sampling_params": {
@@ -76,7 +88,71 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
return pred return pred
def call_select_lightllm(context, choices, url): def call_generate_guidance(
prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
):
assert model is not None
from guidance import gen
rets = []
for _ in range(n):
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
regex=regex,
)
)
rets.append(out["answer"])
return rets if n > 1 else rets[0]
async def call_generate_lmql(
prompt, temperature, max_tokens, stop=None, n=1, max_len=4096, model=None, **kwargs
):
assert model is not None
import lmql
if stop != None:
@lmql.query(model=model)
async def program(question, max_tokens, stop):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and STOPS_AT(ANSWER, stop)
return ANSWER
'''
else:
@lmql.query(model=model)
async def program(question, max_tokens):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens
return ANSWER
'''
tasks = [
program(
question=prompt,
temperature=temperature,
max_tokens=max_tokens,
stop=stop,
max_len=max_len,
**kwargs,
)
for _ in range(n)
]
rets = await asyncio.gather(*tasks)
return rets if n > 1 else rets[0]
def call_select_lightllm(context, choices, url=None):
assert url is not None
scores = [] scores = []
for i in range(len(choices)): for i in range(len(choices)):
data = { data = {
@@ -91,7 +167,9 @@ def call_select_lightllm(context, choices, url):
return np.argmax(scores) return np.argmax(scores)
def call_select_vllm(context, choices, url): def call_select_vllm(context, choices, url=None):
assert url is not None
scores = [] scores = []
for i in range(len(choices)): for i in range(len(choices)):
data = { data = {
@@ -113,6 +191,31 @@ def call_select_vllm(context, choices, url):
""" """
def call_select_guidance(context, choices, model=None):
assert model is not None
from guidance import select
out = model + context + select(choices, name="answer")
return choices.index(out["answer"])
async def call_select_lmql(context, choices, temperature=0, max_len=4096, model=None):
assert model is not None
import lmql
@lmql.query(model=model)
async def program(ctx, choices):
'''lmql
"""{ctx}[ANSWER]""" where ANSWER in set(choices)
return ANSWER
'''
answer = await program(
ctx=context, choices=choices, temperature=temperature, max_len=max_len
)
return choices.index(answer)
def add_common_other_args_and_parse(parser): def add_common_other_args_and_parse(parser):
parser.add_argument("--parallel", type=int, default=64) parser.add_argument("--parallel", type=int, default=64)
parser.add_argument("--host", type=str, default="http://127.0.0.1") parser.add_argument("--host", type=str, default="http://127.0.0.1")
@@ -121,8 +224,17 @@ def add_common_other_args_and_parse(parser):
"--backend", "--backend",
type=str, type=str,
required=True, required=True,
choices=["vllm", "lightllm", "guidance", "lmql", "srt-raw", "llama.cpp"], choices=[
"vllm",
"outlines",
"lightllm",
"guidance",
"lmql",
"srt-raw",
"llama.cpp",
],
) )
parser.add_argument("--n-ctx", type=int, default=4096)
parser.add_argument( parser.add_argument(
"--model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf" "--model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
) )
@@ -132,6 +244,7 @@ def add_common_other_args_and_parse(parser):
if args.port is None: if args.port is None:
default_port = { default_port = {
"vllm": 21000, "vllm": 21000,
"outlines": 21000,
"lightllm": 22000, "lightllm": 22000,
"lmql": 23000, "lmql": 23000,
"srt-raw": 30000, "srt-raw": 30000,
@@ -161,3 +274,77 @@ def select_sglang_backend(args):
else: else:
raise ValueError(f"Invalid backend: {args.backend}") raise ValueError(f"Invalid backend: {args.backend}")
return backend return backend
def _get_call_generate(args):
if args.backend == "lightllm":
return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "vllm":
return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "srt-raw":
return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
elif args.backend == "outlines":
return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
elif args.backend == "guidance":
from guidance import models
model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
call_generate = partial(call_generate_guidance, model=model)
call_generate("Hello,", 1.0, 8, ".")
return call_generate
elif args.backend == "lmql":
import lmql
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
return partial(call_generate_lmql, model=model)
else:
raise ValueError(f"Invalid backend: {args.backend}")
def _get_call_select(args):
if args.backend == "lightllm":
return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "vllm":
return partial(call_select_vllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "guidance":
from guidance import models
model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
call_select = partial(call_select_guidance, model=model)
call_select("Hello,", ["world", "earth"])
return call_select
elif args.backend == "lmql":
import lmql
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
return partial(call_select_lmql, model=model)
else:
raise ValueError(f"Invalid backend: {args.backend}")
def get_call_generate(args):
call_generate = _get_call_generate(args)
def func(*args, **kwargs):
try:
return call_generate(*args, **kwargs)
except Exception:
print("Exception in call_generate:\n" + get_exception_traceback())
raise
return func
def get_call_select(args):
call_select = _get_call_select(args)
def func(*args, **kwargs):
try:
return call_select(*args, **kwargs)
except Exception:
print("Exception in call_select:\n" + get_exception_traceback())
raise
return func