Files
sglang/python/sglang/test/test_utils.py
2024-09-15 08:52:18 -07:00

587 lines
17 KiB
Python

"""Common utilities for testing and benchmarking"""
import argparse
import asyncio
import os
import subprocess
import threading
import time
from functools import partial
from types import SimpleNamespace
from typing import Callable, List, Optional
import numpy as np
import requests
import torch
import torch.nn.functional as F
from sglang.bench_serving import run_benchmark
from sglang.global_config import global_config
from sglang.lang.backend.openai import OpenAI
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.utils import kill_child_process
from sglang.utils import get_exception_traceback
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct"
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Meta-Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Meta-Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8"
def is_in_ci():
"""Return whether it is in CI runner."""
return os.getenv("SGLANG_IS_IN_CI", "false") == "true"
if is_in_ci():
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 5157
DEFAULT_URL_FOR_TEST = "http://127.0.0.1:6157"
else:
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 1157
DEFAULT_URL_FOR_TEST = "http://127.0.0.1:2157"
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
assert url is not None
data = {
"inputs": prompt,
"parameters": {
"temperature": temperature,
"max_new_tokens": max_tokens,
"stop_sequences": stop,
},
}
res = requests.post(url, json=data)
assert res.status_code == 200
pred = res.json()["generated_text"][0]
return pred
def call_generate_vllm(prompt, temperature, max_tokens, stop=None, n=1, url=None):
assert url is not None
data = {
"prompt": prompt,
"temperature": temperature,
"max_tokens": max_tokens,
"stop": stop,
"n": n,
}
res = requests.post(url, json=data)
assert res.status_code == 200
if n == 1:
pred = res.json()["text"][0][len(prompt) :]
else:
pred = [x[len(prompt) :] for x in res.json()["text"]]
return pred
def call_generate_outlines(
prompt, temperature, max_tokens, stop=[], regex=None, n=1, url=None
):
assert url is not None
data = {
"prompt": prompt,
"temperature": temperature,
"max_tokens": max_tokens,
"stop": stop,
"regex": regex,
"n": n,
}
res = requests.post(url, json=data)
assert res.status_code == 200
if n == 1:
pred = res.json()["text"][0][len(prompt) :]
else:
pred = [x[len(prompt) :] for x in res.json()["text"]]
return pred
def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
assert url is not None
data = {
"text": prompt,
"sampling_params": {
"temperature": temperature,
"max_new_tokens": max_tokens,
"stop": stop,
},
}
res = requests.post(url, json=data)
assert res.status_code == 200
obj = res.json()
pred = obj["text"]
return pred
def call_generate_gserver(prompt, temperature, max_tokens, stop=None, url=None):
raise NotImplementedError()
def call_generate_guidance(
prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
):
assert model is not None
from guidance import gen
rets = []
for _ in range(n):
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
regex=regex,
)
)
rets.append(out["answer"])
return rets if n > 1 else rets[0]
async def call_generate_lmql(
prompt, temperature, max_tokens, stop=None, n=1, max_len=4096, model=None, **kwargs
):
assert model is not None
import lmql
if stop != None:
@lmql.query(model=model)
async def program(question, max_tokens, stop):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and STOPS_AT(ANSWER, stop)
return ANSWER
'''
else:
@lmql.query(model=model)
async def program(question, max_tokens):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens
return ANSWER
'''
tasks = [
program(
question=prompt,
temperature=temperature,
max_tokens=max_tokens,
stop=stop,
max_len=max_len,
**kwargs,
)
for _ in range(n)
]
rets = await asyncio.gather(*tasks)
return rets if n > 1 else rets[0]
def call_select_lightllm(context, choices, url=None):
assert url is not None
scores = []
for i in range(len(choices)):
data = {
"inputs": context + choices[i],
"parameters": {
"max_new_tokens": 1,
},
}
res = requests.post(url, json=data)
assert res.status_code == 200
scores.append(0)
return np.argmax(scores)
def call_select_vllm(context, choices, url=None):
assert url is not None
scores = []
for i in range(len(choices)):
data = {
"prompt": context + choices[i],
"max_tokens": 1,
"prompt_logprobs": 1,
}
res = requests.post(url, json=data)
assert res.status_code == 200
scores.append(res.json().get("prompt_score", 0))
return np.argmax(scores)
"""
Modify vllm/entrypoints/api_server.py
if final_output.prompt_logprobs is not None:
score = np.mean([prob[t_id] for t_id, prob in zip(final_output.prompt_token_ids[1:], final_output.prompt_logprobs[1:])])
ret["prompt_score"] = score
"""
def call_select_guidance(context, choices, model=None):
assert model is not None
from guidance import select
out = model + context + select(choices, name="answer")
return choices.index(out["answer"])
async def call_select_lmql(context, choices, temperature=0, max_len=4096, model=None):
assert model is not None
import lmql
@lmql.query(model=model)
async def program(ctx, choices):
'''lmql
"""{ctx}[ANSWER]""" where ANSWER in set(choices)
return ANSWER
'''
answer = await program(
ctx=context, choices=choices, temperature=temperature, max_len=max_len
)
return choices.index(answer)
def add_common_other_args_and_parse(parser: argparse.ArgumentParser):
parser.add_argument("--parallel", type=int, default=64)
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=None)
parser.add_argument(
"--backend",
type=str,
required=True,
choices=[
"vllm",
"outlines",
"lightllm",
"gserver",
"guidance",
"lmql",
"srt-raw",
"llama.cpp",
],
)
parser.add_argument("--n-ctx", type=int, default=4096)
parser.add_argument(
"--model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
)
parser.add_argument("--result-file", type=str, default="result.jsonl")
args = parser.parse_args()
if args.port is None:
default_port = {
"vllm": 21000,
"outlines": 21000,
"lightllm": 22000,
"lmql": 23000,
"srt-raw": 30000,
"gserver": 9988,
}
args.port = default_port.get(args.backend, None)
return args
def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
parser.add_argument("--parallel", type=int, default=64)
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
parser.add_argument("--backend", type=str, default="srt")
parser.add_argument("--result-file", type=str, default="result.jsonl")
args = parser.parse_args()
return args
def select_sglang_backend(args: argparse.Namespace):
if args.backend.startswith("srt"):
if args.backend == "srt-no-parallel":
global_config.enable_parallel_encoding = False
backend = RuntimeEndpoint(f"{args.host}:{args.port}")
elif args.backend.startswith("gpt-"):
backend = OpenAI(args.backend)
else:
raise ValueError(f"Invalid backend: {args.backend}")
return backend
def _get_call_generate(args: argparse.Namespace):
if args.backend == "lightllm":
return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "vllm":
return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "srt-raw":
return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
elif args.backend == "gserver":
return partial(call_generate_gserver, url=f"{args.host}:{args.port}")
elif args.backend == "outlines":
return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
elif args.backend == "guidance":
from guidance import models
model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
call_generate = partial(call_generate_guidance, model=model)
call_generate("Hello,", 1.0, 8, ".")
return call_generate
elif args.backend == "lmql":
import lmql
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
return partial(call_generate_lmql, model=model)
else:
raise ValueError(f"Invalid backend: {args.backend}")
def _get_call_select(args: argparse.Namespace):
if args.backend == "lightllm":
return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "vllm":
return partial(call_select_vllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "guidance":
from guidance import models
model = models.LlamaCpp(args.model_path, n_gpu_layers=-1, n_ctx=args.n_ctx)
call_select = partial(call_select_guidance, model=model)
call_select("Hello,", ["world", "earth"])
return call_select
elif args.backend == "lmql":
import lmql
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
return partial(call_select_lmql, model=model)
else:
raise ValueError(f"Invalid backend: {args.backend}")
def get_call_generate(args: argparse.Namespace):
call_generate = _get_call_generate(args)
def func(*args, **kwargs):
try:
return call_generate(*args, **kwargs)
except Exception:
print("Exception in call_generate:\n" + get_exception_traceback())
raise
return func
def get_call_select(args: argparse.Namespace):
call_select = _get_call_select(args)
def func(*args, **kwargs):
try:
return call_select(*args, **kwargs)
except Exception:
print("Exception in call_select:\n" + get_exception_traceback())
raise
return func
def popen_launch_server(
model: str,
base_url: str,
timeout: float,
api_key: Optional[str] = None,
other_args: tuple = (),
env: Optional[dict] = None,
return_stdout_stderr: bool = False,
):
_, host, port = base_url.split(":")
host = host[2:]
command = [
"python3",
"-m",
"sglang.launch_server",
"--model-path",
model,
"--host",
host,
"--port",
port,
*other_args,
]
if api_key:
command += ["--api-key", api_key]
if return_stdout_stderr:
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
text=True,
)
else:
process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
start_time = time.time()
while time.time() - start_time < timeout:
try:
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {api_key}",
}
response = requests.get(f"{base_url}/v1/models", headers=headers)
if response.status_code == 200:
return process
except requests.RequestException:
pass
time.sleep(10)
raise TimeoutError("Server failed to start within the timeout period.")
def run_with_timeout(
func: Callable,
args: tuple = (),
kwargs: Optional[dict] = None,
timeout: float = None,
):
"""Run a function with timeout."""
ret_value = []
def _target_func():
ret_value.append(func(*args, **(kwargs or {})))
t = threading.Thread(target=_target_func)
t.start()
t.join(timeout=timeout)
if t.is_alive():
raise TimeoutError()
if not ret_value:
raise RuntimeError()
return ret_value[0]
def run_unittest_files(files: List[str], timeout_per_file: float):
tic = time.time()
success = True
for filename in files:
global process
def run_one_file(filename):
filename = os.path.join(os.getcwd(), filename)
print(f"\n\nRun:\npython3 {filename}\n\n", flush=True)
process = subprocess.Popen(
["python3", filename], stdout=None, stderr=None, env=os.environ
)
process.wait()
return process.returncode
try:
ret_code = run_with_timeout(
run_one_file, args=(filename,), timeout=timeout_per_file
)
assert ret_code == 0
except TimeoutError:
kill_child_process(process.pid)
time.sleep(5)
print(
f"\nTimeout after {timeout_per_file} seconds when running {filename}\n",
flush=True,
)
success = False
break
if success:
print(f"Success. Time elapsed: {time.time() - tic:.2f}s", flush=True)
else:
print(f"Fail. Time elapsed: {time.time() - tic:.2f}s", flush=True)
return 0 if success else -1
def get_similarities(vec1, vec2):
return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)
def run_bench_serving(model, num_prompts, request_rate, other_server_args):
# Launch the server
base_url = DEFAULT_URL_FOR_TEST
process = popen_launch_server(
model,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_server_args,
)
# Run benchmark
args = SimpleNamespace(
backend="sglang",
base_url=base_url,
host=None,
port=None,
dataset_name="random",
dataset_path="",
model=None,
tokenizer=None,
num_prompts=num_prompts,
sharegpt_output_len=None,
random_input_len=4096,
random_output_len=2048,
random_range_ratio=0.0,
request_rate=request_rate,
multi=None,
seed=0,
output_file=None,
disable_tqdm=False,
disable_stream=False,
disable_ignore_eos=False,
extra_request_body=None,
)
try:
res = run_benchmark(args)
finally:
kill_child_process(process.pid)
assert res["completed"] == num_prompts
return res
def run_bench_latency(model, other_args):
command = [
"python3",
"-m",
"sglang.bench_latency",
"--model-path",
model,
"--batch-size",
"1",
"--input",
"128",
"--output",
"8",
*other_args,
]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
try:
stdout, stderr = process.communicate()
output = stdout.decode()
error = stderr.decode()
print(f"Output: {output}", flush=True)
print(f"Error: {error}", flush=True)
lastline = output.split("\n")[-3]
output_throughput = float(lastline.split(" ")[-2])
finally:
kill_child_process(process.pid)
return output_throughput