diff --git a/.github/workflows/e2e-test.yml b/.github/workflows/e2e-test.yml index 9630ca718..7b59054fe 100644 --- a/.github/workflows/e2e-test.yml +++ b/.github/workflows/e2e-test.yml @@ -18,7 +18,7 @@ concurrency: cancel-in-progress: true jobs: - pr-e2e-test: + e2e-test: runs-on: self-hosted env: @@ -38,7 +38,7 @@ jobs: pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ --force-reinstall pip install --upgrade transformers - - name: Benchmark Serving + - name: Benchmark Serving Throughput run: | cd /data/zhyncs/venv && source ./bin/activate && cd - python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --port 8413 --disable-radix-cache & diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index dc464fa8c..f1c069ea5 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -59,3 +59,10 @@ jobs: cd test/srt python3 test_openai_server.py + + - name: Test Accuracy + run: | + cd /data/zhyncs/venv && source ./bin/activate && cd - + + cd test/srt + python3 test_eval_accuracy.py diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index b52e114fd..253aab355 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -21,7 +21,7 @@ import sys import time import traceback import warnings -from argparse import ArgumentParser as FlexibleArgumentParser +from argparse import ArgumentParser from dataclasses import dataclass, field from datetime import datetime from typing import AsyncGenerator, List, Optional, Tuple, Union @@ -868,14 +868,12 @@ def set_ulimit(target_soft_limit=65535): if __name__ == "__main__": - parser = FlexibleArgumentParser( - description="Benchmark the online serving throughput." - ) + parser = ArgumentParser(description="Benchmark the online serving throughput.") parser.add_argument( "--backend", type=str, - required=True, choices=list(ASYNC_REQUEST_FUNCS.keys()), + default="sglang", help="Must specify a backend, depending on the LLM Inference Engine.", ) parser.add_argument( diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py new file mode 100644 index 000000000..3729ef7ab --- /dev/null +++ b/python/sglang/test/run_eval.py @@ -0,0 +1,99 @@ +""" +Usage: +python3 -m sglang.test.run_eval --port 30000 --eval-name mmlu --num-examples 10 +""" + +import argparse +import json +import os +import time + +from sglang.test.simple_eval_common import ( + ChatCompletionSampler, + download_dataset, + make_report, + set_ulimit, +) +from sglang.test.simple_eval_mmlu import MMLUEval + + +def run_eval(args): + if "OPENAI_API_KEY" not in os.environ: + os.environ["OPENAI_API_KEY"] = "EMPTY" + + base_url = ( + f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1" + ) + + if args.eval_name == "mmlu": + dataset_path = "mmlu.csv" + + if not os.path.exists(dataset_path): + download_dataset( + dataset_path, + "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", + ) + eval_obj = MMLUEval(dataset_path, args.num_examples, args.num_threads) + else: + raise ValueError(f"Invalid eval name: {args.eval_name}") + + sampler = ChatCompletionSampler( + model=args.model, + max_tokens=2048, + base_url=base_url, + ) + + # Run eval + tic = time.time() + result = eval_obj(sampler) + latency = time.time() - tic + + # Dump reports + metrics = result.metrics | {"score": result.score} + file_stem = f"mmlu_{sampler.model.replace('/', '_')}" + report_filename = f"/tmp/{file_stem}.html" + print(f"Writing report to {report_filename}") + with open(report_filename, "w") as fh: + fh.write(make_report(result)) + metrics = result.metrics | {"score": result.score} + print(metrics) + result_filename = f"/tmp/{file_stem}.json" + with open(result_filename, "w") as f: + f.write(json.dumps(metrics, indent=2)) + print(f"Writing results to {result_filename}") + + # Print results + print(f"Total latency: {latency:.3f} s") + print(f"Score: {metrics['score']:.3f}") + + return metrics + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0." + ) + parser.add_argument( + "--port", + type=int, + help="If not set, the default port is configured according to its default value for different LLM Inference Engines.", + ) + parser.add_argument( + "--model", + type=str, + help="Name or path of the model. If not set, the default model will request /v1/models for conf.", + ) + parser.add_argument("--eval-name", type=str, default="mmlu") + parser.add_argument("--num-examples", type=int) + parser.add_argument("--num-threads", type=int, default=64) + set_ulimit() + args = parser.parse_args() + + run_eval(args) diff --git a/python/sglang/test/simple_eval_common.py b/python/sglang/test/simple_eval_common.py new file mode 100644 index 000000000..75c26f0f0 --- /dev/null +++ b/python/sglang/test/simple_eval_common.py @@ -0,0 +1,456 @@ +# Adapted from https://github.com/openai/simple-evals/ + +import base64 +import os +import resource +import time +from collections import defaultdict +from dataclasses import dataclass, field +from multiprocessing.pool import ThreadPool +from typing import Any + +import jinja2 +import numpy as np +import openai +import requests +from openai import OpenAI +from tqdm import tqdm + +OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant." +OPENAI_SYSTEM_MESSAGE_CHATGPT = ( + "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture." + + "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01" +) + + +Message = dict[str, Any] # keys role, content +MessageList = list[Message] + + +class SamplerBase: + """ + Base class for defining a sampling model, which can be evaluated, + or used as part of the grading process. + """ + + def __call__(self, message_list: MessageList) -> str: + raise NotImplementedError() + + +@dataclass +class EvalResult: + """ + Result of running an evaluation (usually consisting of many samples) + """ + + score: float | None # top-line metric + metrics: dict[str, float] | None # other metrics + htmls: list[str] # strings of valid HTML + convos: list[MessageList] # sampled conversations + + +@dataclass +class SingleEvalResult: + """ + Result of evaluating a single sample + """ + + score: float | None + metrics: dict[str, float] = field(default_factory=dict) + html: str | None = None + convo: MessageList | None = None # sampled conversation + + +class Eval: + """ + Base class for defining an evaluation. + """ + + def __call__(self, sampler: SamplerBase) -> EvalResult: + raise NotImplementedError() + + +class ChatCompletionSampler(SamplerBase): + """ + Sample from OpenAI's chat completion API + """ + + def __init__( + self, + base_url: str = None, + model: str | None = None, + system_message: str | None = None, + temperature: float = 0.0, + max_tokens: int = 2048, + ): + self.client = OpenAI(base_url=base_url) + + if model is None: + model = self.client.models.list().data[0].id + + self.model = model + self.system_message = system_message + self.temperature = temperature + self.max_tokens = max_tokens + self.image_format = "url" + + def _handle_image( + self, + image: str, + encoding: str = "base64", + format: str = "png", + fovea: int = 768, + ): + new_image = { + "type": "image_url", + "image_url": { + "url": f"data:image/{format};{encoding},{image}", + }, + } + return new_image + + def _handle_text(self, text: str): + return {"type": "text", "text": text} + + def _pack_message(self, role: str, content: Any): + return {"role": str(role), "content": content} + + def __call__(self, message_list: MessageList) -> str: + if self.system_message: + message_list = [ + self._pack_message("system", self.system_message) + ] + message_list + trial = 0 + while True: + try: + response = self.client.chat.completions.create( + model=self.model, + messages=message_list, + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + return response.choices[0].message.content + # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU + except openai.BadRequestError as e: + print("Bad Request Error", e) + return "" + except Exception as e: + exception_backoff = 2**trial # expontial back off + print( + f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec", + e, + ) + time.sleep(exception_backoff) + trial += 1 + # unknown error shall throw exception + + +QUERY_TEMPLATE_MULTICHOICE = """ +Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. + +{Question} + +A) {A} +B) {B} +C) {C} +D) {D} +""".strip() + +ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])" +ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)" + + +EQUALITY_TEMPLATE = r""" +Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications + +Examples: + + Expression 1: $2x+3$ + Expression 2: $3+2x$ + +Yes + + Expression 1: 3/2 + Expression 2: 1.5 + +Yes + + Expression 1: $x^2+2x+1$ + Expression 2: $y^2+2y+1$ + +No + + Expression 1: $x^2+2x+1$ + Expression 2: $(x+1)^2$ + +Yes + + Expression 1: 3245/5 + Expression 2: 649 + +No +(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications) + + Expression 1: 2/(-3) + Expression 2: -2/3 + +Yes +(trivial simplifications are allowed) + + Expression 1: 72 degrees + Expression 2: 72 + +Yes +(give benefit of the doubt to units) + + Expression 1: 64 + Expression 2: 64 square feet + +Yes +(give benefit of the doubt to units) + +--- + +YOUR TASK + + +Respond with only "Yes" or "No" (without quotes). Do not include a rationale. + + Expression 1: %(expression1)s + Expression 2: %(expression2)s +""".strip() + + +HTML_JINJA = """ +

Prompt conversation

+{% for message in prompt_messages %} +{{ message_to_html(message) | safe }} +{% endfor %} +

Sampled message

+{{ message_to_html(next_message) | safe }} +

Results

+

Correct Answer: {{ correct_answer }}

+

Extracted Answer: {{ extracted_answer }}

+

Score: {{ score }}

+""" + + +def format_multichoice_question(row): + return QUERY_TEMPLATE_MULTICHOICE.format(**row) + + +def check_equality(sampler: SamplerBase, expr1: str, expr2: str): + prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2} + response = sampler([dict(content=prompt, role="user")]) + return response.lower().strip() == "yes" + + +def _compute_stat(values: list, stat: str): + if stat == "mean": + return np.mean(values) + elif stat == "std": + return np.std(values) + elif stat == "min": + return np.min(values) + elif stat == "max": + return np.max(values) + else: + raise ValueError(f"Unknown {stat =}") + + +def aggregate_results( + single_eval_results: list[SingleEvalResult], + default_stats: tuple[str] = ("mean", "std"), + name2stats: dict[str, tuple[str]] | None = None, +) -> EvalResult: + """ + Aggregate results from multiple evaluations into a single EvalResult. + """ + name2stats = name2stats or {} + name2values = defaultdict(list) + htmls = [] + convos = [] + for single_eval_result in single_eval_results: + for name, value in single_eval_result.metrics.items(): + name2values[name].append(value) + if single_eval_result.score is not None: + name2values["score"].append(single_eval_result.score) + htmls.append(single_eval_result.html) + convos.append(single_eval_result.convo) + final_metrics = {} + for name, values in name2values.items(): + stats = name2stats.get(name, default_stats) + for stat in stats: + key = name if stat == "mean" else f"{name}:{stat}" + final_metrics[key] = _compute_stat(values, stat) + return EvalResult( + score=final_metrics.pop("score", None), + metrics=final_metrics, + htmls=htmls, + convos=convos, + ) + + +def map_with_progress(f: callable, xs: list[Any], num_threads: int): + """ + Apply f to each element of xs, using a ThreadPool, and show progress. + """ + if os.getenv("debug"): + return list(map(f, tqdm(xs, total=len(xs)))) + else: + with ThreadPool(min(num_threads, len(xs))) as pool: + return list(tqdm(pool.imap(f, xs), total=len(xs))) + + +jinja_env = jinja2.Environment( + loader=jinja2.BaseLoader(), + undefined=jinja2.StrictUndefined, + autoescape=jinja2.select_autoescape(["html", "xml"]), +) +_message_template = """ +
+
+ {{ role }} + {% if variant %}({{ variant }}){% endif %} +
+
+
{{ content }}
+
+
+""" + + +def message_to_html(message: Message) -> str: + """ + Generate HTML snippet (inside a
) for a message. + """ + return jinja_env.from_string(_message_template).render( + role=message["role"], + content=message["content"], + variant=message.get("variant", None), + ) + + +jinja_env.globals["message_to_html"] = message_to_html + + +_report_template = """ + + + + + + {% if metrics %} +

Metrics

+ + + + + + + + + + {% for name, value in metrics.items() %} + + + + + {% endfor %} +
MetricValue
Score{{ score | float | round(3) }}
{{ name }}{{ value }}
+ {% endif %} +

Examples

+ {% for html in htmls %} + {{ html | safe }} +
+ {% endfor %} + + +""" + + +def make_report(eval_result: EvalResult) -> str: + """ + Create a standalone HTML report from an EvalResult. + """ + return jinja_env.from_string(_report_template).render( + score=eval_result.score, + metrics=eval_result.metrics, + htmls=eval_result.htmls, + ) + + +def make_report_from_example_htmls(htmls: list[str]): + """ + Create a standalone HTML report from a list of example htmls + """ + return jinja_env.from_string(_report_template).render( + score=None, metrics={}, htmls=htmls + ) + + +def download_dataset(path, url): + print(f"Downloading dataset {path} from {url}") + try: + response = requests.get(url, stream=True) + response.raise_for_status() + + total_size = int(response.headers.get("content-length", 0)) + block_size = 8192 + + with open(path, "wb") as f, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as progress_bar: + for data in response.iter_content(block_size): + size = f.write(data) + progress_bar.update(size) + + print(f"Dataset downloaded and saved to {path}") + except requests.RequestException as e: + raise Exception(f"Failed to download dataset: {e}") + + +def set_ulimit(target_soft_limit=65535): + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) + except ValueError as e: + print(f"Fail to set RLIMIT_NOFILE: {e}") diff --git a/python/sglang/test/simple_eval_mmlu.py b/python/sglang/test/simple_eval_mmlu.py new file mode 100644 index 000000000..3c0287510 --- /dev/null +++ b/python/sglang/test/simple_eval_mmlu.py @@ -0,0 +1,120 @@ +# Adapted from https://github.com/openai/simple-evals/ + +""" +Measuring Massive Multitask Language Understanding +Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, Jacob Steinhardt +https://arxiv.org/abs/2009.03300 +""" + +import random +import re + +import pandas + +from sglang.test import simple_eval_common as common +from sglang.test.simple_eval_common import ( + ANSWER_PATTERN_MULTICHOICE, + HTML_JINJA, + Eval, + EvalResult, + SamplerBase, + SingleEvalResult, + format_multichoice_question, +) + +subject2category = { + "abstract_algebra": "stem", + "anatomy": "other", + "astronomy": "stem", + "business_ethics": "other", + "clinical_knowledge": "other", + "college_biology": "stem", + "college_chemistry": "stem", + "college_computer_science": "stem", + "college_mathematics": "stem", + "college_medicine": "other", + "college_physics": "stem", + "computer_security": "stem", + "conceptual_physics": "stem", + "econometrics": "social_sciences", + "electrical_engineering": "stem", + "elementary_mathematics": "stem", + "formal_logic": "humanities", + "global_facts": "other", + "high_school_biology": "stem", + "high_school_chemistry": "stem", + "high_school_computer_science": "stem", + "high_school_european_history": "humanities", + "high_school_geography": "social_sciences", + "high_school_government_and_politics": "social_sciences", + "high_school_macroeconomics": "social_sciences", + "high_school_mathematics": "stem", + "high_school_microeconomics": "social_sciences", + "high_school_physics": "stem", + "high_school_psychology": "social_sciences", + "high_school_statistics": "stem", + "high_school_us_history": "humanities", + "high_school_world_history": "humanities", + "human_aging": "other", + "human_sexuality": "social_sciences", + "international_law": "humanities", + "jurisprudence": "humanities", + "logical_fallacies": "humanities", + "machine_learning": "stem", + "management": "other", + "marketing": "other", + "medical_genetics": "other", + "miscellaneous": "other", + "moral_disputes": "humanities", + "moral_scenarios": "humanities", + "nutrition": "other", + "philosophy": "humanities", + "prehistory": "humanities", + "professional_accounting": "other", + "professional_law": "humanities", + "professional_medicine": "other", + "professional_psychology": "social_sciences", + "public_relations": "social_sciences", + "security_studies": "social_sciences", + "sociology": "social_sciences", + "us_foreign_policy": "social_sciences", + "virology": "other", + "world_religions": "humanities", +} + + +class MMLUEval(Eval): + def __init__(self, filename: str, num_examples: int | None, num_threads: int): + df = pandas.read_csv(filename) + examples = [row.to_dict() for _, row in df.iterrows()] + if num_examples: + examples = random.Random(0).sample(examples, num_examples) + self.examples = examples + self.num_threads = num_threads + + def __call__(self, sampler: SamplerBase) -> EvalResult: + def fn(row: dict): + prompt_messages = [ + sampler._pack_message( + content=format_multichoice_question(row), role="user" + ) + ] + response_text = sampler(prompt_messages) + match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text) + extracted_answer = match.group(1) if match else None + score = 1.0 if extracted_answer == row["Answer"] else 0.0 + html = common.jinja_env.from_string(HTML_JINJA).render( + prompt_messages=prompt_messages, + next_message=dict(content=response_text, role="assistant"), + score=score, + correct_answer=row["Answer"], + extracted_answer=extracted_answer, + ) + convo = prompt_messages + [dict(content=response_text, role="assistant")] + category = subject2category.get(row["Subject"], "other") + return SingleEvalResult( + html=html, score=score, metrics={category: score}, convo=convo + ) + + results = common.map_with_progress(fn, self.examples, self.num_threads) + return common.aggregate_results(results) diff --git a/python/sglang/test/test_conversation.py b/python/sglang/test/test_conversation.py deleted file mode 100644 index e6d9f396a..000000000 --- a/python/sglang/test/test_conversation.py +++ /dev/null @@ -1,46 +0,0 @@ -from sglang.srt.conversation import generate_chat_conv -from sglang.srt.managers.openai_api.protocol import ( - ChatCompletionMessageContentImagePart, - ChatCompletionMessageContentImageURL, - ChatCompletionMessageContentTextPart, - ChatCompletionMessageGenericParam, - ChatCompletionMessageUserParam, - ChatCompletionRequest, -) - - -def test_chat_completion_to_conv_image(): - """Test that we can convert a chat image request to a convo""" - request = ChatCompletionRequest( - model="default", - messages=[ - ChatCompletionMessageGenericParam( - role="system", content="You are a helpful AI assistant" - ), - ChatCompletionMessageUserParam( - role="user", - content=[ - ChatCompletionMessageContentTextPart( - type="text", text="Describe this image" - ), - ChatCompletionMessageContentImagePart( - type="image_url", - image_url=ChatCompletionMessageContentImageURL( - url="https://someurl.com" - ), - ), - ], - ), - ], - ) - conv = generate_chat_conv(request, "vicuna_v1.1") - assert conv.messages == [ - ["USER", "Describe this image"], - ["ASSISTANT", None], - ] - assert conv.system_message == "You are a helpful AI assistant" - assert conv.image_data == ["https://someurl.com"] - - -if __name__ == "__main__": - test_chat_completion_to_conv_image() diff --git a/python/sglang/test/test_openai_protocol.py b/python/sglang/test/test_openai_protocol.py deleted file mode 100644 index cade4728c..000000000 --- a/python/sglang/test/test_openai_protocol.py +++ /dev/null @@ -1,51 +0,0 @@ -from sglang.srt.managers.openai_api.protocol import ( - ChatCompletionMessageContentImagePart, - ChatCompletionMessageContentImageURL, - ChatCompletionMessageContentTextPart, - ChatCompletionMessageGenericParam, - ChatCompletionMessageUserParam, - ChatCompletionRequest, -) - - -def test_chat_completion_request_image(): - """Test that Chat Completion Requests with images can be converted.""" - - image_request = { - "model": "default", - "messages": [ - {"role": "system", "content": "You are a helpful AI assistant"}, - { - "role": "user", - "content": [ - {"type": "text", "text": "Describe this image"}, - {"type": "image_url", "image_url": {"url": "https://someurl.com"}}, - ], - }, - ], - "temperature": 0, - "max_tokens": 64, - } - request = ChatCompletionRequest(**image_request) - assert len(request.messages) == 2 - assert request.messages[0] == ChatCompletionMessageGenericParam( - role="system", content="You are a helpful AI assistant" - ) - assert request.messages[1] == ChatCompletionMessageUserParam( - role="user", - content=[ - ChatCompletionMessageContentTextPart( - type="text", text="Describe this image" - ), - ChatCompletionMessageContentImagePart( - type="image_url", - image_url=ChatCompletionMessageContentImageURL( - url="https://someurl.com" - ), - ), - ], - ) - - -if __name__ == "__main__": - test_chat_completion_request_image() diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index af7f3765e..4348b57e9 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -1,6 +1,8 @@ """Common utilities for testing and benchmarking""" import asyncio +import subprocess +import time from functools import partial import numpy as np @@ -11,6 +13,8 @@ from sglang.lang.backend.openai import OpenAI from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.utils import get_exception_traceback +MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct" + def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): assert url is not None @@ -379,3 +383,31 @@ def get_call_select(args): raise return func + + +def popen_launch_server(model, port, timeout, *args): + command = [ + "python3", + "-m", + "sglang.launch_server", + "--model-path", + model, + "--host", + "localhost", + "--port", + str(port), + *args, + ] + process = subprocess.Popen(command, stdout=None, stderr=None) + base_url = f"http://localhost:{port}/v1" + + start_time = time.time() + while time.time() - start_time < timeout: + try: + response = requests.get(f"{base_url}/models") + if response.status_code == 200: + return process + except requests.RequestException: + pass + time.sleep(10) + raise TimeoutError("Server failed to start within the timeout period.") diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py index f9d79ed29..7accd349f 100644 --- a/test/lang/test_srt_backend.py +++ b/test/lang/test_srt_backend.py @@ -14,6 +14,7 @@ from sglang.test.test_programs import ( test_stream, test_tool_use, ) +from sglang.test.test_utils import MODEL_NAME_FOR_TEST class TestSRTBackend(unittest.TestCase): @@ -21,7 +22,7 @@ class TestSRTBackend(unittest.TestCase): @classmethod def setUpClass(cls): - cls.backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct") + cls.backend = sgl.Runtime(model_path=MODEL_NAME_FOR_TEST) sgl.set_default_backend(cls.backend) @classmethod diff --git a/test/srt/old/test_curl.sh b/test/srt/deprecated/test_curl.sh similarity index 100% rename from test/srt/old/test_curl.sh rename to test/srt/deprecated/test_curl.sh diff --git a/test/srt/old/test_flashinfer.py b/test/srt/deprecated/test_flashinfer.py similarity index 100% rename from test/srt/old/test_flashinfer.py rename to test/srt/deprecated/test_flashinfer.py diff --git a/test/srt/old/test_httpserver_classify.py b/test/srt/deprecated/test_httpserver_classify.py similarity index 100% rename from test/srt/old/test_httpserver_classify.py rename to test/srt/deprecated/test_httpserver_classify.py diff --git a/test/srt/old/test_httpserver_concurrent.py b/test/srt/deprecated/test_httpserver_concurrent.py similarity index 100% rename from test/srt/old/test_httpserver_concurrent.py rename to test/srt/deprecated/test_httpserver_concurrent.py diff --git a/test/srt/old/test_httpserver_decode.py b/test/srt/deprecated/test_httpserver_decode.py similarity index 100% rename from test/srt/old/test_httpserver_decode.py rename to test/srt/deprecated/test_httpserver_decode.py diff --git a/test/srt/old/test_httpserver_decode_stream.py b/test/srt/deprecated/test_httpserver_decode_stream.py similarity index 100% rename from test/srt/old/test_httpserver_decode_stream.py rename to test/srt/deprecated/test_httpserver_decode_stream.py diff --git a/test/srt/old/test_httpserver_llava.py b/test/srt/deprecated/test_httpserver_llava.py similarity index 100% rename from test/srt/old/test_httpserver_llava.py rename to test/srt/deprecated/test_httpserver_llava.py diff --git a/test/srt/old/test_httpserver_reuse.py b/test/srt/deprecated/test_httpserver_reuse.py similarity index 100% rename from test/srt/old/test_httpserver_reuse.py rename to test/srt/deprecated/test_httpserver_reuse.py diff --git a/test/srt/old/test_jump_forward.py b/test/srt/deprecated/test_jump_forward.py similarity index 100% rename from test/srt/old/test_jump_forward.py rename to test/srt/deprecated/test_jump_forward.py diff --git a/test/srt/old/test_openai_server.py b/test/srt/deprecated/test_openai_server.py similarity index 100% rename from test/srt/old/test_openai_server.py rename to test/srt/deprecated/test_openai_server.py diff --git a/test/srt/old/test_robust.py b/test/srt/deprecated/test_robust.py similarity index 100% rename from test/srt/old/test_robust.py rename to test/srt/deprecated/test_robust.py diff --git a/test/srt/test_eval_accuracy.py b/test/srt/test_eval_accuracy.py new file mode 100644 index 000000000..d392dc4c0 --- /dev/null +++ b/test/srt/test_eval_accuracy.py @@ -0,0 +1,43 @@ +import json +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server + + +class TestAccuracy(unittest.TestCase): + + @classmethod + def setUpClass(cls): + port = 30000 + + cls.model = MODEL_NAME_FOR_TEST + cls.base_url = f"http://localhost:{port}" + cls.process = popen_launch_server(cls.model, port, timeout=300) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=20, + num_threads=20, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.5 + + +if __name__ == "__main__": + unittest.main(warnings="ignore") + + # t = TestAccuracy() + # t.setUpClass() + # t.test_mmlu() + # t.tearDownClass() diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index e15c2ba88..76a105a62 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -1,47 +1,21 @@ import json -import subprocess -import time import unittest import openai -import requests from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server class TestOpenAIServer(unittest.TestCase): @classmethod def setUpClass(cls): - model = "meta-llama/Meta-Llama-3.1-8B-Instruct" port = 30000 - timeout = 300 - command = [ - "python3", - "-m", - "sglang.launch_server", - "--model-path", - model, - "--host", - "localhost", - "--port", - str(port), - ] - cls.process = subprocess.Popen(command, stdout=None, stderr=None) + cls.model = MODEL_NAME_FOR_TEST cls.base_url = f"http://localhost:{port}/v1" - cls.model = model - - start_time = time.time() - while time.time() - start_time < timeout: - try: - response = requests.get(f"{cls.base_url}/models") - if response.status_code == 200: - return - except requests.RequestException: - pass - time.sleep(10) - raise TimeoutError("Server failed to start within the timeout period.") + cls.process = popen_launch_server(cls.model, port, timeout=300) @classmethod def tearDownClass(cls): @@ -178,8 +152,6 @@ class TestOpenAIServer(unittest.TestCase): is_first = True for response in generator: - print(response) - data = response.choices[0].delta if is_first: data.role == "assistant" diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py new file mode 100644 index 000000000..345467858 --- /dev/null +++ b/test/srt/test_srt_endpoint.py @@ -0,0 +1,64 @@ +import json +import unittest + +import requests + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server + + +class TestSRTEndpoint(unittest.TestCase): + + @classmethod + def setUpClass(cls): + port = 30000 + + cls.model = MODEL_NAME_FOR_TEST + cls.base_url = f"http://localhost:{port}" + cls.process = popen_launch_server(cls.model, port, timeout=300) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def run_decode( + self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1 + ): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0 if n == 1 else 0.5, + "max_new_tokens": 32, + "n": n, + }, + "stream": False, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "return_text_in_logprobs": return_text, + "logprob_start_len": 0, + }, + ) + print(json.dumps(response.json())) + print("=" * 100) + + def test_simple_decode(self): + self.run_decode() + + def test_parallel_sample(self): + self.run_decode(n=3) + + def test_logprob(self): + for top_logprobs_num in [0, 3]: + for return_text in [True, False]: + self.run_decode( + return_logprob=True, + top_logprobs_num=top_logprobs_num, + return_text=return_text, + ) + + +if __name__ == "__main__": + unittest.main(warnings="ignore")