diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 8b7f88e8b..0bd0896e4 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -37,23 +37,12 @@ jobs: pip install --upgrade transformers pip install accelerate - - name: Test Frontend Language with SRT Backend + - name: Test Frontend Language run: | cd test/lang - python3 test_srt_backend.py + python3 run_suite.py --suite minimal - - name: Test OpenAI API Server + - name: Test Backend Runtime run: | cd test/srt - python3 test_openai_server.py - - - name: Test Accuracy - run: | - cd test/srt - python3 test_eval_accuracy.py - python3 models/test_causal_models.py - - - name: Test Frontend Language with OpenAI Backend - run: | - cd test/lang - python3 test_openai_backend.py \ No newline at end of file + python3 run_suite.py --suite minimal diff --git a/docs/en/test_process.md b/docs/en/test_process.md deleted file mode 100644 index 509fb9ede..000000000 --- a/docs/en/test_process.md +++ /dev/null @@ -1,102 +0,0 @@ -# SRT Unit Tests - -### Latency Alignment -Make sure your changes do not slow down the following benchmarks -``` -# single gpu -python -m sglang.bench_latency --model-path meta-llama/Llama-2-7b-chat-hf --mem-fraction-static 0.8 --batch 32 --input-len 512 --output-len 256 -python -m sglang.bench_latency --model-path meta-llama/Llama-2-7b-chat-hf --mem-fraction-static 0.8 --batch 1 --input-len 512 --output-len 256 - -# multiple gpu -python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-70B --tp 8 --mem-fraction-static 0.6 --batch 32 --input-len 8192 --output-len 1 -python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-70B --tp 8 --mem-fraction-static 0.6 --batch 1 --input-len 8100 --output-len 32 - -# moe model -python -m sglang.bench_latency --model-path databricks/dbrx-base --tp 8 --mem-fraction-static 0.6 --batch 4 --input-len 1024 --output-len 32 -``` - -### High-level API - -``` -python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 -``` - -``` -cd test/lang -python3 test_srt_backend.py -``` - -### Performance - -#### MMLU -``` -cd benchmark/mmlu -``` -Follow README.md to download the data. - -``` -python3 bench_sglang.py --nsub 3 - -# Expected performance on A10G -# Total latency: 8.200 -# Average accuracy: 0.413 -``` - -#### GSM-8K -``` -cd benchmark/gsm8k -``` -Follow README.md to download the data. - -``` -python3 bench_sglang.py --num-q 200 - -# Expected performance on A10G -# Latency: 32.103 -# Accuracy: 0.250 -``` - -#### More -Please also test `benchmark/hellaswag`, `benchmark/latency_throughput`. - -### More Models - -#### LLaVA - -``` -python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000 -``` - -``` -cd benchmark/llava_bench -python3 bench_sglang.py - -# Expected performance on A10G -# Latency: 50.031 -``` - -## SGLang Unit Tests -``` -export ANTHROPIC_API_KEY= -export OPENAI_API_KEY= -python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 -``` - -``` -cd test/lang -python3 run_all.py -``` - -## OpenAI API server -``` -cd test/srt -python test_openai_server.py -``` - -## Code Formatting -``` -pip3 install pre-commit -cd sglang -pre-commit install -pre-commit run --all-files -``` diff --git a/python/pyproject.toml b/python/pyproject.toml index b59ef852b..ebaa541e0 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -20,8 +20,10 @@ dependencies = [ ] [project.optional-dependencies] -srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow", - "psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.3.post1", "outlines>=0.0.44", "python-multipart", "jsonlines"] +srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "jsonlines", + "packaging", "pillow", "psutil", "pydantic", "python-multipart", + "torch", "uvicorn", "uvloop", "zmq", + "vllm==0.5.3.post1", "outlines>=0.0.44"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] litellm = ["litellm>=1.0.0"] diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py index 178b79c22..6c1f284b1 100644 --- a/python/sglang/test/run_eval.py +++ b/python/sglang/test/run_eval.py @@ -10,7 +10,6 @@ import time from sglang.test.simple_eval_common import ( ChatCompletionSampler, - download_dataset, make_report, set_ulimit, ) @@ -27,14 +26,26 @@ def run_eval(args): if args.eval_name == "mmlu": from sglang.test.simple_eval_mmlu import MMLUEval - dataset_path = "mmlu.csv" + filename = "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv" + eval_obj = MMLUEval(filename, args.num_examples, args.num_threads) + elif args.eval_name == "math": + from sglang.test.simple_eval_math import MathEval - if not os.path.exists(dataset_path): - download_dataset( - dataset_path, - "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", - ) - eval_obj = MMLUEval(dataset_path, args.num_examples, args.num_threads) + equality_checker = ChatCompletionSampler(model="gpt-4-turbo") + + filename = ( + "https://openaipublic.blob.core.windows.net/simple-evals/math_test.csv" + ) + eval_obj = MathEval( + filename, equality_checker, args.num_examples, args.num_threads + ) + elif args.eval_name == "gpqa": + from sglang.test.simple_eval_gpqa import GPQAEval + + filename = ( + "https://openaipublic.blob.core.windows.net/simple-evals/gpqa_diamond.csv" + ) + eval_obj = GPQAEval(filename, args.num_examples, args.num_threads) elif args.eval_name == "humaneval": from sglang.test.simple_eval_humaneval import HumanEval @@ -97,7 +108,7 @@ if __name__ == "__main__": ) parser.add_argument("--eval-name", type=str, default="mmlu") parser.add_argument("--num-examples", type=int) - parser.add_argument("--num-threads", type=int, default=64) + parser.add_argument("--num-threads", type=int, default=512) set_ulimit() args = parser.parse_args() diff --git a/python/sglang/test/simple_eval_gpqa.py b/python/sglang/test/simple_eval_gpqa.py new file mode 100644 index 000000000..46055caa5 --- /dev/null +++ b/python/sglang/test/simple_eval_gpqa.py @@ -0,0 +1,92 @@ +# Adapted from https://github.com/openai/simple-evals/ + +""" +GPQA: A Graduate-Level Google-Proof Q&A Benchmark +David Rein, Betty Li Hou, Asa Cooper Stickland, Jackson Petty, Richard Yuanzhe Pang, Julien Dirani, Julian Michael, Samuel R. Bowman +https://arxiv.org/abs/2311.12022 +""" + +import random +import re + +import pandas + +from sglang.test import simple_eval_common as common +from sglang.test.simple_eval_common import ( + ANSWER_PATTERN_MULTICHOICE, + HTML_JINJA, + Eval, + EvalResult, + MessageList, + SamplerBase, + SingleEvalResult, + format_multichoice_question, +) + + +class GPQAEval(Eval): + def __init__( + self, + filename: str, + num_examples: int | None, + num_threads: int, + n_repeats: int = 1, + ): + df = pandas.read_csv(filename) + examples = [row.to_dict() for _, row in df.iterrows()] + rng = random.Random(0) + if num_examples: + assert n_repeats == 1, "n_repeats only supported for num_examples" + examples = rng.sample(examples, num_examples) + examples = examples * n_repeats + examples = [ + example | {"permutation": rng.sample(range(4), 4)} for example in examples + ] + self.examples = examples + self.n_repeats = n_repeats + self.num_threads = num_threads + + def __call__(self, sampler: SamplerBase) -> EvalResult: + def fn(row: dict): + choices = [ + row["Correct Answer"], + row["Incorrect Answer 1"], + row["Incorrect Answer 2"], + row["Incorrect Answer 3"], + ] + choices = [choices[i] for i in row["permutation"]] + correct_index = choices.index(row["Correct Answer"]) + correct_answer = "ABCD"[correct_index] + choices_dict = dict( + A=choices[0], + B=choices[1], + C=choices[2], + D=choices[3], + Question=row["Question"], + ) + prompt_messages = [ + sampler._pack_message( + content=format_multichoice_question(choices_dict), role="user" + ) + ] + response_text = sampler(prompt_messages) + match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text) + extracted_answer = match.group(1) if match else None + score = 1.0 if extracted_answer == correct_answer else 0.0 + html = common.jinja_env.from_string(HTML_JINJA).render( + prompt_messages=prompt_messages, + next_message=dict(content=response_text, role="assistant"), + score=score, + correct_answer=correct_answer, + extracted_answer=extracted_answer, + ) + convo = prompt_messages + [dict(content=response_text, role="assistant")] + return SingleEvalResult( + html=html, + score=score, + convo=convo, + metrics={"chars": len(response_text)}, + ) + + results = common.map_with_progress(fn, self.examples, self.num_threads) + return common.aggregate_results(results) diff --git a/python/sglang/test/simple_eval_math.py b/python/sglang/test/simple_eval_math.py new file mode 100644 index 000000000..4ddb650d9 --- /dev/null +++ b/python/sglang/test/simple_eval_math.py @@ -0,0 +1,72 @@ +# Adapted from https://github.com/openai/simple-evals/ + +""" +Measuring Mathematical Problem Solving With the MATH Dataset +Dan Hendrycks, Collin Burns, Saurav Kadavath, Akul Arora, Steven Basart, Eric Tang, Dawn Song, Jacob Steinhardt +https://arxiv.org/abs/2103.03874 +""" + +import random +import re + +import pandas + +from sglang.test import simple_eval_common as common +from sglang.test.simple_eval_common import ( + ANSWER_PATTERN, + HTML_JINJA, + Eval, + EvalResult, + SamplerBase, + SingleEvalResult, + check_equality, +) + +QUERY_TEMPLATE = """ +Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem. + +{Question} + +Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command. +""".strip() + + +class MathEval(Eval): + def __init__( + self, + filename: str, + equality_checker: SamplerBase, + num_examples: int | None, + num_threads: int, + ): + df = pandas.read_csv(filename) + examples = [row.to_dict() for _, row in df.iterrows()] + if num_examples: + examples = random.Random(0).sample(examples, num_examples) + self.examples = examples + self.equality_checker = equality_checker + self.num_threads = num_threads + + def __call__(self, sampler: SamplerBase) -> EvalResult: + def fn(row: dict): + prompt_messages = [ + sampler._pack_message(content=QUERY_TEMPLATE.format(**row), role="user") + ] + response_text = sampler(prompt_messages) + match = re.search(ANSWER_PATTERN, response_text) + extracted_answer = match.group(1) if match else None + score = float( + check_equality(self.equality_checker, row["Answer"], extracted_answer) + ) + html = common.jinja_env.from_string(HTML_JINJA).render( + prompt_messages=prompt_messages, + next_message=dict(content=response_text, role="assistant"), + score=score, + correct_answer=row["Answer"], + extracted_answer=extracted_answer, + ) + convo = prompt_messages + [dict(content=response_text, role="assistant")] + return SingleEvalResult(html=html, score=score, convo=convo) + + results = common.map_with_progress(fn, self.examples, self.num_threads) + return common.aggregate_results(results) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 4348b57e9..2ab009eba 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -1,9 +1,14 @@ """Common utilities for testing and benchmarking""" +import argparse import asyncio +import multiprocessing import subprocess +import threading import time +import unittest from functools import partial +from typing import Callable, Optional import numpy as np import requests @@ -247,7 +252,7 @@ async def call_select_lmql(context, choices, temperature=0, max_len=4096, model= return choices.index(answer) -def add_common_other_args_and_parse(parser): +def add_common_other_args_and_parse(parser: argparse.ArgumentParser): parser.add_argument("--parallel", type=int, default=64) parser.add_argument("--host", type=str, default="http://127.0.0.1") parser.add_argument("--port", type=int, default=None) @@ -286,7 +291,7 @@ def add_common_other_args_and_parse(parser): return args -def add_common_sglang_args_and_parse(parser): +def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser): parser.add_argument("--parallel", type=int, default=64) parser.add_argument("--host", type=str, default="http://127.0.0.1") parser.add_argument("--port", type=int, default=30000) @@ -296,7 +301,7 @@ def add_common_sglang_args_and_parse(parser): return args -def select_sglang_backend(args): +def select_sglang_backend(args: argparse.Namespace): if args.backend.startswith("srt"): if args.backend == "srt-no-parallel": global_config.enable_parallel_decoding = False @@ -309,7 +314,7 @@ def select_sglang_backend(args): return backend -def _get_call_generate(args): +def _get_call_generate(args: argparse.Namespace): if args.backend == "lightllm": return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate") elif args.backend == "vllm": @@ -336,7 +341,7 @@ def _get_call_generate(args): raise ValueError(f"Invalid backend: {args.backend}") -def _get_call_select(args): +def _get_call_select(args: argparse.Namespace): if args.backend == "lightllm": return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate") elif args.backend == "vllm": @@ -359,7 +364,7 @@ def _get_call_select(args): raise ValueError(f"Invalid backend: {args.backend}") -def get_call_generate(args): +def get_call_generate(args: argparse.Namespace): call_generate = _get_call_generate(args) def func(*args, **kwargs): @@ -372,7 +377,7 @@ def get_call_generate(args): return func -def get_call_select(args): +def get_call_select(args: argparse.Namespace): call_select = _get_call_select(args) def func(*args, **kwargs): @@ -385,7 +390,12 @@ def get_call_select(args): return func -def popen_launch_server(model, port, timeout, *args): +def popen_launch_server( + model: str, base_url: str, timeout: float, other_args: tuple = () +): + _, host, port = base_url.split(":") + host = host[2:] + command = [ "python3", "-m", @@ -393,21 +403,81 @@ def popen_launch_server(model, port, timeout, *args): "--model-path", model, "--host", - "localhost", + host, "--port", - str(port), - *args, + port, + *other_args, ] process = subprocess.Popen(command, stdout=None, stderr=None) - base_url = f"http://localhost:{port}/v1" start_time = time.time() while time.time() - start_time < timeout: try: - response = requests.get(f"{base_url}/models") + response = requests.get(f"{base_url}/v1/models") if response.status_code == 200: return process except requests.RequestException: pass time.sleep(10) raise TimeoutError("Server failed to start within the timeout period.") + + +def run_with_timeout( + func: Callable, + args: tuple = (), + kwargs: Optional[dict] = None, + timeout: float = None, +): + """Run a function with timeout.""" + ret_value = [] + + def _target_func(): + ret_value.append(func(*args, **(kwargs or {}))) + + t = threading.Thread(target=_target_func) + t.start() + t.join(timeout=timeout) + if t.is_alive(): + raise TimeoutError() + + if not ret_value: + raise RuntimeError() + + return ret_value[0] + + +def run_unittest_files(files: list[str], timeout_per_file: float): + tic = time.time() + success = True + + for filename in files: + + def func(): + print(f"\n\nRun {filename}\n\n") + ret = unittest.main(module=None, argv=["", "-vb"] + [filename]) + + p = multiprocessing.Process(target=func) + + def run_one_file(): + p.start() + p.join() + + try: + run_with_timeout(run_one_file, timeout=timeout_per_file) + if p.exitcode != 0: + success = False + break + except TimeoutError: + p.terminate() + time.sleep(5) + print( + "\nTimeout after {timeout_per_file} seconds when running {filename}\n" + ) + return False + + if success: + print(f"Success. Time elapsed: {time.time() - tic:.2f}s") + else: + print(f"Fail. Time elapsed: {time.time() - tic:.2f}s") + + return 0 if success else -1 diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 838879d5d..27a8c40b8 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -12,6 +12,7 @@ import urllib.request from concurrent.futures import ThreadPoolExecutor from io import BytesIO from json import dumps +from typing import Union import numpy as np import requests @@ -25,7 +26,7 @@ def get_exception_traceback(): return err_str -def is_same_type(values): +def is_same_type(values: list): """Return whether the elements in values are of the same type.""" if len(values) <= 1: return True @@ -45,7 +46,7 @@ def read_jsonl(filename: str): return rets -def dump_state_text(filename, states, mode="w"): +def dump_state_text(filename: str, states: list, mode: str = "w"): """Dump program state in a text file.""" from sglang.lang.interpreter import ProgramState @@ -105,7 +106,7 @@ def http_request( return HttpResponse(e) -def encode_image_base64(image_path): +def encode_image_base64(image_path: Union[str, bytes]): """Encode an image in base64.""" if isinstance(image_path, str): with open(image_path, "rb") as image_file: @@ -144,7 +145,7 @@ def encode_frame(frame): return frame_bytes -def encode_video_base64(video_path, num_frames=16): +def encode_video_base64(video_path: str, num_frames: int = 16): import cv2 # pip install opencv-python-headless cap = cv2.VideoCapture(video_path) @@ -190,7 +191,7 @@ def encode_video_base64(video_path, num_frames=16): return video_base64 -def _is_chinese_char(cp): +def _is_chinese_char(cp: int): """Checks whether CP is the codepoint of a CJK character.""" # This defines a "chinese character" as anything in the CJK Unicode block: # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) @@ -215,7 +216,7 @@ def _is_chinese_char(cp): return False -def find_printable_text(text): +def find_printable_text(text: str): """Returns the longest printable substring of text that contains only entire words.""" # Borrowed from https://github.com/huggingface/transformers/blob/061580c82c2db1de9139528243e105953793f7a2/src/transformers/generation/streamers.py#L99 @@ -234,26 +235,7 @@ def find_printable_text(text): return text[: text.rfind(" ") + 1] -def run_with_timeout(func, args=(), kwargs=None, timeout=None): - """Run a function with timeout.""" - ret_value = [] - - def _target_func(): - ret_value.append(func(*args, **(kwargs or {}))) - - t = threading.Thread(target=_target_func) - t.start() - t.join(timeout=timeout) - if t.is_alive(): - raise TimeoutError() - - if not ret_value: - raise RuntimeError() - - return ret_value[0] - - -def graceful_registry(sub_module_name): +def graceful_registry(sub_module_name: str): def graceful_shutdown(signum, frame): logger.info( f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..." @@ -265,7 +247,9 @@ def graceful_registry(sub_module_name): class LazyImport: - def __init__(self, module_name, class_name): + """Lazy import to make `import sglang` run faster.""" + + def __init__(self, module_name: str, class_name: str): self.module_name = module_name self.class_name = class_name self._module = None @@ -276,7 +260,7 @@ class LazyImport: self._module = getattr(module, self.class_name) return self._module - def __getattr__(self, name): + def __getattr__(self, name: str): module = self._load() return getattr(module, name) diff --git a/test/srt/deprecated/test_curl.sh b/scripts/deprecated/test_curl.sh similarity index 100% rename from test/srt/deprecated/test_curl.sh rename to scripts/deprecated/test_curl.sh diff --git a/test/srt/deprecated/test_flashinfer.py b/scripts/deprecated/test_flashinfer.py similarity index 100% rename from test/srt/deprecated/test_flashinfer.py rename to scripts/deprecated/test_flashinfer.py diff --git a/test/srt/deprecated/test_httpserver_classify.py b/scripts/deprecated/test_httpserver_classify.py similarity index 100% rename from test/srt/deprecated/test_httpserver_classify.py rename to scripts/deprecated/test_httpserver_classify.py diff --git a/test/srt/deprecated/test_httpserver_concurrent.py b/scripts/deprecated/test_httpserver_concurrent.py similarity index 100% rename from test/srt/deprecated/test_httpserver_concurrent.py rename to scripts/deprecated/test_httpserver_concurrent.py diff --git a/test/srt/deprecated/test_httpserver_decode.py b/scripts/deprecated/test_httpserver_decode.py similarity index 100% rename from test/srt/deprecated/test_httpserver_decode.py rename to scripts/deprecated/test_httpserver_decode.py diff --git a/test/srt/deprecated/test_httpserver_decode_stream.py b/scripts/deprecated/test_httpserver_decode_stream.py similarity index 100% rename from test/srt/deprecated/test_httpserver_decode_stream.py rename to scripts/deprecated/test_httpserver_decode_stream.py diff --git a/test/srt/deprecated/test_httpserver_llava.py b/scripts/deprecated/test_httpserver_llava.py similarity index 100% rename from test/srt/deprecated/test_httpserver_llava.py rename to scripts/deprecated/test_httpserver_llava.py diff --git a/test/srt/deprecated/test_httpserver_reuse.py b/scripts/deprecated/test_httpserver_reuse.py similarity index 100% rename from test/srt/deprecated/test_httpserver_reuse.py rename to scripts/deprecated/test_httpserver_reuse.py diff --git a/test/srt/deprecated/test_jump_forward.py b/scripts/deprecated/test_jump_forward.py similarity index 100% rename from test/srt/deprecated/test_jump_forward.py rename to scripts/deprecated/test_jump_forward.py diff --git a/test/srt/deprecated/test_openai_server.py b/scripts/deprecated/test_openai_server.py similarity index 100% rename from test/srt/deprecated/test_openai_server.py rename to scripts/deprecated/test_openai_server.py diff --git a/test/srt/deprecated/test_robust.py b/scripts/deprecated/test_robust.py similarity index 100% rename from test/srt/deprecated/test_robust.py rename to scripts/deprecated/test_robust.py diff --git a/scripts/format.sh b/scripts/format.sh deleted file mode 100644 index a49aed745..000000000 --- a/scripts/format.sh +++ /dev/null @@ -1,8 +0,0 @@ -isort python -black python - -isort test -black test - -isort benchmark -black benchmark diff --git a/scripts/launch_tgi.sh b/scripts/launch_tgi.sh deleted file mode 100644 index eeb405475..000000000 --- a/scripts/launch_tgi.sh +++ /dev/null @@ -1,6 +0,0 @@ -docker run --name tgi --rm -ti --gpus all --network host \ - -v /home/ubuntu/model_weights/Llama-2-7b-chat-hf:/Llama-2-7b-chat-hf \ - ghcr.io/huggingface/text-generation-inference:1.3.0 \ - --model-id /Llama-2-7b-chat-hf --num-shard 1 --trust-remote-code \ - --max-input-length 2048 --max-total-tokens 4096 \ - --port 24000 diff --git a/test/README.md b/test/README.md new file mode 100644 index 000000000..cdfbbaee8 --- /dev/null +++ b/test/README.md @@ -0,0 +1,26 @@ +# Run Unit Tests + +## Test Frontend Language +``` +cd sglang/test/lang +export OPENAI_API_KEY=sk-***** + +# Run a single file +python3 test_openai_backend.py + +# Run a suite +python3 run_suite.py --suite minimal +``` + +## Test Backend Runtime +``` +cd sglang/test/srt + +# Run a single file +python3 test_eval_accuracy.py + +# Run a suite +python3 run_suite.py --suite minimal +``` + + diff --git a/test/lang/run_suite.py b/test/lang/run_suite.py index 4b0c961ef..379427afa 100644 --- a/test/lang/run_suite.py +++ b/test/lang/run_suite.py @@ -1,50 +1,17 @@ import argparse import glob -import multiprocessing -import os -import time -import unittest -from sglang.utils import run_with_timeout +from sglang.test.test_utils import run_unittest_files suites = { - "minimal": ["test_openai_backend.py", "test_srt_backend.py"], + "minimal": ["test_srt_backend.py", "test_openai_backend.py"], } -def run_unittest_files(files, args): - for filename in files: - - def func(): - print(filename) - ret = unittest.main(module=None, argv=["", "-vb"] + [filename]) - - p = multiprocessing.Process(target=func) - - def run_one_file(): - p.start() - p.join() - - try: - run_with_timeout(run_one_file, timeout=args.time_limit_per_file) - if p.exitcode != 0: - return False - except TimeoutError: - p.terminate() - time.sleep(5) - print( - f"\nTimeout after {args.time_limit_per_file} seconds " - f"when running {filename}" - ) - return False - - return True - - if __name__ == "__main__": arg_parser = argparse.ArgumentParser() arg_parser.add_argument( - "--time-limit-per-file", + "--timeout-per-file", type=int, default=1000, help="The time limit for running one file in seconds.", @@ -63,12 +30,5 @@ if __name__ == "__main__": else: files = suites[args.suite] - tic = time.time() - success = run_unittest_files(files, args) - - if success: - print(f"Success. Time elapsed: {time.time() - tic:.2f}s") - else: - print(f"Fail. Time elapsed: {time.time() - tic:.2f}s") - - exit(0 if success else -1) + exit_code = run_unittest_files(files, args.timeout_per_file) + exit(exit_code) diff --git a/test/srt/models/test_causal_models.py b/test/srt/models/test_causal_models.py index 3cec4490a..0522816b3 100644 --- a/test/srt/models/test_causal_models.py +++ b/test/srt/models/test_causal_models.py @@ -18,6 +18,7 @@ import torch from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner MODELS = [ + # (model_name, tp_size) ("meta-llama/Meta-Llama-3.1-8B-Instruct", 1), # ("meta-llama/Meta-Llama-3.1-8B-Instruct", 2), ] diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py new file mode 100644 index 000000000..ab9ae0f41 --- /dev/null +++ b/test/srt/run_suite.py @@ -0,0 +1,40 @@ +import argparse +import glob + +from sglang.test.test_utils import run_unittest_files + +suites = { + "minimal": [ + "test_openai_server.py", + "test_eval_accuracy.py", + "test_chunked_prefill.py", + "test_torch_compile.py", + "models/test_causal_models.py", + ], +} + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument( + "--timeout-per-file", + type=int, + default=1000, + help="The time limit for running one file in seconds.", + ) + arg_parser.add_argument( + "--suite", + type=str, + default=list(suites.keys())[0], + choices=list(suites.keys()) + ["all"], + help="The suite to run", + ) + args = arg_parser.parse_args() + + if args.suite == "all": + files = glob.glob("**/test_*.py", recursive=True) + else: + files = suites[args.suite] + + exit_code = run_unittest_files(files, args.timeout_per_file) + exit(exit_code) diff --git a/test/srt/test_chunked_prefill.py b/test/srt/test_chunked_prefill.py new file mode 100644 index 000000000..3380f6aa8 --- /dev/null +++ b/test/srt/test_chunked_prefill.py @@ -0,0 +1,45 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server + + +class TestAccuracy(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.model = MODEL_NAME_FOR_TEST + cls.base_url = f"http://localhost:30000" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=300, + other_args=["--chunked-prefill-size", "32"], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=20, + num_threads=20, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.5 + + +if __name__ == "__main__": + unittest.main(warnings="ignore") + + # t = TestAccuracy() + # t.setUpClass() + # t.test_mmlu() + # t.tearDownClass() diff --git a/test/srt/test_eval_accuracy.py b/test/srt/test_eval_accuracy.py index d392dc4c0..dc3f8266b 100644 --- a/test/srt/test_eval_accuracy.py +++ b/test/srt/test_eval_accuracy.py @@ -1,4 +1,3 @@ -import json import unittest from types import SimpleNamespace @@ -11,11 +10,9 @@ class TestAccuracy(unittest.TestCase): @classmethod def setUpClass(cls): - port = 30000 - cls.model = MODEL_NAME_FOR_TEST - cls.base_url = f"http://localhost:{port}" - cls.process = popen_launch_server(cls.model, port, timeout=300) + cls.base_url = f"http://localhost:30000" + cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) @classmethod def tearDownClass(cls): diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 76a105a62..a2b934b6b 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -11,11 +11,10 @@ class TestOpenAIServer(unittest.TestCase): @classmethod def setUpClass(cls): - port = 30000 - cls.model = MODEL_NAME_FOR_TEST - cls.base_url = f"http://localhost:{port}/v1" - cls.process = popen_launch_server(cls.model, port, timeout=300) + cls.base_url = f"http://localhost:30000" + cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) + cls.base_url += "/v1" @classmethod def tearDownClass(cls): diff --git a/test/srt/test_torch_compile.py b/test/srt/test_torch_compile.py new file mode 100644 index 000000000..efd9c4698 --- /dev/null +++ b/test/srt/test_torch_compile.py @@ -0,0 +1,42 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server + + +class TestAccuracy(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.model = MODEL_NAME_FOR_TEST + cls.base_url = f"http://localhost:30000" + cls.process = popen_launch_server( + cls.model, cls.base_url, timeout=300, other_args=["--enable-torch-compile"] + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=20, + num_threads=20, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.5 + + +if __name__ == "__main__": + unittest.main(warnings="ignore") + + # t = TestAccuracy() + # t.setUpClass() + # t.test_mmlu() + # t.tearDownClass()