adapt to sglang v0.5.2rc1 on dcu

This commit is contained in:
maxiao
2025-09-04 15:56:33 +08:00
commit 909abb58f5
2320 changed files with 489411 additions and 0 deletions

41
test/README.md Normal file
View File

@@ -0,0 +1,41 @@
# Run Unit Tests
SGLang uses the built-in library [unittest](https://docs.python.org/3/library/unittest.html) as the testing framework.
## Test Backend Runtime
```bash
cd sglang/test/srt
# Run a single file
python3 test_srt_endpoint.py
# Run a single test
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
# Run a suite with multiple files
python3 run_suite.py --suite per-commit
```
## Test Frontend Language
```bash
cd sglang/test/lang
# Run a single file
python3 test_srt_backend.py
```
## Adding or Updating Tests in CI
- Create new test files under `test/srt` or `test/lang` depending on the type of test.
- Ensure they are referenced in the respective `run_suite.py` (e.g., `test/srt/run_suite.py`) so theyre picked up in CI. For most small test cases, they can be added to the `per-commit` suite. Sort the test cases alphabetically.
- The CI will run the `per-commit` and `nightly` automatically. If you need special setup or custom test groups, you may modify the workflows in [`.github/workflows/`](https://github.com/sgl-project/sglang/tree/main/.github/workflows).
## Writing Elegant Test Cases
- Examine existing tests in [sglang/test](https://github.com/sgl-project/sglang/tree/main/test) for practical examples.
- Keep each test function focused on a single scenario or piece of functionality.
- Give tests descriptive names reflecting their purpose.
- Use robust assertions (e.g., assert, unittest methods) to validate outcomes.
- Clean up resources to avoid side effects and preserve test independence.
- Reduce the test time by using smaller models and reusing the server for multiple test cases.

BIN
test/lang/example_image.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

38
test/lang/run_suite.py Normal file
View File

@@ -0,0 +1,38 @@
import argparse
import glob
from sglang.test.test_utils import TestFile, run_unittest_files
suites = {
"per-commit": [
TestFile("test_srt_backend.py"),
# Skip this due to some OPENAI_API_KEY issues
# "test_openai_backend.py",
],
}
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
"--timeout-per-file",
type=int,
default=1000,
help="The time limit for running one file in seconds.",
)
arg_parser.add_argument(
"--suite",
type=str,
default=list(suites.keys())[0],
choices=list(suites.keys()) + ["all"],
help="The suite to run",
)
args = arg_parser.parse_args()
if args.suite == "all":
files = glob.glob("**/test_*.py", recursive=True)
else:
files = suites[args.suite]
exit_code = run_unittest_files(files, args.timeout_per_file)
exit(exit_code)

View File

@@ -0,0 +1,25 @@
import json
import unittest
from sglang import Anthropic, set_default_backend
from sglang.test.test_programs import test_mt_bench, test_stream
from sglang.test.test_utils import CustomTestCase
class TestAnthropicBackend(CustomTestCase):
backend = None
@classmethod
def setUpClass(cls):
cls.backend = Anthropic("claude-3-haiku-20240307")
set_default_backend(cls.backend)
def test_mt_bench(self):
test_mt_bench()
def test_stream(self):
test_stream()
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,51 @@
import unittest
import sglang as sgl
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, CustomTestCase
class TestBind(CustomTestCase):
backend = None
@classmethod
def setUpClass(cls):
cls.backend = sgl.Runtime(model_path=DEFAULT_MODEL_NAME_FOR_TEST)
sgl.set_default_backend(cls.backend)
@classmethod
def tearDownClass(cls):
cls.backend.shutdown()
def test_bind(self):
@sgl.function
def few_shot_qa(s, prompt, question):
s += prompt
s += "Q: What is the capital of France?\n"
s += "A: Paris\n"
s += "Q: " + question + "\n"
s += "A:" + sgl.gen("answer", stop="\n")
few_shot_qa_2 = few_shot_qa.bind(
prompt="The following are questions with answers.\n\n"
)
tracer = few_shot_qa_2.trace()
print(tracer.last_node.print_graph_dfs() + "\n")
def test_cache(self):
@sgl.function
def few_shot_qa(s, prompt, question):
s += prompt
s += "Q: What is the capital of France?\n"
s += "A: Paris\n"
s += "Q: " + question + "\n"
s += "A:" + sgl.gen("answer", stop="\n")
few_shot_qa_2 = few_shot_qa.bind(
prompt="Answer the following questions as if you were a 5-year-old kid.\n\n"
)
few_shot_qa_2.cache(self.backend)
if __name__ == "__main__":
unittest.main()

91
test/lang/test_choices.py Normal file
View File

@@ -0,0 +1,91 @@
import unittest
import numpy as np
from sglang.lang.choices import (
greedy_token_selection,
token_length_normalized,
unconditional_likelihood_normalized,
)
from sglang.test.test_utils import CustomTestCase
MOCK_CHOICES_INPUT_DATA = {
"choices": [
"organ", # ["organ"]
"organism", # ["organ", "ism"]
"antidisestablishmentarianism", # ["ant", "id", "is", "est", "ablish", "ment", "arian", "ism"]
],
"normalized_prompt_logprobs": [-0.1, -0.2, -0.05],
"input_token_logprobs": [
[[-0.1, 1, None]],
[[-0.1, 1, None], [-0.3, 2, None]],
[
[-0.4, 3, None],
[-0.25, 4, None],
[-0.1, 5, None],
[-0.01, 6, None],
[-0.01, 7, None],
[-0.01, 8, None],
[-0.01, 9, None],
[-0.01, 2, None],
],
],
"output_token_logprobs": [
[[-0.1, 10, None]],
[[-0.1, 10, None]],
[[-0.1, 10, None]],
],
"unconditional_token_logprobs": [
[[None, 1, None]],
[[None, 1, None], [-1.4, 2, None]],
[
[None, 3, None],
[-0.25, 4, None],
[-0.1, 5, None],
[-0.01, 6, None],
[-0.01, 7, None],
[-0.01, 8, None],
[-0.01, 9, None],
[-0.01, 2, None],
],
],
}
class TestChoices(CustomTestCase):
def test_token_length_normalized(self):
"""Confirm 'antidisestablishmentarianism' is selected due to high confidences for
its later tokens resulting in highest token length normalized prompt logprob."""
decision = token_length_normalized(**MOCK_CHOICES_INPUT_DATA)
assert decision.decision == "antidisestablishmentarianism"
def test_greedy_token_selection(self):
"""Confirm 'organ' is selected due it having the joint highest initial token
logprob, and a higher average logprob than organism's second token."""
decision = greedy_token_selection(**MOCK_CHOICES_INPUT_DATA)
assert decision.decision == "organ"
assert np.allclose(
decision.meta_info["greedy_logprob_matrix"],
[
[-0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1],
[-0.1, -0.3, -0.2, -0.2, -0.2, -0.2, -0.2, -0.2],
[-0.4, -0.25, -0.1, -0.01, -0.01, -0.01, -0.01, -0.01],
],
atol=0.01,
)
def test_unconditional_likelihood_normalized(self):
"""Confirm 'organism' is selected due to it having the highest average token logprob
once normalized by the unconditional token logprobs."""
decision = unconditional_likelihood_normalized(**MOCK_CHOICES_INPUT_DATA)
assert decision.decision == "organism"
assert np.allclose(
decision.meta_info["normalized_unconditional_prompt_logprobs"],
[-0.1, 0.5, -0.05],
atol=0.01,
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,25 @@
import json
import unittest
from sglang import LiteLLM, set_default_backend
from sglang.test.test_programs import test_mt_bench, test_stream
from sglang.test.test_utils import CustomTestCase
class TestAnthropicBackend(CustomTestCase):
chat_backend = None
@classmethod
def setUpClass(cls):
cls.chat_backend = LiteLLM("gpt-3.5-turbo")
set_default_backend(cls.chat_backend)
def test_mt_bench(self):
test_mt_bench()
def test_stream(self):
test_stream()
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,92 @@
import unittest
from sglang import OpenAI, set_default_backend
from sglang.test.test_programs import (
test_chat_completion_speculative,
test_completion_speculative,
test_decode_int,
test_decode_json,
test_expert_answer,
test_few_shot_qa,
test_image_qa,
test_mt_bench,
test_parallel_decoding,
test_parallel_encoding,
test_react,
test_select,
test_stream,
test_tool_use,
)
from sglang.test.test_utils import CustomTestCase
class TestOpenAIBackend(CustomTestCase):
instruct_backend = None
chat_backend = None
chat_vision_backend = None
@classmethod
def setUpClass(cls):
cls.instruct_backend = OpenAI("gpt-3.5-turbo-instruct")
cls.chat_backend = OpenAI("gpt-3.5-turbo")
cls.chat_vision_backend = OpenAI("gpt-4-turbo")
def test_few_shot_qa(self):
set_default_backend(self.instruct_backend)
test_few_shot_qa()
def test_mt_bench(self):
set_default_backend(self.chat_backend)
test_mt_bench()
def test_select(self):
set_default_backend(self.instruct_backend)
test_select(check_answer=True)
def test_decode_int(self):
set_default_backend(self.instruct_backend)
test_decode_int()
def test_decode_json(self):
set_default_backend(self.instruct_backend)
test_decode_json()
def test_expert_answer(self):
set_default_backend(self.instruct_backend)
test_expert_answer()
def test_tool_use(self):
set_default_backend(self.instruct_backend)
test_tool_use()
def test_react(self):
set_default_backend(self.instruct_backend)
test_react()
def test_parallel_decoding(self):
set_default_backend(self.instruct_backend)
test_parallel_decoding()
def test_parallel_encoding(self):
set_default_backend(self.instruct_backend)
test_parallel_encoding()
def test_image_qa(self):
set_default_backend(self.chat_vision_backend)
test_image_qa()
def test_stream(self):
set_default_backend(self.instruct_backend)
test_stream()
def test_completion_speculative(self):
set_default_backend(self.instruct_backend)
test_completion_speculative()
def test_chat_completion_speculative(self):
set_default_backend(self.chat_backend)
test_chat_completion_speculative()
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,68 @@
"""
Tests for the separate_reasoning functionality in sglang.
Usage:
python3 -m unittest test/lang/test_separate_reasoning.py
"""
import unittest
from sglang import assistant, gen, separate_reasoning, user
from sglang.lang.ir import SglExprList, SglSeparateReasoning
from sglang.test.test_utils import CustomTestCase
class TestSeparateReasoning(CustomTestCase):
def test_separate_reasoning_creation(self):
"""Test that SglSeparateReasoning objects are created correctly."""
# Test with valid model type and gen expression
test_gen = gen("test")
expr = separate_reasoning(test_gen, model_type="deepseek-r1")
self.assertIsInstance(expr, SglExprList)
self.assertEqual(len(expr.expr_list), 2)
self.assertEqual(expr.expr_list[0], test_gen)
reasoning_expr = expr.expr_list[1]
self.assertIsInstance(reasoning_expr, SglSeparateReasoning)
self.assertEqual(reasoning_expr.model_type, "deepseek-r1")
self.assertEqual(reasoning_expr.name, "test_reasoning_content")
# Test with another valid model type
expr = separate_reasoning(test_gen, model_type="qwen3")
self.assertIsInstance(expr, SglExprList)
self.assertEqual(expr.expr_list[1].model_type, "qwen3")
def test_separate_reasoning_name_processing(self):
"""Test that separate_reasoning correctly processes names."""
test_gen = gen("test_var")
expr = separate_reasoning(test_gen, model_type="deepseek-r1")
reasoning_expr = expr.expr_list[1]
self.assertEqual(reasoning_expr.name, "test_var_reasoning_content")
# Test the process_name_for_reasoning method
self.assertEqual(
reasoning_expr.process_name_for_reasoning("another_var"),
"another_var_reasoning_content",
)
def test_separate_reasoning_repr(self):
"""Test the string representation of SglSeparateReasoning."""
test_gen = gen("test_var")
expr = separate_reasoning(test_gen, model_type="deepseek-r1")
reasoning_expr = expr.expr_list[1]
self.assertEqual(
repr(reasoning_expr),
"SeparateReasoning(model_type=deepseek-r1, name=test_var_reasoning_content)",
)
def test_separate_reasoning_with_invalid_model_type(self):
"""Test that separate_reasoning accepts any model type during creation."""
# Create with invalid model type
test_gen = gen("test")
expr = separate_reasoning(test_gen, model_type="invalid-model")
self.assertIsInstance(expr, SglExprList)
self.assertEqual(expr.expr_list[1].model_type, "invalid-model")
# The actual validation happens in the ReasoningParser constructor
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,195 @@
"""
Tests for the execution of separate_reasoning functionality in sglang.
Usage:
python3 -m unittest test/lang/test_separate_reasoning_execution.py
"""
import threading
import time
import unittest
from unittest.mock import MagicMock, patch
from sglang import assistant, gen, separate_reasoning, user
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglGen, SglSeparateReasoning
from sglang.test.test_utils import CustomTestCase
# Helper function to create events that won't block program exit
def create_daemon_event():
event = threading.Event()
return event
class MockReasoningParser:
def __init__(self, model_type):
self.model_type = model_type
self.parse_non_stream_called = False
self.parse_stream_chunk_called = False
def parse_non_stream(self, full_text):
self.parse_non_stream_called = True
# Simulate parsing by adding a prefix to indicate reasoning
reasoning = f"[REASONING from {self.model_type}]: {full_text}"
normal_text = f"[NORMAL from {self.model_type}]: {full_text}"
return reasoning, normal_text
def parse_stream_chunk(self, chunk_text):
self.parse_stream_chunk_called = True
# Simulate parsing by adding a prefix to indicate reasoning
reasoning = f"[REASONING from {self.model_type}]: {chunk_text}"
normal_text = f"[NORMAL from {self.model_type}]: {chunk_text}"
return reasoning, normal_text
class TestSeparateReasoningExecution(CustomTestCase):
def setUp(self):
"""Set up for the test."""
super().setUp()
# Store any events created during the test
self.events = []
def tearDown(self):
"""Clean up any threads that might have been created during the test."""
super().tearDown()
# Set all events to ensure any waiting threads are released
for event in self.events:
event.set()
def tearDown(self):
super().tearDown()
# wake up all threads
for ev in self.events:
ev.set()
@patch("sglang.srt.parser.reasoning_parser.ReasoningParser")
def test_execute_separate_reasoning(self, mock_parser_class):
"""Test that _execute_separate_reasoning correctly calls the ReasoningParser."""
# Setup mock parser
mock_parser = MockReasoningParser("deepseek-r1")
mock_parser_class.return_value = mock_parser
# Create a mock backend to avoid AttributeError in __del__
mock_backend = MagicMock()
# Create a StreamExecutor with necessary setup
executor = StreamExecutor(
backend=mock_backend,
arguments={},
default_sampling_para={},
chat_template={
"role_map": {"user": "user", "assistant": "assistant"}
}, # Simple chat template
stream=False,
use_thread=False,
)
# Set up the executor with a variable and its value
var_name = "test_var"
reasoning_name = f"{var_name}_reasoning_content"
var_value = "Test content"
executor.variables = {var_name: var_value}
# Create events and track them for cleanup
var_event = create_daemon_event()
reasoning_event = create_daemon_event()
self.events.extend([var_event, reasoning_event])
executor.variable_event = {var_name: var_event, reasoning_name: reasoning_event}
executor.variable_event[var_name].set() # Mark as ready
# Set up the current role
executor.cur_role = "assistant"
executor.cur_role_begin_pos = 0
executor.text_ = var_value
# Create a gen expression and a separate_reasoning expression
gen_expr = SglGen(var_name)
expr = SglSeparateReasoning("deepseek-r1", expr=gen_expr)
# Execute separate_reasoning
executor._execute_separate_reasoning(expr)
# Verify that the parser was created with the correct model type
mock_parser_class.assert_called_once_with("deepseek-r1")
# Verify that parse_non_stream was called
self.assertTrue(mock_parser.parse_non_stream_called)
# Verify that the variables were updated correctly
reasoning_name = f"{var_name}_reasoning_content"
self.assertIn(reasoning_name, executor.variables)
self.assertEqual(
executor.variables[reasoning_name],
f"[REASONING from deepseek-r1]: {var_value}",
)
self.assertEqual(
executor.variables[var_name], f"[NORMAL from deepseek-r1]: {var_value}"
)
# Verify that the variable event was set
self.assertIn(reasoning_name, executor.variable_event)
self.assertTrue(executor.variable_event[reasoning_name].is_set())
# Verify that the text was updated
self.assertEqual(executor.text_, f"[NORMAL from deepseek-r1]: {var_value}")
@patch("sglang.srt.parser.reasoning_parser.ReasoningParser")
def test_reasoning_parser_integration(self, mock_parser_class):
"""Test the integration between separate_reasoning and ReasoningParser."""
# Setup mock parsers for different model types
deepseek_parser = MockReasoningParser("deepseek-r1")
qwen_parser = MockReasoningParser("qwen3")
# Configure the mock to return different parsers based on model type
def get_parser(model_type):
if model_type == "deepseek-r1":
return deepseek_parser
elif model_type == "qwen3":
return qwen_parser
else:
raise ValueError(f"Unsupported model type: {model_type}")
mock_parser_class.side_effect = get_parser
# Test with DeepSeek-R1 model
test_text = "This is a test"
reasoning, normal_text = deepseek_parser.parse_non_stream(test_text)
self.assertEqual(reasoning, f"[REASONING from deepseek-r1]: {test_text}")
self.assertEqual(normal_text, f"[NORMAL from deepseek-r1]: {test_text}")
# Test with Qwen3 model
reasoning, normal_text = qwen_parser.parse_non_stream(test_text)
self.assertEqual(reasoning, f"[REASONING from qwen3]: {test_text}")
self.assertEqual(normal_text, f"[NORMAL from qwen3]: {test_text}")
@patch("sglang.srt.parser.reasoning_parser.ReasoningParser")
def test_reasoning_parser_invalid_model(self, mock_parser_class):
"""Test that ReasoningParser raises an error for invalid model types."""
# Configure the mock to raise an error for invalid model types
def get_parser(model_type):
if model_type in ["deepseek-r1", "qwen3"]:
return MockReasoningParser(model_type)
elif model_type is None:
raise ValueError("Model type must be specified")
else:
raise ValueError(f"Unsupported model type: {model_type}")
mock_parser_class.side_effect = get_parser
with self.assertRaises(ValueError) as context:
mock_parser_class("invalid-model")
self.assertIn("Unsupported model type", str(context.exception))
with self.assertRaises(ValueError) as context:
mock_parser_class(None)
self.assertIn("Model type must be specified", str(context.exception))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,86 @@
"""
Usage:
python3 -m unittest test_srt_backend.TestSRTBackend.test_gen_min_new_tokens
python3 -m unittest test_srt_backend.TestSRTBackend.test_hellaswag_select
"""
import unittest
import sglang as sgl
from sglang.test.test_programs import (
test_decode_int,
test_decode_json_regex,
test_dtype_gen,
test_expert_answer,
test_few_shot_qa,
test_gen_min_new_tokens,
test_hellaswag_select,
test_mt_bench,
test_parallel_decoding,
test_regex,
test_select,
test_stream,
test_tool_use,
)
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, CustomTestCase
class TestSRTBackend(CustomTestCase):
backend = None
@classmethod
def setUpClass(cls):
cls.backend = sgl.Runtime(
model_path=DEFAULT_MODEL_NAME_FOR_TEST, cuda_graph_max_bs=4
)
sgl.set_default_backend(cls.backend)
@classmethod
def tearDownClass(cls):
cls.backend.shutdown()
def test_few_shot_qa(self):
test_few_shot_qa()
def test_mt_bench(self):
test_mt_bench()
def test_select(self):
test_select(check_answer=False)
def test_decode_int(self):
test_decode_int()
def test_decode_json_regex(self):
test_decode_json_regex()
def test_expert_answer(self):
test_expert_answer()
def test_tool_use(self):
test_tool_use()
def test_parallel_decoding(self):
test_parallel_decoding()
def test_stream(self):
test_stream()
def test_regex(self):
test_regex()
def test_dtype_gen(self):
test_dtype_gen()
def test_hellaswag_select(self):
# Run twice to capture more bugs
for _ in range(2):
accuracy, latency = test_hellaswag_select()
self.assertGreater(accuracy, 0.60)
def test_gen_min_new_tokens(self):
test_gen_min_new_tokens()
if __name__ == "__main__":
unittest.main()

129
test/lang/test_tracing.py Normal file
View File

@@ -0,0 +1,129 @@
import unittest
import sglang as sgl
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.test.test_utils import CustomTestCase
class TestTracing(CustomTestCase):
def test_few_shot_qa(self):
@sgl.function
def few_shot_qa(s, question):
s += "The following are questions with answers.\n\n"
s += "Q: What is the capital of France?\n"
s += "A: Paris\n"
s += "Q: " + question + "\n"
s += "A:" + sgl.gen("answer", stop="\n")
tracer = few_shot_qa.trace()
# print(tracer.last_node.print_graph_dfs() + "\n")
def test_select(self):
@sgl.function
def capital(s):
s += "The capital of France is"
s += sgl.select("capital", ["Paris. ", "London. "])
s += "It is a city" + sgl.gen("description", stop=".")
tracer = capital.trace()
# print(tracer.last_node.print_graph_dfs() + "\n")
def test_raise_warning(self):
@sgl.function
def wrong(s, question):
s += f"I want to ask {question}"
try:
tracer = wrong.trace()
raised = False
except TypeError:
raised = True
assert raised
def test_multi_function(self):
@sgl.function
def expand(s, tip):
s += (
"Please expand the following tip into a detailed paragraph:"
+ tip
+ "\n"
)
s += sgl.gen("detailed_tip")
@sgl.function
def tip_suggestion(s, topic):
s += "Here are 2 tips for " + topic + ".\n"
s += "1." + sgl.gen("tip_1", stop=["\n", ":", "."]) + "\n"
s += "2." + sgl.gen("tip_2", stop=["\n", ":", "."]) + "\n"
branch1 = expand(tip=s["tip_1"])
branch2 = expand(tip=s["tip_2"])
s += "Tip 1: " + branch1["detailed_tip"] + "\n"
s += "Tip 2: " + branch2["detailed_tip"] + "\n"
s += "In summary" + sgl.gen("summary")
compiled = tip_suggestion.compile()
# compiled.print_graph()
sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
state = compiled.run(topic="staying healthy")
# print(state.text() + "\n")
states = compiled.run_batch(
[
{"topic": "staying healthy"},
{"topic": "staying happy"},
{"topic": "earning money"},
],
temperature=0,
)
# for s in states:
# print(s.text() + "\n")
def test_role(self):
@sgl.function
def multi_turn_chat(s):
s += sgl.user("Who are you?")
s += sgl.assistant(sgl.gen("answer_1"))
s += sgl.user("Who created you?")
s += sgl.assistant(sgl.gen("answer_2"))
backend = BaseBackend()
backend.chat_template = get_chat_template("llama-2-chat")
compiled = multi_turn_chat.compile(backend=backend)
# compiled.print_graph()
def test_fork(self):
@sgl.function
def tip_suggestion(s):
s += (
"Here are three tips for staying healthy: "
"1. Balanced Diet; "
"2. Regular Exercise; "
"3. Adequate Sleep\n"
)
forks = s.fork(3)
for i in range(3):
forks[i] += f"Now, expand tip {i+1} into a paragraph:\n"
forks[i] += sgl.gen(f"detailed_tip")
s += "Tip 1:" + forks[0]["detailed_tip"] + "\n"
s += "Tip 2:" + forks[1]["detailed_tip"] + "\n"
s += "Tip 3:" + forks[2]["detailed_tip"] + "\n"
s += "In summary" + sgl.gen("summary")
tracer = tip_suggestion.trace()
# print(tracer.last_node.print_graph_dfs())
a = tip_suggestion.run(backend=sgl.OpenAI("gpt-3.5-turbo-instruct"))
# print(a.text())
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,53 @@
import unittest
from sglang import VertexAI, set_default_backend
from sglang.test.test_programs import (
test_expert_answer,
test_few_shot_qa,
test_image_qa,
test_mt_bench,
test_parallel_decoding,
test_parallel_encoding,
test_stream,
)
from sglang.test.test_utils import CustomTestCase
class TestVertexAIBackend(CustomTestCase):
backend = None
@classmethod
def setUpClass(cls):
cls.backend = VertexAI("gemini-1.5-pro-001")
def test_few_shot_qa(self):
set_default_backend(self.backend)
test_few_shot_qa()
def test_mt_bench(self):
set_default_backend(self.backend)
test_mt_bench()
def test_expert_answer(self):
set_default_backend(self.backend)
test_expert_answer(check_answer=False)
def test_parallel_decoding(self):
set_default_backend(self.backend)
test_parallel_decoding()
def test_parallel_encoding(self):
set_default_backend(self.backend)
test_parallel_encoding()
def test_image_qa(self):
set_default_backend(self.backend)
test_image_qa()
def test_stream(self):
set_default_backend(self.backend)
test_stream()
if __name__ == "__main__":
unittest.main()

2
test/pytest.ini Normal file
View File

@@ -0,0 +1,2 @@
[pytest]
asyncio_mode = auto

View File

@@ -0,0 +1,95 @@
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_offline_throughput,
)
TEST_MODEL_MATRIX = {
"Qwen/Qwen2.5-7B-Instruct": {
"accuracy": 0.85,
"latency": 150,
"output_throughput": 30,
},
}
class TestAscendGraphTp1Bf16(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.models = TEST_MODEL_MATRIX.keys()
cls.base_url = DEFAULT_URL_FOR_TEST
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
cls.common_args = [
"--trust-remote-code",
"--mem-fraction-static",
0.8,
"--attention-backend",
"ascend",
]
def test_a_gsm8k(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing accuracy: {model} ===##")
process = popen_launch_server(
model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
*self.common_args,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=1319,
max_new_tokens=512,
parallel=128,
host=f"http://{self.url.hostname}",
port=int(self.url.port),
)
metrics = run_eval_few_shot_gsm8k(args)
self.assertGreaterEqual(
metrics["accuracy"],
TEST_MODEL_MATRIX[model]["accuracy"],
)
finally:
kill_process_tree(process.pid)
def test_b_throughput(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing throughput: {model} ===##")
output_throughput = run_bench_offline_throughput(
model,
[
*self.common_args,
],
)
print(f"##=== {model} throughput: {output_throughput} ===##")
if is_in_ci():
self.assertGreater(
output_throughput,
TEST_MODEL_MATRIX[model]["output_throughput"],
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,97 @@
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_offline_throughput,
)
TEST_MODEL_MATRIX = {
"Qwen/Qwen2.5-7B-Instruct": {
"accuracy": 0.85,
"latency": 180,
"output_throughput": 20,
},
}
class TestAscendGraphTp2Bf16(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.models = TEST_MODEL_MATRIX.keys()
cls.base_url = DEFAULT_URL_FOR_TEST
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
cls.common_args = [
"--trust-remote-code",
"--mem-fraction-static",
0.8,
"--attention-backend",
"ascend",
"--tp-size",
2,
]
def test_a_gsm8k(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing accuracy: {model} ===##")
process = popen_launch_server(
model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
*self.common_args,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=1319,
max_new_tokens=512,
parallel=128,
host=f"http://{self.url.hostname}",
port=int(self.url.port),
)
metrics = run_eval_few_shot_gsm8k(args)
self.assertGreaterEqual(
metrics["accuracy"],
TEST_MODEL_MATRIX[model]["accuracy"],
)
finally:
kill_process_tree(process.pid)
def test_b_throughput(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing throughput: {model} ===##")
output_throughput = run_bench_offline_throughput(
model,
[
*self.common_args,
],
)
print(f"##=== {model} throughput: {output_throughput} ===##")
if is_in_ci():
self.assertGreater(
output_throughput,
TEST_MODEL_MATRIX[model]["output_throughput"],
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,103 @@
import os
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_offline_throughput,
)
TEST_MODEL_MATRIX = {
"/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-V2-Lite-W8A8": {
"accuracy": 0.34,
"latency": 1000,
"output_throughput": 6,
},
}
class TestAscendMlaW8A8Int8(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.models = TEST_MODEL_MATRIX.keys()
cls.base_url = DEFAULT_URL_FOR_TEST
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
cls.common_args = [
"--trust-remote-code",
"--disable-cuda-graph",
"--mem-fraction-static",
0.8,
"--attention-backend",
"ascend",
"--quantization",
"w8a8_int8",
"--tp-size",
2,
"--disable-radix-cache",
]
def test_a_gsm8k(self):
os.environ["ASCEND_USE_FIA"] = "true"
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing accuracy: {model} ===##")
process = popen_launch_server(
model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
*self.common_args,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=1319,
max_new_tokens=512,
parallel=128,
host=f"http://{self.url.hostname}",
port=int(self.url.port),
)
metrics = run_eval_few_shot_gsm8k(args)
self.assertGreaterEqual(
metrics["accuracy"],
TEST_MODEL_MATRIX[model]["accuracy"],
)
finally:
kill_process_tree(process.pid)
def test_b_throughput(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing throughput: {model} ===##")
output_throughput = run_bench_offline_throughput(
model,
[
*self.common_args,
],
)
print(f"##=== {model} throughput: {output_throughput} ===##")
if is_in_ci():
self.assertGreater(
output_throughput,
TEST_MODEL_MATRIX[model]["output_throughput"],
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,101 @@
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_offline_throughput,
)
TEST_MODEL_MATRIX = {
"/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-V2-Lite-W8A8": {
"accuracy": 0.34,
"latency": 1000,
"output_throughput": 6,
},
}
class TestAscendMlaW8A8Int8(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.models = TEST_MODEL_MATRIX.keys()
cls.base_url = DEFAULT_URL_FOR_TEST
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
cls.common_args = [
"--trust-remote-code",
"--disable-cuda-graph",
"--mem-fraction-static",
0.8,
"--attention-backend",
"ascend",
"--quantization",
"w8a8_int8",
"--tp-size",
4,
"--disable-radix-cache",
]
def test_a_gsm8k(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing accuracy: {model} ===##")
process = popen_launch_server(
model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
*self.common_args,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=1319,
max_new_tokens=512,
parallel=128,
host=f"http://{self.url.hostname}",
port=int(self.url.port),
)
metrics = run_eval_few_shot_gsm8k(args)
self.assertGreaterEqual(
metrics["accuracy"],
TEST_MODEL_MATRIX[model]["accuracy"],
)
finally:
kill_process_tree(process.pid)
def test_b_throughput(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing throughput: {model} ===##")
output_throughput = run_bench_offline_throughput(
model,
[
*self.common_args,
],
)
print(f"##=== {model} throughput: {output_throughput} ===##")
if is_in_ci():
self.assertGreater(
output_throughput,
TEST_MODEL_MATRIX[model]["output_throughput"],
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,96 @@
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_offline_throughput,
)
TEST_MODEL_MATRIX = {
"Qwen/Qwen2.5-7B-Instruct": {
"accuracy": 0.84,
"latency": 150,
"output_throughput": 30,
},
}
class TestAscendTp1Bf16(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.models = TEST_MODEL_MATRIX.keys()
cls.base_url = DEFAULT_URL_FOR_TEST
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
cls.common_args = [
"--trust-remote-code",
"--disable-cuda-graph",
"--mem-fraction-static",
0.8,
"--attention-backend",
"ascend",
]
def test_a_gsm8k(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing accuracy: {model} ===##")
process = popen_launch_server(
model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
*self.common_args,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=1319,
max_new_tokens=512,
parallel=128,
host=f"http://{self.url.hostname}",
port=int(self.url.port),
)
metrics = run_eval_few_shot_gsm8k(args)
self.assertGreaterEqual(
metrics["accuracy"],
TEST_MODEL_MATRIX[model]["accuracy"],
)
finally:
kill_process_tree(process.pid)
def test_b_throughput(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing throughput: {model} ===##")
output_throughput = run_bench_offline_throughput(
model,
[
*self.common_args,
],
)
print(f"##=== {model} throughput: {output_throughput} ===##")
if is_in_ci():
self.assertGreater(
output_throughput,
TEST_MODEL_MATRIX[model]["output_throughput"],
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,98 @@
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_offline_throughput,
)
TEST_MODEL_MATRIX = {
"Qwen/Qwen2.5-7B-Instruct": {
"accuracy": 0.85,
"latency": 180,
"output_throughput": 20,
},
}
class TestAscendTp2Bf16(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.models = TEST_MODEL_MATRIX.keys()
cls.base_url = DEFAULT_URL_FOR_TEST
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
cls.common_args = [
"--trust-remote-code",
"--disable-cuda-graph",
"--mem-fraction-static",
0.8,
"--attention-backend",
"ascend",
"--tp-size",
2,
]
def test_a_gsm8k(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing accuracy: {model} ===##")
process = popen_launch_server(
model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
*self.common_args,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=1319,
max_new_tokens=512,
parallel=128,
host=f"http://{self.url.hostname}",
port=int(self.url.port),
)
metrics = run_eval_few_shot_gsm8k(args)
self.assertGreaterEqual(
metrics["accuracy"],
TEST_MODEL_MATRIX[model]["accuracy"],
)
finally:
kill_process_tree(process.pid)
def test_b_throughput(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing throughput: {model} ===##")
output_throughput = run_bench_offline_throughput(
model,
[
*self.common_args,
],
)
print(f"##=== {model} throughput: {output_throughput} ===##")
if is_in_ci():
self.assertGreater(
output_throughput,
TEST_MODEL_MATRIX[model]["output_throughput"],
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,101 @@
import os
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_offline_throughput,
)
TEST_MODEL_MATRIX = {
"Qwen/Qwen2.5-7B-Instruct": {
"accuracy": 0.85,
"latency": 180,
"output_throughput": 20,
},
}
class TestAscendTp2Bf16(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.models = TEST_MODEL_MATRIX.keys()
cls.base_url = DEFAULT_URL_FOR_TEST
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
cls.common_args = [
"--trust-remote-code",
"--disable-cuda-graph",
"--mem-fraction-static",
0.8,
"--attention-backend",
"ascend",
"--tp-size",
2,
"--disable-radix-cache",
]
def test_a_gsm8k(self):
os.environ["ASCEND_USE_FIA"] = "true"
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing accuracy: {model} ===##")
process = popen_launch_server(
model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
*self.common_args,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=1319,
max_new_tokens=512,
parallel=128,
host=f"http://{self.url.hostname}",
port=int(self.url.port),
)
metrics = run_eval_few_shot_gsm8k(args)
self.assertGreaterEqual(
metrics["accuracy"],
TEST_MODEL_MATRIX[model]["accuracy"],
)
finally:
kill_process_tree(process.pid)
def test_b_throughput(self):
for model in self.models:
with self.subTest(model=model):
print(f"##=== Testing throughput: {model} ===##")
output_throughput = run_bench_offline_throughput(
model,
[
*self.common_args,
],
)
print(f"##=== {model} throughput: {output_throughput} ===##")
if is_in_ci():
self.assertGreater(
output_throughput,
TEST_MODEL_MATRIX[model]["output_throughput"],
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,104 @@
"""
Usage:
python3 -m unittest test_ascend_w8a8_quantization.TestAscendW8A8.test_gsm8k
"""
import os
import time
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
)
if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ:
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
7000 + int(os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "0")[0]) * 100
)
DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}"
class TestAscendW8A8(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--disable-cuda-graph",
"--device",
"npu",
"--attention-backend",
"ascend",
"--quantization",
"w8a8_int8",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
base_url = DEFAULT_URL_FOR_TEST
url = urlparse(base_url)
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host=f"http://{url.hostname}",
port=int(url.port),
)
metrics = run_eval(args)
print(metrics)
self.assertGreaterEqual(metrics["accuracy"], 0.25)
self.assertGreaterEqual(metrics["output_throughput"], 1000)
def run_decode(self, max_new_tokens):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
},
"ignore_eos": True,
},
)
return response.json()
def test_throughput(self):
max_tokens = 256
tic = time.perf_counter()
res = self.run_decode(max_tokens)
tok = time.perf_counter()
print(res["text"])
throughput = max_tokens / (tok - tic)
print(f"Throughput: {throughput} tokens/s")
if is_in_ci():
self.assertGreaterEqual(throughput, 25)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,28 @@
tasks:
- name: sglang-8192-1024-concurrency1
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 1 --num-prompts 5 --output-file deepseek_v3_results.jsonl
- name: sglang-8192-1024-concurrency2
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 2 --num-prompts 10 --output-file deepseek_v3_results.jsonl
- name: sglang-8192-1024-concurrency4
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 4 --num-prompts 20 --output-file deepseek_v3_results.jsonl
- name: sglang-8192-1024-concurrency8
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 8 --num-prompts 32 --output-file deepseek_v3_results.jsonl
- name: sglang-8192-1024-concurrency16
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 16 --num-prompts 48 --output-file deepseek_v3_results.jsonl
- name: sglang-8192-1024-concurrency24
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 24 --num-prompts 72 --output-file deepseek_v3_results.jsonl
- name: sglang-8192-1024-concurrency32
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 32 --num-prompts 96 --output-file deepseek_v3_results.jsonl

View File

@@ -0,0 +1,28 @@
tasks:
- name: sglang-32000-100-concurrency1
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 32000 --random-output-len 100 --max-concurrency 1 --num-prompts 5 --output-file deepseek_v3_long_context_results.jsonl
- name: sglang-32000-100-concurrency2
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 32000 --random-output-len 100 --max-concurrency 2 --num-prompts 10 --output-file deepseek_v3_long_context_results.jsonl
- name: sglang-32000-100-concurrency4
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 32000 --random-output-len 100 --max-concurrency 4 --num-prompts 20 --output-file deepseek_v3_long_context_results.jsonl
- name: sglang-32000-100-concurrency8
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 32000 --random-output-len 100 --max-concurrency 8 --num-prompts 32 --output-file deepseek_v3_long_context_results.jsonl
- name: sglang-32000-100-concurrency16
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 32000 --random-output-len 100 --max-concurrency 16 --num-prompts 48 --output-file deepseek_v3_long_context_results.jsonl
- name: sglang-32000-100-concurrency24
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 32000 --random-output-len 100 --max-concurrency 24 --num-prompts 72 --output-file deepseek_v3_long_context_results.jsonl
- name: sglang-32000-100-concurrency32
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 32000 --random-output-len 100 --max-concurrency 32 --num-prompts 96 --output-file deepseek_v3_long_context_results.jsonl

View File

@@ -0,0 +1,28 @@
tasks:
- name: sglang-8192-1024-concurrency1
server_cmd: python3 -m sglang.launch_server --model nvidia/Llama-3.1-405B-Instruct-FP8 --tp 8
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 1 --num-prompts 5 --output-file llama_405b_results.jsonl
- name: sglang-8192-1024-concurrency2
server_cmd: python3 -m sglang.launch_server --model nvidia/Llama-3.1-405B-Instruct-FP8 --tp 8
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 2 --num-prompts 10 --output-file llama_405b_results.jsonl
- name: sglang-8192-1024-concurrency4
server_cmd: python3 -m sglang.launch_server --model nvidia/Llama-3.1-405B-Instruct-FP8 --tp 8
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 4 --num-prompts 20 --output-file llama_405b_results.jsonl
- name: sglang-8192-1024-concurrency8
server_cmd: python3 -m sglang.launch_server --model nvidia/Llama-3.1-405B-Instruct-FP8 --tp 8
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 8 --num-prompts 32 --output-file llama_405b_results.jsonl
- name: sglang-8192-1024-concurrency16
server_cmd: python3 -m sglang.launch_server --model nvidia/Llama-3.1-405B-Instruct-FP8 --tp 8
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 16 --num-prompts 48 --output-file llama_405b_results.jsonl
- name: sglang-8192-1024-concurrency24
server_cmd: python3 -m sglang.launch_server --model nvidia/Llama-3.1-405B-Instruct-FP8 --tp 8
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 24 --num-prompts 72 --output-file llama_405b_results.jsonl
- name: sglang-8192-1024-concurrency32
server_cmd: python3 -m sglang.launch_server --model nvidia/Llama-3.1-405B-Instruct-FP8 --tp 8
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 32 --num-prompts 96 --output-file llama_405b_results.jsonl

View File

@@ -0,0 +1,25 @@
tasks:
- name: sglang-128-4
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440
- name: vllm-128-4
server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests
client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440
- name: sglang-2000-100
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120
- name: vllm-2000-100
server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests
client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120
- name: sglang-4000-200
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480
- name: vllm-4000-200
server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests
client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480
- name: sglang-32000-100
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60
- name: vllm-32000-100
server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests
client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60

View File

@@ -0,0 +1,25 @@
tasks:
- name: sglang-128-4
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440
- name: sglang-triton-128-4
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440
- name: sglang-2000-100
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120
- name: sglang-triton-2000-100
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120
- name: sglang-4000-200
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480
- name: sglang-triton-4000-200
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480
- name: sglang-32000-100
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60
- name: sglang-triton-32000-100
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60

View File

@@ -0,0 +1,7 @@
tasks:
- name: sglang-benchmark
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --request-rate 16
- name: vllm-benchmark
server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests
client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --request-rate 16

View File

@@ -0,0 +1,35 @@
import itertools
import unittest
import sgl_kernel
import torch
import torch.nn.functional as F
from utils import SiluAndMul, precision
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class TestActivation(CustomTestCase):
M = [128, 129, 257]
N = [22016, 22018]
dtype = [torch.float16, torch.bfloat16]
def _activation_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
ref_out = SiluAndMul(x)
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_activation(self):
for params in itertools.product(self.M, self.N, self.dtype):
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
self._activation_test(*params)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,28 @@
import re
import unittest
import sgl_kernel
import torch
kernel = torch.ops.sgl_kernel
from sglang.test.test_utils import CustomTestCase
class TestGemm(CustomTestCase):
def test_binding(self):
start_id = 1
n_cpu = 6
expected_cores = list(map(str, range(start_id, start_id + n_cpu)))
cpu_ids = ",".join(expected_cores)
output = kernel.init_cpu_threads_env(cpu_ids)
bindings = re.findall(r"OMP tid: \d+, core (\d+)", output)
self.assertEqual(len(bindings), n_cpu)
self.assertEqual(bindings, expected_cores)
if __name__ == "__main__":
unittest.main()

170
test/srt/cpu/test_decode.py Normal file
View File

@@ -0,0 +1,170 @@
import unittest
import sgl_kernel
import torch
from torch.nn.functional import scaled_dot_product_attention
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class TestDecodeAttention(CustomTestCase):
def _run_sdpa_forward_decode(
self,
query: torch.Tensor,
output: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
scaling=None,
enable_gqa=False,
causal=False,
):
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
query = query.movedim(0, query.dim() - 2)
start_q, start_kv = 0, 0
for seq_idx in range(seq_lens.shape[0]):
seq_len_q = 1
seq_len_kv = seq_lens[seq_idx]
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
per_req_query = query[:, start_q:end_q, :]
# get key and value from cache. per_req_tokens contains the kv cache
# index for each token in the sequence.
req_pool_idx = req_pool_indices[seq_idx]
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_out = (
scaled_dot_product_attention(
per_req_query.unsqueeze(0),
per_req_key.unsqueeze(0),
per_req_value.unsqueeze(0),
enable_gqa=enable_gqa,
scale=scaling,
is_causal=causal,
)
.squeeze(0)
.movedim(query.dim() - 2, 0)
)
output[start_q:end_q, :, :] = per_req_out
start_q, start_kv = end_q, end_kv
return output
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, device):
dtype = torch.bfloat16
# This represents the number of tokens already in the sequence
seq_len = 1024
total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5)
logit_cap = 0.0
num_kv_splits = 8
enable_gqa = H_Q != H_KV
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D, dtype=dtype, device=device)
# k_buffer and v_buffer represent all previous tokens
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device)
v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device=device)
key = torch.randn(B, H_KV, D, dtype=dtype)
value = torch.randn(B, H_KV, D_V, dtype=dtype)
loc = torch.randint(0, 10, (B,)).to(torch.int64)
# set kv cache
k_buffer[loc] = key
v_buffer[loc] = value
# o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device)
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device)
req_to_token = (
torch.arange(total_tokens, device=device)
.reshape(B, seq_len)
.to(torch.int32)
)
b_req_idx = torch.arange(B, device=device).to(torch.int64)
b_seq_len = torch.full((B,), seq_len, device=device).to(torch.int64)
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
device=device,
)
# k_buffer, v_buffer, query, key and value supports non-contiguous tensors
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
q = q.transpose(0, 1).contiguous().transpose(0, 1)
key = key.transpose(0, 1).contiguous().transpose(0, 1)
value = value.transpose(0, 1).contiguous().transpose(0, 1)
torch.ops.sgl_kernel.decode_attention_cpu(
q,
k_buffer,
v_buffer,
o,
key,
value,
loc,
attn_logits,
req_to_token,
b_req_idx,
b_seq_len,
sm_scale,
logit_cap,
)
self._run_sdpa_forward_decode(
q,
o_grouped,
k_buffer,
v_buffer,
req_to_token,
b_req_idx,
b_seq_len,
scaling=sm_scale,
enable_gqa=enable_gqa,
)
cos_sim = torch.nn.functional.cosine_similarity(
o.flatten(), o_grouped.flatten(), dim=0
)
self.assertGreater(cos_sim.item(), 0.99)
torch.testing.assert_close(o, o_grouped, atol=3e-2, rtol=1e-6)
def _test_grouped_decode_attention(self, device="cuda"):
configs = [
(2, 16, 16, 64, 64),
(2, 16, 1, 16, 16),
(2, 32, 8, 33, 55),
(2, 16, 1, 64, 64),
(2, 64, 1, 13, 13),
(2, 128, 1, 80, 80),
(2, 128, 2, 512, 512),
(1, 16, 1, 576, 512),
(1, 16, 16, 576, 512),
(1, 22, 1, 576, 512),
(1, 40, 8, 128, 128),
]
for B, H_Q, H_KV, D, D_V in configs:
self._test_grouped_decode_attention_once(
B, H_Q, H_KV, D, D_V, device=device
)
def test_grouped_decode_attention(self):
self._test_grouped_decode_attention("cpu")
if __name__ == "__main__":
unittest.main()

190
test/srt/cpu/test_extend.py Normal file
View File

@@ -0,0 +1,190 @@
import unittest
import sgl_kernel
import torch
from torch.nn.functional import scaled_dot_product_attention
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class TestExtendAttention(CustomTestCase):
def _run_sdpa_forward_extend(
self,
query: torch.Tensor,
output: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
extend_prefix_lens: torch.Tensor,
extend_seq_lens: torch.Tensor,
scaling=None,
enable_gqa=False,
causal=False,
):
assert seq_lens.shape[0] == extend_prefix_lens.shape[0]
assert seq_lens.shape[0] == extend_seq_lens.shape[0]
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
query = query.movedim(0, query.dim() - 2)
start_q, start_kv = 0, 0
for seq_idx in range(seq_lens.shape[0]):
extend_seq_len_q = extend_seq_lens[seq_idx]
prefill_seq_len_q = extend_prefix_lens[seq_idx]
seq_len_kv = seq_lens[seq_idx]
end_q = start_q + extend_seq_len_q
end_kv = start_kv + seq_len_kv
per_req_query = query[:, start_q:end_q, :]
per_req_query_redudant = torch.empty(
(per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),
dtype=per_req_query.dtype,
device=per_req_query.device,
)
per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query
# get key and value from cache. per_req_tokens contains the kv cache
# index for each token in the sequence.
req_pool_idx = req_pool_indices[seq_idx]
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_out_redudant = (
scaled_dot_product_attention(
per_req_query_redudant.unsqueeze(0),
per_req_key.unsqueeze(0),
per_req_value.unsqueeze(0),
enable_gqa=enable_gqa,
scale=scaling,
is_causal=causal,
)
.squeeze(0)
.movedim(query.dim() - 2, 0)
)
output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :]
start_q, start_kv = end_q, end_kv
return output
def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D, DV, mla=False):
dtype = torch.bfloat16
b_seq_len_prefix = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32)
if mla:
b_seq_len_prefix.zero_()
b_seq_len_extend = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32)
b_seq_len = b_seq_len_prefix + b_seq_len_extend
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
b_req_idx = torch.arange(B, dtype=torch.int32)
req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32)
b_start_loc = torch.zeros((B,), dtype=torch.int32)
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32)
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
for i in range(B):
req_to_tokens[i, : b_seq_len[i]] = torch.arange(
b_start_loc[i], b_start_loc[i] + b_seq_len[i]
)
total_token_num = torch.sum(b_seq_len).item()
extend_token_num = torch.sum(b_seq_len_extend).item()
H_BUF = 1 if mla else H_KV
k_buffer = torch.randn((total_token_num, H_BUF, D), dtype=dtype)
v_buffer = torch.randn((total_token_num, H_BUF, DV), dtype=dtype)
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype)
v_extend = torch.empty((extend_token_num, H_KV, DV), dtype=dtype)
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype)
for i in range(B):
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
extend_start = b_start_loc_extend[i]
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
k_extend[extend_start:extend_end] = k_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
v_extend[extend_start:extend_end] = v_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
q_extend[extend_start:extend_end] = torch.randn(
(b_seq_len_extend[i], H_Q, D), dtype=dtype
)
# q_extend, k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors
q_extend = q_extend.transpose(0, 1).contiguous().transpose(0, 1)
k_extend = k_extend.transpose(0, 1).contiguous().transpose(0, 1)
v_extend = v_extend.transpose(0, 1).contiguous().transpose(0, 1)
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
b_seq_len_extend = b_seq_len - b_seq_len_prefix
b_start_loc_extend = torch.zeros_like(b_seq_len)
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
sm_scale = 1.0 / (D**0.5)
logit_cap = 0.0
# handle index type
b_req_idx = b_req_idx.to(torch.int64)
b_seq_len = b_seq_len.to(torch.int64)
enable_gqa = H_Q != H_KV
o_ref = torch.empty((extend_token_num, H_Q, DV), dtype=dtype)
self._run_sdpa_forward_extend(
q_extend,
o_ref,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_seq_len,
b_seq_len_prefix,
b_seq_len_extend,
scaling=sm_scale,
enable_gqa=enable_gqa,
causal=True,
)
o_extend = torch.empty((extend_token_num, H_Q, DV), dtype=dtype)
torch.ops.sgl_kernel.extend_attention_cpu(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_seq_len,
b_seq_len_extend,
b_start_loc_extend,
max_len_extend,
sm_scale,
logit_cap,
)
torch.testing.assert_close(o_ref, o_extend, atol=1e-2, rtol=1e-2)
def test_extend_attention(self):
for is_mla in [True, False]:
self._test_extend_attention_once(1, 123, 1, 1, 128, 96, is_mla)
self._test_extend_attention_once(1, 123, 16, 1, 128, 96, is_mla)
self._test_extend_attention_once(4, 1230, 16, 4, 128, 96, is_mla)
if __name__ == "__main__":
unittest.main()

189
test/srt/cpu/test_gemm.py Normal file
View File

@@ -0,0 +1,189 @@
import itertools
import unittest
# TODO: use interface in cpu.py
import sgl_kernel
import torch
import torch.nn as nn
from utils import (
convert_weight,
native_w8a8_per_token_matmul,
per_token_quant_int8,
precision,
)
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class Mod(nn.Module):
def __init__(self, input_channel, output_channel, has_bias):
super(Mod, self).__init__()
self.linear = torch.nn.Linear(input_channel, output_channel, has_bias)
def forward(self, x):
return self.linear(x)
class TestGemm(CustomTestCase):
M = [1, 101]
N = [16, 32 * 13]
K = [32 * 16]
has_bias = [False, True]
M_int8 = [2, 128]
N_int8 = [32 * 12]
K_int8 = [32 * 17]
M_fp8 = [1, 11]
N_fp8 = [128, 224]
K_fp8 = [512, 576]
def _bf16_gemm(self, M, N, K, has_bias):
mat1 = torch.randn(M, K, dtype=torch.bfloat16)
mat2 = torch.randn(N, K, dtype=torch.bfloat16)
ref = torch.matmul(mat1.float(), mat2.float().t())
if has_bias:
bias = torch.randn(N, dtype=torch.float32)
ref.add_(bias.bfloat16())
ref = ref.bfloat16()
out = torch.ops.sgl_kernel.weight_packed_linear(
mat1, mat2, bias if has_bias else None, False
)
packed_mat2 = torch.ops.sgl_kernel.convert_weight_packed(mat2)
out2 = torch.ops.sgl_kernel.weight_packed_linear(
mat1, packed_mat2, bias if has_bias else None, True
)
atol = rtol = precision[ref.dtype]
torch.testing.assert_close(ref, out, atol=atol, rtol=rtol)
torch.testing.assert_close(ref, out2, atol=atol, rtol=rtol)
def test_bf16_gemm(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.has_bias,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
has_bias=params[3],
):
self._bf16_gemm(*params)
def _int8_gemm(self, M, N, K, has_bias):
dtype = torch.bfloat16
A = torch.randn((M, K), dtype=dtype) / 10
Aq, As = per_token_quant_int8(A)
factor_for_scale = 1e-2
int8_max = 127
int8_min = -128
B = (torch.rand((N, K), dtype=torch.float32) - 0.5) * 2
Bq = (B * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
Bs = torch.rand(N) * factor_for_scale
bias = torch.randn(N) if has_bias else None
ref_out = native_w8a8_per_token_matmul(Aq, Bq, As, Bs, bias, dtype)
atol = rtol = precision[ref_out.dtype]
Aq2, As2 = torch.ops.sgl_kernel.per_token_quant_int8_cpu(A)
out = torch.ops.sgl_kernel.int8_scaled_mm_cpu(
Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False
)
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
# test the fused version
fused_out = torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False
)
torch.testing.assert_close(ref_out, fused_out, atol=atol, rtol=rtol)
def test_int8_gemm(self):
for params in itertools.product(
self.M_int8,
self.N_int8,
self.K_int8,
self.has_bias,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
has_bias=params[3],
):
self._int8_gemm(*params)
def _fp8_gemm(self, M, N, K, has_bias):
prepack = True
chunk = False
scale_block_size_N = 64
scale_block_size_K = 128
assert scale_block_size_N <= N
assert scale_block_size_K <= K
A_dtype = torch.bfloat16
model = Mod(K, N, has_bias).eval()
if chunk:
data = torch.randn(M, K + 6, dtype=A_dtype).narrow(1, 0, K)
else:
data = torch.randn(M, K, dtype=A_dtype)
weight = model.linear.weight # (N, K)
if has_bias:
bias = model.linear.bias
fp8_weight, scales, dq_weight = convert_weight(
weight, [scale_block_size_N, scale_block_size_K], A_dtype
)
if has_bias:
ref = torch.matmul(data.to(A_dtype), dq_weight.T) + bias.to(A_dtype)
else:
ref = torch.matmul(data.to(A_dtype), dq_weight.T)
if prepack:
fp8_weight = torch.ops.sgl_kernel.convert_weight_packed(fp8_weight)
opt = torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
data,
fp8_weight,
scales,
[scale_block_size_N, scale_block_size_K],
bias if has_bias else None,
data.dtype,
prepack,
)
atol = rtol = precision[ref.dtype]
torch.testing.assert_close(ref, opt, atol=atol, rtol=rtol)
def test_fp8_gemm(self):
for params in itertools.product(
self.M_fp8,
self.N_fp8,
self.K_fp8,
self.has_bias,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
has_bias=params[3],
):
self._fp8_gemm(*params)
if __name__ == "__main__":
unittest.main()

157
test/srt/cpu/test_mla.py Normal file
View File

@@ -0,0 +1,157 @@
import itertools
import unittest
import sgl_kernel
import torch
from torch.nn.functional import scaled_dot_product_attention
from utils import precision
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class TestMLA(CustomTestCase):
def _run_sdpa_forward_decode(
self,
query: torch.Tensor,
output: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
key: torch.Tensor,
loc: torch.Tensor,
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
scaling=None,
enable_gqa=False,
causal=False,
):
# set kv cache
k_cache[loc] = key
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
query = query.movedim(0, query.dim() - 2)
start_q, start_kv = 0, 0
for seq_idx in range(seq_lens.shape[0]):
seq_len_q = 1
seq_len_kv = seq_lens[seq_idx]
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
per_req_query = query[:, start_q:end_q, :]
# get key and value from cache. per_req_tokens contains the kv cache
# index for each token in the sequence.
req_pool_idx = req_pool_indices[seq_idx]
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_out = (
scaled_dot_product_attention(
per_req_query.unsqueeze(0),
per_req_key.unsqueeze(0),
per_req_value.unsqueeze(0),
enable_gqa=enable_gqa,
scale=scaling,
is_causal=causal,
)
.squeeze(0)
.movedim(query.dim() - 2, 0)
)
output[start_q:end_q, :, :] = per_req_out
start_q, start_kv = end_q, end_kv
return output
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, seq_len):
dtype = torch.bfloat16
total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5)
logit_cap = 0.0
num_kv_splits = 8
enable_gqa = H_Q != H_KV
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D, dtype=dtype)
# k_buffer and v_buffer represent all previous tokens
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype)
v_buffer = k_buffer.narrow(2, 0, D_V)
key = torch.randn(B, H_KV, D, dtype=dtype)
value = key.narrow(2, 0, D_V)
# make sure no duplicates in loc
loc = torch.randperm(total_tokens)[:B].to(torch.int64)
k_buffer2 = k_buffer.clone()
v_buffer2 = k_buffer2.narrow(2, 0, D_V)
# o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype)
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype)
req_to_token = torch.arange(total_tokens).reshape(B, seq_len).to(torch.int32)
b_req_idx = torch.arange(B).to(torch.int64)
b_seq_len = torch.full((B,), seq_len).to(torch.int64)
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
)
torch.ops.sgl_kernel.decode_attention_cpu(
q,
k_buffer2,
v_buffer2,
o,
key,
value,
loc,
attn_logits,
req_to_token,
b_req_idx,
b_seq_len,
sm_scale,
logit_cap,
)
self._run_sdpa_forward_decode(
q,
o_grouped,
k_buffer,
v_buffer,
key,
loc,
req_to_token,
b_req_idx,
b_seq_len,
scaling=sm_scale,
enable_gqa=enable_gqa,
)
cos_sim = torch.nn.functional.cosine_similarity(
o.flatten(), o_grouped.flatten(), dim=0
)
atol = rtol = precision[q.dtype]
self.assertGreater(cos_sim.item(), 0.99)
torch.testing.assert_close(o, o_grouped, atol=atol, rtol=rtol)
torch.testing.assert_close(k_buffer, k_buffer2, atol=atol, rtol=rtol)
torch.testing.assert_close(v_buffer, v_buffer2, atol=atol, rtol=rtol)
def test_grouped_decode_attention(self):
configs = [
(1, 22, 1, 576, 512, 8 * 111),
(4, 22, 1, 576, 512, 8 * 128),
(40, 22, 1, 576, 512, 8 * 133),
]
for B, H_Q, H_KV, D, D_V, seqlen in configs:
self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V, seqlen)
if __name__ == "__main__":
unittest.main()

265
test/srt/cpu/test_moe.py Normal file
View File

@@ -0,0 +1,265 @@
import itertools
import math
import unittest
# TODO: use interface in cpu.py
import sgl_kernel
import torch
kernel = torch.ops.sgl_kernel
torch.manual_seed(1234)
from utils import (
BLOCK_K,
BLOCK_N,
factor_for_scale,
fp8_max,
fp8_min,
native_fp8_fused_moe,
precision,
scaled_weight,
torch_naive_fused_moe,
torch_w8a8_per_column_fused_moe,
)
from sglang.test.test_utils import CustomTestCase
def fused_moe(a, w1, w2, score, topk, renormalize, prepack):
G = 1
topk_group = 1
B, D = a.shape
topk_weights = torch.empty(B, topk, dtype=torch.float32)
topk_ids = torch.empty(B, topk, dtype=torch.int32)
topk_weights, topk_ids = kernel.grouped_topk_cpu(
a, score, topk, renormalize, G, topk_group, 0, None, None
)
packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1
packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2
inplace = True
return kernel.fused_experts_cpu(
a,
packed_w1,
packed_w2,
topk_weights,
topk_ids,
inplace,
False,
False,
None,
None,
None,
None,
None,
prepack,
)
class TestFusedExperts(CustomTestCase):
M = [2, 114]
N = [32]
K = [32]
E = [4]
topk = [2]
renormalize = [False, True]
M_int8 = [1, 39]
N_int8 = [128]
K_int8 = [256]
E_int8 = [8]
topk_int8 = [3]
M_fp8 = [2, 121]
N_fp8 = [352, 512]
K_fp8 = [256, 320]
E_fp8 = [8]
topk_fp8 = [4]
def _bf16_moe(self, m, n, k, e, topk, renormalize):
dtype = torch.bfloat16
prepack = True
a = torch.randn((m, k), device="cpu", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cpu", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cpu", dtype=dtype) / 10
score = torch.randn((m, e), device="cpu", dtype=dtype)
torch_output = torch_naive_fused_moe(a, w1, w2, score, topk, renormalize)
fused_output = fused_moe(a, w1, w2, score, topk, renormalize, prepack)
atol = rtol = precision[torch_output.dtype]
torch.testing.assert_close(torch_output, fused_output, atol=atol, rtol=rtol)
def test_bf16_moe(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.E,
self.topk,
self.renormalize,
):
with self.subTest(
m=params[0],
n=params[1],
k=params[2],
e=params[3],
topk=params[4],
renormalize=params[5],
):
self._bf16_moe(*params)
def _int8_moe(self, M, N, K, E, topk):
dtype = torch.bfloat16
prepack = True
# Initialize int8 quantization parameters
int8_factor_for_scale = 1e-2
int8_max = 127
int8_min = -128
# Input tensor
# M * K
a = torch.randn((M, K), dtype=dtype) / math.sqrt(K)
# Generate int8 weights
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
# Generate scale for each column (per-column quantization)
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * int8_factor_for_scale
w2_s = torch.rand(E, K, device=w2_fp32.device) * int8_factor_for_scale
# Calculate routing
score = torch.randn((M, E), dtype=dtype)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
ref_out = torch_w8a8_per_column_fused_moe(
a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk
)
inplace = True
packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1
packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2
out = kernel.fused_experts_cpu(
a,
packed_w1,
packed_w2,
topk_weight,
topk_ids.to(torch.int32),
inplace,
True,
False,
w1_s,
w2_s,
None,
None,
None,
prepack,
)
atol = rtol = precision[ref_out.dtype]
# Increase the tolerance for large input shapes
if M > 35:
atol = rtol = 0.02
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_int8_moe(self):
for params in itertools.product(
self.M_int8,
self.N_int8,
self.K_int8,
self.E_int8,
self.topk_int8,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
topk=params[4],
):
self._int8_moe(*params)
def _fp8_moe(self, M, N, K, E, topk):
dtype = torch.bfloat16
a = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
w1_fp32 = torch.randn(E, 2 * N, K)
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w2_fp32 = torch.randn(E, K, N)
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w1s = (
torch.randn(E, math.ceil(2 * N / BLOCK_N), math.ceil(K / BLOCK_K))
* factor_for_scale
)
w2s = (
torch.randn(E, math.ceil(K / BLOCK_N), math.ceil(N / BLOCK_K))
* factor_for_scale
)
w1_scaled = scaled_weight(w1, w1s)
w2_scaled = scaled_weight(w2, w2s)
score = torch.randn((M, E), dtype=dtype)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
w1 = kernel.convert_weight_packed(w1)
w2 = kernel.convert_weight_packed(w2)
ref_out = native_fp8_fused_moe(
a, w1_scaled, w2_scaled, topk_weight, topk_ids, topk
)
out = kernel.fused_experts_cpu(
a,
w1,
w2,
topk_weight,
topk_ids.to(torch.int32),
False,
False,
True,
w1s,
w2s,
[BLOCK_N, BLOCK_K],
None,
None,
True,
)
atol = rtol = precision[dtype]
torch.testing.assert_close(ref_out.bfloat16(), out, atol=atol, rtol=rtol)
def test_fp8_moe(self):
for params in itertools.product(
self.M_fp8,
self.N_fp8,
self.K_fp8,
self.E_fp8,
self.topk_fp8,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
topk=params[4],
):
self._fp8_moe(*params)
if __name__ == "__main__":
unittest.main()

90
test/srt/cpu/test_norm.py Normal file
View File

@@ -0,0 +1,90 @@
import itertools
import unittest
from typing import Optional, Tuple, Union
import sgl_kernel
import torch
from utils import make_non_contiguous, precision
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class TestNorm(CustomTestCase):
M = [4096, 1024]
N = [4096, 4096 + 13]
dtype = [torch.float16, torch.bfloat16]
def _forward_native(
self,
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float = 1e-6,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
x = x.to(orig_dtype) * weight
if residual is None:
return x
else:
return x, residual
def _norm_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
x = make_non_contiguous(x)
hidden_size = x.size(-1)
weight = torch.randn(hidden_size, dtype=dtype)
variance_epsilon = 1e-6
out = torch.ops.sgl_kernel.rmsnorm_cpu(x, weight, variance_epsilon)
ref_out = self._forward_native(x, weight, variance_epsilon)
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
ref_x = x.clone()
residual = torch.randn([m, hidden_size], dtype=dtype)
ref_residual = residual.clone()
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
x, residual, weight, variance_epsilon
)
ref_x, ref_residual = self._forward_native(
ref_x, weight, variance_epsilon, ref_residual
)
torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol)
torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol)
def _l2norm_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
hidden_size = x.size(-1)
fake_ones_weight = torch.ones(hidden_size, dtype=dtype)
variance_epsilon = 1e-6
out = torch.ops.sgl_kernel.l2norm_cpu(x, variance_epsilon)
ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon)
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_norm(self):
for params in itertools.product(self.M, self.N, self.dtype):
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
self._norm_test(*params)
self._l2norm_test(*params)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,432 @@
import unittest
import sgl_kernel
import torch
from utils import (
convert_weight,
native_w8a8_per_token_matmul,
per_token_quant_int8,
precision,
)
from sglang.srt.layers.rotary_embedding import _apply_rotary_emb
from sglang.test.test_utils import CustomTestCase
convert_weight_packed = torch.ops.sgl_kernel.convert_weight_packed
qkv_proj_with_rope = torch.ops.sgl_kernel.qkv_proj_with_rope
qkv_proj_with_rope_fused_weight = torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight
torch.manual_seed(1234)
# constants
kv_lora_rank = 512
qk_head_dim = 192
qk_nope_head_dim = 128
qk_rope_head_dim = 64
rotary_dim = qk_rope_head_dim
num_heads = 22
q_lora_rank = 1536
hidden_size = 7168
B = 1
eps = 1e-6
def layernorm(x, weight, variance_epsilon=1e-6, residual=None):
orig_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
return (x * weight).to(orig_dtype)
def rotary_emb(q_pe, k_pe, pos, cos_sin_cache):
orig_dtype = q_pe.dtype
q_pe = q_pe.float()
k_pe = k_pe.float()
cos_sin_cache = cos_sin_cache.float()
query_rot = q_pe[..., :rotary_dim]
key_rot = k_pe[..., :rotary_dim]
cos_sin = cos_sin_cache[pos]
cos, sin = cos_sin.chunk(2, dim=-1)
query_rot = _apply_rotary_emb(query_rot, cos, sin, False)
key_rot = _apply_rotary_emb(key_rot, cos, sin, False)
return query_rot.to(orig_dtype), key_rot.to(orig_dtype)
def native_torch(
q_input,
hidden_states,
q_a_proj_weight,
norm_weight1,
q_b_proj_weight,
w_kc,
kv_a_proj_weight,
norm_weight2,
pos,
cos_sin_cache,
):
q = torch.matmul(hidden_states, q_a_proj_weight.t())
q = layernorm(q, norm_weight1)
q = torch.matmul(q, q_b_proj_weight.t()).view(-1, num_heads, qk_head_dim)
q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc)
q_input[..., :kv_lora_rank] = q_nope_out.transpose(0, 1)
latent_cache = torch.matmul(hidden_states, kv_a_proj_weight.t())
v_input = latent_cache[..., :kv_lora_rank]
v_input = layernorm(v_input.contiguous(), norm_weight2).unsqueeze(1)
k_input = latent_cache.unsqueeze(1)
k_input[..., :kv_lora_rank] = v_input
k_pe = k_input[..., kv_lora_rank:]
q_pe, k_pe = rotary_emb(q_pe, k_pe, pos, cos_sin_cache)
q_input[..., kv_lora_rank:] = q_pe
k_input[..., kv_lora_rank:] = k_pe
return q_input, k_input, v_input
def native_torch_int8(
q_input,
hidden_states,
w1_q,
w1_s,
norm_weight1,
w2_q,
w2_s,
w_kc,
w3_q,
w3_s,
norm_weight2,
pos,
cos_sin_cache,
):
a_q, a_s = per_token_quant_int8(hidden_states)
q = native_w8a8_per_token_matmul(a_q, w1_q, a_s, w1_s, None, torch.bfloat16)
q = layernorm(q, norm_weight1)
a_q, a_s = per_token_quant_int8(q)
q = native_w8a8_per_token_matmul(a_q, w2_q, a_s, w2_s, None, torch.bfloat16).view(
-1, num_heads, qk_head_dim
)
q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc)
q_input[..., :kv_lora_rank] = q_nope_out.transpose(0, 1)
a_q, a_s = per_token_quant_int8(hidden_states)
latent_cache = native_w8a8_per_token_matmul(
a_q, w3_q, a_s, w3_s, None, torch.bfloat16
)
v_input = latent_cache[..., :kv_lora_rank]
v_input = layernorm(v_input.contiguous(), norm_weight2).unsqueeze(1)
k_input = latent_cache.unsqueeze(1)
k_input[..., :kv_lora_rank] = v_input
k_pe = k_input[..., kv_lora_rank:]
q_pe, k_pe = rotary_emb(q_pe, k_pe, pos, cos_sin_cache)
q_input[..., kv_lora_rank:] = q_pe
k_input[..., kv_lora_rank:] = k_pe
return q_input, k_input, v_input
class TestQKVProjWithROPE(CustomTestCase):
def test_bf16_qkv_proj_with_rope(self):
dtype = torch.bfloat16
hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size
q_input = torch.empty(
B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype
)
q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1
norm_weight1 = torch.randn(q_lora_rank, dtype=dtype)
q_b_proj_weight = (
torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1
)
w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1
kv_a_proj_weight = (
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
)
fused_weight = torch.cat([q_a_proj_weight, kv_a_proj_weight], dim=0)
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
pos = torch.randint(10, 100, (B,))
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
q_ref, k_ref, v_ref = native_torch(
q_input,
hidden_states,
q_a_proj_weight,
norm_weight1,
q_b_proj_weight,
w_kc.transpose(1, 2),
kv_a_proj_weight,
norm_weight2,
pos,
cos_sin_cache,
)
qa_packed = convert_weight_packed(q_a_proj_weight)
qb_packed = convert_weight_packed(q_b_proj_weight)
kva_packed = convert_weight_packed(kv_a_proj_weight)
wkc_packed = convert_weight_packed(w_kc)
fused_weight_packed = convert_weight_packed(fused_weight)
q_out, k_out, v_out = qkv_proj_with_rope(
hidden_states,
qa_packed,
qb_packed,
kva_packed,
wkc_packed,
norm_weight1,
norm_weight2,
pos,
cos_sin_cache,
eps,
False,
False,
None,
None,
None,
True,
None,
)
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
hidden_states,
fused_weight_packed,
qb_packed,
wkc_packed,
norm_weight1,
norm_weight2,
pos,
cos_sin_cache,
eps,
False,
False,
None,
None,
True,
None,
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
)
atol = rtol = precision[q_ref.dtype]
torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol)
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_q_out, q_out)
torch.testing.assert_close(fused_k_out, k_out)
torch.testing.assert_close(fused_v_out, v_out)
def test_int8_qkv_proj_with_rope(self):
dtype = torch.bfloat16
hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size
q_input = torch.empty(
B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype
)
q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1
norm_weight1 = torch.randn(q_lora_rank, dtype=dtype)
q_b_proj_weight = (
torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1
)
w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1
kv_a_proj_weight = (
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
)
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
pos = torch.randint(10, 100, (B,))
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
w1_q, w1_s = per_token_quant_int8(q_a_proj_weight)
w2_q, w2_s = per_token_quant_int8(q_b_proj_weight)
w3_q, w3_s = per_token_quant_int8(kv_a_proj_weight)
q_ref, k_ref, v_ref = native_torch_int8(
q_input,
hidden_states,
w1_q,
w1_s,
norm_weight1,
w2_q,
w2_s,
w_kc.transpose(1, 2),
w3_q,
w3_s,
norm_weight2,
pos,
cos_sin_cache,
)
w1_q_packed = convert_weight_packed(w1_q)
w2_q_packed = convert_weight_packed(w2_q)
w3_q_packed = convert_weight_packed(w3_q)
wkc_packed = convert_weight_packed(w_kc)
q_out, k_out, v_out = qkv_proj_with_rope(
hidden_states,
w1_q_packed,
w2_q_packed,
w3_q_packed,
wkc_packed,
norm_weight1,
norm_weight2,
pos,
cos_sin_cache,
eps,
True,
False,
w1_s,
w2_s,
w3_s,
True,
None,
)
fused_weight = torch.cat([w1_q, w3_q], dim=0)
fused_weight_s = torch.cat([w1_s, w3_s], dim=0)
w_fused_q_packed = convert_weight_packed(fused_weight)
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
hidden_states,
w_fused_q_packed,
w2_q_packed,
wkc_packed,
norm_weight1,
norm_weight2,
pos,
cos_sin_cache,
eps,
True,
False,
fused_weight_s,
w2_s,
True,
None,
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
)
atol = rtol = precision[q_ref.dtype]
torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol)
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_q_out, q_out)
torch.testing.assert_close(fused_k_out, k_out)
torch.testing.assert_close(fused_v_out, v_out)
def test_fp8_qkv_proj_with_rope(self):
dtype = torch.bfloat16
hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size
q_input = torch.empty(
B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype
)
q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1
norm_weight1 = torch.randn(q_lora_rank, dtype=dtype)
q_b_proj_weight = (
torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1
)
w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1
kv_a_proj_weight = (
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
)
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
pos = torch.randint(10, 100, (B,))
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
scale_block_size_N = 128
scale_block_size_K = 128
fp8_q_a_proj_weight, q_a_proj_weight_scale_inv, q_a_proj_weight_dq = (
convert_weight(
q_a_proj_weight,
[scale_block_size_N, scale_block_size_K],
torch.bfloat16,
)
)
fp8_q_b_proj_weight, q_b_proj_weight_scale_inv, q_b_proj_weight_dq = (
convert_weight(
q_b_proj_weight,
[scale_block_size_N, scale_block_size_K],
torch.bfloat16,
)
)
(
fp8_kv_a_proj_with_mqa_weight,
kv_a_proj_with_mqa_weight_scale_inv,
kv_a_proj_with_mqa_weight_dq,
) = convert_weight(
kv_a_proj_weight, [scale_block_size_N, scale_block_size_K], torch.bfloat16
)
q_ref, k_ref, v_ref = native_torch(
q_input,
hidden_states,
q_a_proj_weight_dq,
norm_weight1,
q_b_proj_weight_dq,
w_kc.transpose(1, 2),
kv_a_proj_with_mqa_weight_dq,
norm_weight2,
pos,
cos_sin_cache,
)
fp8_q_a_proj_weight_packed = convert_weight_packed(fp8_q_a_proj_weight)
fp8_q_b_proj_weight_packed = convert_weight_packed(fp8_q_b_proj_weight)
fp8_kv_a_proj_with_mqa_weight_packed = convert_weight_packed(
fp8_kv_a_proj_with_mqa_weight
)
w_kc = convert_weight_packed(w_kc)
q_out, k_out, v_out = qkv_proj_with_rope(
hidden_states,
fp8_q_a_proj_weight_packed,
fp8_q_b_proj_weight_packed,
fp8_kv_a_proj_with_mqa_weight_packed,
w_kc,
norm_weight1,
norm_weight2,
pos,
cos_sin_cache,
eps,
False,
True,
q_a_proj_weight_scale_inv.float(),
q_b_proj_weight_scale_inv.float(),
kv_a_proj_with_mqa_weight_scale_inv.float(),
True,
[scale_block_size_N, scale_block_size_K],
)
fused_weight = torch.cat(
[fp8_q_a_proj_weight, fp8_kv_a_proj_with_mqa_weight], dim=0
)
fused_weight_s = torch.cat(
[q_a_proj_weight_scale_inv, kv_a_proj_with_mqa_weight_scale_inv], dim=0
)
fused_weight_packed = convert_weight_packed(fused_weight)
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
hidden_states,
fused_weight_packed,
fp8_q_b_proj_weight_packed,
w_kc,
norm_weight1,
norm_weight2,
pos,
cos_sin_cache,
eps,
False,
True,
fused_weight_s.float(),
q_b_proj_weight_scale_inv.float(),
True,
[scale_block_size_N, scale_block_size_K],
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
)
atol = rtol = precision[q_ref.dtype]
# Due to the change in multiplication order, the error is amplified.
# In the model, with fewer layers, this doesn't cause issues, but in
# tests with more layers, we need to enlarge the tolerance to pass the tests.
torch.testing.assert_close(q_ref, q_out, atol=1e-1, rtol=1e-1)
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_q_out, q_out)
torch.testing.assert_close(fused_k_out, k_out)
torch.testing.assert_close(fused_v_out, v_out)
if __name__ == "__main__":
unittest.main()

178
test/srt/cpu/test_rope.py Normal file
View File

@@ -0,0 +1,178 @@
import unittest
import sgl_kernel
import torch
from utils import precision
from sglang.srt.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding,
RotaryEmbedding,
)
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class TestROPE(CustomTestCase):
def test_deepseek_v2_rope(self):
num_head = 16
seq_len = 1024
q_head_dim = 192
qk_nope_head_dim = 128
qk_rope_head_dim = 64
max_pos = 256
k_dim = 576
rotary_dim = 64
is_neox_style = False
# Create cos_sin_cache
freqs = torch.rand(max_pos, qk_rope_head_dim // 2)
cos = freqs.cos() * 0.7
sin = freqs.sin() * 0.7
cos_sin_cache = torch.cat((cos, sin), dim=-1).to(torch.bfloat16)
positions = torch.randint(0, max_pos, (seq_len,))
rope = DeepseekScalingRotaryEmbedding(
qk_rope_head_dim,
rotary_dim,
max_pos,
16, # not used since cos_sin_cache is provided
is_neox_style,
1.0,
torch.bfloat16,
device="cpu",
)
rope.register_buffer("cos_sin_cache", cos_sin_cache)
for dtype in [torch.bfloat16]:
enable_autocast = True
with torch.no_grad(), torch.amp.autocast("cpu", enabled=enable_autocast):
q = torch.randn(seq_len, num_head, q_head_dim, dtype=dtype)
q_clone = q.clone()
k = torch.randn(seq_len, 1, k_dim, dtype=dtype)
k_clone = k.clone()
_, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
_, q_pe_clone = q_clone.split(
[qk_nope_head_dim, qk_rope_head_dim], dim=-1
)
k_pe = k[:, :, k_dim - qk_rope_head_dim :]
k_pe_clone = k_clone[:, :, k_dim - qk_rope_head_dim :]
# ref kernel
q_pe, k_pe = rope.forward_native(
query=q_pe,
key=k_pe,
positions=positions,
)
# fused rope kernel
q_pe_clone, k_pe_clone = torch.ops.sgl_kernel.rotary_embedding_cpu(
positions,
q_pe_clone,
k_pe_clone,
rope.head_size,
cos_sin_cache,
False,
)
atol = rtol = precision[q_pe.dtype]
torch.testing.assert_close(q_pe, q_pe_clone, atol=atol, rtol=rtol)
torch.testing.assert_close(k_pe, k_pe_clone, atol=atol, rtol=rtol)
torch.testing.assert_close(k_pe, k_pe_clone)
def test_origin_rope(self):
def single_test(
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
device: str,
batch_size: int,
seq_len: int,
num_q_heads: int,
num_kv_heads: int,
):
torch.manual_seed(100)
rope_ref = RotaryEmbedding(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
).to(device)
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
query = torch.randn(
batch_size * seq_len,
num_q_heads * head_size,
dtype=dtype,
device=device,
)
key = torch.randn(
batch_size * seq_len,
num_kv_heads * head_size,
dtype=dtype,
device=device,
)
query_ref, key_ref = query.clone(), key.clone()
query_cpu, key_cpu = query.clone(), key.clone()
query_ref_out, key_ref_out = rope_ref.forward_native(
pos_ids, query_ref, key_ref
)
query_cpu_out, key_cpu_out = torch.ops.sgl_kernel.rotary_embedding_cpu(
pos_ids,
query_cpu,
key_cpu,
rope_ref.head_size,
rope_ref.cos_sin_cache.to(query.dtype),
rope_ref.is_neox_style,
)
torch.testing.assert_close(
query_ref_out, query_cpu_out, atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(key_ref_out, key_cpu_out, atol=1e-2, rtol=1e-2)
test_config = [
(64, 64, 32, 8000, True, torch.bfloat16, "cpu", 32, 32, 1, 1),
(256, 128, 4096, 10000, True, torch.bfloat16, "cpu", 2, 512, 32, 8),
(512, 128, 311, 10000, True, torch.bfloat16, "cpu", 3, 39, 4, 2),
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 32, 8),
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 16, 4),
(512, 128, 311, 10000, False, torch.bfloat16, "cpu", 3, 39, 4, 2),
]
for (
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
batch_size,
seq_len,
num_q_heads,
num_kv_heads,
) in test_config:
single_test(
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
batch_size,
seq_len,
num_q_heads,
num_kv_heads,
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,223 @@
import itertools
import math
import unittest
# TODO: use interface in cpu.py
import sgl_kernel
import torch
import torch.nn as nn
from utils import (
BLOCK_K,
BLOCK_N,
SiluAndMul,
factor_for_scale,
fp8_max,
fp8_min,
per_token_quant_int8,
precision,
scaled_weight,
torch_naive_moe,
torch_w8a8_per_column_moe,
)
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class TestSharedExpert(CustomTestCase):
M = [2, 121]
N = [32, 32 * 4]
K = [32, 32 * 2]
routed_scaling_factor = [16]
M_fp8 = [2, 12]
N_fp8 = [512]
K_fp8 = [256]
def _bf16_shared_expert(self, m, n, k, routed_scaling_factor):
dtype = torch.bfloat16
prepack = True
hidden_states = torch.randn(m, k, dtype=dtype) / k
w1 = torch.randn(2 * n, k, dtype=dtype)
w2 = torch.randn(k, n, dtype=dtype)
fused_output = torch.randn(m, k, dtype=dtype) / k
# fused moe mutates content in hs
hidden_states2 = hidden_states.clone()
# bfloat16
ref = torch_naive_moe(
hidden_states.float(),
w1.float(),
w2.float(),
fused_output.float(),
routed_scaling_factor,
).to(dtype=dtype)
res = torch.ops.sgl_kernel.shared_expert_cpu(
hidden_states,
w1,
w2,
fused_output,
routed_scaling_factor,
True,
False,
False,
None,
None,
None,
None,
None,
False,
)
atol = rtol = precision[ref.dtype]
torch.testing.assert_close(ref, res, atol=atol, rtol=rtol)
def test_bf16_shared_expert(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.routed_scaling_factor,
):
with self.subTest(
m=params[0],
n=params[1],
k=params[2],
routed_scaling_factor=params[3],
):
self._bf16_shared_expert(*params)
def _int8_shared_expert(self, m, n, k, routed_scaling_factor):
dtype = torch.bfloat16
prepack = True
hidden_states = torch.randn(m, k, dtype=dtype) / k
w1 = torch.randn(2 * n, k, dtype=dtype)
w2 = torch.randn(k, n, dtype=dtype)
fused_output = torch.randn(m, k, dtype=dtype) / k
# fused moe mutates content in hs
hidden_states2 = hidden_states.clone()
w1_q, w1_s = per_token_quant_int8(w1)
w2_q, w2_s = per_token_quant_int8(w2)
ref2 = torch_w8a8_per_column_moe(
hidden_states2.float(),
w1_q,
w2_q,
w1_s,
w2_s,
fused_output.float(),
routed_scaling_factor,
).to(dtype=dtype)
res2 = torch.ops.sgl_kernel.shared_expert_cpu(
hidden_states2,
w1_q,
w2_q,
fused_output,
routed_scaling_factor,
True,
True,
False,
w1_s,
w2_s,
None,
None,
None,
False,
)
atol = rtol = precision[ref2.dtype]
torch.testing.assert_close(ref2, res2, atol=atol, rtol=rtol)
def test_int8_shared_expert(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.routed_scaling_factor,
):
with self.subTest(
m=params[0],
n=params[1],
k=params[2],
routed_scaling_factor=params[3],
):
self._int8_shared_expert(*params)
def _fp8_shared_expert(self, M, N, K, routed_scaling_factor):
dtype = torch.bfloat16
prepack = True
a = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
w1_fp32 = torch.randn(1, 2 * N, K)
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w2_fp32 = torch.randn(1, K, N)
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
w1s = torch.randn(1, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale
w2s = torch.randn(1, K // BLOCK_N, N // BLOCK_K) * factor_for_scale
w1_scaled = scaled_weight(w1, w1s).view(2 * N, K)
w2_scaled = scaled_weight(w2, w2s).view(K, N)
# change back to 2D
w1, w2 = w1.squeeze(0), w2.squeeze(0)
w1s, w2s = w1s.squeeze(0), w2s.squeeze(0)
w1_scaled, w2_scaled = w1_scaled.squeeze(0), w2_scaled.squeeze(0)
fused_out = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
a2 = a.clone()
# ref
ic0 = torch.matmul(a.float(), w1_scaled.transpose(0, 1))
ic1 = SiluAndMul(ic0)
shared_out = torch.matmul(ic1, w2_scaled.transpose(0, 1))
ref_out = shared_out + fused_out.float() * routed_scaling_factor
ref_out = ref_out.to(dtype=dtype)
w1 = torch.ops.sgl_kernel.convert_weight_packed(w1) # [2N, K]
w2 = torch.ops.sgl_kernel.convert_weight_packed(w2) # [K, N]
out = torch.ops.sgl_kernel.shared_expert_cpu(
a2,
w1,
w2,
fused_out,
routed_scaling_factor,
True,
False,
True,
w1s,
w2s,
[BLOCK_N, BLOCK_K],
None,
None,
True,
)
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_fp8_shared_expert(self):
for params in itertools.product(
self.M_fp8,
self.N_fp8,
self.K_fp8,
self.routed_scaling_factor,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
routed_scaling_factor=params[3],
):
self._fp8_shared_expert(*params)
if __name__ == "__main__":
unittest.main()

199
test/srt/cpu/test_topk.py Normal file
View File

@@ -0,0 +1,199 @@
import itertools
import unittest
import sgl_kernel
import torch
from utils import precision
from sglang.srt.layers.moe.topk import (
biased_grouped_topk_impl as native_biased_grouped_topk,
)
from sglang.srt.layers.moe.topk import fused_topk_torch_native as native_fused_topk
from sglang.srt.layers.moe.topk import grouped_topk_gpu as native_grouped_topk
from sglang.srt.models.llama4 import Llama4MoE
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
# This is used by the Deepseek-V2 model
class TestGroupedTopK(CustomTestCase):
def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype):
torch.manual_seed(1234)
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states = torch.randn(M, 100, dtype=dtype)
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
ref_topk_weights, ref_topk_ids = native_grouped_topk(
hidden_states.float(),
gating_output.float(),
topk,
renormalize,
G,
topk_group,
)
# fused version
topk_weights, topk_ids = torch.ops.sgl_kernel.grouped_topk_cpu(
hidden_states,
gating_output,
topk,
renormalize,
G,
topk_group,
0,
None,
None,
)
res = torch.zeros(M, E, dtype=torch.float)
ref = torch.zeros(M, E, dtype=torch.float)
res.scatter_(1, topk_ids.long(), topk_weights)
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
torch.testing.assert_close(res, ref)
def test_grouped_topk(self):
for renormalize in [True, False]:
self._run_single_test(123, 8, 2, 2, 1, renormalize, torch.bfloat16)
self._run_single_test(123, 16, 4, 3, 2, renormalize, torch.bfloat16)
self._run_single_test(123, 32, 4, 3, 2, renormalize, torch.bfloat16)
self._run_single_test(1123, 32, 4, 3, 2, renormalize, torch.bfloat16)
self._run_single_test(123, 64, 1, 6, 1, renormalize, torch.bfloat16)
self._run_single_test(123, 256, 8, 4, 8, renormalize, torch.bfloat16)
self._run_single_test(123, 160, 8, 6, 2, renormalize, torch.bfloat16)
# DeepSeek V2/V3/R1 uses biased_grouped_top
class TestBiasedGroupedTopK(CustomTestCase):
def _run_single_test(
self, M, E, G, topk, topk_group, renormalize, dtype, bias_dtype
):
torch.manual_seed(1234)
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states = torch.randn(M, 100, dtype=dtype)
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
correction_bias = torch.randn(E, dtype=bias_dtype)
ref_topk_weights, ref_topk_ids = native_biased_grouped_topk(
hidden_states.float(),
gating_output.float(),
correction_bias.float(),
topk,
renormalize,
G,
topk_group,
)
# fused version
topk_weights, topk_ids = torch.ops.sgl_kernel.biased_grouped_topk_cpu(
hidden_states,
gating_output,
correction_bias,
topk,
renormalize,
G,
topk_group,
0,
None,
None,
)
res = torch.zeros(M, E, dtype=torch.float)
ref = torch.zeros(M, E, dtype=torch.float)
res.scatter_(1, topk_ids.long(), topk_weights)
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
torch.testing.assert_close(res, ref)
def test_biased_grouped_topk(self):
for renormalize in [True, False]:
for bias_dtype in [torch.float32, torch.bfloat16]:
self._run_single_test(
122, 256, 8, 8, 2, renormalize, torch.bfloat16, bias_dtype
)
class TestTopK(CustomTestCase):
def _run_single_test(self, M, E, topk, renormalize, dtype):
torch.manual_seed(1998)
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states = torch.randn(M, 100, dtype=dtype)
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
ref_topk_weights, ref_topk_ids = native_fused_topk(
hidden_states.float(),
gating_output.float(),
topk,
renormalize,
)
# fused version
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
hidden_states, gating_output, topk, renormalize
)
res = torch.zeros(M, E, dtype=torch.float)
ref = torch.zeros(M, E, dtype=torch.float)
res.scatter_(1, topk_ids.long(), topk_weights)
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
torch.testing.assert_close(res, ref)
def test_topk(self):
for renormalize in [True, False]:
self._run_single_test(123, 8, 2, renormalize, torch.bfloat16)
self._run_single_test(123, 16, 3, renormalize, torch.bfloat16)
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
self._run_single_test(123, 64, 6, renormalize, torch.bfloat16)
self._run_single_test(123, 256, 4, renormalize, torch.bfloat16)
self._run_single_test(123, 160, 6, renormalize, torch.bfloat16)
class TestCustomTopK(CustomTestCase):
def _run_single_test(
self, M, E, topk, renormalize, dtype, native_custom_f, fused_custom_f
):
torch.manual_seed(16)
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states = torch.randn(M, 100, dtype=dtype)
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
ref_topk_weights, ref_topk_ids = native_custom_f(
hidden_states.float(),
gating_output.float(),
topk,
renormalize,
)
# fused version
topk_weights, topk_ids = fused_custom_f(
hidden_states, gating_output, topk, renormalize
)
res = torch.zeros(M, E, dtype=torch.float)
ref = torch.zeros(M, E, dtype=torch.float)
res.scatter_(1, topk_ids.long(), topk_weights)
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
torch.testing.assert_close(res, ref)
def test_custom_topk(self):
test_custom_functions = [
(Llama4MoE.custom_routing_function, torch.ops.sgl_kernel.topk_sigmoid_cpu)
]
for native_custom_f, fused_custom_f in test_custom_functions:
self._run_single_test(
123, 8, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
)
self._run_single_test(
123, 16, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
)
self._run_single_test(
123, 32, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
)
if __name__ == "__main__":
unittest.main()

269
test/srt/cpu/utils.py Normal file
View File

@@ -0,0 +1,269 @@
import math
import torch
import torch.nn.functional as F
precision = {
torch.bfloat16: 1e-2,
torch.float16: 1e-3,
torch.float32: 1e-5,
}
BLOCK_N, BLOCK_K = 64, 128
factor_for_scale = 1e-3
fp8_max, fp8_min = 400, -400
def SiluAndMul(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def per_token_quant_int8(x):
x = x.float()
absmax = x.abs().max(dim=-1).values
absmax = absmax.clamp_min(1e-10).unsqueeze(-1)
scale_x = absmax / 127
x_q = x.mul(127 / absmax)
x_q = torch.round(x_q).to(torch.int8)
return x_q, scale_x
def convert_weight(weight, scale_block_size, A_dtype):
N, K = weight.size()
fp8_max = 448.0
scale_block_size_N, scale_block_size_K = scale_block_size # (128, 128)
pad_N = (scale_block_size_N - (N % scale_block_size_N)) % scale_block_size_N
pad_K = (scale_block_size_K - (K % scale_block_size_K)) % scale_block_size_K
if pad_N > 0 or pad_K > 0:
weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N))
weight_blocks = weight.view(
math.ceil(N / scale_block_size_N),
scale_block_size_N,
math.ceil(K / scale_block_size_K),
scale_block_size_K,
) # (8, 128, 8, 128)
weight_blocks = weight_blocks.permute(0, 2, 1, 3).contiguous() # (8, 8, 128, 128)
# Step 2: compute per-block max abs values → scale
abs_max = weight_blocks.abs().amax(dim=(-2, -1), keepdim=True) # (8, 8, 1, 1)
scales = abs_max / fp8_max
scales = torch.where(
scales == 0, torch.ones_like(scales), scales
) # avoid division by zero
q_fp8 = (weight_blocks / scales).to(torch.float8_e4m3fn)
q_fp8_reshape = q_fp8.permute(0, 2, 1, 3).contiguous()
if pad_N > 0 or pad_K > 0:
q_fp8_reshape = q_fp8_reshape.view(N + pad_N, K + pad_K)
q_fp8_reshape = q_fp8_reshape[:N, :K].contiguous()
else:
q_fp8_reshape = q_fp8_reshape.view(N, K)
dq_weight = q_fp8.float() * scales
dq_weight = dq_weight.permute(0, 2, 1, 3).contiguous() # (8, 128, 8, 128)
if pad_N > 0 or pad_K > 0:
w_dq = dq_weight.view(N + pad_N, K + pad_K).to(A_dtype)
w_dq = w_dq[:N, :K].contiguous()
else:
w_dq = dq_weight.view(N, K).to(A_dtype)
scales = scales.view(
math.ceil(N / scale_block_size_N), math.ceil(K / scale_block_size_K)
)
return q_fp8_reshape, scales, w_dq
def native_w8a8_per_token_matmul(A, B, As, Bs, bias, output_dtype=torch.bfloat16):
"""Matrix multiplication function that supports per-token input quantization and per-column weight quantization"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
# Reshape input
M = A.numel() // A.shape[-1]
B = B.t() # Transpose weight matrix
N, K = B.shape
origin_C_shape = A.shape[:-1] + (K,)
A = A.reshape(M, N)
# As is per-token [M, 1], Bs is per-column [1, K]
C = torch.matmul(A, B) # [M, K]
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
if bias is not None:
C.add_(bias.view(1, -1))
return C.reshape(origin_C_shape).to(output_dtype)
def torch_naive_moe(a, w1, w2, b, routed_scaling_factor):
ic1 = torch.matmul(a, w1.transpose(0, 1))
ic2 = SiluAndMul(ic1)
ic3 = torch.matmul(ic2, w2.transpose(0, 1))
return ic3 + b * routed_scaling_factor
def torch_w8a8_per_column_moe(a, w1_q, w2_q, w1_s, w2_s, b, routed_scaling_factor):
# Perform per-token quantization
a_q, a_s = per_token_quant_int8(a)
ic1 = native_w8a8_per_token_matmul(
a_q, w1_q, a_s, w1_s, bias=None, output_dtype=torch.float32
)
ic2 = SiluAndMul(ic1)
a1_q, a1_s = per_token_quant_int8(ic2)
ic3 = native_w8a8_per_token_matmul(
a1_q, w2_q, a1_s, w2_s, bias=None, output_dtype=torch.float32
)
return ic3 + b * routed_scaling_factor
def scaled_weight(weight, scales):
E, N, K = weight.shape
pad_N = (BLOCK_N - (N % BLOCK_N)) % BLOCK_N
pad_K = (BLOCK_K - (K % BLOCK_K)) % BLOCK_K
if pad_N > 0 or pad_K > 0:
weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N))
weight_block = (
weight.view(E, math.ceil(N / BLOCK_N), BLOCK_N, math.ceil(K / BLOCK_K), BLOCK_K)
.permute(0, 1, 3, 2, 4)
.float()
.contiguous()
)
weight_scaled = (
(
weight_block
* scales.view(E, math.ceil(N / BLOCK_N), math.ceil(K / BLOCK_K), 1, 1)
)
.permute(0, 1, 3, 2, 4)
.contiguous()
)
if pad_N > 0 or pad_K > 0:
weight_scaled = weight_scaled.view(E, N + pad_N, K + pad_K)
weight_scaled = weight_scaled[..., :N, :K].contiguous()
else:
weight_scaled = weight_scaled.view(E, N, K)
return weight_scaled
def torch_naive_fused_moe(a, w1, w2, score, topk, renormalize):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
if renormalize:
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
0, 1
)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def torch_w8a8_per_column_fused_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk):
"""This function performs fused moe with per-column int8 quantization using native torch."""
B, D = a.shape
# Perform per-token quantization
a_q, a_s = per_token_quant_int8(a)
# Repeat tokens to match topk
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
# Also repeat the scale
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device)
# Calculate routing
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
# Process each expert
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
# First MLP layer: note that a_s is now per-token
inter_out = native_w8a8_per_token_matmul(
a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
bias=None,
output_dtype=torch.float32,
)
# Activation function
act_out = SiluAndMul(inter_out)
# Quantize activation output with per-token
act_out_q, act_out_s = per_token_quant_int8(act_out)
# Second MLP layer
out[mask] = native_w8a8_per_token_matmul(
act_out_q,
w2[i],
act_out_s,
w2_s[i],
bias=None,
output_dtype=torch.float32,
)
# Apply routing weights and sum
return (
(out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype))
.sum(dim=1)
.to(a.dtype)
)
def native_fp8_fused_moe(a, w1, w2, topk_weight, topk_ids, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D).float()
out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device)
# Calculate routing
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
ic0 = torch.matmul(a[mask], w1[i].transpose(0, 1))
ic1 = SiluAndMul(ic0)
out[mask] = torch.matmul(ic1, w2[i].transpose(0, 1))
return (
(out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype))
.sum(dim=1)
.to(a.dtype)
)
def make_non_contiguous(x: torch.Tensor) -> torch.Tensor:
"""
Make a tensor non-contiguous by slicing it via last dimension.
"""
last_dim = x.shape[-1]
return x[..., : last_dim // 2] if x.is_contiguous() else x

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,206 @@
"""
Integration test for abort_request functionality with a SGLang server.
Run with:
python -m unittest sglang.test.srt.entrypoints.http_server.test_abort_request -v
"""
import threading
import time
import unittest
from typing import Optional
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestAbortRequest(CustomTestCase):
"""Integration test class for abort request functionality."""
model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
base_url = DEFAULT_URL_FOR_TEST
@classmethod
def setUpClass(cls):
"""Launch the server."""
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--disable-cuda-graph"],
)
cls.completion_url = f"{cls.base_url}/generate"
cls.abort_url = f"{cls.base_url}/abort_request"
cls.health_url = f"{cls.base_url}/health"
print(f"Server started at {cls.base_url}")
@classmethod
def tearDownClass(cls):
"""Clean up the server."""
kill_process_tree(cls.process.pid)
def _send_completion_request(
self,
text: str,
request_id: str,
max_tokens: int = 50,
temperature: float = 0.8,
stream: bool = True,
) -> requests.Response:
"""Send a completion request to the server."""
payload = {
"text": text,
"sampling_params": {
"max_new_tokens": max_tokens,
"temperature": temperature,
},
"stream": stream,
"rid": request_id,
}
response = requests.post(
self.completion_url,
json=payload,
headers={"Content-Type": "application/json"},
timeout=30,
stream=stream,
)
return response
def _send_abort_request(self, request_id: str) -> requests.Response:
"""Send an abort request."""
payload = {"rid": request_id}
return requests.post(self.abort_url, json=payload, timeout=10)
def _check_server_health(self) -> bool:
"""Check if server is healthy."""
try:
response = requests.get(self.health_url, timeout=5)
return response.status_code == 200
except:
return False
def test_abort_during_non_streaming_generation(self):
"""Test aborting a non-streaming request during generation."""
self.assertTrue(self._check_server_health(), "Server should be healthy")
request_id = "test_abort_non_streaming"
completion_result = {}
def run_completion():
response = self._send_completion_request(
"Write a detailed essay about artificial intelligence",
max_tokens=500,
temperature=1,
request_id=request_id,
stream=False,
)
if response.status_code == 200:
result = response.json()
completion_result["text"] = result.get("text", "")
completion_result["finish_reason"] = result.get("meta_info", {}).get(
"finish_reason"
)
completion_thread = threading.Thread(target=run_completion)
completion_thread.start()
time.sleep(0.1)
abort_response = self._send_abort_request(request_id)
completion_thread.join()
self.assertEqual(abort_response.status_code, 200)
self.assertIsNotNone(completion_result, "Should have completion result")
if completion_result:
finish_reason_obj = completion_result.get("finish_reason")
self.assertIsNotNone(finish_reason_obj, "Should have finish_reason")
if finish_reason_obj:
self.assertEqual(
finish_reason_obj.get("type"), "abort", "Should be aborted"
)
def test_batch_requests_with_selective_abort(self):
"""Test multiple concurrent requests with selective abort of one request."""
self.assertTrue(self._check_server_health(), "Server should be healthy")
request_ids = ["batch_test_0", "batch_test_1", "batch_test_2"]
abort_target_id = "batch_test_1"
completion_results = {}
threads = []
def run_completion(req_id, prompt):
response = self._send_completion_request(
f"Write a story about {prompt}",
max_tokens=100,
temperature=0.8,
request_id=req_id,
stream=False,
)
if response.status_code == 200:
result = response.json()
completion_results[req_id] = {
"text": result.get("text", ""),
"finish_reason": result.get("meta_info", {}).get("finish_reason"),
}
# Start all requests
prompts = ["a knight's adventure", "a space discovery", "a chef's restaurant"]
for i, req_id in enumerate(request_ids):
thread = threading.Thread(target=run_completion, args=(req_id, prompts[i]))
threads.append(thread)
thread.start()
# Abort one request
time.sleep(0.1)
abort_response = self._send_abort_request(abort_target_id)
# Wait for completion
for thread in threads:
thread.join(timeout=30)
# Verify results
self.assertEqual(abort_response.status_code, 200)
# Check aborted request
aborted_result = completion_results.get(abort_target_id)
self.assertIsNotNone(
aborted_result, f"Aborted request {abort_target_id} should have result"
)
if aborted_result:
aborted_finish_reason = aborted_result.get("finish_reason")
self.assertIsNotNone(
aborted_finish_reason, "Aborted request should have finish_reason"
)
if aborted_finish_reason:
self.assertEqual(aborted_finish_reason.get("type"), "abort")
# Check other requests completed normally
normal_completions = 0
for req_id in request_ids:
if req_id != abort_target_id and req_id in completion_results:
result = completion_results[req_id]
if result:
finish_reason = result.get("finish_reason")
if finish_reason and finish_reason.get("type") == "length":
normal_completions += 1
self.assertEqual(
normal_completions, 2, "Other 2 requests should complete normally"
)
if __name__ == "__main__":
unittest.main(verbosity=2, warnings="ignore")

View File

@@ -0,0 +1,445 @@
# Copy from deepseek-ai/DeepEP/tests/test_internode.py
import os
import time
# noinspection PyUnresolvedReferences
import deep_ep
# Test compatibility with low latency functions
import test_deepep_low_latency
import torch
import torch.distributed as dist
from sglang.test.test_deepep_utils import (
bench,
calc_diff,
create_grouped_scores,
init_dist,
inplace_unique,
per_token_cast_back,
per_token_cast_to_fp8,
)
def test_main(
num_sms: int,
local_rank: int,
num_local_ranks: int,
num_ranks: int,
num_nodes: int,
rank: int,
buffer: deep_ep.Buffer,
group: dist.ProcessGroup,
):
# Settings
num_tokens, hidden, num_topk_groups, num_topk, num_experts = (
4096,
7168,
min(num_nodes, 4),
8,
(256 // num_ranks) * num_ranks,
)
assert num_experts % num_ranks == 0 and num_local_ranks == 8
if local_rank == 0:
print(
f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}",
flush=True,
)
# Random data
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
x_e4m3 = per_token_cast_to_fp8(x)
scores = (
torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs()
+ 1
)
group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
group_idx = torch.topk(
group_scores, k=num_topk_groups, dim=-1, sorted=False
).indices
masked_scores = create_grouped_scores(scores, group_idx, num_nodes)
topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[
1
]
topk_weights = (
torch.ones((num_tokens, num_topk), dtype=torch.float32, device="cuda") * rank
)
topk_weights_pure_rand = torch.randn(
(num_tokens, num_topk), dtype=torch.float32, device="cuda"
)
rank_idx = topk_idx // (num_experts // num_ranks)
rank_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rank_idx, num_ranks)
rdma_rank_idx = rank_idx // num_local_ranks
rdma_rank_idx.masked_fill_(rank_idx == -1, -1)
inplace_unique(rdma_rank_idx, num_nodes)
# RDMA dispatch counts
rdma_idx = topk_idx // (num_experts // num_nodes)
rdma_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rdma_idx, num_nodes)
num_rdma_token_sent = rdma_idx.ne(-1).sum().item()
# Expert meta
num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="cuda")
for i in range(num_experts):
num_tokens_per_expert[i] = (topk_idx == i).sum()
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
# Rank layout meta
num_tokens_per_rank = torch.empty((num_ranks,), dtype=torch.int, device="cuda")
num_tokens_per_rdma_rank = torch.empty((num_nodes,), dtype=torch.int, device="cuda")
token_idx_in_rank = torch.full(
(num_ranks, num_tokens), -1, dtype=torch.long, device="cuda"
)
for i in range(num_ranks):
num_tokens_per_rank[i] = (rank_idx == i).sum()
token_sel = (rank_idx == i).max(dim=-1)[0]
count = token_sel.sum().item()
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
tokens[:count] = torch.sort(tokens[:count])[0]
token_idx_in_rank[i][tokens[:count]] = torch.arange(
count, dtype=torch.long, device="cuda"
)
for i in range(num_nodes):
num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum()
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
is_token_in_rank = token_idx_in_rank >= 0
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
(
ref_num_tokens_per_rank,
ref_num_tokens_per_rdma_rank,
ref_num_tokens_per_expert,
ref_is_token_in_rank,
_,
) = buffer.get_dispatch_layout(topk_idx, num_experts)
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank)
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
if local_rank == 0:
print(f"[layout] Kernel performance: {t * 1000:.3f} ms", flush=True)
print("", flush=True)
group.barrier()
time.sleep(1)
# Config
rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512)
config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size)
# Test dispatch
# noinspection PyShadowingNames
def check_data(check_x, recv_gbl_rank_prefix_sum):
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
check_start = 0
for i in range(num_ranks):
check_end = recv_gbl_rank_prefix_sum[i].item()
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
check_start = check_end
for previous_mode in (False, True):
for async_mode in (False, True):
for current_x in (x_pure_rand, x, x_e4m3):
for with_topk in (False, True):
if local_rank == 0:
print(
f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...',
flush=True,
end="",
)
dispatch_args = {
"x": current_x,
"num_tokens_per_rank": num_tokens_per_rank,
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
"is_token_in_rank": is_token_in_rank,
"num_tokens_per_expert": num_tokens_per_expert,
"config": config,
"async_finish": async_mode,
}
if with_topk:
dispatch_args.update(
{
"topk_idx": topk_idx,
"topk_weights": (
topk_weights_pure_rand
if current_x is x_pure_rand
else topk_weights
),
}
)
if previous_mode:
dispatch_args.update({"previous_event": buffer.capture()})
(
recv_x,
recv_topk_idx,
recv_topk_weights,
recv_num_tokens_per_expert_list,
handle,
event,
) = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = (
per_token_cast_back(*recv_x)
if isinstance(recv_x, tuple)
else recv_x
)
# Checks
recv_gbl_rank_prefix_sum = handle[-4]
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(
0
), f"{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}"
assert (
gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist()
== recv_num_tokens_per_expert_list
)
if current_x is not x_pure_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum)
if with_topk:
# Check `topk_idx`
assert (
recv_topk_idx.eq(-1)
| (
(recv_topk_idx >= 0)
& (recv_topk_idx < (num_experts // num_ranks))
)
).sum().item() == recv_topk_idx.numel()
for i, count in enumerate(recv_num_tokens_per_expert_list):
assert recv_topk_idx.eq(i).sum().item() == count
# Check `topk_weights`
if current_x is not x_pure_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = (
recv_topk_weights.amax(dim=1, keepdim=True).expand_as(
recv_topk_weights
)[recv_topk_idx.eq(-1)]
)
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
# Test cached dispatch (must without top-k staffs)
if not with_topk:
dispatch_args = {
"x": current_x,
"handle": handle,
"config": config,
"async_finish": async_mode,
}
if previous_mode:
dispatch_args.update({"previous_event": buffer.capture()})
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = (
per_token_cast_back(*recv_x)
if isinstance(recv_x, tuple)
else recv_x
)
if current_x is not x_pure_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum)
# Test combine
combine_args = {
"x": recv_x,
"handle": handle,
"config": config,
"async_finish": async_mode,
}
if with_topk:
combine_args.update({"topk_weights": recv_topk_weights})
if previous_mode:
dispatch_args.update({"previous_event": buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(
**combine_args
)
event.current_stream_wait() if async_mode else ()
check_x = combined_x.float() / is_token_in_rank.sum(
dim=1
).unsqueeze(1)
ref_x = x_pure_rand if current_x is x_pure_rand else x
assert calc_diff(check_x, ref_x) < 5e-6
if with_topk:
check_topk_weights = (
combined_topk_weights
if (current_x is x_pure_rand)
else (
combined_topk_weights
/ is_token_in_rank.sum(dim=1).unsqueeze(1)
)
)
ref_topk_weights = (
topk_weights_pure_rand
if current_x is x_pure_rand
else topk_weights
)
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
# For later tuning
dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes
if local_rank == 0:
print(" passed", flush=True)
if local_rank == 0:
print("", flush=True)
# Tune dispatch performance
best_dispatch_results = None
fp8_factor = (1 + 4 / 128) / 2
for current_x in (x_e4m3, x):
best_time, best_results = 1e10, None
rdma_send_bytes = (
(dispatch_bf16_rdma_send_bytes * fp8_factor)
if isinstance(current_x, tuple)
else dispatch_bf16_rdma_send_bytes
)
nvl_recv_bytes = (
(dispatch_bf16_nvl_recv_bytes * fp8_factor)
if isinstance(current_x, tuple)
else dispatch_bf16_nvl_recv_bytes
)
for nvl_chunk_size in range(4, 33, 4):
for rdma_chunk_size in range(4, 33, 4):
config = deep_ep.Config(
num_sms,
nvl_chunk_size,
nvl_buffer_size,
rdma_chunk_size,
rdma_buffer_size,
)
tune_args = {"x": current_x, "handle": handle, "config": config}
t = bench(lambda: buffer.dispatch(**tune_args))[0]
if t < best_time:
best_time, best_results = t, (
num_sms,
nvl_chunk_size,
rdma_chunk_size,
)
if local_rank == 0:
print(
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ",
flush=True,
)
if local_rank == 0:
print(
f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)',
flush=True,
)
print("", flush=True)
if isinstance(current_x, tuple):
# Gather FP8 the best config from rank 0
best_dispatch_results = torch.tensor(
[best_results[0], best_results[1], best_results[2]],
dtype=torch.int32,
device="cuda",
)
all_best_fp8_results_list = [
torch.zeros_like(best_dispatch_results)
for _ in range(torch.distributed.get_world_size())
]
dist.all_gather(
all_best_fp8_results_list, best_dispatch_results, group=group
)
best_dispatch_results = all_best_fp8_results_list[0].tolist()
dispatch_config = deep_ep.Config(
best_dispatch_results[0],
best_dispatch_results[1],
nvl_buffer_size,
best_dispatch_results[2],
rdma_buffer_size,
)
dispatch_args = {
"x": x,
"num_tokens_per_rank": num_tokens_per_rank,
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
"is_token_in_rank": is_token_in_rank,
"num_tokens_per_expert": num_tokens_per_expert,
"config": dispatch_config if dispatch_config is not None else config,
}
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
# Tune combine performance
best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 5, 1):
for rdma_chunk_size in range(8, 33, 4):
config = deep_ep.Config(
num_sms,
nvl_chunk_size,
nvl_buffer_size,
rdma_chunk_size,
rdma_buffer_size,
)
tune_args = {"x": recv_x, "handle": handle, "config": config}
t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0:
print(
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ",
flush=True,
)
if t < best_time:
best_time, best_results = t, (
num_sms,
nvl_chunk_size,
rdma_chunk_size,
)
if local_rank == 0:
print(
f"[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)",
flush=True,
)
print("", flush=True)
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int):
num_nodes = int(os.getenv("WORLD_SIZE", 1))
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility = False
if test_ll_compatibility:
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
buffer = deep_ep.Buffer(
group,
int(1e9),
int(1e9),
low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1),
)
assert num_local_ranks == 8 and num_ranks > 8
torch.manual_seed(rank)
for i in (24,):
test_main(
i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group
)
if local_rank == 0:
print("", flush=True)
# Test compatibility with low latency functions
if test_ll_compatibility:
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_deepep_low_latency.test_main(
ll_num_tokens,
ll_hidden,
ll_num_experts,
ll_num_topk,
rank,
num_ranks,
group,
buffer,
seed=1,
)
if __name__ == "__main__":
num_processes = 8
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)

View File

@@ -0,0 +1,379 @@
# Copy from deepseek-ai/DeepEP/tests/test_intranode.py
import os
import time
# noinspection PyUnresolvedReferences
import deep_ep
# Test compatibility with low latency functions
import test_deepep_low_latency
import torch
import torch.distributed as dist
from sglang.test.test_deepep_utils import (
bench,
calc_diff,
init_dist,
inplace_unique,
per_token_cast_back,
per_token_cast_to_fp8,
)
def test_main(
num_sms: int,
local_rank: int,
num_ranks: int,
rank: int,
buffer: deep_ep.Buffer,
group: dist.ProcessGroup,
):
# Settings
num_tokens, hidden, num_topk, num_experts = (
4096,
7168,
8,
(256 // num_ranks) * num_ranks,
)
assert num_experts % num_ranks == 0
if local_rank == 0:
print(
f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}",
flush=True,
)
# Random data
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * rank
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
x_e4m3 = per_token_cast_to_fp8(x)
scores = (
torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs()
+ 1
)
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
topk_weights = (
torch.ones((num_tokens, num_topk), dtype=torch.float32, device="cuda") * rank
)
topk_weights_pure_rand = torch.randn(
(num_tokens, num_topk), dtype=torch.float32, device="cuda"
)
rank_idx = topk_idx // (num_experts // num_ranks)
rank_idx.masked_fill_(topk_idx == -1, -1)
inplace_unique(rank_idx, num_ranks)
# Expert meta
num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="cuda")
for i in range(num_experts):
num_tokens_per_expert[i] = (topk_idx == i).sum()
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
# Rank layout meta
num_tokens_per_rank = torch.empty((num_ranks,), dtype=torch.int, device="cuda")
token_idx_in_rank = torch.full(
(num_ranks, num_tokens), -1, dtype=torch.long, device="cuda"
)
for i in range(num_ranks):
num_tokens_per_rank[i] = (rank_idx == i).sum()
token_sel = (rank_idx == i).max(dim=-1)[0]
count = token_sel.sum().item()
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
tokens[:count] = torch.sort(tokens[:count])[0]
token_idx_in_rank[i][tokens[:count]] = torch.arange(
count, dtype=torch.long, device="cuda"
)
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
is_token_in_rank = token_idx_in_rank >= 0
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
ref_num_tokens_per_rank, _, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = (
buffer.get_dispatch_layout(topk_idx, num_experts)
)
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
if local_rank == 0:
print(f"[layout] Kernel performance: {t * 1000:.3f} ms", flush=True)
print("", flush=True)
group.barrier()
time.sleep(1)
# Config
nvl_buffer_size = 256
config = deep_ep.Config(num_sms, 8, nvl_buffer_size)
# Test dispatch
# noinspection PyShadowingNames
def check_data(check_x, rank_prefix_matrix):
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
check_start = 0
for i in range(num_ranks):
check_end = rank_prefix_matrix[i][rank].item()
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
check_start = check_end
for previous_mode in (False, True):
for async_mode in (False, True):
for current_x in (x_pure_rand, x, x_e4m3):
for with_topk in (False, True):
if local_rank == 0:
print(
f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...',
flush=True,
end="",
)
dispatch_args = {
"x": current_x,
"num_tokens_per_rank": num_tokens_per_rank,
"is_token_in_rank": is_token_in_rank,
"num_tokens_per_expert": num_tokens_per_expert,
"config": config,
"async_finish": async_mode,
}
if with_topk:
dispatch_args.update(
{
"topk_idx": topk_idx,
"topk_weights": (
topk_weights_pure_rand
if current_x is x_pure_rand
else topk_weights
),
}
)
if previous_mode:
dispatch_args.update({"previous_event": buffer.capture()})
(
recv_x,
recv_topk_idx,
recv_topk_weights,
recv_num_tokens_per_expert_list,
handle,
event,
) = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = (
per_token_cast_back(*recv_x)
if isinstance(recv_x, tuple)
else recv_x
)
# Checks
rank_prefix_matrix = handle[0]
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(
0
), f"{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}"
assert (
gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist()
== recv_num_tokens_per_expert_list
)
if current_x is not x_pure_rand:
check_data(recv_x, rank_prefix_matrix)
if with_topk:
# Check `topk_idx`
assert (
recv_topk_idx.eq(-1)
| (
(recv_topk_idx >= 0)
& (recv_topk_idx < (num_experts // num_ranks))
)
).sum().item() == recv_topk_idx.numel()
for i, count in enumerate(recv_num_tokens_per_expert_list):
assert recv_topk_idx.eq(i).sum().item() == count
# Check `topk_weights`
if current_x is not x_pure_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = (
recv_topk_weights.amax(dim=1, keepdim=True).expand_as(
recv_topk_weights
)[recv_topk_idx.eq(-1)]
)
check_data(recv_topk_weights, rank_prefix_matrix)
# Test cached dispatch (must without top-k staffs)
if not with_topk:
dispatch_args = {
"x": current_x,
"handle": handle,
"config": config,
"async_finish": async_mode,
}
if previous_mode:
dispatch_args.update({"previous_event": buffer.capture()})
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_x = (
per_token_cast_back(*recv_x)
if isinstance(recv_x, tuple)
else recv_x
)
if current_x is not x_pure_rand:
check_data(recv_x, rank_prefix_matrix)
# Test combine
combine_args = {
"x": recv_x,
"handle": handle,
"config": config,
"async_finish": async_mode,
}
if with_topk:
combine_args.update({"topk_weights": recv_topk_weights})
if previous_mode:
dispatch_args.update({"previous_event": buffer.capture()})
combined_x, combined_topk_weights, event = buffer.combine(
**combine_args
)
event.current_stream_wait() if async_mode else ()
check_x = combined_x.float() / is_token_in_rank.sum(
dim=1
).unsqueeze(1)
ref_x = x_pure_rand if current_x is x_pure_rand else x
assert calc_diff(check_x, ref_x) < 5e-6
if with_topk:
check_topk_weights = (
combined_topk_weights
if (current_x is x_pure_rand)
else (
combined_topk_weights
/ is_token_in_rank.sum(dim=1).unsqueeze(1)
)
)
ref_topk_weights = (
topk_weights_pure_rand
if current_x is x_pure_rand
else topk_weights
)
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
# For later tuning
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
if local_rank == 0:
print(" passed", flush=True)
if local_rank == 0:
print("", flush=True)
# Tune dispatch performance
best_dispatch_results = None
fp8_factor = (1 + 4 / 128) / 2
for current_x in (x_e4m3, x):
best_time, best_results = 1e10, None
nvl_recv_bytes = (
(dispatch_bf16_nvl_recv_bytes * fp8_factor)
if isinstance(current_x, tuple)
else dispatch_bf16_nvl_recv_bytes
)
for nvl_chunk_size in range(4, 33, 4):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
tune_args = {"x": current_x, "handle": handle, "config": config}
t = bench(lambda: buffer.dispatch(**tune_args))[0]
if t < best_time:
best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0:
print(
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ",
flush=True,
)
if local_rank == 0:
print(
f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)',
flush=True,
)
print("", flush=True)
if isinstance(current_x, tuple):
# Gather FP8 the best config from rank 0
best_dispatch_results = torch.tensor(
[best_results[0], best_results[1]], dtype=torch.int32, device="cuda"
)
all_best_fp8_results_list = [
torch.zeros_like(best_dispatch_results)
for _ in range(torch.distributed.get_world_size())
]
dist.all_gather(
all_best_fp8_results_list, best_dispatch_results, group=group
)
best_dispatch_results = all_best_fp8_results_list[0].tolist()
dispatch_config = deep_ep.Config(
best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size
)
dispatch_args = {
"x": x,
"num_tokens_per_rank": num_tokens_per_rank,
"is_token_in_rank": is_token_in_rank,
"num_tokens_per_expert": num_tokens_per_expert,
"config": dispatch_config if dispatch_config is not None else config,
}
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
# Tune combine performance
best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 7, 1):
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
tune_args = {"x": recv_x, "handle": handle, "config": config}
t = bench(lambda: buffer.combine(**tune_args))[0]
if local_rank == 0:
print(
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ",
flush=True,
)
if t < best_time:
best_time, best_results = t, (num_sms, nvl_chunk_size)
if local_rank == 0:
print(
f"[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)",
flush=True,
)
print("", flush=True)
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
test_ll_compatibility, num_rdma_bytes = False, 0
if test_ll_compatibility:
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
ll_num_tokens, ll_hidden, num_ranks, ll_num_experts
)
buffer = deep_ep.Buffer(
group,
int(1e9),
num_rdma_bytes,
low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1),
)
torch.manual_seed(rank)
for i in (24,):
test_main(i, local_rank, num_ranks, rank, buffer, group)
if local_rank == 0:
print("", flush=True)
# Test compatibility with low latency functions
if test_ll_compatibility:
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
test_deepep_low_latency.test_main(
ll_num_tokens,
ll_hidden,
ll_num_experts,
ll_num_topk,
rank,
num_ranks,
group,
buffer,
seed=1,
)
if __name__ == "__main__":
num_processes = 8
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)

View File

@@ -0,0 +1,149 @@
import unittest
from types import SimpleNamespace
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestDeepseek(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"8",
"--enable-dp-attention",
"--dp",
"8",
"--moe-dense-tp-size",
"1",
"--enable-dp-lm-head",
"--moe-a2a-backend",
"deepep",
"--enable-two-batch-overlap",
"--ep-num-redundant-experts",
"32",
"--ep-dispatch-algorithm",
"dynamic",
"--eplb-algorithm",
"deepseek",
"--cuda-graph-bs",
"256",
"--max-running-requests",
"2048",
"--disable-radix-cache",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=1200,
parallel=1200,
max_new_tokens=512,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Eval accuracy of GSM8K: {metrics=}")
self.assertGreater(metrics["accuracy"], 0.92)
class TestDeepseekMTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"8",
"--enable-dp-attention",
"--dp",
"8",
"--moe-dense-tp-size",
"1",
"--enable-dp-lm-head",
"--moe-a2a-backend",
"deepep",
"--enable-two-batch-overlap",
"--ep-num-redundant-experts",
"32",
"--ep-dispatch-algorithm",
"dynamic",
"--eplb-algorithm",
"deepseek",
"--cuda-graph-bs",
"64", # TODO: increase it to 128 when TBO is supported in draft_extend
"--max-running-requests",
"512",
"--speculative-algorithm",
"EAGLE",
"--speculative-num-steps",
"1",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"2",
"--disable-radix-cache",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=1200,
parallel=1200,
max_new_tokens=512,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Eval accuracy of GSM8K: {metrics=}")
self.assertGreater(metrics["accuracy"], 0.92)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["internal_states"][0][
"avg_spec_accept_length"
]
print(
f"###test_gsm8k:\n"
f"accuracy={metrics['accuracy']=:.3f}\n"
f"{avg_spec_accept_length=:.3f}\n"
)
self.assertGreater(avg_spec_accept_length, 1.85)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,325 @@
# Copy from deepseek-ai/DeepEP/tests/test_low_latency.py
import random
from functools import partial
import deep_ep
import torch
import torch.distributed as dist
from sglang.test.test_deepep_utils import (
bench,
bench_kineto,
calc_diff,
hash_tensor,
init_dist,
per_token_cast_back,
)
def test_main(
num_tokens: int,
hidden: int,
num_experts: int,
num_topk: int,
rank: int,
num_ranks: int,
group: dist.ProcessGroup,
buffer: deep_ep.Buffer,
seed: int = 0,
):
torch.manual_seed(seed + rank)
random.seed(seed + rank)
assert num_experts % num_ranks == 0
num_local_experts = num_experts // num_ranks
# NOTES: the integers greater than 256 exceeds the BF16 precision limit
rank_offset = 128
assert (
num_ranks - rank_offset < 257
), "Too many ranks (exceeding test precision limit)"
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * (
rank - rank_offset
)
x[:, -128:] = torch.arange(num_tokens, device="cuda").to(torch.bfloat16).view(-1, 1)
scores = (
torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs()
+ 1
)
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
topk_weights = torch.randn(
(num_tokens, num_topk), dtype=torch.float32, device="cuda"
).abs()
# Randomly mask some positions
for i in range(10):
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = (
-1
)
# Check dispatch correctness
do_check = True
hash_value, num_times = 0, 0
for return_recv_hook in (False, True):
for dispatch_use_fp8 in (False, True):
num_times += 1
for i in range((num_times % 2) + 1):
packed_recv_x, packed_recv_count, handle, event, hook = (
buffer.low_latency_dispatch(
x,
topk_idx,
num_tokens,
num_experts,
use_fp8=dispatch_use_fp8,
async_finish=not return_recv_hook,
return_recv_hook=return_recv_hook,
)
)
hook() if return_recv_hook else event.current_stream_wait()
packed_recv_x = (
(packed_recv_x[0], packed_recv_x[1].contiguous())
if dispatch_use_fp8
else packed_recv_x
)
simulated_gemm_x = (
per_token_cast_back(
packed_recv_x[0].view(-1, hidden),
packed_recv_x[1].view(-1, hidden // 128),
).view(packed_recv_x[0].shape)
if dispatch_use_fp8
else packed_recv_x.clone()
)
all_topk_idx = torch.empty(
(num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device="cuda"
)
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
for i in range(num_local_experts if do_check else 0):
expert_id = rank * num_local_experts + i
recv_x = (
per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i])
if dispatch_use_fp8
else packed_recv_x[i]
)
recv_count, recv_src_info, recv_layout_range = (
packed_recv_count[i],
handle[0][i],
handle[1][i],
)
# Check expert indices
int_mask = (2**32) - 1
num_valid_tokens = recv_count.item()
assert (
num_valid_tokens == (recv_layout_range & int_mask).sum().item()
), f"{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()"
assert (
num_valid_tokens == (all_topk_idx == expert_id).sum().item()
), f"{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}"
# Check received data
recv_x = recv_x[:num_valid_tokens]
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
recv_src_info = recv_src_info[:num_valid_tokens]
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
assert (
recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens
).sum().item() == 0
for j in range(num_ranks):
begin_idx, count = (recv_layout_range[j] >> 32).item(), (
recv_layout_range[j] & int_mask
).item()
assert (recv_x_amin == j - rank_offset).sum().item() == (
all_topk_idx[j] == expert_id
).sum().item()
assert (
recv_x[begin_idx : begin_idx + count][:-128] - j
).sum().item() == 0
if dispatch_use_fp8:
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
else:
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
# Check combine correctness
for zero_copy in (False, True):
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[
:, :, :
] = simulated_gemm_x
out = torch.empty(
(num_tokens, hidden), dtype=torch.bfloat16, device="cuda"
)
combined_x, event, hook = buffer.low_latency_combine(
simulated_gemm_x,
topk_idx,
topk_weights,
handle,
async_finish=not return_recv_hook,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
out=out,
)
hook() if return_recv_hook else event.current_stream_wait()
if do_check:
diff = calc_diff(
x
* topk_weights.masked_fill(topk_idx == -1, 0)
.sum(dim=1)
.view(-1, 1),
combined_x,
)
assert torch.isnan(combined_x).sum().item() == 0
assert diff < 1e-5, f"Error: {diff=}, {zero_copy=}"
hash_value ^= hash_tensor(combined_x)
def create_test_cast_with_outliers(num_outliers):
tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
tmp /= tmp.abs().amax(dim=1).view(-1, 1)
assert tmp.abs().amax().item() <= 1
# Create some amax outliers
for i in range(num_outliers):
tmp[random.randint(0, num_tokens - 1)] *= 1e3
return tmp
# noinspection PyShadowingNames
def large_gemm_with_hook(hook):
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
mat_1 = torch.randn((8192, 8192), dtype=torch.float)
mat_0 @ mat_1
hook()
# noinspection PyShadowingNames
def test_func(zero_copy: bool, return_recv_hook: bool):
recv_x, recv_count, handle, event, hook = buffer.low_latency_dispatch(
x,
topk_idx,
num_tokens,
num_experts,
async_finish=False,
return_recv_hook=return_recv_hook,
)
large_gemm_with_hook(hook) if return_recv_hook else None
if zero_copy:
buffer.get_next_low_latency_combine_buffer(handle)[
:, :, :
] = simulated_gemm_x
combined_x, event, hook = buffer.low_latency_combine(
simulated_gemm_x,
topk_idx,
topk_weights,
handle,
zero_copy=zero_copy,
return_recv_hook=return_recv_hook,
)
large_gemm_with_hook(hook) if return_recv_hook else None
# Calculate bandwidth
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
for i in range(num_tokens):
num_selections = (topk_idx[i] != -1).sum().item()
num_dispatch_comm_bytes += num_fp8_bytes * num_selections
num_combine_comm_bytes += num_bf16_bytes * num_selections
# Dispatch + combine testing
avg_t, min_t, max_t = bench(
partial(test_func, zero_copy=False, return_recv_hook=False)
)
print(
f"[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, "
f"avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us",
flush=True,
)
# Separate profiling
for return_recv_hook in (False, True):
group.barrier()
dispatch_t, combine_t = bench_kineto(
partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook),
kernel_names=("dispatch", "combine"),
barrier_comm_profiling=True,
suppress_kineto_output=True,
)
if not return_recv_hook:
print(
f"[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | "
f"Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us",
flush=True,
)
else:
print(
f"[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} us | "
f"Combine send/recv time: {combine_t * 2 * 1e6:.2f} us",
flush=True,
)
return hash_value
# noinspection PyUnboundLocalVariable
def test_loop(local_rank: int, num_local_ranks: int):
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
num_tokens, hidden, num_topk, num_experts = 128, 7168, 8, 288
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
num_tokens, hidden, num_ranks, num_experts
)
if local_rank == 0:
print(f"Allocating buffer size: {num_rdma_bytes / 1e6} MB ...", flush=True)
buffer = deep_ep.Buffer(
group,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_experts // num_ranks,
)
test_main(
num_tokens,
hidden,
num_experts,
num_topk,
rank,
num_ranks,
group,
buffer,
seed=1,
)
do_pressure_test = False
for seed in range(int(1e9) if do_pressure_test else 0):
if local_rank == 0:
print(f"Testing with seed {seed} ...", flush=True)
ref_hash = test_main(
num_tokens,
hidden,
num_experts,
num_topk,
rank,
num_ranks,
group,
buffer,
seed=seed,
)
for i in range(20):
assert (
test_main(
num_tokens,
hidden,
num_experts,
num_topk,
rank,
num_ranks,
group,
buffer,
seed=seed,
)
== ref_hash
), f"Error: seed={seed}"
if __name__ == "__main__":
# TODO: you may modify NUMA binding for less CPU overhead
num_processes = 8
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)

View File

@@ -0,0 +1,389 @@
import unittest
from types import SimpleNamespace
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestPureDP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"4",
"--moe-a2a-backend",
"deepep",
"--cuda-graph-max-bs",
"128",
"--max-running-requests",
"512",
"--mem-fraction-static",
"0.5",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
class TestHybridDPTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"2",
"--moe-a2a-backend",
"deepep",
"--cuda-graph-max-bs",
"128",
"--max-running-requests",
"256",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
class TestTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"4",
"--moe-a2a-backend",
"deepep",
"--cuda-graph-max-bs",
"128",
"--max-running-requests",
"128",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
@unittest.skip("covered in test_deepep_large.py")
class TestNoGatherdBuffer(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"4",
"--moe-dense-tp-size",
"1",
"--enable-dp-lm-head",
"--moe-a2a-backend",
"deepep",
"--cuda-graph-max-bs",
"32",
"--max-running-requests",
"512",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
class TestTBO(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"4",
"--moe-dense-tp-size",
"1",
"--moe-a2a-backend",
"deepep",
"--enable-two-batch-overlap",
"--cuda-graph-max-bs",
"128",
"--max-running-requests",
"512",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
@unittest.skip("covered in TestMTPWithTBO")
class TestMTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"2",
"--enable-dp-lm-head",
"--moe-a2a-backend",
"deepep",
"--speculative-algo",
"EAGLE",
"--speculative-draft",
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
"--speculative-num-steps",
"2",
"--speculative-eagle-topk",
"3",
"--speculative-num-draft-tokens",
"3",
"--cuda-graph-max-bs",
"32",
"--max-running-requests",
"64",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["internal_states"][0][
"avg_spec_accept_length"
]
print(
f"###test_gsm8k (deepseek-v3 mtp + dp + tbo):\n"
f"accuracy={metrics['accuracy']=:.3f}\n"
f"{avg_spec_accept_length=:.3f}\n"
)
self.assertGreater(avg_spec_accept_length, 2.1)
class TestMTPWithTBO(CustomTestCase):
@classmethod
def setUpClass(cls):
import os
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--tp-size",
"4",
"--enable-dp-attention",
"--dp-size",
"4",
"--enable-two-batch-overlap",
"--moe-a2a-backend",
"deepep",
"--trust-remote-code",
"--speculative-algorithm",
"EAGLE",
"--speculative-num-steps",
"2",
"--speculative-eagle-topk",
"3",
"--speculative-num-draft-tokens",
"3",
"--speculative-draft",
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
"--chunked-prefill-size",
"256",
"--cuda-graph-max-bs",
"32",
"--max-running-requests",
"128",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["internal_states"][0][
"avg_spec_accept_length"
]
print(
f"###test_gsm8k (deepseek-v3 mtp + dp + tbo):\n"
f"accuracy={metrics['accuracy']=:.3f}\n"
f"{avg_spec_accept_length=:.3f}\n"
)
self.assertGreater(avg_spec_accept_length, 2.1)
if __name__ == "__main__":
unittest.main()

155
test/srt/ep/test_eplb.py Executable file
View File

@@ -0,0 +1,155 @@
import os
import tempfile
import unittest
from pathlib import Path
from types import SimpleNamespace
import sglang as sgl
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class _BaseTestDynamicEPLB(CustomTestCase):
extra_args = []
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--dp",
"2",
"--enable-dp-attention",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
"--enable-eplb",
"--ep-num-redundant-experts",
"4",
"--eplb-rebalance-num-iterations",
"50",
"--expert-distribution-recorder-buffer-size",
"50",
# TODO pr-chain: enable later
# "--enable-expert-distribution-metrics",
# TODO auto determine these flags
"--expert-distribution-recorder-mode",
"stat",
"--ep-dispatch-algorithm",
"static",
*cls.extra_args,
],
env={
"SGL_ENABLE_JIT_DEEPGEMM": "0",
"SGLANG_EXPERT_LOCATION_UPDATER_CANARY": "1",
**os.environ,
},
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.5)
class TestDynamicEPLBSimple(_BaseTestDynamicEPLB):
pass
class TestDynamicEPLBMultiChunk(_BaseTestDynamicEPLB):
extra_args = ["--eplb-rebalance-layers-per-chunk", "1"]
class TestStaticEPLB(CustomTestCase):
def test_save_expert_distribution_and_init_expert_location(self):
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "0"
with tempfile.TemporaryDirectory() as tmp_dir:
engine_kwargs = dict(
model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST,
trust_remote_code=True,
ep_num_redundant_experts=4,
enable_dp_attention=True,
moe_a2a_backend="deepep",
disable_cuda_graph=True,
expert_distribution_recorder_mode="stat",
tp_size=2,
dp_size=2,
log_level="info",
# TODO pr-chain: enable later
# enable_expert_distribution_metrics=True,
)
print(f"Action: start engine")
os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"] = tmp_dir
engine = sgl.Engine(
**engine_kwargs,
disable_overlap_schedule=True,
)
engine.start_expert_distribution_record()
self._assert_engine_generate_correct(engine)
print(f"Action: dump_expert_distribution_record")
engine.dump_expert_distribution_record()
snapshot_path = list(Path(tmp_dir).glob("*.pt"))[0]
assert snapshot_path is not None
print(f"{snapshot_path=}")
print(f"Action: shutdown engine")
engine.shutdown()
del engine
print(f"Action: start engine with init_expert_location")
engine = sgl.Engine(
**engine_kwargs,
init_expert_location=str(snapshot_path),
port=21000,
# TODO auto determine these flags
ep_dispatch_algorithm="static",
)
self._assert_engine_generate_correct(engine)
print(f"Action: shutdown engine")
engine.shutdown()
del engine
def _assert_engine_generate_correct(self, engine: sgl.Engine):
output = engine.generate(
prompt=["1+1=2, 2+2=4", "One plus one is two, two plus two is four"],
sampling_params=dict(max_new_tokens=8, temperature=0.0),
)
print(f"engine.generate {output=}")
self.assertEqual(
[x["text"] for x in output],
[", 4+4=8,", ", four plus four is eight, eight"],
)
if __name__ == "__main__":
unittest.main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,119 @@
import json
import os
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestPureTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--moe-a2a-backend",
"deepep",
"--disable-cuda-graph",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.5)
class TestDPAttn(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--dp",
"2",
"--enable-dp-attention",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
# Test custom config
"--deepep-config",
json.dumps(
{
"normal_dispatch": {
"num_sms": 20,
"num_max_nvl_chunked_send_tokens": 16,
"num_max_nvl_chunked_recv_tokens": 256,
"num_max_rdma_chunked_send_tokens": 6,
"num_max_rdma_chunked_recv_tokens": 128,
},
"normal_combine": {
"num_sms": 20,
"num_max_nvl_chunked_send_tokens": 6,
"num_max_nvl_chunked_recv_tokens": 256,
"num_max_rdma_chunked_send_tokens": 6,
"num_max_rdma_chunked_recv_tokens": 128,
},
}
),
],
env={
"SGL_ENABLE_JIT_DEEPGEMM": "0",
**os.environ,
},
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.5)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,75 @@
"""
Usage:
python -m unittest test_moe_deepep_eval_accuracy_large.TestMoEDeepEPEvalAccuracyLarge.test_mmlu
"""
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestMoEDeepEPEvalAccuracyLarge(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"8",
"--moe-a2a-backend",
"deepep",
"--cuda-graph-max-bs",
"128",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=8,
data_path=None,
num_questions=200,
parallel=64,
max_new_tokens=512,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Eval accuracy of GSM8K: {metrics=}")
self.assertGreater(metrics["accuracy"], 0.93)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
print(f"Eval accuracy of MMLU: {metrics=}")
self.assertGreater(metrics["score"], 0.87)
if __name__ == "__main__":
unittest.main()

112
test/srt/ep/test_moe_ep.py Normal file
View File

@@ -0,0 +1,112 @@
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestEpMoE(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--ep-size",
"2",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.5)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.8)
class TestEpMoEFP8(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--ep-size",
"2",
"--quantization",
"fp8",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.5)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.8)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,365 @@
import argparse
import logging
import os
import queue
import re
import subprocess
import threading
import time
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import psutil
import requests
import yaml
@dataclass
class ServerConfig:
command: str
process_names: List[str]
default_port: int
@dataclass
class TaskConfig:
server_cmd: str
client_cmd: str
name: Optional[str] = None
server_type: Optional[str] = None
@dataclass
class TaskResult:
name: str
success: bool
output: str
runtime: float
timestamp: str
SERVER_DEFAULTS = {
"sglang": ServerConfig(
command="sglang.launch_server",
process_names=["sglang.launch_server"],
default_port=30000,
),
"vllm": ServerConfig(
command="vllm.entrypoints.openai.api_server",
process_names=["vllm.entrypoints.openai.api_server"],
default_port=8000,
),
}
def parse_key_info(output: str) -> str:
"""Extract and format key information from the output"""
key_info = []
# Extract Args namespace
args_match = re.search(r"Namespace\(.*?\)", output, re.DOTALL)
if args_match:
key_info.append(args_match.group(0))
# Extract input/output token counts
token_matches = re.findall(r"#(Input|Output) tokens: \d+", output)
key_info.extend(token_matches)
# Extract benchmark result section
result_match = re.search(
r"============ Serving Benchmark Result ============.*?={50,}",
output,
re.DOTALL,
)
if result_match:
key_info.append(result_match.group(0))
return "\n\n".join(key_info)
def extract_port_from_command(cmd: str, server_type: str) -> int:
port_match = re.search(r"--port[= ](\d+)", cmd)
if port_match:
return int(port_match.group(1))
return SERVER_DEFAULTS.get(server_type, ServerConfig("", [], 8000)).default_port
def detect_server_type(cmd: str) -> str:
for server_type, config in SERVER_DEFAULTS.items():
if config.command in cmd:
return server_type
return "unknown"
def stream_output(
process: subprocess.Popen, prefix: str, logger: logging.Logger
) -> queue.Queue:
output_queue = queue.Queue()
def stream_pipe(pipe, prefix):
for line in iter(pipe.readline, ""):
if prefix == "CLIENT":
output_queue.put(line.rstrip())
logger.debug(f"{prefix} | {line.rstrip()}")
stdout_thread = threading.Thread(
target=stream_pipe, args=(process.stdout, prefix), daemon=True
)
stderr_thread = threading.Thread(
target=stream_pipe, args=(process.stderr, prefix), daemon=True
)
stdout_thread.start()
stderr_thread.start()
return output_queue, (stdout_thread, stderr_thread)
class ProcessManager:
def __init__(self):
self.server_process: Optional[subprocess.Popen] = None
self.client_process: Optional[subprocess.Popen] = None
self.logger = logging.getLogger(__name__)
def start_process(
self, command: str, prefix: str
) -> Tuple[subprocess.Popen, queue.Queue]:
process = subprocess.Popen(
command,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1,
)
output_queue, threads = stream_output(process, prefix, self.logger)
return process, output_queue, threads
def kill_process_tree(self, process: subprocess.Popen):
try:
parent = psutil.Process(process.pid)
children = parent.children(recursive=True)
for child in children:
try:
child.kill()
except psutil.NoSuchProcess:
pass
parent.kill()
gone, alive = psutil.wait_procs(children + [parent], timeout=3)
for p in alive:
try:
p.kill()
except psutil.NoSuchProcess:
pass
except psutil.NoSuchProcess:
pass
def cleanup(self, process_names: List[str]):
if self.client_process:
self.kill_process_tree(self.client_process)
self.client_process = None
if self.server_process:
self.kill_process_tree(self.server_process)
self.server_process = None
for proc in psutil.process_iter(["pid", "name", "cmdline"]):
try:
cmdline = " ".join(proc.cmdline())
if any(name in cmdline for name in process_names):
proc.kill()
except (psutil.NoSuchProcess, psutil.AccessDenied):
continue
class ExperimentRunner:
def __init__(self):
self.process_manager = ProcessManager()
self.logger = logging.getLogger(__name__)
def wait_for_server(self, port: int, timeout: int = 300) -> bool:
start_time = time.perf_counter()
while time.perf_counter() - start_time < timeout:
try:
response = requests.get(f"http://localhost:{port}/health")
if response.status_code == 200:
self.logger.debug(f"Server ready on port {port}")
return True
except requests.RequestException:
time.sleep(2)
return False
def run_task(self, config: TaskConfig) -> TaskResult:
start_time = time.perf_counter()
client_output = []
try:
if not config.server_type:
config.server_type = detect_server_type(config.server_cmd)
server_config = SERVER_DEFAULTS.get(config.server_type)
if not server_config:
raise ValueError(f"Unknown server type: {config.server_type}")
port = extract_port_from_command(config.server_cmd, config.server_type)
self.process_manager.cleanup(server_config.process_names)
self.logger.debug(f"Starting server: {config.name}")
self.process_manager.server_process, _, server_threads = (
self.process_manager.start_process(config.server_cmd, "SERVER")
)
if not self.wait_for_server(port):
raise TimeoutError("Server startup timeout")
time.sleep(10)
self.logger.debug("Starting client")
self.process_manager.client_process, output_queue, client_threads = (
self.process_manager.start_process(config.client_cmd, "CLIENT")
)
returncode = self.process_manager.client_process.wait()
while True:
try:
line = output_queue.get_nowait()
client_output.append(line)
except queue.Empty:
break
if returncode != 0:
raise RuntimeError(f"Client failed with code {returncode}")
# Parse and format the output
full_output = "\n".join(client_output)
formatted_output = parse_key_info(full_output)
return TaskResult(
name=config.name,
success=True,
output=formatted_output,
runtime=time.perf_counter() - start_time,
timestamp=datetime.now().isoformat(),
)
except Exception as e:
return TaskResult(
name=config.name,
success=False,
output=str(e),
runtime=time.perf_counter() - start_time,
timestamp=datetime.now().isoformat(),
)
finally:
if config.server_type in SERVER_DEFAULTS:
self.process_manager.cleanup(
SERVER_DEFAULTS[config.server_type].process_names
)
time.sleep(10)
def load_config(config_path: str) -> List[TaskConfig]:
with open(config_path, "r") as f:
config_data = yaml.safe_load(f)
configs = []
for idx, entry in enumerate(config_data.get("tasks", [])):
if not isinstance(entry, dict):
raise ValueError(f"Invalid entry at index {idx}")
config = TaskConfig(
server_cmd=entry.get("server_cmd"),
client_cmd=entry.get("client_cmd"),
name=entry.get("name", f"task-{idx+1}"),
server_type=entry.get("server_type"),
)
if not config.server_cmd or not config.client_cmd:
raise ValueError(f"Missing commands in {config.name}")
configs.append(config)
return configs
def setup_logging(debug: bool = False):
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(
level=level,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(), logging.FileHandler("experiment.log")],
)
def format_results(results: List[TaskResult]) -> str:
"""Format experiment results in Markdown for GitHub step summary."""
output = ["# Experiment Results\n"]
for result in results:
output.append(f"## {result.name}")
output.append(f"**Status**: {'✅ Success' if result.success else '❌ Failed'}")
output.append(f"**Runtime**: {result.runtime:.2f} seconds")
output.append(f"**Timestamp**: {result.timestamp}")
output.append("\n**Output**:\n```")
output.append(result.output)
output.append("```\n")
return "\n".join(output)
def get_bool_env_var(name: str, default: str = "false") -> bool:
value = os.getenv(name, default)
return value.lower() in ("true", "1")
def write_in_github_step_summary(results: List[TaskResult]):
"""Write formatted results to GitHub step summary."""
if not os.environ.get("GITHUB_STEP_SUMMARY"):
logging.warning("GITHUB_STEP_SUMMARY environment variable not set")
return
formatted_content = format_results(results)
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
f.write(formatted_content)
def main():
parser = argparse.ArgumentParser(description="Experiment Runner")
parser.add_argument(
"--config", type=str, required=True, help="Path to YAML config file"
)
parser.add_argument("--debug", action="store_true", help="Enable debug output")
args = parser.parse_args()
setup_logging(args.debug)
logger = logging.getLogger(__name__)
results = []
try:
configs = load_config(args.config)
runner = ExperimentRunner()
for config in configs:
logger.info(f"Running {config.name}")
result = runner.run_task(config)
results.append(result)
if get_bool_env_var("SGLANG_IS_IN_CI"):
write_in_github_step_summary(results)
except Exception as e:
logger.error(f"Error: {e}")
raise
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,53 @@
import unittest
from types import SimpleNamespace
from sglang.srt.utils import is_hip, kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
_is_hip = is_hip()
class TestHiCache(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--enable-hierarchical-cache",
"--mem-fraction-static",
0.7,
"--hicache-size",
100 if not _is_hip else 200,
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,67 @@
import unittest
from types import SimpleNamespace
from sglang.srt.utils import is_hip, kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
_is_hip = is_hip()
if _is_hip:
hicache_args = ["--hicache-size", 200]
else:
hicache_args = ["--hicache-ratio", 2]
class TestHierarchicalMLA(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--enable-hierarchical-cache",
]
+ hicache_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.5)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.8)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,51 @@
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestHiCachePage(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--enable-hierarchical-cache",
"--page-size",
32,
"--hicache-write-policy",
"write_back",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,57 @@
import unittest
from types import SimpleNamespace
from sglang.srt.utils import is_hip, kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
_is_hip = is_hip()
class TestHiCache(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--enable-hierarchical-cache",
"--mem-fraction-static",
0.7,
"--hicache-size",
100 if not _is_hip else 200,
"--page-size",
"64",
"--hicache-storage-backend",
"file",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,42 @@
{
"model_type": "llama",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 1,
"1": 1,
"2": 1,
"3": 1,
"4": 1,
"5": 1,
"6": 1,
"7": 1,
"8": 1,
"9": 1,
"10": 1,
"11": 1,
"12": 1,
"13": 1,
"14": 1,
"15": 1,
"16": 1,
"17": 1,
"18": 1,
"19": 1,
"20": 1,
"21": 1,
"22": 1,
"23": 1,
"24": 1,
"25": 1,
"26": 1,
"27": 1,
"28": 1,
"29": 1,
"30": 1,
"31": 1
}
}
}
}

View File

@@ -0,0 +1,42 @@
{
"model_type": "llama",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 0.0408,
"1": 0.0503,
"2": 0.0667,
"3": 0.0909,
"4": 0.1135,
"5": 0.127,
"6": 0.1768,
"7": 0.1488,
"8": 0.1135,
"9": 0.1203,
"10": 0.1013,
"11": 0.0842,
"12": 0.1231,
"13": 0.1096,
"14": 0.1221,
"15": 0.1013,
"16": 0.1067,
"17": 0.0952,
"18": 0.0899,
"19": 0.097,
"20": 0.087,
"21": 0.0994,
"22": 0.0904,
"23": 0.1013,
"24": 0.1019,
"25": 0.1053,
"26": 0.1,
"27": 0.0894,
"28": 0.1013,
"29": 0.1488,
"30": 0.0766,
"31": 0.0821
}
}
}
}

View File

@@ -0,0 +1,38 @@
{
"model_type": "qwen",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 0.9846,
"1": 0.0645,
"2": 0.0731,
"3": 0.0800,
"4": 0.0748,
"5": 0.0780,
"6": 0.0702,
"7": 0.0894,
"8": 0.0410,
"9": 0.0758,
"10": 0.0556,
"11": 0.0731,
"12": 0.0899,
"13": 0.0780,
"14": 0.1441,
"15": 0.0914,
"16": 0.5614,
"17": 0.1067,
"18": 0.0537,
"19": 0.0658,
"20": 0.0523,
"21": 0.0533,
"22": 0.0699,
"23": 0.0635,
"24": 0.0588,
"25": 0.0884,
"26": 0.0947,
"27": 0.1032
}
}
}
}

178
test/srt/lora/test_lora.py Normal file
View File

@@ -0,0 +1,178 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import os
import random
import unittest
from typing import List
import torch
from utils import (
ALL_OTHER_MULTI_LORA_MODELS,
CI_MULTI_LORA_MODELS,
TORCH_DTYPES,
LoRAModelCase,
)
from sglang.test.runners import HFRunner, SRTRunner
from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci
TEST_MULTIPLE_BATCH_PROMPTS = [
"""
### Instruction:
Tell me about llamas and alpacas
### Response:
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing.
### Question 2:
What do you know about llamas?
### Answer:
""",
"""
### Instruction:
Write a poem about the transformers Python library.
Mention the word "large language models" in that poem.
### Response:
The Transformers are large language models,
They're used to make predictions on text.
""",
"AI is a field of computer science focused on",
"Computer science is the study of",
"Write a short story.",
"What are the main components of a computer?",
]
class TestLoRA(CustomTestCase):
def _create_test_samples(
self, lora_adapter_paths: List[str], repeated_trials: int = 3
):
random.seed(42) # Ensure reproducibility
patterns = [
[None, lora_adapter_paths[0], lora_adapter_paths[1]],
[lora_adapter_paths[0], None, lora_adapter_paths[1]],
[lora_adapter_paths[0], lora_adapter_paths[1], None],
[None, lora_adapter_paths[1], None],
[None, None, None],
]
batches = [
[random.choice(pattern) for _ in range(3)]
for pattern in patterns
for _ in range(repeated_trials)
]
return batches
def ensure_reproducibility(self):
seed = 42
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(True)
def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]):
for model_case in model_cases:
for torch_dtype in TORCH_DTYPES:
max_new_tokens = 32
backend = "triton"
base_path = model_case.base
lora_adapter_paths = [a.name for a in model_case.adaptors]
assert len(lora_adapter_paths) >= 2
print(
f"\n========== Testing multiple batches on base '{base_path}' with backend={backend}, dtype={torch_dtype} ---"
)
# Initialize runners
srt_runner = SRTRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]],
max_loras_per_batch=len(lora_adapter_paths) + 1,
lora_backend=backend,
sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch.
attention_backend="torch_native",
)
hf_runner = HFRunner(
base_path, torch_dtype=torch_dtype, model_type="generation"
)
batches = self._create_test_samples(lora_adapter_paths)
with srt_runner, hf_runner:
for i, lora_paths in enumerate(batches, start=1):
prompts = [
random.choice(TEST_MULTIPLE_BATCH_PROMPTS) for _ in range(3)
]
print(
f"\n--- Running Batch {i} --- prompts: {prompts}, lora_paths: {lora_paths}"
)
self.ensure_reproducibility()
srt_outputs = srt_runner.batch_forward(
prompts,
max_new_tokens=max_new_tokens,
lora_paths=lora_paths,
)
self.ensure_reproducibility()
hf_outputs = hf_runner.forward(
prompts,
max_new_tokens=max_new_tokens,
lora_paths=lora_paths,
)
print("SRT outputs:", [s for s in srt_outputs.output_strs])
print("HF outputs:", [s for s in hf_outputs.output_strs])
for srt_out, hf_out in zip(
srt_outputs.output_strs, hf_outputs.output_strs
):
srt_str = srt_out.strip()
hf_str = hf_out.strip()
rouge_tol = model_case.rouge_l_tolerance
rouge_score = calculate_rouge_l([srt_str], [hf_str])[0]
if rouge_score < rouge_tol:
raise AssertionError(
f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} "
f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'"
)
print(f"--- Batch {i} Comparison Passed --- ")
def test_ci_lora_models(self):
self._run_lora_multiple_batch_on_model_cases(CI_MULTI_LORA_MODELS)
def test_all_lora_models(self):
if is_in_ci():
return
filtered_models = []
for model_case in ALL_OTHER_MULTI_LORA_MODELS:
if "ONLY_RUN" in os.environ and os.environ["ONLY_RUN"] != model_case.base:
continue
filtered_models.append(model_case)
self._run_lora_multiple_batch_on_model_cases(filtered_models)
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore")

View File

@@ -0,0 +1,76 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import os
import unittest
from typing import List
from utils import (
ALL_OTHER_LORA_MODELS,
BACKENDS,
CI_LORA_MODELS,
DEFAULT_PROMPTS,
TORCH_DTYPES,
LoRAModelCase,
run_lora_test_one_by_one,
)
from sglang.test.test_utils import CustomTestCase, is_in_ci
class TestLoRABackend(CustomTestCase):
def _run_backend_on_model_cases(self, model_cases: List[LoRAModelCase]):
for model_case in model_cases:
# If skip_long_prompt is True, filter out prompts longer than 1000 characters
prompts = (
DEFAULT_PROMPTS
if not model_case.skip_long_prompt
else [p for p in DEFAULT_PROMPTS if len(p) < 1000]
)
for torch_dtype in TORCH_DTYPES:
for backend in BACKENDS:
run_lora_test_one_by_one(
prompts,
model_case,
torch_dtype,
max_new_tokens=32,
backend=backend,
)
def test_ci_lora_models(self):
self._run_backend_on_model_cases(CI_LORA_MODELS)
def test_all_lora_models(self):
if is_in_ci():
return
# Retain ONLY_RUN check here
filtered_models = []
for model_case in ALL_OTHER_LORA_MODELS:
if "ONLY_RUN" in os.environ and os.environ["ONLY_RUN"] != model_case.base:
continue
filtered_models.append(model_case)
self._run_backend_on_model_cases(filtered_models)
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore")

View File

@@ -0,0 +1,110 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import os
import unittest
from typing import List
from utils import (
ALL_OTHER_LORA_MODELS,
CI_LORA_MODELS,
DEFAULT_PROMPTS,
TORCH_DTYPES,
LoRAModelCase,
run_lora_test_by_batch,
run_lora_test_one_by_one,
)
from sglang.test.test_utils import CustomTestCase, is_in_ci
TEST_CUDA_GRAPH_PADDING_PROMPTS = [
"AI is a field of computer science focused on",
"""
### Instruction:
Tell me about llamas and alpacas
### Response:
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing.
### Question 2:
What do you know about llamas?
### Answer:
""",
"Computer science is the study of",
]
class TestLoRACudaGraph(CustomTestCase):
def _run_without_cuda_graph_on_model_cases(self, model_cases: List[LoRAModelCase]):
# Since we have already enabled CUDA graph by default in other lora tests,
# we only need to run lora tests without CUDA graph here.
for model_case in model_cases:
# If skip_long_prompt is True, filter out prompts longer than 1000 characters
prompts = (
DEFAULT_PROMPTS
if not model_case.skip_long_prompt
else [p for p in DEFAULT_PROMPTS if len(p) < 1000]
)
for torch_dtype in TORCH_DTYPES:
run_lora_test_one_by_one(
prompts,
model_case,
torch_dtype,
max_new_tokens=32,
backend="triton",
disable_cuda_graph=True,
test_tag="without_cuda_graph",
)
def _run_cuda_graph_padding_on_model_cases(self, model_cases: List[LoRAModelCase]):
for model_case in model_cases:
# Run a batch size of 3, which will not be captured by CUDA graph and need padding
prompts = TEST_CUDA_GRAPH_PADDING_PROMPTS
for torch_dtype in TORCH_DTYPES:
run_lora_test_by_batch(
prompts,
model_case,
torch_dtype,
max_new_tokens=32,
backend="triton",
disable_cuda_graph=False,
test_tag="cuda_graph_padding",
)
def test_ci_lora_models(self):
self._run_without_cuda_graph_on_model_cases(CI_LORA_MODELS)
self._run_cuda_graph_padding_on_model_cases(CI_LORA_MODELS)
def test_all_lora_models(self):
if is_in_ci():
return
# Retain ONLY_RUN check here
filtered_models = []
for model_case in ALL_OTHER_LORA_MODELS:
if "ONLY_RUN" in os.environ and os.environ["ONLY_RUN"] != model_case.base:
continue
filtered_models.append(model_case)
self._run_without_cuda_graph_on_model_cases(filtered_models)
self._run_cuda_graph_padding_on_model_cases(filtered_models)
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore")

View File

@@ -0,0 +1,146 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import contextlib
import multiprocessing as mp
import unittest
from typing import Dict, List, Tuple
import torch
from sglang.test.runners import SRTRunner
from sglang.test.test_utils import CustomTestCase
PROMPTS = [
"AI is a field of computer science focused on",
"""
### Instruction:
Compose a SQL query that uses the following table: users, and returns the user_id and name of all users whose name that does not have a duplicate in the table.
### Response:
SELECT user_id, name FROM users WHERE name LIKE 'A%';
""",
]
ADAPTERS = [
"faridlazuarda/valadapt-llama-3.1-8B-it-chinese", # target_modules = q, v
"philschmid/code-llama-3-1-8b-text-to-sql-lora", # target_modules = q, k, v, o, gate, up, down
]
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
@contextlib.contextmanager
def dynamically_loaded_adapter(runner, lora_path: str, lora_name: str):
"""A context manager to load and automatically unload a LoRA adapter."""
try:
runner.load_lora_adapter(lora_name=lora_name, lora_path=lora_path)
yield
finally:
runner.unload_lora_adapter(lora_name=lora_name)
class TestLoRAEviction(CustomTestCase):
def test_lora_eviction_with_different_target_modules(self):
"""
Test LoRA eviction with different target modules.
This test runs inference against two LoRA adapters in different orders to force eviction behavior, and ensures
that the outputs of the same (adapter, prompt) pair are consistent across runs.
"""
output_history = {}
self._run_test(ADAPTERS, output_history, reverse=False)
self._run_test(ADAPTERS, output_history, reverse=True)
def test_lora_eviction_with_reused_lora_name(self):
"""
Test LoRA eviction with reused LoRA names.
This test runs inference against two LoRA adapters with the same name to ensure that the eviction behavior
works correctly when reusing LoRA names.
"""
output_history = {}
self._run_test(ADAPTERS, output_history, reuse_lora_name=True, repeat=1)
self._run_test(ADAPTERS, output_history, reuse_lora_name=False, repeat=1)
def _run_test(
self,
lora_paths: List[str],
output_history: Dict[Tuple[str, str], str],
reverse: bool = False,
repeat: int = 2,
reuse_lora_name: bool = False,
):
REUSED_LORA_NAME = "lora"
max_new_tokens = 256
backend = "triton"
torch_dtype = torch.float16
base_path = BASE_MODEL
assert len(lora_paths) >= 2
initial_lora_paths = lora_paths if not reuse_lora_name else None
# Initialize runners
with SRTRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
lora_paths=initial_lora_paths,
max_loras_per_batch=1,
lora_backend=backend,
enable_lora=True,
max_lora_rank=256,
lora_target_modules=["all"],
) as srt_runner:
adapter_sequence = lora_paths if not reverse else lora_paths[::-1]
for i in range(repeat):
for j, lora_path in enumerate(adapter_sequence):
print(
f"\n========== Testing LoRA eviction with adapter '{lora_path}' (#{j + 1}/{len(adapter_sequence)}), reuse_lora_name: {reuse_lora_name}, reversed: {reverse}, repeat: {i + 1}/{repeat} ---"
)
lora_name = REUSED_LORA_NAME if reuse_lora_name else lora_path
context = (
dynamically_loaded_adapter(srt_runner, lora_path, lora_name)
if reuse_lora_name
else contextlib.nullcontext()
)
with context:
for prompt in PROMPTS:
print("\nprompt:\n", prompt)
srt_outputs = srt_runner.forward(
[prompt],
max_new_tokens=max_new_tokens,
lora_paths=[lora_name],
)
output = srt_outputs.output_strs[0].strip()
print("\noutput:\n", output)
prev_output = output_history.get((lora_path, prompt))
if prev_output is not None:
self.assertEqual(
prev_output,
output,
f"Output mismatch for adapter {lora_path} and prompt '{prompt}' on repeat {j + 1}, previous: '{prev_output}', current: '{output}'.",
)
else:
output_history[(lora_path, prompt)] = output
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore")

View File

@@ -0,0 +1,208 @@
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import os
import random
import unittest
from typing import List
from utils import TORCH_DTYPES, LoRAAdaptor, LoRAModelCase
from sglang.test.runners import HFRunner, SRTRunner
from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci
LORA_MODELS_QWEN3 = [
LoRAModelCase(
base="Qwen/Qwen3-4B",
adaptors=[
LoRAAdaptor(
name="nissenj/Qwen3-4B-lora-v2",
prefill_tolerance=3e-1,
),
LoRAAdaptor(
name="y9760210/Qwen3-4B-lora_model",
prefill_tolerance=3e-1,
),
],
max_loras_per_batch=2,
),
]
TEST_MULTIPLE_BATCH_PROMPTS = [
"""
### Instruction:
Tell me about llamas and alpacas
### Response:
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing.
### Question 2:
What do you know about llamas?
### Answer:
""",
"""
### Instruction:
Write a poem about the transformers Python library.
Mention the word "large language models" in that poem.
### Response:
The Transformers are large language models,
They're used to make predictions on text.
""",
# "AI is a field of computer science focused on", TODO: Add it back after fixing its bug
"Computer science is the study of",
"Write a short story.",
"What are the main components of a computer?",
]
class TestLoRA(CustomTestCase):
def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]):
for model_case in model_cases:
for torch_dtype in TORCH_DTYPES:
max_new_tokens = 10
backend = "triton"
base_path = model_case.base
lora_adapter_paths = [a.name for a in model_case.adaptors]
assert len(lora_adapter_paths) >= 2
batches = [
(
[
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
],
[
None,
lora_adapter_paths[0],
lora_adapter_paths[1],
],
),
(
[
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
],
[
lora_adapter_paths[0],
None,
lora_adapter_paths[1],
],
),
(
[
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
],
[lora_adapter_paths[0], lora_adapter_paths[1], None],
),
(
[
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
],
[None, lora_adapter_paths[1], None],
),
(
[
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
],
[None, None, None],
),
]
print(
f"\n========== Testing multiple batches on base '{base_path}' with backend={backend}, dtype={torch_dtype} ---"
)
# Initialize runners
srt_runner = SRTRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]],
max_loras_per_batch=len(lora_adapter_paths) + 1,
lora_backend=backend,
)
hf_runner = HFRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
patch_model_do_sample_false=True,
)
with srt_runner, hf_runner:
for i, (prompts, lora_paths) in enumerate(batches):
print(
f"\n--- Running Batch {i+1} --- prompts: {prompts}, lora_paths: {lora_paths}"
)
srt_outputs = srt_runner.batch_forward(
prompts,
max_new_tokens=max_new_tokens,
lora_paths=lora_paths,
)
hf_outputs = hf_runner.forward(
prompts,
max_new_tokens=max_new_tokens,
lora_paths=lora_paths,
)
print("SRT outputs:", [s for s in srt_outputs.output_strs])
print("HF outputs:", [s for s in hf_outputs.output_strs])
for srt_out, hf_out in zip(
srt_outputs.output_strs, hf_outputs.output_strs
):
srt_str = srt_out.strip()
hf_str = hf_out.strip()
rouge_tol = model_case.rouge_l_tolerance
rouge_score = calculate_rouge_l([srt_str], [hf_str])[0]
if rouge_score < rouge_tol:
raise AssertionError(
f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} "
f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'"
)
print(f"--- Batch {i+1} Comparison Passed --- ")
def test_ci_lora_models(self):
self._run_lora_multiple_batch_on_model_cases(LORA_MODELS_QWEN3)
def test_all_lora_models(self):
if is_in_ci():
return
qwen_filtered_models = []
for model_case in LORA_MODELS_QWEN3:
if "ONLY_RUN" in os.environ and os.environ["ONLY_RUN"] != model_case.base:
continue
qwen_filtered_models.append(model_case)
self._run_lora_multiple_batch_on_model_cases(qwen_filtered_models)
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore")

View File

@@ -0,0 +1,83 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import random
import unittest
import torch
from utils import CI_MULTI_LORA_MODELS, DEFAULT_PROMPTS, run_lora_test_one_by_one
from sglang.test.runners import HFRunner, SRTRunner
from sglang.test.test_utils import CustomTestCase
PROMPTS = [
"AI is a field of computer science focused on",
"""
### Instruction:
Tell me about llamas and alpacas
### Response:
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids.
### Question:
What do you know about llamas?
### Answer:
""",
]
class TestLoRARadixCache(CustomTestCase):
def test_lora_radix_cache(self):
# Here we need a model case with multiple adaptors for testing correctness of radix cache
model_case = CI_MULTI_LORA_MODELS[0]
torch_dtype = torch.float16
max_new_tokens = 32
backend = "triton"
batch_prompts = (
PROMPTS
if not model_case.skip_long_prompt
else [p for p in PROMPTS if len(p) < 1000]
)
# Test lora with radix cache
run_lora_test_one_by_one(
batch_prompts,
model_case,
torch_dtype,
max_new_tokens=max_new_tokens,
backend=backend,
disable_radix_cache=False,
test_tag="lora-with-radix-cache",
)
# Test lora without radix cache
run_lora_test_one_by_one(
batch_prompts,
model_case,
torch_dtype,
max_new_tokens=max_new_tokens,
backend=backend,
disable_radix_cache=True,
test_tag="lora-without-radix-cache",
)
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore")

View File

@@ -0,0 +1,78 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import os
import unittest
from typing import List
from utils import (
ALL_OTHER_LORA_MODELS,
CI_LORA_MODELS,
DEFAULT_PROMPTS,
TORCH_DTYPES,
LoRAModelCase,
run_lora_test_one_by_one,
)
from sglang.test.test_utils import CustomTestCase, is_in_ci
class TestLoRATP(CustomTestCase):
def _run_tp_on_model_cases(self, model_cases: List[LoRAModelCase]):
tp_list = [2] # Define TP sizes to iterate over
for model_case in model_cases:
# If skip_long_prompt is True, filter out prompts longer than 1000 characters
prompts = (
DEFAULT_PROMPTS
if not model_case.skip_long_prompt
else [p for p in DEFAULT_PROMPTS if len(p) < 1000]
)
for tp_size in tp_list:
model_case.tp_size = tp_size
for torch_dtype in TORCH_DTYPES:
run_lora_test_one_by_one(
prompts,
model_case,
torch_dtype,
max_new_tokens=32,
backend="triton",
test_tag=f"tp={tp_size}",
)
def test_ci_lora_models(self):
self._run_tp_on_model_cases(CI_LORA_MODELS)
def test_all_lora_models(self):
if is_in_ci():
return
# Retain ONLY_RUN check here
filtered_models = []
for model_case in ALL_OTHER_LORA_MODELS:
if "ONLY_RUN" in os.environ and os.environ["ONLY_RUN"] != model_case.base:
continue
filtered_models.append(model_case)
self._run_tp_on_model_cases(filtered_models)
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,90 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import os
import unittest
from typing import List
from utils import (
ALL_OTHER_MULTI_LORA_MODELS,
BACKENDS,
CI_MULTI_LORA_MODELS,
TORCH_DTYPES,
LoRAModelCase,
run_lora_test_one_by_one,
)
from sglang.test.test_utils import CustomTestCase, is_in_ci
# All prompts are used at once in a batch.
PROMPTS = [
"AI is a field of computer science focused on",
"""
### Instruction:
Tell me about llamas and alpacas
### Response:
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids.
### Question:
What do you know about llamas?
### Answer:
""",
]
class TestMultiLoRABackend(CustomTestCase):
def _run_multi_lora_test_on_model_cases(self, model_cases: List[LoRAModelCase]):
for model_case in model_cases:
# If skip_long_prompt is True, filter out prompts longer than 1000 characters.
batch_prompts = (
PROMPTS
if not model_case.skip_long_prompt
else [p for p in PROMPTS if len(p) < 1000]
)
for torch_dtype in TORCH_DTYPES:
for backend in BACKENDS:
run_lora_test_one_by_one(
batch_prompts,
model_case,
torch_dtype,
max_new_tokens=32,
backend=backend,
test_tag="multi-lora-backend",
)
def test_ci_lora_models(self):
self._run_multi_lora_test_on_model_cases(CI_MULTI_LORA_MODELS)
def test_all_lora_models(self):
if is_in_ci():
return
# Retain ONLY_RUN check here
filtered_models = []
for model_case in ALL_OTHER_MULTI_LORA_MODELS:
if "ONLY_RUN" in os.environ and os.environ["ONLY_RUN"] != model_case.base:
continue
filtered_models.append(model_case)
self._run_multi_lora_test_on_model_cases(filtered_models)
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore")

388
test/srt/lora/utils.py Normal file
View File

@@ -0,0 +1,388 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import dataclasses
from typing import List
import torch
from sglang.test.runners import HFRunner, SRTRunner
from sglang.test.test_utils import calculate_rouge_l
@dataclasses.dataclass
class LoRAAdaptor:
name: str
prefill_tolerance: float = None
decode_tolerance: float = None
rouge_l_tolerance: float = None
@dataclasses.dataclass
class LoRAModelCase:
base: str
adaptors: List[LoRAAdaptor]
tp_size: int = 1
prefill_tolerance: float = 1e-1
decode_tolerance: float = 1e-1
rouge_l_tolerance: float = 1.0
max_loras_per_batch: int = 1
skip_long_prompt: bool = False
def __post_init__(self):
if len(self.adaptors) > self.max_loras_per_batch:
raise ValueError(
f"For base '{self.base}', number of adaptors ({len(self.adaptors)}) "
f"must be <= max_loras_per_batch ({self.max_loras_per_batch})"
)
TORCH_DTYPES = [torch.float16]
BACKENDS = ["triton"]
DEFAULT_PROMPTS = [
"AI is a field of computer science focused on",
"""
### Instruction:
Tell me about llamas and alpacas
### Response:
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing.
### Question 2:
What do you know about llamas?
### Answer:
""",
]
CI_LORA_MODELS = [
LoRAModelCase(
base="meta-llama/Llama-3.1-8B-Instruct",
adaptors=[
LoRAAdaptor(
name="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
),
],
max_loras_per_batch=1,
),
]
ALL_OTHER_LORA_MODELS = [
LoRAModelCase(
base="meta-llama/Llama-3.1-8B-Instruct",
adaptors=[
LoRAAdaptor(
name="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
prefill_tolerance=1e-1,
),
],
max_loras_per_batch=1,
),
LoRAModelCase(
base="meta-llama/Llama-2-7b-hf",
adaptors=[LoRAAdaptor(name="winddude/wizardLM-LlaMA-LoRA-7B")],
max_loras_per_batch=2,
),
]
CI_MULTI_LORA_MODELS = [
# multi-rank case
LoRAModelCase(
base="meta-llama/Llama-2-7b-hf",
adaptors=[
LoRAAdaptor(
name="winddude/wizardLM-LlaMA-LoRA-7B",
prefill_tolerance=1e-1,
),
LoRAAdaptor(
name="RuterNorway/Llama-2-7b-chat-norwegian-LoRa",
prefill_tolerance=3e-1,
),
],
max_loras_per_batch=2,
),
]
ALL_OTHER_MULTI_LORA_MODELS = [
LoRAModelCase(
base="meta-llama/Llama-3.1-8B-Instruct",
adaptors=[
LoRAAdaptor(
name="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
prefill_tolerance=1e-1,
),
LoRAAdaptor(
name="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
prefill_tolerance=1e-1,
),
],
max_loras_per_batch=2,
),
]
def run_lora_test_one_by_one(
prompts: List[str],
model_case: LoRAModelCase,
torch_dtype: torch.dtype,
max_new_tokens: int,
backend: str,
disable_cuda_graph: bool = False,
disable_radix_cache: bool = False,
mem_fraction_static: float = 0.88,
test_tag: str = "",
):
"""
Input a batch of prompts, and run lora tests one by one with several generate requests
(each request will have bs=1).
For prompt0, prompt1, ..., promptN,
we will use adaptor0, adaptor1, ..., adaptorN included in model case,
We will then compare the outputs of HF and SRT with and without LoRA.
If number of prompts is larger than number of adaptors,
the prompt i will use adaptor i % (number of adaptors).
Args:
prompts (List[str]): The batch of prompts to test.
model_case (LoRAModelCase): The model case to test.
torch_dtype (torch.dtype): The torch dtype to use.
max_new_tokens (int): The maximum number of new tokens to generate.
backend (str): The lora backend to use.
disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False.
disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to False.
mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88.
test_tag (str, optional): The tag to use for the test. Defaults to "".
"""
base_path = model_case.base
# Create used adaptors for each prompt in batch
i, adaptors = 0, []
for _ in range(len(prompts)):
adaptors.append(model_case.adaptors[i])
i = (i + 1) % len(model_case.adaptors)
adaptor_names = [adaptor.name for adaptor in adaptors]
print(
f"\n========== Testing {test_tag} on base '{model_case.base}' with backend={backend}, dtype={torch_dtype} --- "
f"Using prompts {[p[:50] for p in prompts]} with adaptors: {adaptor_names} ---"
)
with SRTRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
tp_size=model_case.tp_size,
lora_paths=[
adaptor.name for adaptor in model_case.adaptors if adaptor.name is not None
],
max_loras_per_batch=model_case.max_loras_per_batch,
lora_backend=backend,
disable_cuda_graph=disable_cuda_graph,
disable_radix_cache=disable_radix_cache,
mem_fraction_static=mem_fraction_static,
) as srt_runner:
srt_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=adaptor_names
)
with SRTRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
tp_size=model_case.tp_size,
mem_fraction_static=mem_fraction_static,
) as srt_runner:
srt_no_lora_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
with HFRunner(
base_path, torch_dtype=torch_dtype, model_type="generation"
) as hf_runner:
hf_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=adaptor_names
)
hf_no_lora_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
# Compare prefill stage logprobs (HF vs SRTRunner with LoRA)
for i in range(len(prompts)):
adaptor = adaptors[i]
# Use individual adaptor tolerances if set, otherwise use model defaults
prefill_tol = (
adaptor.prefill_tolerance
if adaptor.prefill_tolerance is not None
else model_case.prefill_tolerance
)
decode_tol = (
adaptor.decode_tolerance
if adaptor.decode_tolerance is not None
else model_case.decode_tolerance
)
rouge_tol = (
adaptor.rouge_l_tolerance
if adaptor.rouge_l_tolerance is not None
else model_case.rouge_l_tolerance
)
# Compare prefill stage logprobs (HF vs SRTRunner with LoRA)
hf_prefill = torch.tensor(hf_outputs.top_input_logprobs[i])
srt_prefill = torch.tensor(srt_outputs.top_input_logprobs[i])
max_prefill_diff = torch.max(torch.abs(hf_prefill - srt_prefill))
print("Max prefill diff (HF vs SRT):", max_prefill_diff)
# Compare decode stage logprobs
hf_decode = torch.tensor(hf_outputs.top_output_logprobs[i])
srt_decode = torch.tensor(srt_outputs.top_output_logprobs[i])
max_decode_diff = torch.max(torch.abs(hf_decode - srt_decode))
print("Max decode diff (HF vs SRT):", max_decode_diff)
srt_output_str = srt_outputs.output_strs[i].strip()
hf_output_str = hf_outputs.output_strs[i].strip()
rouge_score = calculate_rouge_l([srt_output_str], [hf_output_str])[0]
print("ROUGE-L score:", rouge_score)
print("SRT output:", srt_output_str)
print("HF output:", hf_output_str)
# Additional: compare prefill outputs between base model (no LoRA) and LoRA model for reference
hf_no_lora_prefill = torch.tensor(hf_no_lora_outputs.top_input_logprobs[i])
srt_no_lora_prefill = torch.tensor(srt_no_lora_outputs.top_input_logprobs[i])
print(
"Max diff (SRT base vs SRT LoRA prefill):",
torch.max(torch.abs(srt_no_lora_prefill - srt_prefill)),
)
print(
"Max diff (HF base vs HF LoRA prefill):",
torch.max(torch.abs(hf_no_lora_prefill - hf_prefill)),
)
if hf_prefill.shape[0] <= 100:
assert torch.all(torch.abs(hf_prefill - srt_prefill) < prefill_tol), (
f"Prefill logprobs mismatch for base '{base_path}', adaptor '{adaptor_names}', "
f"backend '{backend}', prompt: '{prompts[0][:50]}...'"
)
if hf_decode.shape[0] <= 100:
assert torch.all(torch.abs(hf_decode - srt_decode) < decode_tol), (
f"Decode logprobs mismatch for base '{base_path}', adaptor '{adaptor_names}', "
f"backend '{backend}', prompt: '{prompts[0][:50]}...'"
)
if rouge_score < rouge_tol:
raise AssertionError(
f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} "
f"for base '{base_path}', adaptor '{adaptor_names}', backend '{backend}', prompt: '{prompts[0][:50]}...'"
)
def run_lora_test_by_batch(
prompts: List[str],
model_case: LoRAModelCase,
torch_dtype: torch.dtype,
max_new_tokens: int,
backend: str,
disable_cuda_graph: bool = False,
disable_radix_cache: bool = False,
mem_fraction_static: float = 0.88,
test_tag: str = "",
):
"""
Run lora tests as a batch.
For prompt0, prompt1, ..., promptN,
we will use adaptor0, adaptor1, ..., adaptorN included in model case,
We will then compare the outputs of HF and SRT with LoRA.
If number of prompts is larger than number of adaptors,
the prompt i will use adaptor i % (number of adaptors).
Args:
prompts (List[str]): The batch of prompts to test.
model_case (LoRAModelCase): The model case to test.
torch_dtype (torch.dtype): The torch dtype to use.
max_new_tokens (int): The maximum number of new tokens to generate.
backend (str): The lora backend to use.
disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False.
disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to False.
mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88.
test_tag (str, optional): The tag to use for the test. Defaults to "".
"""
base_path = model_case.base
# Create used adaptors for each prompt in batch
i, adaptors = 0, []
for _ in range(len(prompts)):
adaptors.append(model_case.adaptors[i])
i = (i + 1) % len(model_case.adaptors)
adaptor_names = [adaptor.name for adaptor in adaptors]
print(
f"\n========== Testing {test_tag} on base '{model_case.base}' with backend={backend}, dtype={torch_dtype} --- "
f"Using prompts {[p[:50] for p in prompts]} with adaptors: {adaptor_names} ---"
)
with SRTRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
tp_size=model_case.tp_size,
lora_paths=[
adaptor.name for adaptor in model_case.adaptors if adaptor.name is not None
],
max_loras_per_batch=model_case.max_loras_per_batch,
lora_backend=backend,
disable_cuda_graph=disable_cuda_graph,
disable_radix_cache=disable_radix_cache,
mem_fraction_static=mem_fraction_static,
) as srt_runner:
srt_outputs = srt_runner.batch_forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=adaptor_names
)
with SRTRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
tp_size=model_case.tp_size,
mem_fraction_static=mem_fraction_static,
) as srt_runner:
srt_no_lora_outputs = srt_runner.batch_forward(
prompts, max_new_tokens=max_new_tokens
)
with HFRunner(
base_path, torch_dtype=torch_dtype, model_type="generation"
) as hf_runner:
hf_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=adaptor_names
)
with HFRunner(
base_path, torch_dtype=torch_dtype, model_type="generation"
) as hf_runner:
hf_no_lora_outputs = hf_runner.forward(
prompts,
max_new_tokens=max_new_tokens,
)
for i in range(len(prompts)):
srt_output_str = srt_outputs.output_strs[i].strip()
hf_output_str = hf_outputs.output_strs[i].strip()
rouge_score = calculate_rouge_l([srt_output_str], [hf_output_str])[0]
print("ROUGE-L score:", rouge_score)
print("SRT output:", srt_output_str)
print("HF output:", hf_output_str)
print("SRT no lora output:", srt_no_lora_outputs.output_strs[i].strip())
print("HF no lora output:", hf_no_lora_outputs.output_strs[i].strip())
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i].strip(
" "
), (
srt_outputs.output_strs[i].strip(" "),
hf_outputs.output_strs[i].strip(" "),
)
assert srt_no_lora_outputs.output_strs[i].strip(
" "
) == hf_no_lora_outputs.output_strs[i].strip(" "), (
srt_no_lora_outputs.output_strs[i].strip(" "),
hf_no_lora_outputs.output_strs[i].strip(" "),
)

View File

@@ -0,0 +1,52 @@
"""
used for debug using tensor comparison
dump {name: tensor} into "log_hf.jsonl" and "log_srt.jsonl"
use the same name for two tensors that supposed to be close
recommend name like: "layer 2 after mlp"
"""
import json
import sys
import torch
if len(sys.argv) > 1:
assert sys.argv[1] == "base"
hf_log = "base_log_hf.jsonl"
srt_log = "base_log_srt.jsonl"
else:
hf_log = "log_hf.jsonl"
srt_log = "log_srt.jsonl"
def load_data(filepath):
tensors = {}
with open(filepath, "r") as f:
lines = f.readlines()
for line in lines:
data = json.loads(line)
for k, v in data.items():
tensors[k] = torch.tensor(v)
return tensors
hf_tensors = load_data(hf_log)
srt_tensors = load_data(srt_log)
def get_diff(t1, t2):
t1 = t1.reshape(t2.shape)
max_diff = torch.max(abs(t1.reshape(t2.shape) - t2))
l2_dis = torch.dist(t1, t2, p=2)
return l2_dis, max_diff
for k, _ in srt_tensors.items():
l2_dis, max_diff = get_diff(hf_tensors[k], srt_tensors[k])
print(f"{k} {l2_dis=} {max_diff=}")
if k == "layer 1 attn":
print(hf_tensors[k])
print(srt_tensors[k])
if k == "layer 0 prefill k":
print(srt_tensors[k].shape)
print(hf_tensors[k].shape)

View File

@@ -0,0 +1,80 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import unittest
import torch
from transformers import AutoProcessor
from sglang.srt.utils import load_image
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import get_similarities
TEXTS = "two Subway Series sandwiches with meats, cheese, lettuce, tomatoes, and onions on a black background, accompanied by the Subway Series logo, highlighting a new sandwich series."
IMAGES = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/023.jpg"
MODELS = [
("openai/clip-vit-large-patch14-336", 1e-5),
]
TORCH_DTYPES = [torch.float16]
class TestClipModels(unittest.TestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def assert_close_embeddings(self, model, prefill_tolerance, torch_dtype):
with HFRunner(
model,
torch_dtype=torch_dtype,
model_type="embedding",
) as hf_runner:
hf_text_embeds = hf_runner.forward(prompts=TEXTS)
hf_image_embeds = hf_runner.forward(image_data=IMAGES)
with SRTRunner(
model,
tp_size=1,
torch_dtype=torch_dtype,
model_type="embedding",
) as srt_runner:
text_embeds = srt_runner.forward(prompts=TEXTS)
image_embeds = srt_runner.forward(prompts="padding", image_data=IMAGES)
text_similarity = get_similarities(
text_embeds.embed_logits[0], hf_text_embeds.embed_logits[0]
)
image_similarity = get_similarities(
image_embeds.embed_logits[0], hf_image_embeds.embed_logits[0]
)
print("text similarity diff", abs(text_similarity - 1))
print("image similarity diff", abs(image_similarity - 1))
assert torch.all(
abs(text_similarity - 1) < prefill_tolerance
), "embeddings are not all close"
assert torch.all(
abs(image_similarity - 1) < prefill_tolerance
), "embeddings are not all close"
def test_accuracy(self):
for model, prefill_tolerance in MODELS:
for torch_dtype in TORCH_DTYPES:
self.assert_close_embeddings(model, prefill_tolerance, torch_dtype)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,46 @@
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestCompressedTensorsLlama3FP8(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "RedHatAI/Meta-Llama-3.1-8B-FP8"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreaterEqual(metrics["accuracy"], 0.45)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,91 @@
import multiprocessing as mp
import random
import unittest
import torch
from transformers import AutoConfig, AutoTokenizer
from sglang.test.runners import TEST_RERANK_QUERY_DOCS, HFRunner, SRTRunner
from sglang.test.test_utils import CustomTestCase, is_in_ci
MODELS = [
("cross-encoder/ms-marco-MiniLM-L6-v2", 1, 1e-2),
("BAAI/bge-reranker-v2-m3", 1, 1e-2),
]
ATTENTION_BACKEND = ["torch_native", "triton"]
TORCH_DTYPES = [torch.float32]
class TestCrossEncoderModels(CustomTestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def assert_close_prefill_logits(
self,
prompts,
model_path,
tp_size,
torch_dtype,
score_tolerance,
attention_backend,
) -> None:
with HFRunner(
model_path,
torch_dtype=torch_dtype,
model_type="cross_encoder",
) as hf_runner:
hf_scores = hf_runner.forward(prompts).scores
with SRTRunner(
model_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
model_type="cross_encoder",
attention_backend=attention_backend,
chunked_prefill_size=-1,
disable_radix_cache=True,
) as srt_runner:
srt_scores = srt_runner.forward(prompts).scores
for i in range(len(srt_scores)):
score_difference = abs(hf_scores[i] - srt_scores[i])
assert (
score_difference < score_tolerance
), "cross encoder scores are not all close"
def preprocess_prompts(self, prompt):
processed_prompts = []
query = prompt["query"]
documents = prompt["documents"]
for document in documents:
processed_prompts.append([query, document])
return processed_prompts
def test_prefill_logits(self):
models_to_test = MODELS
if is_in_ci():
models_to_test = [random.choice(MODELS)]
for model, tp_size, prefill_tolerance in models_to_test:
for attention_backend in ATTENTION_BACKEND:
for queryDocs in TEST_RERANK_QUERY_DOCS:
prompts = self.preprocess_prompts(queryDocs)
for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits(
prompts,
model,
tp_size,
torch_dtype,
prefill_tolerance,
attention_backend,
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,34 @@
import unittest
from sglang.test.test_utils import CustomTestCase, is_in_ci, run_bench_one_batch
class TestDummyGrok1(CustomTestCase):
def test_dummy_grok_1(self):
_, output_throughput, _ = run_bench_one_batch(
None,
[
"--model",
"/dummy-grok",
"--tokenizer-path",
"Xenova/grok-1-tokenizer",
"--batch-size",
"2",
"--tp",
"2",
"--quantization",
"fp8",
"--load-format",
"dummy",
"--json-model-override-args",
'{"num_hidden_layers": 2}',
],
)
if is_in_ci():
self.assertGreater(output_throughput, 0)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,111 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import random
import unittest
import torch
from transformers import AutoConfig, AutoTokenizer
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import CustomTestCase, get_similarities, is_in_ci
MODELS = [
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5),
("intfloat/e5-mistral-7b-instruct", 1, 1e-5),
("marco/mcdse-2b-v1", 1, 1e-5),
("Qwen/Qwen3-Embedding-8B", 1, 1e-5),
# Temporarily disable before this model is fixed
# ("jason9693/Qwen2.5-1.5B-apeach", 1, 1e-5),
]
TORCH_DTYPES = [torch.float16]
class TestEmbeddingModels(CustomTestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def _truncate_prompts(self, prompts, model_path):
config = AutoConfig.from_pretrained(model_path)
max_length = getattr(config, "max_position_embeddings", 2048)
tokenizer = AutoTokenizer.from_pretrained(model_path)
truncated_prompts = []
for prompt in prompts:
tokens = tokenizer(prompt, return_tensors="pt", truncation=False)
if len(tokens.input_ids[0]) > max_length:
truncated_text = tokenizer.decode(
tokens.input_ids[0][: max_length - 1], skip_special_tokens=True
)
truncated_prompts.append(truncated_text)
else:
truncated_prompts.append(prompt)
return truncated_prompts
def assert_close_prefill_logits(
self,
prompts,
model_path,
tp_size,
torch_dtype,
prefill_tolerance,
) -> None:
truncated_prompts = self._truncate_prompts(prompts, model_path)
with HFRunner(
model_path,
torch_dtype=torch_dtype,
model_type="embedding",
) as hf_runner:
hf_outputs = hf_runner.forward(truncated_prompts)
with SRTRunner(
model_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
model_type="embedding",
) as srt_runner:
srt_outputs = srt_runner.forward(truncated_prompts)
for i in range(len(prompts)):
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
srt_logits = torch.Tensor(srt_outputs.embed_logits[i])
similarity = torch.tensor(get_similarities(hf_logits, srt_logits))
print("similarity diff", abs(similarity - 1))
if len(prompts[i]) <= 1000:
assert torch.all(
abs(similarity - 1) < prefill_tolerance
), "embeddings are not all close"
def test_prefill_logits(self):
models_to_test = MODELS
if is_in_ci():
models_to_test = [random.choice(MODELS)]
for model, tp_size, prefill_tolerance in models_to_test:
for torch_dtype in TORCH_DTYPES:
self.assert_close_prefill_logits(
DEFAULT_PROMPTS, model, tp_size, torch_dtype, prefill_tolerance
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,162 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# python -m unittest test_encoder_embedding_models.TestEncoderEmbeddingModels.test_prefill_logits
import multiprocessing as mp
import random
import time
import unittest
import torch
from transformers import AutoConfig, AutoTokenizer
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import CustomTestCase, get_similarities, is_in_ci
MODELS = [("BAAI/bge-small-en", 1, 1e-5), ("BAAI/bge-m3", 1, 1e-5)]
ATTENTION_BACKEND = ["torch_native", "triton", "flashinfer"]
BATCH_SIZE = [1, 2]
TORCH_DTYPES = [torch.float32, torch.float16]
sgl_to_st_ratio = []
class TestEncoderEmbeddingModels(CustomTestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def _truncate_prompts(self, prompts, model_path):
config = AutoConfig.from_pretrained(model_path)
max_length = getattr(config, "max_position_embeddings", 512) - 20
tokenizer = AutoTokenizer.from_pretrained(model_path)
truncated_prompts = []
for prompt in prompts:
tokens = tokenizer(prompt, return_tensors="pt", truncation=False)
if len(tokens.input_ids[0]) > max_length:
truncated_text = tokenizer.decode(
tokens.input_ids[0][: max_length - 1], skip_special_tokens=True
)
truncated_prompts.append(truncated_text)
else:
truncated_prompts.append(prompt)
return truncated_prompts
def assert_close_prefill_logits(
self,
prompts,
model_path,
tp_size,
torch_dtype,
prefill_tolerance,
attention_backend,
batch_size,
) -> None:
truncated_prompts = self._truncate_prompts(prompts, model_path)
truncated_prompts = truncated_prompts * batch_size
with HFRunner(
model_path,
torch_dtype=torch_dtype,
model_type="embedding",
) as hf_runner:
# warm up
hf_outputs = hf_runner.forward(truncated_prompts)
st_start_time = time.perf_counter()
hf_outputs = hf_runner.forward(truncated_prompts)
st_end_time = time.perf_counter()
with SRTRunner(
model_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
model_type="embedding",
attention_backend=attention_backend,
chunked_prefill_size=-1,
disable_radix_cache=True,
) as srt_runner:
# warm up
srt_outputs = srt_runner.forward(truncated_prompts)
sgl_start_time = time.perf_counter()
srt_outputs = srt_runner.forward(truncated_prompts)
sgl_end_time = time.perf_counter()
transformer_time = st_end_time - st_start_time
sgl_time = sgl_end_time - sgl_start_time
sgl_to_st_ratio.append(sgl_time / transformer_time)
for i in range(len(truncated_prompts)):
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
srt_logits = torch.Tensor(srt_outputs.embed_logits[i])
similarity = torch.tensor(get_similarities(hf_logits, srt_logits))
# If something is wrong, uncomment this to observe similarity.
# print("similarity diff", abs(similarity - 1))
if len(truncated_prompts[i]) <= 1000:
assert torch.all(
abs(similarity - 1) < prefill_tolerance
), "embeddings are not all close"
def test_prefill_logits(self):
models_to_test = MODELS
if is_in_ci():
models_to_test = [random.choice(MODELS)]
for model, tp_size, prefill_tolerance in models_to_test:
for attention_backend in ATTENTION_BACKEND:
for batch_size in BATCH_SIZE:
for torch_dtype in TORCH_DTYPES:
# NOTE: FlashInfer currently has limitations with head_dim = 32 or
# other dimensions.
# The FlashInfer head_dim limitation itself is tracked here:
# https://github.com/flashinfer-ai/flashinfer/issues/1048
#
# Flashinfer does not support torch.float32 for dtype_q, so skip it
if attention_backend == "flashinfer":
if (
model == "BAAI/bge-small-en"
or torch_dtype == torch.float32
):
continue
self.assert_close_prefill_logits(
DEFAULT_PROMPTS,
model,
tp_size,
torch_dtype,
prefill_tolerance,
attention_backend,
batch_size,
)
for i in range(len(BATCH_SIZE)):
print(
"bacth size: ",
BATCH_SIZE[i] * 5,
"sgl_time/st_time",
round(sgl_to_st_ratio[i], 3),
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,180 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Usage:
To test a specific model locally:
1. Add it to ALL_MODELS, for example, `ModelCase("Qwen/Qwen2-1.5B")`
2. Run `ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels`
"""
import dataclasses
import multiprocessing as mp
import os
import random
import unittest
from typing import List
import torch
from sglang.test.runners import (
DEFAULT_PROMPTS,
HFRunner,
SRTRunner,
check_close_model_outputs,
)
from sglang.test.test_utils import CustomTestCase, is_in_ci
@dataclasses.dataclass
class ModelCase:
model_path: str
tp_size: int = 1
prefill_tolerance: float = 5e-2
decode_tolerance: float = 6e-2 # Increased to fix numerical error in issue #8614.
rouge_l_tolerance: float = 1
skip_long_prompt: bool = False
trust_remote_code: bool = False
# Popular models that run on the CI
CI_MODELS = [
ModelCase("meta-llama/Llama-3.1-8B-Instruct"),
ModelCase("google/gemma-2-2b"),
]
# the complete set of models to test sglang's generation model
ALL_MODELS = [
*CI_MODELS,
ModelCase("Qwen/Qwen2-1.5B"),
ModelCase("Qwen/Qwen2.5-14B-Instruct"),
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True),
ModelCase(
"THUDM/glm-4-9b-chat", tp_size=2, trust_remote_code=True, skip_long_prompt=True
),
ModelCase("openai-community/gpt2"),
ModelCase("microsoft/phi-1_5", trust_remote_code=True),
ModelCase("adept/persimmon-8b-chat"),
ModelCase("inclusionAI/Ling-lite", trust_remote_code=True),
ModelCase("microsoft/Phi-3-small-8k-instruct", trust_remote_code=True),
ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True),
ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True),
ModelCase(
"microsoft/Phi-3.5-MoE-instruct",
tp_size=2,
trust_remote_code=True,
skip_long_prompt=True,
),
ModelCase(
"nvidia/Llama-3_3-Nemotron-Super-49B-v1_5",
tp_size=2,
trust_remote_code=True,
skip_long_prompt=True,
),
ModelCase(
"nvidia/Llama-3_1-Nemotron-Ultra-253B-v1",
tp_size=8,
trust_remote_code=True,
skip_long_prompt=True,
),
]
TORCH_DTYPES = [torch.float16]
class TestGenerationModels(CustomTestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def assert_close_logits_and_output_strs(
self,
prompts: List[str],
model_case: ModelCase,
torch_dtype: torch.dtype,
) -> None:
model_path = model_case.model_path
prefill_tolerance, decode_tolerance, rouge_l_tolerance = (
model_case.prefill_tolerance,
model_case.decode_tolerance,
model_case.rouge_l_tolerance,
)
max_new_tokens = 32
with HFRunner(
model_path,
torch_dtype=torch_dtype,
model_type="generation",
trust_remote_code=model_case.trust_remote_code,
) as hf_runner:
hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
with SRTRunner(
model_path,
tp_size=model_case.tp_size,
torch_dtype=torch_dtype,
model_type="generation",
trust_remote_code=model_case.trust_remote_code,
) as srt_runner:
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
check_close_model_outputs(
hf_outputs=hf_outputs,
srt_outputs=srt_outputs,
prefill_tolerance=model_case.prefill_tolerance,
decode_tolerance=model_case.decode_tolerance,
rouge_l_tolerance=model_case.rouge_l_tolerance,
debug_text=f"model_path={model_path} prompts={prompts}",
)
@unittest.skipIf(not is_in_ci(), "Local test should run all models")
def test_ci_models(self):
for model_case in CI_MODELS:
for torch_dtype in TORCH_DTYPES:
prompts = DEFAULT_PROMPTS
# Skip long prompts for models that do not have a long context
if model_case.skip_long_prompt:
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
# Assert the logits and output strs are close
self.assert_close_logits_and_output_strs(
prompts, model_case, torch_dtype
)
@unittest.skipIf(is_in_ci(), "CI only runs selected models for simplicity")
def test_all_models(self):
for model_case in ALL_MODELS:
for torch_dtype in TORCH_DTYPES:
if (
"ONLY_RUN" in os.environ
and os.environ["ONLY_RUN"] != model_case.model_path
):
continue
# Skip long prompts for models that do not have a long context
prompts = DEFAULT_PROMPTS
if model_case.skip_long_prompt:
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
# Assert the logits and output strs are close
self.assert_close_logits_and_output_strs(
prompts, model_case, torch_dtype
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,85 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import unittest
import torch
from sglang.test.runners import HFRunner, SRTRunner
from sglang.test.test_utils import CustomTestCase, get_similarities
TEXTS = "two Subway Series sandwiches with meats, cheese, lettuce, tomatoes, and onions on a black background, accompanied by the Subway Series logo, highlighting a new sandwich series."
IMAGES = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/023.jpg"
MODELS = [
("Alibaba-NLP/gme-Qwen2-VL-2B-Instruct", 1e-3),
]
TORCH_DTYPES = [torch.float16]
class TestQmeQwenModels(CustomTestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def assert_close_embeddings(self, model, prefill_tolerance, torch_dtype):
prompts_no_image = f"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n{TEXTS}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>"
prompts_with_image = f"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>assistant\n<|endoftext|>"
with HFRunner(
model,
torch_dtype=torch_dtype,
model_type="embedding",
) as hf_runner:
hf_text_embeddings = hf_runner.forward(prompts=[prompts_no_image])
hf_image_embeddings = hf_runner.forward(
prompts=[prompts_with_image], image_data=[IMAGES]
)
with SRTRunner(
model,
tp_size=1,
torch_dtype=torch_dtype,
model_type="embedding",
) as srt_runner:
srt_text_embeddings = srt_runner.forward(prompts=prompts_no_image)
srt_image_embeddings = srt_runner.forward(
prompts=prompts_with_image, image_data=IMAGES
)
similarity = get_similarities(
hf_text_embeddings.embed_logits[0], srt_text_embeddings.embed_logits[0]
)
print("texts similarity diff", abs(similarity - 1))
assert torch.all(
abs(similarity - 1) < prefill_tolerance
), "embeddings are not all close"
similarity = get_similarities(
hf_image_embeddings.embed_logits[0], srt_image_embeddings.embed_logits[0]
)
print("images similarity diff", abs(similarity - 1))
assert torch.all(
abs(similarity - 1) < prefill_tolerance
), "embeddings are not all close"
def test_accuracy(self):
for model, prefill_tolerance in MODELS:
for torch_dtype in TORCH_DTYPES:
self.assert_close_embeddings(model, prefill_tolerance, torch_dtype)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,53 @@
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestGrok(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmzheng/grok-1"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--load-format",
"dummy",
"--json-model-override-args",
'{"num_hidden_layers": 2}',
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=64,
max_new_tokens=256,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
# It is dummy weights so we only assert the output throughput instead of accuracy.
self.assertGreater(metrics["output_throughput"], 1000)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,74 @@
import random
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
MODELS = [
SimpleNamespace(
model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
accuracy=0.9,
tp_size=4,
),
]
class TestLlama4(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
def test_gsm8k(self):
for model in MODELS:
try:
process = popen_launch_server(
model.model,
self.base_url,
timeout=3 * DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--chat-template",
"llama-4",
"--tp-size",
str(model.tp_size),
"--mem-fraction-static",
"0.8",
"--context-length",
"8192",
],
)
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreaterEqual(metrics["accuracy"], model.accuracy)
except Exception as e:
print(f"Error testing {model.model}: {e}")
self.fail(f"Test failed for {model.model}: {e}")
finally:
# Ensure process cleanup happens regardless of success/failure
if process is not None and process.poll() is None:
print(f"Cleaning up process {process.pid}")
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process: {e}")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,58 @@
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestMiMoMTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "XiaomiMiMo/MiMo-7B-RL"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--speculative-algorithm",
"EAGLE",
"--speculative-num-steps",
"1",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"2",
"--mem-fraction-static",
"0.5",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.7)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,77 @@
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestQwen2(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen2-7B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.78)
class TestQwen2FP8(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "neuralmagic/Qwen2-7B-Instruct-FP8"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.78)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,92 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import unittest
import torch
from sglang.test.runners import HFRunner, SRTRunner
from sglang.test.test_utils import CustomTestCase
MODELS = [
("LxzGordon/URM-LLaMa-3.1-8B", 1, 4e-2),
("Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", 1, 4e-2),
]
TORCH_DTYPES = [torch.float16]
# PROMPT = "Jane has 12 apples. She gives 4 apples to her friend Mark, then buys 1 more apple, and finally splits all her apples equally among herself and her 2 siblings. How many apples does each person get?"
# RESPONSE1 = "1. Jane starts with 12 apples and gives 4 to Mark. 12 - 4 = 8. Jane now has 8 apples.\n2. Jane buys 1 more apple. 8 + 1 = 9. Jane now has 9 apples.\n3. Jane splits the 9 apples equally among herself and her 2 siblings (3 people in total). 9 ÷ 3 = 3 apples each. Each person gets 3 apples."
# RESPONSE2 = "1. Jane starts with 12 apples and gives 4 to Mark. 12 - 4 = 8. Jane now has 8 apples.\n2. Jane buys 1 more apple. 8 + 1 = 9. Jane now has 9 apples.\n3. Jane splits the 9 apples equally among her 2 siblings (2 people in total). 9 ÷ 2 = 4.5 apples each. Each person gets 4 apples."
PROMPT = (
"What is the range of the numeric output of a sigmoid node in a neural network?"
)
RESPONSE1 = "The output of a sigmoid node is bounded between -1 and 1."
RESPONSE2 = "The output of a sigmoid node is bounded between 0 and 1."
CONVS = [
[{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE1}],
[{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE2}],
]
class TestRewardModels(CustomTestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def assert_close_reward_scores(
self,
convs,
model_path,
tp_size,
torch_dtype,
tolerance,
) -> None:
with HFRunner(
model_path,
torch_dtype=torch_dtype,
model_type="reward",
) as hf_runner:
hf_outputs = hf_runner.forward(convs)
with SRTRunner(
model_path,
torch_dtype=torch_dtype,
model_type="reward",
) as srt_runner:
prompts = srt_runner.tokenizer.apply_chat_template(convs, tokenize=False)
srt_outputs = srt_runner.forward(prompts)
hf_scores = torch.tensor(hf_outputs.scores)
srt_scores = torch.tensor(srt_outputs.scores)
print(f"{hf_scores=}")
print(f"{srt_scores=}")
assert torch.all(
abs(hf_scores - srt_scores) < tolerance
), "reward scores are not all close"
def test_reward_scores(self):
for model, tp_size, tolerance in MODELS:
for torch_dtype in TORCH_DTYPES:
self.assert_close_reward_scores(
CONVS, model, tp_size, torch_dtype, tolerance
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,181 @@
import dataclasses
import multiprocessing as mp
import unittest
from types import SimpleNamespace
from typing import List
import torch
from sglang.srt.utils import kill_process_tree
from sglang.test.runners import DEFAULT_PROMPTS, SRTRunner, check_close_model_outputs
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
)
class TestTransformersFallbackEndpoint(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--model-impl", "transformers"],
)
cls.mmlu_lower_bound = 0.65
cls.gsm8k_lower_bound = 0.65
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
from sglang.test.run_eval import run_eval
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], self.mmlu_lower_bound)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
from sglang.test.few_shot_gsm8k import run_eval
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], self.gsm8k_lower_bound)
class TestTransformersFallbackTorchAO(TestTransformersFallbackEndpoint):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--model-impl",
"transformers",
"--torchao-config",
"int4wo-128",
],
)
cls.mmlu_lower_bound = 0.65
cls.gsm8k_lower_bound = 0.65
@dataclasses.dataclass
class ModelCase:
model_path: str
tp_size: int = 1
prefill_tolerance: float = 5e-2
decode_tolerance: float = 5e-2
rouge_l_tolerance: float = 1
skip_long_prompt: bool = False
trust_remote_code: bool = False
torchao_config: str = None
torch_dtype: torch.dtype = torch.float16
# Popular models that run on the CI
CI_MODELS = [
ModelCase(DEFAULT_MODEL_NAME_FOR_TEST),
]
ALL_OTHER_MODELS = [
ModelCase(DEFAULT_MODEL_NAME_FOR_TEST, tp_size=2),
]
class TestTransformersFallbackEngine(CustomTestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def assert_close_logits_and_output_strs(
self,
prompts: List[str],
model_case: ModelCase,
) -> None:
model_path = model_case.model_path
max_new_tokens = 32
# force to use transformers impl
with SRTRunner(
model_path,
tp_size=model_case.tp_size,
torch_dtype=model_case.torch_dtype,
model_type="generation",
model_impl="transformers",
trust_remote_code=model_case.trust_remote_code,
torchao_config=model_case.torchao_config,
) as srt_runner:
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
with SRTRunner(
model_path,
tp_size=model_case.tp_size,
torch_dtype=model_case.torch_dtype,
model_type="generation",
trust_remote_code=model_case.trust_remote_code,
torchao_config=model_case.torchao_config,
) as srt_runner:
srt_transformers_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens
)
check_close_model_outputs(
hf_outputs=srt_transformers_outputs,
srt_outputs=srt_outputs,
prefill_tolerance=model_case.prefill_tolerance,
decode_tolerance=model_case.decode_tolerance,
rouge_l_tolerance=model_case.rouge_l_tolerance,
debug_text=f"model_path={model_path} prompts={prompts}",
)
def test_ci_models(self):
for model_case in CI_MODELS:
# Skip long prompts for models that do not have a long context
prompts = DEFAULT_PROMPTS
if model_case.skip_long_prompt:
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
# Assert the logits and output strs are close
self.assert_close_logits_and_output_strs(prompts, model_case)
def test_others(self):
if is_in_ci():
return
# Skip long prompts for models that do not have a long context
prompts = DEFAULT_PROMPTS
for model_case in ALL_OTHER_MODELS:
if model_case.skip_long_prompt:
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
# Assert the logits and output strs are close
self.assert_close_logits_and_output_strs(prompts, model_case)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,213 @@
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestUnslothPhi4(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "unsloth/phi-4"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.78)
class TestUnslothPhi4Bnb4bit(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "unsloth/phi-4-bnb-4bit"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--load-format",
"bitsandbytes",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.75)
class TestUnslothPhi4UnslothBnb4bit(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "unsloth/phi-4-unsloth-bnb-4bit"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--load-format",
"bitsandbytes",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.75)
class TestUnslothPhi4MiniInstruct(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "unsloth/Phi-4-mini-instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.65)
class TestUnslothPhi4MiniBnb4bit(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "unsloth/Phi-4-mini-instruct-bnb-4bit"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--load-format",
"bitsandbytes",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.6)
class TestUnslothPhi4MiniUnslothBnb4bit(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "unsloth/Phi-4-mini-instruct-unsloth-bnb-4bit"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--load-format",
"bitsandbytes",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.6)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,315 @@
import argparse
import glob
import json
import os
import random
import subprocess
import sys
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
)
# VLM models for testing
MODELS = [
SimpleNamespace(model="google/gemma-3-27b-it", mmmu_accuracy=0.45),
SimpleNamespace(
model="Qwen/Qwen2.5-VL-3B-Instruct",
mmmu_accuracy=0.4,
),
SimpleNamespace(model="openbmb/MiniCPM-V-2_6", mmmu_accuracy=0.4),
]
class TestVLMModels(CustomTestCase):
parsed_args = None # Class variable to store args
@classmethod
def setUpClass(cls):
# Removed argument parsing from here
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.time_out = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
# Set OpenAI API key and base URL environment variables. Needed for lmm-evals to work.
os.environ["OPENAI_API_KEY"] = cls.api_key
os.environ["OPENAI_API_BASE"] = f"{cls.base_url}/v1"
def _detect_eviction_in_logs(self, log_output):
"""Detect if eviction events occurred in the log output."""
eviction_keywords = ["Cache eviction: evicted"]
eviction_detected = False
eviction_count = 0
for line in log_output.split("\n"):
if any(keyword in line for keyword in eviction_keywords):
eviction_detected = True
eviction_count += 1
print(f"Eviction detected: {line.strip()}")
return eviction_detected, eviction_count
def run_mmmu_eval(
self,
model_version: str,
output_path: str,
*,
env: dict | None = None,
):
"""
Evaluate a VLM on the MMMU validation set with lmmseval.
Only `model_version` (checkpoint) and `chat_template` vary;
We are focusing only on the validation set due to resource constraints.
"""
# -------- fixed settings --------
model = "openai_compatible"
tp = 1
tasks = "mmmu_val"
batch_size = 2
log_suffix = "openai_compatible"
os.makedirs(output_path, exist_ok=True)
# -------- compose --model_args --------
model_args = f'model_version="{model_version}",' f"tp={tp}"
# -------- build command list --------
cmd = [
"python3",
"-m",
"lmms_eval",
"--model",
model,
"--model_args",
model_args,
"--tasks",
tasks,
"--batch_size",
str(batch_size),
"--log_samples",
"--log_samples_suffix",
log_suffix,
"--output_path",
str(output_path),
]
subprocess.run(
cmd,
check=True,
timeout=3600,
)
def _run_vlm_mmmu_test(
self,
model,
output_path,
test_name="",
custom_env=None,
log_level="info",
capture_output=False,
):
"""
Common method to run VLM MMMU benchmark test.
Args:
model: Model to test
output_path: Path for output logs
test_name: Optional test name for logging
custom_env: Optional custom environment variables
log_level: Log level for server (default: "info")
capture_output: Whether to capture server stdout/stderr
"""
print(f"\nTesting model: {model.model}{test_name}")
process = None
mmmu_accuracy = 0 # Initialize to handle potential exceptions
server_output = ""
try:
# Prepare environment variables
process_env = os.environ.copy()
if custom_env:
process_env.update(custom_env)
# Prepare stdout/stderr redirection if needed
stdout_file = None
stderr_file = None
if capture_output:
stdout_file = open("/tmp/server_stdout.log", "w")
stderr_file = open("/tmp/server_stderr.log", "w")
# Launch server for testing
process = popen_launch_server(
model.model,
base_url=self.base_url,
timeout=self.time_out,
api_key=self.api_key,
other_args=[
"--trust-remote-code",
"--cuda-graph-max-bs",
"32",
"--enable-multimodal",
"--mem-fraction-static",
str(self.parsed_args.mem_fraction_static), # Use class variable
"--log-level",
log_level,
],
env=process_env,
return_stdout_stderr=(
(stdout_file, stderr_file) if capture_output else None
),
)
# Run evaluation
self.run_mmmu_eval(model.model, output_path)
# Get the result file
result_file_path = glob.glob(f"{output_path}/*.json")[0]
with open(result_file_path, "r") as f:
result = json.load(f)
print(f"Result{test_name}\n: {result}")
# Process the result
mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"]
print(
f"Model {model.model} achieved accuracy{test_name}: {mmmu_accuracy:.4f}"
)
# Capture server output if requested
if capture_output and process:
server_output = self._read_output_from_files()
# Assert performance meets expected threshold
self.assertGreaterEqual(
mmmu_accuracy,
model.mmmu_accuracy,
f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f}){test_name}",
)
return server_output
except Exception as e:
print(f"Error testing {model.model}{test_name}: {e}")
self.fail(f"Test failed for {model.model}{test_name}: {e}")
finally:
# Ensure process cleanup happens regardless of success/failure
if process is not None and process.poll() is None:
print(f"Cleaning up process {process.pid}")
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process: {e}")
# clean up temporary files
if capture_output:
if stdout_file:
stdout_file.close()
if stderr_file:
stderr_file.close()
for filename in ["/tmp/server_stdout.log", "/tmp/server_stderr.log"]:
try:
if os.path.exists(filename):
os.remove(filename)
except Exception as e:
print(f"Error removing {filename}: {e}")
def _read_output_from_files(self):
output_lines = []
log_files = [
("/tmp/server_stdout.log", "[STDOUT]"),
("/tmp/server_stderr.log", "[STDERR]"),
]
for filename, tag in log_files:
try:
if os.path.exists(filename):
with open(filename, "r") as f:
for line in f:
output_lines.append(f"{tag} {line.rstrip()}")
except Exception as e:
print(f"Error reading {tag.lower()} file: {e}")
return "\n".join(output_lines)
def test_vlm_mmmu_benchmark(self):
"""Test VLM models against MMMU benchmark."""
models_to_test = MODELS
if is_in_ci():
models_to_test = [random.choice(MODELS)]
for model in models_to_test:
self._run_vlm_mmmu_test(model, "./logs")
def test_vlm_mmmu_benchmark_with_small_cache(self):
"""Test VLM models against MMMU benchmark with a small embedding cache to force eviction."""
models_to_test = MODELS
if is_in_ci():
models_to_test = [random.choice(MODELS)]
for model in models_to_test:
custom_env = {"SGLANG_VLM_CACHE_SIZE_MB": "5"}
# Run the test with output capture
server_output = self._run_vlm_mmmu_test(
model,
"./logs_small_cache",
test_name=" with small embedding cache (evict test)",
custom_env=custom_env,
log_level="debug", # Enable debug logging for eviction detection
capture_output=True, # Capture server output
)
# Print server output for debugging
print("Server output:\n", server_output)
# Analyze server output for eviction events
eviction_detected, eviction_count = self._detect_eviction_in_logs(
server_output
)
# Assert that eviction was detected (since we're using small cache)
self.assertTrue(
eviction_detected,
f"Expected eviction events to be detected with small cache (5MB), but none found. "
f"Cache size may be too large for the workload or eviction logic may not be working. "
f"Total log content length: {len(server_output)} characters",
)
print(
f"Eviction detection summary: {eviction_count} eviction events detected"
)
# Additional assertion: if eviction was detected, the test passed
if eviction_detected:
print("✅ Eviction logic successfully triggered and detected!")
if __name__ == "__main__":
# Define and parse arguments here, before unittest.main
parser = argparse.ArgumentParser(description="Test VLM models")
parser.add_argument(
"--mem-fraction-static",
type=float,
help="Static memory fraction for the model",
default=0.8,
)
# Parse args intended for unittest
args = parser.parse_args()
# Store the parsed args object on the class
TestVLMModels.parsed_args = args
# Pass args to unittest
unittest.main(argv=[sys.argv[0]])

View File

View File

View File

@@ -0,0 +1,97 @@
import unittest
import openai
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestOpenAIEmbedding(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
# Configure embedding-specific args
other_args = ["--is-embedding", "--enable-metrics"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=other_args,
)
cls.base_url += "/v1"
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_embedding_single(self):
"""Test single embedding request"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(model=self.model, input="Hello world")
self.assertEqual(len(response.data), 1)
self.assertTrue(len(response.data[0].embedding) > 0)
def test_embedding_batch(self):
"""Test batch embedding request"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(
model=self.model, input=["Hello world", "Test text"]
)
self.assertEqual(len(response.data), 2)
self.assertTrue(len(response.data[0].embedding) > 0)
self.assertTrue(len(response.data[1].embedding) > 0)
def test_embedding_single_batch_str(self):
"""Test embedding with a List[str] and length equals to 1"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(model=self.model, input=["Hello world"])
self.assertEqual(len(response.data), 1)
self.assertTrue(len(response.data[0].embedding) > 0)
def test_embedding_single_int_list(self):
"""Test embedding with a List[int] or List[List[int]]]"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(
model=self.model,
input=[[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061]],
)
self.assertEqual(len(response.data), 1)
self.assertTrue(len(response.data[0].embedding) > 0)
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.embeddings.create(
model=self.model,
input=[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061],
)
self.assertEqual(len(response.data), 1)
self.assertTrue(len(response.data[0].embedding) > 0)
def test_empty_string_embedding(self):
"""Test embedding an empty string."""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
# Text embedding example with empty string
text = ""
# Expect a BadRequestError for empty input
with self.assertRaises(openai.BadRequestError) as cm:
client.embeddings.create(
model=self.model,
input=text,
)
# check the status code
self.assertEqual(cm.exception.status_code, 400)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,669 @@
"""
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_completion
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_completion_stream
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_chat_completion
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_chat_completion_stream
"""
import json
import re
import unittest
import numpy as np
import openai
import requests
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.runners import TEST_RERANK_QUERY_DOCS
from sglang.test.test_utils import (
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestOpenAIServer(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_completion(
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is"
if token_input:
prompt_input = self.tokenizer.encode(prompt)
num_prompt_tokens = len(prompt_input)
else:
prompt_input = prompt
num_prompt_tokens = len(self.tokenizer.encode(prompt))
if use_list_input:
prompt_arg = [prompt_input, prompt_input]
num_choices = len(prompt_arg)
num_prompt_tokens *= 2
else:
prompt_arg = prompt_input
num_choices = 1
response = client.completions.create(
model=self.model,
prompt=prompt_arg,
temperature=0,
max_tokens=32,
echo=echo,
logprobs=logprobs,
n=parallel_sample_num,
)
assert len(response.choices) == num_choices * parallel_sample_num
if echo:
text = response.choices[0].text
assert text.startswith(prompt)
if logprobs:
assert response.choices[0].logprobs
assert isinstance(response.choices[0].logprobs.tokens[0], str)
assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert ret_num_top_logprobs > 0
# when echo=True and request.logprobs>0, logprob_start_len is 0, so the first token's logprob would be None.
if not echo:
assert response.choices[0].logprobs.token_logprobs[0]
assert response.id
assert response.created
assert (
response.usage.prompt_tokens == num_prompt_tokens
), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def run_completion_stream(
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is"
if token_input:
prompt_input = self.tokenizer.encode(prompt)
num_prompt_tokens = len(prompt_input)
else:
prompt_input = prompt
num_prompt_tokens = len(self.tokenizer.encode(prompt))
if use_list_input:
prompt_arg = [prompt_input, prompt_input]
num_choices = len(prompt_arg)
num_prompt_tokens *= 2
else:
prompt_arg = prompt_input
num_choices = 1
generator = client.completions.create(
model=self.model,
prompt=prompt_arg,
temperature=0,
max_tokens=32,
echo=echo,
logprobs=logprobs,
stream=True,
stream_options={"include_usage": True},
n=parallel_sample_num,
)
is_firsts = {}
for response in generator:
usage = response.usage
if usage is not None:
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
continue
index = response.choices[0].index
is_first = is_firsts.get(index, True)
if logprobs:
assert response.choices[0].logprobs, f"no logprobs in response"
assert isinstance(
response.choices[0].logprobs.tokens[0], str
), f"{response.choices[0].logprobs.tokens[0]} is not a string"
if not (is_first and echo):
assert isinstance(
response.choices[0].logprobs.top_logprobs[0], dict
), f"top_logprobs was not a dictionary"
ret_num_top_logprobs = len(
response.choices[0].logprobs.top_logprobs[0]
)
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert ret_num_top_logprobs > 0, f"ret_num_top_logprobs was 0"
if is_first:
if echo:
assert response.choices[0].text.startswith(
prompt
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
is_firsts[index] = False
assert response.id, f"no id in response"
assert response.created, f"no created in response"
for index in [i for i in range(parallel_sample_num * num_choices)]:
assert not is_firsts.get(
index, True
), f"index {index} is not found in the response"
def run_chat_completion(self, logprobs, parallel_sample_num):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": "What is the capital of France? Answer in a few words.",
},
],
temperature=0,
logprobs=logprobs is not None and logprobs > 0,
top_logprobs=logprobs,
n=parallel_sample_num,
)
if logprobs:
assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
)
ret_num_top_logprobs = len(
response.choices[0].logprobs.content[0].top_logprobs
)
assert (
ret_num_top_logprobs == logprobs
), f"{ret_num_top_logprobs} vs {logprobs}"
assert len(response.choices) == parallel_sample_num
assert response.choices[0].message.role == "assistant"
assert isinstance(response.choices[0].message.content, str)
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
generator = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "What is the capital of France?"},
],
temperature=0,
logprobs=logprobs is not None and logprobs > 0,
top_logprobs=logprobs,
stream=True,
stream_options={"include_usage": True},
n=parallel_sample_num,
)
is_firsts = {}
is_finished = {}
finish_reason_counts = {}
for response in generator:
usage = response.usage
if usage is not None:
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
continue
index = response.choices[0].index
finish_reason = response.choices[0].finish_reason
if finish_reason is not None:
is_finished[index] = True
finish_reason_counts[index] = finish_reason_counts.get(index, 0) + 1
data = response.choices[0].delta
if is_firsts.get(index, True):
assert (
data.role == "assistant"
), f"data.role was not 'assistant' for first chunk"
is_firsts[index] = False
continue
if logprobs and not is_finished.get(index, False):
assert response.choices[0].logprobs, f"logprobs was not returned"
assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
), f"top_logprobs token was not a string"
assert isinstance(
response.choices[0].logprobs.content[0].top_logprobs, list
), f"top_logprobs was not a list"
ret_num_top_logprobs = len(
response.choices[0].logprobs.content[0].top_logprobs
)
assert (
ret_num_top_logprobs == logprobs
), f"{ret_num_top_logprobs} vs {logprobs}"
assert (
isinstance(data.content, str)
or isinstance(data.reasoning_content, str)
or (isinstance(data.tool_calls, list) and len(data.tool_calls) > 0)
or response.choices[0].finish_reason
)
assert response.id
assert response.created
for index in [i for i in range(parallel_sample_num)]:
assert not is_firsts.get(
index, True
), f"index {index} is not found in the response"
# Verify that each choice gets exactly one finish_reason chunk
for index in range(parallel_sample_num):
assert (
index in finish_reason_counts
), f"No finish_reason found for index {index}"
assert (
finish_reason_counts[index] == 1
), f"Expected 1 finish_reason chunk for index {index}, got {finish_reason_counts[index]}"
def test_completion(self):
for echo in [False, True]:
for logprobs in [None, 5]:
for use_list_input in [True, False]:
for parallel_sample_num in [1, 2]:
for token_input in [False, True]:
self.run_completion(
echo,
logprobs,
use_list_input,
parallel_sample_num,
token_input,
)
def test_completion_stream(self):
# parallel sampling and list input are not supported in streaming mode
for echo in [False, True]:
for logprobs in [None, 5]:
for use_list_input in [True, False]:
for parallel_sample_num in [1, 2]:
for token_input in [False, True]:
self.run_completion_stream(
echo,
logprobs,
use_list_input,
parallel_sample_num,
token_input,
)
def test_chat_completion(self):
for logprobs in [None, 5]:
for parallel_sample_num in [1, 2]:
self.run_chat_completion(logprobs, parallel_sample_num)
def test_chat_completion_stream(self):
for logprobs in [None, 5]:
for parallel_sample_num in [1, 2]:
self.run_chat_completion_stream(logprobs, parallel_sample_num)
def test_regex(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
regex = (
r"""\{\n"""
+ r""" "name": "[\w]+",\n"""
+ r""" "population": [\d]+\n"""
+ r"""\}"""
)
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."},
],
temperature=0,
max_tokens=128,
extra_body={"regex": regex},
)
text = response.choices[0].message.content
try:
js_obj = json.loads(text)
except (TypeError, json.decoder.JSONDecodeError):
print("JSONDecodeError", text)
raise
assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int)
def test_penalty(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."},
],
temperature=0,
max_tokens=32,
frequency_penalty=1.0,
)
text = response.choices[0].message.content
assert isinstance(text, str)
def test_response_prefill(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": """
Extract the name, size, price, and color from this product description as a JSON object:
<description>
The SmartHome Mini is a compact smart home assistant available in black or white for only $49.99. At just 5 inches wide, it lets you control lights, thermostats, and other connected devices via voice or app—no matter where you place it in your home. This affordable little hub brings convenient hands-free control to your smart devices.
</description>
""",
},
{
"role": "assistant",
"content": "{\n",
},
],
temperature=0,
extra_body={"continue_final_message": True},
)
assert (
response.choices[0]
.message.content.strip()
.startswith('"name": "SmartHome Mini",')
)
def test_model_list(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
models = list(client.models.list())
assert len(models) == 1
assert isinstance(getattr(models[0], "max_model_len", None), int)
def test_retrieve_model(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
# Test retrieving an existing model
retrieved_model = client.models.retrieve(self.model)
self.assertEqual(retrieved_model.id, self.model)
self.assertEqual(retrieved_model.root, self.model)
# Test retrieving a non-existent model
with self.assertRaises(openai.NotFoundError):
client.models.retrieve("non-existent-model")
class TestOpenAIV1Rerank(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.score_tolerance = 1e-2
# Configure embedding-specific args
other_args = [
"--is-embedding",
"--enable-metrics",
"--disable-radix-cache",
"--chunked-prefill-size",
"-1",
"--attention-backend",
"torch_native",
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=other_args,
)
cls.base_url += "/v1/rerank"
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_rerank(self, query, docs):
response = requests.post(
self.base_url,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
json={"query": query, "documents": docs},
)
return response.json()
def test_rerank_single(self):
"""Test single rerank request"""
query = TEST_RERANK_QUERY_DOCS[0]["query"]
docs = TEST_RERANK_QUERY_DOCS[0]["documents"]
response = self.run_rerank(query, docs)
self.assertEqual(len(response), 1)
self.assertTrue(isinstance(response[0]["score"], float))
self.assertTrue(isinstance(response[0]["document"], str))
self.assertTrue(isinstance(response[0]["index"], int))
def test_rerank_batch(self):
"""Test batch rerank request"""
query = TEST_RERANK_QUERY_DOCS[1]["query"]
docs = TEST_RERANK_QUERY_DOCS[1]["documents"]
response = self.run_rerank(query, docs)
self.assertEqual(len(response), 2)
self.assertTrue(isinstance(response[0]["score"], float))
self.assertTrue(isinstance(response[1]["score"], float))
self.assertTrue(isinstance(response[0]["document"], str))
self.assertTrue(isinstance(response[1]["document"], str))
self.assertTrue(isinstance(response[0]["index"], int))
self.assertTrue(isinstance(response[1]["index"], int))
class TestOpenAIV1Score(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
)
cls.base_url += "/v1/score"
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_score(
self, query, items, label_token_ids, apply_softmax=False, item_first=False
):
response = requests.post(
self.base_url,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
json={
"model": self.model,
"query": query,
"items": items,
"label_token_ids": label_token_ids,
"apply_softmax": apply_softmax,
"item_first": item_first,
},
)
return response.json()
def test_score_text_input(self):
"""Test scoring with text input"""
query = "The capital of France is"
items = ["Paris", "London", "Berlin"]
# Get valid token IDs from the tokenizer
label_token_ids = []
for item in items:
token_ids = self.tokenizer.encode(item, add_special_tokens=False)
if not token_ids:
self.fail(f"Failed to encode item: {item}")
label_token_ids.append(token_ids[0])
response = self.run_score(query, items, label_token_ids, apply_softmax=True)
# Handle error responses
if response.get("type") == "BadRequestError":
self.fail(f"Score request failed with error: {response['message']}")
# Verify response structure
self.assertIn("scores", response, "Response should have a 'scores' field")
self.assertIsInstance(response["scores"], list, "scores should be a list")
self.assertEqual(
len(response["scores"]),
len(items),
"Number of scores should match number of items",
)
# Each score should be a list of floats in the order of label_token_ids
for i, score_list in enumerate(response["scores"]):
self.assertIsInstance(score_list, list, f"Score {i} should be a list")
self.assertEqual(
len(score_list),
len(label_token_ids),
f"Score {i} length should match label_token_ids",
)
self.assertTrue(
all(isinstance(v, float) for v in score_list),
f"Score {i} values should be floats",
)
self.assertAlmostEqual(
sum(score_list),
1.0,
places=6,
msg=f"Score {i} probabilities should sum to 1",
)
def test_score_token_input(self):
"""Test scoring with token IDs input"""
query = "The capital of France is"
items = ["Paris", "London", "Berlin"]
# Get valid token IDs
query_ids = self.tokenizer.encode(query, add_special_tokens=False)
item_ids = [
self.tokenizer.encode(item, add_special_tokens=False) for item in items
]
label_token_ids = [
ids[0] for ids in item_ids if ids
] # Get first token ID of each item
response = self.run_score(
query_ids, item_ids, label_token_ids, apply_softmax=True
)
# Handle error responses
if response.get("type") == "BadRequestError":
self.fail(f"Score request failed with error: {response['message']}")
# Verify response structure
self.assertIn("scores", response, "Response should have a 'scores' field")
self.assertIsInstance(response["scores"], list, "scores should be a list")
self.assertEqual(
len(response["scores"]),
len(items),
"Number of scores should match number of items",
)
# Each score should be a list of floats in the order of label_token_ids
for i, score_list in enumerate(response["scores"]):
self.assertIsInstance(score_list, list, f"Score {i} should be a list")
self.assertEqual(
len(score_list),
len(label_token_ids),
f"Score {i} length should match label_token_ids",
)
self.assertTrue(
all(isinstance(v, float) for v in score_list),
f"Score {i} values should be floats",
)
self.assertAlmostEqual(
sum(score_list),
1.0,
places=6,
msg=f"Score {i} probabilities should sum to 1",
)
def test_score_error_handling(self):
"""Test error handling for invalid inputs"""
query = "The capital of France is"
items = ["Paris", "London", "Berlin"]
# Test with invalid token ID
response = requests.post(
self.base_url,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
json={
"model": self.model,
"query": query,
"items": items,
"label_token_ids": [999999], # Invalid token ID
"apply_softmax": True,
},
)
self.assertEqual(response.status_code, 400)
error_response = response.json()
self.assertEqual(error_response["type"], "BadRequestError")
self.assertIn("Token ID 999999 is out of vocabulary", error_response["message"])
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,368 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for OpenAI API protocol models"""
import json
import time
import unittest
from typing import Dict, List, Optional
from pydantic import BaseModel, Field, ValidationError
from sglang.srt.entrypoints.openai.protocol import (
BatchRequest,
BatchResponse,
ChatCompletionMessageContentImagePart,
ChatCompletionMessageContentTextPart,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatCompletionTokenLogprob,
ChatMessage,
ChoiceLogprobs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
DeltaMessage,
EmbeddingObject,
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
FileDeleteResponse,
FileRequest,
FileResponse,
Function,
FunctionResponse,
JsonSchemaResponseFormat,
LogProbs,
ModelCard,
ModelList,
MultimodalEmbeddingInput,
ResponseFormat,
ScoringRequest,
ScoringResponse,
StreamOptions,
StructuralTagResponseFormat,
Tool,
ToolCall,
ToolChoice,
TopLogprob,
UsageInfo,
)
class TestModelCard(unittest.TestCase):
"""Test ModelCard protocol model"""
def test_model_card_serialization(self):
"""Test model card JSON serialization"""
card = ModelCard(id="test-model", max_model_len=4096)
data = card.model_dump()
self.assertEqual(data["id"], "test-model")
self.assertEqual(data["object"], "model")
self.assertEqual(data["max_model_len"], 4096)
class TestModelList(unittest.TestCase):
"""Test ModelList protocol model"""
def test_empty_model_list(self):
"""Test empty model list creation"""
model_list = ModelList()
self.assertEqual(model_list.object, "list")
self.assertEqual(len(model_list.data), 0)
def test_model_list_with_cards(self):
"""Test model list with model cards"""
cards = [
ModelCard(id="model-1"),
ModelCard(id="model-2", max_model_len=2048),
]
model_list = ModelList(data=cards)
self.assertEqual(len(model_list.data), 2)
self.assertEqual(model_list.data[0].id, "model-1")
self.assertEqual(model_list.data[1].id, "model-2")
class TestCompletionRequest(unittest.TestCase):
"""Test CompletionRequest protocol model"""
def test_basic_completion_request(self):
"""Test basic completion request"""
request = CompletionRequest(model="test-model", prompt="Hello world")
self.assertEqual(request.model, "test-model")
self.assertEqual(request.prompt, "Hello world")
self.assertEqual(request.max_tokens, 16) # default
self.assertEqual(request.temperature, 1.0) # default
self.assertEqual(request.n, 1) # default
self.assertFalse(request.stream) # default
self.assertFalse(request.echo) # default
def test_completion_request_sglang_extensions(self):
"""Test completion request with SGLang-specific extensions"""
request = CompletionRequest(
model="test-model",
prompt="Hello",
top_k=50,
min_p=0.1,
repetition_penalty=1.1,
regex=r"\d+",
json_schema='{"type": "object"}',
lora_path="/path/to/lora",
)
self.assertEqual(request.top_k, 50)
self.assertEqual(request.min_p, 0.1)
self.assertEqual(request.repetition_penalty, 1.1)
self.assertEqual(request.regex, r"\d+")
self.assertEqual(request.json_schema, '{"type": "object"}')
self.assertEqual(request.lora_path, "/path/to/lora")
def test_completion_request_validation_errors(self):
"""Test completion request validation errors"""
with self.assertRaises(ValidationError):
CompletionRequest() # missing required fields
with self.assertRaises(ValidationError):
CompletionRequest(model="test-model") # missing prompt
class TestChatCompletionRequest(unittest.TestCase):
"""Test ChatCompletionRequest protocol model"""
def test_basic_chat_completion_request(self):
"""Test basic chat completion request"""
messages = [{"role": "user", "content": "Hello"}]
request = ChatCompletionRequest(model="test-model", messages=messages)
self.assertEqual(request.model, "test-model")
self.assertEqual(len(request.messages), 1)
self.assertEqual(request.messages[0].role, "user")
self.assertEqual(request.messages[0].content, "Hello")
self.assertEqual(request.temperature, 0.7) # default
self.assertFalse(request.stream) # default
self.assertEqual(request.tool_choice, "none") # default when no tools
def test_chat_completion_tool_choice_validation(self):
"""Test tool choice validation logic"""
messages = [{"role": "user", "content": "Hello"}]
# No tools, tool_choice should default to "none"
request1 = ChatCompletionRequest(model="test-model", messages=messages)
self.assertEqual(request1.tool_choice, "none")
# With tools, tool_choice should default to "auto"
tools = [
{
"type": "function",
"function": {"name": "test_func", "description": "Test function"},
}
]
request2 = ChatCompletionRequest(
model="test-model", messages=messages, tools=tools
)
self.assertEqual(request2.tool_choice, "auto")
def test_chat_completion_sglang_extensions(self):
"""Test chat completion with SGLang extensions"""
messages = [{"role": "user", "content": "Hello"}]
request = ChatCompletionRequest(
model="test-model",
messages=messages,
top_k=40,
min_p=0.05,
separate_reasoning=False,
stream_reasoning=False,
chat_template_kwargs={"custom_param": "value"},
)
self.assertEqual(request.top_k, 40)
self.assertEqual(request.min_p, 0.05)
self.assertFalse(request.separate_reasoning)
self.assertFalse(request.stream_reasoning)
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
def test_chat_completion_reasoning_effort(self):
"""Test chat completion with reasoning effort"""
messages = [{"role": "user", "content": "Hello"}]
request = ChatCompletionRequest(
model="test-model",
messages=messages,
reasoning={
"enabled": True,
"reasoning_effort": "high",
},
)
self.assertEqual(request.reasoning_effort, "high")
self.assertEqual(request.chat_template_kwargs, {"thinking": True})
def test_chat_completion_json_format(self):
"""Test chat completion json format"""
transcript = "Good morning! It's 7:00 AM, and I'm just waking up. Today is going to be a busy day, "
"so let's get started. First, I need to make a quick breakfast. I think I'll have some "
"scrambled eggs and toast with a cup of coffee. While I'm cooking, I'll also check my "
"emails to see if there's anything urgent."
messages = [
{
"role": "system",
"content": "The following is a voice message transcript. Only answer in JSON.",
},
{
"role": "user",
"content": transcript,
},
]
class VoiceNote(BaseModel):
title: str = Field(description="A title for the voice note")
summary: str = Field(
description="A short one sentence summary of the voice note."
)
strict: Optional[bool] = True
actionItems: List[str] = Field(
description="A list of action items from the voice note"
)
request = ChatCompletionRequest(
model="test-model",
messages=messages,
top_k=40,
min_p=0.05,
separate_reasoning=False,
stream_reasoning=False,
chat_template_kwargs={"custom_param": "value"},
response_format={
"type": "json_schema",
"schema": VoiceNote.model_json_schema(),
},
)
res_format = request.response_format
json_format = res_format.json_schema
name = json_format.name
schema = json_format.schema_
strict = json_format.strict
self.assertEqual(name, "VoiceNote")
self.assertEqual(strict, True)
self.assertNotIn("strict", schema["properties"])
request = ChatCompletionRequest(
model="test-model",
messages=messages,
top_k=40,
min_p=0.05,
separate_reasoning=False,
stream_reasoning=False,
chat_template_kwargs={"custom_param": "value"},
response_format={
"type": "json_schema",
"json_schema": {
"name": "VoiceNote",
"schema": VoiceNote.model_json_schema(),
"strict": True,
},
},
)
res_format = request.response_format
json_format = res_format.json_schema
name = json_format.name
schema = json_format.schema_
strict = json_format.strict
self.assertEqual(name, "VoiceNote")
self.assertEqual(strict, True)
class TestModelSerialization(unittest.TestCase):
"""Test model serialization with hidden states"""
def test_hidden_states_excluded_when_none(self):
"""Test that None hidden_states are excluded with exclude_none=True"""
choice = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content="Hello"),
finish_reason="stop",
hidden_states=None,
)
response = ChatCompletionResponse(
id="test-id",
model="test-model",
choices=[choice],
usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6),
)
# Test exclude_none serialization (should exclude None hidden_states)
data = response.model_dump(exclude_none=True)
self.assertNotIn("hidden_states", data["choices"][0])
def test_hidden_states_included_when_not_none(self):
"""Test that non-None hidden_states are included"""
choice = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content="Hello"),
finish_reason="stop",
hidden_states=[0.1, 0.2, 0.3],
)
response = ChatCompletionResponse(
id="test-id",
model="test-model",
choices=[choice],
usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6),
)
# Test exclude_none serialization (should include non-None hidden_states)
data = response.model_dump(exclude_none=True)
self.assertIn("hidden_states", data["choices"][0])
self.assertEqual(data["choices"][0]["hidden_states"], [0.1, 0.2, 0.3])
class TestValidationEdgeCases(unittest.TestCase):
"""Test edge cases and validation scenarios"""
def test_invalid_tool_choice_type(self):
"""Test invalid tool choice type"""
messages = [{"role": "user", "content": "Hello"}]
with self.assertRaises(ValidationError):
ChatCompletionRequest(
model="test-model", messages=messages, tool_choice=123
)
def test_negative_token_limits(self):
"""Test negative token limits"""
with self.assertRaises(ValidationError):
CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
def test_model_serialization_roundtrip(self):
"""Test that models can be serialized and deserialized"""
original_request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.7,
max_tokens=100,
)
# Serialize to dict
data = original_request.model_dump()
# Deserialize back
restored_request = ChatCompletionRequest(**data)
self.assertEqual(restored_request.model, original_request.model)
self.assertEqual(restored_request.temperature, original_request.temperature)
self.assertEqual(restored_request.max_tokens, original_request.max_tokens)
self.assertEqual(len(restored_request.messages), len(original_request.messages))
if __name__ == "__main__":
unittest.main(verbosity=2)

View File

@@ -0,0 +1,426 @@
"""
Unit-tests for OpenAIServingChat — rewritten to use only the std-lib 'unittest'.
Run with either:
python tests/test_serving_chat_unit.py -v
or
python -m unittest discover -s tests -p "test_*unit.py" -v
"""
import asyncio
import json
import unittest
import uuid
from typing import Optional
from unittest.mock import Mock, patch
from fastapi import Request
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
MessageProcessingResult,
)
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
from sglang.srt.managers.io_struct import GenerateReqInput
class _MockTokenizerManager:
"""Minimal mock that satisfies OpenAIServingChat."""
def __init__(self):
self.model_config = Mock(is_multimodal=False)
self.server_args = Mock(
enable_cache_report=False,
tool_call_parser="hermes",
reasoning_parser=None,
)
self.chat_template_name: Optional[str] = "llama-3"
# tokenizer stub
self.tokenizer = Mock()
self.tokenizer.encode.return_value = [1, 2, 3, 4, 5]
self.tokenizer.decode.return_value = "Test response"
self.tokenizer.chat_template = None
self.tokenizer.bos_token_id = 1
# async generator stub for generate_request
async def _mock_generate():
yield {
"text": "Test response",
"meta_info": {
"id": f"chatcmpl-{uuid.uuid4()}",
"prompt_tokens": 10,
"completion_tokens": 5,
"cached_tokens": 0,
"finish_reason": {"type": "stop", "matched": None},
"output_token_logprobs": [(0.1, 1, "Test"), (0.2, 2, "response")],
"output_top_logprobs": None,
},
"index": 0,
}
self.generate_request = Mock(return_value=_mock_generate())
self.create_abort_task = Mock()
class _MockTemplateManager:
"""Minimal mock for TemplateManager."""
def __init__(self):
self.chat_template_name: Optional[str] = "llama-3"
self.jinja_template_content_format: Optional[str] = None
self.completion_template_name: Optional[str] = None
class ServingChatTestCase(unittest.TestCase):
# ------------- common fixtures -------------
def setUp(self):
self.tm = _MockTokenizerManager()
self.template_manager = _MockTemplateManager()
self.chat = OpenAIServingChat(self.tm, self.template_manager)
# frequently reused requests
self.basic_req = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "Hi?"}],
temperature=0.7,
max_tokens=100,
stream=False,
)
self.stream_req = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "Hi?"}],
temperature=0.7,
max_tokens=100,
stream=True,
)
self.fastapi_request = Mock(spec=Request)
self.fastapi_request.headers = {}
# ------------- conversion tests -------------
def test_convert_to_internal_request_single(self):
with patch(
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
) as conv_mock, patch.object(self.chat, "_process_messages") as proc_mock:
conv_ins = Mock()
conv_ins.get_prompt.return_value = "Test prompt"
conv_ins.image_data = conv_ins.audio_data = None
conv_ins.modalities = []
conv_ins.stop_str = ["</s>"]
conv_mock.return_value = conv_ins
proc_mock.return_value = MessageProcessingResult(
"Test prompt",
[1, 2, 3],
None,
None,
[],
["</s>"],
None,
)
adapted, processed = self.chat._convert_to_internal_request(self.basic_req)
self.assertIsInstance(adapted, GenerateReqInput)
self.assertFalse(adapted.stream)
self.assertEqual(processed, self.basic_req)
def test_stop_str_isolation_between_requests(self):
"""Test that stop strings from one request don't affect subsequent requests.
This tests the fix for the bug where conv.stop_str was being mutated globally,
causing stop strings from one request to persist in subsequent requests.
"""
# Mock conversation template with initial stop_str
initial_stop_str = ["\n"]
with patch(
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
) as conv_mock:
# Create a mock conversation object that will be returned by generate_chat_conv
conv_ins = Mock()
conv_ins.get_prompt.return_value = "Test prompt"
conv_ins.image_data = None
conv_ins.audio_data = None
conv_ins.modalities = []
conv_ins.stop_str = (
initial_stop_str.copy()
) # Template's default stop strings
conv_mock.return_value = conv_ins
# First request with additional stop string
req1 = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "First request"}],
stop=["CUSTOM_STOP"],
)
# Call the actual _apply_conversation_template method (not mocked)
result1 = self.chat._apply_conversation_template(req1, is_multimodal=False)
# Verify first request has both stop strings
expected_stop1 = initial_stop_str + ["CUSTOM_STOP"]
self.assertEqual(result1.stop, expected_stop1)
# Verify the original template's stop_str wasn't mutated after first request
self.assertEqual(conv_ins.stop_str, initial_stop_str)
# Second request without additional stop string
req2 = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "Second request"}],
# No custom stop strings
)
result2 = self.chat._apply_conversation_template(req2, is_multimodal=False)
# Verify second request only has original stop strings (no CUSTOM_STOP from req1)
self.assertEqual(result2.stop, initial_stop_str)
self.assertNotIn("CUSTOM_STOP", result2.stop)
self.assertEqual(conv_ins.stop_str, initial_stop_str)
# ------------- sampling-params -------------
def test_sampling_param_build(self):
req = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "Hi"}],
temperature=0.8,
max_tokens=150,
min_tokens=5,
top_p=0.9,
stop=["</s>"],
)
with patch.object(
self.chat,
"_process_messages",
return_value=("Prompt", [1], None, None, [], ["</s>"], None),
):
params = self.chat._build_sampling_params(req, ["</s>"], None)
self.assertEqual(params["temperature"], 0.8)
self.assertEqual(params["max_new_tokens"], 150)
self.assertEqual(params["min_new_tokens"], 5)
self.assertEqual(params["stop"], ["</s>"])
async def test_unstreamed_tool_args_completion(self):
"""Test that remaining tool call arguments are sent when generation finishes."""
# Mock FunctionCallParser with detector that has partial tool call data
mock_parser = Mock()
mock_detector = Mock()
# Simulate a tool call that was partially streamed
mock_detector.prev_tool_call_arr = [
{
"name": "get_weather",
"arguments": {"location": "San Francisco", "unit": "celsius"},
}
]
mock_detector.streamed_args_for_tool = [
'{"location": "San Francisco"' # Partial arguments streamed so far
]
mock_parser.detector = mock_detector
content = {
"meta_info": {
"id": "chatcmpl-test123",
}
}
request = ChatCompletionRequest(
model="test",
messages=[{"role": "user", "content": "What's the weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather"}}],
)
# Test the completion method
result = self.chat._check_for_unstreamed_tool_args(
parser=mock_parser,
content=content,
request=request,
finish_reason_type="stop",
index=0,
)
# Should return a chunk with remaining arguments
self.assertIsNotNone(result, "Should return chunk with remaining arguments")
self.assertIn('"arguments":', result, "Should contain arguments field")
self.assertIn(
', "unit": "celsius"}', result, "Should contain remaining arguments"
)
self.assertIn(
'"finish_reason":null',
result,
"Should not include finish_reason in completion chunk",
)
async def test_unstreamed_tool_args_no_completion_needed(self):
"""Test that no completion chunk is sent when all arguments were already streamed."""
# Mock FunctionCallParser with detector that has complete tool call data
mock_parser = Mock()
mock_detector = Mock()
# Simulate a tool call that was completely streamed
mock_detector.prev_tool_call_arr = [
{"name": "get_weather", "arguments": {"location": "San Francisco"}}
]
mock_detector.streamed_args_for_tool = [
'{"location": "San Francisco"}' # All arguments already streamed
]
mock_parser.detector = mock_detector
content = {
"meta_info": {
"id": "chatcmpl-test123",
}
}
request = ChatCompletionRequest(
model="test",
messages=[{"role": "user", "content": "What's the weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather"}}],
)
# Test the completion method
result = self.chat._check_for_unstreamed_tool_args(
parser=mock_parser,
content=content,
request=request,
finish_reason_type="stop",
index=0,
)
# Should return None since no completion is needed
self.assertIsNone(result, "Should return None when no completion is needed")
async def test_unstreamed_tool_args_no_parser_data(self):
"""Test that no completion chunk is sent when parser has no tool call data."""
# Mock FunctionCallParser with empty detector
mock_parser = Mock()
mock_detector = Mock()
mock_detector.prev_tool_call_arr = []
mock_detector.streamed_args_for_tool = []
mock_parser.detector = mock_detector
content = {
"meta_info": {
"id": "chatcmpl-test123",
}
}
request = ChatCompletionRequest(
model="test",
messages=[{"role": "user", "content": "What's the weather?"}],
tools=[{"type": "function", "function": {"name": "get_weather"}}],
)
# Test the completion method
result = self.chat._check_for_unstreamed_tool_args(
parser=mock_parser,
content=content,
request=request,
finish_reason_type="stop",
index=0,
)
# Should return None since there's no parser data
self.assertIsNone(
result, "Should return None when parser has no tool call data"
)
# ------------- kimi_k2 tool_call_id formatting -------------
def test_kimi_k2_non_streaming_tool_call_id_format(self):
"""Ensure non-streaming tool_call.id matches functions.{name}:{index} for kimi_k2 parser."""
# Force kimi_k2 parser
self.tm.server_args.tool_call_parser = "kimi_k2"
# Mock FunctionCallParser.parse_non_stream to return one tool call
with patch(
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
) as ParserMock:
parser_instance = ParserMock.return_value
# Build a mock ToolCallItem-like object
call_info = Mock()
call_info.name = "get_weather"
call_info.parameters = '{"city":"Paris"}'
call_info.tool_index = 0
parser_instance.has_tool_call.return_value = True
parser_instance.parse_non_stream.return_value = ("", [call_info])
finish_reason = {"type": "stop", "matched": None}
tools = [
{"type": "function", "function": {"name": "get_weather"}},
]
tool_calls, remaining_text, _ = self.chat._process_tool_calls(
text="<|tool_calls_section_begin|>...",
tools=tools,
tool_call_parser="kimi_k2",
finish_reason=finish_reason,
)
self.assertIsNotNone(tool_calls)
self.assertEqual(len(tool_calls), 1)
self.assertEqual(tool_calls[0].id, "functions.get_weather:0")
self.assertEqual(tool_calls[0].function.name, "get_weather")
def test_kimi_k2_streaming_tool_call_id_format(self):
"""Ensure streaming first chunk tool_call.id matches functions.{name}:{index} for kimi_k2 parser."""
# Force kimi_k2 parser
self.tm.server_args.tool_call_parser = "kimi_k2"
# Prepare request with tools
req = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "Hi?"}],
tools=[{"type": "function", "function": {"name": "get_weather"}}],
stream=True,
)
# Patch FunctionCallParser used inside _process_tool_call_stream
with patch(
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
) as ParserMock:
parser_instance = ParserMock.return_value
# First call returns one ToolCallItem-like chunk (with name)
first_chunk_call = Mock()
first_chunk_call.tool_index = 0
first_chunk_call.name = "get_weather"
first_chunk_call.parameters = ""
parser_instance.parse_stream_chunk.side_effect = [
("", [first_chunk_call]),
("", []),
]
async def collect_first_tool_chunk():
gen = self.chat._process_tool_call_stream(
index=0,
delta="irrelevant",
parser_dict={},
content={"meta_info": {"id": "chatcmpl-test"}},
request=req,
has_tool_calls={},
)
# Get first yielded SSE line
line = None
async for emitted in gen:
line = emitted
break
return line
loop = asyncio.get_event_loop()
line = loop.run_until_complete(collect_first_tool_chunk())
self.assertIsNotNone(line)
self.assertTrue(line.startswith("data: "))
payload = json.loads(line[len("data: ") :])
tool_calls = payload["choices"][0]["delta"]["tool_calls"]
self.assertEqual(tool_calls[0]["id"], "functions.get_weather:0")
if __name__ == "__main__":
unittest.main(verbosity=2)

View File

@@ -0,0 +1,157 @@
"""
Unit-tests for the refactored completions-serving handler (no pytest).
Run with:
python -m unittest tests.test_serving_completions_unit -v
"""
import unittest
from typing import Optional
from unittest.mock import AsyncMock, Mock, patch
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
from sglang.srt.managers.tokenizer_manager import TokenizerManager
class _MockTemplateManager:
"""Minimal mock for TemplateManager."""
def __init__(self):
self.chat_template_name: Optional[str] = None
self.jinja_template_content_format: Optional[str] = None
self.completion_template_name: Optional[str] = (
None # Set to None to avoid template processing
)
class ServingCompletionTestCase(unittest.TestCase):
"""Bundle all prompt/echo tests in one TestCase."""
# ---------- shared test fixtures ----------
def setUp(self):
# build the mock TokenizerManager once for every test
tm = Mock(spec=TokenizerManager)
tm.tokenizer = Mock()
tm.tokenizer.encode.return_value = [1, 2, 3, 4]
tm.tokenizer.decode.return_value = "decoded text"
tm.tokenizer.bos_token_id = 1
tm.model_config = Mock(is_multimodal=False)
tm.server_args = Mock(enable_cache_report=False)
tm.generate_request = AsyncMock()
tm.create_abort_task = Mock()
self.template_manager = _MockTemplateManager()
self.sc = OpenAIServingCompletion(tm, self.template_manager)
# ---------- prompt-handling ----------
def test_single_string_prompt(self):
req = CompletionRequest(model="x", prompt="Hello world", max_tokens=100)
internal, _ = self.sc._convert_to_internal_request(req)
self.assertEqual(internal.text, "Hello world")
def test_single_token_ids_prompt(self):
req = CompletionRequest(model="x", prompt=[1, 2, 3, 4], max_tokens=100)
internal, _ = self.sc._convert_to_internal_request(req)
self.assertEqual(internal.input_ids, [1, 2, 3, 4])
# ---------- echo-handling ----------
def test_echo_with_string_prompt_streaming(self):
req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True)
self.assertEqual(self.sc._get_echo_text(req, 0), "Hello")
def test_echo_with_list_of_strings_streaming(self):
req = CompletionRequest(
model="x", prompt=["A", "B"], max_tokens=1, echo=True, n=1
)
self.assertEqual(self.sc._get_echo_text(req, 0), "A")
self.assertEqual(self.sc._get_echo_text(req, 1), "B")
def test_echo_with_token_ids_streaming(self):
req = CompletionRequest(model="x", prompt=[1, 2, 3], max_tokens=1, echo=True)
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded_prompt"
self.assertEqual(self.sc._get_echo_text(req, 0), "decoded_prompt")
def test_echo_with_multiple_token_ids_streaming(self):
req = CompletionRequest(
model="x", prompt=[[1, 2], [3, 4]], max_tokens=1, echo=True, n=1
)
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
self.assertEqual(self.sc._get_echo_text(req, 0), "decoded")
def test_prepare_echo_prompts_non_streaming(self):
# single string
req = CompletionRequest(model="x", prompt="Hi", echo=True)
self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi"])
# list of strings
req = CompletionRequest(model="x", prompt=["Hi", "Yo"], echo=True)
self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi", "Yo"])
# token IDs
req = CompletionRequest(model="x", prompt=[1, 2, 3], echo=True)
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
self.assertEqual(self.sc._prepare_echo_prompts(req), ["decoded"])
# ---------- response_format handling ----------
def test_response_format_json_object(self):
"""Test that response_format json_object is correctly processed in sampling params."""
req = CompletionRequest(
model="x",
prompt="Generate a JSON object:",
max_tokens=100,
response_format={"type": "json_object"},
)
sampling_params = self.sc._build_sampling_params(req)
self.assertEqual(sampling_params["json_schema"], '{"type": "object"}')
def test_response_format_json_schema(self):
"""Test that response_format json_schema is correctly processed in sampling params."""
schema = {
"type": "object",
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
}
req = CompletionRequest(
model="x",
prompt="Generate a JSON object:",
max_tokens=100,
response_format={
"type": "json_schema",
"json_schema": {"name": "person", "schema": schema},
},
)
sampling_params = self.sc._build_sampling_params(req)
# The schema should be converted to string by convert_json_schema_to_str
self.assertIn("json_schema", sampling_params)
self.assertIsInstance(sampling_params["json_schema"], str)
def test_response_format_structural_tag(self):
"""Test that response_format structural_tag is correctly processed in sampling params."""
req = CompletionRequest(
model="x",
prompt="Generate structured output:",
max_tokens=100,
response_format={
"type": "structural_tag",
"structures": [{"begin": "<data>", "end": "</data>"}],
"triggers": ["<data>"],
},
)
sampling_params = self.sc._build_sampling_params(req)
# The structural_tag should be processed
self.assertIn("structural_tag", sampling_params)
self.assertIsInstance(sampling_params["structural_tag"], str)
def test_response_format_none(self):
"""Test that no response_format doesn't add extra constraints."""
req = CompletionRequest(model="x", prompt="Generate text:", max_tokens=100)
sampling_params = self.sc._build_sampling_params(req)
# Should not have json_schema or structural_tag from response_format
# (but might have json_schema from the legacy json_schema field)
self.assertIsNone(sampling_params.get("structural_tag"))
if __name__ == "__main__":
unittest.main(verbosity=2)

View File

@@ -0,0 +1,145 @@
"""
Unit tests for the OpenAIServingEmbedding class from serving_embedding.py.
"""
import unittest
import uuid
from unittest.mock import Mock
from fastapi import Request
from sglang.srt.entrypoints.openai.protocol import (
EmbeddingRequest,
EmbeddingResponse,
MultimodalEmbeddingInput,
)
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from sglang.srt.managers.io_struct import EmbeddingReqInput
# Mock TokenizerManager for embedding tests
class _MockTokenizerManager:
def __init__(self):
self.model_config = Mock()
self.model_config.is_multimodal = False
self.server_args = Mock()
self.server_args.enable_cache_report = False
self.model_path = "test-model"
# Mock tokenizer
self.tokenizer = Mock()
self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
self.tokenizer.decode = Mock(return_value="Test embedding input")
self.tokenizer.chat_template = None
self.tokenizer.bos_token_id = 1
# Mock generate_request method for embeddings
async def mock_generate_embedding():
yield {
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5] * 20, # 100-dim embedding
"meta_info": {
"id": f"embd-{uuid.uuid4()}",
"prompt_tokens": 5,
},
}
self.generate_request = Mock(return_value=mock_generate_embedding())
# Mock TemplateManager for embedding tests
class _MockTemplateManager:
def __init__(self):
self.chat_template_name = None # None for embeddings usually
self.jinja_template_content_format = None
self.completion_template_name = None
class ServingEmbeddingTestCase(unittest.TestCase):
def setUp(self):
"""Set up test fixtures."""
self.tokenizer_manager = _MockTokenizerManager()
self.template_manager = _MockTemplateManager()
self.serving_embedding = OpenAIServingEmbedding(
self.tokenizer_manager, self.template_manager
)
self.request = Mock(spec=Request)
self.request.headers = {}
self.basic_req = EmbeddingRequest(
model="test-model",
input="Hello, how are you?",
encoding_format="float",
)
self.list_req = EmbeddingRequest(
model="test-model",
input=["Hello, how are you?", "I am fine, thank you!"],
encoding_format="float",
)
self.multimodal_req = EmbeddingRequest(
model="test-model",
input=[
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
MultimodalEmbeddingInput(text="World", image=None),
],
encoding_format="float",
)
self.token_ids_req = EmbeddingRequest(
model="test-model",
input=[1, 2, 3, 4, 5],
encoding_format="float",
)
def test_convert_single_string_request(self):
"""Test converting single string request to internal format."""
adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request(self.basic_req)
)
self.assertIsInstance(adapted_request, EmbeddingReqInput)
self.assertEqual(adapted_request.text, "Hello, how are you?")
# self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(processed_request, self.basic_req)
def test_convert_list_string_request(self):
"""Test converting list of strings request to internal format."""
adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request(self.list_req)
)
self.assertIsInstance(adapted_request, EmbeddingReqInput)
self.assertEqual(
adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
)
# self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(processed_request, self.list_req)
def test_convert_token_ids_request(self):
"""Test converting token IDs request to internal format."""
adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request(self.token_ids_req)
)
self.assertIsInstance(adapted_request, EmbeddingReqInput)
self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
# self.assertEqual(adapted_request.rid, "test-id")
self.assertEqual(processed_request, self.token_ids_req)
def test_convert_multimodal_request(self):
"""Test converting multimodal request to internal format."""
adapted_request, processed_request = (
self.serving_embedding._convert_to_internal_request(self.multimodal_req)
)
self.assertIsInstance(adapted_request, EmbeddingReqInput)
# Should extract text and images separately
self.assertEqual(len(adapted_request.text), 2)
self.assertIn("Hello", adapted_request.text)
self.assertIn("World", adapted_request.text)
self.assertEqual(adapted_request.image_data[0], "base64_image_data")
self.assertIsNone(adapted_request.image_data[1])
# self.assertEqual(adapted_request.rid, "test-id")
if __name__ == "__main__":
unittest.main(verbosity=2)

View File

@@ -0,0 +1,212 @@
import asyncio
import unittest
import openai
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestCacheReport(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.min_cached = 5
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=300,
other_args=[
"--chunked-prefill-size=40",
"--enable-cache-report",
],
)
cls.client = openai.Client(api_key="EMPTY", base_url=f"{cls.base_url}/v1")
cls.aclient = openai.AsyncClient(api_key="EMPTY", base_url=f"{cls.base_url}/v1")
usage = cls.run_openai(cls, "1").usage
# we can assume that our request is of size 1, plus the total template size
# ideally we would like to know the begin size / end size of the template to be more precise
total_template_size = usage.prompt_tokens - 1
print(f"template size: {total_template_size}")
usage2 = cls.run_openai(cls, "2").usage
assert usage2.prompt_tokens_details.cached_tokens <= total_template_size
cls.min_cached = max(
usage2.prompt_tokens_details.cached_tokens,
total_template_size - usage2.prompt_tokens_details.cached_tokens,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
response = requests.post(
self.base_url + "/generate",
# we use an uncommon start to minimise the chance that the cache is hit by chance
json={
"text": "_ The capital of France is",
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 128,
"n": n,
"stop_token_ids": [119690],
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
)
return response
def run_openai(self, message):
response = self.client.chat.completions.create(
model=self.model,
messages=[
# {"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": message},
],
temperature=0,
max_tokens=100,
)
return response
async def run_openai_async(self, message):
response = await self.aclient.chat.completions.create(
model=self.model,
messages=[
{"role": "user", "content": message},
],
temperature=0,
max_tokens=100,
)
return response
def cache_report_openai(self, message):
response = self.run_openai(message)
print(
f"openai first request cached_tokens: {int(response.usage.prompt_tokens_details.cached_tokens)}"
)
first_cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
# assert int(response.usage.cached_tokens) == 0
assert first_cached_tokens <= self.min_cached
response = self.run_openai(message)
cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
print(f"openai second request cached_tokens: {cached_tokens}")
assert cached_tokens > 0
assert cached_tokens == int(response.usage.prompt_tokens) - 1
return first_cached_tokens
async def cache_report_openai_async(self, message):
response = await self.run_openai_async(message)
cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
prompt_tokens = int(response.usage.prompt_tokens)
return cached_tokens, prompt_tokens
def test_generate(self):
print("=" * 100)
response = self.run_decode()
# print(response.json())
cached_tokens = int(response.json()["meta_info"]["cached_tokens"])
print(f"sglang first request cached_tokens: {cached_tokens}")
print(
f"sglang first request prompt_tokens: {int(response.json()['meta_info']['prompt_tokens'])}"
)
# can't assure to be 0: depends on the initialisation request / if a template is used with the model
assert cached_tokens < self.min_cached
response = self.run_decode()
cached_tokens = int(response.json()["meta_info"]["cached_tokens"])
print(f"sglang second request cached_tokens: {cached_tokens}")
print(
f"sglang second request prompt_tokens: {int(response.json()['meta_info']['prompt_tokens'])}"
)
assert cached_tokens == int(response.json()["meta_info"]["prompt_tokens"]) - 1
def test_cache_split_prefill_openai(self):
print("=" * 100)
self.cache_report_openai(
"€ This is a very long and unique text that should not be already cached, the twist is"
" that it should be longer than the chunked-prefill-size, so it should be split among"
" several prefill requests. Still, it shouldn't be cached"
)
def test_cache_report_openai(self):
print("=" * 100)
# warm up the cache, for the template
self.run_openai("Introduce the capital of France.")
first_cached_tokens_1 = self.run_openai(
"How many sparrow do you need to lift a coconut?"
).usage.prompt_tokens_details.cached_tokens
usage_2 = self.run_openai("* sing something about cats").usage
first_cached_tokens_2 = usage_2.prompt_tokens_details.cached_tokens
# first request may not have 0 cached tokens, but if they only have the template in common they
# should be the same once the cache is warmed up
assert first_cached_tokens_1 == first_cached_tokens_2
resp = self.run_openai("* sing something about cats and dogs")
print(resp.usage)
resp = self.run_openai("* sing something about cats, please")
print(resp.usage)
assert (
resp.usage.prompt_tokens_details.cached_tokens
>= usage_2.prompt_tokens - self.min_cached
)
# TODO: flaky test
# def test_cache_report_openai_async(self):
# print("=" * 100)
# async def run_test():
# task0 = asyncio.create_task(
# self.cache_report_openai_async(
# "first request, to start the inference and let the next two request be started in the same batch"
# )
# )
# await asyncio.sleep(1) # to force the first request to be started first
# task1 = asyncio.create_task(
# self.cache_report_openai_async(
# "> can the same batch parallel request use the cache?"
# )
# )
# task2 = asyncio.create_task(
# self.cache_report_openai_async(
# "> can the same batch parallel request use the cache?"
# )
# )
# result0, result1, result2 = await asyncio.gather(task0, task1, task2)
# cached_tokens0, prompt_tokens0 = result0
# cached_tokens1, prompt_tokens1 = result1
# cached_tokens2, prompt_tokens2 = result2
# print(
# f"Async request 0 - Cached tokens: {cached_tokens0}, Prompt tokens: {prompt_tokens0}"
# )
# print(
# f"Async request 1 - Cached tokens: {cached_tokens1}, Prompt tokens: {prompt_tokens1}"
# )
# print(
# f"Async request 2 - Cached tokens: {cached_tokens2}, Prompt tokens: {prompt_tokens2}"
# )
# # Assert that no requests used the cache (because first is alone, and the next two are in the same batch)
# # If a new optimisation limiting starting request with same prefix at the same time was added
# # to maximise the cache hit, this would not be true
# assert cached_tokens1 == cached_tokens2 == cached_tokens0
# asyncio.run(run_test())
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,243 @@
"""
Usage:
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_chat_completion_with_reasoning
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_chat_completion_without_reasoning
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_stream_chat_completion_with_reasoning
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_stream_chat_completion_without_reasoning
"""
import asyncio
import json
import os
import sys
import time
import unittest
import openai
import requests
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestEnableThinking(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-1234"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--reasoning-parser",
"qwen3",
],
)
cls.additional_chat_kwargs = {}
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_chat_completion_with_reasoning(self):
# Test non-streaming with "enable_thinking": True, reasoning_content should not be empty
client = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0,
"separate_reasoning": True,
"chat_template_kwargs": {"enable_thinking": True},
**self.additional_chat_kwargs,
},
)
self.assertEqual(client.status_code, 200, f"Failed with: {client.text}")
data = client.json()
self.assertIn("choices", data)
self.assertTrue(len(data["choices"]) > 0)
self.assertIn("message", data["choices"][0])
self.assertIn("reasoning_content", data["choices"][0]["message"])
self.assertIsNotNone(data["choices"][0]["message"]["reasoning_content"])
def test_chat_completion_without_reasoning(self):
# Test non-streaming with "enable_thinking": False, reasoning_content should be empty
client = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0,
"separate_reasoning": True,
"chat_template_kwargs": {"enable_thinking": False},
**self.additional_chat_kwargs,
},
)
self.assertEqual(client.status_code, 200, f"Failed with: {client.text}")
data = client.json()
self.assertIn("choices", data)
self.assertTrue(len(data["choices"]) > 0)
self.assertIn("message", data["choices"][0])
if "reasoning_content" in data["choices"][0]["message"]:
self.assertIsNone(data["choices"][0]["message"]["reasoning_content"])
def test_stream_chat_completion_with_reasoning(self):
# Test streaming with "enable_thinking": True, reasoning_content should not be empty
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0,
"separate_reasoning": True,
"stream": True,
"chat_template_kwargs": {"enable_thinking": True},
**self.additional_chat_kwargs,
},
stream=True,
)
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
has_reasoning = False
has_content = False
print("\n=== Stream With Reasoning ===")
for line in response.iter_lines():
if line:
line = line.decode("utf-8")
if line.startswith("data:") and not line.startswith("data: [DONE]"):
data = json.loads(line[6:])
if "choices" in data and len(data["choices"]) > 0:
delta = data["choices"][0].get("delta", {})
if "reasoning_content" in delta and delta["reasoning_content"]:
has_reasoning = True
if "content" in delta and delta["content"]:
has_content = True
self.assertTrue(
has_reasoning,
"The reasoning content is not included in the stream response",
)
self.assertTrue(
has_content, "The stream response does not contain normal content"
)
def test_stream_chat_completion_without_reasoning(self):
# Test streaming with "enable_thinking": False, reasoning_content should be empty
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0,
"separate_reasoning": True,
"stream": True,
"chat_template_kwargs": {"enable_thinking": False},
**self.additional_chat_kwargs,
},
stream=True,
)
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
has_reasoning = False
has_content = False
print("\n=== Stream Without Reasoning ===")
for line in response.iter_lines():
if line:
line = line.decode("utf-8")
if line.startswith("data:") and not line.startswith("data: [DONE]"):
data = json.loads(line[6:])
if "choices" in data and len(data["choices"]) > 0:
delta = data["choices"][0].get("delta", {})
if "reasoning_content" in delta and delta["reasoning_content"]:
has_reasoning = True
if "content" in delta and delta["content"]:
has_content = True
self.assertFalse(
has_reasoning,
"The reasoning content should not be included in the stream response",
)
self.assertTrue(
has_content, "The stream response does not contain normal content"
)
# Skip for ci test
# class TestGLM45EnableThinking(TestEnableThinking):
# @classmethod
# def setUpClass(cls):
# # Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST
# cls.model = "THUDM/GLM-4.5"
# cls.base_url = DEFAULT_URL_FOR_TEST
# cls.api_key = "sk-1234"
# cls.process = popen_launch_server(
# cls.model,
# cls.base_url,
# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
# api_key=cls.api_key,
# other_args=[
# "--tool-call-parser",
# "glm45",
# "--reasoning-parser",
# "glm45",
# "--tp-size",
# "8"
# ],
# )
# # Validate whether enable-thinking conflict with tool_calls
# cls.additional_chat_kwargs = {
# "tools": [
# {
# "type": "function",
# "function": {
# "name": "add",
# "description": "Compute the sum of two numbers",
# "parameters": {
# "type": "object",
# "properties": {
# "a": {
# "type": "int",
# "description": "A number",
# },
# "b": {
# "type": "int",
# "description": "A number",
# },
# },
# "required": ["a", "b"],
# },
# },
# }
# ]
# }
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,153 @@
"""
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedLLGuidanceBackend.test_json_generate
"""
import json
import unittest
from concurrent.futures import ThreadPoolExecutor
import openai
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
def setup_class(cls, backend: str):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.json_schema = json.dumps(
{
"type": "object",
"properties": {
"name": {"type": "string", "pattern": "^[\\w]+$"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
"additionalProperties": False,
}
)
other_args = [
"--max-running-requests",
"10",
"--grammar-backend",
backend,
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
class TestJSONConstrainedOutlinesBackend(CustomTestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="outlines")
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 128,
"n": n,
"stop_token_ids": [119690],
"json_schema": json_schema,
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
)
ret = response.json()
print(json.dumps(ret))
print("=" * 100)
if not json_schema or json_schema == "INVALID":
return
# Make sure the json output is valid
try:
js_obj = json.loads(ret["text"])
except (TypeError, json.decoder.JSONDecodeError):
raise
self.assertIsInstance(js_obj["name"], str)
self.assertIsInstance(js_obj["population"], int)
def test_json_generate(self):
self.run_decode(json_schema=self.json_schema)
def test_json_invalid(self):
self.run_decode(json_schema="INVALID")
def test_json_openai(self):
client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": "Introduce the capital of France. Return in a JSON format.",
},
],
temperature=0,
max_tokens=128,
response_format={
"type": "json_schema",
"json_schema": {"name": "foo", "schema": json.loads(self.json_schema)},
},
)
text = response.choices[0].message.content
try:
js_obj = json.loads(text)
except (TypeError, json.decoder.JSONDecodeError):
print("JSONDecodeError", text)
raise
self.assertIsInstance(js_obj["name"], str)
self.assertIsInstance(js_obj["population"], int)
def test_mix_json_and_other(self):
json_schemas = [None, None, self.json_schema, self.json_schema] * 10
with ThreadPoolExecutor(len(json_schemas)) as executor:
list(executor.map(self.run_decode, json_schemas))
class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="xgrammar")
class TestJSONConstrainedLLGuidanceBackend(TestJSONConstrainedOutlinesBackend):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="llguidance")
if __name__ == "__main__":
unittest.main()

Some files were not shown because too many files have changed in this diff Show More