Improve the structure of CI (#911)
This commit is contained in:
19
.github/workflows/unit-test.yml
vendored
19
.github/workflows/unit-test.yml
vendored
@@ -37,23 +37,12 @@ jobs:
|
||||
pip install --upgrade transformers
|
||||
pip install accelerate
|
||||
|
||||
- name: Test Frontend Language with SRT Backend
|
||||
- name: Test Frontend Language
|
||||
run: |
|
||||
cd test/lang
|
||||
python3 test_srt_backend.py
|
||||
python3 run_suite.py --suite minimal
|
||||
|
||||
- name: Test OpenAI API Server
|
||||
- name: Test Backend Runtime
|
||||
run: |
|
||||
cd test/srt
|
||||
python3 test_openai_server.py
|
||||
|
||||
- name: Test Accuracy
|
||||
run: |
|
||||
cd test/srt
|
||||
python3 test_eval_accuracy.py
|
||||
python3 models/test_causal_models.py
|
||||
|
||||
- name: Test Frontend Language with OpenAI Backend
|
||||
run: |
|
||||
cd test/lang
|
||||
python3 test_openai_backend.py
|
||||
python3 run_suite.py --suite minimal
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
# SRT Unit Tests
|
||||
|
||||
### Latency Alignment
|
||||
Make sure your changes do not slow down the following benchmarks
|
||||
```
|
||||
# single gpu
|
||||
python -m sglang.bench_latency --model-path meta-llama/Llama-2-7b-chat-hf --mem-fraction-static 0.8 --batch 32 --input-len 512 --output-len 256
|
||||
python -m sglang.bench_latency --model-path meta-llama/Llama-2-7b-chat-hf --mem-fraction-static 0.8 --batch 1 --input-len 512 --output-len 256
|
||||
|
||||
# multiple gpu
|
||||
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-70B --tp 8 --mem-fraction-static 0.6 --batch 32 --input-len 8192 --output-len 1
|
||||
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-70B --tp 8 --mem-fraction-static 0.6 --batch 1 --input-len 8100 --output-len 32
|
||||
|
||||
# moe model
|
||||
python -m sglang.bench_latency --model-path databricks/dbrx-base --tp 8 --mem-fraction-static 0.6 --batch 4 --input-len 1024 --output-len 32
|
||||
```
|
||||
|
||||
### High-level API
|
||||
|
||||
```
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
cd test/lang
|
||||
python3 test_srt_backend.py
|
||||
```
|
||||
|
||||
### Performance
|
||||
|
||||
#### MMLU
|
||||
```
|
||||
cd benchmark/mmlu
|
||||
```
|
||||
Follow README.md to download the data.
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --nsub 3
|
||||
|
||||
# Expected performance on A10G
|
||||
# Total latency: 8.200
|
||||
# Average accuracy: 0.413
|
||||
```
|
||||
|
||||
#### GSM-8K
|
||||
```
|
||||
cd benchmark/gsm8k
|
||||
```
|
||||
Follow README.md to download the data.
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --num-q 200
|
||||
|
||||
# Expected performance on A10G
|
||||
# Latency: 32.103
|
||||
# Accuracy: 0.250
|
||||
```
|
||||
|
||||
#### More
|
||||
Please also test `benchmark/hellaswag`, `benchmark/latency_throughput`.
|
||||
|
||||
### More Models
|
||||
|
||||
#### LLaVA
|
||||
|
||||
```
|
||||
python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
cd benchmark/llava_bench
|
||||
python3 bench_sglang.py
|
||||
|
||||
# Expected performance on A10G
|
||||
# Latency: 50.031
|
||||
```
|
||||
|
||||
## SGLang Unit Tests
|
||||
```
|
||||
export ANTHROPIC_API_KEY=
|
||||
export OPENAI_API_KEY=
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
cd test/lang
|
||||
python3 run_all.py
|
||||
```
|
||||
|
||||
## OpenAI API server
|
||||
```
|
||||
cd test/srt
|
||||
python test_openai_server.py
|
||||
```
|
||||
|
||||
## Code Formatting
|
||||
```
|
||||
pip3 install pre-commit
|
||||
cd sglang
|
||||
pre-commit install
|
||||
pre-commit run --all-files
|
||||
```
|
||||
@@ -20,8 +20,10 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow",
|
||||
"psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.3.post1", "outlines>=0.0.44", "python-multipart", "jsonlines"]
|
||||
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "jsonlines",
|
||||
"packaging", "pillow", "psutil", "pydantic", "python-multipart",
|
||||
"torch", "uvicorn", "uvloop", "zmq",
|
||||
"vllm==0.5.3.post1", "outlines>=0.0.44"]
|
||||
openai = ["openai>=1.0", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
litellm = ["litellm>=1.0.0"]
|
||||
|
||||
@@ -10,7 +10,6 @@ import time
|
||||
|
||||
from sglang.test.simple_eval_common import (
|
||||
ChatCompletionSampler,
|
||||
download_dataset,
|
||||
make_report,
|
||||
set_ulimit,
|
||||
)
|
||||
@@ -27,14 +26,26 @@ def run_eval(args):
|
||||
if args.eval_name == "mmlu":
|
||||
from sglang.test.simple_eval_mmlu import MMLUEval
|
||||
|
||||
dataset_path = "mmlu.csv"
|
||||
filename = "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv"
|
||||
eval_obj = MMLUEval(filename, args.num_examples, args.num_threads)
|
||||
elif args.eval_name == "math":
|
||||
from sglang.test.simple_eval_math import MathEval
|
||||
|
||||
if not os.path.exists(dataset_path):
|
||||
download_dataset(
|
||||
dataset_path,
|
||||
"https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
|
||||
)
|
||||
eval_obj = MMLUEval(dataset_path, args.num_examples, args.num_threads)
|
||||
equality_checker = ChatCompletionSampler(model="gpt-4-turbo")
|
||||
|
||||
filename = (
|
||||
"https://openaipublic.blob.core.windows.net/simple-evals/math_test.csv"
|
||||
)
|
||||
eval_obj = MathEval(
|
||||
filename, equality_checker, args.num_examples, args.num_threads
|
||||
)
|
||||
elif args.eval_name == "gpqa":
|
||||
from sglang.test.simple_eval_gpqa import GPQAEval
|
||||
|
||||
filename = (
|
||||
"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_diamond.csv"
|
||||
)
|
||||
eval_obj = GPQAEval(filename, args.num_examples, args.num_threads)
|
||||
elif args.eval_name == "humaneval":
|
||||
from sglang.test.simple_eval_humaneval import HumanEval
|
||||
|
||||
@@ -97,7 +108,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument("--eval-name", type=str, default="mmlu")
|
||||
parser.add_argument("--num-examples", type=int)
|
||||
parser.add_argument("--num-threads", type=int, default=64)
|
||||
parser.add_argument("--num-threads", type=int, default=512)
|
||||
set_ulimit()
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
92
python/sglang/test/simple_eval_gpqa.py
Normal file
92
python/sglang/test/simple_eval_gpqa.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Adapted from https://github.com/openai/simple-evals/
|
||||
|
||||
"""
|
||||
GPQA: A Graduate-Level Google-Proof Q&A Benchmark
|
||||
David Rein, Betty Li Hou, Asa Cooper Stickland, Jackson Petty, Richard Yuanzhe Pang, Julien Dirani, Julian Michael, Samuel R. Bowman
|
||||
https://arxiv.org/abs/2311.12022
|
||||
"""
|
||||
|
||||
import random
|
||||
import re
|
||||
|
||||
import pandas
|
||||
|
||||
from sglang.test import simple_eval_common as common
|
||||
from sglang.test.simple_eval_common import (
|
||||
ANSWER_PATTERN_MULTICHOICE,
|
||||
HTML_JINJA,
|
||||
Eval,
|
||||
EvalResult,
|
||||
MessageList,
|
||||
SamplerBase,
|
||||
SingleEvalResult,
|
||||
format_multichoice_question,
|
||||
)
|
||||
|
||||
|
||||
class GPQAEval(Eval):
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
num_examples: int | None,
|
||||
num_threads: int,
|
||||
n_repeats: int = 1,
|
||||
):
|
||||
df = pandas.read_csv(filename)
|
||||
examples = [row.to_dict() for _, row in df.iterrows()]
|
||||
rng = random.Random(0)
|
||||
if num_examples:
|
||||
assert n_repeats == 1, "n_repeats only supported for num_examples"
|
||||
examples = rng.sample(examples, num_examples)
|
||||
examples = examples * n_repeats
|
||||
examples = [
|
||||
example | {"permutation": rng.sample(range(4), 4)} for example in examples
|
||||
]
|
||||
self.examples = examples
|
||||
self.n_repeats = n_repeats
|
||||
self.num_threads = num_threads
|
||||
|
||||
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||
def fn(row: dict):
|
||||
choices = [
|
||||
row["Correct Answer"],
|
||||
row["Incorrect Answer 1"],
|
||||
row["Incorrect Answer 2"],
|
||||
row["Incorrect Answer 3"],
|
||||
]
|
||||
choices = [choices[i] for i in row["permutation"]]
|
||||
correct_index = choices.index(row["Correct Answer"])
|
||||
correct_answer = "ABCD"[correct_index]
|
||||
choices_dict = dict(
|
||||
A=choices[0],
|
||||
B=choices[1],
|
||||
C=choices[2],
|
||||
D=choices[3],
|
||||
Question=row["Question"],
|
||||
)
|
||||
prompt_messages = [
|
||||
sampler._pack_message(
|
||||
content=format_multichoice_question(choices_dict), role="user"
|
||||
)
|
||||
]
|
||||
response_text = sampler(prompt_messages)
|
||||
match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
|
||||
extracted_answer = match.group(1) if match else None
|
||||
score = 1.0 if extracted_answer == correct_answer else 0.0
|
||||
html = common.jinja_env.from_string(HTML_JINJA).render(
|
||||
prompt_messages=prompt_messages,
|
||||
next_message=dict(content=response_text, role="assistant"),
|
||||
score=score,
|
||||
correct_answer=correct_answer,
|
||||
extracted_answer=extracted_answer,
|
||||
)
|
||||
convo = prompt_messages + [dict(content=response_text, role="assistant")]
|
||||
return SingleEvalResult(
|
||||
html=html,
|
||||
score=score,
|
||||
convo=convo,
|
||||
metrics={"chars": len(response_text)},
|
||||
)
|
||||
|
||||
results = common.map_with_progress(fn, self.examples, self.num_threads)
|
||||
return common.aggregate_results(results)
|
||||
72
python/sglang/test/simple_eval_math.py
Normal file
72
python/sglang/test/simple_eval_math.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Adapted from https://github.com/openai/simple-evals/
|
||||
|
||||
"""
|
||||
Measuring Mathematical Problem Solving With the MATH Dataset
|
||||
Dan Hendrycks, Collin Burns, Saurav Kadavath, Akul Arora, Steven Basart, Eric Tang, Dawn Song, Jacob Steinhardt
|
||||
https://arxiv.org/abs/2103.03874
|
||||
"""
|
||||
|
||||
import random
|
||||
import re
|
||||
|
||||
import pandas
|
||||
|
||||
from sglang.test import simple_eval_common as common
|
||||
from sglang.test.simple_eval_common import (
|
||||
ANSWER_PATTERN,
|
||||
HTML_JINJA,
|
||||
Eval,
|
||||
EvalResult,
|
||||
SamplerBase,
|
||||
SingleEvalResult,
|
||||
check_equality,
|
||||
)
|
||||
|
||||
QUERY_TEMPLATE = """
|
||||
Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
|
||||
|
||||
{Question}
|
||||
|
||||
Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command.
|
||||
""".strip()
|
||||
|
||||
|
||||
class MathEval(Eval):
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
equality_checker: SamplerBase,
|
||||
num_examples: int | None,
|
||||
num_threads: int,
|
||||
):
|
||||
df = pandas.read_csv(filename)
|
||||
examples = [row.to_dict() for _, row in df.iterrows()]
|
||||
if num_examples:
|
||||
examples = random.Random(0).sample(examples, num_examples)
|
||||
self.examples = examples
|
||||
self.equality_checker = equality_checker
|
||||
self.num_threads = num_threads
|
||||
|
||||
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||
def fn(row: dict):
|
||||
prompt_messages = [
|
||||
sampler._pack_message(content=QUERY_TEMPLATE.format(**row), role="user")
|
||||
]
|
||||
response_text = sampler(prompt_messages)
|
||||
match = re.search(ANSWER_PATTERN, response_text)
|
||||
extracted_answer = match.group(1) if match else None
|
||||
score = float(
|
||||
check_equality(self.equality_checker, row["Answer"], extracted_answer)
|
||||
)
|
||||
html = common.jinja_env.from_string(HTML_JINJA).render(
|
||||
prompt_messages=prompt_messages,
|
||||
next_message=dict(content=response_text, role="assistant"),
|
||||
score=score,
|
||||
correct_answer=row["Answer"],
|
||||
extracted_answer=extracted_answer,
|
||||
)
|
||||
convo = prompt_messages + [dict(content=response_text, role="assistant")]
|
||||
return SingleEvalResult(html=html, score=score, convo=convo)
|
||||
|
||||
results = common.map_with_progress(fn, self.examples, self.num_threads)
|
||||
return common.aggregate_results(results)
|
||||
@@ -1,9 +1,14 @@
|
||||
"""Common utilities for testing and benchmarking"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from functools import partial
|
||||
from typing import Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -247,7 +252,7 @@ async def call_select_lmql(context, choices, temperature=0, max_len=4096, model=
|
||||
return choices.index(answer)
|
||||
|
||||
|
||||
def add_common_other_args_and_parse(parser):
|
||||
def add_common_other_args_and_parse(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--parallel", type=int, default=64)
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=None)
|
||||
@@ -286,7 +291,7 @@ def add_common_other_args_and_parse(parser):
|
||||
return args
|
||||
|
||||
|
||||
def add_common_sglang_args_and_parse(parser):
|
||||
def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--parallel", type=int, default=64)
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
@@ -296,7 +301,7 @@ def add_common_sglang_args_and_parse(parser):
|
||||
return args
|
||||
|
||||
|
||||
def select_sglang_backend(args):
|
||||
def select_sglang_backend(args: argparse.Namespace):
|
||||
if args.backend.startswith("srt"):
|
||||
if args.backend == "srt-no-parallel":
|
||||
global_config.enable_parallel_decoding = False
|
||||
@@ -309,7 +314,7 @@ def select_sglang_backend(args):
|
||||
return backend
|
||||
|
||||
|
||||
def _get_call_generate(args):
|
||||
def _get_call_generate(args: argparse.Namespace):
|
||||
if args.backend == "lightllm":
|
||||
return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate")
|
||||
elif args.backend == "vllm":
|
||||
@@ -336,7 +341,7 @@ def _get_call_generate(args):
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
|
||||
def _get_call_select(args):
|
||||
def _get_call_select(args: argparse.Namespace):
|
||||
if args.backend == "lightllm":
|
||||
return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate")
|
||||
elif args.backend == "vllm":
|
||||
@@ -359,7 +364,7 @@ def _get_call_select(args):
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
|
||||
def get_call_generate(args):
|
||||
def get_call_generate(args: argparse.Namespace):
|
||||
call_generate = _get_call_generate(args)
|
||||
|
||||
def func(*args, **kwargs):
|
||||
@@ -372,7 +377,7 @@ def get_call_generate(args):
|
||||
return func
|
||||
|
||||
|
||||
def get_call_select(args):
|
||||
def get_call_select(args: argparse.Namespace):
|
||||
call_select = _get_call_select(args)
|
||||
|
||||
def func(*args, **kwargs):
|
||||
@@ -385,7 +390,12 @@ def get_call_select(args):
|
||||
return func
|
||||
|
||||
|
||||
def popen_launch_server(model, port, timeout, *args):
|
||||
def popen_launch_server(
|
||||
model: str, base_url: str, timeout: float, other_args: tuple = ()
|
||||
):
|
||||
_, host, port = base_url.split(":")
|
||||
host = host[2:]
|
||||
|
||||
command = [
|
||||
"python3",
|
||||
"-m",
|
||||
@@ -393,21 +403,81 @@ def popen_launch_server(model, port, timeout, *args):
|
||||
"--model-path",
|
||||
model,
|
||||
"--host",
|
||||
"localhost",
|
||||
host,
|
||||
"--port",
|
||||
str(port),
|
||||
*args,
|
||||
port,
|
||||
*other_args,
|
||||
]
|
||||
process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||
base_url = f"http://localhost:{port}/v1"
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(f"{base_url}/models")
|
||||
response = requests.get(f"{base_url}/v1/models")
|
||||
if response.status_code == 200:
|
||||
return process
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(10)
|
||||
raise TimeoutError("Server failed to start within the timeout period.")
|
||||
|
||||
|
||||
def run_with_timeout(
|
||||
func: Callable,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict] = None,
|
||||
timeout: float = None,
|
||||
):
|
||||
"""Run a function with timeout."""
|
||||
ret_value = []
|
||||
|
||||
def _target_func():
|
||||
ret_value.append(func(*args, **(kwargs or {})))
|
||||
|
||||
t = threading.Thread(target=_target_func)
|
||||
t.start()
|
||||
t.join(timeout=timeout)
|
||||
if t.is_alive():
|
||||
raise TimeoutError()
|
||||
|
||||
if not ret_value:
|
||||
raise RuntimeError()
|
||||
|
||||
return ret_value[0]
|
||||
|
||||
|
||||
def run_unittest_files(files: list[str], timeout_per_file: float):
|
||||
tic = time.time()
|
||||
success = True
|
||||
|
||||
for filename in files:
|
||||
|
||||
def func():
|
||||
print(f"\n\nRun {filename}\n\n")
|
||||
ret = unittest.main(module=None, argv=["", "-vb"] + [filename])
|
||||
|
||||
p = multiprocessing.Process(target=func)
|
||||
|
||||
def run_one_file():
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
try:
|
||||
run_with_timeout(run_one_file, timeout=timeout_per_file)
|
||||
if p.exitcode != 0:
|
||||
success = False
|
||||
break
|
||||
except TimeoutError:
|
||||
p.terminate()
|
||||
time.sleep(5)
|
||||
print(
|
||||
"\nTimeout after {timeout_per_file} seconds when running {filename}\n"
|
||||
)
|
||||
return False
|
||||
|
||||
if success:
|
||||
print(f"Success. Time elapsed: {time.time() - tic:.2f}s")
|
||||
else:
|
||||
print(f"Fail. Time elapsed: {time.time() - tic:.2f}s")
|
||||
|
||||
return 0 if success else -1
|
||||
|
||||
@@ -12,6 +12,7 @@ import urllib.request
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from io import BytesIO
|
||||
from json import dumps
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -25,7 +26,7 @@ def get_exception_traceback():
|
||||
return err_str
|
||||
|
||||
|
||||
def is_same_type(values):
|
||||
def is_same_type(values: list):
|
||||
"""Return whether the elements in values are of the same type."""
|
||||
if len(values) <= 1:
|
||||
return True
|
||||
@@ -45,7 +46,7 @@ def read_jsonl(filename: str):
|
||||
return rets
|
||||
|
||||
|
||||
def dump_state_text(filename, states, mode="w"):
|
||||
def dump_state_text(filename: str, states: list, mode: str = "w"):
|
||||
"""Dump program state in a text file."""
|
||||
from sglang.lang.interpreter import ProgramState
|
||||
|
||||
@@ -105,7 +106,7 @@ def http_request(
|
||||
return HttpResponse(e)
|
||||
|
||||
|
||||
def encode_image_base64(image_path):
|
||||
def encode_image_base64(image_path: Union[str, bytes]):
|
||||
"""Encode an image in base64."""
|
||||
if isinstance(image_path, str):
|
||||
with open(image_path, "rb") as image_file:
|
||||
@@ -144,7 +145,7 @@ def encode_frame(frame):
|
||||
return frame_bytes
|
||||
|
||||
|
||||
def encode_video_base64(video_path, num_frames=16):
|
||||
def encode_video_base64(video_path: str, num_frames: int = 16):
|
||||
import cv2 # pip install opencv-python-headless
|
||||
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
@@ -190,7 +191,7 @@ def encode_video_base64(video_path, num_frames=16):
|
||||
return video_base64
|
||||
|
||||
|
||||
def _is_chinese_char(cp):
|
||||
def _is_chinese_char(cp: int):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
@@ -215,7 +216,7 @@ def _is_chinese_char(cp):
|
||||
return False
|
||||
|
||||
|
||||
def find_printable_text(text):
|
||||
def find_printable_text(text: str):
|
||||
"""Returns the longest printable substring of text that contains only entire words."""
|
||||
# Borrowed from https://github.com/huggingface/transformers/blob/061580c82c2db1de9139528243e105953793f7a2/src/transformers/generation/streamers.py#L99
|
||||
|
||||
@@ -234,26 +235,7 @@ def find_printable_text(text):
|
||||
return text[: text.rfind(" ") + 1]
|
||||
|
||||
|
||||
def run_with_timeout(func, args=(), kwargs=None, timeout=None):
|
||||
"""Run a function with timeout."""
|
||||
ret_value = []
|
||||
|
||||
def _target_func():
|
||||
ret_value.append(func(*args, **(kwargs or {})))
|
||||
|
||||
t = threading.Thread(target=_target_func)
|
||||
t.start()
|
||||
t.join(timeout=timeout)
|
||||
if t.is_alive():
|
||||
raise TimeoutError()
|
||||
|
||||
if not ret_value:
|
||||
raise RuntimeError()
|
||||
|
||||
return ret_value[0]
|
||||
|
||||
|
||||
def graceful_registry(sub_module_name):
|
||||
def graceful_registry(sub_module_name: str):
|
||||
def graceful_shutdown(signum, frame):
|
||||
logger.info(
|
||||
f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
|
||||
@@ -265,7 +247,9 @@ def graceful_registry(sub_module_name):
|
||||
|
||||
|
||||
class LazyImport:
|
||||
def __init__(self, module_name, class_name):
|
||||
"""Lazy import to make `import sglang` run faster."""
|
||||
|
||||
def __init__(self, module_name: str, class_name: str):
|
||||
self.module_name = module_name
|
||||
self.class_name = class_name
|
||||
self._module = None
|
||||
@@ -276,7 +260,7 @@ class LazyImport:
|
||||
self._module = getattr(module, self.class_name)
|
||||
return self._module
|
||||
|
||||
def __getattr__(self, name):
|
||||
def __getattr__(self, name: str):
|
||||
module = self._load()
|
||||
return getattr(module, name)
|
||||
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
isort python
|
||||
black python
|
||||
|
||||
isort test
|
||||
black test
|
||||
|
||||
isort benchmark
|
||||
black benchmark
|
||||
@@ -1,6 +0,0 @@
|
||||
docker run --name tgi --rm -ti --gpus all --network host \
|
||||
-v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \
|
||||
ghcr.io/huggingface/text-generation-inference:1.3.0 \
|
||||
--model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \
|
||||
--max-input-length 2048 --max-total-tokens 4096 \
|
||||
--port 24000
|
||||
26
test/README.md
Normal file
26
test/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# Run Unit Tests
|
||||
|
||||
## Test Frontend Language
|
||||
```
|
||||
cd sglang/test/lang
|
||||
export OPENAI_API_KEY=sk-*****
|
||||
|
||||
# Run a single file
|
||||
python3 test_openai_backend.py
|
||||
|
||||
# Run a suite
|
||||
python3 run_suite.py --suite minimal
|
||||
```
|
||||
|
||||
## Test Backend Runtime
|
||||
```
|
||||
cd sglang/test/srt
|
||||
|
||||
# Run a single file
|
||||
python3 test_eval_accuracy.py
|
||||
|
||||
# Run a suite
|
||||
python3 run_suite.py --suite minimal
|
||||
```
|
||||
|
||||
|
||||
@@ -1,50 +1,17 @@
|
||||
import argparse
|
||||
import glob
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from sglang.utils import run_with_timeout
|
||||
from sglang.test.test_utils import run_unittest_files
|
||||
|
||||
suites = {
|
||||
"minimal": ["test_openai_backend.py", "test_srt_backend.py"],
|
||||
"minimal": ["test_srt_backend.py", "test_openai_backend.py"],
|
||||
}
|
||||
|
||||
|
||||
def run_unittest_files(files, args):
|
||||
for filename in files:
|
||||
|
||||
def func():
|
||||
print(filename)
|
||||
ret = unittest.main(module=None, argv=["", "-vb"] + [filename])
|
||||
|
||||
p = multiprocessing.Process(target=func)
|
||||
|
||||
def run_one_file():
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
try:
|
||||
run_with_timeout(run_one_file, timeout=args.time_limit_per_file)
|
||||
if p.exitcode != 0:
|
||||
return False
|
||||
except TimeoutError:
|
||||
p.terminate()
|
||||
time.sleep(5)
|
||||
print(
|
||||
f"\nTimeout after {args.time_limit_per_file} seconds "
|
||||
f"when running {filename}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser.add_argument(
|
||||
"--time-limit-per-file",
|
||||
"--timeout-per-file",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="The time limit for running one file in seconds.",
|
||||
@@ -63,12 +30,5 @@ if __name__ == "__main__":
|
||||
else:
|
||||
files = suites[args.suite]
|
||||
|
||||
tic = time.time()
|
||||
success = run_unittest_files(files, args)
|
||||
|
||||
if success:
|
||||
print(f"Success. Time elapsed: {time.time() - tic:.2f}s")
|
||||
else:
|
||||
print(f"Fail. Time elapsed: {time.time() - tic:.2f}s")
|
||||
|
||||
exit(0 if success else -1)
|
||||
exit_code = run_unittest_files(files, args.timeout_per_file)
|
||||
exit(exit_code)
|
||||
|
||||
@@ -18,6 +18,7 @@ import torch
|
||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||
|
||||
MODELS = [
|
||||
# (model_name, tp_size)
|
||||
("meta-llama/Meta-Llama-3.1-8B-Instruct", 1),
|
||||
# ("meta-llama/Meta-Llama-3.1-8B-Instruct", 2),
|
||||
]
|
||||
|
||||
40
test/srt/run_suite.py
Normal file
40
test/srt/run_suite.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import argparse
|
||||
import glob
|
||||
|
||||
from sglang.test.test_utils import run_unittest_files
|
||||
|
||||
suites = {
|
||||
"minimal": [
|
||||
"test_openai_server.py",
|
||||
"test_eval_accuracy.py",
|
||||
"test_chunked_prefill.py",
|
||||
"test_torch_compile.py",
|
||||
"models/test_causal_models.py",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser.add_argument(
|
||||
"--timeout-per-file",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="The time limit for running one file in seconds.",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"--suite",
|
||||
type=str,
|
||||
default=list(suites.keys())[0],
|
||||
choices=list(suites.keys()) + ["all"],
|
||||
help="The suite to run",
|
||||
)
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
if args.suite == "all":
|
||||
files = glob.glob("**/test_*.py", recursive=True)
|
||||
else:
|
||||
files = suites[args.suite]
|
||||
|
||||
exit_code = run_unittest_files(files, args.timeout_per_file)
|
||||
exit(exit_code)
|
||||
45
test/srt/test_chunked_prefill.py
Normal file
45
test/srt/test_chunked_prefill.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_child_process
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server
|
||||
|
||||
|
||||
class TestAccuracy(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = MODEL_NAME_FOR_TEST
|
||||
cls.base_url = f"http://localhost:30000"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=300,
|
||||
other_args=["--chunked-prefill-size", "32"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=20,
|
||||
num_threads=20,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
assert metrics["score"] >= 0.5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(warnings="ignore")
|
||||
|
||||
# t = TestAccuracy()
|
||||
# t.setUpClass()
|
||||
# t.test_mmlu()
|
||||
# t.tearDownClass()
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
@@ -11,11 +10,9 @@ class TestAccuracy(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
port = 30000
|
||||
|
||||
cls.model = MODEL_NAME_FOR_TEST
|
||||
cls.base_url = f"http://localhost:{port}"
|
||||
cls.process = popen_launch_server(cls.model, port, timeout=300)
|
||||
cls.base_url = f"http://localhost:30000"
|
||||
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
|
||||
@@ -11,11 +11,10 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
port = 30000
|
||||
|
||||
cls.model = MODEL_NAME_FOR_TEST
|
||||
cls.base_url = f"http://localhost:{port}/v1"
|
||||
cls.process = popen_launch_server(cls.model, port, timeout=300)
|
||||
cls.base_url = f"http://localhost:30000"
|
||||
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
|
||||
42
test/srt/test_torch_compile.py
Normal file
42
test/srt/test_torch_compile.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_child_process
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server
|
||||
|
||||
|
||||
class TestAccuracy(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = MODEL_NAME_FOR_TEST
|
||||
cls.base_url = f"http://localhost:30000"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model, cls.base_url, timeout=300, other_args=["--enable-torch-compile"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=20,
|
||||
num_threads=20,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
assert metrics["score"] >= 0.5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(warnings="ignore")
|
||||
|
||||
# t = TestAccuracy()
|
||||
# t.setUpClass()
|
||||
# t.test_mmlu()
|
||||
# t.tearDownClass()
|
||||
Reference in New Issue
Block a user