Files
sglang/python/sglang/test/test_utils.py

500 lines
14 KiB
Python
Raw Normal View History

"""Common utilities for testing and benchmarking"""
2024-08-03 23:09:21 -07:00
import argparse
2024-05-05 16:14:17 +08:00
import asyncio
import os
2024-08-01 21:20:17 -07:00
import subprocess
2024-08-03 23:09:21 -07:00
import threading
2024-08-01 21:20:17 -07:00
import time
2024-05-05 16:14:17 +08:00
from functools import partial
2024-08-04 16:02:05 -07:00
from typing import Callable, List, Optional
2024-05-05 16:14:17 +08:00
import numpy as np
import requests
import torch
import torch.nn.functional as F
2024-04-22 22:38:09 +08:00
from sglang.global_config import global_config
from sglang.lang.backend.openai import OpenAI
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
2024-08-25 19:56:42 -07:00
from sglang.srt.utils import kill_child_process
from sglang.utils import get_exception_traceback
2024-05-05 16:14:17 +08:00
2024-08-04 20:51:55 -07:00
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct"
2024-08-13 16:43:23 +08:00
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
2024-08-25 16:43:07 -07:00
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 5157
DEFAULT_URL_FOR_TEST = "http://127.0.0.1:6157"
else:
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 1157
DEFAULT_URL_FOR_TEST = "http://127.0.0.1:2157"
2024-08-01 21:20:17 -07:00
2024-05-05 16:14:17 +08:00
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
assert url is not None
data = {
"inputs": prompt,
"parameters": {
"temperature": temperature,
"max_new_tokens": max_tokens,
"stop_sequences": stop,
},
}
res = requests.post(url, json=data)
assert res.status_code == 200
pred = res.json()["generated_text"][0]
return pred
2024-05-05 16:14:17 +08:00
def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None):
assert url is not None
data = {
"prompt": prompt,
"temperature": temperature,
"max_tokens": max_tokens,
"stop": stop,
"n": n,
}
res = requests.post(url, json=data)
assert res.status_code == 200
if n == 1:
pred = res.json()["text"][0][len(prompt) :]
else:
pred = [x[len(prompt) :] for x in res.json()["text"]]
return pred
2024-01-15 16:49:29 +08:00
def call_generate_outlines(
2024-05-05 16:14:17 +08:00
prompt, temperature, max_tokens, stop=[], regex=None, n=1, url=None
2024-01-15 16:49:29 +08:00
):
2024-05-05 16:14:17 +08:00
assert url is not None
2024-01-15 16:49:29 +08:00
data = {
"prompt": prompt,
"temperature": temperature,
"max_tokens": max_tokens,
"stop": stop,
"regex": regex,
"n": n,
}
res = requests.post(url, json=data)
assert res.status_code == 200
if n == 1:
pred = res.json()["text"][0][len(prompt) :]
else:
pred = [x[len(prompt) :] for x in res.json()["text"]]
return pred
2024-05-05 16:14:17 +08:00
def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
assert url is not None
data = {
"text": prompt,
"sampling_params": {
"temperature": temperature,
"max_new_tokens": max_tokens,
"stop": stop,
},
}
res = requests.post(url, json=data)
assert res.status_code == 200
obj = res.json()
pred = obj["text"]
return pred
def call_generate_gserver(prompt, temperature, max_tokens, stop=None, url=None):
2024-08-17 17:43:23 -07:00
raise NotImplementedError()
2024-05-05 16:14:17 +08:00
def call_generate_guidance(
prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
):
assert model is not None
from guidance import gen
rets = []
for _ in range(n):
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
regex=regex,
)
)
rets.append(out["answer"])
return rets if n > 1 else rets[0]
async def call_generate_lmql(
prompt, temperature, max_tokens, stop=None, n=1, max_len=4096, model=None, **kwargs
):
assert model is not None
import lmql
if stop != None:
@lmql.query(model=model)
async def program(question, max_tokens, stop):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and STOPS_AT(ANSWER, stop)
return ANSWER
'''
else:
@lmql.query(model=model)
async def program(question, max_tokens):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens
return ANSWER
'''
tasks = [
program(
question=prompt,
temperature=temperature,
max_tokens=max_tokens,
stop=stop,
max_len=max_len,
**kwargs,
)
for _ in range(n)
]
rets = await asyncio.gather(*tasks)
return rets if n > 1 else rets[0]
def call_select_lightllm(context, choices, url=None):
assert url is not None
scores = []
for i in range(len(choices)):
data = {
"inputs": context + choices[i],
"parameters": {
"max_new_tokens": 1,
},
}
res = requests.post(url, json=data)
assert res.status_code == 200
scores.append(0)
return np.argmax(scores)
2024-05-05 16:14:17 +08:00
def call_select_vllm(context, choices, url=None):
assert url is not None
scores = []
for i in range(len(choices)):
data = {
"prompt": context + choices[i],
"max_tokens": 1,
"prompt_logprobs": 1,
}
res = requests.post(url, json=data)
assert res.status_code == 200
2024-01-15 16:12:57 -08:00
scores.append(res.json().get("prompt_score", 0))
return np.argmax(scores)
"""
Modify vllm/entrypoints/api_server.py
if final_output.prompt_logprobs is not None:
score = np.mean([prob[t_id] for t_id, prob in zip(final_output.prompt_token_ids[1:], final_output.prompt_logprobs[1:])])
ret["prompt_score"] = score
"""
2024-05-05 16:14:17 +08:00
def call_select_guidance(context, choices, model=None):
assert model is not None
from guidance import select
out = model + context + select(choices, name="answer")
return choices.index(out["answer"])
async def call_select_lmql(context, choices, temperature=0, max_len=4096, model=None):
assert model is not None
import lmql
@lmql.query(model=model)
async def program(ctx, choices):
'''lmql
"""{ctx}[ANSWER]""" where ANSWER in set(choices)
return ANSWER
'''
answer = await program(
ctx=context, choices=choices, temperature=temperature, max_len=max_len
)
return choices.index(answer)
2024-08-03 23:09:21 -07:00
def add_common_other_args_and_parse(parser: argparse.ArgumentParser):
2024-01-15 16:12:57 -08:00
parser.add_argument("--parallel", type=int, default=64)
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=None)
parser.add_argument(
"--backend",
type=str,
required=True,
2024-05-05 16:14:17 +08:00
choices=[
"vllm",
"outlines",
"lightllm",
"gserver",
2024-05-05 16:14:17 +08:00
"guidance",
"lmql",
"srt-raw",
"llama.cpp",
],
)
2024-05-05 16:14:17 +08:00
parser.add_argument("--n-ctx", type=int, default=4096)
parser.add_argument(
"--model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
)
parser.add_argument("--result-file", type=str, default="result.jsonl")
args = parser.parse_args()
if args.port is None:
default_port = {
"vllm": 21000,
2024-05-05 16:14:17 +08:00
"outlines": 21000,
"lightllm": 22000,
"lmql": 23000,
"srt-raw": 30000,
"gserver": 9988,
}
args.port = default_port.get(args.backend, None)
return args
2024-08-03 23:09:21 -07:00
def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
parser.add_argument("--parallel", type=int, default=64)
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
parser.add_argument("--backend", type=str, default="srt")
parser.add_argument("--result-file", type=str, default="result.jsonl")
args = parser.parse_args()
return args
2024-08-03 23:09:21 -07:00
def select_sglang_backend(args: argparse.Namespace):
if args.backend.startswith("srt"):
if args.backend == "srt-no-parallel":
global_config.enable_parallel_decoding = False
global_config.enable_parallel_encoding = False
backend = RuntimeEndpoint(f"{args.host}:{args.port}")
2024-03-22 13:37:57 -07:00
elif args.backend.startswith("gpt-"):
backend = OpenAI(args.backend)
else:
raise ValueError(f"Invalid backend: {args.backend}")
return backend
2024-05-05 16:14:17 +08:00
2024-08-03 23:09:21 -07:00
def _get_call_generate(args: argparse.Namespace):
2024-05-05 16:14:17 +08:00
if args.backend == "lightllm":
return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "vllm":
return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "srt-raw":
return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
elif args.backend == "gserver":
return partial(call_generate_gserver, url=f"{args.host}:{args.port}")
2024-05-05 16:14:17 +08:00
elif args.backend == "outlines":
return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
elif args.backend == "guidance":
from guidance import models
model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
call_generate = partial(call_generate_guidance, model=model)
call_generate("Hello,", 1.0, 8, ".")
return call_generate
elif args.backend == "lmql":
import lmql
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
return partial(call_generate_lmql, model=model)
else:
raise ValueError(f"Invalid backend: {args.backend}")
2024-08-03 23:09:21 -07:00
def _get_call_select(args: argparse.Namespace):
2024-05-05 16:14:17 +08:00
if args.backend == "lightllm":
return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "vllm":
return partial(call_select_vllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "guidance":
from guidance import models
model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
call_select = partial(call_select_guidance, model=model)
call_select("Hello,", ["world", "earth"])
return call_select
elif args.backend == "lmql":
import lmql
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
return partial(call_select_lmql, model=model)
else:
raise ValueError(f"Invalid backend: {args.backend}")
2024-08-03 23:09:21 -07:00
def get_call_generate(args: argparse.Namespace):
2024-05-05 16:14:17 +08:00
call_generate = _get_call_generate(args)
def func(*args, **kwargs):
try:
return call_generate(*args, **kwargs)
except Exception:
print("Exception in call_generate:\n" + get_exception_traceback())
raise
return func
2024-08-03 23:09:21 -07:00
def get_call_select(args: argparse.Namespace):
2024-05-05 16:14:17 +08:00
call_select = _get_call_select(args)
def func(*args, **kwargs):
try:
return call_select(*args, **kwargs)
except Exception:
print("Exception in call_select:\n" + get_exception_traceback())
raise
return func
2024-08-01 21:20:17 -07:00
2024-08-03 23:09:21 -07:00
def popen_launch_server(
2024-08-04 13:35:44 -07:00
model: str,
base_url: str,
timeout: float,
api_key: Optional[str] = None,
other_args: tuple = (),
env: Optional[dict] = None,
return_stdout_stderr: bool = False,
2024-08-03 23:09:21 -07:00
):
_, host, port = base_url.split(":")
host = host[2:]
2024-08-01 21:20:17 -07:00
command = [
"python3",
"-m",
"sglang.launch_server",
"--model-path",
model,
"--host",
2024-08-03 23:09:21 -07:00
host,
2024-08-01 21:20:17 -07:00
"--port",
2024-08-03 23:09:21 -07:00
port,
*other_args,
2024-08-01 21:20:17 -07:00
]
2024-08-04 13:35:44 -07:00
if api_key:
command += ["--api-key", api_key]
if return_stdout_stderr:
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
text=True,
)
else:
process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
2024-08-01 21:20:17 -07:00
start_time = time.time()
while time.time() - start_time < timeout:
try:
2024-08-04 13:35:44 -07:00
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {api_key}",
}
response = requests.get(f"{base_url}/v1/models", headers=headers)
2024-08-01 21:20:17 -07:00
if response.status_code == 200:
return process
except requests.RequestException:
pass
time.sleep(10)
raise TimeoutError("Server failed to start within the timeout period.")
2024-08-03 23:09:21 -07:00
def run_with_timeout(
func: Callable,
args: tuple = (),
kwargs: Optional[dict] = None,
timeout: float = None,
):
"""Run a function with timeout."""
ret_value = []
def _target_func():
ret_value.append(func(*args, **(kwargs or {})))
t = threading.Thread(target=_target_func)
t.start()
t.join(timeout=timeout)
if t.is_alive():
raise TimeoutError()
if not ret_value:
raise RuntimeError()
return ret_value[0]
2024-08-04 16:02:05 -07:00
def run_unittest_files(files: List[str], timeout_per_file: float):
2024-08-03 23:09:21 -07:00
tic = time.time()
success = True
for filename in files:
2024-08-25 19:56:42 -07:00
global process
2024-08-03 23:09:21 -07:00
2024-08-25 19:56:42 -07:00
def run_one_file(filename):
filename = os.path.join(os.getcwd(), filename)
print(f"\n\nRun:\npython3 {filename}\n\n", flush=True)
2024-08-25 19:56:42 -07:00
process = subprocess.Popen(
["python3", filename], stdout=None, stderr=None, env=os.environ
)
process.wait()
return process.returncode
2024-08-03 23:09:21 -07:00
try:
2024-08-25 19:56:42 -07:00
ret_code = run_with_timeout(
run_one_file, args=(filename,), timeout=timeout_per_file
)
assert ret_code == 0
2024-08-03 23:09:21 -07:00
except TimeoutError:
2024-08-25 19:56:42 -07:00
kill_child_process(process.pid)
2024-08-03 23:09:21 -07:00
time.sleep(5)
print(
f"\nTimeout after {timeout_per_file} seconds when running {filename}\n",
flush=True,
2024-08-03 23:09:21 -07:00
)
2024-08-25 19:56:42 -07:00
success = False
break
2024-08-03 23:09:21 -07:00
if success:
print(f"Success. Time elapsed: {time.time() - tic:.2f}s", flush=True)
2024-08-03 23:09:21 -07:00
else:
print(f"Fail. Time elapsed: {time.time() - tic:.2f}s", flush=True)
2024-08-03 23:09:21 -07:00
return 0 if success else -1
def get_similarities(vec1, vec2):
return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)