diff --git a/.github/workflows/e2e-test.yml b/.github/workflows/e2e-test.yml
index 9630ca718..7b59054fe 100644
--- a/.github/workflows/e2e-test.yml
+++ b/.github/workflows/e2e-test.yml
@@ -18,7 +18,7 @@ concurrency:
cancel-in-progress: true
jobs:
- pr-e2e-test:
+ e2e-test:
runs-on: self-hosted
env:
@@ -38,7 +38,7 @@ jobs:
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ --force-reinstall
pip install --upgrade transformers
- - name: Benchmark Serving
+ - name: Benchmark Serving Throughput
run: |
cd /data/zhyncs/venv && source ./bin/activate && cd -
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --port 8413 --disable-radix-cache &
diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml
index dc464fa8c..f1c069ea5 100644
--- a/.github/workflows/unit-test.yml
+++ b/.github/workflows/unit-test.yml
@@ -59,3 +59,10 @@ jobs:
cd test/srt
python3 test_openai_server.py
+
+ - name: Test Accuracy
+ run: |
+ cd /data/zhyncs/venv && source ./bin/activate && cd -
+
+ cd test/srt
+ python3 test_eval_accuracy.py
diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py
index b52e114fd..253aab355 100644
--- a/python/sglang/bench_serving.py
+++ b/python/sglang/bench_serving.py
@@ -21,7 +21,7 @@ import sys
import time
import traceback
import warnings
-from argparse import ArgumentParser as FlexibleArgumentParser
+from argparse import ArgumentParser
from dataclasses import dataclass, field
from datetime import datetime
from typing import AsyncGenerator, List, Optional, Tuple, Union
@@ -868,14 +868,12 @@ def set_ulimit(target_soft_limit=65535):
if __name__ == "__main__":
- parser = FlexibleArgumentParser(
- description="Benchmark the online serving throughput."
- )
+ parser = ArgumentParser(description="Benchmark the online serving throughput.")
parser.add_argument(
"--backend",
type=str,
- required=True,
choices=list(ASYNC_REQUEST_FUNCS.keys()),
+ default="sglang",
help="Must specify a backend, depending on the LLM Inference Engine.",
)
parser.add_argument(
diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py
new file mode 100644
index 000000000..3729ef7ab
--- /dev/null
+++ b/python/sglang/test/run_eval.py
@@ -0,0 +1,99 @@
+"""
+Usage:
+python3 -m sglang.test.run_eval --port 30000 --eval-name mmlu --num-examples 10
+"""
+
+import argparse
+import json
+import os
+import time
+
+from sglang.test.simple_eval_common import (
+ ChatCompletionSampler,
+ download_dataset,
+ make_report,
+ set_ulimit,
+)
+from sglang.test.simple_eval_mmlu import MMLUEval
+
+
+def run_eval(args):
+ if "OPENAI_API_KEY" not in os.environ:
+ os.environ["OPENAI_API_KEY"] = "EMPTY"
+
+ base_url = (
+ f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1"
+ )
+
+ if args.eval_name == "mmlu":
+ dataset_path = "mmlu.csv"
+
+ if not os.path.exists(dataset_path):
+ download_dataset(
+ dataset_path,
+ "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
+ )
+ eval_obj = MMLUEval(dataset_path, args.num_examples, args.num_threads)
+ else:
+ raise ValueError(f"Invalid eval name: {args.eval_name}")
+
+ sampler = ChatCompletionSampler(
+ model=args.model,
+ max_tokens=2048,
+ base_url=base_url,
+ )
+
+ # Run eval
+ tic = time.time()
+ result = eval_obj(sampler)
+ latency = time.time() - tic
+
+ # Dump reports
+ metrics = result.metrics | {"score": result.score}
+ file_stem = f"mmlu_{sampler.model.replace('/', '_')}"
+ report_filename = f"/tmp/{file_stem}.html"
+ print(f"Writing report to {report_filename}")
+ with open(report_filename, "w") as fh:
+ fh.write(make_report(result))
+ metrics = result.metrics | {"score": result.score}
+ print(metrics)
+ result_filename = f"/tmp/{file_stem}.json"
+ with open(result_filename, "w") as f:
+ f.write(json.dumps(metrics, indent=2))
+ print(f"Writing results to {result_filename}")
+
+ # Print results
+ print(f"Total latency: {latency:.3f} s")
+ print(f"Score: {metrics['score']:.3f}")
+
+ return metrics
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--base-url",
+ type=str,
+ default=None,
+ help="Server or API base url if not using http host and port.",
+ )
+ parser.add_argument(
+ "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
+ )
+ parser.add_argument(
+ "--port",
+ type=int,
+ help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
+ )
+ parser.add_argument("--eval-name", type=str, default="mmlu")
+ parser.add_argument("--num-examples", type=int)
+ parser.add_argument("--num-threads", type=int, default=64)
+ set_ulimit()
+ args = parser.parse_args()
+
+ run_eval(args)
diff --git a/python/sglang/test/simple_eval_common.py b/python/sglang/test/simple_eval_common.py
new file mode 100644
index 000000000..75c26f0f0
--- /dev/null
+++ b/python/sglang/test/simple_eval_common.py
@@ -0,0 +1,456 @@
+# Adapted from https://github.com/openai/simple-evals/
+
+import base64
+import os
+import resource
+import time
+from collections import defaultdict
+from dataclasses import dataclass, field
+from multiprocessing.pool import ThreadPool
+from typing import Any
+
+import jinja2
+import numpy as np
+import openai
+import requests
+from openai import OpenAI
+from tqdm import tqdm
+
+OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
+OPENAI_SYSTEM_MESSAGE_CHATGPT = (
+ "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
+ + "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01"
+)
+
+
+Message = dict[str, Any] # keys role, content
+MessageList = list[Message]
+
+
+class SamplerBase:
+ """
+ Base class for defining a sampling model, which can be evaluated,
+ or used as part of the grading process.
+ """
+
+ def __call__(self, message_list: MessageList) -> str:
+ raise NotImplementedError()
+
+
+@dataclass
+class EvalResult:
+ """
+ Result of running an evaluation (usually consisting of many samples)
+ """
+
+ score: float | None # top-line metric
+ metrics: dict[str, float] | None # other metrics
+ htmls: list[str] # strings of valid HTML
+ convos: list[MessageList] # sampled conversations
+
+
+@dataclass
+class SingleEvalResult:
+ """
+ Result of evaluating a single sample
+ """
+
+ score: float | None
+ metrics: dict[str, float] = field(default_factory=dict)
+ html: str | None = None
+ convo: MessageList | None = None # sampled conversation
+
+
+class Eval:
+ """
+ Base class for defining an evaluation.
+ """
+
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
+ raise NotImplementedError()
+
+
+class ChatCompletionSampler(SamplerBase):
+ """
+ Sample from OpenAI's chat completion API
+ """
+
+ def __init__(
+ self,
+ base_url: str = None,
+ model: str | None = None,
+ system_message: str | None = None,
+ temperature: float = 0.0,
+ max_tokens: int = 2048,
+ ):
+ self.client = OpenAI(base_url=base_url)
+
+ if model is None:
+ model = self.client.models.list().data[0].id
+
+ self.model = model
+ self.system_message = system_message
+ self.temperature = temperature
+ self.max_tokens = max_tokens
+ self.image_format = "url"
+
+ def _handle_image(
+ self,
+ image: str,
+ encoding: str = "base64",
+ format: str = "png",
+ fovea: int = 768,
+ ):
+ new_image = {
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:image/{format};{encoding},{image}",
+ },
+ }
+ return new_image
+
+ def _handle_text(self, text: str):
+ return {"type": "text", "text": text}
+
+ def _pack_message(self, role: str, content: Any):
+ return {"role": str(role), "content": content}
+
+ def __call__(self, message_list: MessageList) -> str:
+ if self.system_message:
+ message_list = [
+ self._pack_message("system", self.system_message)
+ ] + message_list
+ trial = 0
+ while True:
+ try:
+ response = self.client.chat.completions.create(
+ model=self.model,
+ messages=message_list,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+ return response.choices[0].message.content
+ # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
+ except openai.BadRequestError as e:
+ print("Bad Request Error", e)
+ return ""
+ except Exception as e:
+ exception_backoff = 2**trial # expontial back off
+ print(
+ f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
+ e,
+ )
+ time.sleep(exception_backoff)
+ trial += 1
+ # unknown error shall throw exception
+
+
+QUERY_TEMPLATE_MULTICHOICE = """
+Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
+
+{Question}
+
+A) {A}
+B) {B}
+C) {C}
+D) {D}
+""".strip()
+
+ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])"
+ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)"
+
+
+EQUALITY_TEMPLATE = r"""
+Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
+
+Examples:
+
+ Expression 1: $2x+3$
+ Expression 2: $3+2x$
+
+Yes
+
+ Expression 1: 3/2
+ Expression 2: 1.5
+
+Yes
+
+ Expression 1: $x^2+2x+1$
+ Expression 2: $y^2+2y+1$
+
+No
+
+ Expression 1: $x^2+2x+1$
+ Expression 2: $(x+1)^2$
+
+Yes
+
+ Expression 1: 3245/5
+ Expression 2: 649
+
+No
+(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
+
+ Expression 1: 2/(-3)
+ Expression 2: -2/3
+
+Yes
+(trivial simplifications are allowed)
+
+ Expression 1: 72 degrees
+ Expression 2: 72
+
+Yes
+(give benefit of the doubt to units)
+
+ Expression 1: 64
+ Expression 2: 64 square feet
+
+Yes
+(give benefit of the doubt to units)
+
+---
+
+YOUR TASK
+
+
+Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
+
+ Expression 1: %(expression1)s
+ Expression 2: %(expression2)s
+""".strip()
+
+
+HTML_JINJA = """
+
Prompt conversation
+{% for message in prompt_messages %}
+{{ message_to_html(message) | safe }}
+{% endfor %}
+Sampled message
+{{ message_to_html(next_message) | safe }}
+Results
+Correct Answer: {{ correct_answer }}
+Extracted Answer: {{ extracted_answer }}
+Score: {{ score }}
+"""
+
+
+def format_multichoice_question(row):
+ return QUERY_TEMPLATE_MULTICHOICE.format(**row)
+
+
+def check_equality(sampler: SamplerBase, expr1: str, expr2: str):
+ prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2}
+ response = sampler([dict(content=prompt, role="user")])
+ return response.lower().strip() == "yes"
+
+
+def _compute_stat(values: list, stat: str):
+ if stat == "mean":
+ return np.mean(values)
+ elif stat == "std":
+ return np.std(values)
+ elif stat == "min":
+ return np.min(values)
+ elif stat == "max":
+ return np.max(values)
+ else:
+ raise ValueError(f"Unknown {stat =}")
+
+
+def aggregate_results(
+ single_eval_results: list[SingleEvalResult],
+ default_stats: tuple[str] = ("mean", "std"),
+ name2stats: dict[str, tuple[str]] | None = None,
+) -> EvalResult:
+ """
+ Aggregate results from multiple evaluations into a single EvalResult.
+ """
+ name2stats = name2stats or {}
+ name2values = defaultdict(list)
+ htmls = []
+ convos = []
+ for single_eval_result in single_eval_results:
+ for name, value in single_eval_result.metrics.items():
+ name2values[name].append(value)
+ if single_eval_result.score is not None:
+ name2values["score"].append(single_eval_result.score)
+ htmls.append(single_eval_result.html)
+ convos.append(single_eval_result.convo)
+ final_metrics = {}
+ for name, values in name2values.items():
+ stats = name2stats.get(name, default_stats)
+ for stat in stats:
+ key = name if stat == "mean" else f"{name}:{stat}"
+ final_metrics[key] = _compute_stat(values, stat)
+ return EvalResult(
+ score=final_metrics.pop("score", None),
+ metrics=final_metrics,
+ htmls=htmls,
+ convos=convos,
+ )
+
+
+def map_with_progress(f: callable, xs: list[Any], num_threads: int):
+ """
+ Apply f to each element of xs, using a ThreadPool, and show progress.
+ """
+ if os.getenv("debug"):
+ return list(map(f, tqdm(xs, total=len(xs))))
+ else:
+ with ThreadPool(min(num_threads, len(xs))) as pool:
+ return list(tqdm(pool.imap(f, xs), total=len(xs)))
+
+
+jinja_env = jinja2.Environment(
+ loader=jinja2.BaseLoader(),
+ undefined=jinja2.StrictUndefined,
+ autoescape=jinja2.select_autoescape(["html", "xml"]),
+)
+_message_template = """
+
+
+ {{ role }}
+ {% if variant %}({{ variant }}){% endif %}
+
+
+
+"""
+
+
+def message_to_html(message: Message) -> str:
+ """
+ Generate HTML snippet (inside a ) for a message.
+ """
+ return jinja_env.from_string(_message_template).render(
+ role=message["role"],
+ content=message["content"],
+ variant=message.get("variant", None),
+ )
+
+
+jinja_env.globals["message_to_html"] = message_to_html
+
+
+_report_template = """
+
+
+
+
+
+ {% if metrics %}
+
Metrics
+
+
+ | Metric |
+ Value |
+
+
+ | Score |
+ {{ score | float | round(3) }} |
+
+ {% for name, value in metrics.items() %}
+
+ | {{ name }} |
+ {{ value }} |
+
+ {% endfor %}
+
+ {% endif %}
+
Examples
+ {% for html in htmls %}
+ {{ html | safe }}
+
+ {% endfor %}
+
+
+"""
+
+
+def make_report(eval_result: EvalResult) -> str:
+ """
+ Create a standalone HTML report from an EvalResult.
+ """
+ return jinja_env.from_string(_report_template).render(
+ score=eval_result.score,
+ metrics=eval_result.metrics,
+ htmls=eval_result.htmls,
+ )
+
+
+def make_report_from_example_htmls(htmls: list[str]):
+ """
+ Create a standalone HTML report from a list of example htmls
+ """
+ return jinja_env.from_string(_report_template).render(
+ score=None, metrics={}, htmls=htmls
+ )
+
+
+def download_dataset(path, url):
+ print(f"Downloading dataset {path} from {url}")
+ try:
+ response = requests.get(url, stream=True)
+ response.raise_for_status()
+
+ total_size = int(response.headers.get("content-length", 0))
+ block_size = 8192
+
+ with open(path, "wb") as f, tqdm(
+ desc="Downloading",
+ total=total_size,
+ unit="iB",
+ unit_scale=True,
+ unit_divisor=1024,
+ ) as progress_bar:
+ for data in response.iter_content(block_size):
+ size = f.write(data)
+ progress_bar.update(size)
+
+ print(f"Dataset downloaded and saved to {path}")
+ except requests.RequestException as e:
+ raise Exception(f"Failed to download dataset: {e}")
+
+
+def set_ulimit(target_soft_limit=65535):
+ resource_type = resource.RLIMIT_NOFILE
+ current_soft, current_hard = resource.getrlimit(resource_type)
+
+ if current_soft < target_soft_limit:
+ try:
+ resource.setrlimit(resource_type, (target_soft_limit, current_hard))
+ except ValueError as e:
+ print(f"Fail to set RLIMIT_NOFILE: {e}")
diff --git a/python/sglang/test/simple_eval_mmlu.py b/python/sglang/test/simple_eval_mmlu.py
new file mode 100644
index 000000000..3c0287510
--- /dev/null
+++ b/python/sglang/test/simple_eval_mmlu.py
@@ -0,0 +1,120 @@
+# Adapted from https://github.com/openai/simple-evals/
+
+"""
+Measuring Massive Multitask Language Understanding
+Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, Jacob Steinhardt
+https://arxiv.org/abs/2009.03300
+"""
+
+import random
+import re
+
+import pandas
+
+from sglang.test import simple_eval_common as common
+from sglang.test.simple_eval_common import (
+ ANSWER_PATTERN_MULTICHOICE,
+ HTML_JINJA,
+ Eval,
+ EvalResult,
+ SamplerBase,
+ SingleEvalResult,
+ format_multichoice_question,
+)
+
+subject2category = {
+ "abstract_algebra": "stem",
+ "anatomy": "other",
+ "astronomy": "stem",
+ "business_ethics": "other",
+ "clinical_knowledge": "other",
+ "college_biology": "stem",
+ "college_chemistry": "stem",
+ "college_computer_science": "stem",
+ "college_mathematics": "stem",
+ "college_medicine": "other",
+ "college_physics": "stem",
+ "computer_security": "stem",
+ "conceptual_physics": "stem",
+ "econometrics": "social_sciences",
+ "electrical_engineering": "stem",
+ "elementary_mathematics": "stem",
+ "formal_logic": "humanities",
+ "global_facts": "other",
+ "high_school_biology": "stem",
+ "high_school_chemistry": "stem",
+ "high_school_computer_science": "stem",
+ "high_school_european_history": "humanities",
+ "high_school_geography": "social_sciences",
+ "high_school_government_and_politics": "social_sciences",
+ "high_school_macroeconomics": "social_sciences",
+ "high_school_mathematics": "stem",
+ "high_school_microeconomics": "social_sciences",
+ "high_school_physics": "stem",
+ "high_school_psychology": "social_sciences",
+ "high_school_statistics": "stem",
+ "high_school_us_history": "humanities",
+ "high_school_world_history": "humanities",
+ "human_aging": "other",
+ "human_sexuality": "social_sciences",
+ "international_law": "humanities",
+ "jurisprudence": "humanities",
+ "logical_fallacies": "humanities",
+ "machine_learning": "stem",
+ "management": "other",
+ "marketing": "other",
+ "medical_genetics": "other",
+ "miscellaneous": "other",
+ "moral_disputes": "humanities",
+ "moral_scenarios": "humanities",
+ "nutrition": "other",
+ "philosophy": "humanities",
+ "prehistory": "humanities",
+ "professional_accounting": "other",
+ "professional_law": "humanities",
+ "professional_medicine": "other",
+ "professional_psychology": "social_sciences",
+ "public_relations": "social_sciences",
+ "security_studies": "social_sciences",
+ "sociology": "social_sciences",
+ "us_foreign_policy": "social_sciences",
+ "virology": "other",
+ "world_religions": "humanities",
+}
+
+
+class MMLUEval(Eval):
+ def __init__(self, filename: str, num_examples: int | None, num_threads: int):
+ df = pandas.read_csv(filename)
+ examples = [row.to_dict() for _, row in df.iterrows()]
+ if num_examples:
+ examples = random.Random(0).sample(examples, num_examples)
+ self.examples = examples
+ self.num_threads = num_threads
+
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
+ def fn(row: dict):
+ prompt_messages = [
+ sampler._pack_message(
+ content=format_multichoice_question(row), role="user"
+ )
+ ]
+ response_text = sampler(prompt_messages)
+ match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
+ extracted_answer = match.group(1) if match else None
+ score = 1.0 if extracted_answer == row["Answer"] else 0.0
+ html = common.jinja_env.from_string(HTML_JINJA).render(
+ prompt_messages=prompt_messages,
+ next_message=dict(content=response_text, role="assistant"),
+ score=score,
+ correct_answer=row["Answer"],
+ extracted_answer=extracted_answer,
+ )
+ convo = prompt_messages + [dict(content=response_text, role="assistant")]
+ category = subject2category.get(row["Subject"], "other")
+ return SingleEvalResult(
+ html=html, score=score, metrics={category: score}, convo=convo
+ )
+
+ results = common.map_with_progress(fn, self.examples, self.num_threads)
+ return common.aggregate_results(results)
diff --git a/python/sglang/test/test_conversation.py b/python/sglang/test/test_conversation.py
deleted file mode 100644
index e6d9f396a..000000000
--- a/python/sglang/test/test_conversation.py
+++ /dev/null
@@ -1,46 +0,0 @@
-from sglang.srt.conversation import generate_chat_conv
-from sglang.srt.managers.openai_api.protocol import (
- ChatCompletionMessageContentImagePart,
- ChatCompletionMessageContentImageURL,
- ChatCompletionMessageContentTextPart,
- ChatCompletionMessageGenericParam,
- ChatCompletionMessageUserParam,
- ChatCompletionRequest,
-)
-
-
-def test_chat_completion_to_conv_image():
- """Test that we can convert a chat image request to a convo"""
- request = ChatCompletionRequest(
- model="default",
- messages=[
- ChatCompletionMessageGenericParam(
- role="system", content="You are a helpful AI assistant"
- ),
- ChatCompletionMessageUserParam(
- role="user",
- content=[
- ChatCompletionMessageContentTextPart(
- type="text", text="Describe this image"
- ),
- ChatCompletionMessageContentImagePart(
- type="image_url",
- image_url=ChatCompletionMessageContentImageURL(
- url="https://someurl.com"
- ),
- ),
- ],
- ),
- ],
- )
- conv = generate_chat_conv(request, "vicuna_v1.1")
- assert conv.messages == [
- ["USER", "Describe this image
"],
- ["ASSISTANT", None],
- ]
- assert conv.system_message == "You are a helpful AI assistant"
- assert conv.image_data == ["https://someurl.com"]
-
-
-if __name__ == "__main__":
- test_chat_completion_to_conv_image()
diff --git a/python/sglang/test/test_openai_protocol.py b/python/sglang/test/test_openai_protocol.py
deleted file mode 100644
index cade4728c..000000000
--- a/python/sglang/test/test_openai_protocol.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from sglang.srt.managers.openai_api.protocol import (
- ChatCompletionMessageContentImagePart,
- ChatCompletionMessageContentImageURL,
- ChatCompletionMessageContentTextPart,
- ChatCompletionMessageGenericParam,
- ChatCompletionMessageUserParam,
- ChatCompletionRequest,
-)
-
-
-def test_chat_completion_request_image():
- """Test that Chat Completion Requests with images can be converted."""
-
- image_request = {
- "model": "default",
- "messages": [
- {"role": "system", "content": "You are a helpful AI assistant"},
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "Describe this image"},
- {"type": "image_url", "image_url": {"url": "https://someurl.com"}},
- ],
- },
- ],
- "temperature": 0,
- "max_tokens": 64,
- }
- request = ChatCompletionRequest(**image_request)
- assert len(request.messages) == 2
- assert request.messages[0] == ChatCompletionMessageGenericParam(
- role="system", content="You are a helpful AI assistant"
- )
- assert request.messages[1] == ChatCompletionMessageUserParam(
- role="user",
- content=[
- ChatCompletionMessageContentTextPart(
- type="text", text="Describe this image"
- ),
- ChatCompletionMessageContentImagePart(
- type="image_url",
- image_url=ChatCompletionMessageContentImageURL(
- url="https://someurl.com"
- ),
- ),
- ],
- )
-
-
-if __name__ == "__main__":
- test_chat_completion_request_image()
diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py
index af7f3765e..4348b57e9 100644
--- a/python/sglang/test/test_utils.py
+++ b/python/sglang/test/test_utils.py
@@ -1,6 +1,8 @@
"""Common utilities for testing and benchmarking"""
import asyncio
+import subprocess
+import time
from functools import partial
import numpy as np
@@ -11,6 +13,8 @@ from sglang.lang.backend.openai import OpenAI
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.utils import get_exception_traceback
+MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct"
+
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
assert url is not None
@@ -379,3 +383,31 @@ def get_call_select(args):
raise
return func
+
+
+def popen_launch_server(model, port, timeout, *args):
+ command = [
+ "python3",
+ "-m",
+ "sglang.launch_server",
+ "--model-path",
+ model,
+ "--host",
+ "localhost",
+ "--port",
+ str(port),
+ *args,
+ ]
+ process = subprocess.Popen(command, stdout=None, stderr=None)
+ base_url = f"http://localhost:{port}/v1"
+
+ start_time = time.time()
+ while time.time() - start_time < timeout:
+ try:
+ response = requests.get(f"{base_url}/models")
+ if response.status_code == 200:
+ return process
+ except requests.RequestException:
+ pass
+ time.sleep(10)
+ raise TimeoutError("Server failed to start within the timeout period.")
diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py
index f9d79ed29..7accd349f 100644
--- a/test/lang/test_srt_backend.py
+++ b/test/lang/test_srt_backend.py
@@ -14,6 +14,7 @@ from sglang.test.test_programs import (
test_stream,
test_tool_use,
)
+from sglang.test.test_utils import MODEL_NAME_FOR_TEST
class TestSRTBackend(unittest.TestCase):
@@ -21,7 +22,7 @@ class TestSRTBackend(unittest.TestCase):
@classmethod
def setUpClass(cls):
- cls.backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
+ cls.backend = sgl.Runtime(model_path=MODEL_NAME_FOR_TEST)
sgl.set_default_backend(cls.backend)
@classmethod
diff --git a/test/srt/old/test_curl.sh b/test/srt/deprecated/test_curl.sh
similarity index 100%
rename from test/srt/old/test_curl.sh
rename to test/srt/deprecated/test_curl.sh
diff --git a/test/srt/old/test_flashinfer.py b/test/srt/deprecated/test_flashinfer.py
similarity index 100%
rename from test/srt/old/test_flashinfer.py
rename to test/srt/deprecated/test_flashinfer.py
diff --git a/test/srt/old/test_httpserver_classify.py b/test/srt/deprecated/test_httpserver_classify.py
similarity index 100%
rename from test/srt/old/test_httpserver_classify.py
rename to test/srt/deprecated/test_httpserver_classify.py
diff --git a/test/srt/old/test_httpserver_concurrent.py b/test/srt/deprecated/test_httpserver_concurrent.py
similarity index 100%
rename from test/srt/old/test_httpserver_concurrent.py
rename to test/srt/deprecated/test_httpserver_concurrent.py
diff --git a/test/srt/old/test_httpserver_decode.py b/test/srt/deprecated/test_httpserver_decode.py
similarity index 100%
rename from test/srt/old/test_httpserver_decode.py
rename to test/srt/deprecated/test_httpserver_decode.py
diff --git a/test/srt/old/test_httpserver_decode_stream.py b/test/srt/deprecated/test_httpserver_decode_stream.py
similarity index 100%
rename from test/srt/old/test_httpserver_decode_stream.py
rename to test/srt/deprecated/test_httpserver_decode_stream.py
diff --git a/test/srt/old/test_httpserver_llava.py b/test/srt/deprecated/test_httpserver_llava.py
similarity index 100%
rename from test/srt/old/test_httpserver_llava.py
rename to test/srt/deprecated/test_httpserver_llava.py
diff --git a/test/srt/old/test_httpserver_reuse.py b/test/srt/deprecated/test_httpserver_reuse.py
similarity index 100%
rename from test/srt/old/test_httpserver_reuse.py
rename to test/srt/deprecated/test_httpserver_reuse.py
diff --git a/test/srt/old/test_jump_forward.py b/test/srt/deprecated/test_jump_forward.py
similarity index 100%
rename from test/srt/old/test_jump_forward.py
rename to test/srt/deprecated/test_jump_forward.py
diff --git a/test/srt/old/test_openai_server.py b/test/srt/deprecated/test_openai_server.py
similarity index 100%
rename from test/srt/old/test_openai_server.py
rename to test/srt/deprecated/test_openai_server.py
diff --git a/test/srt/old/test_robust.py b/test/srt/deprecated/test_robust.py
similarity index 100%
rename from test/srt/old/test_robust.py
rename to test/srt/deprecated/test_robust.py
diff --git a/test/srt/test_eval_accuracy.py b/test/srt/test_eval_accuracy.py
new file mode 100644
index 000000000..d392dc4c0
--- /dev/null
+++ b/test/srt/test_eval_accuracy.py
@@ -0,0 +1,43 @@
+import json
+import unittest
+from types import SimpleNamespace
+
+from sglang.srt.utils import kill_child_process
+from sglang.test.run_eval import run_eval
+from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server
+
+
+class TestAccuracy(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ port = 30000
+
+ cls.model = MODEL_NAME_FOR_TEST
+ cls.base_url = f"http://localhost:{port}"
+ cls.process = popen_launch_server(cls.model, port, timeout=300)
+
+ @classmethod
+ def tearDownClass(cls):
+ kill_child_process(cls.process.pid)
+
+ def test_mmlu(self):
+ args = SimpleNamespace(
+ base_url=self.base_url,
+ model=self.model,
+ eval_name="mmlu",
+ num_examples=20,
+ num_threads=20,
+ )
+
+ metrics = run_eval(args)
+ assert metrics["score"] >= 0.5
+
+
+if __name__ == "__main__":
+ unittest.main(warnings="ignore")
+
+ # t = TestAccuracy()
+ # t.setUpClass()
+ # t.test_mmlu()
+ # t.tearDownClass()
diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py
index e15c2ba88..76a105a62 100644
--- a/test/srt/test_openai_server.py
+++ b/test/srt/test_openai_server.py
@@ -1,47 +1,21 @@
import json
-import subprocess
-import time
import unittest
import openai
-import requests
from sglang.srt.utils import kill_child_process
+from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server
class TestOpenAIServer(unittest.TestCase):
@classmethod
def setUpClass(cls):
- model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
port = 30000
- timeout = 300
- command = [
- "python3",
- "-m",
- "sglang.launch_server",
- "--model-path",
- model,
- "--host",
- "localhost",
- "--port",
- str(port),
- ]
- cls.process = subprocess.Popen(command, stdout=None, stderr=None)
+ cls.model = MODEL_NAME_FOR_TEST
cls.base_url = f"http://localhost:{port}/v1"
- cls.model = model
-
- start_time = time.time()
- while time.time() - start_time < timeout:
- try:
- response = requests.get(f"{cls.base_url}/models")
- if response.status_code == 200:
- return
- except requests.RequestException:
- pass
- time.sleep(10)
- raise TimeoutError("Server failed to start within the timeout period.")
+ cls.process = popen_launch_server(cls.model, port, timeout=300)
@classmethod
def tearDownClass(cls):
@@ -178,8 +152,6 @@ class TestOpenAIServer(unittest.TestCase):
is_first = True
for response in generator:
- print(response)
-
data = response.choices[0].delta
if is_first:
data.role == "assistant"
diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py
new file mode 100644
index 000000000..345467858
--- /dev/null
+++ b/test/srt/test_srt_endpoint.py
@@ -0,0 +1,64 @@
+import json
+import unittest
+
+import requests
+
+from sglang.srt.utils import kill_child_process
+from sglang.test.run_eval import run_eval
+from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server
+
+
+class TestSRTEndpoint(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ port = 30000
+
+ cls.model = MODEL_NAME_FOR_TEST
+ cls.base_url = f"http://localhost:{port}"
+ cls.process = popen_launch_server(cls.model, port, timeout=300)
+
+ @classmethod
+ def tearDownClass(cls):
+ kill_child_process(cls.process.pid)
+
+ def run_decode(
+ self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1
+ ):
+ response = requests.post(
+ self.base_url + "/generate",
+ json={
+ "text": "The capital of France is",
+ "sampling_params": {
+ "temperature": 0 if n == 1 else 0.5,
+ "max_new_tokens": 32,
+ "n": n,
+ },
+ "stream": False,
+ "return_logprob": return_logprob,
+ "top_logprobs_num": top_logprobs_num,
+ "return_text_in_logprobs": return_text,
+ "logprob_start_len": 0,
+ },
+ )
+ print(json.dumps(response.json()))
+ print("=" * 100)
+
+ def test_simple_decode(self):
+ self.run_decode()
+
+ def test_parallel_sample(self):
+ self.run_decode(n=3)
+
+ def test_logprob(self):
+ for top_logprobs_num in [0, 3]:
+ for return_text in [True, False]:
+ self.run_decode(
+ return_logprob=True,
+ top_logprobs_num=top_logprobs_num,
+ return_text=return_text,
+ )
+
+
+if __name__ == "__main__":
+ unittest.main(warnings="ignore")