diff --git a/benchmark/dspy/README.md b/benchmark/dspy/README.md index f321f9570..a2b213aa6 100644 --- a/benchmark/dspy/README.md +++ b/benchmark/dspy/README.md @@ -9,6 +9,12 @@ Turn off cache at https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa2 cache_turn_on = False ``` +or set the environment variable + +``` +export DSP_CACHEBOOL=false +``` + ## Benchmark SGLang ``` python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 diff --git a/benchmark/generative_agents/README.md b/benchmark/generative_agents/README.md index 3801e2b1b..393a9ce83 100644 --- a/benchmark/generative_agents/README.md +++ b/benchmark/generative_agents/README.md @@ -28,5 +28,11 @@ python3 bench_other.py --num-events 1000 --backend vllm --parallel 1 ### Benchmark guidance ``` -python3 bench_other.py --num-events 1000 --backend guidance --parallel 1 +python3 bench_other.py --num-events 1000 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf +``` + +### Benchmark lmql + +``` +python3 bench_other.py --num-events 1000 --backend lmql --parallel 1 ``` diff --git a/benchmark/generative_agents/bench_other.py b/benchmark/generative_agents/bench_other.py index 8dc462483..48f6ebc40 100644 --- a/benchmark/generative_agents/bench_other.py +++ b/benchmark/generative_agents/bench_other.py @@ -1,8 +1,6 @@ import argparse import json import time -from functools import partial -from pathlib import Path from agent_functions import ( action_location_object_prompt, @@ -13,12 +11,7 @@ from agent_functions import ( ) from tqdm import tqdm -from sglang.test.test_utils import ( - add_common_other_args_and_parse, - call_generate_lightllm, - call_generate_srt_raw, - call_generate_vllm, -) +from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate from sglang.utils import dump_state_text, read_jsonl @@ -36,48 +29,27 @@ def main(args): states = [] # Select backend - if args.backend == "lightllm": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_lightllm, url=url) - elif args.backend == "vllm": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_vllm, url=url) - elif args.backend == "srt-raw": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_srt_raw, url=url) - elif args.backend == "guidance": - from guidance import gen, models - - model = models.LlamaCpp( - str(Path.home()) + "/model_weights/Llama-2-7b-chat.gguf", - n_gpu_layers=-1, - n_ctx=4096, - ) - - def call_generate(prompt, temperature, max_tokens, stop): - out = ( - model - + prompt - + gen( - name="result", - max_tokens=max_tokens, - temperature=temperature, - stop=stop, - ) - ) - return out["result"] - - else: - raise ValueError(f"Invalid backend: {args.backend}") + call_generate = get_call_generate(args) def get_one_answer(arg): answer = call_generate(**arg, temperature=0) states.append(answer) + async def get_one_answer_async(arg): + answer = await call_generate(**arg, temperature=0) + states.append(answer) + tic = time.time() # we always sequentially execute agent calls to maintain its dependency - for arg in tqdm(arguments): - get_one_answer(arg) + if args.backend != "lmql": + for arg in tqdm(arguments): + get_one_answer(arg) + else: + import asyncio + + loop = asyncio.get_event_loop() + for arg in tqdm(arguments): + loop.run_until_complete(get_one_answer_async(arg)) latency = time.time() - tic print(f"Latency: {latency:.3f}") diff --git a/benchmark/gsm8k/README.md b/benchmark/gsm8k/README.md index ffd2bcf9d..cb68d269d 100644 --- a/benchmark/gsm8k/README.md +++ b/benchmark/gsm8k/README.md @@ -38,7 +38,7 @@ python3 bench_other.py --num-questions 200 --backend lightllm ### Benchmark guidance ``` -python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 +python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf ``` diff --git a/benchmark/gsm8k/bench_other.py b/benchmark/gsm8k/bench_other.py index 254ffc6e7..2815a079e 100644 --- a/benchmark/gsm8k/bench_other.py +++ b/benchmark/gsm8k/bench_other.py @@ -5,17 +5,11 @@ import json import re import time from concurrent.futures import ThreadPoolExecutor -from functools import partial import numpy as np from tqdm import tqdm -from sglang.test.test_utils import ( - add_common_other_args_and_parse, - call_generate_lightllm, - call_generate_srt_raw, - call_generate_vllm, -) +from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate from sglang.utils import dump_state_text, read_jsonl INVALID = -9999999 @@ -63,54 +57,7 @@ def main(args): states = [None] * len(labels) # Select backend - if args.backend == "lightllm": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_lightllm, url=url) - elif args.backend == "vllm": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_vllm, url=url) - elif args.backend == "srt-raw": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_srt_raw, url=url) - elif args.backend == "guidance": - from guidance import gen, models - - model = models.LlamaCpp( - "/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", - n_gpu_layers=-1, - n_ctx=4096, - ) - - def call_generate(prompt, temperature, max_tokens, stop): - out = ( - model - + prompt - + gen( - name="answer", - max_tokens=max_tokens, - temperature=temperature, - stop=stop, - ) - ) - return out["answer"] - - elif args.backend == "lmql": - import lmql - - model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}") - - @lmql.query(model=model) - async def program(question): - '''lmql - """{question}[ANSWER]""" where len(TOKENS(ANSWER)) < 257 and STOPS_AT(ANSWER, "Question") - return ANSWER - ''' - - async def call_generate(prompt, temperature, max_tokens, stop): - return await program(question=prompt, temperature=0) - - else: - raise ValueError(f"Invalid backend: {args.backend}") + call_generate = get_call_generate(args) # Run requests if args.backend != "lmql": @@ -130,7 +77,13 @@ def main(args): get_one_answer(i) else: with ThreadPoolExecutor(args.parallel) as executor: - executor.map(get_one_answer, list(range(len(questions)))) + list( + tqdm( + executor.map(get_one_answer, list(range(len(questions)))), + total=len(questions), + ) + ) + else: # Use asyncio async def batched_call(batch_size): diff --git a/benchmark/hellaswag/README.md b/benchmark/hellaswag/README.md index c2b7b2aa2..b3e7abc30 100644 --- a/benchmark/hellaswag/README.md +++ b/benchmark/hellaswag/README.md @@ -38,7 +38,7 @@ python3 bench_other.py --num-questions 200 --backend lightllm ### Benchmark guidance ``` -CUDA_VISIBLE_DEVICES=0,1 python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 +CUDA_VISIBLE_DEVICES=0,1 python3 bench_other.py --num-questions 200 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf ``` diff --git a/benchmark/hellaswag/bench_other.py b/benchmark/hellaswag/bench_other.py index 3436b06ce..5b9ba797b 100644 --- a/benchmark/hellaswag/bench_other.py +++ b/benchmark/hellaswag/bench_other.py @@ -3,15 +3,11 @@ import asyncio import json import time from concurrent.futures import ThreadPoolExecutor -from functools import partial import numpy as np +from tqdm import tqdm -from sglang.test.test_utils import ( - add_common_other_args_and_parse, - call_select_lightllm, - call_select_vllm, -) +from sglang.test.test_utils import add_common_other_args_and_parse, get_call_select from sglang.utils import read_jsonl @@ -47,47 +43,7 @@ def main(args): preds = [None] * len(labels) # Select backend - if args.backend == "lightllm": - url = f"{args.host}:{args.port}/generate" - call_select = partial(call_select_lightllm, url=url) - elif args.backend == "vllm": - url = f"{args.host}:{args.port}/generate" - call_select = partial(call_select_vllm, url=url) - elif args.backend == "guidance": - from guidance import models, select - - model = models.LlamaCpp( - "/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", - n_gpu_layers=-1, - n_ctx=4096, - ) - - def call_select(context, choices): - out = model + context + select(choices, name="answer") - return choices.index(out["answer"]) - - call_select("Hello,", ["world", "earth"]) - - elif args.backend == "lmql": - import lmql - - model = lmql.model( - "meta-llama/Llama-2-7b-chat-hf", endpoint=f"{args.host}:{args.port}" - ) - - @lmql.query(model=model) - async def program(ctx, choices): - '''lmql - """{ctx}[ANSWER]""" where ANSWER in set(choices) - return ANSWER - ''' - - async def call_select(context, choices): - answer = await program(ctx=context, choices=choices, temperature=0) - return choices.index(answer) - - else: - raise ValueError(f"Invalid backend: {args.backend}") + call_select = get_call_select(args) # Run requests if args.backend != "lmql": @@ -99,11 +55,17 @@ def main(args): tic = time.time() if args.parallel == 1: - for i in range(len(questions)): + for i in tqdm(range(len(questions))): get_one_answer(i) else: with ThreadPoolExecutor(args.parallel) as executor: - executor.map(get_one_answer, list(range(len(questions)))) + list( + tqdm( + executor.map(get_one_answer, list(range(len(questions)))), + total=len(questions), + ) + ) + else: # Use asyncio async def batched_call(batch_size): diff --git a/benchmark/json_decode_regex/README.md b/benchmark/json_decode_regex/README.md index 399b89154..ec575a6b1 100644 --- a/benchmark/json_decode_regex/README.md +++ b/benchmark/json_decode_regex/README.md @@ -36,7 +36,7 @@ python3 bench_sglang.py --num-questions 10 ``` -### Benchmark vllm +### Benchmark Outlines + vLLM Run Llama-7B @@ -47,7 +47,7 @@ python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2 Benchmark ``` -python3 bench_other.py --backend vllm --num-questions 10 +python3 bench_other.py --backend outlines --num-questions 10 ``` @@ -56,5 +56,5 @@ python3 bench_other.py --backend vllm --num-questions 10 Run Llama-7B and benchmark ``` -python3 bench_other.py --backend guidance --num-questions 10 --parallel 1 +python3 bench_other.py --backend guidance --num-questions 10 --parallel 1 --n-ctx 4096 --model-path path/to/gguf ``` diff --git a/benchmark/json_decode_regex/bench_other.py b/benchmark/json_decode_regex/bench_other.py index 4532644c6..bbe22835a 100644 --- a/benchmark/json_decode_regex/bench_other.py +++ b/benchmark/json_decode_regex/bench_other.py @@ -7,10 +7,7 @@ from functools import partial from tqdm import tqdm from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING -from sglang.test.test_utils import ( - add_common_other_args_and_parse, - call_generate_outlines, -) +from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate from sglang.utils import dump_state_text, read_jsonl REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]" @@ -50,41 +47,11 @@ def main(args): states = [None] * len(arguments) # Select backend - if args.backend == "vllm": - url = f"{args.host}:{args.port}/generate" - generate = partial(call_generate_outlines, url=url, temperature=0) - elif args.backend == "guidance": - from guidance import gen, models - - model = models.LlamaCpp( - "/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf", - n_gpu_layers=-1, - n_ctx=4096, - ) - - def generate(prompt, max_tokens, stop=None, regex=None): - out = ( - model - + prompt - + gen( - name="answer", - max_tokens=max_tokens, - temperature=0, - stop=stop, - regex=regex, - ) - ) - return out["answer"] - - # warmup - for _ in range(3): - generate("Hello!" * 10, max_tokens=64, stop=None) - else: - raise ValueError(f"Invalid backend: {args.backend}") + call_generate = partial(get_call_generate(args), temperature=0) # Run requests def get_one_answer(i): - states[i] = json_decode(generate=generate, **arguments[i]) + states[i] = json_decode(generate=call_generate, **arguments[i]) tic = time.time() if args.parallel == 1: @@ -92,7 +59,12 @@ def main(args): get_one_answer(i) else: with ThreadPoolExecutor(args.parallel) as executor: - rets = executor.map(get_one_answer, list(range(len(arguments)))) + rets = list( + tqdm( + executor.map(get_one_answer, list(range(len(arguments)))), + total=len(arguments), + ) + ) for _ in rets: pass diff --git a/benchmark/json_jump_forward/README.md b/benchmark/json_jump_forward/README.md index 95737cf8f..1a34d3669 100644 --- a/benchmark/json_jump_forward/README.md +++ b/benchmark/json_jump_forward/README.md @@ -39,7 +39,7 @@ python3 bench_sglang.py --mode city ``` -### Benchmark vllm +### Benchmark Outlines + vLLM Run Llama-7B @@ -50,13 +50,13 @@ python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2 Benchmark Character Generation ```bash -python3 bench_other.py --mode character --backend vllm +python3 bench_other.py --mode character --backend outlines ``` Benchmark City Information Retrieval ```bash -python3 bench_other.py --mode city --backend vllm +python3 bench_other.py --mode city --backend outlines ``` ### Benchmark guidance @@ -64,11 +64,25 @@ python3 bench_other.py --mode city --backend vllm Run Llama-7B and benchmark character generation ```bash -python3 bench_other.py --mode character --backend guidance --parallel 1 +python3 bench_other.py --mode character --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf ``` Run Llama-7B and benchmark city information retrieval ```bash -python3 bench_other.py --mode city --backend guidance --parallel 1 +python3 bench_other.py --mode city --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf +``` + +### Benchmark lmql + +Run Llama-7B and benchmark character generation + +``` +python3 bench_other.py --mode character --backend lmql --parallel 1 +``` + +Run Llama-7B and benchmark city information retrieval + +``` +python3 bench_other.py --mode city --backend lmql --parallel 1 ``` diff --git a/benchmark/json_jump_forward/bench_other.py b/benchmark/json_jump_forward/bench_other.py index bb8fdc1dd..9eb5c58b3 100644 --- a/benchmark/json_jump_forward/bench_other.py +++ b/benchmark/json_jump_forward/bench_other.py @@ -7,10 +7,7 @@ from functools import partial import guidance from tqdm import tqdm -from sglang.test.test_utils import ( - add_common_other_args_and_parse, - call_generate_outlines, -) +from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate from sglang.utils import dump_state_text, read_jsonl # there are some FSM bugs with json regex converted from pydantic model @@ -85,6 +82,29 @@ def character_maker(lm, name): return lm +async def call_generate_lmql( + prompt, temperature, max_tokens, regex, max_len=4096, model=None, **kwargs +): + assert model is not None + import lmql + + @lmql.query(model=model) + async def program(question, max_tokens, regex): + '''lmql + """{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and REGEX(ANSWER, regex) + return ANSWER + ''' + + return await program( + question=prompt, + temperature=temperature, + max_tokens=max_tokens, + max_len=max_len, + regex=regex, + **kwargs, + ) + + @guidance def city_maker(lm, document): regex_str_no_quote = r"[\w\d\s]+" @@ -119,38 +139,68 @@ def bench_character(args): states = [None] * len(arguments) # Select backend - if args.backend == "vllm": - url = f"{args.host}:{args.port}/generate" - generate = partial(call_generate_outlines, url=url, temperature=0) + if args.backend == "outlines": + call_generate = partial(get_call_generate(args), temperature=0) - def func(i): - states[i] = character_gen(**arguments[i], generate=generate) + def get_one_answer(i): + states[i] = character_gen(**arguments[i], generate=call_generate) - get_one_answer = func elif args.backend == "guidance": model = guidance.models.LlamaCpp( - args.llama_cpp_model_path, + args.model_path, n_gpu_layers=-1, - n_ctx=4096, + n_ctx=args.n_ctx, ) - def func(i): + def get_one_answer(i): lm = model + character_maker(**arguments[i]) states[i] = lm - get_one_answer = func + elif args.backend == "lmql": + import asyncio + + import lmql + + model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}") + call_generate = partial( + call_generate_lmql, + model=model, + max_tokens=256, + regex=character_regex, + ) + + async def get_one_answer_async(i): + states[i] = await call_generate(prompt=arguments[i]["name"], temperature=0) + else: raise ValueError(f"Invalid backend: {args.backend}") tic = time.time() - if args.parallel == 1: - for i in tqdm(range(len(arguments))): - get_one_answer(i) + + if args.backend != "lmql": + if args.parallel == 1: + for i in tqdm(range(len(arguments))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + rets = list( + tqdm( + executor.map(get_one_answer, list(range(len(arguments)))), + total=len(arguments), + ) + ) + for _ in rets: + pass else: - with ThreadPoolExecutor(args.parallel) as executor: - rets = executor.map(get_one_answer, list(range(len(arguments)))) - for _ in rets: - pass + batches = [] + for i in range(0, len(arguments), args.parallel): + batches.append(list(range(i, min(i + args.parallel, len(arguments))))) + loop = asyncio.get_event_loop() + + for bt in tqdm(batches): + loop.run_until_complete( + asyncio.gather(*[get_one_answer_async(i) for i in bt]) + ) latency = time.time() - tic @@ -166,26 +216,23 @@ def bench_city_doc(args): states = [None] * len(arguments) # Select backend - if args.backend == "vllm": - url = f"{args.host}:{args.port}/generate" - generate = partial(call_generate_outlines, url=url, temperature=0) + if args.backend == "outlines": + call_generate = partial(get_call_generate(args), temperature=0) - def func(i): - states[i] = city_gen(**arguments[i], generate=generate) + def get_one_answer(i): + states[i] = city_gen(**arguments[i], generate=call_generate) - get_one_answer = func elif args.backend == "guidance": model = guidance.models.LlamaCpp( - args.llama_cpp_model_path, + args.model_path, n_gpu_layers=-1, - n_ctx=4096, + n_ctx=args.n_ctx, ) - def func(i): + def get_one_answer(i): lm = model + city_maker(**arguments[i]) states[i] = lm - get_one_answer = func else: raise ValueError(f"Invalid backend: {args.backend}") @@ -237,10 +284,5 @@ if __name__ == "__main__": parser.add_argument( "--mode", type=str, default="character", choices=["character", "city"] ) - parser.add_argument( - "--llama-cpp-model-path", - type=str, - default="/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf", - ) args = add_common_other_args_and_parse(parser) main(args) diff --git a/benchmark/llm_judge/README.md b/benchmark/llm_judge/README.md index e4516bf10..08255b641 100644 --- a/benchmark/llm_judge/README.md +++ b/benchmark/llm_judge/README.md @@ -23,5 +23,11 @@ python3 bench_other.py --backend vllm --num-questions 25 ### Benchmark guidance ``` -python3 bench_other.py --backend guidance --num-questions 25 --parallel 1 +python3 bench_other.py --backend guidance --num-questions 25 --parallel 1 --n-ctx 4096 --model-path path/to/gguf ``` + +### Benchmark lmql + +``` +python3 bench_other.py --backend lmql --num-questions 25 --parallel 1 +``` \ No newline at end of file diff --git a/benchmark/llm_judge/bench_other.py b/benchmark/llm_judge/bench_other.py index 5fdc2c4ce..2231bcdbb 100644 --- a/benchmark/llm_judge/bench_other.py +++ b/benchmark/llm_judge/bench_other.py @@ -6,12 +6,7 @@ from functools import partial from tqdm import tqdm -from sglang.test.test_utils import ( - add_common_other_args_and_parse, - call_generate_lightllm, - call_generate_srt_raw, - call_generate_vllm, -) +from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate from sglang.utils import dump_state_text, read_jsonl system_prompt = "Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency." @@ -54,53 +49,77 @@ def multi_dimension_judge(article, generate): return s +async def multi_dimension_judge_async(article, generate): + s = system_prompt + s += "\n```\n" + article + "\n```\n\n" + + judges = [] + for i in range(len(dimension_prompts)): + comp = await generate( + s + + "USER: Please judge the quality based on the following metric. " + + dimension_prompts[i] + + " Please provide a single-paragraph judgement. " + + "Focus on the provided metric and do not say other things. " + 'End your judgement paragraph with the word "END"\nJUDGE:', + max_tokens=256, + stop="END", + ) + judges.append(comp) + + s += "I will judge the quality based on the following metrics.\n" + for i in range(len(dimension_prompts)): + s += dimension_prompts[i].split(":")[0] + ": " + judges[i].strip() + "\n" + + s += "In summary, on a scale of 1 to 10, I would give the article a score of" + s += await generate(s, max_tokens=2, stop=None) + + return s + + def main(args): lines = read_jsonl(args.data_path)[: args.num_questions] states = [None] * len(lines) # Select backend - if args.backend == "lightllm": - url = f"{args.host}:{args.port}/generate" - generate = partial(call_generate_lightllm, url=url, temperature=0) - elif args.backend == "vllm": - url = f"{args.host}:{args.port}/generate" - generate = partial(call_generate_vllm, url=url, temperature=0) - elif args.backend == "srt-raw": - url = f"{args.host}:{args.port}/generate" - generate = partial(call_generate_srt_raw, url=url, temperature=0) - elif args.backend == "guidance": - from guidance import gen, models - - model = models.LlamaCpp( - "/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", - n_gpu_layers=-1, - n_ctx=4096, - ) - - def generate(prompt, max_tokens, stop): - out = ( - model - + prompt - + gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop) - ) - return out["answer"] - - # warmup - generate("Hello!", max_tokens=8, stop=None) - else: - raise ValueError(f"Invalid backend: {args.backend}") + call_generate = partial(get_call_generate(args), temperature=0) # Run requests - def get_one_answer(i): - states[i] = multi_dimension_judge(lines[i], generate) - tic = time.time() - if args.parallel == 1: - for i in tqdm(range(len(lines))): - get_one_answer(i) + + if args.backend != "lmql": + + def get_one_answer(i): + states[i] = multi_dimension_judge(lines[i], call_generate) + + if args.parallel == 1: + for i in tqdm(range(len(lines))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + list( + tqdm( + executor.map(get_one_answer, list(range(len(lines)))), + total=len(lines), + ) + ) + else: - with ThreadPoolExecutor(args.parallel) as executor: - executor.map(get_one_answer, list(range(len(lines)))) + import asyncio + + async def get_one_answer_async(i): + states[i] = await multi_dimension_judge_async(lines[i], call_generate) + + batches = [] + for i in range(0, len(lines), args.parallel): + batches.append(list(range(i, min(i + args.parallel, len(lines))))) + + loop = asyncio.get_event_loop() + for bt in tqdm(batches): + loop.run_until_complete( + asyncio.gather(*[get_one_answer_async(i) for i in bt]) + ) + latency = time.time() - tic # Compute accuracy diff --git a/benchmark/long_json_decode/README.md b/benchmark/long_json_decode/README.md index 6d52030a5..37fceee13 100644 --- a/benchmark/long_json_decode/README.md +++ b/benchmark/long_json_decode/README.md @@ -22,7 +22,7 @@ python3 bench_other.py --backend vllm --num-questions 5 ### Benchmark guidance ``` -python3 bench_other.py --backend guidance --num-questions 5 --parallel 1 +python3 bench_other.py --backend guidance --num-questions 5 --parallel 1 --n-ctx 11000 --model-path path/to/code-llama/gguf ``` diff --git a/benchmark/long_json_decode/bench_other.py b/benchmark/long_json_decode/bench_other.py index 0627d9928..a83c797c4 100644 --- a/benchmark/long_json_decode/bench_other.py +++ b/benchmark/long_json_decode/bench_other.py @@ -6,12 +6,7 @@ from functools import partial from tqdm import tqdm -from sglang.test.test_utils import ( - add_common_other_args_and_parse, - call_generate_lightllm, - call_generate_srt_raw, - call_generate_vllm, -) +from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate from sglang.utils import dump_state_text, read_jsonl @@ -44,40 +39,11 @@ def main(args): states = [None] * len(arguments) # Select backend - if args.backend == "lightllm": - url = f"{args.host}:{args.port}/generate" - generate = partial(call_generate_lightllm, url=url, temperature=0) - elif args.backend == "vllm": - url = f"{args.host}:{args.port}/generate" - generate = partial(call_generate_vllm, url=url, temperature=0) - elif args.backend == "srt-raw": - url = f"{args.host}:{args.port}/generate" - generate = partial(call_generate_srt_raw, url=url, temperature=0) - elif args.backend == "guidance": - from guidance import gen, models - - model = models.LlamaCpp( - "/home/ubuntu/model_weights/CodeLlama-7b-instruct-hf.gguf", - n_gpu_layers=-1, - n_ctx=11000, - ) - - def generate(prompt, max_tokens, stop): - out = ( - model - + prompt - + gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop) - ) - return out["answer"] - - # warmup - generate("Hello!", max_tokens=8, stop=None) - else: - raise ValueError(f"Invalid backend: {args.backend}") + call_generate = partial(get_call_generate(args), temperature=0) # Run requests def get_one_answer(i): - states[i] = json_decode(generate=generate, **arguments[i]) + states[i] = json_decode(generate=call_generate, **arguments[i]) tic = time.time() if args.parallel == 1: @@ -85,7 +51,13 @@ def main(args): get_one_answer(i) else: with ThreadPoolExecutor(args.parallel) as executor: - executor.map(get_one_answer, list(range(len(arguments)))) + list( + tqdm( + executor.map(get_one_answer, list(range(len(arguments)))), + total=len(arguments), + ) + ) + latency = time.time() - tic # Compute accuracy diff --git a/benchmark/mmlu/README.md b/benchmark/mmlu/README.md index 3bc1fa439..9aa01d617 100644 --- a/benchmark/mmlu/README.md +++ b/benchmark/mmlu/README.md @@ -46,7 +46,7 @@ python3 bench_other.py --nsub 10 --backend lightllm ### Benchmark guidance ``` -python3 bench_other.py --nsub 10 --backend guidance --parallel 1 +python3 bench_other.py --nsub 10 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf ``` diff --git a/benchmark/mmlu/bench_other.py b/benchmark/mmlu/bench_other.py index aecdc3204..c5d48dac6 100644 --- a/benchmark/mmlu/bench_other.py +++ b/benchmark/mmlu/bench_other.py @@ -4,19 +4,13 @@ import json import os import time from concurrent.futures import ThreadPoolExecutor -from functools import partial import numpy as np import pandas as pd import tiktoken from tqdm import tqdm -from sglang.test.test_utils import ( - add_common_other_args_and_parse, - call_generate_lightllm, - call_generate_srt_raw, - call_generate_vllm, -) +from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate choices = ["A", "B", "C", "D"] @@ -53,10 +47,7 @@ def gen_prompt(train_df, subject, k=-1): return prompt -model_initialized = None - - -def evaluate(args, subject, dev_df, test_df): +def evaluate(args, subject, dev_df, test_df, call_generate): prompts = [] labels = [] @@ -78,62 +69,6 @@ def evaluate(args, subject, dev_df, test_df): preds = [None] * len(prompts) max_tokens = 1 - # Select backend - global model_initialized - - if args.backend == "lightllm": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_lightllm, url=url, stop=None) - elif args.backend == "vllm": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_vllm, url=url, stop=None) - elif args.backend == "srt-raw": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_srt_raw, url=url, stop=None) - elif args.backend == "guidance": - from guidance import gen, models - - if model_initialized is None: - model = models.LlamaCpp( - "/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", - n_gpu_layers=-1, - n_ctx=4096, - ) - model_initialized = model - else: - model = model_initialized - - def call_generate(prompt, temperature, max_tokens): - out = ( - model - + prompt - + gen(name="answer", max_tokens=max_tokens, temperature=0) - ) - return out["answer"] - - # warmup - call_generate("Hello,", temperature=1.0, max_tokens=8) - - elif args.backend == "lmql": - import lmql - - model = lmql.model( - "meta-llama/Llama-2-7b-chat-hf", endpoint=f"{args.host}:{args.port}" - ) - - @lmql.query(model=model) - async def program(question): - '''lmql - """{question}[ANSWER]""" where len(TOKENS(ANSWER)) < 2 - return ANSWER - ''' - - async def call_generate(prompt, temperature, max_tokens): - return await program(question=prompt, temperature=temperature) - - else: - raise ValueError(f"Invalid backend: {args.backend}") - # Run requests if args.backend != "lmql": # Use thread pool @@ -190,6 +125,9 @@ def main(args): all_latencies = [] num_requests = 0 + # Select backend + call_generate = get_call_generate(args) + for subject in tqdm(subjects[: args.nsub]): dev_df = pd.read_csv( os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None @@ -198,7 +136,7 @@ def main(args): os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None ) - cors, acc, latency = evaluate(args, subject, dev_df, test_df) + cors, acc, latency = evaluate(args, subject, dev_df, test_df, call_generate) all_cors.append(cors) all_latencies.append(latency) num_requests += len(test_df) diff --git a/benchmark/mtbench/README.md b/benchmark/mtbench/README.md index e32a6eaab..28623fc7b 100644 --- a/benchmark/mtbench/README.md +++ b/benchmark/mtbench/README.md @@ -1,3 +1,9 @@ +## Download Dataset + +```sh +wget -O question.jsonl https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl +``` + ## Run benchmark ### Benchmark sglang diff --git a/benchmark/mtbench/bench_other.py b/benchmark/mtbench/bench_other.py index f45c5c0a5..2c321e8a1 100644 --- a/benchmark/mtbench/bench_other.py +++ b/benchmark/mtbench/bench_other.py @@ -4,16 +4,11 @@ import os import time import uuid from concurrent.futures import ThreadPoolExecutor -from functools import partial from fastchat.model import get_conversation_template +from tqdm import tqdm -from sglang.test.test_utils import ( - add_common_other_args_and_parse, - call_generate_lightllm, - call_generate_srt, - call_generate_vllm, -) +from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate def load_questions(filename): @@ -50,17 +45,7 @@ def main(args): conv_main = get_conversation_template(model_id) # Select backend - if args.backend == "lightllm": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_lightllm, url=url, stop=None) - elif args.backend == "vllm": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_vllm, url=url, stop=None) - elif args.backend == "srt": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_srt, url=url, stop=None) - else: - raise ValueError(f"Invalid backend: {args.backend}") + call_generate = get_call_generate(args) answers = [None] * len(questions) @@ -83,11 +68,17 @@ def main(args): # Run requests tic = time.time() if args.parallel == 1: - for i in range(len(questions)): + for i in tqdm(range(len(questions))): get_answer(i) else: with ThreadPoolExecutor(args.parallel) as executor: - executor.map(get_answer, list(range(len(questions)))) + list( + tqdm( + executor.map(get_answer, list(range(len(questions)))), + total=len(questions), + ) + ) + latency = time.time() - tic print(f"#questions: {len(questions)}, Latency: {latency:.2f}") diff --git a/benchmark/multi_chain_reasoning/README.md b/benchmark/multi_chain_reasoning/README.md index 67f627681..b9c3f4d85 100644 --- a/benchmark/multi_chain_reasoning/README.md +++ b/benchmark/multi_chain_reasoning/README.md @@ -39,5 +39,11 @@ python3 bench_other.py --num-questions 64 --backend lightllm ### Benchmark guidance ``` -python3 bench_other.py --num-questions 8 --backend guidance --parallel 1 +python3 bench_other.py --num-questions 8 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf +``` + +### Benchmark lmql + +``` +python3 bench_other.py --num-questions 64 --backend lmql --parallel 1 ``` diff --git a/benchmark/multi_chain_reasoning/bench_other.py b/benchmark/multi_chain_reasoning/bench_other.py index 147909c48..e0ff2be45 100644 --- a/benchmark/multi_chain_reasoning/bench_other.py +++ b/benchmark/multi_chain_reasoning/bench_other.py @@ -5,16 +5,11 @@ import json import re import time from concurrent.futures import ThreadPoolExecutor -from functools import partial import numpy as np +from tqdm import tqdm -from sglang.test.test_utils import ( - add_common_other_args_and_parse, - call_generate_lightllm, - call_generate_srt_raw, - call_generate_vllm, -) +from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate from sglang.utils import dump_state_text, read_jsonl INVALID = -9999999 @@ -67,6 +62,32 @@ def multi_chain_gsm8k(question, num_chains, call_generate): return s +async def multi_chain_gsm8k_async(question, num_chains, call_generate): + s = "Question: " + question + "\n" + # s += call_generate(s + "Answer: " + prompt_lib[0], max_tokens=256, + # stop="Question", temperature=0) + # return s + + comps = [] + for i in range(num_chains): + comps.append( + await call_generate( + s + "Answer: " + prompt_lib[i % num_chains], + max_tokens=256, + temperature=0.3, + stop="Question", + ) + ) + + s += "Answer: To answer this question, here are some possible solutions. " + s += "After considering all of them, I will do a majority vote.\n\n" + for i in range(num_chains): + s += f"Solution {i+1}: " + comps[i].strip() + "\n\n" + s += "\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is " + s += await call_generate(s, max_tokens=16, temperature=0, stop=None) + return s + + def main(args): lines = read_jsonl(args.data_path) @@ -83,71 +104,7 @@ def main(args): states = [None] * len(labels) # Select backend - if args.backend == "lightllm": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_lightllm, url=url) - elif args.backend == "vllm": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_vllm, url=url) - elif args.backend == "srt-raw": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_srt_raw, url=url) - elif args.backend == "guidance": - from guidance import gen, models - - model = models.LlamaCpp( - "/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", - n_gpu_layers=-1, - n_ctx=4096, - ) - - def call_generate(prompt, temperature, max_tokens, stop): - out = ( - model - + prompt - + gen( - name="answer", - max_tokens=max_tokens, - temperature=temperature, - stop=stop, - ) - ) - return out["answer"] - - # def multi_chain_gsm8k(question, num_chains, call_generate): - # s = model + "Question: " + question + "\n" - - # comps = [] - # for i in range(num_chains): - # comps.append(call_generate(s + "Answer: " + prompt_lib[i % num_chains], - # max_tokens=256, temperature=0.3, stop="Question")) - - # s += "Answer: To answer this question, here are some possible solutions. " - # s += "After considering all of them, I will do a majority vote.\n\n" - # for i in range(num_chains): - # s += f"Solution {i+1}: " + comps[i].strip() + "\n\n" - # s += f"\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is " - # return call_generate(s, max_tokens=16, temperature=0, stop=None) - - elif args.backend == "lmql": - import lmql - - model = lmql.model( - "meta-llama/Llama-2-7b-chat-hf", endpoint=f"{args.host}:{args.port}" - ) - - @lmql.query(model=model) - async def program(question): - '''lmql - """{question}[ANSWER]""" where len(TOKENS(ANSWER)) < 257 and STOPS_AT(ANSWER, "Question") - return ANSWER - ''' - - async def call_generate(prompt, temperature, max_tokens, stop): - return await program(question=prompt, temperature=0) - - else: - raise ValueError(f"Invalid backend: {args.backend}") + call_generate = get_call_generate(args) # Run requests if args.backend != "lmql": @@ -158,31 +115,35 @@ def main(args): tic = time.time() if args.parallel == 1: - for i in range(len(questions)): + for i in tqdm(range(len(questions))): get_one_answer(i) else: with ThreadPoolExecutor(args.parallel) as executor: - executor.map(get_one_answer, list(range(len(questions)))) + list( + tqdm( + executor.map(get_one_answer, list(range(len(questions)))), + total=len(questions), + ) + ) + else: # Use asyncio - async def batched_call(batch_size): - for i in range(0, len(questions), batch_size): - tasks = [] - for q in questions[i : i + batch_size]: - tasks.append( - call_generate( - few_shot_examples + q, - temperature=0, - max_tokens=256, - stop="Question", - ) - ) - rets = await asyncio.gather(*tasks) - for j in range(len(rets)): - states[i + j] = get_answer_value(rets[j]) + async def get_one_answer_asyncio(i): + answer = await multi_chain_gsm8k_async( + questions[i], args.num_chains, call_generate + ) + states[i] = answer tic = time.time() - asyncio.run(batched_call(batch_size=args.parallel)) + loop = asyncio.get_event_loop() + batches = [ + list(range(i, min(i + args.parallel, len(questions)))) + for i in range(0, len(questions), args.parallel) + ] + for bt in tqdm(batches): + tasks = [get_one_answer_asyncio(k) for k in bt] + loop.run_until_complete(asyncio.gather(*tasks)) + latency = time.time() - tic preds = [] diff --git a/benchmark/multi_document_qa/README.md b/benchmark/multi_document_qa/README.md index 96b0a3ad6..09f49c78f 100644 --- a/benchmark/multi_document_qa/README.md +++ b/benchmark/multi_document_qa/README.md @@ -22,7 +22,7 @@ python3 bench_other.py --backend vllm --num-questions 64 ### Benchmark guidance ``` -python3 bench_other.py --backend guidance --num-questions 32 --parallel 1 +python3 bench_other.py --backend guidance --num-questions 32 --parallel 1 --n-ctx 11000 --model-path path/to/code-llama/gguf ``` diff --git a/benchmark/multi_document_qa/bench_other.py b/benchmark/multi_document_qa/bench_other.py index cb263b0a7..97ff41686 100644 --- a/benchmark/multi_document_qa/bench_other.py +++ b/benchmark/multi_document_qa/bench_other.py @@ -6,12 +6,7 @@ from functools import partial from tqdm import tqdm -from sglang.test.test_utils import ( - add_common_other_args_and_parse, - call_generate_lightllm, - call_generate_srt_raw, - call_generate_vllm, -) +from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate from sglang.utils import dump_state_text, read_jsonl USER_PREFIX = "[INST] " @@ -60,40 +55,11 @@ def main(args): states = [None] * len(arguments) # Select backend - if args.backend == "lightllm": - url = f"{args.host}:{args.port}/generate" - generate = partial(call_generate_lightllm, url=url, temperature=0) - elif args.backend == "vllm": - url = f"{args.host}:{args.port}/generate" - generate = partial(call_generate_vllm, url=url, temperature=0) - elif args.backend == "srt-raw": - url = f"{args.host}:{args.port}/generate" - generate = partial(call_generate_srt_raw, url=url, temperature=0) - elif args.backend == "guidance": - from guidance import gen, models - - model = models.LlamaCpp( - "/home/ubuntu/model_weights/CodeLlama-7b-instruct-hf.gguf", - n_gpu_layers=-1, - n_ctx=11000, - ) - - def generate(prompt, max_tokens, stop): - out = ( - model - + prompt - + gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop) - ) - return out["answer"] - - # warmup - generate("Hello!", max_tokens=8, stop=None) - else: - raise ValueError(f"Invalid backend: {args.backend}") + call_generate = partial(get_call_generate(args), temperature=0) # Run requests def get_one_answer(i): - states[i] = multi_document_qa(generate=generate, **arguments[i]) + states[i] = multi_document_qa(generate=call_generate, **arguments[i]) tic = time.time() if args.parallel == 1: @@ -101,7 +67,13 @@ def main(args): get_one_answer(i) else: with ThreadPoolExecutor(args.parallel) as executor: - executor.map(get_one_answer, list(range(len(labels)))) + list( + tqdm( + executor.map(get_one_answer, list(range(len(labels)))), + total=len(labels), + ) + ) + latency = time.time() - tic # Compute accuracy diff --git a/benchmark/multi_turn_chat/README.md b/benchmark/multi_turn_chat/README.md index bddd40276..0fb5b21fa 100644 --- a/benchmark/multi_turn_chat/README.md +++ b/benchmark/multi_turn_chat/README.md @@ -56,11 +56,11 @@ python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm Benchmark Llama-7B (short output) ``` -python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 +python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf ``` Benchmark Llama-7B (long output) ``` -python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --long +python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf --long ``` diff --git a/benchmark/multi_turn_chat/bench_other.py b/benchmark/multi_turn_chat/bench_other.py index c86f0435e..81d67ab7b 100644 --- a/benchmark/multi_turn_chat/bench_other.py +++ b/benchmark/multi_turn_chat/bench_other.py @@ -2,61 +2,16 @@ import json import time from argparse import ArgumentParser from concurrent.futures import ThreadPoolExecutor +from functools import partial -import requests from data_gen import gen_arguments from tqdm import tqdm from vllm.transformers_utils.tokenizer import get_tokenizer -from sglang.test.test_utils import add_common_other_args_and_parse +from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate from sglang.utils import dump_state_text -def get_generate(args): - # Select backend - if args.backend == "vllm": - url = f"{args.host}:{args.port}/generate" - - def generate(prompt, max_tokens, stop=None, temperature=0, url=url, n=1): - data = { - "prompt": prompt, - "temperature": temperature, - "max_tokens": max_tokens, - "ignore_eos": True, - "stop": stop, - "stream": False, - "n": n, - } - res = requests.post(url, json=data) - assert res.status_code == 200 - return res.json()["text"][0][len(prompt) :] - - elif args.backend == "guidance": - from guidance import gen, models - - model = models.LlamaCpp( - "/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf", - n_gpu_layers=-1, - n_ctx=4096, - ) - - def generate(prompt, max_tokens, stop=None): - out = ( - model - + prompt - + gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop) - ) - return out["answer"] - - # warmup - for _ in range(3): - generate("Hello!" * 10, max_tokens=64, stop=None) - else: - raise ValueError(f"Invalid backend: {args.backend}") - - return generate - - def multi_turns(generate, qas): s = "" for qa in qas: @@ -75,10 +30,10 @@ def main(args): states = [None] * args.num_qa - generate = get_generate(args) + call_generate = partial(get_call_generate(args), temperature=0) def get_one_answer(i): - states[i] = multi_turns(generate=generate, **multi_qas[i]) + states[i] = multi_turns(generate=call_generate, **multi_qas[i]) tic = time.time() if args.parallel == 1: @@ -86,7 +41,12 @@ def main(args): get_one_answer(i) else: with ThreadPoolExecutor(args.parallel) as executor: - rets = executor.map(get_one_answer, list(range(len(multi_qas)))) + rets = list( + tqdm( + executor.map(get_one_answer, list(range(len(multi_qas)))), + total=len(multi_qas), + ) + ) for _ in rets: pass diff --git a/benchmark/react/README.md b/benchmark/react/README.md index 58de673fa..24b3e95df 100644 --- a/benchmark/react/README.md +++ b/benchmark/react/README.md @@ -24,5 +24,11 @@ python3 bench_other.py --num-questions 100 --backend vllm ### Benchmark guidance ``` -python3 bench_other.py --num-questions 100 --backend guidance --parallel 1 +python3 bench_other.py --num-questions 100 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf ``` + +### Benchmark lmql + +``` +python3 bench_other.py --num-questions 100 --backend lmql --parallel 1 +``` \ No newline at end of file diff --git a/benchmark/react/bench_other.py b/benchmark/react/bench_other.py index dc70a3355..a850bdfb3 100644 --- a/benchmark/react/bench_other.py +++ b/benchmark/react/bench_other.py @@ -2,17 +2,10 @@ import argparse import json import time from concurrent.futures import ThreadPoolExecutor -from functools import partial -from pathlib import Path from tqdm import tqdm -from sglang.test.test_utils import ( - add_common_other_args_and_parse, - call_generate_lightllm, - call_generate_srt_raw, - call_generate_vllm, -) +from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate from sglang.utils import dump_state_text, read_jsonl @@ -97,42 +90,7 @@ def main(args): states = [] # Select backend - if args.backend == "lightllm": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_lightllm, url=url) - elif args.backend == "vllm": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_vllm, url=url) - elif args.backend == "srt-raw": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_srt_raw, url=url) - elif args.backend == "guidance": - from guidance import gen, models - - model = models.LlamaCpp( - str(Path.home()) + "/model_weights/Llama-2-7b-chat.gguf", - n_gpu_layers=-1, - n_ctx=4096, - ) - - def call_generate(prompt, temperature, max_tokens, stop): - out = ( - model - + prompt - + gen( - name="result", - max_tokens=max_tokens, - temperature=temperature, - stop=stop, - ) - ) - return out["result"] - - # warmup - call_generate("Hello,", 1.0, 8, ".") - - else: - raise ValueError(f"Invalid backend: {args.backend}") + call_generate = get_call_generate(args) def run_single_agent(argument): question = argument["question"] @@ -161,13 +119,60 @@ def main(args): states.append(answer) + async def run_single_agent_async(argument): + question = argument["question"] + triplets = argument["triplets"] + prompt = get_prompt(question) + for i in range(1, len(triplets) + 2): + prompt += "Thought " + str(i) + ":" + states.append(prompt) + answer = await call_generate( + prompt, max_tokens=200, temperature=0, stop="Observation", max_len=4096 + ) + if i > len(triplets): + break + prompt += ( + triplets[i - 1]["thought"] + + "\nAction " + + str(i) + + ":" + + triplets[i - 1]["action"] + + "\nObservation " + + str(i) + + ":" + + triplets[i - 1]["observation"] + + "\n" + ) + + states.append(answer) + tic = time.time() - if args.parallel == 1: - for arg in tqdm(arguments): - run_single_agent(arg) + + if args.backend != "lmql": + if args.parallel == 1: + for arg in tqdm(arguments): + run_single_agent(arg) + else: + with ThreadPoolExecutor(args.parallel) as executor: + list( + tqdm( + executor.map(run_single_agent, arguments), total=len(arguments) + ) + ) + else: - with ThreadPoolExecutor(args.parallel) as executor: - executor.map(run_single_agent, arguments) + import asyncio + + loop = asyncio.get_event_loop() + batches = [ + [] for _ in range((len(arguments) + args.parallel - 1) // args.parallel) + ] + for i, arg in enumerate(arguments): + batches[i // args.parallel].append(arg) + for bt in tqdm(batches): + tasks = [run_single_agent_async(arg) for arg in bt] + loop.run_until_complete(asyncio.gather(*tasks)) + latency = time.time() - tic print(f"Latency: {latency:.3f}") diff --git a/benchmark/tip_suggestion/README.md b/benchmark/tip_suggestion/README.md index 81ca47d04..be15da623 100644 --- a/benchmark/tip_suggestion/README.md +++ b/benchmark/tip_suggestion/README.md @@ -23,5 +23,11 @@ python3 bench_other.py --backend vllm --num-questions 64 ### Benchmark guidance ``` -python3 bench_other.py --backend guidance --num-questions 32 --parallel 1 +python3 bench_other.py --backend guidance --num-questions 32 --parallel 1 --n-ctx 4096 --model-path path/to/gguf ``` + +### Benchmark lmql + +``` +python3 bench_other.py --backend lmql --num-questions 32 --parallel 1 +``` \ No newline at end of file diff --git a/benchmark/tip_suggestion/bench_other.py b/benchmark/tip_suggestion/bench_other.py index 46da00227..fcc4fd624 100644 --- a/benchmark/tip_suggestion/bench_other.py +++ b/benchmark/tip_suggestion/bench_other.py @@ -6,12 +6,7 @@ from functools import partial from tqdm import tqdm -from sglang.test.test_utils import ( - add_common_other_args_and_parse, - call_generate_lightllm, - call_generate_srt_raw, - call_generate_vllm, -) +from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate from sglang.utils import dump_state_text, read_jsonl number = 5 @@ -70,48 +65,43 @@ def main(args): states = [None] * len(lines) # Select backend - if args.backend == "lightllm": - url = f"{args.host}:{args.port}/generate" - generate = partial(call_generate_lightllm, url=url, temperature=0) - elif args.backend == "vllm": - url = f"{args.host}:{args.port}/generate" - generate = partial(call_generate_vllm, url=url, temperature=0) - elif args.backend == "srt-raw": - url = f"{args.host}:{args.port}/generate" - generate = partial(call_generate_srt_raw, url=url, temperature=0) - elif args.backend == "guidance": - from guidance import gen, models - - model = models.LlamaCpp( - "/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", - n_gpu_layers=-1, - n_ctx=4096, - ) - - def generate(prompt, max_tokens, stop): - out = ( - model - + prompt - + gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop) - ) - return out["answer"] - - # warmup - generate("Hello!", max_tokens=8, stop=None) - else: - raise ValueError(f"Invalid backend: {args.backend}") + call_generate = partial(get_call_generate(args), temperature=0) # Run requests - def get_one_answer(i): - states[i] = suggest_tips(lines[i]["topic"], generate) - tic = time.time() - if args.parallel == 1: - for i in tqdm(range(len(lines))): - get_one_answer(i) + if args.backend != "lmql": + + def get_one_answer(i): + states[i] = suggest_tips(lines[i]["topic"], call_generate) + + if args.parallel == 1: + for i in tqdm(range(len(lines))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + list( + tqdm( + executor.map(get_one_answer, list(range(len(lines)))), + total=len(lines), + ) + ) + else: - with ThreadPoolExecutor(args.parallel) as executor: - executor.map(get_one_answer, list(range(len(lines)))) + import asyncio + + from lmql_funcs import suggest_tips_async + + async def get_one_answer_async(i): + states[i] = await suggest_tips_async(lines[i]["topic"], call_generate) + + batches = [] + for i in range(0, len(lines), args.parallel): + batches.append(list(range(i, min(i + args.parallel, len(lines))))) + loop = asyncio.get_event_loop() + for batch in tqdm(batches): + loop.run_until_complete( + asyncio.gather(*[get_one_answer_async(i) for i in batch]) + ) latency = time.time() - tic # Compute accuracy diff --git a/benchmark/tip_suggestion/lmql_funcs.py b/benchmark/tip_suggestion/lmql_funcs.py new file mode 100644 index 000000000..7790bbe95 --- /dev/null +++ b/benchmark/tip_suggestion/lmql_funcs.py @@ -0,0 +1,50 @@ +number = 5 + + +async def expand_tip_async(topic, tip, generate): + s = ( + """Please expand a tip for a topic into a detailed paragraph. + +Topic: staying healthy +Tip: Regular Exercise +Paragraph: Incorporate physical activity into your daily routine. This doesn't necessarily mean intense gym workouts; it can be as simple as walking, cycling, or yoga. Regular exercise helps in maintaining a healthy weight, improves cardiovascular health, boosts mental health, and can enhance cognitive function, which is crucial for fields that require intense intellectual engagement. + +Topic: building a campfire +Tip: Choose the Right Location +Paragraph: Always build your campfire in a safe spot. This means selecting a location that's away from trees, bushes, and other flammable materials. Ideally, use a fire ring if available. If you're building a fire pit, it should be on bare soil or on a bed of stones, not on grass or near roots which can catch fire underground. Make sure the area above is clear of low-hanging branches. + +Topic: writing a blog post +Tip: structure your content effectively +Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement. + +Topic: """ + + topic + + "\nTip: " + + tip + + "\nParagraph:" + ) + return await generate(s, max_tokens=128, stop="\n\n") + + +async def suggest_tips_async(topic, generate): + s = "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n" + s += "USER: Give some tips for " + topic + ".\n" + s += ( + "ASSISTANT: Okay. Here are " + + str(number) + + " concise tips, each under 8 words:\n" + ) + + tips = [] + for i in range(1, 1 + number): + s += f"{i}." + # NOTE: stop is different due to lmql does not support a list of stop tokens + tip = await generate(s, max_tokens=24, stop=".\n") + s += tip + ".\n" + tips.append(tip) + + paragraphs = [await expand_tip_async(topic, tip, generate=generate) for tip in tips] + + for i in range(1, 1 + number): + s += f"Tip {i}:" + paragraphs[i - 1] + "\n" + return s diff --git a/benchmark/tree_of_thought_deep/README.md b/benchmark/tree_of_thought_deep/README.md index cfebbac9b..bf5ab1638 100644 --- a/benchmark/tree_of_thought_deep/README.md +++ b/benchmark/tree_of_thought_deep/README.md @@ -41,5 +41,11 @@ python3 bench_other.py --num-questions 32 --backend lightllm ### Benchmark guidance ``` -python3 bench_other.py --num-questions 8 --backend guidance --parallel 1 +python3 bench_other.py --num-questions 8 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf +``` + +### Benchmark lmql + +``` +python3 bench_other.py --num-questions 8 --backend lmql --parallel 1 ``` diff --git a/benchmark/tree_of_thought_deep/bench_other.py b/benchmark/tree_of_thought_deep/bench_other.py index 57a629768..21c7df351 100644 --- a/benchmark/tree_of_thought_deep/bench_other.py +++ b/benchmark/tree_of_thought_deep/bench_other.py @@ -5,17 +5,11 @@ import re import time from collections import Counter from concurrent.futures import ThreadPoolExecutor -from functools import partial import numpy as np from tqdm import tqdm -from sglang.test.test_utils import ( - add_common_other_args_and_parse, - call_generate_lightllm, - call_generate_srt_raw, - call_generate_vllm, -) +from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate from sglang.utils import dump_state_text, read_jsonl INVALID = -9999999 @@ -139,69 +133,50 @@ def main(args): arguments = [{"question": q, "num_branches": num_branches} for q in questions] # Select backend - if args.backend == "lightllm": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_lightllm, url=url) - elif args.backend == "vllm": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_vllm, url=url) - elif args.backend == "srt-raw": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_srt_raw, url=url) - elif args.backend == "guidance": - from guidance import gen, models - - model = models.LlamaCpp( - "/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", - n_gpu_layers=-1, - n_ctx=4096, - ) - - def call_generate(prompt, temperature, max_tokens, stop, n): - if n == 1: - out = ( - model - + prompt - + gen( - name="answer", - max_tokens=max_tokens, - temperature=temperature, - stop=stop, - ) - ) - return out["answer"] - else: - rets = [] - for i in range(n): - out = ( - model - + prompt - + gen( - name="answer", - max_tokens=max_tokens, - temperature=temperature, - stop=stop, - ) - ) - rets.append(out["answer"]) - return rets - - # warmup - call_generate("Hello,", 1.0, 8, ".", 1) + call_generate = get_call_generate(args) # Run requests states = [None] * len(questions) - def get_one_answer(i): - states[i] = tree_search(**arguments[i], call_generate=call_generate) - tic = time.time() - if args.parallel == 1: - for i in tqdm(range(len(questions))): - get_one_answer(i) + if args.backend != "lmql": + + def get_one_answer(i): + states[i] = tree_search(**arguments[i], call_generate=call_generate) + + if args.parallel == 1: + for i in tqdm(range(len(questions))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + list( + tqdm( + executor.map(get_one_answer, list(range(len(questions)))), + total=len(questions), + ) + ) + else: - with ThreadPoolExecutor(args.parallel) as executor: - executor.map(get_one_answer, list(range(len(questions)))) + import asyncio + + from lmql_funcs import tree_search_async + + async def get_one_answer_async(i): + states[i] = await tree_search_async( + **arguments[i], call_generate=call_generate + ) + + batches = [ + [] for _ in range((len(questions) + args.parallel - 1) // args.parallel) + ] + for i in range(len(questions)): + batches[i // args.parallel].append(i) + + loop = asyncio.get_event_loop() + for bt in tqdm(batches): + tasks = [get_one_answer_async(k) for k in bt] + loop.run_until_complete(asyncio.gather(*tasks)) + latency = time.time() - tic answers_text = [] diff --git a/benchmark/tree_of_thought_deep/lmql_funcs.py b/benchmark/tree_of_thought_deep/lmql_funcs.py new file mode 100644 index 000000000..c783cdbe3 --- /dev/null +++ b/benchmark/tree_of_thought_deep/lmql_funcs.py @@ -0,0 +1,82 @@ +from bench_other import ( + ASSISTANT_PREFIX, + ASSISTANT_SUFFIX, + USER_PREFIX, + USER_SUFFIX, + temp, +) + + +async def propose_plan_async(s, question, num_branches, call_generate): + s += ( + USER_PREFIX + + """Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + + question + + USER_SUFFIX + ) + + s += ASSISTANT_PREFIX + comps = await call_generate( + s, max_tokens=256, temperature=temp, stop=None, n=num_branches + ) + return [s + comp + ASSISTANT_SUFFIX for comp in comps] + + +async def execute_plan_async(s, num_branches, call_generate): + s += ( + USER_PREFIX + + """The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""" + + USER_SUFFIX + ) + s += ASSISTANT_PREFIX + comps = await call_generate( + s, max_tokens=256, temperature=temp, stop=None, n=num_branches + ) + return [s + comp + ASSISTANT_SUFFIX for comp in comps] + + +async def reflect_solution_async(s, num_branches, call_generate): + s += ( + USER_PREFIX + + """Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""" + + USER_SUFFIX + ) + s += ASSISTANT_PREFIX + comps = await call_generate( + s, max_tokens=256, temperature=temp, stop=None, n=num_branches + ) + return [s + comp + ASSISTANT_SUFFIX for comp in comps] + + +async def get_final_answer_async(s, num_branches, call_generate): + s += ( + USER_PREFIX + + """Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration.""" + + USER_SUFFIX + ) + s += ASSISTANT_PREFIX + comps = await call_generate( + s, max_tokens=256, temperature=temp, stop=None, n=num_branches + ) + return [s + comp + ASSISTANT_SUFFIX for comp in comps] + + +async def tree_search_async(question, num_branches, call_generate): + plan_forks = await propose_plan_async("", question, num_branches, call_generate) + + sol_states = [] + for plan in plan_forks: + forks = await execute_plan_async(plan, num_branches, call_generate) + sol_states.extend(forks) + + ref_states = [] + for sol in sol_states: + forks = await reflect_solution_async(sol, num_branches, call_generate) + ref_states.extend(forks) + + solutions = [] + for sol in ref_states: + ans = await get_final_answer_async(sol, num_branches, call_generate) + solutions.append(ans) + + return solutions diff --git a/benchmark/tree_of_thought_v0/README.md b/benchmark/tree_of_thought_v0/README.md index 760b24c46..821bb20d1 100644 --- a/benchmark/tree_of_thought_v0/README.md +++ b/benchmark/tree_of_thought_v0/README.md @@ -39,5 +39,5 @@ python3 bench_other.py --num-questions 32 --backend lightllm ### Benchmark guidance ``` -python3 bench_other.py --num-questions 32 --backend guidance --parallel 1 +python3 bench_other.py --num-questions 32 --backend guidance --parallel 1 --n-ctx 4096 --model-path path/to/gguf ``` diff --git a/benchmark/tree_of_thought_v0/bench_other.py b/benchmark/tree_of_thought_v0/bench_other.py index b200da479..86e133577 100644 --- a/benchmark/tree_of_thought_v0/bench_other.py +++ b/benchmark/tree_of_thought_v0/bench_other.py @@ -5,17 +5,11 @@ import re import time from collections import Counter from concurrent.futures import ThreadPoolExecutor -from functools import partial import numpy as np from tqdm import tqdm -from sglang.test.test_utils import ( - add_common_other_args_and_parse, - call_generate_lightllm, - call_generate_srt_raw, - call_generate_vllm, -) +from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate from sglang.utils import dump_state_text, read_jsonl INVALID = -9999999 @@ -119,52 +113,7 @@ def main(args): arguments = [{"question": q, "num_branches": num_branches} for q in questions] # Select backend - if args.backend == "lightllm": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_lightllm, url=url) - elif args.backend == "vllm": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_vllm, url=url) - elif args.backend == "srt-raw": - url = f"{args.host}:{args.port}/generate" - call_generate = partial(call_generate_srt_raw, url=url) - elif args.backend == "guidance": - from guidance import gen, models - - model = models.LlamaCpp( - "/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", - n_gpu_layers=-1, - n_ctx=4096, - ) - - def call_generate(prompt, temperature, max_tokens, stop, n): - if n == 1: - out = ( - model - + prompt - + gen( - name="answer", - max_tokens=max_tokens, - temperature=temperature, - stop=stop, - ) - ) - return out["answer"] - else: - rets = [] - for i in range(n): - out = ( - model - + prompt - + gen( - name="answer", - max_tokens=max_tokens, - temperature=temperature, - stop=stop, - ) - ) - rets.append(out["answer"]) - return rets + call_generate = get_call_generate(args) # Run requests states = [None] * len(questions) @@ -178,7 +127,13 @@ def main(args): get_one_answer(i) else: with ThreadPoolExecutor(args.parallel) as executor: - executor.map(get_one_answer, list(range(len(questions)))) + list( + tqdm( + executor.map(get_one_answer, list(range(len(questions)))), + total=len(questions), + ) + ) + latency = time.time() - tic answers_text = [] diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 4d5e18211..1b0d8fe6e 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -1,14 +1,20 @@ """Common utilities for testing and benchmarking""" +import asyncio +from functools import partial + import numpy as np import requests from sglang.backend.openai import OpenAI from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.global_config import global_config +from sglang.srt.utils import get_exception_traceback -def call_generate_lightllm(prompt, temperature, max_tokens, stop, url): +def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): + assert url is not None + data = { "inputs": prompt, "parameters": { @@ -23,7 +29,9 @@ def call_generate_lightllm(prompt, temperature, max_tokens, stop, url): return pred -def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1): +def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None): + assert url is not None + data = { "prompt": prompt, "temperature": temperature, @@ -41,8 +49,10 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1): def call_generate_outlines( - prompt, temperature, max_tokens, url, stop=[], regex=None, n=1 + prompt, temperature, max_tokens, stop=[], regex=None, n=1, url=None ): + assert url is not None + data = { "prompt": prompt, "temperature": temperature, @@ -60,7 +70,9 @@ def call_generate_outlines( return pred -def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url): +def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None): + assert url is not None + data = { "text": prompt, "sampling_params": { @@ -76,7 +88,71 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url): return pred -def call_select_lightllm(context, choices, url): +def call_generate_guidance( + prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None +): + assert model is not None + from guidance import gen + + rets = [] + for _ in range(n): + out = ( + model + + prompt + + gen( + name="answer", + max_tokens=max_tokens, + temperature=temperature, + stop=stop, + regex=regex, + ) + ) + rets.append(out["answer"]) + return rets if n > 1 else rets[0] + + +async def call_generate_lmql( + prompt, temperature, max_tokens, stop=None, n=1, max_len=4096, model=None, **kwargs +): + assert model is not None + import lmql + + if stop != None: + + @lmql.query(model=model) + async def program(question, max_tokens, stop): + '''lmql + """{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and STOPS_AT(ANSWER, stop) + return ANSWER + ''' + + else: + + @lmql.query(model=model) + async def program(question, max_tokens): + '''lmql + """{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens + return ANSWER + ''' + + tasks = [ + program( + question=prompt, + temperature=temperature, + max_tokens=max_tokens, + stop=stop, + max_len=max_len, + **kwargs, + ) + for _ in range(n) + ] + rets = await asyncio.gather(*tasks) + return rets if n > 1 else rets[0] + + +def call_select_lightllm(context, choices, url=None): + assert url is not None + scores = [] for i in range(len(choices)): data = { @@ -91,7 +167,9 @@ def call_select_lightllm(context, choices, url): return np.argmax(scores) -def call_select_vllm(context, choices, url): +def call_select_vllm(context, choices, url=None): + assert url is not None + scores = [] for i in range(len(choices)): data = { @@ -113,6 +191,31 @@ def call_select_vllm(context, choices, url): """ +def call_select_guidance(context, choices, model=None): + assert model is not None + from guidance import select + + out = model + context + select(choices, name="answer") + return choices.index(out["answer"]) + + +async def call_select_lmql(context, choices, temperature=0, max_len=4096, model=None): + assert model is not None + import lmql + + @lmql.query(model=model) + async def program(ctx, choices): + '''lmql + """{ctx}[ANSWER]""" where ANSWER in set(choices) + return ANSWER + ''' + + answer = await program( + ctx=context, choices=choices, temperature=temperature, max_len=max_len + ) + return choices.index(answer) + + def add_common_other_args_and_parse(parser): parser.add_argument("--parallel", type=int, default=64) parser.add_argument("--host", type=str, default="http://127.0.0.1") @@ -121,8 +224,17 @@ def add_common_other_args_and_parse(parser): "--backend", type=str, required=True, - choices=["vllm", "lightllm", "guidance", "lmql", "srt-raw", "llama.cpp"], + choices=[ + "vllm", + "outlines", + "lightllm", + "guidance", + "lmql", + "srt-raw", + "llama.cpp", + ], ) + parser.add_argument("--n-ctx", type=int, default=4096) parser.add_argument( "--model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf" ) @@ -132,6 +244,7 @@ def add_common_other_args_and_parse(parser): if args.port is None: default_port = { "vllm": 21000, + "outlines": 21000, "lightllm": 22000, "lmql": 23000, "srt-raw": 30000, @@ -161,3 +274,77 @@ def select_sglang_backend(args): else: raise ValueError(f"Invalid backend: {args.backend}") return backend + + +def _get_call_generate(args): + if args.backend == "lightllm": + return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate") + elif args.backend == "vllm": + return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate") + elif args.backend == "srt-raw": + return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate") + elif args.backend == "outlines": + return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate") + elif args.backend == "guidance": + from guidance import models + + model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx) + call_generate = partial(call_generate_guidance, model=model) + call_generate("Hello,", 1.0, 8, ".") + return call_generate + elif args.backend == "lmql": + import lmql + + model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}") + return partial(call_generate_lmql, model=model) + else: + raise ValueError(f"Invalid backend: {args.backend}") + + +def _get_call_select(args): + if args.backend == "lightllm": + return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate") + elif args.backend == "vllm": + return partial(call_select_vllm, url=f"{args.host}:{args.port}/generate") + elif args.backend == "guidance": + from guidance import models + + model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx) + call_select = partial(call_select_guidance, model=model) + + call_select("Hello,", ["world", "earth"]) + return call_select + + elif args.backend == "lmql": + import lmql + + model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}") + return partial(call_select_lmql, model=model) + else: + raise ValueError(f"Invalid backend: {args.backend}") + + +def get_call_generate(args): + call_generate = _get_call_generate(args) + + def func(*args, **kwargs): + try: + return call_generate(*args, **kwargs) + except Exception: + print("Exception in call_generate:\n" + get_exception_traceback()) + raise + + return func + + +def get_call_select(args): + call_select = _get_call_select(args) + + def func(*args, **kwargs): + try: + return call_select(*args, **kwargs) + except Exception: + print("Exception in call_select:\n" + get_exception_traceback()) + raise + + return func