Organize Benchmark (#381)
This commit is contained in:
@@ -1,14 +1,20 @@
|
||||
"""Common utilities for testing and benchmarking"""
|
||||
|
||||
import asyncio
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from sglang.backend.openai import OpenAI
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.utils import get_exception_traceback
|
||||
|
||||
|
||||
def call_generate_lightllm(prompt, temperature, max_tokens, stop, url):
|
||||
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
|
||||
assert url is not None
|
||||
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": {
|
||||
@@ -23,7 +29,9 @@ def call_generate_lightllm(prompt, temperature, max_tokens, stop, url):
|
||||
return pred
|
||||
|
||||
|
||||
def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
|
||||
def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None):
|
||||
assert url is not None
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"temperature": temperature,
|
||||
@@ -41,8 +49,10 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
|
||||
|
||||
|
||||
def call_generate_outlines(
|
||||
prompt, temperature, max_tokens, url, stop=[], regex=None, n=1
|
||||
prompt, temperature, max_tokens, stop=[], regex=None, n=1, url=None
|
||||
):
|
||||
assert url is not None
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"temperature": temperature,
|
||||
@@ -60,7 +70,9 @@ def call_generate_outlines(
|
||||
return pred
|
||||
|
||||
|
||||
def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
|
||||
def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
|
||||
assert url is not None
|
||||
|
||||
data = {
|
||||
"text": prompt,
|
||||
"sampling_params": {
|
||||
@@ -76,7 +88,71 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
|
||||
return pred
|
||||
|
||||
|
||||
def call_select_lightllm(context, choices, url):
|
||||
def call_generate_guidance(
|
||||
prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
|
||||
):
|
||||
assert model is not None
|
||||
from guidance import gen
|
||||
|
||||
rets = []
|
||||
for _ in range(n):
|
||||
out = (
|
||||
model
|
||||
+ prompt
|
||||
+ gen(
|
||||
name="answer",
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop,
|
||||
regex=regex,
|
||||
)
|
||||
)
|
||||
rets.append(out["answer"])
|
||||
return rets if n > 1 else rets[0]
|
||||
|
||||
|
||||
async def call_generate_lmql(
|
||||
prompt, temperature, max_tokens, stop=None, n=1, max_len=4096, model=None, **kwargs
|
||||
):
|
||||
assert model is not None
|
||||
import lmql
|
||||
|
||||
if stop != None:
|
||||
|
||||
@lmql.query(model=model)
|
||||
async def program(question, max_tokens, stop):
|
||||
'''lmql
|
||||
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and STOPS_AT(ANSWER, stop)
|
||||
return ANSWER
|
||||
'''
|
||||
|
||||
else:
|
||||
|
||||
@lmql.query(model=model)
|
||||
async def program(question, max_tokens):
|
||||
'''lmql
|
||||
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens
|
||||
return ANSWER
|
||||
'''
|
||||
|
||||
tasks = [
|
||||
program(
|
||||
question=prompt,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stop=stop,
|
||||
max_len=max_len,
|
||||
**kwargs,
|
||||
)
|
||||
for _ in range(n)
|
||||
]
|
||||
rets = await asyncio.gather(*tasks)
|
||||
return rets if n > 1 else rets[0]
|
||||
|
||||
|
||||
def call_select_lightllm(context, choices, url=None):
|
||||
assert url is not None
|
||||
|
||||
scores = []
|
||||
for i in range(len(choices)):
|
||||
data = {
|
||||
@@ -91,7 +167,9 @@ def call_select_lightllm(context, choices, url):
|
||||
return np.argmax(scores)
|
||||
|
||||
|
||||
def call_select_vllm(context, choices, url):
|
||||
def call_select_vllm(context, choices, url=None):
|
||||
assert url is not None
|
||||
|
||||
scores = []
|
||||
for i in range(len(choices)):
|
||||
data = {
|
||||
@@ -113,6 +191,31 @@ def call_select_vllm(context, choices, url):
|
||||
"""
|
||||
|
||||
|
||||
def call_select_guidance(context, choices, model=None):
|
||||
assert model is not None
|
||||
from guidance import select
|
||||
|
||||
out = model + context + select(choices, name="answer")
|
||||
return choices.index(out["answer"])
|
||||
|
||||
|
||||
async def call_select_lmql(context, choices, temperature=0, max_len=4096, model=None):
|
||||
assert model is not None
|
||||
import lmql
|
||||
|
||||
@lmql.query(model=model)
|
||||
async def program(ctx, choices):
|
||||
'''lmql
|
||||
"""{ctx}[ANSWER]""" where ANSWER in set(choices)
|
||||
return ANSWER
|
||||
'''
|
||||
|
||||
answer = await program(
|
||||
ctx=context, choices=choices, temperature=temperature, max_len=max_len
|
||||
)
|
||||
return choices.index(answer)
|
||||
|
||||
|
||||
def add_common_other_args_and_parse(parser):
|
||||
parser.add_argument("--parallel", type=int, default=64)
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
@@ -121,8 +224,17 @@ def add_common_other_args_and_parse(parser):
|
||||
"--backend",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["vllm", "lightllm", "guidance", "lmql", "srt-raw", "llama.cpp"],
|
||||
choices=[
|
||||
"vllm",
|
||||
"outlines",
|
||||
"lightllm",
|
||||
"guidance",
|
||||
"lmql",
|
||||
"srt-raw",
|
||||
"llama.cpp",
|
||||
],
|
||||
)
|
||||
parser.add_argument("--n-ctx", type=int, default=4096)
|
||||
parser.add_argument(
|
||||
"--model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
|
||||
)
|
||||
@@ -132,6 +244,7 @@ def add_common_other_args_and_parse(parser):
|
||||
if args.port is None:
|
||||
default_port = {
|
||||
"vllm": 21000,
|
||||
"outlines": 21000,
|
||||
"lightllm": 22000,
|
||||
"lmql": 23000,
|
||||
"srt-raw": 30000,
|
||||
@@ -161,3 +274,77 @@ def select_sglang_backend(args):
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
return backend
|
||||
|
||||
|
||||
def _get_call_generate(args):
|
||||
if args.backend == "lightllm":
|
||||
return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate")
|
||||
elif args.backend == "vllm":
|
||||
return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
|
||||
elif args.backend == "srt-raw":
|
||||
return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
|
||||
elif args.backend == "outlines":
|
||||
return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
|
||||
elif args.backend == "guidance":
|
||||
from guidance import models
|
||||
|
||||
model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
|
||||
call_generate = partial(call_generate_guidance, model=model)
|
||||
call_generate("Hello,", 1.0, 8, ".")
|
||||
return call_generate
|
||||
elif args.backend == "lmql":
|
||||
import lmql
|
||||
|
||||
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
|
||||
return partial(call_generate_lmql, model=model)
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
|
||||
def _get_call_select(args):
|
||||
if args.backend == "lightllm":
|
||||
return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate")
|
||||
elif args.backend == "vllm":
|
||||
return partial(call_select_vllm, url=f"{args.host}:{args.port}/generate")
|
||||
elif args.backend == "guidance":
|
||||
from guidance import models
|
||||
|
||||
model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
|
||||
call_select = partial(call_select_guidance, model=model)
|
||||
|
||||
call_select("Hello,", ["world", "earth"])
|
||||
return call_select
|
||||
|
||||
elif args.backend == "lmql":
|
||||
import lmql
|
||||
|
||||
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
|
||||
return partial(call_select_lmql, model=model)
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
|
||||
def get_call_generate(args):
|
||||
call_generate = _get_call_generate(args)
|
||||
|
||||
def func(*args, **kwargs):
|
||||
try:
|
||||
return call_generate(*args, **kwargs)
|
||||
except Exception:
|
||||
print("Exception in call_generate:\n" + get_exception_traceback())
|
||||
raise
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def get_call_select(args):
|
||||
call_select = _get_call_select(args)
|
||||
|
||||
def func(*args, **kwargs):
|
||||
try:
|
||||
return call_select(*args, **kwargs)
|
||||
except Exception:
|
||||
print("Exception in call_select:\n" + get_exception_traceback())
|
||||
raise
|
||||
|
||||
return func
|
||||
|
||||
Reference in New Issue
Block a user