Organize Benchmark (#381)

This commit is contained in:
Liangsheng Yin
2024-05-05 16:14:17 +08:00
committed by GitHub
parent 183df47282
commit 14522e6a26
36 changed files with 829 additions and 809 deletions

View File

@@ -1,14 +1,20 @@
"""Common utilities for testing and benchmarking"""
import asyncio
from functools import partial
import numpy as np
import requests
from sglang.backend.openai import OpenAI
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.global_config import global_config
from sglang.srt.utils import get_exception_traceback
def call_generate_lightllm(prompt, temperature, max_tokens, stop, url):
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
assert url is not None
data = {
"inputs": prompt,
"parameters": {
@@ -23,7 +29,9 @@ def call_generate_lightllm(prompt, temperature, max_tokens, stop, url):
return pred
def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None):
assert url is not None
data = {
"prompt": prompt,
"temperature": temperature,
@@ -41,8 +49,10 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
def call_generate_outlines(
prompt, temperature, max_tokens, url, stop=[], regex=None, n=1
prompt, temperature, max_tokens, stop=[], regex=None, n=1, url=None
):
assert url is not None
data = {
"prompt": prompt,
"temperature": temperature,
@@ -60,7 +70,9 @@ def call_generate_outlines(
return pred
def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
assert url is not None
data = {
"text": prompt,
"sampling_params": {
@@ -76,7 +88,71 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
return pred
def call_select_lightllm(context, choices, url):
def call_generate_guidance(
prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
):
assert model is not None
from guidance import gen
rets = []
for _ in range(n):
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
regex=regex,
)
)
rets.append(out["answer"])
return rets if n > 1 else rets[0]
async def call_generate_lmql(
prompt, temperature, max_tokens, stop=None, n=1, max_len=4096, model=None, **kwargs
):
assert model is not None
import lmql
if stop != None:
@lmql.query(model=model)
async def program(question, max_tokens, stop):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and STOPS_AT(ANSWER, stop)
return ANSWER
'''
else:
@lmql.query(model=model)
async def program(question, max_tokens):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens
return ANSWER
'''
tasks = [
program(
question=prompt,
temperature=temperature,
max_tokens=max_tokens,
stop=stop,
max_len=max_len,
**kwargs,
)
for _ in range(n)
]
rets = await asyncio.gather(*tasks)
return rets if n > 1 else rets[0]
def call_select_lightllm(context, choices, url=None):
assert url is not None
scores = []
for i in range(len(choices)):
data = {
@@ -91,7 +167,9 @@ def call_select_lightllm(context, choices, url):
return np.argmax(scores)
def call_select_vllm(context, choices, url):
def call_select_vllm(context, choices, url=None):
assert url is not None
scores = []
for i in range(len(choices)):
data = {
@@ -113,6 +191,31 @@ def call_select_vllm(context, choices, url):
"""
def call_select_guidance(context, choices, model=None):
assert model is not None
from guidance import select
out = model + context + select(choices, name="answer")
return choices.index(out["answer"])
async def call_select_lmql(context, choices, temperature=0, max_len=4096, model=None):
assert model is not None
import lmql
@lmql.query(model=model)
async def program(ctx, choices):
'''lmql
"""{ctx}[ANSWER]""" where ANSWER in set(choices)
return ANSWER
'''
answer = await program(
ctx=context, choices=choices, temperature=temperature, max_len=max_len
)
return choices.index(answer)
def add_common_other_args_and_parse(parser):
parser.add_argument("--parallel", type=int, default=64)
parser.add_argument("--host", type=str, default="http://127.0.0.1")
@@ -121,8 +224,17 @@ def add_common_other_args_and_parse(parser):
"--backend",
type=str,
required=True,
choices=["vllm", "lightllm", "guidance", "lmql", "srt-raw", "llama.cpp"],
choices=[
"vllm",
"outlines",
"lightllm",
"guidance",
"lmql",
"srt-raw",
"llama.cpp",
],
)
parser.add_argument("--n-ctx", type=int, default=4096)
parser.add_argument(
"--model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
)
@@ -132,6 +244,7 @@ def add_common_other_args_and_parse(parser):
if args.port is None:
default_port = {
"vllm": 21000,
"outlines": 21000,
"lightllm": 22000,
"lmql": 23000,
"srt-raw": 30000,
@@ -161,3 +274,77 @@ def select_sglang_backend(args):
else:
raise ValueError(f"Invalid backend: {args.backend}")
return backend
def _get_call_generate(args):
if args.backend == "lightllm":
return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "vllm":
return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "srt-raw":
return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
elif args.backend == "outlines":
return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
elif args.backend == "guidance":
from guidance import models
model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
call_generate = partial(call_generate_guidance, model=model)
call_generate("Hello,", 1.0, 8, ".")
return call_generate
elif args.backend == "lmql":
import lmql
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
return partial(call_generate_lmql, model=model)
else:
raise ValueError(f"Invalid backend: {args.backend}")
def _get_call_select(args):
if args.backend == "lightllm":
return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "vllm":
return partial(call_select_vllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "guidance":
from guidance import models
model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
call_select = partial(call_select_guidance, model=model)
call_select("Hello,", ["world", "earth"])
return call_select
elif args.backend == "lmql":
import lmql
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
return partial(call_select_lmql, model=model)
else:
raise ValueError(f"Invalid backend: {args.backend}")
def get_call_generate(args):
call_generate = _get_call_generate(args)
def func(*args, **kwargs):
try:
return call_generate(*args, **kwargs)
except Exception:
print("Exception in call_generate:\n" + get_exception_traceback())
raise
return func
def get_call_select(args):
call_select = _get_call_select(args)
def func(*args, **kwargs):
try:
return call_select(*args, **kwargs)
except Exception:
print("Exception in call_select:\n" + get_exception_traceback())
raise
return func