adapt to sglang v0.5.2rc1 on dcu
This commit is contained in:
41
test/README.md
Normal file
41
test/README.md
Normal file
@@ -0,0 +1,41 @@
|
||||
# Run Unit Tests
|
||||
|
||||
SGLang uses the built-in library [unittest](https://docs.python.org/3/library/unittest.html) as the testing framework.
|
||||
|
||||
## Test Backend Runtime
|
||||
```bash
|
||||
cd sglang/test/srt
|
||||
|
||||
# Run a single file
|
||||
python3 test_srt_endpoint.py
|
||||
|
||||
# Run a single test
|
||||
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
|
||||
|
||||
# Run a suite with multiple files
|
||||
python3 run_suite.py --suite per-commit
|
||||
```
|
||||
|
||||
## Test Frontend Language
|
||||
```bash
|
||||
cd sglang/test/lang
|
||||
|
||||
# Run a single file
|
||||
python3 test_srt_backend.py
|
||||
```
|
||||
|
||||
## Adding or Updating Tests in CI
|
||||
|
||||
- Create new test files under `test/srt` or `test/lang` depending on the type of test.
|
||||
- Ensure they are referenced in the respective `run_suite.py` (e.g., `test/srt/run_suite.py`) so they’re picked up in CI. For most small test cases, they can be added to the `per-commit` suite. Sort the test cases alphabetically.
|
||||
- The CI will run the `per-commit` and `nightly` automatically. If you need special setup or custom test groups, you may modify the workflows in [`.github/workflows/`](https://github.com/sgl-project/sglang/tree/main/.github/workflows).
|
||||
|
||||
|
||||
## Writing Elegant Test Cases
|
||||
|
||||
- Examine existing tests in [sglang/test](https://github.com/sgl-project/sglang/tree/main/test) for practical examples.
|
||||
- Keep each test function focused on a single scenario or piece of functionality.
|
||||
- Give tests descriptive names reflecting their purpose.
|
||||
- Use robust assertions (e.g., assert, unittest methods) to validate outcomes.
|
||||
- Clean up resources to avoid side effects and preserve test independence.
|
||||
- Reduce the test time by using smaller models and reusing the server for multiple test cases.
|
||||
BIN
test/lang/example_image.png
Normal file
BIN
test/lang/example_image.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 56 KiB |
38
test/lang/run_suite.py
Normal file
38
test/lang/run_suite.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import argparse
|
||||
import glob
|
||||
|
||||
from sglang.test.test_utils import TestFile, run_unittest_files
|
||||
|
||||
suites = {
|
||||
"per-commit": [
|
||||
TestFile("test_srt_backend.py"),
|
||||
# Skip this due to some OPENAI_API_KEY issues
|
||||
# "test_openai_backend.py",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
arg_parser = argparse.ArgumentParser()
|
||||
arg_parser.add_argument(
|
||||
"--timeout-per-file",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="The time limit for running one file in seconds.",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"--suite",
|
||||
type=str,
|
||||
default=list(suites.keys())[0],
|
||||
choices=list(suites.keys()) + ["all"],
|
||||
help="The suite to run",
|
||||
)
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
if args.suite == "all":
|
||||
files = glob.glob("**/test_*.py", recursive=True)
|
||||
else:
|
||||
files = suites[args.suite]
|
||||
|
||||
exit_code = run_unittest_files(files, args.timeout_per_file)
|
||||
exit(exit_code)
|
||||
25
test/lang/test_anthropic_backend.py
Normal file
25
test/lang/test_anthropic_backend.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import json
|
||||
import unittest
|
||||
|
||||
from sglang import Anthropic, set_default_backend
|
||||
from sglang.test.test_programs import test_mt_bench, test_stream
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
class TestAnthropicBackend(CustomTestCase):
|
||||
backend = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.backend = Anthropic("claude-3-haiku-20240307")
|
||||
set_default_backend(cls.backend)
|
||||
|
||||
def test_mt_bench(self):
|
||||
test_mt_bench()
|
||||
|
||||
def test_stream(self):
|
||||
test_stream()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
51
test/lang/test_bind_cache.py
Normal file
51
test/lang/test_bind_cache.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import unittest
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, CustomTestCase
|
||||
|
||||
|
||||
class TestBind(CustomTestCase):
|
||||
backend = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.backend = sgl.Runtime(model_path=DEFAULT_MODEL_NAME_FOR_TEST)
|
||||
sgl.set_default_backend(cls.backend)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.backend.shutdown()
|
||||
|
||||
def test_bind(self):
|
||||
@sgl.function
|
||||
def few_shot_qa(s, prompt, question):
|
||||
s += prompt
|
||||
s += "Q: What is the capital of France?\n"
|
||||
s += "A: Paris\n"
|
||||
s += "Q: " + question + "\n"
|
||||
s += "A:" + sgl.gen("answer", stop="\n")
|
||||
|
||||
few_shot_qa_2 = few_shot_qa.bind(
|
||||
prompt="The following are questions with answers.\n\n"
|
||||
)
|
||||
|
||||
tracer = few_shot_qa_2.trace()
|
||||
print(tracer.last_node.print_graph_dfs() + "\n")
|
||||
|
||||
def test_cache(self):
|
||||
@sgl.function
|
||||
def few_shot_qa(s, prompt, question):
|
||||
s += prompt
|
||||
s += "Q: What is the capital of France?\n"
|
||||
s += "A: Paris\n"
|
||||
s += "Q: " + question + "\n"
|
||||
s += "A:" + sgl.gen("answer", stop="\n")
|
||||
|
||||
few_shot_qa_2 = few_shot_qa.bind(
|
||||
prompt="Answer the following questions as if you were a 5-year-old kid.\n\n"
|
||||
)
|
||||
few_shot_qa_2.cache(self.backend)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
91
test/lang/test_choices.py
Normal file
91
test/lang/test_choices.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sglang.lang.choices import (
|
||||
greedy_token_selection,
|
||||
token_length_normalized,
|
||||
unconditional_likelihood_normalized,
|
||||
)
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
MOCK_CHOICES_INPUT_DATA = {
|
||||
"choices": [
|
||||
"organ", # ["organ"]
|
||||
"organism", # ["organ", "ism"]
|
||||
"antidisestablishmentarianism", # ["ant", "id", "is", "est", "ablish", "ment", "arian", "ism"]
|
||||
],
|
||||
"normalized_prompt_logprobs": [-0.1, -0.2, -0.05],
|
||||
"input_token_logprobs": [
|
||||
[[-0.1, 1, None]],
|
||||
[[-0.1, 1, None], [-0.3, 2, None]],
|
||||
[
|
||||
[-0.4, 3, None],
|
||||
[-0.25, 4, None],
|
||||
[-0.1, 5, None],
|
||||
[-0.01, 6, None],
|
||||
[-0.01, 7, None],
|
||||
[-0.01, 8, None],
|
||||
[-0.01, 9, None],
|
||||
[-0.01, 2, None],
|
||||
],
|
||||
],
|
||||
"output_token_logprobs": [
|
||||
[[-0.1, 10, None]],
|
||||
[[-0.1, 10, None]],
|
||||
[[-0.1, 10, None]],
|
||||
],
|
||||
"unconditional_token_logprobs": [
|
||||
[[None, 1, None]],
|
||||
[[None, 1, None], [-1.4, 2, None]],
|
||||
[
|
||||
[None, 3, None],
|
||||
[-0.25, 4, None],
|
||||
[-0.1, 5, None],
|
||||
[-0.01, 6, None],
|
||||
[-0.01, 7, None],
|
||||
[-0.01, 8, None],
|
||||
[-0.01, 9, None],
|
||||
[-0.01, 2, None],
|
||||
],
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class TestChoices(CustomTestCase):
|
||||
|
||||
def test_token_length_normalized(self):
|
||||
"""Confirm 'antidisestablishmentarianism' is selected due to high confidences for
|
||||
its later tokens resulting in highest token length normalized prompt logprob."""
|
||||
decision = token_length_normalized(**MOCK_CHOICES_INPUT_DATA)
|
||||
assert decision.decision == "antidisestablishmentarianism"
|
||||
|
||||
def test_greedy_token_selection(self):
|
||||
"""Confirm 'organ' is selected due it having the joint highest initial token
|
||||
logprob, and a higher average logprob than organism's second token."""
|
||||
decision = greedy_token_selection(**MOCK_CHOICES_INPUT_DATA)
|
||||
assert decision.decision == "organ"
|
||||
assert np.allclose(
|
||||
decision.meta_info["greedy_logprob_matrix"],
|
||||
[
|
||||
[-0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1],
|
||||
[-0.1, -0.3, -0.2, -0.2, -0.2, -0.2, -0.2, -0.2],
|
||||
[-0.4, -0.25, -0.1, -0.01, -0.01, -0.01, -0.01, -0.01],
|
||||
],
|
||||
atol=0.01,
|
||||
)
|
||||
|
||||
def test_unconditional_likelihood_normalized(self):
|
||||
"""Confirm 'organism' is selected due to it having the highest average token logprob
|
||||
once normalized by the unconditional token logprobs."""
|
||||
decision = unconditional_likelihood_normalized(**MOCK_CHOICES_INPUT_DATA)
|
||||
assert decision.decision == "organism"
|
||||
assert np.allclose(
|
||||
decision.meta_info["normalized_unconditional_prompt_logprobs"],
|
||||
[-0.1, 0.5, -0.05],
|
||||
atol=0.01,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
25
test/lang/test_litellm_backend.py
Normal file
25
test/lang/test_litellm_backend.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import json
|
||||
import unittest
|
||||
|
||||
from sglang import LiteLLM, set_default_backend
|
||||
from sglang.test.test_programs import test_mt_bench, test_stream
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
class TestAnthropicBackend(CustomTestCase):
|
||||
chat_backend = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.chat_backend = LiteLLM("gpt-3.5-turbo")
|
||||
set_default_backend(cls.chat_backend)
|
||||
|
||||
def test_mt_bench(self):
|
||||
test_mt_bench()
|
||||
|
||||
def test_stream(self):
|
||||
test_stream()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
92
test/lang/test_openai_backend.py
Normal file
92
test/lang/test_openai_backend.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import unittest
|
||||
|
||||
from sglang import OpenAI, set_default_backend
|
||||
from sglang.test.test_programs import (
|
||||
test_chat_completion_speculative,
|
||||
test_completion_speculative,
|
||||
test_decode_int,
|
||||
test_decode_json,
|
||||
test_expert_answer,
|
||||
test_few_shot_qa,
|
||||
test_image_qa,
|
||||
test_mt_bench,
|
||||
test_parallel_decoding,
|
||||
test_parallel_encoding,
|
||||
test_react,
|
||||
test_select,
|
||||
test_stream,
|
||||
test_tool_use,
|
||||
)
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
class TestOpenAIBackend(CustomTestCase):
|
||||
instruct_backend = None
|
||||
chat_backend = None
|
||||
chat_vision_backend = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.instruct_backend = OpenAI("gpt-3.5-turbo-instruct")
|
||||
cls.chat_backend = OpenAI("gpt-3.5-turbo")
|
||||
cls.chat_vision_backend = OpenAI("gpt-4-turbo")
|
||||
|
||||
def test_few_shot_qa(self):
|
||||
set_default_backend(self.instruct_backend)
|
||||
test_few_shot_qa()
|
||||
|
||||
def test_mt_bench(self):
|
||||
set_default_backend(self.chat_backend)
|
||||
test_mt_bench()
|
||||
|
||||
def test_select(self):
|
||||
set_default_backend(self.instruct_backend)
|
||||
test_select(check_answer=True)
|
||||
|
||||
def test_decode_int(self):
|
||||
set_default_backend(self.instruct_backend)
|
||||
test_decode_int()
|
||||
|
||||
def test_decode_json(self):
|
||||
set_default_backend(self.instruct_backend)
|
||||
test_decode_json()
|
||||
|
||||
def test_expert_answer(self):
|
||||
set_default_backend(self.instruct_backend)
|
||||
test_expert_answer()
|
||||
|
||||
def test_tool_use(self):
|
||||
set_default_backend(self.instruct_backend)
|
||||
test_tool_use()
|
||||
|
||||
def test_react(self):
|
||||
set_default_backend(self.instruct_backend)
|
||||
test_react()
|
||||
|
||||
def test_parallel_decoding(self):
|
||||
set_default_backend(self.instruct_backend)
|
||||
test_parallel_decoding()
|
||||
|
||||
def test_parallel_encoding(self):
|
||||
set_default_backend(self.instruct_backend)
|
||||
test_parallel_encoding()
|
||||
|
||||
def test_image_qa(self):
|
||||
set_default_backend(self.chat_vision_backend)
|
||||
test_image_qa()
|
||||
|
||||
def test_stream(self):
|
||||
set_default_backend(self.instruct_backend)
|
||||
test_stream()
|
||||
|
||||
def test_completion_speculative(self):
|
||||
set_default_backend(self.instruct_backend)
|
||||
test_completion_speculative()
|
||||
|
||||
def test_chat_completion_speculative(self):
|
||||
set_default_backend(self.chat_backend)
|
||||
test_chat_completion_speculative()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
68
test/lang/test_separate_reasoning.py
Normal file
68
test/lang/test_separate_reasoning.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Tests for the separate_reasoning functionality in sglang.
|
||||
|
||||
Usage:
|
||||
python3 -m unittest test/lang/test_separate_reasoning.py
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from sglang import assistant, gen, separate_reasoning, user
|
||||
from sglang.lang.ir import SglExprList, SglSeparateReasoning
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
class TestSeparateReasoning(CustomTestCase):
|
||||
def test_separate_reasoning_creation(self):
|
||||
"""Test that SglSeparateReasoning objects are created correctly."""
|
||||
# Test with valid model type and gen expression
|
||||
test_gen = gen("test")
|
||||
expr = separate_reasoning(test_gen, model_type="deepseek-r1")
|
||||
self.assertIsInstance(expr, SglExprList)
|
||||
self.assertEqual(len(expr.expr_list), 2)
|
||||
self.assertEqual(expr.expr_list[0], test_gen)
|
||||
reasoning_expr = expr.expr_list[1]
|
||||
self.assertIsInstance(reasoning_expr, SglSeparateReasoning)
|
||||
self.assertEqual(reasoning_expr.model_type, "deepseek-r1")
|
||||
self.assertEqual(reasoning_expr.name, "test_reasoning_content")
|
||||
|
||||
# Test with another valid model type
|
||||
expr = separate_reasoning(test_gen, model_type="qwen3")
|
||||
self.assertIsInstance(expr, SglExprList)
|
||||
self.assertEqual(expr.expr_list[1].model_type, "qwen3")
|
||||
|
||||
def test_separate_reasoning_name_processing(self):
|
||||
"""Test that separate_reasoning correctly processes names."""
|
||||
test_gen = gen("test_var")
|
||||
expr = separate_reasoning(test_gen, model_type="deepseek-r1")
|
||||
reasoning_expr = expr.expr_list[1]
|
||||
self.assertEqual(reasoning_expr.name, "test_var_reasoning_content")
|
||||
|
||||
# Test the process_name_for_reasoning method
|
||||
self.assertEqual(
|
||||
reasoning_expr.process_name_for_reasoning("another_var"),
|
||||
"another_var_reasoning_content",
|
||||
)
|
||||
|
||||
def test_separate_reasoning_repr(self):
|
||||
"""Test the string representation of SglSeparateReasoning."""
|
||||
test_gen = gen("test_var")
|
||||
expr = separate_reasoning(test_gen, model_type="deepseek-r1")
|
||||
reasoning_expr = expr.expr_list[1]
|
||||
self.assertEqual(
|
||||
repr(reasoning_expr),
|
||||
"SeparateReasoning(model_type=deepseek-r1, name=test_var_reasoning_content)",
|
||||
)
|
||||
|
||||
def test_separate_reasoning_with_invalid_model_type(self):
|
||||
"""Test that separate_reasoning accepts any model type during creation."""
|
||||
# Create with invalid model type
|
||||
test_gen = gen("test")
|
||||
expr = separate_reasoning(test_gen, model_type="invalid-model")
|
||||
self.assertIsInstance(expr, SglExprList)
|
||||
self.assertEqual(expr.expr_list[1].model_type, "invalid-model")
|
||||
# The actual validation happens in the ReasoningParser constructor
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
195
test/lang/test_separate_reasoning_execution.py
Normal file
195
test/lang/test_separate_reasoning_execution.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""
|
||||
Tests for the execution of separate_reasoning functionality in sglang.
|
||||
|
||||
Usage:
|
||||
python3 -m unittest test/lang/test_separate_reasoning_execution.py
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from sglang import assistant, gen, separate_reasoning, user
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglGen, SglSeparateReasoning
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
# Helper function to create events that won't block program exit
|
||||
def create_daemon_event():
|
||||
event = threading.Event()
|
||||
return event
|
||||
|
||||
|
||||
class MockReasoningParser:
|
||||
def __init__(self, model_type):
|
||||
self.model_type = model_type
|
||||
self.parse_non_stream_called = False
|
||||
self.parse_stream_chunk_called = False
|
||||
|
||||
def parse_non_stream(self, full_text):
|
||||
self.parse_non_stream_called = True
|
||||
# Simulate parsing by adding a prefix to indicate reasoning
|
||||
reasoning = f"[REASONING from {self.model_type}]: {full_text}"
|
||||
normal_text = f"[NORMAL from {self.model_type}]: {full_text}"
|
||||
return reasoning, normal_text
|
||||
|
||||
def parse_stream_chunk(self, chunk_text):
|
||||
self.parse_stream_chunk_called = True
|
||||
# Simulate parsing by adding a prefix to indicate reasoning
|
||||
reasoning = f"[REASONING from {self.model_type}]: {chunk_text}"
|
||||
normal_text = f"[NORMAL from {self.model_type}]: {chunk_text}"
|
||||
return reasoning, normal_text
|
||||
|
||||
|
||||
class TestSeparateReasoningExecution(CustomTestCase):
|
||||
def setUp(self):
|
||||
"""Set up for the test."""
|
||||
super().setUp()
|
||||
# Store any events created during the test
|
||||
self.events = []
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up any threads that might have been created during the test."""
|
||||
super().tearDown()
|
||||
|
||||
# Set all events to ensure any waiting threads are released
|
||||
for event in self.events:
|
||||
event.set()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
# wake up all threads
|
||||
for ev in self.events:
|
||||
ev.set()
|
||||
|
||||
@patch("sglang.srt.parser.reasoning_parser.ReasoningParser")
|
||||
def test_execute_separate_reasoning(self, mock_parser_class):
|
||||
"""Test that _execute_separate_reasoning correctly calls the ReasoningParser."""
|
||||
# Setup mock parser
|
||||
mock_parser = MockReasoningParser("deepseek-r1")
|
||||
mock_parser_class.return_value = mock_parser
|
||||
|
||||
# Create a mock backend to avoid AttributeError in __del__
|
||||
mock_backend = MagicMock()
|
||||
|
||||
# Create a StreamExecutor with necessary setup
|
||||
executor = StreamExecutor(
|
||||
backend=mock_backend,
|
||||
arguments={},
|
||||
default_sampling_para={},
|
||||
chat_template={
|
||||
"role_map": {"user": "user", "assistant": "assistant"}
|
||||
}, # Simple chat template
|
||||
stream=False,
|
||||
use_thread=False,
|
||||
)
|
||||
|
||||
# Set up the executor with a variable and its value
|
||||
var_name = "test_var"
|
||||
reasoning_name = f"{var_name}_reasoning_content"
|
||||
var_value = "Test content"
|
||||
executor.variables = {var_name: var_value}
|
||||
|
||||
# Create events and track them for cleanup
|
||||
var_event = create_daemon_event()
|
||||
reasoning_event = create_daemon_event()
|
||||
self.events.extend([var_event, reasoning_event])
|
||||
|
||||
executor.variable_event = {var_name: var_event, reasoning_name: reasoning_event}
|
||||
executor.variable_event[var_name].set() # Mark as ready
|
||||
|
||||
# Set up the current role
|
||||
executor.cur_role = "assistant"
|
||||
executor.cur_role_begin_pos = 0
|
||||
executor.text_ = var_value
|
||||
|
||||
# Create a gen expression and a separate_reasoning expression
|
||||
gen_expr = SglGen(var_name)
|
||||
expr = SglSeparateReasoning("deepseek-r1", expr=gen_expr)
|
||||
|
||||
# Execute separate_reasoning
|
||||
executor._execute_separate_reasoning(expr)
|
||||
|
||||
# Verify that the parser was created with the correct model type
|
||||
mock_parser_class.assert_called_once_with("deepseek-r1")
|
||||
|
||||
# Verify that parse_non_stream was called
|
||||
self.assertTrue(mock_parser.parse_non_stream_called)
|
||||
|
||||
# Verify that the variables were updated correctly
|
||||
reasoning_name = f"{var_name}_reasoning_content"
|
||||
self.assertIn(reasoning_name, executor.variables)
|
||||
self.assertEqual(
|
||||
executor.variables[reasoning_name],
|
||||
f"[REASONING from deepseek-r1]: {var_value}",
|
||||
)
|
||||
self.assertEqual(
|
||||
executor.variables[var_name], f"[NORMAL from deepseek-r1]: {var_value}"
|
||||
)
|
||||
|
||||
# Verify that the variable event was set
|
||||
self.assertIn(reasoning_name, executor.variable_event)
|
||||
self.assertTrue(executor.variable_event[reasoning_name].is_set())
|
||||
|
||||
# Verify that the text was updated
|
||||
self.assertEqual(executor.text_, f"[NORMAL from deepseek-r1]: {var_value}")
|
||||
|
||||
@patch("sglang.srt.parser.reasoning_parser.ReasoningParser")
|
||||
def test_reasoning_parser_integration(self, mock_parser_class):
|
||||
"""Test the integration between separate_reasoning and ReasoningParser."""
|
||||
# Setup mock parsers for different model types
|
||||
deepseek_parser = MockReasoningParser("deepseek-r1")
|
||||
qwen_parser = MockReasoningParser("qwen3")
|
||||
|
||||
# Configure the mock to return different parsers based on model type
|
||||
def get_parser(model_type):
|
||||
if model_type == "deepseek-r1":
|
||||
return deepseek_parser
|
||||
elif model_type == "qwen3":
|
||||
return qwen_parser
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
|
||||
mock_parser_class.side_effect = get_parser
|
||||
|
||||
# Test with DeepSeek-R1 model
|
||||
test_text = "This is a test"
|
||||
reasoning, normal_text = deepseek_parser.parse_non_stream(test_text)
|
||||
|
||||
self.assertEqual(reasoning, f"[REASONING from deepseek-r1]: {test_text}")
|
||||
self.assertEqual(normal_text, f"[NORMAL from deepseek-r1]: {test_text}")
|
||||
|
||||
# Test with Qwen3 model
|
||||
reasoning, normal_text = qwen_parser.parse_non_stream(test_text)
|
||||
|
||||
self.assertEqual(reasoning, f"[REASONING from qwen3]: {test_text}")
|
||||
self.assertEqual(normal_text, f"[NORMAL from qwen3]: {test_text}")
|
||||
|
||||
@patch("sglang.srt.parser.reasoning_parser.ReasoningParser")
|
||||
def test_reasoning_parser_invalid_model(self, mock_parser_class):
|
||||
"""Test that ReasoningParser raises an error for invalid model types."""
|
||||
|
||||
# Configure the mock to raise an error for invalid model types
|
||||
def get_parser(model_type):
|
||||
if model_type in ["deepseek-r1", "qwen3"]:
|
||||
return MockReasoningParser(model_type)
|
||||
elif model_type is None:
|
||||
raise ValueError("Model type must be specified")
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {model_type}")
|
||||
|
||||
mock_parser_class.side_effect = get_parser
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
mock_parser_class("invalid-model")
|
||||
self.assertIn("Unsupported model type", str(context.exception))
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
mock_parser_class(None)
|
||||
self.assertIn("Model type must be specified", str(context.exception))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
86
test/lang/test_srt_backend.py
Normal file
86
test/lang/test_srt_backend.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m unittest test_srt_backend.TestSRTBackend.test_gen_min_new_tokens
|
||||
python3 -m unittest test_srt_backend.TestSRTBackend.test_hellaswag_select
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_programs import (
|
||||
test_decode_int,
|
||||
test_decode_json_regex,
|
||||
test_dtype_gen,
|
||||
test_expert_answer,
|
||||
test_few_shot_qa,
|
||||
test_gen_min_new_tokens,
|
||||
test_hellaswag_select,
|
||||
test_mt_bench,
|
||||
test_parallel_decoding,
|
||||
test_regex,
|
||||
test_select,
|
||||
test_stream,
|
||||
test_tool_use,
|
||||
)
|
||||
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, CustomTestCase
|
||||
|
||||
|
||||
class TestSRTBackend(CustomTestCase):
|
||||
backend = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.backend = sgl.Runtime(
|
||||
model_path=DEFAULT_MODEL_NAME_FOR_TEST, cuda_graph_max_bs=4
|
||||
)
|
||||
sgl.set_default_backend(cls.backend)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.backend.shutdown()
|
||||
|
||||
def test_few_shot_qa(self):
|
||||
test_few_shot_qa()
|
||||
|
||||
def test_mt_bench(self):
|
||||
test_mt_bench()
|
||||
|
||||
def test_select(self):
|
||||
test_select(check_answer=False)
|
||||
|
||||
def test_decode_int(self):
|
||||
test_decode_int()
|
||||
|
||||
def test_decode_json_regex(self):
|
||||
test_decode_json_regex()
|
||||
|
||||
def test_expert_answer(self):
|
||||
test_expert_answer()
|
||||
|
||||
def test_tool_use(self):
|
||||
test_tool_use()
|
||||
|
||||
def test_parallel_decoding(self):
|
||||
test_parallel_decoding()
|
||||
|
||||
def test_stream(self):
|
||||
test_stream()
|
||||
|
||||
def test_regex(self):
|
||||
test_regex()
|
||||
|
||||
def test_dtype_gen(self):
|
||||
test_dtype_gen()
|
||||
|
||||
def test_hellaswag_select(self):
|
||||
# Run twice to capture more bugs
|
||||
for _ in range(2):
|
||||
accuracy, latency = test_hellaswag_select()
|
||||
self.assertGreater(accuracy, 0.60)
|
||||
|
||||
def test_gen_min_new_tokens(self):
|
||||
test_gen_min_new_tokens()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
129
test/lang/test_tracing.py
Normal file
129
test/lang/test_tracing.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import unittest
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
class TestTracing(CustomTestCase):
|
||||
def test_few_shot_qa(self):
|
||||
@sgl.function
|
||||
def few_shot_qa(s, question):
|
||||
s += "The following are questions with answers.\n\n"
|
||||
s += "Q: What is the capital of France?\n"
|
||||
s += "A: Paris\n"
|
||||
s += "Q: " + question + "\n"
|
||||
s += "A:" + sgl.gen("answer", stop="\n")
|
||||
|
||||
tracer = few_shot_qa.trace()
|
||||
# print(tracer.last_node.print_graph_dfs() + "\n")
|
||||
|
||||
def test_select(self):
|
||||
@sgl.function
|
||||
def capital(s):
|
||||
s += "The capital of France is"
|
||||
s += sgl.select("capital", ["Paris. ", "London. "])
|
||||
s += "It is a city" + sgl.gen("description", stop=".")
|
||||
|
||||
tracer = capital.trace()
|
||||
# print(tracer.last_node.print_graph_dfs() + "\n")
|
||||
|
||||
def test_raise_warning(self):
|
||||
@sgl.function
|
||||
def wrong(s, question):
|
||||
s += f"I want to ask {question}"
|
||||
|
||||
try:
|
||||
tracer = wrong.trace()
|
||||
raised = False
|
||||
except TypeError:
|
||||
raised = True
|
||||
|
||||
assert raised
|
||||
|
||||
def test_multi_function(self):
|
||||
@sgl.function
|
||||
def expand(s, tip):
|
||||
s += (
|
||||
"Please expand the following tip into a detailed paragraph:"
|
||||
+ tip
|
||||
+ "\n"
|
||||
)
|
||||
s += sgl.gen("detailed_tip")
|
||||
|
||||
@sgl.function
|
||||
def tip_suggestion(s, topic):
|
||||
s += "Here are 2 tips for " + topic + ".\n"
|
||||
|
||||
s += "1." + sgl.gen("tip_1", stop=["\n", ":", "."]) + "\n"
|
||||
s += "2." + sgl.gen("tip_2", stop=["\n", ":", "."]) + "\n"
|
||||
|
||||
branch1 = expand(tip=s["tip_1"])
|
||||
branch2 = expand(tip=s["tip_2"])
|
||||
|
||||
s += "Tip 1: " + branch1["detailed_tip"] + "\n"
|
||||
s += "Tip 2: " + branch2["detailed_tip"] + "\n"
|
||||
s += "In summary" + sgl.gen("summary")
|
||||
|
||||
compiled = tip_suggestion.compile()
|
||||
# compiled.print_graph()
|
||||
|
||||
sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
|
||||
state = compiled.run(topic="staying healthy")
|
||||
# print(state.text() + "\n")
|
||||
|
||||
states = compiled.run_batch(
|
||||
[
|
||||
{"topic": "staying healthy"},
|
||||
{"topic": "staying happy"},
|
||||
{"topic": "earning money"},
|
||||
],
|
||||
temperature=0,
|
||||
)
|
||||
# for s in states:
|
||||
# print(s.text() + "\n")
|
||||
|
||||
def test_role(self):
|
||||
@sgl.function
|
||||
def multi_turn_chat(s):
|
||||
s += sgl.user("Who are you?")
|
||||
s += sgl.assistant(sgl.gen("answer_1"))
|
||||
s += sgl.user("Who created you?")
|
||||
s += sgl.assistant(sgl.gen("answer_2"))
|
||||
|
||||
backend = BaseBackend()
|
||||
backend.chat_template = get_chat_template("llama-2-chat")
|
||||
|
||||
compiled = multi_turn_chat.compile(backend=backend)
|
||||
# compiled.print_graph()
|
||||
|
||||
def test_fork(self):
|
||||
@sgl.function
|
||||
def tip_suggestion(s):
|
||||
s += (
|
||||
"Here are three tips for staying healthy: "
|
||||
"1. Balanced Diet; "
|
||||
"2. Regular Exercise; "
|
||||
"3. Adequate Sleep\n"
|
||||
)
|
||||
|
||||
forks = s.fork(3)
|
||||
for i in range(3):
|
||||
forks[i] += f"Now, expand tip {i+1} into a paragraph:\n"
|
||||
forks[i] += sgl.gen(f"detailed_tip")
|
||||
|
||||
s += "Tip 1:" + forks[0]["detailed_tip"] + "\n"
|
||||
s += "Tip 2:" + forks[1]["detailed_tip"] + "\n"
|
||||
s += "Tip 3:" + forks[2]["detailed_tip"] + "\n"
|
||||
s += "In summary" + sgl.gen("summary")
|
||||
|
||||
tracer = tip_suggestion.trace()
|
||||
# print(tracer.last_node.print_graph_dfs())
|
||||
|
||||
a = tip_suggestion.run(backend=sgl.OpenAI("gpt-3.5-turbo-instruct"))
|
||||
# print(a.text())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
53
test/lang/test_vertexai_backend.py
Normal file
53
test/lang/test_vertexai_backend.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import unittest
|
||||
|
||||
from sglang import VertexAI, set_default_backend
|
||||
from sglang.test.test_programs import (
|
||||
test_expert_answer,
|
||||
test_few_shot_qa,
|
||||
test_image_qa,
|
||||
test_mt_bench,
|
||||
test_parallel_decoding,
|
||||
test_parallel_encoding,
|
||||
test_stream,
|
||||
)
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
class TestVertexAIBackend(CustomTestCase):
|
||||
backend = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.backend = VertexAI("gemini-1.5-pro-001")
|
||||
|
||||
def test_few_shot_qa(self):
|
||||
set_default_backend(self.backend)
|
||||
test_few_shot_qa()
|
||||
|
||||
def test_mt_bench(self):
|
||||
set_default_backend(self.backend)
|
||||
test_mt_bench()
|
||||
|
||||
def test_expert_answer(self):
|
||||
set_default_backend(self.backend)
|
||||
test_expert_answer(check_answer=False)
|
||||
|
||||
def test_parallel_decoding(self):
|
||||
set_default_backend(self.backend)
|
||||
test_parallel_decoding()
|
||||
|
||||
def test_parallel_encoding(self):
|
||||
set_default_backend(self.backend)
|
||||
test_parallel_encoding()
|
||||
|
||||
def test_image_qa(self):
|
||||
set_default_backend(self.backend)
|
||||
test_image_qa()
|
||||
|
||||
def test_stream(self):
|
||||
set_default_backend(self.backend)
|
||||
test_stream()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
2
test/pytest.ini
Normal file
2
test/pytest.ini
Normal file
@@ -0,0 +1,2 @@
|
||||
[pytest]
|
||||
asyncio_mode = auto
|
||||
95
test/srt/ascend/test_ascend_graph_tp1_bf16.py
Normal file
95
test/srt/ascend/test_ascend_graph_tp1_bf16.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
run_bench_offline_throughput,
|
||||
)
|
||||
|
||||
TEST_MODEL_MATRIX = {
|
||||
"Qwen/Qwen2.5-7B-Instruct": {
|
||||
"accuracy": 0.85,
|
||||
"latency": 150,
|
||||
"output_throughput": 30,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestAscendGraphTp1Bf16(CustomTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.models = TEST_MODEL_MATRIX.keys()
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||
cls.common_args = [
|
||||
"--trust-remote-code",
|
||||
"--mem-fraction-static",
|
||||
0.8,
|
||||
"--attention-backend",
|
||||
"ascend",
|
||||
]
|
||||
|
||||
def test_a_gsm8k(self):
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing accuracy: {model} ===##")
|
||||
|
||||
process = popen_launch_server(
|
||||
model,
|
||||
self.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=1319,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host=f"http://{self.url.hostname}",
|
||||
port=int(self.url.port),
|
||||
)
|
||||
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
self.assertGreaterEqual(
|
||||
metrics["accuracy"],
|
||||
TEST_MODEL_MATRIX[model]["accuracy"],
|
||||
)
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
def test_b_throughput(self):
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing throughput: {model} ===##")
|
||||
|
||||
output_throughput = run_bench_offline_throughput(
|
||||
model,
|
||||
[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
print(f"##=== {model} throughput: {output_throughput} ===##")
|
||||
|
||||
if is_in_ci():
|
||||
self.assertGreater(
|
||||
output_throughput,
|
||||
TEST_MODEL_MATRIX[model]["output_throughput"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
97
test/srt/ascend/test_ascend_graph_tp2_bf16.py
Normal file
97
test/srt/ascend/test_ascend_graph_tp2_bf16.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
run_bench_offline_throughput,
|
||||
)
|
||||
|
||||
TEST_MODEL_MATRIX = {
|
||||
"Qwen/Qwen2.5-7B-Instruct": {
|
||||
"accuracy": 0.85,
|
||||
"latency": 180,
|
||||
"output_throughput": 20,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestAscendGraphTp2Bf16(CustomTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.models = TEST_MODEL_MATRIX.keys()
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||
cls.common_args = [
|
||||
"--trust-remote-code",
|
||||
"--mem-fraction-static",
|
||||
0.8,
|
||||
"--attention-backend",
|
||||
"ascend",
|
||||
"--tp-size",
|
||||
2,
|
||||
]
|
||||
|
||||
def test_a_gsm8k(self):
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing accuracy: {model} ===##")
|
||||
|
||||
process = popen_launch_server(
|
||||
model,
|
||||
self.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=1319,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host=f"http://{self.url.hostname}",
|
||||
port=int(self.url.port),
|
||||
)
|
||||
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
self.assertGreaterEqual(
|
||||
metrics["accuracy"],
|
||||
TEST_MODEL_MATRIX[model]["accuracy"],
|
||||
)
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
def test_b_throughput(self):
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing throughput: {model} ===##")
|
||||
|
||||
output_throughput = run_bench_offline_throughput(
|
||||
model,
|
||||
[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
print(f"##=== {model} throughput: {output_throughput} ===##")
|
||||
|
||||
if is_in_ci():
|
||||
self.assertGreater(
|
||||
output_throughput,
|
||||
TEST_MODEL_MATRIX[model]["output_throughput"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
103
test/srt/ascend/test_ascend_mla_fia_w8a8int8.py
Normal file
103
test/srt/ascend/test_ascend_mla_fia_w8a8int8.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import os
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
run_bench_offline_throughput,
|
||||
)
|
||||
|
||||
TEST_MODEL_MATRIX = {
|
||||
"/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-V2-Lite-W8A8": {
|
||||
"accuracy": 0.34,
|
||||
"latency": 1000,
|
||||
"output_throughput": 6,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestAscendMlaW8A8Int8(CustomTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.models = TEST_MODEL_MATRIX.keys()
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||
cls.common_args = [
|
||||
"--trust-remote-code",
|
||||
"--disable-cuda-graph",
|
||||
"--mem-fraction-static",
|
||||
0.8,
|
||||
"--attention-backend",
|
||||
"ascend",
|
||||
"--quantization",
|
||||
"w8a8_int8",
|
||||
"--tp-size",
|
||||
2,
|
||||
"--disable-radix-cache",
|
||||
]
|
||||
|
||||
def test_a_gsm8k(self):
|
||||
os.environ["ASCEND_USE_FIA"] = "true"
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing accuracy: {model} ===##")
|
||||
|
||||
process = popen_launch_server(
|
||||
model,
|
||||
self.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=1319,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host=f"http://{self.url.hostname}",
|
||||
port=int(self.url.port),
|
||||
)
|
||||
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
self.assertGreaterEqual(
|
||||
metrics["accuracy"],
|
||||
TEST_MODEL_MATRIX[model]["accuracy"],
|
||||
)
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
def test_b_throughput(self):
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing throughput: {model} ===##")
|
||||
|
||||
output_throughput = run_bench_offline_throughput(
|
||||
model,
|
||||
[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
print(f"##=== {model} throughput: {output_throughput} ===##")
|
||||
|
||||
if is_in_ci():
|
||||
self.assertGreater(
|
||||
output_throughput,
|
||||
TEST_MODEL_MATRIX[model]["output_throughput"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
101
test/srt/ascend/test_ascend_mla_w8a8int8.py
Normal file
101
test/srt/ascend/test_ascend_mla_w8a8int8.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
run_bench_offline_throughput,
|
||||
)
|
||||
|
||||
TEST_MODEL_MATRIX = {
|
||||
"/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-V2-Lite-W8A8": {
|
||||
"accuracy": 0.34,
|
||||
"latency": 1000,
|
||||
"output_throughput": 6,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestAscendMlaW8A8Int8(CustomTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.models = TEST_MODEL_MATRIX.keys()
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||
cls.common_args = [
|
||||
"--trust-remote-code",
|
||||
"--disable-cuda-graph",
|
||||
"--mem-fraction-static",
|
||||
0.8,
|
||||
"--attention-backend",
|
||||
"ascend",
|
||||
"--quantization",
|
||||
"w8a8_int8",
|
||||
"--tp-size",
|
||||
4,
|
||||
"--disable-radix-cache",
|
||||
]
|
||||
|
||||
def test_a_gsm8k(self):
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing accuracy: {model} ===##")
|
||||
|
||||
process = popen_launch_server(
|
||||
model,
|
||||
self.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=1319,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host=f"http://{self.url.hostname}",
|
||||
port=int(self.url.port),
|
||||
)
|
||||
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
self.assertGreaterEqual(
|
||||
metrics["accuracy"],
|
||||
TEST_MODEL_MATRIX[model]["accuracy"],
|
||||
)
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
def test_b_throughput(self):
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing throughput: {model} ===##")
|
||||
|
||||
output_throughput = run_bench_offline_throughput(
|
||||
model,
|
||||
[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
print(f"##=== {model} throughput: {output_throughput} ===##")
|
||||
|
||||
if is_in_ci():
|
||||
self.assertGreater(
|
||||
output_throughput,
|
||||
TEST_MODEL_MATRIX[model]["output_throughput"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
96
test/srt/ascend/test_ascend_tp1_bf16.py
Normal file
96
test/srt/ascend/test_ascend_tp1_bf16.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
run_bench_offline_throughput,
|
||||
)
|
||||
|
||||
TEST_MODEL_MATRIX = {
|
||||
"Qwen/Qwen2.5-7B-Instruct": {
|
||||
"accuracy": 0.84,
|
||||
"latency": 150,
|
||||
"output_throughput": 30,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestAscendTp1Bf16(CustomTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.models = TEST_MODEL_MATRIX.keys()
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||
cls.common_args = [
|
||||
"--trust-remote-code",
|
||||
"--disable-cuda-graph",
|
||||
"--mem-fraction-static",
|
||||
0.8,
|
||||
"--attention-backend",
|
||||
"ascend",
|
||||
]
|
||||
|
||||
def test_a_gsm8k(self):
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing accuracy: {model} ===##")
|
||||
|
||||
process = popen_launch_server(
|
||||
model,
|
||||
self.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=1319,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host=f"http://{self.url.hostname}",
|
||||
port=int(self.url.port),
|
||||
)
|
||||
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
self.assertGreaterEqual(
|
||||
metrics["accuracy"],
|
||||
TEST_MODEL_MATRIX[model]["accuracy"],
|
||||
)
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
def test_b_throughput(self):
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing throughput: {model} ===##")
|
||||
|
||||
output_throughput = run_bench_offline_throughput(
|
||||
model,
|
||||
[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
print(f"##=== {model} throughput: {output_throughput} ===##")
|
||||
|
||||
if is_in_ci():
|
||||
self.assertGreater(
|
||||
output_throughput,
|
||||
TEST_MODEL_MATRIX[model]["output_throughput"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
98
test/srt/ascend/test_ascend_tp2_bf16.py
Normal file
98
test/srt/ascend/test_ascend_tp2_bf16.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
run_bench_offline_throughput,
|
||||
)
|
||||
|
||||
TEST_MODEL_MATRIX = {
|
||||
"Qwen/Qwen2.5-7B-Instruct": {
|
||||
"accuracy": 0.85,
|
||||
"latency": 180,
|
||||
"output_throughput": 20,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestAscendTp2Bf16(CustomTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.models = TEST_MODEL_MATRIX.keys()
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||
cls.common_args = [
|
||||
"--trust-remote-code",
|
||||
"--disable-cuda-graph",
|
||||
"--mem-fraction-static",
|
||||
0.8,
|
||||
"--attention-backend",
|
||||
"ascend",
|
||||
"--tp-size",
|
||||
2,
|
||||
]
|
||||
|
||||
def test_a_gsm8k(self):
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing accuracy: {model} ===##")
|
||||
|
||||
process = popen_launch_server(
|
||||
model,
|
||||
self.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=1319,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host=f"http://{self.url.hostname}",
|
||||
port=int(self.url.port),
|
||||
)
|
||||
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
self.assertGreaterEqual(
|
||||
metrics["accuracy"],
|
||||
TEST_MODEL_MATRIX[model]["accuracy"],
|
||||
)
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
def test_b_throughput(self):
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing throughput: {model} ===##")
|
||||
|
||||
output_throughput = run_bench_offline_throughput(
|
||||
model,
|
||||
[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
print(f"##=== {model} throughput: {output_throughput} ===##")
|
||||
|
||||
if is_in_ci():
|
||||
self.assertGreater(
|
||||
output_throughput,
|
||||
TEST_MODEL_MATRIX[model]["output_throughput"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
101
test/srt/ascend/test_ascend_tp2_fia_bf16.py
Normal file
101
test/srt/ascend/test_ascend_tp2_fia_bf16.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import os
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
run_bench_offline_throughput,
|
||||
)
|
||||
|
||||
TEST_MODEL_MATRIX = {
|
||||
"Qwen/Qwen2.5-7B-Instruct": {
|
||||
"accuracy": 0.85,
|
||||
"latency": 180,
|
||||
"output_throughput": 20,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestAscendTp2Bf16(CustomTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.models = TEST_MODEL_MATRIX.keys()
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||
cls.common_args = [
|
||||
"--trust-remote-code",
|
||||
"--disable-cuda-graph",
|
||||
"--mem-fraction-static",
|
||||
0.8,
|
||||
"--attention-backend",
|
||||
"ascend",
|
||||
"--tp-size",
|
||||
2,
|
||||
"--disable-radix-cache",
|
||||
]
|
||||
|
||||
def test_a_gsm8k(self):
|
||||
os.environ["ASCEND_USE_FIA"] = "true"
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing accuracy: {model} ===##")
|
||||
|
||||
process = popen_launch_server(
|
||||
model,
|
||||
self.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=1319,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host=f"http://{self.url.hostname}",
|
||||
port=int(self.url.port),
|
||||
)
|
||||
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
self.assertGreaterEqual(
|
||||
metrics["accuracy"],
|
||||
TEST_MODEL_MATRIX[model]["accuracy"],
|
||||
)
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
def test_b_throughput(self):
|
||||
for model in self.models:
|
||||
with self.subTest(model=model):
|
||||
print(f"##=== Testing throughput: {model} ===##")
|
||||
|
||||
output_throughput = run_bench_offline_throughput(
|
||||
model,
|
||||
[
|
||||
*self.common_args,
|
||||
],
|
||||
)
|
||||
|
||||
print(f"##=== {model} throughput: {output_throughput} ===##")
|
||||
|
||||
if is_in_ci():
|
||||
self.assertGreater(
|
||||
output_throughput,
|
||||
TEST_MODEL_MATRIX[model]["output_throughput"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
104
test/srt/ascend/test_ascend_w8a8_quantization.py
Normal file
104
test/srt/ascend/test_ascend_w8a8_quantization.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m unittest test_ascend_w8a8_quantization.TestAscendW8A8.test_gsm8k
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ:
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
|
||||
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
|
||||
7000 + int(os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "0")[0]) * 100
|
||||
)
|
||||
DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}"
|
||||
|
||||
|
||||
class TestAscendW8A8(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--disable-cuda-graph",
|
||||
"--device",
|
||||
"npu",
|
||||
"--attention-backend",
|
||||
"ascend",
|
||||
"--quantization",
|
||||
"w8a8_int8",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
base_url = DEFAULT_URL_FOR_TEST
|
||||
url = urlparse(base_url)
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host=f"http://{url.hostname}",
|
||||
port=int(url.port),
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
print(metrics)
|
||||
|
||||
self.assertGreaterEqual(metrics["accuracy"], 0.25)
|
||||
self.assertGreaterEqual(metrics["output_throughput"], 1000)
|
||||
|
||||
def run_decode(self, max_new_tokens):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
},
|
||||
"ignore_eos": True,
|
||||
},
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def test_throughput(self):
|
||||
max_tokens = 256
|
||||
|
||||
tic = time.perf_counter()
|
||||
res = self.run_decode(max_tokens)
|
||||
tok = time.perf_counter()
|
||||
print(res["text"])
|
||||
throughput = max_tokens / (tok - tic)
|
||||
print(f"Throughput: {throughput} tokens/s")
|
||||
|
||||
if is_in_ci():
|
||||
self.assertGreaterEqual(throughput, 25)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
28
test/srt/configs/deepseek_v3.yaml
Normal file
28
test/srt/configs/deepseek_v3.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
tasks:
|
||||
- name: sglang-8192-1024-concurrency1
|
||||
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 1 --num-prompts 5 --output-file deepseek_v3_results.jsonl
|
||||
|
||||
- name: sglang-8192-1024-concurrency2
|
||||
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 2 --num-prompts 10 --output-file deepseek_v3_results.jsonl
|
||||
|
||||
- name: sglang-8192-1024-concurrency4
|
||||
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 4 --num-prompts 20 --output-file deepseek_v3_results.jsonl
|
||||
|
||||
- name: sglang-8192-1024-concurrency8
|
||||
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 8 --num-prompts 32 --output-file deepseek_v3_results.jsonl
|
||||
|
||||
- name: sglang-8192-1024-concurrency16
|
||||
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 16 --num-prompts 48 --output-file deepseek_v3_results.jsonl
|
||||
|
||||
- name: sglang-8192-1024-concurrency24
|
||||
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 24 --num-prompts 72 --output-file deepseek_v3_results.jsonl
|
||||
|
||||
- name: sglang-8192-1024-concurrency32
|
||||
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 32 --num-prompts 96 --output-file deepseek_v3_results.jsonl
|
||||
28
test/srt/configs/deepseek_v3_long_context.yaml
Normal file
28
test/srt/configs/deepseek_v3_long_context.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
tasks:
|
||||
- name: sglang-32000-100-concurrency1
|
||||
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 32000 --random-output-len 100 --max-concurrency 1 --num-prompts 5 --output-file deepseek_v3_long_context_results.jsonl
|
||||
|
||||
- name: sglang-32000-100-concurrency2
|
||||
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 32000 --random-output-len 100 --max-concurrency 2 --num-prompts 10 --output-file deepseek_v3_long_context_results.jsonl
|
||||
|
||||
- name: sglang-32000-100-concurrency4
|
||||
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 32000 --random-output-len 100 --max-concurrency 4 --num-prompts 20 --output-file deepseek_v3_long_context_results.jsonl
|
||||
|
||||
- name: sglang-32000-100-concurrency8
|
||||
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 32000 --random-output-len 100 --max-concurrency 8 --num-prompts 32 --output-file deepseek_v3_long_context_results.jsonl
|
||||
|
||||
- name: sglang-32000-100-concurrency16
|
||||
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 32000 --random-output-len 100 --max-concurrency 16 --num-prompts 48 --output-file deepseek_v3_long_context_results.jsonl
|
||||
|
||||
- name: sglang-32000-100-concurrency24
|
||||
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 32000 --random-output-len 100 --max-concurrency 24 --num-prompts 72 --output-file deepseek_v3_long_context_results.jsonl
|
||||
|
||||
- name: sglang-32000-100-concurrency32
|
||||
server_cmd: python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --trust-remote-code --disable-radix-cache --max-prefill-tokens 32768
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 32000 --random-output-len 100 --max-concurrency 32 --num-prompts 96 --output-file deepseek_v3_long_context_results.jsonl
|
||||
28
test/srt/configs/llama_405b.yaml
Normal file
28
test/srt/configs/llama_405b.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
tasks:
|
||||
- name: sglang-8192-1024-concurrency1
|
||||
server_cmd: python3 -m sglang.launch_server --model nvidia/Llama-3.1-405B-Instruct-FP8 --tp 8
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 1 --num-prompts 5 --output-file llama_405b_results.jsonl
|
||||
|
||||
- name: sglang-8192-1024-concurrency2
|
||||
server_cmd: python3 -m sglang.launch_server --model nvidia/Llama-3.1-405B-Instruct-FP8 --tp 8
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 2 --num-prompts 10 --output-file llama_405b_results.jsonl
|
||||
|
||||
- name: sglang-8192-1024-concurrency4
|
||||
server_cmd: python3 -m sglang.launch_server --model nvidia/Llama-3.1-405B-Instruct-FP8 --tp 8
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 4 --num-prompts 20 --output-file llama_405b_results.jsonl
|
||||
|
||||
- name: sglang-8192-1024-concurrency8
|
||||
server_cmd: python3 -m sglang.launch_server --model nvidia/Llama-3.1-405B-Instruct-FP8 --tp 8
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 8 --num-prompts 32 --output-file llama_405b_results.jsonl
|
||||
|
||||
- name: sglang-8192-1024-concurrency16
|
||||
server_cmd: python3 -m sglang.launch_server --model nvidia/Llama-3.1-405B-Instruct-FP8 --tp 8
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 16 --num-prompts 48 --output-file llama_405b_results.jsonl
|
||||
|
||||
- name: sglang-8192-1024-concurrency24
|
||||
server_cmd: python3 -m sglang.launch_server --model nvidia/Llama-3.1-405B-Instruct-FP8 --tp 8
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 24 --num-prompts 72 --output-file llama_405b_results.jsonl
|
||||
|
||||
- name: sglang-8192-1024-concurrency32
|
||||
server_cmd: python3 -m sglang.launch_server --model nvidia/Llama-3.1-405B-Instruct-FP8 --tp 8
|
||||
client_cmd: python3 -m sglang.bench_serving --dataset-name random --random-range-ratio 1 --random-input-len 8192 --random-output-len 1024 --max-concurrency 32 --num-prompts 96 --output-file llama_405b_results.jsonl
|
||||
25
test/srt/configs/random_config.yaml
Normal file
25
test/srt/configs/random_config.yaml
Normal file
@@ -0,0 +1,25 @@
|
||||
tasks:
|
||||
- name: sglang-128-4
|
||||
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
|
||||
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440
|
||||
- name: vllm-128-4
|
||||
server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests
|
||||
client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440
|
||||
- name: sglang-2000-100
|
||||
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
|
||||
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120
|
||||
- name: vllm-2000-100
|
||||
server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests
|
||||
client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120
|
||||
- name: sglang-4000-200
|
||||
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
|
||||
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480
|
||||
- name: vllm-4000-200
|
||||
server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests
|
||||
client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480
|
||||
- name: sglang-32000-100
|
||||
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
|
||||
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60
|
||||
- name: vllm-32000-100
|
||||
server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests
|
||||
client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60
|
||||
25
test/srt/configs/random_flashinfer_vs_triton_config.yaml
Normal file
25
test/srt/configs/random_flashinfer_vs_triton_config.yaml
Normal file
@@ -0,0 +1,25 @@
|
||||
tasks:
|
||||
- name: sglang-128-4
|
||||
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
|
||||
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440
|
||||
- name: sglang-triton-128-4
|
||||
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton
|
||||
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440
|
||||
- name: sglang-2000-100
|
||||
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
|
||||
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120
|
||||
- name: sglang-triton-2000-100
|
||||
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton
|
||||
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120
|
||||
- name: sglang-4000-200
|
||||
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
|
||||
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480
|
||||
- name: sglang-triton-4000-200
|
||||
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton
|
||||
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480
|
||||
- name: sglang-32000-100
|
||||
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
|
||||
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60
|
||||
- name: sglang-triton-32000-100
|
||||
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton
|
||||
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60
|
||||
7
test/srt/configs/sharegpt_config.yaml
Normal file
7
test/srt/configs/sharegpt_config.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
tasks:
|
||||
- name: sglang-benchmark
|
||||
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
|
||||
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --request-rate 16
|
||||
- name: vllm-benchmark
|
||||
server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests
|
||||
client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --request-rate 16
|
||||
35
test/srt/cpu/test_activation.py
Normal file
35
test/srt/cpu/test_activation.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from utils import SiluAndMul, precision
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
|
||||
class TestActivation(CustomTestCase):
|
||||
M = [128, 129, 257]
|
||||
N = [22016, 22018]
|
||||
dtype = [torch.float16, torch.bfloat16]
|
||||
|
||||
def _activation_test(self, m, n, dtype):
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
|
||||
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
|
||||
ref_out = SiluAndMul(x)
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
def test_activation(self):
|
||||
for params in itertools.product(self.M, self.N, self.dtype):
|
||||
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
|
||||
self._activation_test(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
28
test/srt/cpu/test_binding.py
Normal file
28
test/srt/cpu/test_binding.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import re
|
||||
import unittest
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
|
||||
kernel = torch.ops.sgl_kernel
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
class TestGemm(CustomTestCase):
|
||||
def test_binding(self):
|
||||
start_id = 1
|
||||
n_cpu = 6
|
||||
|
||||
expected_cores = list(map(str, range(start_id, start_id + n_cpu)))
|
||||
cpu_ids = ",".join(expected_cores)
|
||||
output = kernel.init_cpu_threads_env(cpu_ids)
|
||||
|
||||
bindings = re.findall(r"OMP tid: \d+, core (\d+)", output)
|
||||
self.assertEqual(len(bindings), n_cpu)
|
||||
|
||||
self.assertEqual(bindings, expected_cores)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
170
test/srt/cpu/test_decode.py
Normal file
170
test/srt/cpu/test_decode.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import unittest
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
|
||||
class TestDecodeAttention(CustomTestCase):
|
||||
def _run_sdpa_forward_decode(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
scaling=None,
|
||||
enable_gqa=False,
|
||||
causal=False,
|
||||
):
|
||||
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
|
||||
start_q, start_kv = 0, 0
|
||||
for seq_idx in range(seq_lens.shape[0]):
|
||||
seq_len_q = 1
|
||||
seq_len_kv = seq_lens[seq_idx]
|
||||
end_q = start_q + seq_len_q
|
||||
end_kv = start_kv + seq_len_kv
|
||||
|
||||
per_req_query = query[:, start_q:end_q, :]
|
||||
|
||||
# get key and value from cache. per_req_tokens contains the kv cache
|
||||
# index for each token in the sequence.
|
||||
req_pool_idx = req_pool_indices[seq_idx]
|
||||
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
||||
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||
|
||||
per_req_out = (
|
||||
scaled_dot_product_attention(
|
||||
per_req_query.unsqueeze(0),
|
||||
per_req_key.unsqueeze(0),
|
||||
per_req_value.unsqueeze(0),
|
||||
enable_gqa=enable_gqa,
|
||||
scale=scaling,
|
||||
is_causal=causal,
|
||||
)
|
||||
.squeeze(0)
|
||||
.movedim(query.dim() - 2, 0)
|
||||
)
|
||||
output[start_q:end_q, :, :] = per_req_out
|
||||
start_q, start_kv = end_q, end_kv
|
||||
|
||||
return output
|
||||
|
||||
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, device):
|
||||
dtype = torch.bfloat16
|
||||
# This represents the number of tokens already in the sequence
|
||||
seq_len = 1024
|
||||
total_tokens = B * seq_len
|
||||
sm_scale = 1.0 / (D**0.5)
|
||||
logit_cap = 0.0
|
||||
num_kv_splits = 8
|
||||
enable_gqa = H_Q != H_KV
|
||||
|
||||
# q represents the new token being generated, one per batch
|
||||
q = torch.randn(B, H_Q, D, dtype=dtype, device=device)
|
||||
|
||||
# k_buffer and v_buffer represent all previous tokens
|
||||
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device)
|
||||
v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device=device)
|
||||
|
||||
key = torch.randn(B, H_KV, D, dtype=dtype)
|
||||
value = torch.randn(B, H_KV, D_V, dtype=dtype)
|
||||
loc = torch.randint(0, 10, (B,)).to(torch.int64)
|
||||
|
||||
# set kv cache
|
||||
k_buffer[loc] = key
|
||||
v_buffer[loc] = value
|
||||
|
||||
# o will have the same shape as q
|
||||
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device)
|
||||
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device)
|
||||
|
||||
req_to_token = (
|
||||
torch.arange(total_tokens, device=device)
|
||||
.reshape(B, seq_len)
|
||||
.to(torch.int32)
|
||||
)
|
||||
b_req_idx = torch.arange(B, device=device).to(torch.int64)
|
||||
b_seq_len = torch.full((B,), seq_len, device=device).to(torch.int64)
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(B, H_Q, num_kv_splits, D_V + 1),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# k_buffer, v_buffer, query, key and value supports non-contiguous tensors
|
||||
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
q = q.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
key = key.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
value = value.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
torch.ops.sgl_kernel.decode_attention_cpu(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
key,
|
||||
value,
|
||||
loc,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
|
||||
self._run_sdpa_forward_decode(
|
||||
q,
|
||||
o_grouped,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
scaling=sm_scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
o.flatten(), o_grouped.flatten(), dim=0
|
||||
)
|
||||
self.assertGreater(cos_sim.item(), 0.99)
|
||||
torch.testing.assert_close(o, o_grouped, atol=3e-2, rtol=1e-6)
|
||||
|
||||
def _test_grouped_decode_attention(self, device="cuda"):
|
||||
configs = [
|
||||
(2, 16, 16, 64, 64),
|
||||
(2, 16, 1, 16, 16),
|
||||
(2, 32, 8, 33, 55),
|
||||
(2, 16, 1, 64, 64),
|
||||
(2, 64, 1, 13, 13),
|
||||
(2, 128, 1, 80, 80),
|
||||
(2, 128, 2, 512, 512),
|
||||
(1, 16, 1, 576, 512),
|
||||
(1, 16, 16, 576, 512),
|
||||
(1, 22, 1, 576, 512),
|
||||
(1, 40, 8, 128, 128),
|
||||
]
|
||||
|
||||
for B, H_Q, H_KV, D, D_V in configs:
|
||||
self._test_grouped_decode_attention_once(
|
||||
B, H_Q, H_KV, D, D_V, device=device
|
||||
)
|
||||
|
||||
def test_grouped_decode_attention(self):
|
||||
self._test_grouped_decode_attention("cpu")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
190
test/srt/cpu/test_extend.py
Normal file
190
test/srt/cpu/test_extend.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import unittest
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
|
||||
class TestExtendAttention(CustomTestCase):
|
||||
|
||||
def _run_sdpa_forward_extend(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
extend_prefix_lens: torch.Tensor,
|
||||
extend_seq_lens: torch.Tensor,
|
||||
scaling=None,
|
||||
enable_gqa=False,
|
||||
causal=False,
|
||||
):
|
||||
|
||||
assert seq_lens.shape[0] == extend_prefix_lens.shape[0]
|
||||
assert seq_lens.shape[0] == extend_seq_lens.shape[0]
|
||||
|
||||
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
|
||||
start_q, start_kv = 0, 0
|
||||
for seq_idx in range(seq_lens.shape[0]):
|
||||
|
||||
extend_seq_len_q = extend_seq_lens[seq_idx]
|
||||
prefill_seq_len_q = extend_prefix_lens[seq_idx]
|
||||
|
||||
seq_len_kv = seq_lens[seq_idx]
|
||||
end_q = start_q + extend_seq_len_q
|
||||
end_kv = start_kv + seq_len_kv
|
||||
|
||||
per_req_query = query[:, start_q:end_q, :]
|
||||
per_req_query_redudant = torch.empty(
|
||||
(per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),
|
||||
dtype=per_req_query.dtype,
|
||||
device=per_req_query.device,
|
||||
)
|
||||
|
||||
per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query
|
||||
|
||||
# get key and value from cache. per_req_tokens contains the kv cache
|
||||
# index for each token in the sequence.
|
||||
req_pool_idx = req_pool_indices[seq_idx]
|
||||
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
||||
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||
|
||||
per_req_out_redudant = (
|
||||
scaled_dot_product_attention(
|
||||
per_req_query_redudant.unsqueeze(0),
|
||||
per_req_key.unsqueeze(0),
|
||||
per_req_value.unsqueeze(0),
|
||||
enable_gqa=enable_gqa,
|
||||
scale=scaling,
|
||||
is_causal=causal,
|
||||
)
|
||||
.squeeze(0)
|
||||
.movedim(query.dim() - 2, 0)
|
||||
)
|
||||
output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :]
|
||||
start_q, start_kv = end_q, end_kv
|
||||
return output
|
||||
|
||||
def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D, DV, mla=False):
|
||||
dtype = torch.bfloat16
|
||||
|
||||
b_seq_len_prefix = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32)
|
||||
if mla:
|
||||
b_seq_len_prefix.zero_()
|
||||
b_seq_len_extend = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32)
|
||||
b_seq_len = b_seq_len_prefix + b_seq_len_extend
|
||||
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
|
||||
|
||||
b_req_idx = torch.arange(B, dtype=torch.int32)
|
||||
req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32)
|
||||
b_start_loc = torch.zeros((B,), dtype=torch.int32)
|
||||
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
|
||||
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32)
|
||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||
|
||||
for i in range(B):
|
||||
req_to_tokens[i, : b_seq_len[i]] = torch.arange(
|
||||
b_start_loc[i], b_start_loc[i] + b_seq_len[i]
|
||||
)
|
||||
|
||||
total_token_num = torch.sum(b_seq_len).item()
|
||||
extend_token_num = torch.sum(b_seq_len_extend).item()
|
||||
|
||||
H_BUF = 1 if mla else H_KV
|
||||
k_buffer = torch.randn((total_token_num, H_BUF, D), dtype=dtype)
|
||||
v_buffer = torch.randn((total_token_num, H_BUF, DV), dtype=dtype)
|
||||
|
||||
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype)
|
||||
v_extend = torch.empty((extend_token_num, H_KV, DV), dtype=dtype)
|
||||
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype)
|
||||
|
||||
for i in range(B):
|
||||
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
|
||||
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
|
||||
extend_start = b_start_loc_extend[i]
|
||||
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
|
||||
k_extend[extend_start:extend_end] = k_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
v_extend[extend_start:extend_end] = v_buffer[
|
||||
extend_start_in_buffer:extend_end_in_buffer
|
||||
]
|
||||
q_extend[extend_start:extend_end] = torch.randn(
|
||||
(b_seq_len_extend[i], H_Q, D), dtype=dtype
|
||||
)
|
||||
|
||||
# q_extend, k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors
|
||||
q_extend = q_extend.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
k_extend = k_extend.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
v_extend = v_extend.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
|
||||
b_seq_len_extend = b_seq_len - b_seq_len_prefix
|
||||
b_start_loc_extend = torch.zeros_like(b_seq_len)
|
||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
|
||||
|
||||
sm_scale = 1.0 / (D**0.5)
|
||||
logit_cap = 0.0
|
||||
|
||||
# handle index type
|
||||
b_req_idx = b_req_idx.to(torch.int64)
|
||||
b_seq_len = b_seq_len.to(torch.int64)
|
||||
|
||||
enable_gqa = H_Q != H_KV
|
||||
o_ref = torch.empty((extend_token_num, H_Q, DV), dtype=dtype)
|
||||
self._run_sdpa_forward_extend(
|
||||
q_extend,
|
||||
o_ref,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
b_seq_len_prefix,
|
||||
b_seq_len_extend,
|
||||
scaling=sm_scale,
|
||||
enable_gqa=enable_gqa,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
o_extend = torch.empty((extend_token_num, H_Q, DV), dtype=dtype)
|
||||
torch.ops.sgl_kernel.extend_attention_cpu(
|
||||
q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
b_seq_len_extend,
|
||||
b_start_loc_extend,
|
||||
max_len_extend,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(o_ref, o_extend, atol=1e-2, rtol=1e-2)
|
||||
|
||||
def test_extend_attention(self):
|
||||
for is_mla in [True, False]:
|
||||
self._test_extend_attention_once(1, 123, 1, 1, 128, 96, is_mla)
|
||||
self._test_extend_attention_once(1, 123, 16, 1, 128, 96, is_mla)
|
||||
self._test_extend_attention_once(4, 1230, 16, 4, 128, 96, is_mla)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
189
test/srt/cpu/test_gemm.py
Normal file
189
test/srt/cpu/test_gemm.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
# TODO: use interface in cpu.py
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from utils import (
|
||||
convert_weight,
|
||||
native_w8a8_per_token_matmul,
|
||||
per_token_quant_int8,
|
||||
precision,
|
||||
)
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
|
||||
class Mod(nn.Module):
|
||||
def __init__(self, input_channel, output_channel, has_bias):
|
||||
super(Mod, self).__init__()
|
||||
self.linear = torch.nn.Linear(input_channel, output_channel, has_bias)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
class TestGemm(CustomTestCase):
|
||||
M = [1, 101]
|
||||
N = [16, 32 * 13]
|
||||
K = [32 * 16]
|
||||
has_bias = [False, True]
|
||||
|
||||
M_int8 = [2, 128]
|
||||
N_int8 = [32 * 12]
|
||||
K_int8 = [32 * 17]
|
||||
|
||||
M_fp8 = [1, 11]
|
||||
N_fp8 = [128, 224]
|
||||
K_fp8 = [512, 576]
|
||||
|
||||
def _bf16_gemm(self, M, N, K, has_bias):
|
||||
|
||||
mat1 = torch.randn(M, K, dtype=torch.bfloat16)
|
||||
mat2 = torch.randn(N, K, dtype=torch.bfloat16)
|
||||
|
||||
ref = torch.matmul(mat1.float(), mat2.float().t())
|
||||
if has_bias:
|
||||
bias = torch.randn(N, dtype=torch.float32)
|
||||
ref.add_(bias.bfloat16())
|
||||
|
||||
ref = ref.bfloat16()
|
||||
|
||||
out = torch.ops.sgl_kernel.weight_packed_linear(
|
||||
mat1, mat2, bias if has_bias else None, False
|
||||
)
|
||||
|
||||
packed_mat2 = torch.ops.sgl_kernel.convert_weight_packed(mat2)
|
||||
out2 = torch.ops.sgl_kernel.weight_packed_linear(
|
||||
mat1, packed_mat2, bias if has_bias else None, True
|
||||
)
|
||||
|
||||
atol = rtol = precision[ref.dtype]
|
||||
torch.testing.assert_close(ref, out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(ref, out2, atol=atol, rtol=rtol)
|
||||
|
||||
def test_bf16_gemm(self):
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.has_bias,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
has_bias=params[3],
|
||||
):
|
||||
self._bf16_gemm(*params)
|
||||
|
||||
def _int8_gemm(self, M, N, K, has_bias):
|
||||
dtype = torch.bfloat16
|
||||
A = torch.randn((M, K), dtype=dtype) / 10
|
||||
Aq, As = per_token_quant_int8(A)
|
||||
|
||||
factor_for_scale = 1e-2
|
||||
int8_max = 127
|
||||
int8_min = -128
|
||||
|
||||
B = (torch.rand((N, K), dtype=torch.float32) - 0.5) * 2
|
||||
Bq = (B * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
Bs = torch.rand(N) * factor_for_scale
|
||||
|
||||
bias = torch.randn(N) if has_bias else None
|
||||
ref_out = native_w8a8_per_token_matmul(Aq, Bq, As, Bs, bias, dtype)
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
|
||||
Aq2, As2 = torch.ops.sgl_kernel.per_token_quant_int8_cpu(A)
|
||||
out = torch.ops.sgl_kernel.int8_scaled_mm_cpu(
|
||||
Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False
|
||||
)
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
# test the fused version
|
||||
fused_out = torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
|
||||
A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False
|
||||
)
|
||||
torch.testing.assert_close(ref_out, fused_out, atol=atol, rtol=rtol)
|
||||
|
||||
def test_int8_gemm(self):
|
||||
for params in itertools.product(
|
||||
self.M_int8,
|
||||
self.N_int8,
|
||||
self.K_int8,
|
||||
self.has_bias,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
has_bias=params[3],
|
||||
):
|
||||
self._int8_gemm(*params)
|
||||
|
||||
def _fp8_gemm(self, M, N, K, has_bias):
|
||||
prepack = True
|
||||
chunk = False
|
||||
scale_block_size_N = 64
|
||||
scale_block_size_K = 128
|
||||
assert scale_block_size_N <= N
|
||||
assert scale_block_size_K <= K
|
||||
A_dtype = torch.bfloat16
|
||||
|
||||
model = Mod(K, N, has_bias).eval()
|
||||
if chunk:
|
||||
data = torch.randn(M, K + 6, dtype=A_dtype).narrow(1, 0, K)
|
||||
else:
|
||||
data = torch.randn(M, K, dtype=A_dtype)
|
||||
|
||||
weight = model.linear.weight # (N, K)
|
||||
|
||||
if has_bias:
|
||||
bias = model.linear.bias
|
||||
|
||||
fp8_weight, scales, dq_weight = convert_weight(
|
||||
weight, [scale_block_size_N, scale_block_size_K], A_dtype
|
||||
)
|
||||
|
||||
if has_bias:
|
||||
ref = torch.matmul(data.to(A_dtype), dq_weight.T) + bias.to(A_dtype)
|
||||
else:
|
||||
ref = torch.matmul(data.to(A_dtype), dq_weight.T)
|
||||
|
||||
if prepack:
|
||||
fp8_weight = torch.ops.sgl_kernel.convert_weight_packed(fp8_weight)
|
||||
|
||||
opt = torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
|
||||
data,
|
||||
fp8_weight,
|
||||
scales,
|
||||
[scale_block_size_N, scale_block_size_K],
|
||||
bias if has_bias else None,
|
||||
data.dtype,
|
||||
prepack,
|
||||
)
|
||||
atol = rtol = precision[ref.dtype]
|
||||
torch.testing.assert_close(ref, opt, atol=atol, rtol=rtol)
|
||||
|
||||
def test_fp8_gemm(self):
|
||||
for params in itertools.product(
|
||||
self.M_fp8,
|
||||
self.N_fp8,
|
||||
self.K_fp8,
|
||||
self.has_bias,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
has_bias=params[3],
|
||||
):
|
||||
self._fp8_gemm(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
157
test/srt/cpu/test_mla.py
Normal file
157
test/srt/cpu/test_mla.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
from utils import precision
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
|
||||
class TestMLA(CustomTestCase):
|
||||
def _run_sdpa_forward_decode(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
scaling=None,
|
||||
enable_gqa=False,
|
||||
causal=False,
|
||||
):
|
||||
# set kv cache
|
||||
k_cache[loc] = key
|
||||
|
||||
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
|
||||
start_q, start_kv = 0, 0
|
||||
for seq_idx in range(seq_lens.shape[0]):
|
||||
seq_len_q = 1
|
||||
seq_len_kv = seq_lens[seq_idx]
|
||||
end_q = start_q + seq_len_q
|
||||
end_kv = start_kv + seq_len_kv
|
||||
|
||||
per_req_query = query[:, start_q:end_q, :]
|
||||
|
||||
# get key and value from cache. per_req_tokens contains the kv cache
|
||||
# index for each token in the sequence.
|
||||
req_pool_idx = req_pool_indices[seq_idx]
|
||||
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
||||
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||
|
||||
per_req_out = (
|
||||
scaled_dot_product_attention(
|
||||
per_req_query.unsqueeze(0),
|
||||
per_req_key.unsqueeze(0),
|
||||
per_req_value.unsqueeze(0),
|
||||
enable_gqa=enable_gqa,
|
||||
scale=scaling,
|
||||
is_causal=causal,
|
||||
)
|
||||
.squeeze(0)
|
||||
.movedim(query.dim() - 2, 0)
|
||||
)
|
||||
output[start_q:end_q, :, :] = per_req_out
|
||||
start_q, start_kv = end_q, end_kv
|
||||
|
||||
return output
|
||||
|
||||
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, seq_len):
|
||||
dtype = torch.bfloat16
|
||||
|
||||
total_tokens = B * seq_len
|
||||
sm_scale = 1.0 / (D**0.5)
|
||||
logit_cap = 0.0
|
||||
num_kv_splits = 8
|
||||
enable_gqa = H_Q != H_KV
|
||||
|
||||
# q represents the new token being generated, one per batch
|
||||
q = torch.randn(B, H_Q, D, dtype=dtype)
|
||||
|
||||
# k_buffer and v_buffer represent all previous tokens
|
||||
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype)
|
||||
v_buffer = k_buffer.narrow(2, 0, D_V)
|
||||
|
||||
key = torch.randn(B, H_KV, D, dtype=dtype)
|
||||
value = key.narrow(2, 0, D_V)
|
||||
# make sure no duplicates in loc
|
||||
loc = torch.randperm(total_tokens)[:B].to(torch.int64)
|
||||
|
||||
k_buffer2 = k_buffer.clone()
|
||||
v_buffer2 = k_buffer2.narrow(2, 0, D_V)
|
||||
|
||||
# o will have the same shape as q
|
||||
o = torch.zeros(B, H_Q, D_V, dtype=dtype)
|
||||
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype)
|
||||
|
||||
req_to_token = torch.arange(total_tokens).reshape(B, seq_len).to(torch.int32)
|
||||
b_req_idx = torch.arange(B).to(torch.int64)
|
||||
b_seq_len = torch.full((B,), seq_len).to(torch.int64)
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(B, H_Q, num_kv_splits, D_V + 1),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.decode_attention_cpu(
|
||||
q,
|
||||
k_buffer2,
|
||||
v_buffer2,
|
||||
o,
|
||||
key,
|
||||
value,
|
||||
loc,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
|
||||
self._run_sdpa_forward_decode(
|
||||
q,
|
||||
o_grouped,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
key,
|
||||
loc,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
scaling=sm_scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
o.flatten(), o_grouped.flatten(), dim=0
|
||||
)
|
||||
atol = rtol = precision[q.dtype]
|
||||
self.assertGreater(cos_sim.item(), 0.99)
|
||||
torch.testing.assert_close(o, o_grouped, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(k_buffer, k_buffer2, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(v_buffer, v_buffer2, atol=atol, rtol=rtol)
|
||||
|
||||
def test_grouped_decode_attention(self):
|
||||
configs = [
|
||||
(1, 22, 1, 576, 512, 8 * 111),
|
||||
(4, 22, 1, 576, 512, 8 * 128),
|
||||
(40, 22, 1, 576, 512, 8 * 133),
|
||||
]
|
||||
|
||||
for B, H_Q, H_KV, D, D_V, seqlen in configs:
|
||||
self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V, seqlen)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
265
test/srt/cpu/test_moe.py
Normal file
265
test/srt/cpu/test_moe.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import itertools
|
||||
import math
|
||||
import unittest
|
||||
|
||||
# TODO: use interface in cpu.py
|
||||
import sgl_kernel
|
||||
import torch
|
||||
|
||||
kernel = torch.ops.sgl_kernel
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
from utils import (
|
||||
BLOCK_K,
|
||||
BLOCK_N,
|
||||
factor_for_scale,
|
||||
fp8_max,
|
||||
fp8_min,
|
||||
native_fp8_fused_moe,
|
||||
precision,
|
||||
scaled_weight,
|
||||
torch_naive_fused_moe,
|
||||
torch_w8a8_per_column_fused_moe,
|
||||
)
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
def fused_moe(a, w1, w2, score, topk, renormalize, prepack):
|
||||
|
||||
G = 1
|
||||
topk_group = 1
|
||||
|
||||
B, D = a.shape
|
||||
topk_weights = torch.empty(B, topk, dtype=torch.float32)
|
||||
topk_ids = torch.empty(B, topk, dtype=torch.int32)
|
||||
topk_weights, topk_ids = kernel.grouped_topk_cpu(
|
||||
a, score, topk, renormalize, G, topk_group, 0, None, None
|
||||
)
|
||||
|
||||
packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1
|
||||
packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2
|
||||
|
||||
inplace = True
|
||||
return kernel.fused_experts_cpu(
|
||||
a,
|
||||
packed_w1,
|
||||
packed_w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace,
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
prepack,
|
||||
)
|
||||
|
||||
|
||||
class TestFusedExperts(CustomTestCase):
|
||||
M = [2, 114]
|
||||
N = [32]
|
||||
K = [32]
|
||||
E = [4]
|
||||
topk = [2]
|
||||
renormalize = [False, True]
|
||||
|
||||
M_int8 = [1, 39]
|
||||
N_int8 = [128]
|
||||
K_int8 = [256]
|
||||
E_int8 = [8]
|
||||
topk_int8 = [3]
|
||||
|
||||
M_fp8 = [2, 121]
|
||||
N_fp8 = [352, 512]
|
||||
K_fp8 = [256, 320]
|
||||
E_fp8 = [8]
|
||||
topk_fp8 = [4]
|
||||
|
||||
def _bf16_moe(self, m, n, k, e, topk, renormalize):
|
||||
dtype = torch.bfloat16
|
||||
prepack = True
|
||||
|
||||
a = torch.randn((m, k), device="cpu", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cpu", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cpu", dtype=dtype) / 10
|
||||
score = torch.randn((m, e), device="cpu", dtype=dtype)
|
||||
|
||||
torch_output = torch_naive_fused_moe(a, w1, w2, score, topk, renormalize)
|
||||
fused_output = fused_moe(a, w1, w2, score, topk, renormalize, prepack)
|
||||
|
||||
atol = rtol = precision[torch_output.dtype]
|
||||
torch.testing.assert_close(torch_output, fused_output, atol=atol, rtol=rtol)
|
||||
|
||||
def test_bf16_moe(self):
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.E,
|
||||
self.topk,
|
||||
self.renormalize,
|
||||
):
|
||||
with self.subTest(
|
||||
m=params[0],
|
||||
n=params[1],
|
||||
k=params[2],
|
||||
e=params[3],
|
||||
topk=params[4],
|
||||
renormalize=params[5],
|
||||
):
|
||||
self._bf16_moe(*params)
|
||||
|
||||
def _int8_moe(self, M, N, K, E, topk):
|
||||
dtype = torch.bfloat16
|
||||
prepack = True
|
||||
|
||||
# Initialize int8 quantization parameters
|
||||
int8_factor_for_scale = 1e-2
|
||||
int8_max = 127
|
||||
int8_min = -128
|
||||
|
||||
# Input tensor
|
||||
# M * K
|
||||
a = torch.randn((M, K), dtype=dtype) / math.sqrt(K)
|
||||
|
||||
# Generate int8 weights
|
||||
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
|
||||
w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
|
||||
w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
# Generate scale for each column (per-column quantization)
|
||||
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * int8_factor_for_scale
|
||||
w2_s = torch.rand(E, K, device=w2_fp32.device) * int8_factor_for_scale
|
||||
|
||||
# Calculate routing
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
|
||||
ref_out = torch_w8a8_per_column_fused_moe(
|
||||
a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk
|
||||
)
|
||||
|
||||
inplace = True
|
||||
packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1
|
||||
packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2
|
||||
out = kernel.fused_experts_cpu(
|
||||
a,
|
||||
packed_w1,
|
||||
packed_w2,
|
||||
topk_weight,
|
||||
topk_ids.to(torch.int32),
|
||||
inplace,
|
||||
True,
|
||||
False,
|
||||
w1_s,
|
||||
w2_s,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
prepack,
|
||||
)
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
# Increase the tolerance for large input shapes
|
||||
if M > 35:
|
||||
atol = rtol = 0.02
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
def test_int8_moe(self):
|
||||
for params in itertools.product(
|
||||
self.M_int8,
|
||||
self.N_int8,
|
||||
self.K_int8,
|
||||
self.E_int8,
|
||||
self.topk_int8,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
E=params[3],
|
||||
topk=params[4],
|
||||
):
|
||||
self._int8_moe(*params)
|
||||
|
||||
def _fp8_moe(self, M, N, K, E, topk):
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
|
||||
|
||||
w1_fp32 = torch.randn(E, 2 * N, K)
|
||||
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w2_fp32 = torch.randn(E, K, N)
|
||||
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w1s = (
|
||||
torch.randn(E, math.ceil(2 * N / BLOCK_N), math.ceil(K / BLOCK_K))
|
||||
* factor_for_scale
|
||||
)
|
||||
w2s = (
|
||||
torch.randn(E, math.ceil(K / BLOCK_N), math.ceil(N / BLOCK_K))
|
||||
* factor_for_scale
|
||||
)
|
||||
|
||||
w1_scaled = scaled_weight(w1, w1s)
|
||||
w2_scaled = scaled_weight(w2, w2s)
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
|
||||
w1 = kernel.convert_weight_packed(w1)
|
||||
w2 = kernel.convert_weight_packed(w2)
|
||||
|
||||
ref_out = native_fp8_fused_moe(
|
||||
a, w1_scaled, w2_scaled, topk_weight, topk_ids, topk
|
||||
)
|
||||
out = kernel.fused_experts_cpu(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids.to(torch.int32),
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
w1s,
|
||||
w2s,
|
||||
[BLOCK_N, BLOCK_K],
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
|
||||
atol = rtol = precision[dtype]
|
||||
torch.testing.assert_close(ref_out.bfloat16(), out, atol=atol, rtol=rtol)
|
||||
|
||||
def test_fp8_moe(self):
|
||||
for params in itertools.product(
|
||||
self.M_fp8,
|
||||
self.N_fp8,
|
||||
self.K_fp8,
|
||||
self.E_fp8,
|
||||
self.topk_fp8,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
E=params[3],
|
||||
topk=params[4],
|
||||
):
|
||||
self._fp8_moe(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
90
test/srt/cpu/test_norm.py
Normal file
90
test/srt/cpu/test_norm.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import itertools
|
||||
import unittest
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from utils import make_non_contiguous, precision
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
|
||||
class TestNorm(CustomTestCase):
|
||||
M = [4096, 1024]
|
||||
N = [4096, 4096 + 13]
|
||||
dtype = [torch.float16, torch.bfloat16]
|
||||
|
||||
def _forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
variance_epsilon: float = 1e-6,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
if residual is not None:
|
||||
x = x + residual.to(torch.float32)
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + variance_epsilon)
|
||||
x = x.to(orig_dtype) * weight
|
||||
if residual is None:
|
||||
return x
|
||||
else:
|
||||
return x, residual
|
||||
|
||||
def _norm_test(self, m, n, dtype):
|
||||
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
x = make_non_contiguous(x)
|
||||
hidden_size = x.size(-1)
|
||||
weight = torch.randn(hidden_size, dtype=dtype)
|
||||
variance_epsilon = 1e-6
|
||||
|
||||
out = torch.ops.sgl_kernel.rmsnorm_cpu(x, weight, variance_epsilon)
|
||||
ref_out = self._forward_native(x, weight, variance_epsilon)
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
ref_x = x.clone()
|
||||
residual = torch.randn([m, hidden_size], dtype=dtype)
|
||||
ref_residual = residual.clone()
|
||||
|
||||
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
|
||||
x, residual, weight, variance_epsilon
|
||||
)
|
||||
|
||||
ref_x, ref_residual = self._forward_native(
|
||||
ref_x, weight, variance_epsilon, ref_residual
|
||||
)
|
||||
|
||||
torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol)
|
||||
|
||||
def _l2norm_test(self, m, n, dtype):
|
||||
|
||||
x = torch.randn([m, n], dtype=dtype)
|
||||
hidden_size = x.size(-1)
|
||||
fake_ones_weight = torch.ones(hidden_size, dtype=dtype)
|
||||
variance_epsilon = 1e-6
|
||||
|
||||
out = torch.ops.sgl_kernel.l2norm_cpu(x, variance_epsilon)
|
||||
ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon)
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
def test_norm(self):
|
||||
for params in itertools.product(self.M, self.N, self.dtype):
|
||||
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
|
||||
self._norm_test(*params)
|
||||
self._l2norm_test(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
432
test/srt/cpu/test_qkv_proj_with_rope.py
Normal file
432
test/srt/cpu/test_qkv_proj_with_rope.py
Normal file
@@ -0,0 +1,432 @@
|
||||
import unittest
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from utils import (
|
||||
convert_weight,
|
||||
native_w8a8_per_token_matmul,
|
||||
per_token_quant_int8,
|
||||
precision,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import _apply_rotary_emb
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
convert_weight_packed = torch.ops.sgl_kernel.convert_weight_packed
|
||||
qkv_proj_with_rope = torch.ops.sgl_kernel.qkv_proj_with_rope
|
||||
qkv_proj_with_rope_fused_weight = torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight
|
||||
torch.manual_seed(1234)
|
||||
# constants
|
||||
kv_lora_rank = 512
|
||||
qk_head_dim = 192
|
||||
qk_nope_head_dim = 128
|
||||
qk_rope_head_dim = 64
|
||||
rotary_dim = qk_rope_head_dim
|
||||
num_heads = 22
|
||||
q_lora_rank = 1536
|
||||
hidden_size = 7168
|
||||
B = 1
|
||||
eps = 1e-6
|
||||
|
||||
|
||||
def layernorm(x, weight, variance_epsilon=1e-6, residual=None):
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + variance_epsilon)
|
||||
return (x * weight).to(orig_dtype)
|
||||
|
||||
|
||||
def rotary_emb(q_pe, k_pe, pos, cos_sin_cache):
|
||||
orig_dtype = q_pe.dtype
|
||||
q_pe = q_pe.float()
|
||||
k_pe = k_pe.float()
|
||||
cos_sin_cache = cos_sin_cache.float()
|
||||
|
||||
query_rot = q_pe[..., :rotary_dim]
|
||||
key_rot = k_pe[..., :rotary_dim]
|
||||
cos_sin = cos_sin_cache[pos]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
query_rot = _apply_rotary_emb(query_rot, cos, sin, False)
|
||||
key_rot = _apply_rotary_emb(key_rot, cos, sin, False)
|
||||
return query_rot.to(orig_dtype), key_rot.to(orig_dtype)
|
||||
|
||||
|
||||
def native_torch(
|
||||
q_input,
|
||||
hidden_states,
|
||||
q_a_proj_weight,
|
||||
norm_weight1,
|
||||
q_b_proj_weight,
|
||||
w_kc,
|
||||
kv_a_proj_weight,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
):
|
||||
|
||||
q = torch.matmul(hidden_states, q_a_proj_weight.t())
|
||||
q = layernorm(q, norm_weight1)
|
||||
q = torch.matmul(q, q_b_proj_weight.t()).view(-1, num_heads, qk_head_dim)
|
||||
|
||||
q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||
q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc)
|
||||
|
||||
q_input[..., :kv_lora_rank] = q_nope_out.transpose(0, 1)
|
||||
latent_cache = torch.matmul(hidden_states, kv_a_proj_weight.t())
|
||||
v_input = latent_cache[..., :kv_lora_rank]
|
||||
v_input = layernorm(v_input.contiguous(), norm_weight2).unsqueeze(1)
|
||||
k_input = latent_cache.unsqueeze(1)
|
||||
k_input[..., :kv_lora_rank] = v_input
|
||||
k_pe = k_input[..., kv_lora_rank:]
|
||||
|
||||
q_pe, k_pe = rotary_emb(q_pe, k_pe, pos, cos_sin_cache)
|
||||
q_input[..., kv_lora_rank:] = q_pe
|
||||
k_input[..., kv_lora_rank:] = k_pe
|
||||
|
||||
return q_input, k_input, v_input
|
||||
|
||||
|
||||
def native_torch_int8(
|
||||
q_input,
|
||||
hidden_states,
|
||||
w1_q,
|
||||
w1_s,
|
||||
norm_weight1,
|
||||
w2_q,
|
||||
w2_s,
|
||||
w_kc,
|
||||
w3_q,
|
||||
w3_s,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
):
|
||||
|
||||
a_q, a_s = per_token_quant_int8(hidden_states)
|
||||
q = native_w8a8_per_token_matmul(a_q, w1_q, a_s, w1_s, None, torch.bfloat16)
|
||||
q = layernorm(q, norm_weight1)
|
||||
|
||||
a_q, a_s = per_token_quant_int8(q)
|
||||
q = native_w8a8_per_token_matmul(a_q, w2_q, a_s, w2_s, None, torch.bfloat16).view(
|
||||
-1, num_heads, qk_head_dim
|
||||
)
|
||||
|
||||
q_nope, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||
q_nope_out = torch.bmm(q_nope.transpose(0, 1), w_kc)
|
||||
|
||||
q_input[..., :kv_lora_rank] = q_nope_out.transpose(0, 1)
|
||||
a_q, a_s = per_token_quant_int8(hidden_states)
|
||||
latent_cache = native_w8a8_per_token_matmul(
|
||||
a_q, w3_q, a_s, w3_s, None, torch.bfloat16
|
||||
)
|
||||
v_input = latent_cache[..., :kv_lora_rank]
|
||||
v_input = layernorm(v_input.contiguous(), norm_weight2).unsqueeze(1)
|
||||
k_input = latent_cache.unsqueeze(1)
|
||||
k_input[..., :kv_lora_rank] = v_input
|
||||
k_pe = k_input[..., kv_lora_rank:]
|
||||
|
||||
q_pe, k_pe = rotary_emb(q_pe, k_pe, pos, cos_sin_cache)
|
||||
q_input[..., kv_lora_rank:] = q_pe
|
||||
k_input[..., kv_lora_rank:] = k_pe
|
||||
|
||||
return q_input, k_input, v_input
|
||||
|
||||
|
||||
class TestQKVProjWithROPE(CustomTestCase):
|
||||
def test_bf16_qkv_proj_with_rope(self):
|
||||
dtype = torch.bfloat16
|
||||
hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size
|
||||
q_input = torch.empty(
|
||||
B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype
|
||||
)
|
||||
q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1
|
||||
norm_weight1 = torch.randn(q_lora_rank, dtype=dtype)
|
||||
q_b_proj_weight = (
|
||||
torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1
|
||||
)
|
||||
w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1
|
||||
kv_a_proj_weight = (
|
||||
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
|
||||
)
|
||||
fused_weight = torch.cat([q_a_proj_weight, kv_a_proj_weight], dim=0)
|
||||
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
|
||||
pos = torch.randint(10, 100, (B,))
|
||||
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
|
||||
q_ref, k_ref, v_ref = native_torch(
|
||||
q_input,
|
||||
hidden_states,
|
||||
q_a_proj_weight,
|
||||
norm_weight1,
|
||||
q_b_proj_weight,
|
||||
w_kc.transpose(1, 2),
|
||||
kv_a_proj_weight,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
)
|
||||
qa_packed = convert_weight_packed(q_a_proj_weight)
|
||||
qb_packed = convert_weight_packed(q_b_proj_weight)
|
||||
kva_packed = convert_weight_packed(kv_a_proj_weight)
|
||||
wkc_packed = convert_weight_packed(w_kc)
|
||||
fused_weight_packed = convert_weight_packed(fused_weight)
|
||||
|
||||
q_out, k_out, v_out = qkv_proj_with_rope(
|
||||
hidden_states,
|
||||
qa_packed,
|
||||
qb_packed,
|
||||
kva_packed,
|
||||
wkc_packed,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
None,
|
||||
)
|
||||
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
|
||||
hidden_states,
|
||||
fused_weight_packed,
|
||||
qb_packed,
|
||||
wkc_packed,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
None,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
)
|
||||
atol = rtol = precision[q_ref.dtype]
|
||||
torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(fused_q_out, q_out)
|
||||
torch.testing.assert_close(fused_k_out, k_out)
|
||||
torch.testing.assert_close(fused_v_out, v_out)
|
||||
|
||||
def test_int8_qkv_proj_with_rope(self):
|
||||
dtype = torch.bfloat16
|
||||
hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size
|
||||
q_input = torch.empty(
|
||||
B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype
|
||||
)
|
||||
q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1
|
||||
norm_weight1 = torch.randn(q_lora_rank, dtype=dtype)
|
||||
q_b_proj_weight = (
|
||||
torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1
|
||||
)
|
||||
w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1
|
||||
kv_a_proj_weight = (
|
||||
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
|
||||
)
|
||||
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
|
||||
pos = torch.randint(10, 100, (B,))
|
||||
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
|
||||
|
||||
w1_q, w1_s = per_token_quant_int8(q_a_proj_weight)
|
||||
w2_q, w2_s = per_token_quant_int8(q_b_proj_weight)
|
||||
w3_q, w3_s = per_token_quant_int8(kv_a_proj_weight)
|
||||
q_ref, k_ref, v_ref = native_torch_int8(
|
||||
q_input,
|
||||
hidden_states,
|
||||
w1_q,
|
||||
w1_s,
|
||||
norm_weight1,
|
||||
w2_q,
|
||||
w2_s,
|
||||
w_kc.transpose(1, 2),
|
||||
w3_q,
|
||||
w3_s,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
)
|
||||
w1_q_packed = convert_weight_packed(w1_q)
|
||||
w2_q_packed = convert_weight_packed(w2_q)
|
||||
w3_q_packed = convert_weight_packed(w3_q)
|
||||
wkc_packed = convert_weight_packed(w_kc)
|
||||
q_out, k_out, v_out = qkv_proj_with_rope(
|
||||
hidden_states,
|
||||
w1_q_packed,
|
||||
w2_q_packed,
|
||||
w3_q_packed,
|
||||
wkc_packed,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
True,
|
||||
False,
|
||||
w1_s,
|
||||
w2_s,
|
||||
w3_s,
|
||||
True,
|
||||
None,
|
||||
)
|
||||
fused_weight = torch.cat([w1_q, w3_q], dim=0)
|
||||
fused_weight_s = torch.cat([w1_s, w3_s], dim=0)
|
||||
w_fused_q_packed = convert_weight_packed(fused_weight)
|
||||
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
|
||||
hidden_states,
|
||||
w_fused_q_packed,
|
||||
w2_q_packed,
|
||||
wkc_packed,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
True,
|
||||
False,
|
||||
fused_weight_s,
|
||||
w2_s,
|
||||
True,
|
||||
None,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
)
|
||||
atol = rtol = precision[q_ref.dtype]
|
||||
torch.testing.assert_close(q_ref, q_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(fused_q_out, q_out)
|
||||
torch.testing.assert_close(fused_k_out, k_out)
|
||||
torch.testing.assert_close(fused_v_out, v_out)
|
||||
|
||||
def test_fp8_qkv_proj_with_rope(self):
|
||||
dtype = torch.bfloat16
|
||||
hidden_states = torch.randn(B, hidden_size, dtype=dtype) / hidden_size
|
||||
q_input = torch.empty(
|
||||
B, num_heads, kv_lora_rank + qk_rope_head_dim, dtype=dtype
|
||||
)
|
||||
q_a_proj_weight = torch.randn(q_lora_rank, hidden_size, dtype=dtype) * 0.1
|
||||
norm_weight1 = torch.randn(q_lora_rank, dtype=dtype)
|
||||
q_b_proj_weight = (
|
||||
torch.randn(num_heads * qk_head_dim, q_lora_rank, dtype=dtype) * 0.1
|
||||
)
|
||||
w_kc = torch.randn(num_heads, kv_lora_rank, qk_nope_head_dim, dtype=dtype) * 0.1
|
||||
kv_a_proj_weight = (
|
||||
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
|
||||
)
|
||||
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
|
||||
pos = torch.randint(10, 100, (B,))
|
||||
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
|
||||
|
||||
scale_block_size_N = 128
|
||||
scale_block_size_K = 128
|
||||
fp8_q_a_proj_weight, q_a_proj_weight_scale_inv, q_a_proj_weight_dq = (
|
||||
convert_weight(
|
||||
q_a_proj_weight,
|
||||
[scale_block_size_N, scale_block_size_K],
|
||||
torch.bfloat16,
|
||||
)
|
||||
)
|
||||
fp8_q_b_proj_weight, q_b_proj_weight_scale_inv, q_b_proj_weight_dq = (
|
||||
convert_weight(
|
||||
q_b_proj_weight,
|
||||
[scale_block_size_N, scale_block_size_K],
|
||||
torch.bfloat16,
|
||||
)
|
||||
)
|
||||
(
|
||||
fp8_kv_a_proj_with_mqa_weight,
|
||||
kv_a_proj_with_mqa_weight_scale_inv,
|
||||
kv_a_proj_with_mqa_weight_dq,
|
||||
) = convert_weight(
|
||||
kv_a_proj_weight, [scale_block_size_N, scale_block_size_K], torch.bfloat16
|
||||
)
|
||||
q_ref, k_ref, v_ref = native_torch(
|
||||
q_input,
|
||||
hidden_states,
|
||||
q_a_proj_weight_dq,
|
||||
norm_weight1,
|
||||
q_b_proj_weight_dq,
|
||||
w_kc.transpose(1, 2),
|
||||
kv_a_proj_with_mqa_weight_dq,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
)
|
||||
fp8_q_a_proj_weight_packed = convert_weight_packed(fp8_q_a_proj_weight)
|
||||
fp8_q_b_proj_weight_packed = convert_weight_packed(fp8_q_b_proj_weight)
|
||||
fp8_kv_a_proj_with_mqa_weight_packed = convert_weight_packed(
|
||||
fp8_kv_a_proj_with_mqa_weight
|
||||
)
|
||||
w_kc = convert_weight_packed(w_kc)
|
||||
q_out, k_out, v_out = qkv_proj_with_rope(
|
||||
hidden_states,
|
||||
fp8_q_a_proj_weight_packed,
|
||||
fp8_q_b_proj_weight_packed,
|
||||
fp8_kv_a_proj_with_mqa_weight_packed,
|
||||
w_kc,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
False,
|
||||
True,
|
||||
q_a_proj_weight_scale_inv.float(),
|
||||
q_b_proj_weight_scale_inv.float(),
|
||||
kv_a_proj_with_mqa_weight_scale_inv.float(),
|
||||
True,
|
||||
[scale_block_size_N, scale_block_size_K],
|
||||
)
|
||||
|
||||
fused_weight = torch.cat(
|
||||
[fp8_q_a_proj_weight, fp8_kv_a_proj_with_mqa_weight], dim=0
|
||||
)
|
||||
fused_weight_s = torch.cat(
|
||||
[q_a_proj_weight_scale_inv, kv_a_proj_with_mqa_weight_scale_inv], dim=0
|
||||
)
|
||||
fused_weight_packed = convert_weight_packed(fused_weight)
|
||||
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
|
||||
hidden_states,
|
||||
fused_weight_packed,
|
||||
fp8_q_b_proj_weight_packed,
|
||||
w_kc,
|
||||
norm_weight1,
|
||||
norm_weight2,
|
||||
pos,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
False,
|
||||
True,
|
||||
fused_weight_s.float(),
|
||||
q_b_proj_weight_scale_inv.float(),
|
||||
True,
|
||||
[scale_block_size_N, scale_block_size_K],
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
)
|
||||
atol = rtol = precision[q_ref.dtype]
|
||||
# Due to the change in multiplication order, the error is amplified.
|
||||
# In the model, with fewer layers, this doesn't cause issues, but in
|
||||
# tests with more layers, we need to enlarge the tolerance to pass the tests.
|
||||
torch.testing.assert_close(q_ref, q_out, atol=1e-1, rtol=1e-1)
|
||||
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(fused_q_out, q_out)
|
||||
torch.testing.assert_close(fused_k_out, k_out)
|
||||
torch.testing.assert_close(fused_v_out, v_out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
178
test/srt/cpu/test_rope.py
Normal file
178
test/srt/cpu/test_rope.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import unittest
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from utils import precision
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding,
|
||||
RotaryEmbedding,
|
||||
)
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
|
||||
class TestROPE(CustomTestCase):
|
||||
def test_deepseek_v2_rope(self):
|
||||
num_head = 16
|
||||
seq_len = 1024
|
||||
q_head_dim = 192
|
||||
qk_nope_head_dim = 128
|
||||
qk_rope_head_dim = 64
|
||||
max_pos = 256
|
||||
k_dim = 576
|
||||
rotary_dim = 64
|
||||
is_neox_style = False
|
||||
|
||||
# Create cos_sin_cache
|
||||
freqs = torch.rand(max_pos, qk_rope_head_dim // 2)
|
||||
cos = freqs.cos() * 0.7
|
||||
sin = freqs.sin() * 0.7
|
||||
cos_sin_cache = torch.cat((cos, sin), dim=-1).to(torch.bfloat16)
|
||||
positions = torch.randint(0, max_pos, (seq_len,))
|
||||
|
||||
rope = DeepseekScalingRotaryEmbedding(
|
||||
qk_rope_head_dim,
|
||||
rotary_dim,
|
||||
max_pos,
|
||||
16, # not used since cos_sin_cache is provided
|
||||
is_neox_style,
|
||||
1.0,
|
||||
torch.bfloat16,
|
||||
device="cpu",
|
||||
)
|
||||
rope.register_buffer("cos_sin_cache", cos_sin_cache)
|
||||
|
||||
for dtype in [torch.bfloat16]:
|
||||
enable_autocast = True
|
||||
|
||||
with torch.no_grad(), torch.amp.autocast("cpu", enabled=enable_autocast):
|
||||
q = torch.randn(seq_len, num_head, q_head_dim, dtype=dtype)
|
||||
q_clone = q.clone()
|
||||
k = torch.randn(seq_len, 1, k_dim, dtype=dtype)
|
||||
k_clone = k.clone()
|
||||
_, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||
_, q_pe_clone = q_clone.split(
|
||||
[qk_nope_head_dim, qk_rope_head_dim], dim=-1
|
||||
)
|
||||
k_pe = k[:, :, k_dim - qk_rope_head_dim :]
|
||||
k_pe_clone = k_clone[:, :, k_dim - qk_rope_head_dim :]
|
||||
|
||||
# ref kernel
|
||||
q_pe, k_pe = rope.forward_native(
|
||||
query=q_pe,
|
||||
key=k_pe,
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
# fused rope kernel
|
||||
q_pe_clone, k_pe_clone = torch.ops.sgl_kernel.rotary_embedding_cpu(
|
||||
positions,
|
||||
q_pe_clone,
|
||||
k_pe_clone,
|
||||
rope.head_size,
|
||||
cos_sin_cache,
|
||||
False,
|
||||
)
|
||||
|
||||
atol = rtol = precision[q_pe.dtype]
|
||||
torch.testing.assert_close(q_pe, q_pe_clone, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(k_pe, k_pe_clone, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(k_pe, k_pe_clone)
|
||||
|
||||
def test_origin_rope(self):
|
||||
def single_test(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
torch.manual_seed(100)
|
||||
rope_ref = RotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
).to(device)
|
||||
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
|
||||
query = torch.randn(
|
||||
batch_size * seq_len,
|
||||
num_q_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
key = torch.randn(
|
||||
batch_size * seq_len,
|
||||
num_kv_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
query_ref, key_ref = query.clone(), key.clone()
|
||||
query_cpu, key_cpu = query.clone(), key.clone()
|
||||
|
||||
query_ref_out, key_ref_out = rope_ref.forward_native(
|
||||
pos_ids, query_ref, key_ref
|
||||
)
|
||||
query_cpu_out, key_cpu_out = torch.ops.sgl_kernel.rotary_embedding_cpu(
|
||||
pos_ids,
|
||||
query_cpu,
|
||||
key_cpu,
|
||||
rope_ref.head_size,
|
||||
rope_ref.cos_sin_cache.to(query.dtype),
|
||||
rope_ref.is_neox_style,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
query_ref_out, query_cpu_out, atol=1e-2, rtol=1e-2
|
||||
)
|
||||
torch.testing.assert_close(key_ref_out, key_cpu_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
test_config = [
|
||||
(64, 64, 32, 8000, True, torch.bfloat16, "cpu", 32, 32, 1, 1),
|
||||
(256, 128, 4096, 10000, True, torch.bfloat16, "cpu", 2, 512, 32, 8),
|
||||
(512, 128, 311, 10000, True, torch.bfloat16, "cpu", 3, 39, 4, 2),
|
||||
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 32, 8),
|
||||
(128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 16, 4),
|
||||
(512, 128, 311, 10000, False, torch.bfloat16, "cpu", 3, 39, 4, 2),
|
||||
]
|
||||
|
||||
for (
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_q_heads,
|
||||
num_kv_heads,
|
||||
) in test_config:
|
||||
single_test(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
device,
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_q_heads,
|
||||
num_kv_heads,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
223
test/srt/cpu/test_shared_expert.py
Normal file
223
test/srt/cpu/test_shared_expert.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import itertools
|
||||
import math
|
||||
import unittest
|
||||
|
||||
# TODO: use interface in cpu.py
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from utils import (
|
||||
BLOCK_K,
|
||||
BLOCK_N,
|
||||
SiluAndMul,
|
||||
factor_for_scale,
|
||||
fp8_max,
|
||||
fp8_min,
|
||||
per_token_quant_int8,
|
||||
precision,
|
||||
scaled_weight,
|
||||
torch_naive_moe,
|
||||
torch_w8a8_per_column_moe,
|
||||
)
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
|
||||
class TestSharedExpert(CustomTestCase):
|
||||
M = [2, 121]
|
||||
N = [32, 32 * 4]
|
||||
K = [32, 32 * 2]
|
||||
routed_scaling_factor = [16]
|
||||
|
||||
M_fp8 = [2, 12]
|
||||
N_fp8 = [512]
|
||||
K_fp8 = [256]
|
||||
|
||||
def _bf16_shared_expert(self, m, n, k, routed_scaling_factor):
|
||||
dtype = torch.bfloat16
|
||||
prepack = True
|
||||
|
||||
hidden_states = torch.randn(m, k, dtype=dtype) / k
|
||||
w1 = torch.randn(2 * n, k, dtype=dtype)
|
||||
w2 = torch.randn(k, n, dtype=dtype)
|
||||
fused_output = torch.randn(m, k, dtype=dtype) / k
|
||||
|
||||
# fused moe mutates content in hs
|
||||
hidden_states2 = hidden_states.clone()
|
||||
|
||||
# bfloat16
|
||||
ref = torch_naive_moe(
|
||||
hidden_states.float(),
|
||||
w1.float(),
|
||||
w2.float(),
|
||||
fused_output.float(),
|
||||
routed_scaling_factor,
|
||||
).to(dtype=dtype)
|
||||
res = torch.ops.sgl_kernel.shared_expert_cpu(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
fused_output,
|
||||
routed_scaling_factor,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
)
|
||||
|
||||
atol = rtol = precision[ref.dtype]
|
||||
torch.testing.assert_close(ref, res, atol=atol, rtol=rtol)
|
||||
|
||||
def test_bf16_shared_expert(self):
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.routed_scaling_factor,
|
||||
):
|
||||
with self.subTest(
|
||||
m=params[0],
|
||||
n=params[1],
|
||||
k=params[2],
|
||||
routed_scaling_factor=params[3],
|
||||
):
|
||||
self._bf16_shared_expert(*params)
|
||||
|
||||
def _int8_shared_expert(self, m, n, k, routed_scaling_factor):
|
||||
dtype = torch.bfloat16
|
||||
prepack = True
|
||||
|
||||
hidden_states = torch.randn(m, k, dtype=dtype) / k
|
||||
w1 = torch.randn(2 * n, k, dtype=dtype)
|
||||
w2 = torch.randn(k, n, dtype=dtype)
|
||||
fused_output = torch.randn(m, k, dtype=dtype) / k
|
||||
|
||||
# fused moe mutates content in hs
|
||||
hidden_states2 = hidden_states.clone()
|
||||
|
||||
w1_q, w1_s = per_token_quant_int8(w1)
|
||||
w2_q, w2_s = per_token_quant_int8(w2)
|
||||
ref2 = torch_w8a8_per_column_moe(
|
||||
hidden_states2.float(),
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_s,
|
||||
w2_s,
|
||||
fused_output.float(),
|
||||
routed_scaling_factor,
|
||||
).to(dtype=dtype)
|
||||
res2 = torch.ops.sgl_kernel.shared_expert_cpu(
|
||||
hidden_states2,
|
||||
w1_q,
|
||||
w2_q,
|
||||
fused_output,
|
||||
routed_scaling_factor,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
w1_s,
|
||||
w2_s,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
)
|
||||
|
||||
atol = rtol = precision[ref2.dtype]
|
||||
torch.testing.assert_close(ref2, res2, atol=atol, rtol=rtol)
|
||||
|
||||
def test_int8_shared_expert(self):
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.routed_scaling_factor,
|
||||
):
|
||||
with self.subTest(
|
||||
m=params[0],
|
||||
n=params[1],
|
||||
k=params[2],
|
||||
routed_scaling_factor=params[3],
|
||||
):
|
||||
self._int8_shared_expert(*params)
|
||||
|
||||
def _fp8_shared_expert(self, M, N, K, routed_scaling_factor):
|
||||
dtype = torch.bfloat16
|
||||
prepack = True
|
||||
|
||||
a = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
|
||||
|
||||
w1_fp32 = torch.randn(1, 2 * N, K)
|
||||
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w2_fp32 = torch.randn(1, K, N)
|
||||
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w1s = torch.randn(1, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale
|
||||
w2s = torch.randn(1, K // BLOCK_N, N // BLOCK_K) * factor_for_scale
|
||||
|
||||
w1_scaled = scaled_weight(w1, w1s).view(2 * N, K)
|
||||
w2_scaled = scaled_weight(w2, w2s).view(K, N)
|
||||
|
||||
# change back to 2D
|
||||
w1, w2 = w1.squeeze(0), w2.squeeze(0)
|
||||
w1s, w2s = w1s.squeeze(0), w2s.squeeze(0)
|
||||
w1_scaled, w2_scaled = w1_scaled.squeeze(0), w2_scaled.squeeze(0)
|
||||
|
||||
fused_out = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
|
||||
a2 = a.clone()
|
||||
|
||||
# ref
|
||||
ic0 = torch.matmul(a.float(), w1_scaled.transpose(0, 1))
|
||||
ic1 = SiluAndMul(ic0)
|
||||
shared_out = torch.matmul(ic1, w2_scaled.transpose(0, 1))
|
||||
ref_out = shared_out + fused_out.float() * routed_scaling_factor
|
||||
ref_out = ref_out.to(dtype=dtype)
|
||||
|
||||
w1 = torch.ops.sgl_kernel.convert_weight_packed(w1) # [2N, K]
|
||||
w2 = torch.ops.sgl_kernel.convert_weight_packed(w2) # [K, N]
|
||||
out = torch.ops.sgl_kernel.shared_expert_cpu(
|
||||
a2,
|
||||
w1,
|
||||
w2,
|
||||
fused_out,
|
||||
routed_scaling_factor,
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
w1s,
|
||||
w2s,
|
||||
[BLOCK_N, BLOCK_K],
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
|
||||
|
||||
def test_fp8_shared_expert(self):
|
||||
for params in itertools.product(
|
||||
self.M_fp8,
|
||||
self.N_fp8,
|
||||
self.K_fp8,
|
||||
self.routed_scaling_factor,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
routed_scaling_factor=params[3],
|
||||
):
|
||||
self._fp8_shared_expert(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
199
test/srt/cpu/test_topk.py
Normal file
199
test/srt/cpu/test_topk.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
import sgl_kernel
|
||||
import torch
|
||||
from utils import precision
|
||||
|
||||
from sglang.srt.layers.moe.topk import (
|
||||
biased_grouped_topk_impl as native_biased_grouped_topk,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import fused_topk_torch_native as native_fused_topk
|
||||
from sglang.srt.layers.moe.topk import grouped_topk_gpu as native_grouped_topk
|
||||
from sglang.srt.models.llama4 import Llama4MoE
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(1234)
|
||||
|
||||
|
||||
# This is used by the Deepseek-V2 model
|
||||
class TestGroupedTopK(CustomTestCase):
|
||||
def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype):
|
||||
torch.manual_seed(1234)
|
||||
|
||||
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||
|
||||
ref_topk_weights, ref_topk_ids = native_grouped_topk(
|
||||
hidden_states.float(),
|
||||
gating_output.float(),
|
||||
topk,
|
||||
renormalize,
|
||||
G,
|
||||
topk_group,
|
||||
)
|
||||
|
||||
# fused version
|
||||
topk_weights, topk_ids = torch.ops.sgl_kernel.grouped_topk_cpu(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
topk,
|
||||
renormalize,
|
||||
G,
|
||||
topk_group,
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
res = torch.zeros(M, E, dtype=torch.float)
|
||||
ref = torch.zeros(M, E, dtype=torch.float)
|
||||
res.scatter_(1, topk_ids.long(), topk_weights)
|
||||
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
||||
torch.testing.assert_close(res, ref)
|
||||
|
||||
def test_grouped_topk(self):
|
||||
for renormalize in [True, False]:
|
||||
self._run_single_test(123, 8, 2, 2, 1, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 16, 4, 3, 2, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 32, 4, 3, 2, renormalize, torch.bfloat16)
|
||||
self._run_single_test(1123, 32, 4, 3, 2, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 64, 1, 6, 1, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 256, 8, 4, 8, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 160, 8, 6, 2, renormalize, torch.bfloat16)
|
||||
|
||||
|
||||
# DeepSeek V2/V3/R1 uses biased_grouped_top
|
||||
class TestBiasedGroupedTopK(CustomTestCase):
|
||||
def _run_single_test(
|
||||
self, M, E, G, topk, topk_group, renormalize, dtype, bias_dtype
|
||||
):
|
||||
torch.manual_seed(1234)
|
||||
|
||||
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||
correction_bias = torch.randn(E, dtype=bias_dtype)
|
||||
|
||||
ref_topk_weights, ref_topk_ids = native_biased_grouped_topk(
|
||||
hidden_states.float(),
|
||||
gating_output.float(),
|
||||
correction_bias.float(),
|
||||
topk,
|
||||
renormalize,
|
||||
G,
|
||||
topk_group,
|
||||
)
|
||||
|
||||
# fused version
|
||||
topk_weights, topk_ids = torch.ops.sgl_kernel.biased_grouped_topk_cpu(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
correction_bias,
|
||||
topk,
|
||||
renormalize,
|
||||
G,
|
||||
topk_group,
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
res = torch.zeros(M, E, dtype=torch.float)
|
||||
ref = torch.zeros(M, E, dtype=torch.float)
|
||||
res.scatter_(1, topk_ids.long(), topk_weights)
|
||||
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
||||
torch.testing.assert_close(res, ref)
|
||||
|
||||
def test_biased_grouped_topk(self):
|
||||
for renormalize in [True, False]:
|
||||
for bias_dtype in [torch.float32, torch.bfloat16]:
|
||||
self._run_single_test(
|
||||
122, 256, 8, 8, 2, renormalize, torch.bfloat16, bias_dtype
|
||||
)
|
||||
|
||||
|
||||
class TestTopK(CustomTestCase):
|
||||
def _run_single_test(self, M, E, topk, renormalize, dtype):
|
||||
torch.manual_seed(1998)
|
||||
|
||||
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||
|
||||
ref_topk_weights, ref_topk_ids = native_fused_topk(
|
||||
hidden_states.float(),
|
||||
gating_output.float(),
|
||||
topk,
|
||||
renormalize,
|
||||
)
|
||||
|
||||
# fused version
|
||||
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
|
||||
hidden_states, gating_output, topk, renormalize
|
||||
)
|
||||
|
||||
res = torch.zeros(M, E, dtype=torch.float)
|
||||
ref = torch.zeros(M, E, dtype=torch.float)
|
||||
res.scatter_(1, topk_ids.long(), topk_weights)
|
||||
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
||||
torch.testing.assert_close(res, ref)
|
||||
|
||||
def test_topk(self):
|
||||
for renormalize in [True, False]:
|
||||
self._run_single_test(123, 8, 2, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 16, 3, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 32, 3, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 64, 6, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 256, 4, renormalize, torch.bfloat16)
|
||||
self._run_single_test(123, 160, 6, renormalize, torch.bfloat16)
|
||||
|
||||
|
||||
class TestCustomTopK(CustomTestCase):
|
||||
def _run_single_test(
|
||||
self, M, E, topk, renormalize, dtype, native_custom_f, fused_custom_f
|
||||
):
|
||||
torch.manual_seed(16)
|
||||
|
||||
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
|
||||
hidden_states = torch.randn(M, 100, dtype=dtype)
|
||||
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
|
||||
|
||||
ref_topk_weights, ref_topk_ids = native_custom_f(
|
||||
hidden_states.float(),
|
||||
gating_output.float(),
|
||||
topk,
|
||||
renormalize,
|
||||
)
|
||||
|
||||
# fused version
|
||||
topk_weights, topk_ids = fused_custom_f(
|
||||
hidden_states, gating_output, topk, renormalize
|
||||
)
|
||||
|
||||
res = torch.zeros(M, E, dtype=torch.float)
|
||||
ref = torch.zeros(M, E, dtype=torch.float)
|
||||
res.scatter_(1, topk_ids.long(), topk_weights)
|
||||
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
|
||||
torch.testing.assert_close(res, ref)
|
||||
|
||||
def test_custom_topk(self):
|
||||
test_custom_functions = [
|
||||
(Llama4MoE.custom_routing_function, torch.ops.sgl_kernel.topk_sigmoid_cpu)
|
||||
]
|
||||
for native_custom_f, fused_custom_f in test_custom_functions:
|
||||
self._run_single_test(
|
||||
123, 8, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
|
||||
)
|
||||
self._run_single_test(
|
||||
123, 16, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
|
||||
)
|
||||
self._run_single_test(
|
||||
123, 32, 1, False, torch.bfloat16, native_custom_f, fused_custom_f
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
269
test/srt/cpu/utils.py
Normal file
269
test/srt/cpu/utils.py
Normal file
@@ -0,0 +1,269 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
precision = {
|
||||
torch.bfloat16: 1e-2,
|
||||
torch.float16: 1e-3,
|
||||
torch.float32: 1e-5,
|
||||
}
|
||||
|
||||
|
||||
BLOCK_N, BLOCK_K = 64, 128
|
||||
factor_for_scale = 1e-3
|
||||
fp8_max, fp8_min = 400, -400
|
||||
|
||||
|
||||
def SiluAndMul(x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
return F.silu(x[..., :d]) * x[..., d:]
|
||||
|
||||
|
||||
def per_token_quant_int8(x):
|
||||
x = x.float()
|
||||
absmax = x.abs().max(dim=-1).values
|
||||
absmax = absmax.clamp_min(1e-10).unsqueeze(-1)
|
||||
scale_x = absmax / 127
|
||||
x_q = x.mul(127 / absmax)
|
||||
x_q = torch.round(x_q).to(torch.int8)
|
||||
|
||||
return x_q, scale_x
|
||||
|
||||
|
||||
def convert_weight(weight, scale_block_size, A_dtype):
|
||||
N, K = weight.size()
|
||||
fp8_max = 448.0
|
||||
scale_block_size_N, scale_block_size_K = scale_block_size # (128, 128)
|
||||
|
||||
pad_N = (scale_block_size_N - (N % scale_block_size_N)) % scale_block_size_N
|
||||
pad_K = (scale_block_size_K - (K % scale_block_size_K)) % scale_block_size_K
|
||||
|
||||
if pad_N > 0 or pad_K > 0:
|
||||
weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N))
|
||||
|
||||
weight_blocks = weight.view(
|
||||
math.ceil(N / scale_block_size_N),
|
||||
scale_block_size_N,
|
||||
math.ceil(K / scale_block_size_K),
|
||||
scale_block_size_K,
|
||||
) # (8, 128, 8, 128)
|
||||
weight_blocks = weight_blocks.permute(0, 2, 1, 3).contiguous() # (8, 8, 128, 128)
|
||||
|
||||
# Step 2: compute per-block max abs values → scale
|
||||
abs_max = weight_blocks.abs().amax(dim=(-2, -1), keepdim=True) # (8, 8, 1, 1)
|
||||
scales = abs_max / fp8_max
|
||||
scales = torch.where(
|
||||
scales == 0, torch.ones_like(scales), scales
|
||||
) # avoid division by zero
|
||||
|
||||
q_fp8 = (weight_blocks / scales).to(torch.float8_e4m3fn)
|
||||
q_fp8_reshape = q_fp8.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
if pad_N > 0 or pad_K > 0:
|
||||
q_fp8_reshape = q_fp8_reshape.view(N + pad_N, K + pad_K)
|
||||
q_fp8_reshape = q_fp8_reshape[:N, :K].contiguous()
|
||||
else:
|
||||
q_fp8_reshape = q_fp8_reshape.view(N, K)
|
||||
|
||||
dq_weight = q_fp8.float() * scales
|
||||
dq_weight = dq_weight.permute(0, 2, 1, 3).contiguous() # (8, 128, 8, 128)
|
||||
|
||||
if pad_N > 0 or pad_K > 0:
|
||||
w_dq = dq_weight.view(N + pad_N, K + pad_K).to(A_dtype)
|
||||
w_dq = w_dq[:N, :K].contiguous()
|
||||
else:
|
||||
w_dq = dq_weight.view(N, K).to(A_dtype)
|
||||
|
||||
scales = scales.view(
|
||||
math.ceil(N / scale_block_size_N), math.ceil(K / scale_block_size_K)
|
||||
)
|
||||
|
||||
return q_fp8_reshape, scales, w_dq
|
||||
|
||||
|
||||
def native_w8a8_per_token_matmul(A, B, As, Bs, bias, output_dtype=torch.bfloat16):
|
||||
"""Matrix multiplication function that supports per-token input quantization and per-column weight quantization"""
|
||||
A = A.to(torch.float32)
|
||||
B = B.to(torch.float32)
|
||||
|
||||
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
|
||||
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
|
||||
|
||||
# Reshape input
|
||||
M = A.numel() // A.shape[-1]
|
||||
B = B.t() # Transpose weight matrix
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (K,)
|
||||
A = A.reshape(M, N)
|
||||
|
||||
# As is per-token [M, 1], Bs is per-column [1, K]
|
||||
C = torch.matmul(A, B) # [M, K]
|
||||
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
|
||||
|
||||
if bias is not None:
|
||||
C.add_(bias.view(1, -1))
|
||||
|
||||
return C.reshape(origin_C_shape).to(output_dtype)
|
||||
|
||||
|
||||
def torch_naive_moe(a, w1, w2, b, routed_scaling_factor):
|
||||
|
||||
ic1 = torch.matmul(a, w1.transpose(0, 1))
|
||||
ic2 = SiluAndMul(ic1)
|
||||
ic3 = torch.matmul(ic2, w2.transpose(0, 1))
|
||||
|
||||
return ic3 + b * routed_scaling_factor
|
||||
|
||||
|
||||
def torch_w8a8_per_column_moe(a, w1_q, w2_q, w1_s, w2_s, b, routed_scaling_factor):
|
||||
|
||||
# Perform per-token quantization
|
||||
a_q, a_s = per_token_quant_int8(a)
|
||||
|
||||
ic1 = native_w8a8_per_token_matmul(
|
||||
a_q, w1_q, a_s, w1_s, bias=None, output_dtype=torch.float32
|
||||
)
|
||||
ic2 = SiluAndMul(ic1)
|
||||
|
||||
a1_q, a1_s = per_token_quant_int8(ic2)
|
||||
ic3 = native_w8a8_per_token_matmul(
|
||||
a1_q, w2_q, a1_s, w2_s, bias=None, output_dtype=torch.float32
|
||||
)
|
||||
|
||||
return ic3 + b * routed_scaling_factor
|
||||
|
||||
|
||||
def scaled_weight(weight, scales):
|
||||
E, N, K = weight.shape
|
||||
pad_N = (BLOCK_N - (N % BLOCK_N)) % BLOCK_N
|
||||
pad_K = (BLOCK_K - (K % BLOCK_K)) % BLOCK_K
|
||||
|
||||
if pad_N > 0 or pad_K > 0:
|
||||
weight = torch.nn.functional.pad(weight, (0, pad_K, 0, pad_N))
|
||||
|
||||
weight_block = (
|
||||
weight.view(E, math.ceil(N / BLOCK_N), BLOCK_N, math.ceil(K / BLOCK_K), BLOCK_K)
|
||||
.permute(0, 1, 3, 2, 4)
|
||||
.float()
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
weight_scaled = (
|
||||
(
|
||||
weight_block
|
||||
* scales.view(E, math.ceil(N / BLOCK_N), math.ceil(K / BLOCK_K), 1, 1)
|
||||
)
|
||||
.permute(0, 1, 3, 2, 4)
|
||||
.contiguous()
|
||||
)
|
||||
if pad_N > 0 or pad_K > 0:
|
||||
weight_scaled = weight_scaled.view(E, N + pad_N, K + pad_K)
|
||||
weight_scaled = weight_scaled[..., :N, :K].contiguous()
|
||||
else:
|
||||
weight_scaled = weight_scaled.view(E, N, K)
|
||||
return weight_scaled
|
||||
|
||||
|
||||
def torch_naive_fused_moe(a, w1, w2, score, topk, renormalize):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
|
||||
if renormalize:
|
||||
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
|
||||
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
|
||||
0, 1
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
def torch_w8a8_per_column_fused_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk):
|
||||
"""This function performs fused moe with per-column int8 quantization using native torch."""
|
||||
|
||||
B, D = a.shape
|
||||
# Perform per-token quantization
|
||||
a_q, a_s = per_token_quant_int8(a)
|
||||
# Repeat tokens to match topk
|
||||
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
# Also repeat the scale
|
||||
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
|
||||
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device)
|
||||
|
||||
# Calculate routing
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
# Process each expert
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
# First MLP layer: note that a_s is now per-token
|
||||
inter_out = native_w8a8_per_token_matmul(
|
||||
a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
bias=None,
|
||||
output_dtype=torch.float32,
|
||||
)
|
||||
# Activation function
|
||||
act_out = SiluAndMul(inter_out)
|
||||
# Quantize activation output with per-token
|
||||
act_out_q, act_out_s = per_token_quant_int8(act_out)
|
||||
# Second MLP layer
|
||||
out[mask] = native_w8a8_per_token_matmul(
|
||||
act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
bias=None,
|
||||
output_dtype=torch.float32,
|
||||
)
|
||||
# Apply routing weights and sum
|
||||
return (
|
||||
(out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype))
|
||||
.sum(dim=1)
|
||||
.to(a.dtype)
|
||||
)
|
||||
|
||||
|
||||
def native_fp8_fused_moe(a, w1, w2, topk_weight, topk_ids, topk):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D).float()
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device)
|
||||
|
||||
# Calculate routing
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
ic0 = torch.matmul(a[mask], w1[i].transpose(0, 1))
|
||||
ic1 = SiluAndMul(ic0)
|
||||
out[mask] = torch.matmul(ic1, w2[i].transpose(0, 1))
|
||||
|
||||
return (
|
||||
(out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype))
|
||||
.sum(dim=1)
|
||||
.to(a.dtype)
|
||||
)
|
||||
|
||||
|
||||
def make_non_contiguous(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Make a tensor non-contiguous by slicing it via last dimension.
|
||||
"""
|
||||
last_dim = x.shape[-1]
|
||||
return x[..., : last_dim // 2] if x.is_contiguous() else x
|
||||
File diff suppressed because one or more lines are too long
206
test/srt/entrypoints/http_server/test_abort_request.py
Normal file
206
test/srt/entrypoints/http_server/test_abort_request.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Integration test for abort_request functionality with a SGLang server.
|
||||
|
||||
Run with:
|
||||
python -m unittest sglang.test.srt.entrypoints.http_server.test_abort_request -v
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestAbortRequest(CustomTestCase):
|
||||
"""Integration test class for abort request functionality."""
|
||||
|
||||
model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
base_url = DEFAULT_URL_FOR_TEST
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Launch the server."""
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=["--disable-cuda-graph"],
|
||||
)
|
||||
|
||||
cls.completion_url = f"{cls.base_url}/generate"
|
||||
cls.abort_url = f"{cls.base_url}/abort_request"
|
||||
cls.health_url = f"{cls.base_url}/health"
|
||||
|
||||
print(f"Server started at {cls.base_url}")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
"""Clean up the server."""
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def _send_completion_request(
|
||||
self,
|
||||
text: str,
|
||||
request_id: str,
|
||||
max_tokens: int = 50,
|
||||
temperature: float = 0.8,
|
||||
stream: bool = True,
|
||||
) -> requests.Response:
|
||||
"""Send a completion request to the server."""
|
||||
payload = {
|
||||
"text": text,
|
||||
"sampling_params": {
|
||||
"max_new_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
"stream": stream,
|
||||
"rid": request_id,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self.completion_url,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=30,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _send_abort_request(self, request_id: str) -> requests.Response:
|
||||
"""Send an abort request."""
|
||||
payload = {"rid": request_id}
|
||||
return requests.post(self.abort_url, json=payload, timeout=10)
|
||||
|
||||
def _check_server_health(self) -> bool:
|
||||
"""Check if server is healthy."""
|
||||
try:
|
||||
response = requests.get(self.health_url, timeout=5)
|
||||
return response.status_code == 200
|
||||
except:
|
||||
return False
|
||||
|
||||
def test_abort_during_non_streaming_generation(self):
|
||||
"""Test aborting a non-streaming request during generation."""
|
||||
self.assertTrue(self._check_server_health(), "Server should be healthy")
|
||||
|
||||
request_id = "test_abort_non_streaming"
|
||||
completion_result = {}
|
||||
|
||||
def run_completion():
|
||||
response = self._send_completion_request(
|
||||
"Write a detailed essay about artificial intelligence",
|
||||
max_tokens=500,
|
||||
temperature=1,
|
||||
request_id=request_id,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
completion_result["text"] = result.get("text", "")
|
||||
completion_result["finish_reason"] = result.get("meta_info", {}).get(
|
||||
"finish_reason"
|
||||
)
|
||||
|
||||
completion_thread = threading.Thread(target=run_completion)
|
||||
completion_thread.start()
|
||||
time.sleep(0.1)
|
||||
|
||||
abort_response = self._send_abort_request(request_id)
|
||||
completion_thread.join()
|
||||
|
||||
self.assertEqual(abort_response.status_code, 200)
|
||||
self.assertIsNotNone(completion_result, "Should have completion result")
|
||||
if completion_result:
|
||||
finish_reason_obj = completion_result.get("finish_reason")
|
||||
self.assertIsNotNone(finish_reason_obj, "Should have finish_reason")
|
||||
if finish_reason_obj:
|
||||
self.assertEqual(
|
||||
finish_reason_obj.get("type"), "abort", "Should be aborted"
|
||||
)
|
||||
|
||||
def test_batch_requests_with_selective_abort(self):
|
||||
"""Test multiple concurrent requests with selective abort of one request."""
|
||||
self.assertTrue(self._check_server_health(), "Server should be healthy")
|
||||
|
||||
request_ids = ["batch_test_0", "batch_test_1", "batch_test_2"]
|
||||
abort_target_id = "batch_test_1"
|
||||
completion_results = {}
|
||||
threads = []
|
||||
|
||||
def run_completion(req_id, prompt):
|
||||
response = self._send_completion_request(
|
||||
f"Write a story about {prompt}",
|
||||
max_tokens=100,
|
||||
temperature=0.8,
|
||||
request_id=req_id,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
completion_results[req_id] = {
|
||||
"text": result.get("text", ""),
|
||||
"finish_reason": result.get("meta_info", {}).get("finish_reason"),
|
||||
}
|
||||
|
||||
# Start all requests
|
||||
prompts = ["a knight's adventure", "a space discovery", "a chef's restaurant"]
|
||||
for i, req_id in enumerate(request_ids):
|
||||
thread = threading.Thread(target=run_completion, args=(req_id, prompts[i]))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Abort one request
|
||||
time.sleep(0.1)
|
||||
abort_response = self._send_abort_request(abort_target_id)
|
||||
|
||||
# Wait for completion
|
||||
for thread in threads:
|
||||
thread.join(timeout=30)
|
||||
|
||||
# Verify results
|
||||
self.assertEqual(abort_response.status_code, 200)
|
||||
|
||||
# Check aborted request
|
||||
aborted_result = completion_results.get(abort_target_id)
|
||||
self.assertIsNotNone(
|
||||
aborted_result, f"Aborted request {abort_target_id} should have result"
|
||||
)
|
||||
if aborted_result:
|
||||
aborted_finish_reason = aborted_result.get("finish_reason")
|
||||
self.assertIsNotNone(
|
||||
aborted_finish_reason, "Aborted request should have finish_reason"
|
||||
)
|
||||
if aborted_finish_reason:
|
||||
self.assertEqual(aborted_finish_reason.get("type"), "abort")
|
||||
|
||||
# Check other requests completed normally
|
||||
normal_completions = 0
|
||||
for req_id in request_ids:
|
||||
if req_id != abort_target_id and req_id in completion_results:
|
||||
result = completion_results[req_id]
|
||||
if result:
|
||||
finish_reason = result.get("finish_reason")
|
||||
if finish_reason and finish_reason.get("type") == "length":
|
||||
normal_completions += 1
|
||||
|
||||
self.assertEqual(
|
||||
normal_completions, 2, "Other 2 requests should complete normally"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2, warnings="ignore")
|
||||
445
test/srt/ep/test_deepep_internode.py
Normal file
445
test/srt/ep/test_deepep_internode.py
Normal file
@@ -0,0 +1,445 @@
|
||||
# Copy from deepseek-ai/DeepEP/tests/test_internode.py
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
import deep_ep
|
||||
|
||||
# Test compatibility with low latency functions
|
||||
import test_deepep_low_latency
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.test.test_deepep_utils import (
|
||||
bench,
|
||||
calc_diff,
|
||||
create_grouped_scores,
|
||||
init_dist,
|
||||
inplace_unique,
|
||||
per_token_cast_back,
|
||||
per_token_cast_to_fp8,
|
||||
)
|
||||
|
||||
|
||||
def test_main(
|
||||
num_sms: int,
|
||||
local_rank: int,
|
||||
num_local_ranks: int,
|
||||
num_ranks: int,
|
||||
num_nodes: int,
|
||||
rank: int,
|
||||
buffer: deep_ep.Buffer,
|
||||
group: dist.ProcessGroup,
|
||||
):
|
||||
# Settings
|
||||
num_tokens, hidden, num_topk_groups, num_topk, num_experts = (
|
||||
4096,
|
||||
7168,
|
||||
min(num_nodes, 4),
|
||||
8,
|
||||
(256 // num_ranks) * num_ranks,
|
||||
)
|
||||
assert num_experts % num_ranks == 0 and num_local_ranks == 8
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Random data
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * rank
|
||||
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
|
||||
x_e4m3 = per_token_cast_to_fp8(x)
|
||||
scores = (
|
||||
torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs()
|
||||
+ 1
|
||||
)
|
||||
group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1)
|
||||
group_idx = torch.topk(
|
||||
group_scores, k=num_topk_groups, dim=-1, sorted=False
|
||||
).indices
|
||||
masked_scores = create_grouped_scores(scores, group_idx, num_nodes)
|
||||
topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[
|
||||
1
|
||||
]
|
||||
topk_weights = (
|
||||
torch.ones((num_tokens, num_topk), dtype=torch.float32, device="cuda") * rank
|
||||
)
|
||||
topk_weights_pure_rand = torch.randn(
|
||||
(num_tokens, num_topk), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
rank_idx = topk_idx // (num_experts // num_ranks)
|
||||
rank_idx.masked_fill_(topk_idx == -1, -1)
|
||||
inplace_unique(rank_idx, num_ranks)
|
||||
rdma_rank_idx = rank_idx // num_local_ranks
|
||||
rdma_rank_idx.masked_fill_(rank_idx == -1, -1)
|
||||
inplace_unique(rdma_rank_idx, num_nodes)
|
||||
|
||||
# RDMA dispatch counts
|
||||
rdma_idx = topk_idx // (num_experts // num_nodes)
|
||||
rdma_idx.masked_fill_(topk_idx == -1, -1)
|
||||
inplace_unique(rdma_idx, num_nodes)
|
||||
num_rdma_token_sent = rdma_idx.ne(-1).sum().item()
|
||||
|
||||
# Expert meta
|
||||
num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="cuda")
|
||||
for i in range(num_experts):
|
||||
num_tokens_per_expert[i] = (topk_idx == i).sum()
|
||||
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
|
||||
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
|
||||
|
||||
# Rank layout meta
|
||||
num_tokens_per_rank = torch.empty((num_ranks,), dtype=torch.int, device="cuda")
|
||||
num_tokens_per_rdma_rank = torch.empty((num_nodes,), dtype=torch.int, device="cuda")
|
||||
token_idx_in_rank = torch.full(
|
||||
(num_ranks, num_tokens), -1, dtype=torch.long, device="cuda"
|
||||
)
|
||||
for i in range(num_ranks):
|
||||
num_tokens_per_rank[i] = (rank_idx == i).sum()
|
||||
token_sel = (rank_idx == i).max(dim=-1)[0]
|
||||
count = token_sel.sum().item()
|
||||
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
|
||||
tokens[:count] = torch.sort(tokens[:count])[0]
|
||||
token_idx_in_rank[i][tokens[:count]] = torch.arange(
|
||||
count, dtype=torch.long, device="cuda"
|
||||
)
|
||||
for i in range(num_nodes):
|
||||
num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum()
|
||||
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
|
||||
is_token_in_rank = token_idx_in_rank >= 0
|
||||
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
|
||||
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
|
||||
|
||||
(
|
||||
ref_num_tokens_per_rank,
|
||||
ref_num_tokens_per_rdma_rank,
|
||||
ref_num_tokens_per_expert,
|
||||
ref_is_token_in_rank,
|
||||
_,
|
||||
) = buffer.get_dispatch_layout(topk_idx, num_experts)
|
||||
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
|
||||
assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank)
|
||||
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
|
||||
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
|
||||
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
|
||||
if local_rank == 0:
|
||||
print(f"[layout] Kernel performance: {t * 1000:.3f} ms", flush=True)
|
||||
print("", flush=True)
|
||||
group.barrier()
|
||||
time.sleep(1)
|
||||
|
||||
# Config
|
||||
rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512)
|
||||
config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size)
|
||||
|
||||
# Test dispatch
|
||||
# noinspection PyShadowingNames
|
||||
def check_data(check_x, recv_gbl_rank_prefix_sum):
|
||||
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
|
||||
check_start = 0
|
||||
for i in range(num_ranks):
|
||||
check_end = recv_gbl_rank_prefix_sum[i].item()
|
||||
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
|
||||
check_start = check_end
|
||||
|
||||
for previous_mode in (False, True):
|
||||
for async_mode in (False, True):
|
||||
for current_x in (x_pure_rand, x, x_e4m3):
|
||||
for with_topk in (False, True):
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...',
|
||||
flush=True,
|
||||
end="",
|
||||
)
|
||||
dispatch_args = {
|
||||
"x": current_x,
|
||||
"num_tokens_per_rank": num_tokens_per_rank,
|
||||
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
|
||||
"is_token_in_rank": is_token_in_rank,
|
||||
"num_tokens_per_expert": num_tokens_per_expert,
|
||||
"config": config,
|
||||
"async_finish": async_mode,
|
||||
}
|
||||
if with_topk:
|
||||
dispatch_args.update(
|
||||
{
|
||||
"topk_idx": topk_idx,
|
||||
"topk_weights": (
|
||||
topk_weights_pure_rand
|
||||
if current_x is x_pure_rand
|
||||
else topk_weights
|
||||
),
|
||||
}
|
||||
)
|
||||
if previous_mode:
|
||||
dispatch_args.update({"previous_event": buffer.capture()})
|
||||
(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
recv_num_tokens_per_expert_list,
|
||||
handle,
|
||||
event,
|
||||
) = buffer.dispatch(**dispatch_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
recv_x = (
|
||||
per_token_cast_back(*recv_x)
|
||||
if isinstance(recv_x, tuple)
|
||||
else recv_x
|
||||
)
|
||||
|
||||
# Checks
|
||||
recv_gbl_rank_prefix_sum = handle[-4]
|
||||
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(
|
||||
0
|
||||
), f"{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}"
|
||||
assert (
|
||||
gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist()
|
||||
== recv_num_tokens_per_expert_list
|
||||
)
|
||||
if current_x is not x_pure_rand:
|
||||
check_data(recv_x, recv_gbl_rank_prefix_sum)
|
||||
if with_topk:
|
||||
# Check `topk_idx`
|
||||
assert (
|
||||
recv_topk_idx.eq(-1)
|
||||
| (
|
||||
(recv_topk_idx >= 0)
|
||||
& (recv_topk_idx < (num_experts // num_ranks))
|
||||
)
|
||||
).sum().item() == recv_topk_idx.numel()
|
||||
for i, count in enumerate(recv_num_tokens_per_expert_list):
|
||||
assert recv_topk_idx.eq(i).sum().item() == count
|
||||
|
||||
# Check `topk_weights`
|
||||
if current_x is not x_pure_rand:
|
||||
recv_topk_weights[recv_topk_idx.eq(-1)] = (
|
||||
recv_topk_weights.amax(dim=1, keepdim=True).expand_as(
|
||||
recv_topk_weights
|
||||
)[recv_topk_idx.eq(-1)]
|
||||
)
|
||||
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)
|
||||
|
||||
# Test cached dispatch (must without top-k staffs)
|
||||
if not with_topk:
|
||||
dispatch_args = {
|
||||
"x": current_x,
|
||||
"handle": handle,
|
||||
"config": config,
|
||||
"async_finish": async_mode,
|
||||
}
|
||||
if previous_mode:
|
||||
dispatch_args.update({"previous_event": buffer.capture()})
|
||||
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
recv_x = (
|
||||
per_token_cast_back(*recv_x)
|
||||
if isinstance(recv_x, tuple)
|
||||
else recv_x
|
||||
)
|
||||
if current_x is not x_pure_rand:
|
||||
check_data(recv_x, recv_gbl_rank_prefix_sum)
|
||||
|
||||
# Test combine
|
||||
combine_args = {
|
||||
"x": recv_x,
|
||||
"handle": handle,
|
||||
"config": config,
|
||||
"async_finish": async_mode,
|
||||
}
|
||||
if with_topk:
|
||||
combine_args.update({"topk_weights": recv_topk_weights})
|
||||
if previous_mode:
|
||||
dispatch_args.update({"previous_event": buffer.capture()})
|
||||
combined_x, combined_topk_weights, event = buffer.combine(
|
||||
**combine_args
|
||||
)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
check_x = combined_x.float() / is_token_in_rank.sum(
|
||||
dim=1
|
||||
).unsqueeze(1)
|
||||
ref_x = x_pure_rand if current_x is x_pure_rand else x
|
||||
assert calc_diff(check_x, ref_x) < 5e-6
|
||||
if with_topk:
|
||||
check_topk_weights = (
|
||||
combined_topk_weights
|
||||
if (current_x is x_pure_rand)
|
||||
else (
|
||||
combined_topk_weights
|
||||
/ is_token_in_rank.sum(dim=1).unsqueeze(1)
|
||||
)
|
||||
)
|
||||
ref_topk_weights = (
|
||||
topk_weights_pure_rand
|
||||
if current_x is x_pure_rand
|
||||
else topk_weights
|
||||
)
|
||||
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
|
||||
|
||||
# For later tuning
|
||||
dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2
|
||||
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
|
||||
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
|
||||
combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes
|
||||
|
||||
if local_rank == 0:
|
||||
print(" passed", flush=True)
|
||||
if local_rank == 0:
|
||||
print("", flush=True)
|
||||
|
||||
# Tune dispatch performance
|
||||
best_dispatch_results = None
|
||||
fp8_factor = (1 + 4 / 128) / 2
|
||||
for current_x in (x_e4m3, x):
|
||||
best_time, best_results = 1e10, None
|
||||
rdma_send_bytes = (
|
||||
(dispatch_bf16_rdma_send_bytes * fp8_factor)
|
||||
if isinstance(current_x, tuple)
|
||||
else dispatch_bf16_rdma_send_bytes
|
||||
)
|
||||
nvl_recv_bytes = (
|
||||
(dispatch_bf16_nvl_recv_bytes * fp8_factor)
|
||||
if isinstance(current_x, tuple)
|
||||
else dispatch_bf16_nvl_recv_bytes
|
||||
)
|
||||
for nvl_chunk_size in range(4, 33, 4):
|
||||
for rdma_chunk_size in range(4, 33, 4):
|
||||
config = deep_ep.Config(
|
||||
num_sms,
|
||||
nvl_chunk_size,
|
||||
nvl_buffer_size,
|
||||
rdma_chunk_size,
|
||||
rdma_buffer_size,
|
||||
)
|
||||
tune_args = {"x": current_x, "handle": handle, "config": config}
|
||||
t = bench(lambda: buffer.dispatch(**tune_args))[0]
|
||||
if t < best_time:
|
||||
best_time, best_results = t, (
|
||||
num_sms,
|
||||
nvl_chunk_size,
|
||||
rdma_chunk_size,
|
||||
)
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ",
|
||||
flush=True,
|
||||
)
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)',
|
||||
flush=True,
|
||||
)
|
||||
print("", flush=True)
|
||||
|
||||
if isinstance(current_x, tuple):
|
||||
# Gather FP8 the best config from rank 0
|
||||
best_dispatch_results = torch.tensor(
|
||||
[best_results[0], best_results[1], best_results[2]],
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
all_best_fp8_results_list = [
|
||||
torch.zeros_like(best_dispatch_results)
|
||||
for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
dist.all_gather(
|
||||
all_best_fp8_results_list, best_dispatch_results, group=group
|
||||
)
|
||||
best_dispatch_results = all_best_fp8_results_list[0].tolist()
|
||||
dispatch_config = deep_ep.Config(
|
||||
best_dispatch_results[0],
|
||||
best_dispatch_results[1],
|
||||
nvl_buffer_size,
|
||||
best_dispatch_results[2],
|
||||
rdma_buffer_size,
|
||||
)
|
||||
|
||||
dispatch_args = {
|
||||
"x": x,
|
||||
"num_tokens_per_rank": num_tokens_per_rank,
|
||||
"num_tokens_per_rdma_rank": num_tokens_per_rdma_rank,
|
||||
"is_token_in_rank": is_token_in_rank,
|
||||
"num_tokens_per_expert": num_tokens_per_expert,
|
||||
"config": dispatch_config if dispatch_config is not None else config,
|
||||
}
|
||||
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
|
||||
|
||||
# Tune combine performance
|
||||
best_time, best_results = 1e10, None
|
||||
for nvl_chunk_size in range(1, 5, 1):
|
||||
for rdma_chunk_size in range(8, 33, 4):
|
||||
config = deep_ep.Config(
|
||||
num_sms,
|
||||
nvl_chunk_size,
|
||||
nvl_buffer_size,
|
||||
rdma_chunk_size,
|
||||
rdma_buffer_size,
|
||||
)
|
||||
tune_args = {"x": recv_x, "handle": handle, "config": config}
|
||||
t = bench(lambda: buffer.combine(**tune_args))[0]
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ",
|
||||
flush=True,
|
||||
)
|
||||
if t < best_time:
|
||||
best_time, best_results = t, (
|
||||
num_sms,
|
||||
nvl_chunk_size,
|
||||
rdma_chunk_size,
|
||||
)
|
||||
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)",
|
||||
flush=True,
|
||||
)
|
||||
print("", flush=True)
|
||||
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
def test_loop(local_rank: int, num_local_ranks: int):
|
||||
num_nodes = int(os.getenv("WORLD_SIZE", 1))
|
||||
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
|
||||
test_ll_compatibility = False
|
||||
if test_ll_compatibility:
|
||||
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
|
||||
|
||||
buffer = deep_ep.Buffer(
|
||||
group,
|
||||
int(1e9),
|
||||
int(1e9),
|
||||
low_latency_mode=test_ll_compatibility,
|
||||
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1),
|
||||
)
|
||||
assert num_local_ranks == 8 and num_ranks > 8
|
||||
torch.manual_seed(rank)
|
||||
|
||||
for i in (24,):
|
||||
test_main(
|
||||
i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group
|
||||
)
|
||||
if local_rank == 0:
|
||||
print("", flush=True)
|
||||
|
||||
# Test compatibility with low latency functions
|
||||
if test_ll_compatibility:
|
||||
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
|
||||
test_deepep_low_latency.test_main(
|
||||
ll_num_tokens,
|
||||
ll_hidden,
|
||||
ll_num_experts,
|
||||
ll_num_topk,
|
||||
rank,
|
||||
num_ranks,
|
||||
group,
|
||||
buffer,
|
||||
seed=1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_processes = 8
|
||||
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)
|
||||
379
test/srt/ep/test_deepep_intranode.py
Normal file
379
test/srt/ep/test_deepep_intranode.py
Normal file
@@ -0,0 +1,379 @@
|
||||
# Copy from deepseek-ai/DeepEP/tests/test_intranode.py
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
import deep_ep
|
||||
|
||||
# Test compatibility with low latency functions
|
||||
import test_deepep_low_latency
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.test.test_deepep_utils import (
|
||||
bench,
|
||||
calc_diff,
|
||||
init_dist,
|
||||
inplace_unique,
|
||||
per_token_cast_back,
|
||||
per_token_cast_to_fp8,
|
||||
)
|
||||
|
||||
|
||||
def test_main(
|
||||
num_sms: int,
|
||||
local_rank: int,
|
||||
num_ranks: int,
|
||||
rank: int,
|
||||
buffer: deep_ep.Buffer,
|
||||
group: dist.ProcessGroup,
|
||||
):
|
||||
# Settings
|
||||
num_tokens, hidden, num_topk, num_experts = (
|
||||
4096,
|
||||
7168,
|
||||
8,
|
||||
(256 // num_ranks) * num_ranks,
|
||||
)
|
||||
assert num_experts % num_ranks == 0
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Random data
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * rank
|
||||
x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
|
||||
x_e4m3 = per_token_cast_to_fp8(x)
|
||||
scores = (
|
||||
torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs()
|
||||
+ 1
|
||||
)
|
||||
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1]
|
||||
topk_weights = (
|
||||
torch.ones((num_tokens, num_topk), dtype=torch.float32, device="cuda") * rank
|
||||
)
|
||||
topk_weights_pure_rand = torch.randn(
|
||||
(num_tokens, num_topk), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
rank_idx = topk_idx // (num_experts // num_ranks)
|
||||
rank_idx.masked_fill_(topk_idx == -1, -1)
|
||||
inplace_unique(rank_idx, num_ranks)
|
||||
|
||||
# Expert meta
|
||||
num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="cuda")
|
||||
for i in range(num_experts):
|
||||
num_tokens_per_expert[i] = (topk_idx == i).sum()
|
||||
gbl_num_tokens_per_expert = num_tokens_per_expert.clone()
|
||||
dist.all_reduce(gbl_num_tokens_per_expert, group=group)
|
||||
|
||||
# Rank layout meta
|
||||
num_tokens_per_rank = torch.empty((num_ranks,), dtype=torch.int, device="cuda")
|
||||
token_idx_in_rank = torch.full(
|
||||
(num_ranks, num_tokens), -1, dtype=torch.long, device="cuda"
|
||||
)
|
||||
for i in range(num_ranks):
|
||||
num_tokens_per_rank[i] = (rank_idx == i).sum()
|
||||
token_sel = (rank_idx == i).max(dim=-1)[0]
|
||||
count = token_sel.sum().item()
|
||||
tokens = torch.sort(token_sel.to(torch.int), descending=True)[1]
|
||||
tokens[:count] = torch.sort(tokens[:count])[0]
|
||||
token_idx_in_rank[i][tokens[:count]] = torch.arange(
|
||||
count, dtype=torch.long, device="cuda"
|
||||
)
|
||||
token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int)
|
||||
is_token_in_rank = token_idx_in_rank >= 0
|
||||
gbl_num_tokens_per_rank = num_tokens_per_rank.clone()
|
||||
dist.all_reduce(gbl_num_tokens_per_rank, group=group)
|
||||
|
||||
ref_num_tokens_per_rank, _, ref_num_tokens_per_expert, ref_is_token_in_rank, _ = (
|
||||
buffer.get_dispatch_layout(topk_idx, num_experts)
|
||||
)
|
||||
assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank)
|
||||
assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert)
|
||||
assert torch.allclose(ref_is_token_in_rank, is_token_in_rank)
|
||||
t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0]
|
||||
if local_rank == 0:
|
||||
print(f"[layout] Kernel performance: {t * 1000:.3f} ms", flush=True)
|
||||
print("", flush=True)
|
||||
group.barrier()
|
||||
time.sleep(1)
|
||||
|
||||
# Config
|
||||
nvl_buffer_size = 256
|
||||
config = deep_ep.Config(num_sms, 8, nvl_buffer_size)
|
||||
|
||||
# Test dispatch
|
||||
# noinspection PyShadowingNames
|
||||
def check_data(check_x, rank_prefix_matrix):
|
||||
assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1))
|
||||
check_start = 0
|
||||
for i in range(num_ranks):
|
||||
check_end = rank_prefix_matrix[i][rank].item()
|
||||
assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0
|
||||
check_start = check_end
|
||||
|
||||
for previous_mode in (False, True):
|
||||
for async_mode in (False, True):
|
||||
for current_x in (x_pure_rand, x, x_e4m3):
|
||||
for with_topk in (False, True):
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...',
|
||||
flush=True,
|
||||
end="",
|
||||
)
|
||||
dispatch_args = {
|
||||
"x": current_x,
|
||||
"num_tokens_per_rank": num_tokens_per_rank,
|
||||
"is_token_in_rank": is_token_in_rank,
|
||||
"num_tokens_per_expert": num_tokens_per_expert,
|
||||
"config": config,
|
||||
"async_finish": async_mode,
|
||||
}
|
||||
if with_topk:
|
||||
dispatch_args.update(
|
||||
{
|
||||
"topk_idx": topk_idx,
|
||||
"topk_weights": (
|
||||
topk_weights_pure_rand
|
||||
if current_x is x_pure_rand
|
||||
else topk_weights
|
||||
),
|
||||
}
|
||||
)
|
||||
if previous_mode:
|
||||
dispatch_args.update({"previous_event": buffer.capture()})
|
||||
(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
recv_num_tokens_per_expert_list,
|
||||
handle,
|
||||
event,
|
||||
) = buffer.dispatch(**dispatch_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
recv_x = (
|
||||
per_token_cast_back(*recv_x)
|
||||
if isinstance(recv_x, tuple)
|
||||
else recv_x
|
||||
)
|
||||
|
||||
# Checks
|
||||
rank_prefix_matrix = handle[0]
|
||||
assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(
|
||||
0
|
||||
), f"{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}"
|
||||
assert (
|
||||
gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist()
|
||||
== recv_num_tokens_per_expert_list
|
||||
)
|
||||
if current_x is not x_pure_rand:
|
||||
check_data(recv_x, rank_prefix_matrix)
|
||||
if with_topk:
|
||||
# Check `topk_idx`
|
||||
assert (
|
||||
recv_topk_idx.eq(-1)
|
||||
| (
|
||||
(recv_topk_idx >= 0)
|
||||
& (recv_topk_idx < (num_experts // num_ranks))
|
||||
)
|
||||
).sum().item() == recv_topk_idx.numel()
|
||||
for i, count in enumerate(recv_num_tokens_per_expert_list):
|
||||
assert recv_topk_idx.eq(i).sum().item() == count
|
||||
|
||||
# Check `topk_weights`
|
||||
if current_x is not x_pure_rand:
|
||||
recv_topk_weights[recv_topk_idx.eq(-1)] = (
|
||||
recv_topk_weights.amax(dim=1, keepdim=True).expand_as(
|
||||
recv_topk_weights
|
||||
)[recv_topk_idx.eq(-1)]
|
||||
)
|
||||
check_data(recv_topk_weights, rank_prefix_matrix)
|
||||
|
||||
# Test cached dispatch (must without top-k staffs)
|
||||
if not with_topk:
|
||||
dispatch_args = {
|
||||
"x": current_x,
|
||||
"handle": handle,
|
||||
"config": config,
|
||||
"async_finish": async_mode,
|
||||
}
|
||||
if previous_mode:
|
||||
dispatch_args.update({"previous_event": buffer.capture()})
|
||||
recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
recv_x = (
|
||||
per_token_cast_back(*recv_x)
|
||||
if isinstance(recv_x, tuple)
|
||||
else recv_x
|
||||
)
|
||||
if current_x is not x_pure_rand:
|
||||
check_data(recv_x, rank_prefix_matrix)
|
||||
|
||||
# Test combine
|
||||
combine_args = {
|
||||
"x": recv_x,
|
||||
"handle": handle,
|
||||
"config": config,
|
||||
"async_finish": async_mode,
|
||||
}
|
||||
if with_topk:
|
||||
combine_args.update({"topk_weights": recv_topk_weights})
|
||||
if previous_mode:
|
||||
dispatch_args.update({"previous_event": buffer.capture()})
|
||||
combined_x, combined_topk_weights, event = buffer.combine(
|
||||
**combine_args
|
||||
)
|
||||
event.current_stream_wait() if async_mode else ()
|
||||
check_x = combined_x.float() / is_token_in_rank.sum(
|
||||
dim=1
|
||||
).unsqueeze(1)
|
||||
ref_x = x_pure_rand if current_x is x_pure_rand else x
|
||||
assert calc_diff(check_x, ref_x) < 5e-6
|
||||
if with_topk:
|
||||
check_topk_weights = (
|
||||
combined_topk_weights
|
||||
if (current_x is x_pure_rand)
|
||||
else (
|
||||
combined_topk_weights
|
||||
/ is_token_in_rank.sum(dim=1).unsqueeze(1)
|
||||
)
|
||||
)
|
||||
ref_topk_weights = (
|
||||
topk_weights_pure_rand
|
||||
if current_x is x_pure_rand
|
||||
else topk_weights
|
||||
)
|
||||
assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9
|
||||
|
||||
# For later tuning
|
||||
dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2
|
||||
combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes
|
||||
|
||||
if local_rank == 0:
|
||||
print(" passed", flush=True)
|
||||
if local_rank == 0:
|
||||
print("", flush=True)
|
||||
|
||||
# Tune dispatch performance
|
||||
best_dispatch_results = None
|
||||
fp8_factor = (1 + 4 / 128) / 2
|
||||
for current_x in (x_e4m3, x):
|
||||
best_time, best_results = 1e10, None
|
||||
nvl_recv_bytes = (
|
||||
(dispatch_bf16_nvl_recv_bytes * fp8_factor)
|
||||
if isinstance(current_x, tuple)
|
||||
else dispatch_bf16_nvl_recv_bytes
|
||||
)
|
||||
for nvl_chunk_size in range(4, 33, 4):
|
||||
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
|
||||
tune_args = {"x": current_x, "handle": handle, "config": config}
|
||||
t = bench(lambda: buffer.dispatch(**tune_args))[0]
|
||||
if t < best_time:
|
||||
best_time, best_results = t, (num_sms, nvl_chunk_size)
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ",
|
||||
flush=True,
|
||||
)
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)',
|
||||
flush=True,
|
||||
)
|
||||
print("", flush=True)
|
||||
|
||||
if isinstance(current_x, tuple):
|
||||
# Gather FP8 the best config from rank 0
|
||||
best_dispatch_results = torch.tensor(
|
||||
[best_results[0], best_results[1]], dtype=torch.int32, device="cuda"
|
||||
)
|
||||
all_best_fp8_results_list = [
|
||||
torch.zeros_like(best_dispatch_results)
|
||||
for _ in range(torch.distributed.get_world_size())
|
||||
]
|
||||
dist.all_gather(
|
||||
all_best_fp8_results_list, best_dispatch_results, group=group
|
||||
)
|
||||
best_dispatch_results = all_best_fp8_results_list[0].tolist()
|
||||
dispatch_config = deep_ep.Config(
|
||||
best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size
|
||||
)
|
||||
|
||||
dispatch_args = {
|
||||
"x": x,
|
||||
"num_tokens_per_rank": num_tokens_per_rank,
|
||||
"is_token_in_rank": is_token_in_rank,
|
||||
"num_tokens_per_expert": num_tokens_per_expert,
|
||||
"config": dispatch_config if dispatch_config is not None else config,
|
||||
}
|
||||
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
|
||||
|
||||
# Tune combine performance
|
||||
best_time, best_results = 1e10, None
|
||||
for nvl_chunk_size in range(1, 7, 1):
|
||||
config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size)
|
||||
tune_args = {"x": recv_x, "handle": handle, "config": config}
|
||||
t = bench(lambda: buffer.combine(**tune_args))[0]
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ",
|
||||
flush=True,
|
||||
)
|
||||
if t < best_time:
|
||||
best_time, best_results = t, (num_sms, nvl_chunk_size)
|
||||
|
||||
if local_rank == 0:
|
||||
print(
|
||||
f"[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)",
|
||||
flush=True,
|
||||
)
|
||||
print("", flush=True)
|
||||
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
def test_loop(local_rank: int, num_local_ranks: int):
|
||||
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
|
||||
test_ll_compatibility, num_rdma_bytes = False, 0
|
||||
if test_ll_compatibility:
|
||||
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
ll_num_tokens, ll_hidden, num_ranks, ll_num_experts
|
||||
)
|
||||
|
||||
buffer = deep_ep.Buffer(
|
||||
group,
|
||||
int(1e9),
|
||||
num_rdma_bytes,
|
||||
low_latency_mode=test_ll_compatibility,
|
||||
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1),
|
||||
)
|
||||
torch.manual_seed(rank)
|
||||
|
||||
for i in (24,):
|
||||
test_main(i, local_rank, num_ranks, rank, buffer, group)
|
||||
if local_rank == 0:
|
||||
print("", flush=True)
|
||||
|
||||
# Test compatibility with low latency functions
|
||||
if test_ll_compatibility:
|
||||
buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts)
|
||||
test_deepep_low_latency.test_main(
|
||||
ll_num_tokens,
|
||||
ll_hidden,
|
||||
ll_num_experts,
|
||||
ll_num_topk,
|
||||
rank,
|
||||
num_ranks,
|
||||
group,
|
||||
buffer,
|
||||
seed=1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_processes = 8
|
||||
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)
|
||||
149
test/srt/ep/test_deepep_large.py
Normal file
149
test/srt/ep/test_deepep_large.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestDeepseek(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--tp",
|
||||
"8",
|
||||
"--enable-dp-attention",
|
||||
"--dp",
|
||||
"8",
|
||||
"--moe-dense-tp-size",
|
||||
"1",
|
||||
"--enable-dp-lm-head",
|
||||
"--moe-a2a-backend",
|
||||
"deepep",
|
||||
"--enable-two-batch-overlap",
|
||||
"--ep-num-redundant-experts",
|
||||
"32",
|
||||
"--ep-dispatch-algorithm",
|
||||
"dynamic",
|
||||
"--eplb-algorithm",
|
||||
"deepseek",
|
||||
"--cuda-graph-bs",
|
||||
"256",
|
||||
"--max-running-requests",
|
||||
"2048",
|
||||
"--disable-radix-cache",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=1200,
|
||||
parallel=1200,
|
||||
max_new_tokens=512,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(f"Eval accuracy of GSM8K: {metrics=}")
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.92)
|
||||
|
||||
|
||||
class TestDeepseekMTP(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--tp",
|
||||
"8",
|
||||
"--enable-dp-attention",
|
||||
"--dp",
|
||||
"8",
|
||||
"--moe-dense-tp-size",
|
||||
"1",
|
||||
"--enable-dp-lm-head",
|
||||
"--moe-a2a-backend",
|
||||
"deepep",
|
||||
"--enable-two-batch-overlap",
|
||||
"--ep-num-redundant-experts",
|
||||
"32",
|
||||
"--ep-dispatch-algorithm",
|
||||
"dynamic",
|
||||
"--eplb-algorithm",
|
||||
"deepseek",
|
||||
"--cuda-graph-bs",
|
||||
"64", # TODO: increase it to 128 when TBO is supported in draft_extend
|
||||
"--max-running-requests",
|
||||
"512",
|
||||
"--speculative-algorithm",
|
||||
"EAGLE",
|
||||
"--speculative-num-steps",
|
||||
"1",
|
||||
"--speculative-eagle-topk",
|
||||
"1",
|
||||
"--speculative-num-draft-tokens",
|
||||
"2",
|
||||
"--disable-radix-cache",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=1200,
|
||||
parallel=1200,
|
||||
max_new_tokens=512,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(f"Eval accuracy of GSM8K: {metrics=}")
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.92)
|
||||
|
||||
server_info = requests.get(self.base_url + "/get_server_info")
|
||||
avg_spec_accept_length = server_info.json()["internal_states"][0][
|
||||
"avg_spec_accept_length"
|
||||
]
|
||||
print(
|
||||
f"###test_gsm8k:\n"
|
||||
f"accuracy={metrics['accuracy']=:.3f}\n"
|
||||
f"{avg_spec_accept_length=:.3f}\n"
|
||||
)
|
||||
self.assertGreater(avg_spec_accept_length, 1.85)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
325
test/srt/ep/test_deepep_low_latency.py
Normal file
325
test/srt/ep/test_deepep_low_latency.py
Normal file
@@ -0,0 +1,325 @@
|
||||
# Copy from deepseek-ai/DeepEP/tests/test_low_latency.py
|
||||
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import deep_ep
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.test.test_deepep_utils import (
|
||||
bench,
|
||||
bench_kineto,
|
||||
calc_diff,
|
||||
hash_tensor,
|
||||
init_dist,
|
||||
per_token_cast_back,
|
||||
)
|
||||
|
||||
|
||||
def test_main(
|
||||
num_tokens: int,
|
||||
hidden: int,
|
||||
num_experts: int,
|
||||
num_topk: int,
|
||||
rank: int,
|
||||
num_ranks: int,
|
||||
group: dist.ProcessGroup,
|
||||
buffer: deep_ep.Buffer,
|
||||
seed: int = 0,
|
||||
):
|
||||
torch.manual_seed(seed + rank)
|
||||
random.seed(seed + rank)
|
||||
|
||||
assert num_experts % num_ranks == 0
|
||||
num_local_experts = num_experts // num_ranks
|
||||
|
||||
# NOTES: the integers greater than 256 exceeds the BF16 precision limit
|
||||
rank_offset = 128
|
||||
assert (
|
||||
num_ranks - rank_offset < 257
|
||||
), "Too many ranks (exceeding test precision limit)"
|
||||
|
||||
x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * (
|
||||
rank - rank_offset
|
||||
)
|
||||
x[:, -128:] = torch.arange(num_tokens, device="cuda").to(torch.bfloat16).view(-1, 1)
|
||||
scores = (
|
||||
torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs()
|
||||
+ 1
|
||||
)
|
||||
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1]
|
||||
topk_weights = torch.randn(
|
||||
(num_tokens, num_topk), dtype=torch.float32, device="cuda"
|
||||
).abs()
|
||||
|
||||
# Randomly mask some positions
|
||||
for i in range(10):
|
||||
topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = (
|
||||
-1
|
||||
)
|
||||
|
||||
# Check dispatch correctness
|
||||
do_check = True
|
||||
hash_value, num_times = 0, 0
|
||||
for return_recv_hook in (False, True):
|
||||
for dispatch_use_fp8 in (False, True):
|
||||
num_times += 1
|
||||
for i in range((num_times % 2) + 1):
|
||||
packed_recv_x, packed_recv_count, handle, event, hook = (
|
||||
buffer.low_latency_dispatch(
|
||||
x,
|
||||
topk_idx,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
use_fp8=dispatch_use_fp8,
|
||||
async_finish=not return_recv_hook,
|
||||
return_recv_hook=return_recv_hook,
|
||||
)
|
||||
)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
packed_recv_x = (
|
||||
(packed_recv_x[0], packed_recv_x[1].contiguous())
|
||||
if dispatch_use_fp8
|
||||
else packed_recv_x
|
||||
)
|
||||
simulated_gemm_x = (
|
||||
per_token_cast_back(
|
||||
packed_recv_x[0].view(-1, hidden),
|
||||
packed_recv_x[1].view(-1, hidden // 128),
|
||||
).view(packed_recv_x[0].shape)
|
||||
if dispatch_use_fp8
|
||||
else packed_recv_x.clone()
|
||||
)
|
||||
all_topk_idx = torch.empty(
|
||||
(num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device="cuda"
|
||||
)
|
||||
dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group)
|
||||
for i in range(num_local_experts if do_check else 0):
|
||||
expert_id = rank * num_local_experts + i
|
||||
recv_x = (
|
||||
per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i])
|
||||
if dispatch_use_fp8
|
||||
else packed_recv_x[i]
|
||||
)
|
||||
recv_count, recv_src_info, recv_layout_range = (
|
||||
packed_recv_count[i],
|
||||
handle[0][i],
|
||||
handle[1][i],
|
||||
)
|
||||
|
||||
# Check expert indices
|
||||
int_mask = (2**32) - 1
|
||||
num_valid_tokens = recv_count.item()
|
||||
assert (
|
||||
num_valid_tokens == (recv_layout_range & int_mask).sum().item()
|
||||
), f"{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()"
|
||||
assert (
|
||||
num_valid_tokens == (all_topk_idx == expert_id).sum().item()
|
||||
), f"{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}"
|
||||
|
||||
# Check received data
|
||||
recv_x = recv_x[:num_valid_tokens]
|
||||
recv_x_amin = recv_x[:, :-128].amin(dim=-1)
|
||||
recv_src_info = recv_src_info[:num_valid_tokens]
|
||||
assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1))
|
||||
assert (
|
||||
recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens
|
||||
).sum().item() == 0
|
||||
for j in range(num_ranks):
|
||||
begin_idx, count = (recv_layout_range[j] >> 32).item(), (
|
||||
recv_layout_range[j] & int_mask
|
||||
).item()
|
||||
assert (recv_x_amin == j - rank_offset).sum().item() == (
|
||||
all_topk_idx[j] == expert_id
|
||||
).sum().item()
|
||||
assert (
|
||||
recv_x[begin_idx : begin_idx + count][:-128] - j
|
||||
).sum().item() == 0
|
||||
if dispatch_use_fp8:
|
||||
hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens])
|
||||
hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens])
|
||||
else:
|
||||
hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens])
|
||||
|
||||
# Check combine correctness
|
||||
for zero_copy in (False, True):
|
||||
if zero_copy:
|
||||
buffer.get_next_low_latency_combine_buffer(handle)[
|
||||
:, :, :
|
||||
] = simulated_gemm_x
|
||||
out = torch.empty(
|
||||
(num_tokens, hidden), dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
combined_x, event, hook = buffer.low_latency_combine(
|
||||
simulated_gemm_x,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
handle,
|
||||
async_finish=not return_recv_hook,
|
||||
zero_copy=zero_copy,
|
||||
return_recv_hook=return_recv_hook,
|
||||
out=out,
|
||||
)
|
||||
hook() if return_recv_hook else event.current_stream_wait()
|
||||
if do_check:
|
||||
diff = calc_diff(
|
||||
x
|
||||
* topk_weights.masked_fill(topk_idx == -1, 0)
|
||||
.sum(dim=1)
|
||||
.view(-1, 1),
|
||||
combined_x,
|
||||
)
|
||||
assert torch.isnan(combined_x).sum().item() == 0
|
||||
assert diff < 1e-5, f"Error: {diff=}, {zero_copy=}"
|
||||
hash_value ^= hash_tensor(combined_x)
|
||||
|
||||
def create_test_cast_with_outliers(num_outliers):
|
||||
tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
|
||||
tmp /= tmp.abs().amax(dim=1).view(-1, 1)
|
||||
assert tmp.abs().amax().item() <= 1
|
||||
|
||||
# Create some amax outliers
|
||||
for i in range(num_outliers):
|
||||
tmp[random.randint(0, num_tokens - 1)] *= 1e3
|
||||
return tmp
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def large_gemm_with_hook(hook):
|
||||
mat_0 = torch.randn((8192, 8192), dtype=torch.float)
|
||||
mat_1 = torch.randn((8192, 8192), dtype=torch.float)
|
||||
mat_0 @ mat_1
|
||||
hook()
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func(zero_copy: bool, return_recv_hook: bool):
|
||||
recv_x, recv_count, handle, event, hook = buffer.low_latency_dispatch(
|
||||
x,
|
||||
topk_idx,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
async_finish=False,
|
||||
return_recv_hook=return_recv_hook,
|
||||
)
|
||||
large_gemm_with_hook(hook) if return_recv_hook else None
|
||||
if zero_copy:
|
||||
buffer.get_next_low_latency_combine_buffer(handle)[
|
||||
:, :, :
|
||||
] = simulated_gemm_x
|
||||
combined_x, event, hook = buffer.low_latency_combine(
|
||||
simulated_gemm_x,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
handle,
|
||||
zero_copy=zero_copy,
|
||||
return_recv_hook=return_recv_hook,
|
||||
)
|
||||
large_gemm_with_hook(hook) if return_recv_hook else None
|
||||
|
||||
# Calculate bandwidth
|
||||
num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2
|
||||
num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0
|
||||
for i in range(num_tokens):
|
||||
num_selections = (topk_idx[i] != -1).sum().item()
|
||||
num_dispatch_comm_bytes += num_fp8_bytes * num_selections
|
||||
num_combine_comm_bytes += num_bf16_bytes * num_selections
|
||||
|
||||
# Dispatch + combine testing
|
||||
avg_t, min_t, max_t = bench(
|
||||
partial(test_func, zero_copy=False, return_recv_hook=False)
|
||||
)
|
||||
print(
|
||||
f"[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, "
|
||||
f"avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Separate profiling
|
||||
for return_recv_hook in (False, True):
|
||||
group.barrier()
|
||||
dispatch_t, combine_t = bench_kineto(
|
||||
partial(test_func, zero_copy=True, return_recv_hook=return_recv_hook),
|
||||
kernel_names=("dispatch", "combine"),
|
||||
barrier_comm_profiling=True,
|
||||
suppress_kineto_output=True,
|
||||
)
|
||||
if not return_recv_hook:
|
||||
print(
|
||||
f"[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | "
|
||||
f"Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} us | "
|
||||
f"Combine send/recv time: {combine_t * 2 * 1e6:.2f} us",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
return hash_value
|
||||
|
||||
|
||||
# noinspection PyUnboundLocalVariable
|
||||
def test_loop(local_rank: int, num_local_ranks: int):
|
||||
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
|
||||
num_tokens, hidden, num_topk, num_experts = 128, 7168, 8, 288
|
||||
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
num_tokens, hidden, num_ranks, num_experts
|
||||
)
|
||||
if local_rank == 0:
|
||||
print(f"Allocating buffer size: {num_rdma_bytes / 1e6} MB ...", flush=True)
|
||||
buffer = deep_ep.Buffer(
|
||||
group,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=num_experts // num_ranks,
|
||||
)
|
||||
test_main(
|
||||
num_tokens,
|
||||
hidden,
|
||||
num_experts,
|
||||
num_topk,
|
||||
rank,
|
||||
num_ranks,
|
||||
group,
|
||||
buffer,
|
||||
seed=1,
|
||||
)
|
||||
|
||||
do_pressure_test = False
|
||||
for seed in range(int(1e9) if do_pressure_test else 0):
|
||||
if local_rank == 0:
|
||||
print(f"Testing with seed {seed} ...", flush=True)
|
||||
ref_hash = test_main(
|
||||
num_tokens,
|
||||
hidden,
|
||||
num_experts,
|
||||
num_topk,
|
||||
rank,
|
||||
num_ranks,
|
||||
group,
|
||||
buffer,
|
||||
seed=seed,
|
||||
)
|
||||
for i in range(20):
|
||||
assert (
|
||||
test_main(
|
||||
num_tokens,
|
||||
hidden,
|
||||
num_experts,
|
||||
num_topk,
|
||||
rank,
|
||||
num_ranks,
|
||||
group,
|
||||
buffer,
|
||||
seed=seed,
|
||||
)
|
||||
== ref_hash
|
||||
), f"Error: seed={seed}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO: you may modify NUMA binding for less CPU overhead
|
||||
num_processes = 8
|
||||
torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes)
|
||||
389
test/srt/ep/test_deepep_small.py
Normal file
389
test/srt/ep/test_deepep_small.py
Normal file
@@ -0,0 +1,389 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
|
||||
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestPureDP(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--tp",
|
||||
"4",
|
||||
"--enable-dp-attention",
|
||||
"--dp",
|
||||
"4",
|
||||
"--moe-a2a-backend",
|
||||
"deepep",
|
||||
"--cuda-graph-max-bs",
|
||||
"128",
|
||||
"--max-running-requests",
|
||||
"512",
|
||||
"--mem-fraction-static",
|
||||
"0.5",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(metrics)
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.60)
|
||||
|
||||
|
||||
class TestHybridDPTP(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--tp",
|
||||
"4",
|
||||
"--enable-dp-attention",
|
||||
"--dp",
|
||||
"2",
|
||||
"--moe-a2a-backend",
|
||||
"deepep",
|
||||
"--cuda-graph-max-bs",
|
||||
"128",
|
||||
"--max-running-requests",
|
||||
"256",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(metrics)
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.60)
|
||||
|
||||
|
||||
class TestTP(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--tp",
|
||||
"4",
|
||||
"--moe-a2a-backend",
|
||||
"deepep",
|
||||
"--cuda-graph-max-bs",
|
||||
"128",
|
||||
"--max-running-requests",
|
||||
"128",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(metrics)
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.60)
|
||||
|
||||
|
||||
@unittest.skip("covered in test_deepep_large.py")
|
||||
class TestNoGatherdBuffer(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--tp",
|
||||
"4",
|
||||
"--enable-dp-attention",
|
||||
"--dp",
|
||||
"4",
|
||||
"--moe-dense-tp-size",
|
||||
"1",
|
||||
"--enable-dp-lm-head",
|
||||
"--moe-a2a-backend",
|
||||
"deepep",
|
||||
"--cuda-graph-max-bs",
|
||||
"32",
|
||||
"--max-running-requests",
|
||||
"512",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(metrics)
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.60)
|
||||
|
||||
|
||||
class TestTBO(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--tp",
|
||||
"4",
|
||||
"--enable-dp-attention",
|
||||
"--dp",
|
||||
"4",
|
||||
"--moe-dense-tp-size",
|
||||
"1",
|
||||
"--moe-a2a-backend",
|
||||
"deepep",
|
||||
"--enable-two-batch-overlap",
|
||||
"--cuda-graph-max-bs",
|
||||
"128",
|
||||
"--max-running-requests",
|
||||
"512",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(metrics)
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.60)
|
||||
|
||||
|
||||
@unittest.skip("covered in TestMTPWithTBO")
|
||||
class TestMTP(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--tp",
|
||||
"4",
|
||||
"--enable-dp-attention",
|
||||
"--dp",
|
||||
"2",
|
||||
"--enable-dp-lm-head",
|
||||
"--moe-a2a-backend",
|
||||
"deepep",
|
||||
"--speculative-algo",
|
||||
"EAGLE",
|
||||
"--speculative-draft",
|
||||
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
|
||||
"--speculative-num-steps",
|
||||
"2",
|
||||
"--speculative-eagle-topk",
|
||||
"3",
|
||||
"--speculative-num-draft-tokens",
|
||||
"3",
|
||||
"--cuda-graph-max-bs",
|
||||
"32",
|
||||
"--max-running-requests",
|
||||
"64",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(metrics)
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.60)
|
||||
|
||||
server_info = requests.get(self.base_url + "/get_server_info")
|
||||
avg_spec_accept_length = server_info.json()["internal_states"][0][
|
||||
"avg_spec_accept_length"
|
||||
]
|
||||
print(
|
||||
f"###test_gsm8k (deepseek-v3 mtp + dp + tbo):\n"
|
||||
f"accuracy={metrics['accuracy']=:.3f}\n"
|
||||
f"{avg_spec_accept_length=:.3f}\n"
|
||||
)
|
||||
self.assertGreater(avg_spec_accept_length, 2.1)
|
||||
|
||||
|
||||
class TestMTPWithTBO(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
import os
|
||||
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--tp-size",
|
||||
"4",
|
||||
"--enable-dp-attention",
|
||||
"--dp-size",
|
||||
"4",
|
||||
"--enable-two-batch-overlap",
|
||||
"--moe-a2a-backend",
|
||||
"deepep",
|
||||
"--trust-remote-code",
|
||||
"--speculative-algorithm",
|
||||
"EAGLE",
|
||||
"--speculative-num-steps",
|
||||
"2",
|
||||
"--speculative-eagle-topk",
|
||||
"3",
|
||||
"--speculative-num-draft-tokens",
|
||||
"3",
|
||||
"--speculative-draft",
|
||||
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
|
||||
"--chunked-prefill-size",
|
||||
"256",
|
||||
"--cuda-graph-max-bs",
|
||||
"32",
|
||||
"--max-running-requests",
|
||||
"128",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(metrics)
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.60)
|
||||
|
||||
server_info = requests.get(self.base_url + "/get_server_info")
|
||||
avg_spec_accept_length = server_info.json()["internal_states"][0][
|
||||
"avg_spec_accept_length"
|
||||
]
|
||||
print(
|
||||
f"###test_gsm8k (deepseek-v3 mtp + dp + tbo):\n"
|
||||
f"accuracy={metrics['accuracy']=:.3f}\n"
|
||||
f"{avg_spec_accept_length=:.3f}\n"
|
||||
)
|
||||
self.assertGreater(avg_spec_accept_length, 2.1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
155
test/srt/ep/test_eplb.py
Executable file
155
test/srt/ep/test_eplb.py
Executable file
@@ -0,0 +1,155 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class _BaseTestDynamicEPLB(CustomTestCase):
|
||||
extra_args = []
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--tp",
|
||||
"2",
|
||||
"--dp",
|
||||
"2",
|
||||
"--enable-dp-attention",
|
||||
"--moe-a2a-backend",
|
||||
"deepep",
|
||||
"--deepep-mode",
|
||||
"normal",
|
||||
"--disable-cuda-graph",
|
||||
"--enable-eplb",
|
||||
"--ep-num-redundant-experts",
|
||||
"4",
|
||||
"--eplb-rebalance-num-iterations",
|
||||
"50",
|
||||
"--expert-distribution-recorder-buffer-size",
|
||||
"50",
|
||||
# TODO pr-chain: enable later
|
||||
# "--enable-expert-distribution-metrics",
|
||||
# TODO auto determine these flags
|
||||
"--expert-distribution-recorder-mode",
|
||||
"stat",
|
||||
"--ep-dispatch-algorithm",
|
||||
"static",
|
||||
*cls.extra_args,
|
||||
],
|
||||
env={
|
||||
"SGL_ENABLE_JIT_DEEPGEMM": "0",
|
||||
"SGLANG_EXPERT_LOCATION_UPDATER_CANARY": "1",
|
||||
**os.environ,
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreater(metrics["score"], 0.5)
|
||||
|
||||
|
||||
class TestDynamicEPLBSimple(_BaseTestDynamicEPLB):
|
||||
pass
|
||||
|
||||
|
||||
class TestDynamicEPLBMultiChunk(_BaseTestDynamicEPLB):
|
||||
extra_args = ["--eplb-rebalance-layers-per-chunk", "1"]
|
||||
|
||||
|
||||
class TestStaticEPLB(CustomTestCase):
|
||||
def test_save_expert_distribution_and_init_expert_location(self):
|
||||
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "0"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
engine_kwargs = dict(
|
||||
model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||
trust_remote_code=True,
|
||||
ep_num_redundant_experts=4,
|
||||
enable_dp_attention=True,
|
||||
moe_a2a_backend="deepep",
|
||||
disable_cuda_graph=True,
|
||||
expert_distribution_recorder_mode="stat",
|
||||
tp_size=2,
|
||||
dp_size=2,
|
||||
log_level="info",
|
||||
# TODO pr-chain: enable later
|
||||
# enable_expert_distribution_metrics=True,
|
||||
)
|
||||
|
||||
print(f"Action: start engine")
|
||||
os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"] = tmp_dir
|
||||
engine = sgl.Engine(
|
||||
**engine_kwargs,
|
||||
disable_overlap_schedule=True,
|
||||
)
|
||||
engine.start_expert_distribution_record()
|
||||
self._assert_engine_generate_correct(engine)
|
||||
|
||||
print(f"Action: dump_expert_distribution_record")
|
||||
engine.dump_expert_distribution_record()
|
||||
snapshot_path = list(Path(tmp_dir).glob("*.pt"))[0]
|
||||
assert snapshot_path is not None
|
||||
print(f"{snapshot_path=}")
|
||||
|
||||
print(f"Action: shutdown engine")
|
||||
engine.shutdown()
|
||||
del engine
|
||||
|
||||
print(f"Action: start engine with init_expert_location")
|
||||
engine = sgl.Engine(
|
||||
**engine_kwargs,
|
||||
init_expert_location=str(snapshot_path),
|
||||
port=21000,
|
||||
# TODO auto determine these flags
|
||||
ep_dispatch_algorithm="static",
|
||||
)
|
||||
self._assert_engine_generate_correct(engine)
|
||||
print(f"Action: shutdown engine")
|
||||
engine.shutdown()
|
||||
del engine
|
||||
|
||||
def _assert_engine_generate_correct(self, engine: sgl.Engine):
|
||||
output = engine.generate(
|
||||
prompt=["1+1=2, 2+2=4", "One plus one is two, two plus two is four"],
|
||||
sampling_params=dict(max_new_tokens=8, temperature=0.0),
|
||||
)
|
||||
print(f"engine.generate {output=}")
|
||||
self.assertEqual(
|
||||
[x["text"] for x in output],
|
||||
[", 4+4=8,", ", four plus four is eight, eight"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
2728
test/srt/ep/test_hybrid_dp_ep_tp_mtp.py
Normal file
2728
test/srt/ep/test_hybrid_dp_ep_tp_mtp.py
Normal file
File diff suppressed because it is too large
Load Diff
119
test/srt/ep/test_moe_deepep.py
Normal file
119
test/srt/ep/test_moe_deepep.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestPureTP(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--tp",
|
||||
"2",
|
||||
"--moe-a2a-backend",
|
||||
"deepep",
|
||||
"--disable-cuda-graph",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreater(metrics["score"], 0.5)
|
||||
|
||||
|
||||
class TestDPAttn(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--tp",
|
||||
"2",
|
||||
"--dp",
|
||||
"2",
|
||||
"--enable-dp-attention",
|
||||
"--moe-a2a-backend",
|
||||
"deepep",
|
||||
"--deepep-mode",
|
||||
"normal",
|
||||
"--disable-cuda-graph",
|
||||
# Test custom config
|
||||
"--deepep-config",
|
||||
json.dumps(
|
||||
{
|
||||
"normal_dispatch": {
|
||||
"num_sms": 20,
|
||||
"num_max_nvl_chunked_send_tokens": 16,
|
||||
"num_max_nvl_chunked_recv_tokens": 256,
|
||||
"num_max_rdma_chunked_send_tokens": 6,
|
||||
"num_max_rdma_chunked_recv_tokens": 128,
|
||||
},
|
||||
"normal_combine": {
|
||||
"num_sms": 20,
|
||||
"num_max_nvl_chunked_send_tokens": 6,
|
||||
"num_max_nvl_chunked_recv_tokens": 256,
|
||||
"num_max_rdma_chunked_send_tokens": 6,
|
||||
"num_max_rdma_chunked_recv_tokens": 128,
|
||||
},
|
||||
}
|
||||
),
|
||||
],
|
||||
env={
|
||||
"SGL_ENABLE_JIT_DEEPGEMM": "0",
|
||||
**os.environ,
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreater(metrics["score"], 0.5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
75
test/srt/ep/test_moe_deepep_eval_accuracy_large.py
Normal file
75
test/srt/ep/test_moe_deepep_eval_accuracy_large.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Usage:
|
||||
python -m unittest test_moe_deepep_eval_accuracy_large.TestMoEDeepEPEvalAccuracyLarge.test_mmlu
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestMoEDeepEPEvalAccuracyLarge(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_DEEPPEP_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--tp",
|
||||
"8",
|
||||
"--moe-a2a-backend",
|
||||
"deepep",
|
||||
"--cuda-graph-max-bs",
|
||||
"128",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=8,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
parallel=64,
|
||||
max_new_tokens=512,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(f"Eval accuracy of GSM8K: {metrics=}")
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.93)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
print(f"Eval accuracy of MMLU: {metrics=}")
|
||||
self.assertGreater(metrics["score"], 0.87)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
112
test/srt/ep/test_moe_ep.py
Normal file
112
test/srt/ep/test_moe_ep.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestEpMoE(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--tp",
|
||||
"2",
|
||||
"--ep-size",
|
||||
"2",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreaterEqual(metrics["score"], 0.5)
|
||||
|
||||
def test_mgsm_en(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mgsm_en",
|
||||
num_examples=None,
|
||||
num_threads=1024,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreaterEqual(metrics["score"], 0.8)
|
||||
|
||||
|
||||
class TestEpMoEFP8(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--tp",
|
||||
"2",
|
||||
"--ep-size",
|
||||
"2",
|
||||
"--quantization",
|
||||
"fp8",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreaterEqual(metrics["score"], 0.5)
|
||||
|
||||
def test_mgsm_en(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mgsm_en",
|
||||
num_examples=None,
|
||||
num_threads=1024,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreaterEqual(metrics["score"], 0.8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
365
test/srt/experiment_runner.py
Normal file
365
test/srt/experiment_runner.py
Normal file
@@ -0,0 +1,365 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import psutil
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerConfig:
|
||||
command: str
|
||||
process_names: List[str]
|
||||
default_port: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskConfig:
|
||||
server_cmd: str
|
||||
client_cmd: str
|
||||
name: Optional[str] = None
|
||||
server_type: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskResult:
|
||||
name: str
|
||||
success: bool
|
||||
output: str
|
||||
runtime: float
|
||||
timestamp: str
|
||||
|
||||
|
||||
SERVER_DEFAULTS = {
|
||||
"sglang": ServerConfig(
|
||||
command="sglang.launch_server",
|
||||
process_names=["sglang.launch_server"],
|
||||
default_port=30000,
|
||||
),
|
||||
"vllm": ServerConfig(
|
||||
command="vllm.entrypoints.openai.api_server",
|
||||
process_names=["vllm.entrypoints.openai.api_server"],
|
||||
default_port=8000,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def parse_key_info(output: str) -> str:
|
||||
"""Extract and format key information from the output"""
|
||||
key_info = []
|
||||
|
||||
# Extract Args namespace
|
||||
args_match = re.search(r"Namespace\(.*?\)", output, re.DOTALL)
|
||||
if args_match:
|
||||
key_info.append(args_match.group(0))
|
||||
|
||||
# Extract input/output token counts
|
||||
token_matches = re.findall(r"#(Input|Output) tokens: \d+", output)
|
||||
key_info.extend(token_matches)
|
||||
|
||||
# Extract benchmark result section
|
||||
result_match = re.search(
|
||||
r"============ Serving Benchmark Result ============.*?={50,}",
|
||||
output,
|
||||
re.DOTALL,
|
||||
)
|
||||
if result_match:
|
||||
key_info.append(result_match.group(0))
|
||||
|
||||
return "\n\n".join(key_info)
|
||||
|
||||
|
||||
def extract_port_from_command(cmd: str, server_type: str) -> int:
|
||||
port_match = re.search(r"--port[= ](\d+)", cmd)
|
||||
if port_match:
|
||||
return int(port_match.group(1))
|
||||
return SERVER_DEFAULTS.get(server_type, ServerConfig("", [], 8000)).default_port
|
||||
|
||||
|
||||
def detect_server_type(cmd: str) -> str:
|
||||
for server_type, config in SERVER_DEFAULTS.items():
|
||||
if config.command in cmd:
|
||||
return server_type
|
||||
return "unknown"
|
||||
|
||||
|
||||
def stream_output(
|
||||
process: subprocess.Popen, prefix: str, logger: logging.Logger
|
||||
) -> queue.Queue:
|
||||
output_queue = queue.Queue()
|
||||
|
||||
def stream_pipe(pipe, prefix):
|
||||
for line in iter(pipe.readline, ""):
|
||||
if prefix == "CLIENT":
|
||||
output_queue.put(line.rstrip())
|
||||
logger.debug(f"{prefix} | {line.rstrip()}")
|
||||
|
||||
stdout_thread = threading.Thread(
|
||||
target=stream_pipe, args=(process.stdout, prefix), daemon=True
|
||||
)
|
||||
stderr_thread = threading.Thread(
|
||||
target=stream_pipe, args=(process.stderr, prefix), daemon=True
|
||||
)
|
||||
|
||||
stdout_thread.start()
|
||||
stderr_thread.start()
|
||||
return output_queue, (stdout_thread, stderr_thread)
|
||||
|
||||
|
||||
class ProcessManager:
|
||||
def __init__(self):
|
||||
self.server_process: Optional[subprocess.Popen] = None
|
||||
self.client_process: Optional[subprocess.Popen] = None
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def start_process(
|
||||
self, command: str, prefix: str
|
||||
) -> Tuple[subprocess.Popen, queue.Queue]:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
output_queue, threads = stream_output(process, prefix, self.logger)
|
||||
return process, output_queue, threads
|
||||
|
||||
def kill_process_tree(self, process: subprocess.Popen):
|
||||
try:
|
||||
parent = psutil.Process(process.pid)
|
||||
children = parent.children(recursive=True)
|
||||
|
||||
for child in children:
|
||||
try:
|
||||
child.kill()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
parent.kill()
|
||||
gone, alive = psutil.wait_procs(children + [parent], timeout=3)
|
||||
|
||||
for p in alive:
|
||||
try:
|
||||
p.kill()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
def cleanup(self, process_names: List[str]):
|
||||
if self.client_process:
|
||||
self.kill_process_tree(self.client_process)
|
||||
self.client_process = None
|
||||
|
||||
if self.server_process:
|
||||
self.kill_process_tree(self.server_process)
|
||||
self.server_process = None
|
||||
|
||||
for proc in psutil.process_iter(["pid", "name", "cmdline"]):
|
||||
try:
|
||||
cmdline = " ".join(proc.cmdline())
|
||||
if any(name in cmdline for name in process_names):
|
||||
proc.kill()
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
continue
|
||||
|
||||
|
||||
class ExperimentRunner:
|
||||
def __init__(self):
|
||||
self.process_manager = ProcessManager()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def wait_for_server(self, port: int, timeout: int = 300) -> bool:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
while time.perf_counter() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(f"http://localhost:{port}/health")
|
||||
if response.status_code == 200:
|
||||
self.logger.debug(f"Server ready on port {port}")
|
||||
return True
|
||||
except requests.RequestException:
|
||||
time.sleep(2)
|
||||
return False
|
||||
|
||||
def run_task(self, config: TaskConfig) -> TaskResult:
|
||||
start_time = time.perf_counter()
|
||||
client_output = []
|
||||
|
||||
try:
|
||||
if not config.server_type:
|
||||
config.server_type = detect_server_type(config.server_cmd)
|
||||
|
||||
server_config = SERVER_DEFAULTS.get(config.server_type)
|
||||
if not server_config:
|
||||
raise ValueError(f"Unknown server type: {config.server_type}")
|
||||
|
||||
port = extract_port_from_command(config.server_cmd, config.server_type)
|
||||
|
||||
self.process_manager.cleanup(server_config.process_names)
|
||||
|
||||
self.logger.debug(f"Starting server: {config.name}")
|
||||
self.process_manager.server_process, _, server_threads = (
|
||||
self.process_manager.start_process(config.server_cmd, "SERVER")
|
||||
)
|
||||
|
||||
if not self.wait_for_server(port):
|
||||
raise TimeoutError("Server startup timeout")
|
||||
|
||||
time.sleep(10)
|
||||
|
||||
self.logger.debug("Starting client")
|
||||
self.process_manager.client_process, output_queue, client_threads = (
|
||||
self.process_manager.start_process(config.client_cmd, "CLIENT")
|
||||
)
|
||||
|
||||
returncode = self.process_manager.client_process.wait()
|
||||
|
||||
while True:
|
||||
try:
|
||||
line = output_queue.get_nowait()
|
||||
client_output.append(line)
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
if returncode != 0:
|
||||
raise RuntimeError(f"Client failed with code {returncode}")
|
||||
|
||||
# Parse and format the output
|
||||
full_output = "\n".join(client_output)
|
||||
formatted_output = parse_key_info(full_output)
|
||||
|
||||
return TaskResult(
|
||||
name=config.name,
|
||||
success=True,
|
||||
output=formatted_output,
|
||||
runtime=time.perf_counter() - start_time,
|
||||
timestamp=datetime.now().isoformat(),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return TaskResult(
|
||||
name=config.name,
|
||||
success=False,
|
||||
output=str(e),
|
||||
runtime=time.perf_counter() - start_time,
|
||||
timestamp=datetime.now().isoformat(),
|
||||
)
|
||||
|
||||
finally:
|
||||
if config.server_type in SERVER_DEFAULTS:
|
||||
self.process_manager.cleanup(
|
||||
SERVER_DEFAULTS[config.server_type].process_names
|
||||
)
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
def load_config(config_path: str) -> List[TaskConfig]:
|
||||
with open(config_path, "r") as f:
|
||||
config_data = yaml.safe_load(f)
|
||||
|
||||
configs = []
|
||||
for idx, entry in enumerate(config_data.get("tasks", [])):
|
||||
if not isinstance(entry, dict):
|
||||
raise ValueError(f"Invalid entry at index {idx}")
|
||||
|
||||
config = TaskConfig(
|
||||
server_cmd=entry.get("server_cmd"),
|
||||
client_cmd=entry.get("client_cmd"),
|
||||
name=entry.get("name", f"task-{idx+1}"),
|
||||
server_type=entry.get("server_type"),
|
||||
)
|
||||
|
||||
if not config.server_cmd or not config.client_cmd:
|
||||
raise ValueError(f"Missing commands in {config.name}")
|
||||
|
||||
configs.append(config)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def setup_logging(debug: bool = False):
|
||||
level = logging.DEBUG if debug else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.StreamHandler(), logging.FileHandler("experiment.log")],
|
||||
)
|
||||
|
||||
|
||||
def format_results(results: List[TaskResult]) -> str:
|
||||
"""Format experiment results in Markdown for GitHub step summary."""
|
||||
output = ["# Experiment Results\n"]
|
||||
|
||||
for result in results:
|
||||
output.append(f"## {result.name}")
|
||||
output.append(f"**Status**: {'✅ Success' if result.success else '❌ Failed'}")
|
||||
output.append(f"**Runtime**: {result.runtime:.2f} seconds")
|
||||
output.append(f"**Timestamp**: {result.timestamp}")
|
||||
output.append("\n**Output**:\n```")
|
||||
output.append(result.output)
|
||||
output.append("```\n")
|
||||
|
||||
return "\n".join(output)
|
||||
|
||||
|
||||
def get_bool_env_var(name: str, default: str = "false") -> bool:
|
||||
value = os.getenv(name, default)
|
||||
return value.lower() in ("true", "1")
|
||||
|
||||
|
||||
def write_in_github_step_summary(results: List[TaskResult]):
|
||||
"""Write formatted results to GitHub step summary."""
|
||||
if not os.environ.get("GITHUB_STEP_SUMMARY"):
|
||||
logging.warning("GITHUB_STEP_SUMMARY environment variable not set")
|
||||
return
|
||||
|
||||
formatted_content = format_results(results)
|
||||
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
|
||||
f.write(formatted_content)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Experiment Runner")
|
||||
parser.add_argument(
|
||||
"--config", type=str, required=True, help="Path to YAML config file"
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="Enable debug output")
|
||||
args = parser.parse_args()
|
||||
|
||||
setup_logging(args.debug)
|
||||
logger = logging.getLogger(__name__)
|
||||
results = []
|
||||
|
||||
try:
|
||||
configs = load_config(args.config)
|
||||
runner = ExperimentRunner()
|
||||
|
||||
for config in configs:
|
||||
logger.info(f"Running {config.name}")
|
||||
result = runner.run_task(config)
|
||||
results.append(result)
|
||||
|
||||
if get_bool_env_var("SGLANG_IS_IN_CI"):
|
||||
write_in_github_step_summary(results)
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
53
test/srt/hicache/test_hicache.py
Normal file
53
test/srt/hicache/test_hicache.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import is_hip, kill_process_tree
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
class TestHiCache(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--enable-hierarchical-cache",
|
||||
"--mem-fraction-static",
|
||||
0.7,
|
||||
"--hicache-size",
|
||||
100 if not _is_hip else 200,
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreaterEqual(metrics["score"], 0.65)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
67
test/srt/hicache/test_hicache_mla.py
Normal file
67
test/srt/hicache/test_hicache_mla.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import is_hip, kill_process_tree
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
if _is_hip:
|
||||
hicache_args = ["--hicache-size", 200]
|
||||
else:
|
||||
hicache_args = ["--hicache-ratio", 2]
|
||||
|
||||
|
||||
class TestHierarchicalMLA(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--enable-hierarchical-cache",
|
||||
]
|
||||
+ hicache_args,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreater(metrics["score"], 0.5)
|
||||
|
||||
def test_mgsm_en(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mgsm_en",
|
||||
num_examples=None,
|
||||
num_threads=1024,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreater(metrics["score"], 0.8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
51
test/srt/hicache/test_hicache_page.py
Normal file
51
test/srt/hicache/test_hicache_page.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestHiCachePage(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--enable-hierarchical-cache",
|
||||
"--page-size",
|
||||
32,
|
||||
"--hicache-write-policy",
|
||||
"write_back",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreaterEqual(metrics["score"], 0.65)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
57
test/srt/hicache/test_hicache_storage.py
Normal file
57
test/srt/hicache/test_hicache_storage.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import is_hip, kill_process_tree
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
class TestHiCache(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--enable-hierarchical-cache",
|
||||
"--mem-fraction-static",
|
||||
0.7,
|
||||
"--hicache-size",
|
||||
100 if not _is_hip else 200,
|
||||
"--page-size",
|
||||
"64",
|
||||
"--hicache-storage-backend",
|
||||
"file",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreaterEqual(metrics["score"], 0.65)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
42
test/srt/kv_cache_scales_llama3_1_8b.json
Normal file
42
test/srt/kv_cache_scales_llama3_1_8b.json
Normal file
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"model_type": "llama",
|
||||
"kv_cache": {
|
||||
"dtype": "float8_e4m3fn",
|
||||
"scaling_factor": {
|
||||
"0": {
|
||||
"0": 1,
|
||||
"1": 1,
|
||||
"2": 1,
|
||||
"3": 1,
|
||||
"4": 1,
|
||||
"5": 1,
|
||||
"6": 1,
|
||||
"7": 1,
|
||||
"8": 1,
|
||||
"9": 1,
|
||||
"10": 1,
|
||||
"11": 1,
|
||||
"12": 1,
|
||||
"13": 1,
|
||||
"14": 1,
|
||||
"15": 1,
|
||||
"16": 1,
|
||||
"17": 1,
|
||||
"18": 1,
|
||||
"19": 1,
|
||||
"20": 1,
|
||||
"21": 1,
|
||||
"22": 1,
|
||||
"23": 1,
|
||||
"24": 1,
|
||||
"25": 1,
|
||||
"26": 1,
|
||||
"27": 1,
|
||||
"28": 1,
|
||||
"29": 1,
|
||||
"30": 1,
|
||||
"31": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
42
test/srt/kv_cache_scales_llama3_8b.json
Normal file
42
test/srt/kv_cache_scales_llama3_8b.json
Normal file
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"model_type": "llama",
|
||||
"kv_cache": {
|
||||
"dtype": "float8_e4m3fn",
|
||||
"scaling_factor": {
|
||||
"0": {
|
||||
"0": 0.0408,
|
||||
"1": 0.0503,
|
||||
"2": 0.0667,
|
||||
"3": 0.0909,
|
||||
"4": 0.1135,
|
||||
"5": 0.127,
|
||||
"6": 0.1768,
|
||||
"7": 0.1488,
|
||||
"8": 0.1135,
|
||||
"9": 0.1203,
|
||||
"10": 0.1013,
|
||||
"11": 0.0842,
|
||||
"12": 0.1231,
|
||||
"13": 0.1096,
|
||||
"14": 0.1221,
|
||||
"15": 0.1013,
|
||||
"16": 0.1067,
|
||||
"17": 0.0952,
|
||||
"18": 0.0899,
|
||||
"19": 0.097,
|
||||
"20": 0.087,
|
||||
"21": 0.0994,
|
||||
"22": 0.0904,
|
||||
"23": 0.1013,
|
||||
"24": 0.1019,
|
||||
"25": 0.1053,
|
||||
"26": 0.1,
|
||||
"27": 0.0894,
|
||||
"28": 0.1013,
|
||||
"29": 0.1488,
|
||||
"30": 0.0766,
|
||||
"31": 0.0821
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
38
test/srt/kv_cache_scales_qwen2_1_5b.json
Normal file
38
test/srt/kv_cache_scales_qwen2_1_5b.json
Normal file
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"model_type": "qwen",
|
||||
"kv_cache": {
|
||||
"dtype": "float8_e4m3fn",
|
||||
"scaling_factor": {
|
||||
"0": {
|
||||
"0": 0.9846,
|
||||
"1": 0.0645,
|
||||
"2": 0.0731,
|
||||
"3": 0.0800,
|
||||
"4": 0.0748,
|
||||
"5": 0.0780,
|
||||
"6": 0.0702,
|
||||
"7": 0.0894,
|
||||
"8": 0.0410,
|
||||
"9": 0.0758,
|
||||
"10": 0.0556,
|
||||
"11": 0.0731,
|
||||
"12": 0.0899,
|
||||
"13": 0.0780,
|
||||
"14": 0.1441,
|
||||
"15": 0.0914,
|
||||
"16": 0.5614,
|
||||
"17": 0.1067,
|
||||
"18": 0.0537,
|
||||
"19": 0.0658,
|
||||
"20": 0.0523,
|
||||
"21": 0.0533,
|
||||
"22": 0.0699,
|
||||
"23": 0.0635,
|
||||
"24": 0.0588,
|
||||
"25": 0.0884,
|
||||
"26": 0.0947,
|
||||
"27": 0.1032
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
178
test/srt/lora/test_lora.py
Normal file
178
test/srt/lora/test_lora.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import random
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from utils import (
|
||||
ALL_OTHER_MULTI_LORA_MODELS,
|
||||
CI_MULTI_LORA_MODELS,
|
||||
TORCH_DTYPES,
|
||||
LoRAModelCase,
|
||||
)
|
||||
|
||||
from sglang.test.runners import HFRunner, SRTRunner
|
||||
from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci
|
||||
|
||||
TEST_MULTIPLE_BATCH_PROMPTS = [
|
||||
"""
|
||||
### Instruction:
|
||||
Tell me about llamas and alpacas
|
||||
### Response:
|
||||
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing.
|
||||
### Question 2:
|
||||
What do you know about llamas?
|
||||
### Answer:
|
||||
""",
|
||||
"""
|
||||
### Instruction:
|
||||
Write a poem about the transformers Python library.
|
||||
Mention the word "large language models" in that poem.
|
||||
### Response:
|
||||
The Transformers are large language models,
|
||||
They're used to make predictions on text.
|
||||
""",
|
||||
"AI is a field of computer science focused on",
|
||||
"Computer science is the study of",
|
||||
"Write a short story.",
|
||||
"What are the main components of a computer?",
|
||||
]
|
||||
|
||||
|
||||
class TestLoRA(CustomTestCase):
|
||||
def _create_test_samples(
|
||||
self, lora_adapter_paths: List[str], repeated_trials: int = 3
|
||||
):
|
||||
random.seed(42) # Ensure reproducibility
|
||||
|
||||
patterns = [
|
||||
[None, lora_adapter_paths[0], lora_adapter_paths[1]],
|
||||
[lora_adapter_paths[0], None, lora_adapter_paths[1]],
|
||||
[lora_adapter_paths[0], lora_adapter_paths[1], None],
|
||||
[None, lora_adapter_paths[1], None],
|
||||
[None, None, None],
|
||||
]
|
||||
|
||||
batches = [
|
||||
[random.choice(pattern) for _ in range(3)]
|
||||
for pattern in patterns
|
||||
for _ in range(repeated_trials)
|
||||
]
|
||||
|
||||
return batches
|
||||
|
||||
def ensure_reproducibility(self):
|
||||
seed = 42
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]):
|
||||
for model_case in model_cases:
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
max_new_tokens = 32
|
||||
backend = "triton"
|
||||
base_path = model_case.base
|
||||
lora_adapter_paths = [a.name for a in model_case.adaptors]
|
||||
assert len(lora_adapter_paths) >= 2
|
||||
|
||||
print(
|
||||
f"\n========== Testing multiple batches on base '{base_path}' with backend={backend}, dtype={torch_dtype} ---"
|
||||
)
|
||||
|
||||
# Initialize runners
|
||||
srt_runner = SRTRunner(
|
||||
base_path,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="generation",
|
||||
lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]],
|
||||
max_loras_per_batch=len(lora_adapter_paths) + 1,
|
||||
lora_backend=backend,
|
||||
sleep_on_idle=True, # Eliminate non-determinism by forcing all requests to be processed in one batch.
|
||||
attention_backend="torch_native",
|
||||
)
|
||||
hf_runner = HFRunner(
|
||||
base_path, torch_dtype=torch_dtype, model_type="generation"
|
||||
)
|
||||
|
||||
batches = self._create_test_samples(lora_adapter_paths)
|
||||
with srt_runner, hf_runner:
|
||||
for i, lora_paths in enumerate(batches, start=1):
|
||||
prompts = [
|
||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS) for _ in range(3)
|
||||
]
|
||||
print(
|
||||
f"\n--- Running Batch {i} --- prompts: {prompts}, lora_paths: {lora_paths}"
|
||||
)
|
||||
|
||||
self.ensure_reproducibility()
|
||||
srt_outputs = srt_runner.batch_forward(
|
||||
prompts,
|
||||
max_new_tokens=max_new_tokens,
|
||||
lora_paths=lora_paths,
|
||||
)
|
||||
|
||||
self.ensure_reproducibility()
|
||||
hf_outputs = hf_runner.forward(
|
||||
prompts,
|
||||
max_new_tokens=max_new_tokens,
|
||||
lora_paths=lora_paths,
|
||||
)
|
||||
|
||||
print("SRT outputs:", [s for s in srt_outputs.output_strs])
|
||||
print("HF outputs:", [s for s in hf_outputs.output_strs])
|
||||
|
||||
for srt_out, hf_out in zip(
|
||||
srt_outputs.output_strs, hf_outputs.output_strs
|
||||
):
|
||||
srt_str = srt_out.strip()
|
||||
hf_str = hf_out.strip()
|
||||
rouge_tol = model_case.rouge_l_tolerance
|
||||
rouge_score = calculate_rouge_l([srt_str], [hf_str])[0]
|
||||
if rouge_score < rouge_tol:
|
||||
raise AssertionError(
|
||||
f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} "
|
||||
f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'"
|
||||
)
|
||||
|
||||
print(f"--- Batch {i} Comparison Passed --- ")
|
||||
|
||||
def test_ci_lora_models(self):
|
||||
self._run_lora_multiple_batch_on_model_cases(CI_MULTI_LORA_MODELS)
|
||||
|
||||
def test_all_lora_models(self):
|
||||
if is_in_ci():
|
||||
return
|
||||
|
||||
filtered_models = []
|
||||
for model_case in ALL_OTHER_MULTI_LORA_MODELS:
|
||||
if "ONLY_RUN" in os.environ and os.environ["ONLY_RUN"] != model_case.base:
|
||||
continue
|
||||
filtered_models.append(model_case)
|
||||
|
||||
self._run_lora_multiple_batch_on_model_cases(filtered_models)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
mp.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
unittest.main(warnings="ignore")
|
||||
76
test/srt/lora/test_lora_backend.py
Normal file
76
test/srt/lora/test_lora_backend.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
from utils import (
|
||||
ALL_OTHER_LORA_MODELS,
|
||||
BACKENDS,
|
||||
CI_LORA_MODELS,
|
||||
DEFAULT_PROMPTS,
|
||||
TORCH_DTYPES,
|
||||
LoRAModelCase,
|
||||
run_lora_test_one_by_one,
|
||||
)
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase, is_in_ci
|
||||
|
||||
|
||||
class TestLoRABackend(CustomTestCase):
|
||||
|
||||
def _run_backend_on_model_cases(self, model_cases: List[LoRAModelCase]):
|
||||
for model_case in model_cases:
|
||||
# If skip_long_prompt is True, filter out prompts longer than 1000 characters
|
||||
prompts = (
|
||||
DEFAULT_PROMPTS
|
||||
if not model_case.skip_long_prompt
|
||||
else [p for p in DEFAULT_PROMPTS if len(p) < 1000]
|
||||
)
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
for backend in BACKENDS:
|
||||
run_lora_test_one_by_one(
|
||||
prompts,
|
||||
model_case,
|
||||
torch_dtype,
|
||||
max_new_tokens=32,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
def test_ci_lora_models(self):
|
||||
self._run_backend_on_model_cases(CI_LORA_MODELS)
|
||||
|
||||
def test_all_lora_models(self):
|
||||
if is_in_ci():
|
||||
return
|
||||
|
||||
# Retain ONLY_RUN check here
|
||||
filtered_models = []
|
||||
for model_case in ALL_OTHER_LORA_MODELS:
|
||||
if "ONLY_RUN" in os.environ and os.environ["ONLY_RUN"] != model_case.base:
|
||||
continue
|
||||
filtered_models.append(model_case)
|
||||
|
||||
self._run_backend_on_model_cases(filtered_models)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
mp.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
unittest.main(warnings="ignore")
|
||||
110
test/srt/lora/test_lora_cuda_graph.py
Normal file
110
test/srt/lora/test_lora_cuda_graph.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
from utils import (
|
||||
ALL_OTHER_LORA_MODELS,
|
||||
CI_LORA_MODELS,
|
||||
DEFAULT_PROMPTS,
|
||||
TORCH_DTYPES,
|
||||
LoRAModelCase,
|
||||
run_lora_test_by_batch,
|
||||
run_lora_test_one_by_one,
|
||||
)
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase, is_in_ci
|
||||
|
||||
TEST_CUDA_GRAPH_PADDING_PROMPTS = [
|
||||
"AI is a field of computer science focused on",
|
||||
"""
|
||||
### Instruction:
|
||||
Tell me about llamas and alpacas
|
||||
### Response:
|
||||
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing.
|
||||
### Question 2:
|
||||
What do you know about llamas?
|
||||
### Answer:
|
||||
""",
|
||||
"Computer science is the study of",
|
||||
]
|
||||
|
||||
|
||||
class TestLoRACudaGraph(CustomTestCase):
|
||||
|
||||
def _run_without_cuda_graph_on_model_cases(self, model_cases: List[LoRAModelCase]):
|
||||
# Since we have already enabled CUDA graph by default in other lora tests,
|
||||
# we only need to run lora tests without CUDA graph here.
|
||||
for model_case in model_cases:
|
||||
# If skip_long_prompt is True, filter out prompts longer than 1000 characters
|
||||
prompts = (
|
||||
DEFAULT_PROMPTS
|
||||
if not model_case.skip_long_prompt
|
||||
else [p for p in DEFAULT_PROMPTS if len(p) < 1000]
|
||||
)
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
run_lora_test_one_by_one(
|
||||
prompts,
|
||||
model_case,
|
||||
torch_dtype,
|
||||
max_new_tokens=32,
|
||||
backend="triton",
|
||||
disable_cuda_graph=True,
|
||||
test_tag="without_cuda_graph",
|
||||
)
|
||||
|
||||
def _run_cuda_graph_padding_on_model_cases(self, model_cases: List[LoRAModelCase]):
|
||||
for model_case in model_cases:
|
||||
# Run a batch size of 3, which will not be captured by CUDA graph and need padding
|
||||
prompts = TEST_CUDA_GRAPH_PADDING_PROMPTS
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
run_lora_test_by_batch(
|
||||
prompts,
|
||||
model_case,
|
||||
torch_dtype,
|
||||
max_new_tokens=32,
|
||||
backend="triton",
|
||||
disable_cuda_graph=False,
|
||||
test_tag="cuda_graph_padding",
|
||||
)
|
||||
|
||||
def test_ci_lora_models(self):
|
||||
self._run_without_cuda_graph_on_model_cases(CI_LORA_MODELS)
|
||||
self._run_cuda_graph_padding_on_model_cases(CI_LORA_MODELS)
|
||||
|
||||
def test_all_lora_models(self):
|
||||
if is_in_ci():
|
||||
return
|
||||
|
||||
# Retain ONLY_RUN check here
|
||||
filtered_models = []
|
||||
for model_case in ALL_OTHER_LORA_MODELS:
|
||||
if "ONLY_RUN" in os.environ and os.environ["ONLY_RUN"] != model_case.base:
|
||||
continue
|
||||
filtered_models.append(model_case)
|
||||
|
||||
self._run_without_cuda_graph_on_model_cases(filtered_models)
|
||||
self._run_cuda_graph_padding_on_model_cases(filtered_models)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
mp.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
unittest.main(warnings="ignore")
|
||||
146
test/srt/lora/test_lora_eviction.py
Normal file
146
test/srt/lora/test_lora_eviction.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import contextlib
|
||||
import multiprocessing as mp
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.test.runners import SRTRunner
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
PROMPTS = [
|
||||
"AI is a field of computer science focused on",
|
||||
"""
|
||||
### Instruction:
|
||||
Compose a SQL query that uses the following table: users, and returns the user_id and name of all users whose name that does not have a duplicate in the table.
|
||||
### Response:
|
||||
SELECT user_id, name FROM users WHERE name LIKE 'A%';
|
||||
""",
|
||||
]
|
||||
|
||||
ADAPTERS = [
|
||||
"faridlazuarda/valadapt-llama-3.1-8B-it-chinese", # target_modules = q, v
|
||||
"philschmid/code-llama-3-1-8b-text-to-sql-lora", # target_modules = q, k, v, o, gate, up, down
|
||||
]
|
||||
|
||||
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def dynamically_loaded_adapter(runner, lora_path: str, lora_name: str):
|
||||
"""A context manager to load and automatically unload a LoRA adapter."""
|
||||
try:
|
||||
runner.load_lora_adapter(lora_name=lora_name, lora_path=lora_path)
|
||||
yield
|
||||
finally:
|
||||
runner.unload_lora_adapter(lora_name=lora_name)
|
||||
|
||||
|
||||
class TestLoRAEviction(CustomTestCase):
|
||||
def test_lora_eviction_with_different_target_modules(self):
|
||||
"""
|
||||
Test LoRA eviction with different target modules.
|
||||
|
||||
This test runs inference against two LoRA adapters in different orders to force eviction behavior, and ensures
|
||||
that the outputs of the same (adapter, prompt) pair are consistent across runs.
|
||||
"""
|
||||
output_history = {}
|
||||
self._run_test(ADAPTERS, output_history, reverse=False)
|
||||
self._run_test(ADAPTERS, output_history, reverse=True)
|
||||
|
||||
def test_lora_eviction_with_reused_lora_name(self):
|
||||
"""
|
||||
Test LoRA eviction with reused LoRA names.
|
||||
|
||||
This test runs inference against two LoRA adapters with the same name to ensure that the eviction behavior
|
||||
works correctly when reusing LoRA names.
|
||||
"""
|
||||
output_history = {}
|
||||
self._run_test(ADAPTERS, output_history, reuse_lora_name=True, repeat=1)
|
||||
self._run_test(ADAPTERS, output_history, reuse_lora_name=False, repeat=1)
|
||||
|
||||
def _run_test(
|
||||
self,
|
||||
lora_paths: List[str],
|
||||
output_history: Dict[Tuple[str, str], str],
|
||||
reverse: bool = False,
|
||||
repeat: int = 2,
|
||||
reuse_lora_name: bool = False,
|
||||
):
|
||||
REUSED_LORA_NAME = "lora"
|
||||
max_new_tokens = 256
|
||||
backend = "triton"
|
||||
torch_dtype = torch.float16
|
||||
base_path = BASE_MODEL
|
||||
assert len(lora_paths) >= 2
|
||||
|
||||
initial_lora_paths = lora_paths if not reuse_lora_name else None
|
||||
# Initialize runners
|
||||
with SRTRunner(
|
||||
base_path,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="generation",
|
||||
lora_paths=initial_lora_paths,
|
||||
max_loras_per_batch=1,
|
||||
lora_backend=backend,
|
||||
enable_lora=True,
|
||||
max_lora_rank=256,
|
||||
lora_target_modules=["all"],
|
||||
) as srt_runner:
|
||||
adapter_sequence = lora_paths if not reverse else lora_paths[::-1]
|
||||
|
||||
for i in range(repeat):
|
||||
for j, lora_path in enumerate(adapter_sequence):
|
||||
print(
|
||||
f"\n========== Testing LoRA eviction with adapter '{lora_path}' (#{j + 1}/{len(adapter_sequence)}), reuse_lora_name: {reuse_lora_name}, reversed: {reverse}, repeat: {i + 1}/{repeat} ---"
|
||||
)
|
||||
|
||||
lora_name = REUSED_LORA_NAME if reuse_lora_name else lora_path
|
||||
context = (
|
||||
dynamically_loaded_adapter(srt_runner, lora_path, lora_name)
|
||||
if reuse_lora_name
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
with context:
|
||||
for prompt in PROMPTS:
|
||||
print("\nprompt:\n", prompt)
|
||||
srt_outputs = srt_runner.forward(
|
||||
[prompt],
|
||||
max_new_tokens=max_new_tokens,
|
||||
lora_paths=[lora_name],
|
||||
)
|
||||
output = srt_outputs.output_strs[0].strip()
|
||||
print("\noutput:\n", output)
|
||||
|
||||
prev_output = output_history.get((lora_path, prompt))
|
||||
if prev_output is not None:
|
||||
self.assertEqual(
|
||||
prev_output,
|
||||
output,
|
||||
f"Output mismatch for adapter {lora_path} and prompt '{prompt}' on repeat {j + 1}, previous: '{prev_output}', current: '{output}'.",
|
||||
)
|
||||
else:
|
||||
output_history[(lora_path, prompt)] = output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
mp.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
unittest.main(warnings="ignore")
|
||||
208
test/srt/lora/test_lora_qwen3.py
Normal file
208
test/srt/lora/test_lora_qwen3.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# Copyright 2023-2025 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import random
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
from utils import TORCH_DTYPES, LoRAAdaptor, LoRAModelCase
|
||||
|
||||
from sglang.test.runners import HFRunner, SRTRunner
|
||||
from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci
|
||||
|
||||
LORA_MODELS_QWEN3 = [
|
||||
LoRAModelCase(
|
||||
base="Qwen/Qwen3-4B",
|
||||
adaptors=[
|
||||
LoRAAdaptor(
|
||||
name="nissenj/Qwen3-4B-lora-v2",
|
||||
prefill_tolerance=3e-1,
|
||||
),
|
||||
LoRAAdaptor(
|
||||
name="y9760210/Qwen3-4B-lora_model",
|
||||
prefill_tolerance=3e-1,
|
||||
),
|
||||
],
|
||||
max_loras_per_batch=2,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
TEST_MULTIPLE_BATCH_PROMPTS = [
|
||||
"""
|
||||
### Instruction:
|
||||
Tell me about llamas and alpacas
|
||||
### Response:
|
||||
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing.
|
||||
### Question 2:
|
||||
What do you know about llamas?
|
||||
### Answer:
|
||||
""",
|
||||
"""
|
||||
### Instruction:
|
||||
Write a poem about the transformers Python library.
|
||||
Mention the word "large language models" in that poem.
|
||||
### Response:
|
||||
The Transformers are large language models,
|
||||
They're used to make predictions on text.
|
||||
""",
|
||||
# "AI is a field of computer science focused on", TODO: Add it back after fixing its bug
|
||||
"Computer science is the study of",
|
||||
"Write a short story.",
|
||||
"What are the main components of a computer?",
|
||||
]
|
||||
|
||||
|
||||
class TestLoRA(CustomTestCase):
|
||||
|
||||
def _run_lora_multiple_batch_on_model_cases(self, model_cases: List[LoRAModelCase]):
|
||||
for model_case in model_cases:
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
max_new_tokens = 10
|
||||
backend = "triton"
|
||||
base_path = model_case.base
|
||||
lora_adapter_paths = [a.name for a in model_case.adaptors]
|
||||
assert len(lora_adapter_paths) >= 2
|
||||
|
||||
batches = [
|
||||
(
|
||||
[
|
||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
||||
],
|
||||
[
|
||||
None,
|
||||
lora_adapter_paths[0],
|
||||
lora_adapter_paths[1],
|
||||
],
|
||||
),
|
||||
(
|
||||
[
|
||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
||||
],
|
||||
[
|
||||
lora_adapter_paths[0],
|
||||
None,
|
||||
lora_adapter_paths[1],
|
||||
],
|
||||
),
|
||||
(
|
||||
[
|
||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
||||
],
|
||||
[lora_adapter_paths[0], lora_adapter_paths[1], None],
|
||||
),
|
||||
(
|
||||
[
|
||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
||||
],
|
||||
[None, lora_adapter_paths[1], None],
|
||||
),
|
||||
(
|
||||
[
|
||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
||||
random.choice(TEST_MULTIPLE_BATCH_PROMPTS),
|
||||
],
|
||||
[None, None, None],
|
||||
),
|
||||
]
|
||||
|
||||
print(
|
||||
f"\n========== Testing multiple batches on base '{base_path}' with backend={backend}, dtype={torch_dtype} ---"
|
||||
)
|
||||
|
||||
# Initialize runners
|
||||
srt_runner = SRTRunner(
|
||||
base_path,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="generation",
|
||||
lora_paths=[lora_adapter_paths[0], lora_adapter_paths[1]],
|
||||
max_loras_per_batch=len(lora_adapter_paths) + 1,
|
||||
lora_backend=backend,
|
||||
)
|
||||
hf_runner = HFRunner(
|
||||
base_path,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="generation",
|
||||
patch_model_do_sample_false=True,
|
||||
)
|
||||
|
||||
with srt_runner, hf_runner:
|
||||
for i, (prompts, lora_paths) in enumerate(batches):
|
||||
print(
|
||||
f"\n--- Running Batch {i+1} --- prompts: {prompts}, lora_paths: {lora_paths}"
|
||||
)
|
||||
|
||||
srt_outputs = srt_runner.batch_forward(
|
||||
prompts,
|
||||
max_new_tokens=max_new_tokens,
|
||||
lora_paths=lora_paths,
|
||||
)
|
||||
|
||||
hf_outputs = hf_runner.forward(
|
||||
prompts,
|
||||
max_new_tokens=max_new_tokens,
|
||||
lora_paths=lora_paths,
|
||||
)
|
||||
|
||||
print("SRT outputs:", [s for s in srt_outputs.output_strs])
|
||||
print("HF outputs:", [s for s in hf_outputs.output_strs])
|
||||
|
||||
for srt_out, hf_out in zip(
|
||||
srt_outputs.output_strs, hf_outputs.output_strs
|
||||
):
|
||||
srt_str = srt_out.strip()
|
||||
hf_str = hf_out.strip()
|
||||
rouge_tol = model_case.rouge_l_tolerance
|
||||
rouge_score = calculate_rouge_l([srt_str], [hf_str])[0]
|
||||
if rouge_score < rouge_tol:
|
||||
raise AssertionError(
|
||||
f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} "
|
||||
f"for base '{base_path}', adaptor '{lora_paths}', backend '{backend}', prompt: '{prompts}...'"
|
||||
)
|
||||
|
||||
print(f"--- Batch {i+1} Comparison Passed --- ")
|
||||
|
||||
def test_ci_lora_models(self):
|
||||
self._run_lora_multiple_batch_on_model_cases(LORA_MODELS_QWEN3)
|
||||
|
||||
def test_all_lora_models(self):
|
||||
if is_in_ci():
|
||||
return
|
||||
qwen_filtered_models = []
|
||||
for model_case in LORA_MODELS_QWEN3:
|
||||
if "ONLY_RUN" in os.environ and os.environ["ONLY_RUN"] != model_case.base:
|
||||
continue
|
||||
qwen_filtered_models.append(model_case)
|
||||
|
||||
self._run_lora_multiple_batch_on_model_cases(qwen_filtered_models)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
mp.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
unittest.main(warnings="ignore")
|
||||
83
test/srt/lora/test_lora_radix_cache.py
Normal file
83
test/srt/lora/test_lora_radix_cache.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import multiprocessing as mp
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from utils import CI_MULTI_LORA_MODELS, DEFAULT_PROMPTS, run_lora_test_one_by_one
|
||||
|
||||
from sglang.test.runners import HFRunner, SRTRunner
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
PROMPTS = [
|
||||
"AI is a field of computer science focused on",
|
||||
"""
|
||||
### Instruction:
|
||||
Tell me about llamas and alpacas
|
||||
### Response:
|
||||
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids.
|
||||
### Question:
|
||||
What do you know about llamas?
|
||||
### Answer:
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
class TestLoRARadixCache(CustomTestCase):
|
||||
|
||||
def test_lora_radix_cache(self):
|
||||
# Here we need a model case with multiple adaptors for testing correctness of radix cache
|
||||
model_case = CI_MULTI_LORA_MODELS[0]
|
||||
|
||||
torch_dtype = torch.float16
|
||||
max_new_tokens = 32
|
||||
backend = "triton"
|
||||
batch_prompts = (
|
||||
PROMPTS
|
||||
if not model_case.skip_long_prompt
|
||||
else [p for p in PROMPTS if len(p) < 1000]
|
||||
)
|
||||
|
||||
# Test lora with radix cache
|
||||
run_lora_test_one_by_one(
|
||||
batch_prompts,
|
||||
model_case,
|
||||
torch_dtype,
|
||||
max_new_tokens=max_new_tokens,
|
||||
backend=backend,
|
||||
disable_radix_cache=False,
|
||||
test_tag="lora-with-radix-cache",
|
||||
)
|
||||
|
||||
# Test lora without radix cache
|
||||
run_lora_test_one_by_one(
|
||||
batch_prompts,
|
||||
model_case,
|
||||
torch_dtype,
|
||||
max_new_tokens=max_new_tokens,
|
||||
backend=backend,
|
||||
disable_radix_cache=True,
|
||||
test_tag="lora-without-radix-cache",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
mp.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
unittest.main(warnings="ignore")
|
||||
78
test/srt/lora/test_lora_tp.py
Normal file
78
test/srt/lora/test_lora_tp.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
from utils import (
|
||||
ALL_OTHER_LORA_MODELS,
|
||||
CI_LORA_MODELS,
|
||||
DEFAULT_PROMPTS,
|
||||
TORCH_DTYPES,
|
||||
LoRAModelCase,
|
||||
run_lora_test_one_by_one,
|
||||
)
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase, is_in_ci
|
||||
|
||||
|
||||
class TestLoRATP(CustomTestCase):
|
||||
|
||||
def _run_tp_on_model_cases(self, model_cases: List[LoRAModelCase]):
|
||||
tp_list = [2] # Define TP sizes to iterate over
|
||||
for model_case in model_cases:
|
||||
# If skip_long_prompt is True, filter out prompts longer than 1000 characters
|
||||
prompts = (
|
||||
DEFAULT_PROMPTS
|
||||
if not model_case.skip_long_prompt
|
||||
else [p for p in DEFAULT_PROMPTS if len(p) < 1000]
|
||||
)
|
||||
for tp_size in tp_list:
|
||||
model_case.tp_size = tp_size
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
run_lora_test_one_by_one(
|
||||
prompts,
|
||||
model_case,
|
||||
torch_dtype,
|
||||
max_new_tokens=32,
|
||||
backend="triton",
|
||||
test_tag=f"tp={tp_size}",
|
||||
)
|
||||
|
||||
def test_ci_lora_models(self):
|
||||
self._run_tp_on_model_cases(CI_LORA_MODELS)
|
||||
|
||||
def test_all_lora_models(self):
|
||||
if is_in_ci():
|
||||
return
|
||||
|
||||
# Retain ONLY_RUN check here
|
||||
filtered_models = []
|
||||
for model_case in ALL_OTHER_LORA_MODELS:
|
||||
if "ONLY_RUN" in os.environ and os.environ["ONLY_RUN"] != model_case.base:
|
||||
continue
|
||||
filtered_models.append(model_case)
|
||||
|
||||
self._run_tp_on_model_cases(filtered_models)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
mp.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
unittest.main(warnings="ignore")
|
||||
1306
test/srt/lora/test_lora_update.py
Normal file
1306
test/srt/lora/test_lora_update.py
Normal file
File diff suppressed because it is too large
Load Diff
90
test/srt/lora/test_multi_lora_backend.py
Normal file
90
test/srt/lora/test_multi_lora_backend.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
from utils import (
|
||||
ALL_OTHER_MULTI_LORA_MODELS,
|
||||
BACKENDS,
|
||||
CI_MULTI_LORA_MODELS,
|
||||
TORCH_DTYPES,
|
||||
LoRAModelCase,
|
||||
run_lora_test_one_by_one,
|
||||
)
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase, is_in_ci
|
||||
|
||||
# All prompts are used at once in a batch.
|
||||
PROMPTS = [
|
||||
"AI is a field of computer science focused on",
|
||||
"""
|
||||
### Instruction:
|
||||
Tell me about llamas and alpacas
|
||||
### Response:
|
||||
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids.
|
||||
### Question:
|
||||
What do you know about llamas?
|
||||
### Answer:
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
class TestMultiLoRABackend(CustomTestCase):
|
||||
|
||||
def _run_multi_lora_test_on_model_cases(self, model_cases: List[LoRAModelCase]):
|
||||
for model_case in model_cases:
|
||||
# If skip_long_prompt is True, filter out prompts longer than 1000 characters.
|
||||
batch_prompts = (
|
||||
PROMPTS
|
||||
if not model_case.skip_long_prompt
|
||||
else [p for p in PROMPTS if len(p) < 1000]
|
||||
)
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
for backend in BACKENDS:
|
||||
run_lora_test_one_by_one(
|
||||
batch_prompts,
|
||||
model_case,
|
||||
torch_dtype,
|
||||
max_new_tokens=32,
|
||||
backend=backend,
|
||||
test_tag="multi-lora-backend",
|
||||
)
|
||||
|
||||
def test_ci_lora_models(self):
|
||||
self._run_multi_lora_test_on_model_cases(CI_MULTI_LORA_MODELS)
|
||||
|
||||
def test_all_lora_models(self):
|
||||
if is_in_ci():
|
||||
return
|
||||
|
||||
# Retain ONLY_RUN check here
|
||||
filtered_models = []
|
||||
for model_case in ALL_OTHER_MULTI_LORA_MODELS:
|
||||
if "ONLY_RUN" in os.environ and os.environ["ONLY_RUN"] != model_case.base:
|
||||
continue
|
||||
filtered_models.append(model_case)
|
||||
|
||||
self._run_multi_lora_test_on_model_cases(filtered_models)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
mp.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
unittest.main(warnings="ignore")
|
||||
388
test/srt/lora/utils.py
Normal file
388
test/srt/lora/utils.py
Normal file
@@ -0,0 +1,388 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import dataclasses
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.test.runners import HFRunner, SRTRunner
|
||||
from sglang.test.test_utils import calculate_rouge_l
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LoRAAdaptor:
|
||||
name: str
|
||||
prefill_tolerance: float = None
|
||||
decode_tolerance: float = None
|
||||
rouge_l_tolerance: float = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LoRAModelCase:
|
||||
base: str
|
||||
adaptors: List[LoRAAdaptor]
|
||||
tp_size: int = 1
|
||||
prefill_tolerance: float = 1e-1
|
||||
decode_tolerance: float = 1e-1
|
||||
rouge_l_tolerance: float = 1.0
|
||||
max_loras_per_batch: int = 1
|
||||
skip_long_prompt: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if len(self.adaptors) > self.max_loras_per_batch:
|
||||
raise ValueError(
|
||||
f"For base '{self.base}', number of adaptors ({len(self.adaptors)}) "
|
||||
f"must be <= max_loras_per_batch ({self.max_loras_per_batch})"
|
||||
)
|
||||
|
||||
|
||||
TORCH_DTYPES = [torch.float16]
|
||||
BACKENDS = ["triton"]
|
||||
DEFAULT_PROMPTS = [
|
||||
"AI is a field of computer science focused on",
|
||||
"""
|
||||
### Instruction:
|
||||
Tell me about llamas and alpacas
|
||||
### Response:
|
||||
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing.
|
||||
### Question 2:
|
||||
What do you know about llamas?
|
||||
### Answer:
|
||||
""",
|
||||
]
|
||||
|
||||
CI_LORA_MODELS = [
|
||||
LoRAModelCase(
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
adaptors=[
|
||||
LoRAAdaptor(
|
||||
name="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
|
||||
),
|
||||
],
|
||||
max_loras_per_batch=1,
|
||||
),
|
||||
]
|
||||
|
||||
ALL_OTHER_LORA_MODELS = [
|
||||
LoRAModelCase(
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
adaptors=[
|
||||
LoRAAdaptor(
|
||||
name="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
prefill_tolerance=1e-1,
|
||||
),
|
||||
],
|
||||
max_loras_per_batch=1,
|
||||
),
|
||||
LoRAModelCase(
|
||||
base="meta-llama/Llama-2-7b-hf",
|
||||
adaptors=[LoRAAdaptor(name="winddude/wizardLM-LlaMA-LoRA-7B")],
|
||||
max_loras_per_batch=2,
|
||||
),
|
||||
]
|
||||
|
||||
CI_MULTI_LORA_MODELS = [
|
||||
# multi-rank case
|
||||
LoRAModelCase(
|
||||
base="meta-llama/Llama-2-7b-hf",
|
||||
adaptors=[
|
||||
LoRAAdaptor(
|
||||
name="winddude/wizardLM-LlaMA-LoRA-7B",
|
||||
prefill_tolerance=1e-1,
|
||||
),
|
||||
LoRAAdaptor(
|
||||
name="RuterNorway/Llama-2-7b-chat-norwegian-LoRa",
|
||||
prefill_tolerance=3e-1,
|
||||
),
|
||||
],
|
||||
max_loras_per_batch=2,
|
||||
),
|
||||
]
|
||||
|
||||
ALL_OTHER_MULTI_LORA_MODELS = [
|
||||
LoRAModelCase(
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
adaptors=[
|
||||
LoRAAdaptor(
|
||||
name="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
|
||||
prefill_tolerance=1e-1,
|
||||
),
|
||||
LoRAAdaptor(
|
||||
name="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
prefill_tolerance=1e-1,
|
||||
),
|
||||
],
|
||||
max_loras_per_batch=2,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def run_lora_test_one_by_one(
|
||||
prompts: List[str],
|
||||
model_case: LoRAModelCase,
|
||||
torch_dtype: torch.dtype,
|
||||
max_new_tokens: int,
|
||||
backend: str,
|
||||
disable_cuda_graph: bool = False,
|
||||
disable_radix_cache: bool = False,
|
||||
mem_fraction_static: float = 0.88,
|
||||
test_tag: str = "",
|
||||
):
|
||||
"""
|
||||
Input a batch of prompts, and run lora tests one by one with several generate requests
|
||||
(each request will have bs=1).
|
||||
For prompt0, prompt1, ..., promptN,
|
||||
we will use adaptor0, adaptor1, ..., adaptorN included in model case,
|
||||
We will then compare the outputs of HF and SRT with and without LoRA.
|
||||
If number of prompts is larger than number of adaptors,
|
||||
the prompt i will use adaptor i % (number of adaptors).
|
||||
|
||||
Args:
|
||||
prompts (List[str]): The batch of prompts to test.
|
||||
model_case (LoRAModelCase): The model case to test.
|
||||
torch_dtype (torch.dtype): The torch dtype to use.
|
||||
max_new_tokens (int): The maximum number of new tokens to generate.
|
||||
backend (str): The lora backend to use.
|
||||
disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False.
|
||||
disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to False.
|
||||
mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88.
|
||||
test_tag (str, optional): The tag to use for the test. Defaults to "".
|
||||
"""
|
||||
base_path = model_case.base
|
||||
|
||||
# Create used adaptors for each prompt in batch
|
||||
i, adaptors = 0, []
|
||||
for _ in range(len(prompts)):
|
||||
adaptors.append(model_case.adaptors[i])
|
||||
i = (i + 1) % len(model_case.adaptors)
|
||||
adaptor_names = [adaptor.name for adaptor in adaptors]
|
||||
|
||||
print(
|
||||
f"\n========== Testing {test_tag} on base '{model_case.base}' with backend={backend}, dtype={torch_dtype} --- "
|
||||
f"Using prompts {[p[:50] for p in prompts]} with adaptors: {adaptor_names} ---"
|
||||
)
|
||||
with SRTRunner(
|
||||
base_path,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="generation",
|
||||
tp_size=model_case.tp_size,
|
||||
lora_paths=[
|
||||
adaptor.name for adaptor in model_case.adaptors if adaptor.name is not None
|
||||
],
|
||||
max_loras_per_batch=model_case.max_loras_per_batch,
|
||||
lora_backend=backend,
|
||||
disable_cuda_graph=disable_cuda_graph,
|
||||
disable_radix_cache=disable_radix_cache,
|
||||
mem_fraction_static=mem_fraction_static,
|
||||
) as srt_runner:
|
||||
srt_outputs = srt_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens, lora_paths=adaptor_names
|
||||
)
|
||||
|
||||
with SRTRunner(
|
||||
base_path,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="generation",
|
||||
tp_size=model_case.tp_size,
|
||||
mem_fraction_static=mem_fraction_static,
|
||||
) as srt_runner:
|
||||
srt_no_lora_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
||||
|
||||
with HFRunner(
|
||||
base_path, torch_dtype=torch_dtype, model_type="generation"
|
||||
) as hf_runner:
|
||||
hf_outputs = hf_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens, lora_paths=adaptor_names
|
||||
)
|
||||
hf_no_lora_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
||||
|
||||
# Compare prefill stage logprobs (HF vs SRTRunner with LoRA)
|
||||
for i in range(len(prompts)):
|
||||
adaptor = adaptors[i]
|
||||
# Use individual adaptor tolerances if set, otherwise use model defaults
|
||||
prefill_tol = (
|
||||
adaptor.prefill_tolerance
|
||||
if adaptor.prefill_tolerance is not None
|
||||
else model_case.prefill_tolerance
|
||||
)
|
||||
decode_tol = (
|
||||
adaptor.decode_tolerance
|
||||
if adaptor.decode_tolerance is not None
|
||||
else model_case.decode_tolerance
|
||||
)
|
||||
rouge_tol = (
|
||||
adaptor.rouge_l_tolerance
|
||||
if adaptor.rouge_l_tolerance is not None
|
||||
else model_case.rouge_l_tolerance
|
||||
)
|
||||
# Compare prefill stage logprobs (HF vs SRTRunner with LoRA)
|
||||
hf_prefill = torch.tensor(hf_outputs.top_input_logprobs[i])
|
||||
srt_prefill = torch.tensor(srt_outputs.top_input_logprobs[i])
|
||||
max_prefill_diff = torch.max(torch.abs(hf_prefill - srt_prefill))
|
||||
print("Max prefill diff (HF vs SRT):", max_prefill_diff)
|
||||
|
||||
# Compare decode stage logprobs
|
||||
hf_decode = torch.tensor(hf_outputs.top_output_logprobs[i])
|
||||
srt_decode = torch.tensor(srt_outputs.top_output_logprobs[i])
|
||||
max_decode_diff = torch.max(torch.abs(hf_decode - srt_decode))
|
||||
print("Max decode diff (HF vs SRT):", max_decode_diff)
|
||||
|
||||
srt_output_str = srt_outputs.output_strs[i].strip()
|
||||
hf_output_str = hf_outputs.output_strs[i].strip()
|
||||
rouge_score = calculate_rouge_l([srt_output_str], [hf_output_str])[0]
|
||||
print("ROUGE-L score:", rouge_score)
|
||||
print("SRT output:", srt_output_str)
|
||||
print("HF output:", hf_output_str)
|
||||
|
||||
# Additional: compare prefill outputs between base model (no LoRA) and LoRA model for reference
|
||||
hf_no_lora_prefill = torch.tensor(hf_no_lora_outputs.top_input_logprobs[i])
|
||||
srt_no_lora_prefill = torch.tensor(srt_no_lora_outputs.top_input_logprobs[i])
|
||||
print(
|
||||
"Max diff (SRT base vs SRT LoRA prefill):",
|
||||
torch.max(torch.abs(srt_no_lora_prefill - srt_prefill)),
|
||||
)
|
||||
print(
|
||||
"Max diff (HF base vs HF LoRA prefill):",
|
||||
torch.max(torch.abs(hf_no_lora_prefill - hf_prefill)),
|
||||
)
|
||||
|
||||
if hf_prefill.shape[0] <= 100:
|
||||
assert torch.all(torch.abs(hf_prefill - srt_prefill) < prefill_tol), (
|
||||
f"Prefill logprobs mismatch for base '{base_path}', adaptor '{adaptor_names}', "
|
||||
f"backend '{backend}', prompt: '{prompts[0][:50]}...'"
|
||||
)
|
||||
|
||||
if hf_decode.shape[0] <= 100:
|
||||
assert torch.all(torch.abs(hf_decode - srt_decode) < decode_tol), (
|
||||
f"Decode logprobs mismatch for base '{base_path}', adaptor '{adaptor_names}', "
|
||||
f"backend '{backend}', prompt: '{prompts[0][:50]}...'"
|
||||
)
|
||||
|
||||
if rouge_score < rouge_tol:
|
||||
raise AssertionError(
|
||||
f"ROUGE-L score {rouge_score} below tolerance {rouge_tol} "
|
||||
f"for base '{base_path}', adaptor '{adaptor_names}', backend '{backend}', prompt: '{prompts[0][:50]}...'"
|
||||
)
|
||||
|
||||
|
||||
def run_lora_test_by_batch(
|
||||
prompts: List[str],
|
||||
model_case: LoRAModelCase,
|
||||
torch_dtype: torch.dtype,
|
||||
max_new_tokens: int,
|
||||
backend: str,
|
||||
disable_cuda_graph: bool = False,
|
||||
disable_radix_cache: bool = False,
|
||||
mem_fraction_static: float = 0.88,
|
||||
test_tag: str = "",
|
||||
):
|
||||
"""
|
||||
Run lora tests as a batch.
|
||||
For prompt0, prompt1, ..., promptN,
|
||||
we will use adaptor0, adaptor1, ..., adaptorN included in model case,
|
||||
We will then compare the outputs of HF and SRT with LoRA.
|
||||
If number of prompts is larger than number of adaptors,
|
||||
the prompt i will use adaptor i % (number of adaptors).
|
||||
|
||||
Args:
|
||||
prompts (List[str]): The batch of prompts to test.
|
||||
model_case (LoRAModelCase): The model case to test.
|
||||
torch_dtype (torch.dtype): The torch dtype to use.
|
||||
max_new_tokens (int): The maximum number of new tokens to generate.
|
||||
backend (str): The lora backend to use.
|
||||
disable_cuda_graph (bool, optional): Whether to disable CUDA graph. Defaults to False.
|
||||
disable_radix_cache (bool, optional): Whether to disable radix cache. Defaults to False.
|
||||
mem_fraction_static (float, optional): The fraction of memory to use. Defaults to 0.88.
|
||||
test_tag (str, optional): The tag to use for the test. Defaults to "".
|
||||
"""
|
||||
base_path = model_case.base
|
||||
|
||||
# Create used adaptors for each prompt in batch
|
||||
i, adaptors = 0, []
|
||||
for _ in range(len(prompts)):
|
||||
adaptors.append(model_case.adaptors[i])
|
||||
i = (i + 1) % len(model_case.adaptors)
|
||||
adaptor_names = [adaptor.name for adaptor in adaptors]
|
||||
|
||||
print(
|
||||
f"\n========== Testing {test_tag} on base '{model_case.base}' with backend={backend}, dtype={torch_dtype} --- "
|
||||
f"Using prompts {[p[:50] for p in prompts]} with adaptors: {adaptor_names} ---"
|
||||
)
|
||||
with SRTRunner(
|
||||
base_path,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="generation",
|
||||
tp_size=model_case.tp_size,
|
||||
lora_paths=[
|
||||
adaptor.name for adaptor in model_case.adaptors if adaptor.name is not None
|
||||
],
|
||||
max_loras_per_batch=model_case.max_loras_per_batch,
|
||||
lora_backend=backend,
|
||||
disable_cuda_graph=disable_cuda_graph,
|
||||
disable_radix_cache=disable_radix_cache,
|
||||
mem_fraction_static=mem_fraction_static,
|
||||
) as srt_runner:
|
||||
srt_outputs = srt_runner.batch_forward(
|
||||
prompts, max_new_tokens=max_new_tokens, lora_paths=adaptor_names
|
||||
)
|
||||
|
||||
with SRTRunner(
|
||||
base_path,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="generation",
|
||||
tp_size=model_case.tp_size,
|
||||
mem_fraction_static=mem_fraction_static,
|
||||
) as srt_runner:
|
||||
srt_no_lora_outputs = srt_runner.batch_forward(
|
||||
prompts, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
with HFRunner(
|
||||
base_path, torch_dtype=torch_dtype, model_type="generation"
|
||||
) as hf_runner:
|
||||
hf_outputs = hf_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens, lora_paths=adaptor_names
|
||||
)
|
||||
|
||||
with HFRunner(
|
||||
base_path, torch_dtype=torch_dtype, model_type="generation"
|
||||
) as hf_runner:
|
||||
hf_no_lora_outputs = hf_runner.forward(
|
||||
prompts,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
|
||||
for i in range(len(prompts)):
|
||||
|
||||
srt_output_str = srt_outputs.output_strs[i].strip()
|
||||
hf_output_str = hf_outputs.output_strs[i].strip()
|
||||
rouge_score = calculate_rouge_l([srt_output_str], [hf_output_str])[0]
|
||||
print("ROUGE-L score:", rouge_score)
|
||||
print("SRT output:", srt_output_str)
|
||||
print("HF output:", hf_output_str)
|
||||
print("SRT no lora output:", srt_no_lora_outputs.output_strs[i].strip())
|
||||
print("HF no lora output:", hf_no_lora_outputs.output_strs[i].strip())
|
||||
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i].strip(
|
||||
" "
|
||||
), (
|
||||
srt_outputs.output_strs[i].strip(" "),
|
||||
hf_outputs.output_strs[i].strip(" "),
|
||||
)
|
||||
assert srt_no_lora_outputs.output_strs[i].strip(
|
||||
" "
|
||||
) == hf_no_lora_outputs.output_strs[i].strip(" "), (
|
||||
srt_no_lora_outputs.output_strs[i].strip(" "),
|
||||
hf_no_lora_outputs.output_strs[i].strip(" "),
|
||||
)
|
||||
52
test/srt/models/compare.py
Normal file
52
test/srt/models/compare.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
used for debug using tensor comparison
|
||||
dump {name: tensor} into "log_hf.jsonl" and "log_srt.jsonl"
|
||||
use the same name for two tensors that supposed to be close
|
||||
recommend name like: "layer 2 after mlp"
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
assert sys.argv[1] == "base"
|
||||
hf_log = "base_log_hf.jsonl"
|
||||
srt_log = "base_log_srt.jsonl"
|
||||
else:
|
||||
hf_log = "log_hf.jsonl"
|
||||
srt_log = "log_srt.jsonl"
|
||||
|
||||
|
||||
def load_data(filepath):
|
||||
tensors = {}
|
||||
with open(filepath, "r") as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
data = json.loads(line)
|
||||
for k, v in data.items():
|
||||
tensors[k] = torch.tensor(v)
|
||||
return tensors
|
||||
|
||||
|
||||
hf_tensors = load_data(hf_log)
|
||||
srt_tensors = load_data(srt_log)
|
||||
|
||||
|
||||
def get_diff(t1, t2):
|
||||
t1 = t1.reshape(t2.shape)
|
||||
max_diff = torch.max(abs(t1.reshape(t2.shape) - t2))
|
||||
l2_dis = torch.dist(t1, t2, p=2)
|
||||
return l2_dis, max_diff
|
||||
|
||||
|
||||
for k, _ in srt_tensors.items():
|
||||
l2_dis, max_diff = get_diff(hf_tensors[k], srt_tensors[k])
|
||||
print(f"{k} {l2_dis=} {max_diff=}")
|
||||
if k == "layer 1 attn":
|
||||
print(hf_tensors[k])
|
||||
print(srt_tensors[k])
|
||||
if k == "layer 0 prefill k":
|
||||
print(srt_tensors[k].shape)
|
||||
print(hf_tensors[k].shape)
|
||||
80
test/srt/models/test_clip_models.py
Normal file
80
test/srt/models/test_clip_models.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import multiprocessing as mp
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from sglang.srt.utils import load_image
|
||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||
from sglang.test.test_utils import get_similarities
|
||||
|
||||
TEXTS = "two Subway Series sandwiches with meats, cheese, lettuce, tomatoes, and onions on a black background, accompanied by the Subway Series logo, highlighting a new sandwich series."
|
||||
IMAGES = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/023.jpg"
|
||||
MODELS = [
|
||||
("openai/clip-vit-large-patch14-336", 1e-5),
|
||||
]
|
||||
TORCH_DTYPES = [torch.float16]
|
||||
|
||||
|
||||
class TestClipModels(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
def assert_close_embeddings(self, model, prefill_tolerance, torch_dtype):
|
||||
|
||||
with HFRunner(
|
||||
model,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="embedding",
|
||||
) as hf_runner:
|
||||
hf_text_embeds = hf_runner.forward(prompts=TEXTS)
|
||||
hf_image_embeds = hf_runner.forward(image_data=IMAGES)
|
||||
|
||||
with SRTRunner(
|
||||
model,
|
||||
tp_size=1,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="embedding",
|
||||
) as srt_runner:
|
||||
text_embeds = srt_runner.forward(prompts=TEXTS)
|
||||
image_embeds = srt_runner.forward(prompts="padding", image_data=IMAGES)
|
||||
|
||||
text_similarity = get_similarities(
|
||||
text_embeds.embed_logits[0], hf_text_embeds.embed_logits[0]
|
||||
)
|
||||
image_similarity = get_similarities(
|
||||
image_embeds.embed_logits[0], hf_image_embeds.embed_logits[0]
|
||||
)
|
||||
print("text similarity diff", abs(text_similarity - 1))
|
||||
print("image similarity diff", abs(image_similarity - 1))
|
||||
assert torch.all(
|
||||
abs(text_similarity - 1) < prefill_tolerance
|
||||
), "embeddings are not all close"
|
||||
assert torch.all(
|
||||
abs(image_similarity - 1) < prefill_tolerance
|
||||
), "embeddings are not all close"
|
||||
|
||||
def test_accuracy(self):
|
||||
for model, prefill_tolerance in MODELS:
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
self.assert_close_embeddings(model, prefill_tolerance, torch_dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
46
test/srt/models/test_compressed_tensors_models.py
Normal file
46
test/srt/models/test_compressed_tensors_models.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestCompressedTensorsLlama3FP8(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "RedHatAI/Meta-Llama-3.1-8B-FP8"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
print(f"{metrics=}")
|
||||
self.assertGreaterEqual(metrics["accuracy"], 0.45)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
91
test/srt/models/test_cross_encoder_models.py
Normal file
91
test/srt/models/test_cross_encoder_models.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import multiprocessing as mp
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from sglang.test.runners import TEST_RERANK_QUERY_DOCS, HFRunner, SRTRunner
|
||||
from sglang.test.test_utils import CustomTestCase, is_in_ci
|
||||
|
||||
MODELS = [
|
||||
("cross-encoder/ms-marco-MiniLM-L6-v2", 1, 1e-2),
|
||||
("BAAI/bge-reranker-v2-m3", 1, 1e-2),
|
||||
]
|
||||
ATTENTION_BACKEND = ["torch_native", "triton"]
|
||||
|
||||
TORCH_DTYPES = [torch.float32]
|
||||
|
||||
|
||||
class TestCrossEncoderModels(CustomTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
def assert_close_prefill_logits(
|
||||
self,
|
||||
prompts,
|
||||
model_path,
|
||||
tp_size,
|
||||
torch_dtype,
|
||||
score_tolerance,
|
||||
attention_backend,
|
||||
) -> None:
|
||||
with HFRunner(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="cross_encoder",
|
||||
) as hf_runner:
|
||||
hf_scores = hf_runner.forward(prompts).scores
|
||||
|
||||
with SRTRunner(
|
||||
model_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="cross_encoder",
|
||||
attention_backend=attention_backend,
|
||||
chunked_prefill_size=-1,
|
||||
disable_radix_cache=True,
|
||||
) as srt_runner:
|
||||
srt_scores = srt_runner.forward(prompts).scores
|
||||
|
||||
for i in range(len(srt_scores)):
|
||||
score_difference = abs(hf_scores[i] - srt_scores[i])
|
||||
|
||||
assert (
|
||||
score_difference < score_tolerance
|
||||
), "cross encoder scores are not all close"
|
||||
|
||||
def preprocess_prompts(self, prompt):
|
||||
processed_prompts = []
|
||||
query = prompt["query"]
|
||||
documents = prompt["documents"]
|
||||
for document in documents:
|
||||
processed_prompts.append([query, document])
|
||||
|
||||
return processed_prompts
|
||||
|
||||
def test_prefill_logits(self):
|
||||
models_to_test = MODELS
|
||||
|
||||
if is_in_ci():
|
||||
models_to_test = [random.choice(MODELS)]
|
||||
|
||||
for model, tp_size, prefill_tolerance in models_to_test:
|
||||
for attention_backend in ATTENTION_BACKEND:
|
||||
for queryDocs in TEST_RERANK_QUERY_DOCS:
|
||||
prompts = self.preprocess_prompts(queryDocs)
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
self.assert_close_prefill_logits(
|
||||
prompts,
|
||||
model,
|
||||
tp_size,
|
||||
torch_dtype,
|
||||
prefill_tolerance,
|
||||
attention_backend,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
34
test/srt/models/test_dummy_grok_models.py
Normal file
34
test/srt/models/test_dummy_grok_models.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import unittest
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase, is_in_ci, run_bench_one_batch
|
||||
|
||||
|
||||
class TestDummyGrok1(CustomTestCase):
|
||||
|
||||
def test_dummy_grok_1(self):
|
||||
_, output_throughput, _ = run_bench_one_batch(
|
||||
None,
|
||||
[
|
||||
"--model",
|
||||
"/dummy-grok",
|
||||
"--tokenizer-path",
|
||||
"Xenova/grok-1-tokenizer",
|
||||
"--batch-size",
|
||||
"2",
|
||||
"--tp",
|
||||
"2",
|
||||
"--quantization",
|
||||
"fp8",
|
||||
"--load-format",
|
||||
"dummy",
|
||||
"--json-model-override-args",
|
||||
'{"num_hidden_layers": 2}',
|
||||
],
|
||||
)
|
||||
|
||||
if is_in_ci():
|
||||
self.assertGreater(output_throughput, 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
111
test/srt/models/test_embedding_models.py
Normal file
111
test/srt/models/test_embedding_models.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import multiprocessing as mp
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||
from sglang.test.test_utils import CustomTestCase, get_similarities, is_in_ci
|
||||
|
||||
MODELS = [
|
||||
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5),
|
||||
("intfloat/e5-mistral-7b-instruct", 1, 1e-5),
|
||||
("marco/mcdse-2b-v1", 1, 1e-5),
|
||||
("Qwen/Qwen3-Embedding-8B", 1, 1e-5),
|
||||
# Temporarily disable before this model is fixed
|
||||
# ("jason9693/Qwen2.5-1.5B-apeach", 1, 1e-5),
|
||||
]
|
||||
TORCH_DTYPES = [torch.float16]
|
||||
|
||||
|
||||
class TestEmbeddingModels(CustomTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
def _truncate_prompts(self, prompts, model_path):
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
max_length = getattr(config, "max_position_embeddings", 2048)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
truncated_prompts = []
|
||||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, return_tensors="pt", truncation=False)
|
||||
if len(tokens.input_ids[0]) > max_length:
|
||||
truncated_text = tokenizer.decode(
|
||||
tokens.input_ids[0][: max_length - 1], skip_special_tokens=True
|
||||
)
|
||||
truncated_prompts.append(truncated_text)
|
||||
else:
|
||||
truncated_prompts.append(prompt)
|
||||
return truncated_prompts
|
||||
|
||||
def assert_close_prefill_logits(
|
||||
self,
|
||||
prompts,
|
||||
model_path,
|
||||
tp_size,
|
||||
torch_dtype,
|
||||
prefill_tolerance,
|
||||
) -> None:
|
||||
truncated_prompts = self._truncate_prompts(prompts, model_path)
|
||||
|
||||
with HFRunner(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="embedding",
|
||||
) as hf_runner:
|
||||
hf_outputs = hf_runner.forward(truncated_prompts)
|
||||
|
||||
with SRTRunner(
|
||||
model_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="embedding",
|
||||
) as srt_runner:
|
||||
srt_outputs = srt_runner.forward(truncated_prompts)
|
||||
|
||||
for i in range(len(prompts)):
|
||||
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
|
||||
srt_logits = torch.Tensor(srt_outputs.embed_logits[i])
|
||||
|
||||
similarity = torch.tensor(get_similarities(hf_logits, srt_logits))
|
||||
print("similarity diff", abs(similarity - 1))
|
||||
|
||||
if len(prompts[i]) <= 1000:
|
||||
assert torch.all(
|
||||
abs(similarity - 1) < prefill_tolerance
|
||||
), "embeddings are not all close"
|
||||
|
||||
def test_prefill_logits(self):
|
||||
models_to_test = MODELS
|
||||
|
||||
if is_in_ci():
|
||||
models_to_test = [random.choice(MODELS)]
|
||||
|
||||
for model, tp_size, prefill_tolerance in models_to_test:
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
self.assert_close_prefill_logits(
|
||||
DEFAULT_PROMPTS, model, tp_size, torch_dtype, prefill_tolerance
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
162
test/srt/models/test_encoder_embedding_models.py
Normal file
162
test/srt/models/test_encoder_embedding_models.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# python -m unittest test_encoder_embedding_models.TestEncoderEmbeddingModels.test_prefill_logits
|
||||
|
||||
import multiprocessing as mp
|
||||
import random
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||
from sglang.test.test_utils import CustomTestCase, get_similarities, is_in_ci
|
||||
|
||||
MODELS = [("BAAI/bge-small-en", 1, 1e-5), ("BAAI/bge-m3", 1, 1e-5)]
|
||||
|
||||
ATTENTION_BACKEND = ["torch_native", "triton", "flashinfer"]
|
||||
BATCH_SIZE = [1, 2]
|
||||
TORCH_DTYPES = [torch.float32, torch.float16]
|
||||
sgl_to_st_ratio = []
|
||||
|
||||
|
||||
class TestEncoderEmbeddingModels(CustomTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
def _truncate_prompts(self, prompts, model_path):
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
max_length = getattr(config, "max_position_embeddings", 512) - 20
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
truncated_prompts = []
|
||||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, return_tensors="pt", truncation=False)
|
||||
if len(tokens.input_ids[0]) > max_length:
|
||||
truncated_text = tokenizer.decode(
|
||||
tokens.input_ids[0][: max_length - 1], skip_special_tokens=True
|
||||
)
|
||||
truncated_prompts.append(truncated_text)
|
||||
else:
|
||||
truncated_prompts.append(prompt)
|
||||
|
||||
return truncated_prompts
|
||||
|
||||
def assert_close_prefill_logits(
|
||||
self,
|
||||
prompts,
|
||||
model_path,
|
||||
tp_size,
|
||||
torch_dtype,
|
||||
prefill_tolerance,
|
||||
attention_backend,
|
||||
batch_size,
|
||||
) -> None:
|
||||
truncated_prompts = self._truncate_prompts(prompts, model_path)
|
||||
truncated_prompts = truncated_prompts * batch_size
|
||||
|
||||
with HFRunner(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="embedding",
|
||||
) as hf_runner:
|
||||
# warm up
|
||||
hf_outputs = hf_runner.forward(truncated_prompts)
|
||||
|
||||
st_start_time = time.perf_counter()
|
||||
hf_outputs = hf_runner.forward(truncated_prompts)
|
||||
st_end_time = time.perf_counter()
|
||||
|
||||
with SRTRunner(
|
||||
model_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="embedding",
|
||||
attention_backend=attention_backend,
|
||||
chunked_prefill_size=-1,
|
||||
disable_radix_cache=True,
|
||||
) as srt_runner:
|
||||
# warm up
|
||||
srt_outputs = srt_runner.forward(truncated_prompts)
|
||||
|
||||
sgl_start_time = time.perf_counter()
|
||||
srt_outputs = srt_runner.forward(truncated_prompts)
|
||||
sgl_end_time = time.perf_counter()
|
||||
|
||||
transformer_time = st_end_time - st_start_time
|
||||
sgl_time = sgl_end_time - sgl_start_time
|
||||
sgl_to_st_ratio.append(sgl_time / transformer_time)
|
||||
|
||||
for i in range(len(truncated_prompts)):
|
||||
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
|
||||
srt_logits = torch.Tensor(srt_outputs.embed_logits[i])
|
||||
|
||||
similarity = torch.tensor(get_similarities(hf_logits, srt_logits))
|
||||
# If something is wrong, uncomment this to observe similarity.
|
||||
# print("similarity diff", abs(similarity - 1))
|
||||
|
||||
if len(truncated_prompts[i]) <= 1000:
|
||||
assert torch.all(
|
||||
abs(similarity - 1) < prefill_tolerance
|
||||
), "embeddings are not all close"
|
||||
|
||||
def test_prefill_logits(self):
|
||||
models_to_test = MODELS
|
||||
|
||||
if is_in_ci():
|
||||
models_to_test = [random.choice(MODELS)]
|
||||
|
||||
for model, tp_size, prefill_tolerance in models_to_test:
|
||||
for attention_backend in ATTENTION_BACKEND:
|
||||
for batch_size in BATCH_SIZE:
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
# NOTE: FlashInfer currently has limitations with head_dim = 32 or
|
||||
# other dimensions.
|
||||
# The FlashInfer head_dim limitation itself is tracked here:
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/1048
|
||||
#
|
||||
# Flashinfer does not support torch.float32 for dtype_q, so skip it
|
||||
if attention_backend == "flashinfer":
|
||||
if (
|
||||
model == "BAAI/bge-small-en"
|
||||
or torch_dtype == torch.float32
|
||||
):
|
||||
continue
|
||||
|
||||
self.assert_close_prefill_logits(
|
||||
DEFAULT_PROMPTS,
|
||||
model,
|
||||
tp_size,
|
||||
torch_dtype,
|
||||
prefill_tolerance,
|
||||
attention_backend,
|
||||
batch_size,
|
||||
)
|
||||
|
||||
for i in range(len(BATCH_SIZE)):
|
||||
print(
|
||||
"bacth size: ",
|
||||
BATCH_SIZE[i] * 5,
|
||||
"sgl_time/st_time",
|
||||
round(sgl_to_st_ratio[i], 3),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
180
test/srt/models/test_generation_models.py
Normal file
180
test/srt/models/test_generation_models.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Usage:
|
||||
|
||||
To test a specific model locally:
|
||||
1. Add it to ALL_MODELS, for example, `ModelCase("Qwen/Qwen2-1.5B")`
|
||||
2. Run `ONLY_RUN=Qwen/Qwen2-1.5B python3 -m unittest test_generation_models.TestGenerationModels`
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import random
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.test.runners import (
|
||||
DEFAULT_PROMPTS,
|
||||
HFRunner,
|
||||
SRTRunner,
|
||||
check_close_model_outputs,
|
||||
)
|
||||
from sglang.test.test_utils import CustomTestCase, is_in_ci
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelCase:
|
||||
model_path: str
|
||||
tp_size: int = 1
|
||||
prefill_tolerance: float = 5e-2
|
||||
decode_tolerance: float = 6e-2 # Increased to fix numerical error in issue #8614.
|
||||
rouge_l_tolerance: float = 1
|
||||
skip_long_prompt: bool = False
|
||||
trust_remote_code: bool = False
|
||||
|
||||
|
||||
# Popular models that run on the CI
|
||||
CI_MODELS = [
|
||||
ModelCase("meta-llama/Llama-3.1-8B-Instruct"),
|
||||
ModelCase("google/gemma-2-2b"),
|
||||
]
|
||||
|
||||
# the complete set of models to test sglang's generation model
|
||||
ALL_MODELS = [
|
||||
*CI_MODELS,
|
||||
ModelCase("Qwen/Qwen2-1.5B"),
|
||||
ModelCase("Qwen/Qwen2.5-14B-Instruct"),
|
||||
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
|
||||
ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True),
|
||||
ModelCase(
|
||||
"THUDM/glm-4-9b-chat", tp_size=2, trust_remote_code=True, skip_long_prompt=True
|
||||
),
|
||||
ModelCase("openai-community/gpt2"),
|
||||
ModelCase("microsoft/phi-1_5", trust_remote_code=True),
|
||||
ModelCase("adept/persimmon-8b-chat"),
|
||||
ModelCase("inclusionAI/Ling-lite", trust_remote_code=True),
|
||||
ModelCase("microsoft/Phi-3-small-8k-instruct", trust_remote_code=True),
|
||||
ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True),
|
||||
ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True),
|
||||
ModelCase(
|
||||
"microsoft/Phi-3.5-MoE-instruct",
|
||||
tp_size=2,
|
||||
trust_remote_code=True,
|
||||
skip_long_prompt=True,
|
||||
),
|
||||
ModelCase(
|
||||
"nvidia/Llama-3_3-Nemotron-Super-49B-v1_5",
|
||||
tp_size=2,
|
||||
trust_remote_code=True,
|
||||
skip_long_prompt=True,
|
||||
),
|
||||
ModelCase(
|
||||
"nvidia/Llama-3_1-Nemotron-Ultra-253B-v1",
|
||||
tp_size=8,
|
||||
trust_remote_code=True,
|
||||
skip_long_prompt=True,
|
||||
),
|
||||
]
|
||||
|
||||
TORCH_DTYPES = [torch.float16]
|
||||
|
||||
|
||||
class TestGenerationModels(CustomTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
def assert_close_logits_and_output_strs(
|
||||
self,
|
||||
prompts: List[str],
|
||||
model_case: ModelCase,
|
||||
torch_dtype: torch.dtype,
|
||||
) -> None:
|
||||
model_path = model_case.model_path
|
||||
prefill_tolerance, decode_tolerance, rouge_l_tolerance = (
|
||||
model_case.prefill_tolerance,
|
||||
model_case.decode_tolerance,
|
||||
model_case.rouge_l_tolerance,
|
||||
)
|
||||
max_new_tokens = 32
|
||||
|
||||
with HFRunner(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="generation",
|
||||
trust_remote_code=model_case.trust_remote_code,
|
||||
) as hf_runner:
|
||||
hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
||||
|
||||
with SRTRunner(
|
||||
model_path,
|
||||
tp_size=model_case.tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="generation",
|
||||
trust_remote_code=model_case.trust_remote_code,
|
||||
) as srt_runner:
|
||||
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
||||
|
||||
check_close_model_outputs(
|
||||
hf_outputs=hf_outputs,
|
||||
srt_outputs=srt_outputs,
|
||||
prefill_tolerance=model_case.prefill_tolerance,
|
||||
decode_tolerance=model_case.decode_tolerance,
|
||||
rouge_l_tolerance=model_case.rouge_l_tolerance,
|
||||
debug_text=f"model_path={model_path} prompts={prompts}",
|
||||
)
|
||||
|
||||
@unittest.skipIf(not is_in_ci(), "Local test should run all models")
|
||||
def test_ci_models(self):
|
||||
for model_case in CI_MODELS:
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
prompts = DEFAULT_PROMPTS
|
||||
|
||||
# Skip long prompts for models that do not have a long context
|
||||
if model_case.skip_long_prompt:
|
||||
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
|
||||
|
||||
# Assert the logits and output strs are close
|
||||
self.assert_close_logits_and_output_strs(
|
||||
prompts, model_case, torch_dtype
|
||||
)
|
||||
|
||||
@unittest.skipIf(is_in_ci(), "CI only runs selected models for simplicity")
|
||||
def test_all_models(self):
|
||||
for model_case in ALL_MODELS:
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
if (
|
||||
"ONLY_RUN" in os.environ
|
||||
and os.environ["ONLY_RUN"] != model_case.model_path
|
||||
):
|
||||
continue
|
||||
|
||||
# Skip long prompts for models that do not have a long context
|
||||
prompts = DEFAULT_PROMPTS
|
||||
if model_case.skip_long_prompt:
|
||||
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
|
||||
|
||||
# Assert the logits and output strs are close
|
||||
self.assert_close_logits_and_output_strs(
|
||||
prompts, model_case, torch_dtype
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
85
test/srt/models/test_gme_qwen_models.py
Normal file
85
test/srt/models/test_gme_qwen_models.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
import multiprocessing as mp
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.test.runners import HFRunner, SRTRunner
|
||||
from sglang.test.test_utils import CustomTestCase, get_similarities
|
||||
|
||||
TEXTS = "two Subway Series sandwiches with meats, cheese, lettuce, tomatoes, and onions on a black background, accompanied by the Subway Series logo, highlighting a new sandwich series."
|
||||
IMAGES = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/023.jpg"
|
||||
|
||||
|
||||
MODELS = [
|
||||
("Alibaba-NLP/gme-Qwen2-VL-2B-Instruct", 1e-3),
|
||||
]
|
||||
TORCH_DTYPES = [torch.float16]
|
||||
|
||||
|
||||
class TestQmeQwenModels(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
def assert_close_embeddings(self, model, prefill_tolerance, torch_dtype):
|
||||
|
||||
prompts_no_image = f"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n{TEXTS}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>"
|
||||
prompts_with_image = f"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>assistant\n<|endoftext|>"
|
||||
with HFRunner(
|
||||
model,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="embedding",
|
||||
) as hf_runner:
|
||||
hf_text_embeddings = hf_runner.forward(prompts=[prompts_no_image])
|
||||
hf_image_embeddings = hf_runner.forward(
|
||||
prompts=[prompts_with_image], image_data=[IMAGES]
|
||||
)
|
||||
with SRTRunner(
|
||||
model,
|
||||
tp_size=1,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="embedding",
|
||||
) as srt_runner:
|
||||
srt_text_embeddings = srt_runner.forward(prompts=prompts_no_image)
|
||||
srt_image_embeddings = srt_runner.forward(
|
||||
prompts=prompts_with_image, image_data=IMAGES
|
||||
)
|
||||
|
||||
similarity = get_similarities(
|
||||
hf_text_embeddings.embed_logits[0], srt_text_embeddings.embed_logits[0]
|
||||
)
|
||||
print("texts similarity diff", abs(similarity - 1))
|
||||
assert torch.all(
|
||||
abs(similarity - 1) < prefill_tolerance
|
||||
), "embeddings are not all close"
|
||||
similarity = get_similarities(
|
||||
hf_image_embeddings.embed_logits[0], srt_image_embeddings.embed_logits[0]
|
||||
)
|
||||
print("images similarity diff", abs(similarity - 1))
|
||||
assert torch.all(
|
||||
abs(similarity - 1) < prefill_tolerance
|
||||
), "embeddings are not all close"
|
||||
|
||||
def test_accuracy(self):
|
||||
for model, prefill_tolerance in MODELS:
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
self.assert_close_embeddings(model, prefill_tolerance, torch_dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
53
test/srt/models/test_grok_models.py
Normal file
53
test/srt/models/test_grok_models.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestGrok(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "lmzheng/grok-1"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--load-format",
|
||||
"dummy",
|
||||
"--json-model-override-args",
|
||||
'{"num_hidden_layers": 2}',
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=64,
|
||||
max_new_tokens=256,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
print(f"{metrics=}")
|
||||
|
||||
# It is dummy weights so we only assert the output throughput instead of accuracy.
|
||||
self.assertGreater(metrics["output_throughput"], 1000)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
74
test/srt/models/test_llama4_models.py
Normal file
74
test/srt/models/test_llama4_models.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import random
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
MODELS = [
|
||||
SimpleNamespace(
|
||||
model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
accuracy=0.9,
|
||||
tp_size=4,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TestLlama4(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
|
||||
def test_gsm8k(self):
|
||||
|
||||
for model in MODELS:
|
||||
try:
|
||||
process = popen_launch_server(
|
||||
model.model,
|
||||
self.base_url,
|
||||
timeout=3 * DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--chat-template",
|
||||
"llama-4",
|
||||
"--tp-size",
|
||||
str(model.tp_size),
|
||||
"--mem-fraction-static",
|
||||
"0.8",
|
||||
"--context-length",
|
||||
"8192",
|
||||
],
|
||||
)
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
print(f"{metrics=}")
|
||||
self.assertGreaterEqual(metrics["accuracy"], model.accuracy)
|
||||
except Exception as e:
|
||||
print(f"Error testing {model.model}: {e}")
|
||||
self.fail(f"Test failed for {model.model}: {e}")
|
||||
|
||||
finally:
|
||||
# Ensure process cleanup happens regardless of success/failure
|
||||
if process is not None and process.poll() is None:
|
||||
print(f"Cleaning up process {process.pid}")
|
||||
try:
|
||||
kill_process_tree(process.pid)
|
||||
except Exception as e:
|
||||
print(f"Error killing process: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
58
test/srt/models/test_mtp_models.py
Normal file
58
test/srt/models/test_mtp_models.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestMiMoMTP(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "XiaomiMiMo/MiMo-7B-RL"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--speculative-algorithm",
|
||||
"EAGLE",
|
||||
"--speculative-num-steps",
|
||||
"1",
|
||||
"--speculative-eagle-topk",
|
||||
"1",
|
||||
"--speculative-num-draft-tokens",
|
||||
"2",
|
||||
"--mem-fraction-static",
|
||||
"0.5",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
print(f"{metrics=}")
|
||||
self.assertGreater(metrics["accuracy"], 0.7)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
77
test/srt/models/test_qwen_models.py
Normal file
77
test/srt/models/test_qwen_models.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestQwen2(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "Qwen/Qwen2-7B-Instruct"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
print(f"{metrics=}")
|
||||
self.assertGreater(metrics["accuracy"], 0.78)
|
||||
|
||||
|
||||
class TestQwen2FP8(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "neuralmagic/Qwen2-7B-Instruct-FP8"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
print(f"{metrics=}")
|
||||
self.assertGreater(metrics["accuracy"], 0.78)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
92
test/srt/models/test_reward_models.py
Normal file
92
test/srt/models/test_reward_models.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import multiprocessing as mp
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.test.runners import HFRunner, SRTRunner
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
MODELS = [
|
||||
("LxzGordon/URM-LLaMa-3.1-8B", 1, 4e-2),
|
||||
("Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", 1, 4e-2),
|
||||
]
|
||||
TORCH_DTYPES = [torch.float16]
|
||||
|
||||
# PROMPT = "Jane has 12 apples. She gives 4 apples to her friend Mark, then buys 1 more apple, and finally splits all her apples equally among herself and her 2 siblings. How many apples does each person get?"
|
||||
# RESPONSE1 = "1. Jane starts with 12 apples and gives 4 to Mark. 12 - 4 = 8. Jane now has 8 apples.\n2. Jane buys 1 more apple. 8 + 1 = 9. Jane now has 9 apples.\n3. Jane splits the 9 apples equally among herself and her 2 siblings (3 people in total). 9 ÷ 3 = 3 apples each. Each person gets 3 apples."
|
||||
# RESPONSE2 = "1. Jane starts with 12 apples and gives 4 to Mark. 12 - 4 = 8. Jane now has 8 apples.\n2. Jane buys 1 more apple. 8 + 1 = 9. Jane now has 9 apples.\n3. Jane splits the 9 apples equally among her 2 siblings (2 people in total). 9 ÷ 2 = 4.5 apples each. Each person gets 4 apples."
|
||||
|
||||
PROMPT = (
|
||||
"What is the range of the numeric output of a sigmoid node in a neural network?"
|
||||
)
|
||||
RESPONSE1 = "The output of a sigmoid node is bounded between -1 and 1."
|
||||
RESPONSE2 = "The output of a sigmoid node is bounded between 0 and 1."
|
||||
|
||||
CONVS = [
|
||||
[{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE1}],
|
||||
[{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE2}],
|
||||
]
|
||||
|
||||
|
||||
class TestRewardModels(CustomTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
def assert_close_reward_scores(
|
||||
self,
|
||||
convs,
|
||||
model_path,
|
||||
tp_size,
|
||||
torch_dtype,
|
||||
tolerance,
|
||||
) -> None:
|
||||
with HFRunner(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="reward",
|
||||
) as hf_runner:
|
||||
hf_outputs = hf_runner.forward(convs)
|
||||
|
||||
with SRTRunner(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="reward",
|
||||
) as srt_runner:
|
||||
prompts = srt_runner.tokenizer.apply_chat_template(convs, tokenize=False)
|
||||
srt_outputs = srt_runner.forward(prompts)
|
||||
|
||||
hf_scores = torch.tensor(hf_outputs.scores)
|
||||
srt_scores = torch.tensor(srt_outputs.scores)
|
||||
print(f"{hf_scores=}")
|
||||
print(f"{srt_scores=}")
|
||||
|
||||
assert torch.all(
|
||||
abs(hf_scores - srt_scores) < tolerance
|
||||
), "reward scores are not all close"
|
||||
|
||||
def test_reward_scores(self):
|
||||
for model, tp_size, tolerance in MODELS:
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
self.assert_close_reward_scores(
|
||||
CONVS, model, tp_size, torch_dtype, tolerance
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
181
test/srt/models/test_transformers_models.py
Normal file
181
test/srt/models/test_transformers_models.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import dataclasses
|
||||
import multiprocessing as mp
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.runners import DEFAULT_PROMPTS, SRTRunner, check_close_model_outputs
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestTransformersFallbackEndpoint(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=["--model-impl", "transformers"],
|
||||
)
|
||||
cls.mmlu_lower_bound = 0.65
|
||||
cls.gsm8k_lower_bound = 0.65
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
from sglang.test.run_eval import run_eval
|
||||
|
||||
metrics = run_eval(args)
|
||||
self.assertGreaterEqual(metrics["score"], self.mmlu_lower_bound)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
from sglang.test.few_shot_gsm8k import run_eval
|
||||
|
||||
metrics = run_eval(args)
|
||||
print(f"{metrics=}")
|
||||
self.assertGreater(metrics["accuracy"], self.gsm8k_lower_bound)
|
||||
|
||||
|
||||
class TestTransformersFallbackTorchAO(TestTransformersFallbackEndpoint):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--model-impl",
|
||||
"transformers",
|
||||
"--torchao-config",
|
||||
"int4wo-128",
|
||||
],
|
||||
)
|
||||
cls.mmlu_lower_bound = 0.65
|
||||
cls.gsm8k_lower_bound = 0.65
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelCase:
|
||||
model_path: str
|
||||
tp_size: int = 1
|
||||
prefill_tolerance: float = 5e-2
|
||||
decode_tolerance: float = 5e-2
|
||||
rouge_l_tolerance: float = 1
|
||||
skip_long_prompt: bool = False
|
||||
trust_remote_code: bool = False
|
||||
torchao_config: str = None
|
||||
torch_dtype: torch.dtype = torch.float16
|
||||
|
||||
|
||||
# Popular models that run on the CI
|
||||
CI_MODELS = [
|
||||
ModelCase(DEFAULT_MODEL_NAME_FOR_TEST),
|
||||
]
|
||||
|
||||
ALL_OTHER_MODELS = [
|
||||
ModelCase(DEFAULT_MODEL_NAME_FOR_TEST, tp_size=2),
|
||||
]
|
||||
|
||||
|
||||
class TestTransformersFallbackEngine(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
def assert_close_logits_and_output_strs(
|
||||
self,
|
||||
prompts: List[str],
|
||||
model_case: ModelCase,
|
||||
) -> None:
|
||||
model_path = model_case.model_path
|
||||
max_new_tokens = 32
|
||||
# force to use transformers impl
|
||||
with SRTRunner(
|
||||
model_path,
|
||||
tp_size=model_case.tp_size,
|
||||
torch_dtype=model_case.torch_dtype,
|
||||
model_type="generation",
|
||||
model_impl="transformers",
|
||||
trust_remote_code=model_case.trust_remote_code,
|
||||
torchao_config=model_case.torchao_config,
|
||||
) as srt_runner:
|
||||
srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens)
|
||||
|
||||
with SRTRunner(
|
||||
model_path,
|
||||
tp_size=model_case.tp_size,
|
||||
torch_dtype=model_case.torch_dtype,
|
||||
model_type="generation",
|
||||
trust_remote_code=model_case.trust_remote_code,
|
||||
torchao_config=model_case.torchao_config,
|
||||
) as srt_runner:
|
||||
srt_transformers_outputs = srt_runner.forward(
|
||||
prompts, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
||||
check_close_model_outputs(
|
||||
hf_outputs=srt_transformers_outputs,
|
||||
srt_outputs=srt_outputs,
|
||||
prefill_tolerance=model_case.prefill_tolerance,
|
||||
decode_tolerance=model_case.decode_tolerance,
|
||||
rouge_l_tolerance=model_case.rouge_l_tolerance,
|
||||
debug_text=f"model_path={model_path} prompts={prompts}",
|
||||
)
|
||||
|
||||
def test_ci_models(self):
|
||||
for model_case in CI_MODELS:
|
||||
# Skip long prompts for models that do not have a long context
|
||||
prompts = DEFAULT_PROMPTS
|
||||
if model_case.skip_long_prompt:
|
||||
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
|
||||
# Assert the logits and output strs are close
|
||||
self.assert_close_logits_and_output_strs(prompts, model_case)
|
||||
|
||||
def test_others(self):
|
||||
if is_in_ci():
|
||||
return
|
||||
|
||||
# Skip long prompts for models that do not have a long context
|
||||
prompts = DEFAULT_PROMPTS
|
||||
for model_case in ALL_OTHER_MODELS:
|
||||
if model_case.skip_long_prompt:
|
||||
prompts = [p for p in DEFAULT_PROMPTS if len(p) < 1000]
|
||||
|
||||
# Assert the logits and output strs are close
|
||||
self.assert_close_logits_and_output_strs(prompts, model_case)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
213
test/srt/models/test_unsloth_models.py
Normal file
213
test/srt/models/test_unsloth_models.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.few_shot_gsm8k import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestUnslothPhi4(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "unsloth/phi-4"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
print(f"{metrics=}")
|
||||
self.assertGreater(metrics["accuracy"], 0.78)
|
||||
|
||||
|
||||
class TestUnslothPhi4Bnb4bit(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "unsloth/phi-4-bnb-4bit"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--load-format",
|
||||
"bitsandbytes",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
print(f"{metrics=}")
|
||||
self.assertGreater(metrics["accuracy"], 0.75)
|
||||
|
||||
|
||||
class TestUnslothPhi4UnslothBnb4bit(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "unsloth/phi-4-unsloth-bnb-4bit"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--load-format",
|
||||
"bitsandbytes",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
print(f"{metrics=}")
|
||||
self.assertGreater(metrics["accuracy"], 0.75)
|
||||
|
||||
|
||||
class TestUnslothPhi4MiniInstruct(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "unsloth/Phi-4-mini-instruct"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
print(f"{metrics=}")
|
||||
self.assertGreater(metrics["accuracy"], 0.65)
|
||||
|
||||
|
||||
class TestUnslothPhi4MiniBnb4bit(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "unsloth/Phi-4-mini-instruct-bnb-4bit"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--load-format",
|
||||
"bitsandbytes",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
print(f"{metrics=}")
|
||||
self.assertGreater(metrics["accuracy"], 0.6)
|
||||
|
||||
|
||||
class TestUnslothPhi4MiniUnslothBnb4bit(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "unsloth/Phi-4-mini-instruct-unsloth-bnb-4bit"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--load-format",
|
||||
"bitsandbytes",
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
print(f"{metrics=}")
|
||||
self.assertGreater(metrics["accuracy"], 0.6)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
315
test/srt/models/test_vlm_models.py
Normal file
315
test/srt/models/test_vlm_models.py
Normal file
@@ -0,0 +1,315 @@
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
import sys
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
# VLM models for testing
|
||||
MODELS = [
|
||||
SimpleNamespace(model="google/gemma-3-27b-it", mmmu_accuracy=0.45),
|
||||
SimpleNamespace(
|
||||
model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
mmmu_accuracy=0.4,
|
||||
),
|
||||
SimpleNamespace(model="openbmb/MiniCPM-V-2_6", mmmu_accuracy=0.4),
|
||||
]
|
||||
|
||||
|
||||
class TestVLMModels(CustomTestCase):
|
||||
parsed_args = None # Class variable to store args
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Removed argument parsing from here
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.time_out = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
|
||||
# Set OpenAI API key and base URL environment variables. Needed for lmm-evals to work.
|
||||
os.environ["OPENAI_API_KEY"] = cls.api_key
|
||||
os.environ["OPENAI_API_BASE"] = f"{cls.base_url}/v1"
|
||||
|
||||
def _detect_eviction_in_logs(self, log_output):
|
||||
"""Detect if eviction events occurred in the log output."""
|
||||
eviction_keywords = ["Cache eviction: evicted"]
|
||||
|
||||
eviction_detected = False
|
||||
eviction_count = 0
|
||||
|
||||
for line in log_output.split("\n"):
|
||||
if any(keyword in line for keyword in eviction_keywords):
|
||||
eviction_detected = True
|
||||
eviction_count += 1
|
||||
print(f"Eviction detected: {line.strip()}")
|
||||
|
||||
return eviction_detected, eviction_count
|
||||
|
||||
def run_mmmu_eval(
|
||||
self,
|
||||
model_version: str,
|
||||
output_path: str,
|
||||
*,
|
||||
env: dict | None = None,
|
||||
):
|
||||
"""
|
||||
Evaluate a VLM on the MMMU validation set with lmms‑eval.
|
||||
Only `model_version` (checkpoint) and `chat_template` vary;
|
||||
We are focusing only on the validation set due to resource constraints.
|
||||
"""
|
||||
# -------- fixed settings --------
|
||||
model = "openai_compatible"
|
||||
tp = 1
|
||||
tasks = "mmmu_val"
|
||||
batch_size = 2
|
||||
log_suffix = "openai_compatible"
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
# -------- compose --model_args --------
|
||||
model_args = f'model_version="{model_version}",' f"tp={tp}"
|
||||
|
||||
# -------- build command list --------
|
||||
cmd = [
|
||||
"python3",
|
||||
"-m",
|
||||
"lmms_eval",
|
||||
"--model",
|
||||
model,
|
||||
"--model_args",
|
||||
model_args,
|
||||
"--tasks",
|
||||
tasks,
|
||||
"--batch_size",
|
||||
str(batch_size),
|
||||
"--log_samples",
|
||||
"--log_samples_suffix",
|
||||
log_suffix,
|
||||
"--output_path",
|
||||
str(output_path),
|
||||
]
|
||||
|
||||
subprocess.run(
|
||||
cmd,
|
||||
check=True,
|
||||
timeout=3600,
|
||||
)
|
||||
|
||||
def _run_vlm_mmmu_test(
|
||||
self,
|
||||
model,
|
||||
output_path,
|
||||
test_name="",
|
||||
custom_env=None,
|
||||
log_level="info",
|
||||
capture_output=False,
|
||||
):
|
||||
"""
|
||||
Common method to run VLM MMMU benchmark test.
|
||||
|
||||
Args:
|
||||
model: Model to test
|
||||
output_path: Path for output logs
|
||||
test_name: Optional test name for logging
|
||||
custom_env: Optional custom environment variables
|
||||
log_level: Log level for server (default: "info")
|
||||
capture_output: Whether to capture server stdout/stderr
|
||||
"""
|
||||
print(f"\nTesting model: {model.model}{test_name}")
|
||||
|
||||
process = None
|
||||
mmmu_accuracy = 0 # Initialize to handle potential exceptions
|
||||
server_output = ""
|
||||
|
||||
try:
|
||||
# Prepare environment variables
|
||||
process_env = os.environ.copy()
|
||||
if custom_env:
|
||||
process_env.update(custom_env)
|
||||
|
||||
# Prepare stdout/stderr redirection if needed
|
||||
stdout_file = None
|
||||
stderr_file = None
|
||||
if capture_output:
|
||||
stdout_file = open("/tmp/server_stdout.log", "w")
|
||||
stderr_file = open("/tmp/server_stderr.log", "w")
|
||||
|
||||
# Launch server for testing
|
||||
process = popen_launch_server(
|
||||
model.model,
|
||||
base_url=self.base_url,
|
||||
timeout=self.time_out,
|
||||
api_key=self.api_key,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--cuda-graph-max-bs",
|
||||
"32",
|
||||
"--enable-multimodal",
|
||||
"--mem-fraction-static",
|
||||
str(self.parsed_args.mem_fraction_static), # Use class variable
|
||||
"--log-level",
|
||||
log_level,
|
||||
],
|
||||
env=process_env,
|
||||
return_stdout_stderr=(
|
||||
(stdout_file, stderr_file) if capture_output else None
|
||||
),
|
||||
)
|
||||
|
||||
# Run evaluation
|
||||
self.run_mmmu_eval(model.model, output_path)
|
||||
|
||||
# Get the result file
|
||||
result_file_path = glob.glob(f"{output_path}/*.json")[0]
|
||||
|
||||
with open(result_file_path, "r") as f:
|
||||
result = json.load(f)
|
||||
print(f"Result{test_name}\n: {result}")
|
||||
|
||||
# Process the result
|
||||
mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"]
|
||||
print(
|
||||
f"Model {model.model} achieved accuracy{test_name}: {mmmu_accuracy:.4f}"
|
||||
)
|
||||
|
||||
# Capture server output if requested
|
||||
if capture_output and process:
|
||||
server_output = self._read_output_from_files()
|
||||
|
||||
# Assert performance meets expected threshold
|
||||
self.assertGreaterEqual(
|
||||
mmmu_accuracy,
|
||||
model.mmmu_accuracy,
|
||||
f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f}){test_name}",
|
||||
)
|
||||
|
||||
return server_output
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error testing {model.model}{test_name}: {e}")
|
||||
self.fail(f"Test failed for {model.model}{test_name}: {e}")
|
||||
|
||||
finally:
|
||||
# Ensure process cleanup happens regardless of success/failure
|
||||
if process is not None and process.poll() is None:
|
||||
print(f"Cleaning up process {process.pid}")
|
||||
try:
|
||||
kill_process_tree(process.pid)
|
||||
except Exception as e:
|
||||
print(f"Error killing process: {e}")
|
||||
|
||||
# clean up temporary files
|
||||
if capture_output:
|
||||
if stdout_file:
|
||||
stdout_file.close()
|
||||
if stderr_file:
|
||||
stderr_file.close()
|
||||
for filename in ["/tmp/server_stdout.log", "/tmp/server_stderr.log"]:
|
||||
try:
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
except Exception as e:
|
||||
print(f"Error removing {filename}: {e}")
|
||||
|
||||
def _read_output_from_files(self):
|
||||
output_lines = []
|
||||
|
||||
log_files = [
|
||||
("/tmp/server_stdout.log", "[STDOUT]"),
|
||||
("/tmp/server_stderr.log", "[STDERR]"),
|
||||
]
|
||||
for filename, tag in log_files:
|
||||
try:
|
||||
if os.path.exists(filename):
|
||||
with open(filename, "r") as f:
|
||||
for line in f:
|
||||
output_lines.append(f"{tag} {line.rstrip()}")
|
||||
except Exception as e:
|
||||
print(f"Error reading {tag.lower()} file: {e}")
|
||||
|
||||
return "\n".join(output_lines)
|
||||
|
||||
def test_vlm_mmmu_benchmark(self):
|
||||
"""Test VLM models against MMMU benchmark."""
|
||||
models_to_test = MODELS
|
||||
|
||||
if is_in_ci():
|
||||
models_to_test = [random.choice(MODELS)]
|
||||
|
||||
for model in models_to_test:
|
||||
self._run_vlm_mmmu_test(model, "./logs")
|
||||
|
||||
def test_vlm_mmmu_benchmark_with_small_cache(self):
|
||||
"""Test VLM models against MMMU benchmark with a small embedding cache to force eviction."""
|
||||
models_to_test = MODELS
|
||||
|
||||
if is_in_ci():
|
||||
models_to_test = [random.choice(MODELS)]
|
||||
|
||||
for model in models_to_test:
|
||||
custom_env = {"SGLANG_VLM_CACHE_SIZE_MB": "5"}
|
||||
|
||||
# Run the test with output capture
|
||||
server_output = self._run_vlm_mmmu_test(
|
||||
model,
|
||||
"./logs_small_cache",
|
||||
test_name=" with small embedding cache (evict test)",
|
||||
custom_env=custom_env,
|
||||
log_level="debug", # Enable debug logging for eviction detection
|
||||
capture_output=True, # Capture server output
|
||||
)
|
||||
|
||||
# Print server output for debugging
|
||||
print("Server output:\n", server_output)
|
||||
|
||||
# Analyze server output for eviction events
|
||||
eviction_detected, eviction_count = self._detect_eviction_in_logs(
|
||||
server_output
|
||||
)
|
||||
|
||||
# Assert that eviction was detected (since we're using small cache)
|
||||
self.assertTrue(
|
||||
eviction_detected,
|
||||
f"Expected eviction events to be detected with small cache (5MB), but none found. "
|
||||
f"Cache size may be too large for the workload or eviction logic may not be working. "
|
||||
f"Total log content length: {len(server_output)} characters",
|
||||
)
|
||||
|
||||
print(
|
||||
f"Eviction detection summary: {eviction_count} eviction events detected"
|
||||
)
|
||||
|
||||
# Additional assertion: if eviction was detected, the test passed
|
||||
if eviction_detected:
|
||||
print("✅ Eviction logic successfully triggered and detected!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Define and parse arguments here, before unittest.main
|
||||
parser = argparse.ArgumentParser(description="Test VLM models")
|
||||
parser.add_argument(
|
||||
"--mem-fraction-static",
|
||||
type=float,
|
||||
help="Static memory fraction for the model",
|
||||
default=0.8,
|
||||
)
|
||||
|
||||
# Parse args intended for unittest
|
||||
args = parser.parse_args()
|
||||
|
||||
# Store the parsed args object on the class
|
||||
TestVLMModels.parsed_args = args
|
||||
|
||||
# Pass args to unittest
|
||||
unittest.main(argv=[sys.argv[0]])
|
||||
0
test/srt/openai_server/__init__.py
Normal file
0
test/srt/openai_server/__init__.py
Normal file
0
test/srt/openai_server/basic/__init__.py
Normal file
0
test/srt/openai_server/basic/__init__.py
Normal file
97
test/srt/openai_server/basic/test_openai_embedding.py
Normal file
97
test/srt/openai_server/basic/test_openai_embedding.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIEmbedding(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
# Configure embedding-specific args
|
||||
other_args = ["--is-embedding", "--enable-metrics"]
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=other_args,
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_embedding_single(self):
|
||||
"""Test single embedding request"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(model=self.model, input="Hello world")
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
|
||||
def test_embedding_batch(self):
|
||||
"""Test batch embedding request"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(
|
||||
model=self.model, input=["Hello world", "Test text"]
|
||||
)
|
||||
self.assertEqual(len(response.data), 2)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
self.assertTrue(len(response.data[1].embedding) > 0)
|
||||
|
||||
def test_embedding_single_batch_str(self):
|
||||
"""Test embedding with a List[str] and length equals to 1"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(model=self.model, input=["Hello world"])
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
|
||||
def test_embedding_single_int_list(self):
|
||||
"""Test embedding with a List[int] or List[List[int]]]"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(
|
||||
model=self.model,
|
||||
input=[[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061]],
|
||||
)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.embeddings.create(
|
||||
model=self.model,
|
||||
input=[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061],
|
||||
)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||
|
||||
def test_empty_string_embedding(self):
|
||||
"""Test embedding an empty string."""
|
||||
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
# Text embedding example with empty string
|
||||
text = ""
|
||||
# Expect a BadRequestError for empty input
|
||||
with self.assertRaises(openai.BadRequestError) as cm:
|
||||
client.embeddings.create(
|
||||
model=self.model,
|
||||
input=text,
|
||||
)
|
||||
# check the status code
|
||||
self.assertEqual(cm.exception.status_code, 400)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
669
test/srt/openai_server/basic/test_openai_server.py
Normal file
669
test/srt/openai_server/basic/test_openai_server.py
Normal file
@@ -0,0 +1,669 @@
|
||||
"""
|
||||
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_completion
|
||||
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_completion_stream
|
||||
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_chat_completion
|
||||
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_chat_completion_stream
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.runners import TEST_RERANK_QUERY_DOCS
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIServer(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_completion(
|
||||
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
prompt = "The capital of France is"
|
||||
if token_input:
|
||||
prompt_input = self.tokenizer.encode(prompt)
|
||||
num_prompt_tokens = len(prompt_input)
|
||||
else:
|
||||
prompt_input = prompt
|
||||
num_prompt_tokens = len(self.tokenizer.encode(prompt))
|
||||
|
||||
if use_list_input:
|
||||
prompt_arg = [prompt_input, prompt_input]
|
||||
num_choices = len(prompt_arg)
|
||||
num_prompt_tokens *= 2
|
||||
else:
|
||||
prompt_arg = prompt_input
|
||||
num_choices = 1
|
||||
|
||||
response = client.completions.create(
|
||||
model=self.model,
|
||||
prompt=prompt_arg,
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
echo=echo,
|
||||
logprobs=logprobs,
|
||||
n=parallel_sample_num,
|
||||
)
|
||||
|
||||
assert len(response.choices) == num_choices * parallel_sample_num
|
||||
|
||||
if echo:
|
||||
text = response.choices[0].text
|
||||
assert text.startswith(prompt)
|
||||
|
||||
if logprobs:
|
||||
assert response.choices[0].logprobs
|
||||
assert isinstance(response.choices[0].logprobs.tokens[0], str)
|
||||
assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
|
||||
ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
|
||||
|
||||
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
|
||||
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
assert ret_num_top_logprobs > 0
|
||||
|
||||
# when echo=True and request.logprobs>0, logprob_start_len is 0, so the first token's logprob would be None.
|
||||
if not echo:
|
||||
assert response.choices[0].logprobs.token_logprobs[0]
|
||||
|
||||
assert response.id
|
||||
assert response.created
|
||||
assert (
|
||||
response.usage.prompt_tokens == num_prompt_tokens
|
||||
), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
|
||||
assert response.usage.completion_tokens > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def run_completion_stream(
|
||||
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
prompt = "The capital of France is"
|
||||
if token_input:
|
||||
prompt_input = self.tokenizer.encode(prompt)
|
||||
num_prompt_tokens = len(prompt_input)
|
||||
else:
|
||||
prompt_input = prompt
|
||||
num_prompt_tokens = len(self.tokenizer.encode(prompt))
|
||||
|
||||
if use_list_input:
|
||||
prompt_arg = [prompt_input, prompt_input]
|
||||
num_choices = len(prompt_arg)
|
||||
num_prompt_tokens *= 2
|
||||
else:
|
||||
prompt_arg = prompt_input
|
||||
num_choices = 1
|
||||
|
||||
generator = client.completions.create(
|
||||
model=self.model,
|
||||
prompt=prompt_arg,
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
echo=echo,
|
||||
logprobs=logprobs,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
n=parallel_sample_num,
|
||||
)
|
||||
|
||||
is_firsts = {}
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
if usage is not None:
|
||||
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
|
||||
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
|
||||
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
|
||||
continue
|
||||
|
||||
index = response.choices[0].index
|
||||
is_first = is_firsts.get(index, True)
|
||||
|
||||
if logprobs:
|
||||
assert response.choices[0].logprobs, f"no logprobs in response"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.tokens[0], str
|
||||
), f"{response.choices[0].logprobs.tokens[0]} is not a string"
|
||||
if not (is_first and echo):
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.top_logprobs[0], dict
|
||||
), f"top_logprobs was not a dictionary"
|
||||
ret_num_top_logprobs = len(
|
||||
response.choices[0].logprobs.top_logprobs[0]
|
||||
)
|
||||
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
|
||||
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
assert ret_num_top_logprobs > 0, f"ret_num_top_logprobs was 0"
|
||||
|
||||
if is_first:
|
||||
if echo:
|
||||
assert response.choices[0].text.startswith(
|
||||
prompt
|
||||
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
|
||||
is_firsts[index] = False
|
||||
assert response.id, f"no id in response"
|
||||
assert response.created, f"no created in response"
|
||||
|
||||
for index in [i for i in range(parallel_sample_num * num_choices)]:
|
||||
assert not is_firsts.get(
|
||||
index, True
|
||||
), f"index {index} is not found in the response"
|
||||
|
||||
def run_chat_completion(self, logprobs, parallel_sample_num):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the capital of France? Answer in a few words.",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
logprobs=logprobs is not None and logprobs > 0,
|
||||
top_logprobs=logprobs,
|
||||
n=parallel_sample_num,
|
||||
)
|
||||
|
||||
if logprobs:
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
||||
)
|
||||
|
||||
ret_num_top_logprobs = len(
|
||||
response.choices[0].logprobs.content[0].top_logprobs
|
||||
)
|
||||
assert (
|
||||
ret_num_top_logprobs == logprobs
|
||||
), f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
|
||||
assert len(response.choices) == parallel_sample_num
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert isinstance(response.choices[0].message.content, str)
|
||||
assert response.id
|
||||
assert response.created
|
||||
assert response.usage.prompt_tokens > 0
|
||||
assert response.usage.completion_tokens > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
generator = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
],
|
||||
temperature=0,
|
||||
logprobs=logprobs is not None and logprobs > 0,
|
||||
top_logprobs=logprobs,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
n=parallel_sample_num,
|
||||
)
|
||||
|
||||
is_firsts = {}
|
||||
is_finished = {}
|
||||
finish_reason_counts = {}
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
if usage is not None:
|
||||
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
|
||||
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
|
||||
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
|
||||
continue
|
||||
|
||||
index = response.choices[0].index
|
||||
finish_reason = response.choices[0].finish_reason
|
||||
if finish_reason is not None:
|
||||
is_finished[index] = True
|
||||
finish_reason_counts[index] = finish_reason_counts.get(index, 0) + 1
|
||||
|
||||
data = response.choices[0].delta
|
||||
|
||||
if is_firsts.get(index, True):
|
||||
assert (
|
||||
data.role == "assistant"
|
||||
), f"data.role was not 'assistant' for first chunk"
|
||||
is_firsts[index] = False
|
||||
continue
|
||||
|
||||
if logprobs and not is_finished.get(index, False):
|
||||
assert response.choices[0].logprobs, f"logprobs was not returned"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
||||
), f"top_logprobs token was not a string"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs, list
|
||||
), f"top_logprobs was not a list"
|
||||
ret_num_top_logprobs = len(
|
||||
response.choices[0].logprobs.content[0].top_logprobs
|
||||
)
|
||||
assert (
|
||||
ret_num_top_logprobs == logprobs
|
||||
), f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
|
||||
assert (
|
||||
isinstance(data.content, str)
|
||||
or isinstance(data.reasoning_content, str)
|
||||
or (isinstance(data.tool_calls, list) and len(data.tool_calls) > 0)
|
||||
or response.choices[0].finish_reason
|
||||
)
|
||||
assert response.id
|
||||
assert response.created
|
||||
|
||||
for index in [i for i in range(parallel_sample_num)]:
|
||||
assert not is_firsts.get(
|
||||
index, True
|
||||
), f"index {index} is not found in the response"
|
||||
|
||||
# Verify that each choice gets exactly one finish_reason chunk
|
||||
for index in range(parallel_sample_num):
|
||||
assert (
|
||||
index in finish_reason_counts
|
||||
), f"No finish_reason found for index {index}"
|
||||
assert (
|
||||
finish_reason_counts[index] == 1
|
||||
), f"Expected 1 finish_reason chunk for index {index}, got {finish_reason_counts[index]}"
|
||||
|
||||
def test_completion(self):
|
||||
for echo in [False, True]:
|
||||
for logprobs in [None, 5]:
|
||||
for use_list_input in [True, False]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
for token_input in [False, True]:
|
||||
self.run_completion(
|
||||
echo,
|
||||
logprobs,
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
token_input,
|
||||
)
|
||||
|
||||
def test_completion_stream(self):
|
||||
# parallel sampling and list input are not supported in streaming mode
|
||||
for echo in [False, True]:
|
||||
for logprobs in [None, 5]:
|
||||
for use_list_input in [True, False]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
for token_input in [False, True]:
|
||||
self.run_completion_stream(
|
||||
echo,
|
||||
logprobs,
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
token_input,
|
||||
)
|
||||
|
||||
def test_chat_completion(self):
|
||||
for logprobs in [None, 5]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
self.run_chat_completion(logprobs, parallel_sample_num)
|
||||
|
||||
def test_chat_completion_stream(self):
|
||||
for logprobs in [None, 5]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
self.run_chat_completion_stream(logprobs, parallel_sample_num)
|
||||
|
||||
def test_regex(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
regex = (
|
||||
r"""\{\n"""
|
||||
+ r""" "name": "[\w]+",\n"""
|
||||
+ r""" "population": [\d]+\n"""
|
||||
+ r"""\}"""
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "Introduce the capital of France."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
extra_body={"regex": regex},
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
|
||||
try:
|
||||
js_obj = json.loads(text)
|
||||
except (TypeError, json.decoder.JSONDecodeError):
|
||||
print("JSONDecodeError", text)
|
||||
raise
|
||||
assert isinstance(js_obj["name"], str)
|
||||
assert isinstance(js_obj["population"], int)
|
||||
|
||||
def test_penalty(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "Introduce the capital of France."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
frequency_penalty=1.0,
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
assert isinstance(text, str)
|
||||
|
||||
def test_response_prefill(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": """
|
||||
Extract the name, size, price, and color from this product description as a JSON object:
|
||||
|
||||
<description>
|
||||
The SmartHome Mini is a compact smart home assistant available in black or white for only $49.99. At just 5 inches wide, it lets you control lights, thermostats, and other connected devices via voice or app—no matter where you place it in your home. This affordable little hub brings convenient hands-free control to your smart devices.
|
||||
</description>
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "{\n",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
extra_body={"continue_final_message": True},
|
||||
)
|
||||
|
||||
assert (
|
||||
response.choices[0]
|
||||
.message.content.strip()
|
||||
.startswith('"name": "SmartHome Mini",')
|
||||
)
|
||||
|
||||
def test_model_list(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
models = list(client.models.list())
|
||||
assert len(models) == 1
|
||||
assert isinstance(getattr(models[0], "max_model_len", None), int)
|
||||
|
||||
def test_retrieve_model(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
# Test retrieving an existing model
|
||||
retrieved_model = client.models.retrieve(self.model)
|
||||
self.assertEqual(retrieved_model.id, self.model)
|
||||
self.assertEqual(retrieved_model.root, self.model)
|
||||
|
||||
# Test retrieving a non-existent model
|
||||
with self.assertRaises(openai.NotFoundError):
|
||||
client.models.retrieve("non-existent-model")
|
||||
|
||||
|
||||
class TestOpenAIV1Rerank(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.score_tolerance = 1e-2
|
||||
|
||||
# Configure embedding-specific args
|
||||
other_args = [
|
||||
"--is-embedding",
|
||||
"--enable-metrics",
|
||||
"--disable-radix-cache",
|
||||
"--chunked-prefill-size",
|
||||
"-1",
|
||||
"--attention-backend",
|
||||
"torch_native",
|
||||
]
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=other_args,
|
||||
)
|
||||
cls.base_url += "/v1/rerank"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_rerank(self, query, docs):
|
||||
response = requests.post(
|
||||
self.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"query": query, "documents": docs},
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
def test_rerank_single(self):
|
||||
"""Test single rerank request"""
|
||||
query = TEST_RERANK_QUERY_DOCS[0]["query"]
|
||||
docs = TEST_RERANK_QUERY_DOCS[0]["documents"]
|
||||
|
||||
response = self.run_rerank(query, docs)
|
||||
|
||||
self.assertEqual(len(response), 1)
|
||||
self.assertTrue(isinstance(response[0]["score"], float))
|
||||
self.assertTrue(isinstance(response[0]["document"], str))
|
||||
self.assertTrue(isinstance(response[0]["index"], int))
|
||||
|
||||
def test_rerank_batch(self):
|
||||
"""Test batch rerank request"""
|
||||
query = TEST_RERANK_QUERY_DOCS[1]["query"]
|
||||
docs = TEST_RERANK_QUERY_DOCS[1]["documents"]
|
||||
|
||||
response = self.run_rerank(query, docs)
|
||||
|
||||
self.assertEqual(len(response), 2)
|
||||
self.assertTrue(isinstance(response[0]["score"], float))
|
||||
self.assertTrue(isinstance(response[1]["score"], float))
|
||||
self.assertTrue(isinstance(response[0]["document"], str))
|
||||
self.assertTrue(isinstance(response[1]["document"], str))
|
||||
self.assertTrue(isinstance(response[0]["index"], int))
|
||||
self.assertTrue(isinstance(response[1]["index"], int))
|
||||
|
||||
|
||||
class TestOpenAIV1Score(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
)
|
||||
cls.base_url += "/v1/score"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_score(
|
||||
self, query, items, label_token_ids, apply_softmax=False, item_first=False
|
||||
):
|
||||
response = requests.post(
|
||||
self.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": self.model,
|
||||
"query": query,
|
||||
"items": items,
|
||||
"label_token_ids": label_token_ids,
|
||||
"apply_softmax": apply_softmax,
|
||||
"item_first": item_first,
|
||||
},
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def test_score_text_input(self):
|
||||
"""Test scoring with text input"""
|
||||
query = "The capital of France is"
|
||||
items = ["Paris", "London", "Berlin"]
|
||||
|
||||
# Get valid token IDs from the tokenizer
|
||||
label_token_ids = []
|
||||
for item in items:
|
||||
token_ids = self.tokenizer.encode(item, add_special_tokens=False)
|
||||
if not token_ids:
|
||||
self.fail(f"Failed to encode item: {item}")
|
||||
label_token_ids.append(token_ids[0])
|
||||
|
||||
response = self.run_score(query, items, label_token_ids, apply_softmax=True)
|
||||
|
||||
# Handle error responses
|
||||
if response.get("type") == "BadRequestError":
|
||||
self.fail(f"Score request failed with error: {response['message']}")
|
||||
|
||||
# Verify response structure
|
||||
self.assertIn("scores", response, "Response should have a 'scores' field")
|
||||
self.assertIsInstance(response["scores"], list, "scores should be a list")
|
||||
self.assertEqual(
|
||||
len(response["scores"]),
|
||||
len(items),
|
||||
"Number of scores should match number of items",
|
||||
)
|
||||
|
||||
# Each score should be a list of floats in the order of label_token_ids
|
||||
for i, score_list in enumerate(response["scores"]):
|
||||
self.assertIsInstance(score_list, list, f"Score {i} should be a list")
|
||||
self.assertEqual(
|
||||
len(score_list),
|
||||
len(label_token_ids),
|
||||
f"Score {i} length should match label_token_ids",
|
||||
)
|
||||
self.assertTrue(
|
||||
all(isinstance(v, float) for v in score_list),
|
||||
f"Score {i} values should be floats",
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
sum(score_list),
|
||||
1.0,
|
||||
places=6,
|
||||
msg=f"Score {i} probabilities should sum to 1",
|
||||
)
|
||||
|
||||
def test_score_token_input(self):
|
||||
"""Test scoring with token IDs input"""
|
||||
query = "The capital of France is"
|
||||
items = ["Paris", "London", "Berlin"]
|
||||
|
||||
# Get valid token IDs
|
||||
query_ids = self.tokenizer.encode(query, add_special_tokens=False)
|
||||
item_ids = [
|
||||
self.tokenizer.encode(item, add_special_tokens=False) for item in items
|
||||
]
|
||||
label_token_ids = [
|
||||
ids[0] for ids in item_ids if ids
|
||||
] # Get first token ID of each item
|
||||
|
||||
response = self.run_score(
|
||||
query_ids, item_ids, label_token_ids, apply_softmax=True
|
||||
)
|
||||
|
||||
# Handle error responses
|
||||
if response.get("type") == "BadRequestError":
|
||||
self.fail(f"Score request failed with error: {response['message']}")
|
||||
|
||||
# Verify response structure
|
||||
self.assertIn("scores", response, "Response should have a 'scores' field")
|
||||
self.assertIsInstance(response["scores"], list, "scores should be a list")
|
||||
self.assertEqual(
|
||||
len(response["scores"]),
|
||||
len(items),
|
||||
"Number of scores should match number of items",
|
||||
)
|
||||
|
||||
# Each score should be a list of floats in the order of label_token_ids
|
||||
for i, score_list in enumerate(response["scores"]):
|
||||
self.assertIsInstance(score_list, list, f"Score {i} should be a list")
|
||||
self.assertEqual(
|
||||
len(score_list),
|
||||
len(label_token_ids),
|
||||
f"Score {i} length should match label_token_ids",
|
||||
)
|
||||
self.assertTrue(
|
||||
all(isinstance(v, float) for v in score_list),
|
||||
f"Score {i} values should be floats",
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
sum(score_list),
|
||||
1.0,
|
||||
places=6,
|
||||
msg=f"Score {i} probabilities should sum to 1",
|
||||
)
|
||||
|
||||
def test_score_error_handling(self):
|
||||
"""Test error handling for invalid inputs"""
|
||||
query = "The capital of France is"
|
||||
items = ["Paris", "London", "Berlin"]
|
||||
|
||||
# Test with invalid token ID
|
||||
response = requests.post(
|
||||
self.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": self.model,
|
||||
"query": query,
|
||||
"items": items,
|
||||
"label_token_ids": [999999], # Invalid token ID
|
||||
"apply_softmax": True,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
error_response = response.json()
|
||||
self.assertEqual(error_response["type"], "BadRequestError")
|
||||
self.assertIn("Token ID 999999 is out of vocabulary", error_response["message"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
368
test/srt/openai_server/basic/test_protocol.py
Normal file
368
test/srt/openai_server/basic/test_protocol.py
Normal file
@@ -0,0 +1,368 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for OpenAI API protocol models"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import unittest
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
BatchRequest,
|
||||
BatchResponse,
|
||||
ChatCompletionMessageContentImagePart,
|
||||
ChatCompletionMessageContentTextPart,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionTokenLogprob,
|
||||
ChatMessage,
|
||||
ChoiceLogprobs,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
DeltaMessage,
|
||||
EmbeddingObject,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ErrorResponse,
|
||||
FileDeleteResponse,
|
||||
FileRequest,
|
||||
FileResponse,
|
||||
Function,
|
||||
FunctionResponse,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbs,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
MultimodalEmbeddingInput,
|
||||
ResponseFormat,
|
||||
ScoringRequest,
|
||||
ScoringResponse,
|
||||
StreamOptions,
|
||||
StructuralTagResponseFormat,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolChoice,
|
||||
TopLogprob,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
|
||||
class TestModelCard(unittest.TestCase):
|
||||
"""Test ModelCard protocol model"""
|
||||
|
||||
def test_model_card_serialization(self):
|
||||
"""Test model card JSON serialization"""
|
||||
card = ModelCard(id="test-model", max_model_len=4096)
|
||||
data = card.model_dump()
|
||||
self.assertEqual(data["id"], "test-model")
|
||||
self.assertEqual(data["object"], "model")
|
||||
self.assertEqual(data["max_model_len"], 4096)
|
||||
|
||||
|
||||
class TestModelList(unittest.TestCase):
|
||||
"""Test ModelList protocol model"""
|
||||
|
||||
def test_empty_model_list(self):
|
||||
"""Test empty model list creation"""
|
||||
model_list = ModelList()
|
||||
self.assertEqual(model_list.object, "list")
|
||||
self.assertEqual(len(model_list.data), 0)
|
||||
|
||||
def test_model_list_with_cards(self):
|
||||
"""Test model list with model cards"""
|
||||
cards = [
|
||||
ModelCard(id="model-1"),
|
||||
ModelCard(id="model-2", max_model_len=2048),
|
||||
]
|
||||
model_list = ModelList(data=cards)
|
||||
self.assertEqual(len(model_list.data), 2)
|
||||
self.assertEqual(model_list.data[0].id, "model-1")
|
||||
self.assertEqual(model_list.data[1].id, "model-2")
|
||||
|
||||
|
||||
class TestCompletionRequest(unittest.TestCase):
|
||||
"""Test CompletionRequest protocol model"""
|
||||
|
||||
def test_basic_completion_request(self):
|
||||
"""Test basic completion request"""
|
||||
request = CompletionRequest(model="test-model", prompt="Hello world")
|
||||
self.assertEqual(request.model, "test-model")
|
||||
self.assertEqual(request.prompt, "Hello world")
|
||||
self.assertEqual(request.max_tokens, 16) # default
|
||||
self.assertEqual(request.temperature, 1.0) # default
|
||||
self.assertEqual(request.n, 1) # default
|
||||
self.assertFalse(request.stream) # default
|
||||
self.assertFalse(request.echo) # default
|
||||
|
||||
def test_completion_request_sglang_extensions(self):
|
||||
"""Test completion request with SGLang-specific extensions"""
|
||||
request = CompletionRequest(
|
||||
model="test-model",
|
||||
prompt="Hello",
|
||||
top_k=50,
|
||||
min_p=0.1,
|
||||
repetition_penalty=1.1,
|
||||
regex=r"\d+",
|
||||
json_schema='{"type": "object"}',
|
||||
lora_path="/path/to/lora",
|
||||
)
|
||||
self.assertEqual(request.top_k, 50)
|
||||
self.assertEqual(request.min_p, 0.1)
|
||||
self.assertEqual(request.repetition_penalty, 1.1)
|
||||
self.assertEqual(request.regex, r"\d+")
|
||||
self.assertEqual(request.json_schema, '{"type": "object"}')
|
||||
self.assertEqual(request.lora_path, "/path/to/lora")
|
||||
|
||||
def test_completion_request_validation_errors(self):
|
||||
"""Test completion request validation errors"""
|
||||
with self.assertRaises(ValidationError):
|
||||
CompletionRequest() # missing required fields
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
CompletionRequest(model="test-model") # missing prompt
|
||||
|
||||
|
||||
class TestChatCompletionRequest(unittest.TestCase):
|
||||
"""Test ChatCompletionRequest protocol model"""
|
||||
|
||||
def test_basic_chat_completion_request(self):
|
||||
"""Test basic chat completion request"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
request = ChatCompletionRequest(model="test-model", messages=messages)
|
||||
self.assertEqual(request.model, "test-model")
|
||||
self.assertEqual(len(request.messages), 1)
|
||||
self.assertEqual(request.messages[0].role, "user")
|
||||
self.assertEqual(request.messages[0].content, "Hello")
|
||||
self.assertEqual(request.temperature, 0.7) # default
|
||||
self.assertFalse(request.stream) # default
|
||||
self.assertEqual(request.tool_choice, "none") # default when no tools
|
||||
|
||||
def test_chat_completion_tool_choice_validation(self):
|
||||
"""Test tool choice validation logic"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
# No tools, tool_choice should default to "none"
|
||||
request1 = ChatCompletionRequest(model="test-model", messages=messages)
|
||||
self.assertEqual(request1.tool_choice, "none")
|
||||
|
||||
# With tools, tool_choice should default to "auto"
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "test_func", "description": "Test function"},
|
||||
}
|
||||
]
|
||||
request2 = ChatCompletionRequest(
|
||||
model="test-model", messages=messages, tools=tools
|
||||
)
|
||||
self.assertEqual(request2.tool_choice, "auto")
|
||||
|
||||
def test_chat_completion_sglang_extensions(self):
|
||||
"""Test chat completion with SGLang extensions"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=messages,
|
||||
top_k=40,
|
||||
min_p=0.05,
|
||||
separate_reasoning=False,
|
||||
stream_reasoning=False,
|
||||
chat_template_kwargs={"custom_param": "value"},
|
||||
)
|
||||
self.assertEqual(request.top_k, 40)
|
||||
self.assertEqual(request.min_p, 0.05)
|
||||
self.assertFalse(request.separate_reasoning)
|
||||
self.assertFalse(request.stream_reasoning)
|
||||
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
|
||||
|
||||
def test_chat_completion_reasoning_effort(self):
|
||||
"""Test chat completion with reasoning effort"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=messages,
|
||||
reasoning={
|
||||
"enabled": True,
|
||||
"reasoning_effort": "high",
|
||||
},
|
||||
)
|
||||
self.assertEqual(request.reasoning_effort, "high")
|
||||
self.assertEqual(request.chat_template_kwargs, {"thinking": True})
|
||||
|
||||
def test_chat_completion_json_format(self):
|
||||
"""Test chat completion json format"""
|
||||
transcript = "Good morning! It's 7:00 AM, and I'm just waking up. Today is going to be a busy day, "
|
||||
"so let's get started. First, I need to make a quick breakfast. I think I'll have some "
|
||||
"scrambled eggs and toast with a cup of coffee. While I'm cooking, I'll also check my "
|
||||
"emails to see if there's anything urgent."
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "The following is a voice message transcript. Only answer in JSON.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": transcript,
|
||||
},
|
||||
]
|
||||
|
||||
class VoiceNote(BaseModel):
|
||||
title: str = Field(description="A title for the voice note")
|
||||
summary: str = Field(
|
||||
description="A short one sentence summary of the voice note."
|
||||
)
|
||||
strict: Optional[bool] = True
|
||||
actionItems: List[str] = Field(
|
||||
description="A list of action items from the voice note"
|
||||
)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=messages,
|
||||
top_k=40,
|
||||
min_p=0.05,
|
||||
separate_reasoning=False,
|
||||
stream_reasoning=False,
|
||||
chat_template_kwargs={"custom_param": "value"},
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"schema": VoiceNote.model_json_schema(),
|
||||
},
|
||||
)
|
||||
res_format = request.response_format
|
||||
json_format = res_format.json_schema
|
||||
name = json_format.name
|
||||
schema = json_format.schema_
|
||||
strict = json_format.strict
|
||||
self.assertEqual(name, "VoiceNote")
|
||||
self.assertEqual(strict, True)
|
||||
self.assertNotIn("strict", schema["properties"])
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=messages,
|
||||
top_k=40,
|
||||
min_p=0.05,
|
||||
separate_reasoning=False,
|
||||
stream_reasoning=False,
|
||||
chat_template_kwargs={"custom_param": "value"},
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "VoiceNote",
|
||||
"schema": VoiceNote.model_json_schema(),
|
||||
"strict": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
res_format = request.response_format
|
||||
json_format = res_format.json_schema
|
||||
name = json_format.name
|
||||
schema = json_format.schema_
|
||||
strict = json_format.strict
|
||||
self.assertEqual(name, "VoiceNote")
|
||||
self.assertEqual(strict, True)
|
||||
|
||||
|
||||
class TestModelSerialization(unittest.TestCase):
|
||||
"""Test model serialization with hidden states"""
|
||||
|
||||
def test_hidden_states_excluded_when_none(self):
|
||||
"""Test that None hidden_states are excluded with exclude_none=True"""
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
hidden_states=None,
|
||||
)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id="test-id",
|
||||
model="test-model",
|
||||
choices=[choice],
|
||||
usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6),
|
||||
)
|
||||
|
||||
# Test exclude_none serialization (should exclude None hidden_states)
|
||||
data = response.model_dump(exclude_none=True)
|
||||
self.assertNotIn("hidden_states", data["choices"][0])
|
||||
|
||||
def test_hidden_states_included_when_not_none(self):
|
||||
"""Test that non-None hidden_states are included"""
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
hidden_states=[0.1, 0.2, 0.3],
|
||||
)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id="test-id",
|
||||
model="test-model",
|
||||
choices=[choice],
|
||||
usage=UsageInfo(prompt_tokens=5, completion_tokens=1, total_tokens=6),
|
||||
)
|
||||
|
||||
# Test exclude_none serialization (should include non-None hidden_states)
|
||||
data = response.model_dump(exclude_none=True)
|
||||
self.assertIn("hidden_states", data["choices"][0])
|
||||
self.assertEqual(data["choices"][0]["hidden_states"], [0.1, 0.2, 0.3])
|
||||
|
||||
|
||||
class TestValidationEdgeCases(unittest.TestCase):
|
||||
"""Test edge cases and validation scenarios"""
|
||||
|
||||
def test_invalid_tool_choice_type(self):
|
||||
"""Test invalid tool choice type"""
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
with self.assertRaises(ValidationError):
|
||||
ChatCompletionRequest(
|
||||
model="test-model", messages=messages, tool_choice=123
|
||||
)
|
||||
|
||||
def test_negative_token_limits(self):
|
||||
"""Test negative token limits"""
|
||||
with self.assertRaises(ValidationError):
|
||||
CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
|
||||
|
||||
def test_model_serialization_roundtrip(self):
|
||||
"""Test that models can be serialized and deserialized"""
|
||||
original_request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
# Serialize to dict
|
||||
data = original_request.model_dump()
|
||||
|
||||
# Deserialize back
|
||||
restored_request = ChatCompletionRequest(**data)
|
||||
|
||||
self.assertEqual(restored_request.model, original_request.model)
|
||||
self.assertEqual(restored_request.temperature, original_request.temperature)
|
||||
self.assertEqual(restored_request.max_tokens, original_request.max_tokens)
|
||||
self.assertEqual(len(restored_request.messages), len(original_request.messages))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
426
test/srt/openai_server/basic/test_serving_chat.py
Normal file
426
test/srt/openai_server/basic/test_serving_chat.py
Normal file
@@ -0,0 +1,426 @@
|
||||
"""
|
||||
Unit-tests for OpenAIServingChat — rewritten to use only the std-lib 'unittest'.
|
||||
Run with either:
|
||||
python tests/test_serving_chat_unit.py -v
|
||||
or
|
||||
python -m unittest discover -s tests -p "test_*unit.py" -v
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import unittest
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
MessageProcessingResult,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
|
||||
|
||||
class _MockTokenizerManager:
|
||||
"""Minimal mock that satisfies OpenAIServingChat."""
|
||||
|
||||
def __init__(self):
|
||||
self.model_config = Mock(is_multimodal=False)
|
||||
self.server_args = Mock(
|
||||
enable_cache_report=False,
|
||||
tool_call_parser="hermes",
|
||||
reasoning_parser=None,
|
||||
)
|
||||
self.chat_template_name: Optional[str] = "llama-3"
|
||||
|
||||
# tokenizer stub
|
||||
self.tokenizer = Mock()
|
||||
self.tokenizer.encode.return_value = [1, 2, 3, 4, 5]
|
||||
self.tokenizer.decode.return_value = "Test response"
|
||||
self.tokenizer.chat_template = None
|
||||
self.tokenizer.bos_token_id = 1
|
||||
|
||||
# async generator stub for generate_request
|
||||
async def _mock_generate():
|
||||
yield {
|
||||
"text": "Test response",
|
||||
"meta_info": {
|
||||
"id": f"chatcmpl-{uuid.uuid4()}",
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"cached_tokens": 0,
|
||||
"finish_reason": {"type": "stop", "matched": None},
|
||||
"output_token_logprobs": [(0.1, 1, "Test"), (0.2, 2, "response")],
|
||||
"output_top_logprobs": None,
|
||||
},
|
||||
"index": 0,
|
||||
}
|
||||
|
||||
self.generate_request = Mock(return_value=_mock_generate())
|
||||
self.create_abort_task = Mock()
|
||||
|
||||
|
||||
class _MockTemplateManager:
|
||||
"""Minimal mock for TemplateManager."""
|
||||
|
||||
def __init__(self):
|
||||
self.chat_template_name: Optional[str] = "llama-3"
|
||||
self.jinja_template_content_format: Optional[str] = None
|
||||
self.completion_template_name: Optional[str] = None
|
||||
|
||||
|
||||
class ServingChatTestCase(unittest.TestCase):
|
||||
# ------------- common fixtures -------------
|
||||
def setUp(self):
|
||||
self.tm = _MockTokenizerManager()
|
||||
self.template_manager = _MockTemplateManager()
|
||||
self.chat = OpenAIServingChat(self.tm, self.template_manager)
|
||||
|
||||
# frequently reused requests
|
||||
self.basic_req = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "Hi?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
stream=False,
|
||||
)
|
||||
self.stream_req = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "Hi?"}],
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
self.fastapi_request = Mock(spec=Request)
|
||||
self.fastapi_request.headers = {}
|
||||
|
||||
# ------------- conversion tests -------------
|
||||
def test_convert_to_internal_request_single(self):
|
||||
with patch(
|
||||
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
|
||||
) as conv_mock, patch.object(self.chat, "_process_messages") as proc_mock:
|
||||
conv_ins = Mock()
|
||||
conv_ins.get_prompt.return_value = "Test prompt"
|
||||
conv_ins.image_data = conv_ins.audio_data = None
|
||||
conv_ins.modalities = []
|
||||
conv_ins.stop_str = ["</s>"]
|
||||
conv_mock.return_value = conv_ins
|
||||
|
||||
proc_mock.return_value = MessageProcessingResult(
|
||||
"Test prompt",
|
||||
[1, 2, 3],
|
||||
None,
|
||||
None,
|
||||
[],
|
||||
["</s>"],
|
||||
None,
|
||||
)
|
||||
|
||||
adapted, processed = self.chat._convert_to_internal_request(self.basic_req)
|
||||
self.assertIsInstance(adapted, GenerateReqInput)
|
||||
self.assertFalse(adapted.stream)
|
||||
self.assertEqual(processed, self.basic_req)
|
||||
|
||||
def test_stop_str_isolation_between_requests(self):
|
||||
"""Test that stop strings from one request don't affect subsequent requests.
|
||||
|
||||
This tests the fix for the bug where conv.stop_str was being mutated globally,
|
||||
causing stop strings from one request to persist in subsequent requests.
|
||||
"""
|
||||
# Mock conversation template with initial stop_str
|
||||
initial_stop_str = ["\n"]
|
||||
|
||||
with patch(
|
||||
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
|
||||
) as conv_mock:
|
||||
# Create a mock conversation object that will be returned by generate_chat_conv
|
||||
conv_ins = Mock()
|
||||
conv_ins.get_prompt.return_value = "Test prompt"
|
||||
conv_ins.image_data = None
|
||||
conv_ins.audio_data = None
|
||||
conv_ins.modalities = []
|
||||
conv_ins.stop_str = (
|
||||
initial_stop_str.copy()
|
||||
) # Template's default stop strings
|
||||
conv_mock.return_value = conv_ins
|
||||
|
||||
# First request with additional stop string
|
||||
req1 = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "First request"}],
|
||||
stop=["CUSTOM_STOP"],
|
||||
)
|
||||
|
||||
# Call the actual _apply_conversation_template method (not mocked)
|
||||
result1 = self.chat._apply_conversation_template(req1, is_multimodal=False)
|
||||
|
||||
# Verify first request has both stop strings
|
||||
expected_stop1 = initial_stop_str + ["CUSTOM_STOP"]
|
||||
self.assertEqual(result1.stop, expected_stop1)
|
||||
|
||||
# Verify the original template's stop_str wasn't mutated after first request
|
||||
self.assertEqual(conv_ins.stop_str, initial_stop_str)
|
||||
|
||||
# Second request without additional stop string
|
||||
req2 = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "Second request"}],
|
||||
# No custom stop strings
|
||||
)
|
||||
result2 = self.chat._apply_conversation_template(req2, is_multimodal=False)
|
||||
|
||||
# Verify second request only has original stop strings (no CUSTOM_STOP from req1)
|
||||
self.assertEqual(result2.stop, initial_stop_str)
|
||||
self.assertNotIn("CUSTOM_STOP", result2.stop)
|
||||
self.assertEqual(conv_ins.stop_str, initial_stop_str)
|
||||
|
||||
# ------------- sampling-params -------------
|
||||
def test_sampling_param_build(self):
|
||||
req = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
temperature=0.8,
|
||||
max_tokens=150,
|
||||
min_tokens=5,
|
||||
top_p=0.9,
|
||||
stop=["</s>"],
|
||||
)
|
||||
with patch.object(
|
||||
self.chat,
|
||||
"_process_messages",
|
||||
return_value=("Prompt", [1], None, None, [], ["</s>"], None),
|
||||
):
|
||||
params = self.chat._build_sampling_params(req, ["</s>"], None)
|
||||
self.assertEqual(params["temperature"], 0.8)
|
||||
self.assertEqual(params["max_new_tokens"], 150)
|
||||
self.assertEqual(params["min_new_tokens"], 5)
|
||||
self.assertEqual(params["stop"], ["</s>"])
|
||||
|
||||
async def test_unstreamed_tool_args_completion(self):
|
||||
"""Test that remaining tool call arguments are sent when generation finishes."""
|
||||
|
||||
# Mock FunctionCallParser with detector that has partial tool call data
|
||||
mock_parser = Mock()
|
||||
mock_detector = Mock()
|
||||
|
||||
# Simulate a tool call that was partially streamed
|
||||
mock_detector.prev_tool_call_arr = [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"arguments": {"location": "San Francisco", "unit": "celsius"},
|
||||
}
|
||||
]
|
||||
mock_detector.streamed_args_for_tool = [
|
||||
'{"location": "San Francisco"' # Partial arguments streamed so far
|
||||
]
|
||||
mock_parser.detector = mock_detector
|
||||
|
||||
content = {
|
||||
"meta_info": {
|
||||
"id": "chatcmpl-test123",
|
||||
}
|
||||
}
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model="test",
|
||||
messages=[{"role": "user", "content": "What's the weather?"}],
|
||||
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
||||
)
|
||||
|
||||
# Test the completion method
|
||||
result = self.chat._check_for_unstreamed_tool_args(
|
||||
parser=mock_parser,
|
||||
content=content,
|
||||
request=request,
|
||||
finish_reason_type="stop",
|
||||
index=0,
|
||||
)
|
||||
|
||||
# Should return a chunk with remaining arguments
|
||||
self.assertIsNotNone(result, "Should return chunk with remaining arguments")
|
||||
self.assertIn('"arguments":', result, "Should contain arguments field")
|
||||
self.assertIn(
|
||||
', "unit": "celsius"}', result, "Should contain remaining arguments"
|
||||
)
|
||||
self.assertIn(
|
||||
'"finish_reason":null',
|
||||
result,
|
||||
"Should not include finish_reason in completion chunk",
|
||||
)
|
||||
|
||||
async def test_unstreamed_tool_args_no_completion_needed(self):
|
||||
"""Test that no completion chunk is sent when all arguments were already streamed."""
|
||||
|
||||
# Mock FunctionCallParser with detector that has complete tool call data
|
||||
mock_parser = Mock()
|
||||
mock_detector = Mock()
|
||||
|
||||
# Simulate a tool call that was completely streamed
|
||||
mock_detector.prev_tool_call_arr = [
|
||||
{"name": "get_weather", "arguments": {"location": "San Francisco"}}
|
||||
]
|
||||
mock_detector.streamed_args_for_tool = [
|
||||
'{"location": "San Francisco"}' # All arguments already streamed
|
||||
]
|
||||
mock_parser.detector = mock_detector
|
||||
|
||||
content = {
|
||||
"meta_info": {
|
||||
"id": "chatcmpl-test123",
|
||||
}
|
||||
}
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model="test",
|
||||
messages=[{"role": "user", "content": "What's the weather?"}],
|
||||
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
||||
)
|
||||
|
||||
# Test the completion method
|
||||
result = self.chat._check_for_unstreamed_tool_args(
|
||||
parser=mock_parser,
|
||||
content=content,
|
||||
request=request,
|
||||
finish_reason_type="stop",
|
||||
index=0,
|
||||
)
|
||||
|
||||
# Should return None since no completion is needed
|
||||
self.assertIsNone(result, "Should return None when no completion is needed")
|
||||
|
||||
async def test_unstreamed_tool_args_no_parser_data(self):
|
||||
"""Test that no completion chunk is sent when parser has no tool call data."""
|
||||
|
||||
# Mock FunctionCallParser with empty detector
|
||||
mock_parser = Mock()
|
||||
mock_detector = Mock()
|
||||
mock_detector.prev_tool_call_arr = []
|
||||
mock_detector.streamed_args_for_tool = []
|
||||
mock_parser.detector = mock_detector
|
||||
|
||||
content = {
|
||||
"meta_info": {
|
||||
"id": "chatcmpl-test123",
|
||||
}
|
||||
}
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model="test",
|
||||
messages=[{"role": "user", "content": "What's the weather?"}],
|
||||
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
||||
)
|
||||
|
||||
# Test the completion method
|
||||
result = self.chat._check_for_unstreamed_tool_args(
|
||||
parser=mock_parser,
|
||||
content=content,
|
||||
request=request,
|
||||
finish_reason_type="stop",
|
||||
index=0,
|
||||
)
|
||||
|
||||
# Should return None since there's no parser data
|
||||
self.assertIsNone(
|
||||
result, "Should return None when parser has no tool call data"
|
||||
)
|
||||
|
||||
# ------------- kimi_k2 tool_call_id formatting -------------
|
||||
def test_kimi_k2_non_streaming_tool_call_id_format(self):
|
||||
"""Ensure non-streaming tool_call.id matches functions.{name}:{index} for kimi_k2 parser."""
|
||||
|
||||
# Force kimi_k2 parser
|
||||
self.tm.server_args.tool_call_parser = "kimi_k2"
|
||||
|
||||
# Mock FunctionCallParser.parse_non_stream to return one tool call
|
||||
with patch(
|
||||
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
|
||||
) as ParserMock:
|
||||
parser_instance = ParserMock.return_value
|
||||
|
||||
# Build a mock ToolCallItem-like object
|
||||
call_info = Mock()
|
||||
call_info.name = "get_weather"
|
||||
call_info.parameters = '{"city":"Paris"}'
|
||||
call_info.tool_index = 0
|
||||
|
||||
parser_instance.has_tool_call.return_value = True
|
||||
parser_instance.parse_non_stream.return_value = ("", [call_info])
|
||||
|
||||
finish_reason = {"type": "stop", "matched": None}
|
||||
tools = [
|
||||
{"type": "function", "function": {"name": "get_weather"}},
|
||||
]
|
||||
|
||||
tool_calls, remaining_text, _ = self.chat._process_tool_calls(
|
||||
text="<|tool_calls_section_begin|>...",
|
||||
tools=tools,
|
||||
tool_call_parser="kimi_k2",
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
self.assertIsNotNone(tool_calls)
|
||||
self.assertEqual(len(tool_calls), 1)
|
||||
self.assertEqual(tool_calls[0].id, "functions.get_weather:0")
|
||||
self.assertEqual(tool_calls[0].function.name, "get_weather")
|
||||
|
||||
def test_kimi_k2_streaming_tool_call_id_format(self):
|
||||
"""Ensure streaming first chunk tool_call.id matches functions.{name}:{index} for kimi_k2 parser."""
|
||||
|
||||
# Force kimi_k2 parser
|
||||
self.tm.server_args.tool_call_parser = "kimi_k2"
|
||||
|
||||
# Prepare request with tools
|
||||
req = ChatCompletionRequest(
|
||||
model="x",
|
||||
messages=[{"role": "user", "content": "Hi?"}],
|
||||
tools=[{"type": "function", "function": {"name": "get_weather"}}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Patch FunctionCallParser used inside _process_tool_call_stream
|
||||
with patch(
|
||||
"sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser"
|
||||
) as ParserMock:
|
||||
parser_instance = ParserMock.return_value
|
||||
|
||||
# First call returns one ToolCallItem-like chunk (with name)
|
||||
first_chunk_call = Mock()
|
||||
first_chunk_call.tool_index = 0
|
||||
first_chunk_call.name = "get_weather"
|
||||
first_chunk_call.parameters = ""
|
||||
parser_instance.parse_stream_chunk.side_effect = [
|
||||
("", [first_chunk_call]),
|
||||
("", []),
|
||||
]
|
||||
|
||||
async def collect_first_tool_chunk():
|
||||
gen = self.chat._process_tool_call_stream(
|
||||
index=0,
|
||||
delta="irrelevant",
|
||||
parser_dict={},
|
||||
content={"meta_info": {"id": "chatcmpl-test"}},
|
||||
request=req,
|
||||
has_tool_calls={},
|
||||
)
|
||||
# Get first yielded SSE line
|
||||
line = None
|
||||
async for emitted in gen:
|
||||
line = emitted
|
||||
break
|
||||
return line
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
line = loop.run_until_complete(collect_first_tool_chunk())
|
||||
self.assertIsNotNone(line)
|
||||
self.assertTrue(line.startswith("data: "))
|
||||
|
||||
payload = json.loads(line[len("data: ") :])
|
||||
tool_calls = payload["choices"][0]["delta"]["tool_calls"]
|
||||
self.assertEqual(tool_calls[0]["id"], "functions.get_weather:0")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
157
test/srt/openai_server/basic/test_serving_completions.py
Normal file
157
test/srt/openai_server/basic/test_serving_completions.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Unit-tests for the refactored completions-serving handler (no pytest).
|
||||
Run with:
|
||||
python -m unittest tests.test_serving_completions_unit -v
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from typing import Optional
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
|
||||
from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
|
||||
|
||||
class _MockTemplateManager:
|
||||
"""Minimal mock for TemplateManager."""
|
||||
|
||||
def __init__(self):
|
||||
self.chat_template_name: Optional[str] = None
|
||||
self.jinja_template_content_format: Optional[str] = None
|
||||
self.completion_template_name: Optional[str] = (
|
||||
None # Set to None to avoid template processing
|
||||
)
|
||||
|
||||
|
||||
class ServingCompletionTestCase(unittest.TestCase):
|
||||
"""Bundle all prompt/echo tests in one TestCase."""
|
||||
|
||||
# ---------- shared test fixtures ----------
|
||||
def setUp(self):
|
||||
# build the mock TokenizerManager once for every test
|
||||
tm = Mock(spec=TokenizerManager)
|
||||
|
||||
tm.tokenizer = Mock()
|
||||
tm.tokenizer.encode.return_value = [1, 2, 3, 4]
|
||||
tm.tokenizer.decode.return_value = "decoded text"
|
||||
tm.tokenizer.bos_token_id = 1
|
||||
|
||||
tm.model_config = Mock(is_multimodal=False)
|
||||
tm.server_args = Mock(enable_cache_report=False)
|
||||
|
||||
tm.generate_request = AsyncMock()
|
||||
tm.create_abort_task = Mock()
|
||||
|
||||
self.template_manager = _MockTemplateManager()
|
||||
self.sc = OpenAIServingCompletion(tm, self.template_manager)
|
||||
|
||||
# ---------- prompt-handling ----------
|
||||
def test_single_string_prompt(self):
|
||||
req = CompletionRequest(model="x", prompt="Hello world", max_tokens=100)
|
||||
internal, _ = self.sc._convert_to_internal_request(req)
|
||||
self.assertEqual(internal.text, "Hello world")
|
||||
|
||||
def test_single_token_ids_prompt(self):
|
||||
req = CompletionRequest(model="x", prompt=[1, 2, 3, 4], max_tokens=100)
|
||||
internal, _ = self.sc._convert_to_internal_request(req)
|
||||
self.assertEqual(internal.input_ids, [1, 2, 3, 4])
|
||||
|
||||
# ---------- echo-handling ----------
|
||||
def test_echo_with_string_prompt_streaming(self):
|
||||
req = CompletionRequest(model="x", prompt="Hello", max_tokens=1, echo=True)
|
||||
self.assertEqual(self.sc._get_echo_text(req, 0), "Hello")
|
||||
|
||||
def test_echo_with_list_of_strings_streaming(self):
|
||||
req = CompletionRequest(
|
||||
model="x", prompt=["A", "B"], max_tokens=1, echo=True, n=1
|
||||
)
|
||||
self.assertEqual(self.sc._get_echo_text(req, 0), "A")
|
||||
self.assertEqual(self.sc._get_echo_text(req, 1), "B")
|
||||
|
||||
def test_echo_with_token_ids_streaming(self):
|
||||
req = CompletionRequest(model="x", prompt=[1, 2, 3], max_tokens=1, echo=True)
|
||||
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded_prompt"
|
||||
self.assertEqual(self.sc._get_echo_text(req, 0), "decoded_prompt")
|
||||
|
||||
def test_echo_with_multiple_token_ids_streaming(self):
|
||||
req = CompletionRequest(
|
||||
model="x", prompt=[[1, 2], [3, 4]], max_tokens=1, echo=True, n=1
|
||||
)
|
||||
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
|
||||
self.assertEqual(self.sc._get_echo_text(req, 0), "decoded")
|
||||
|
||||
def test_prepare_echo_prompts_non_streaming(self):
|
||||
# single string
|
||||
req = CompletionRequest(model="x", prompt="Hi", echo=True)
|
||||
self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi"])
|
||||
|
||||
# list of strings
|
||||
req = CompletionRequest(model="x", prompt=["Hi", "Yo"], echo=True)
|
||||
self.assertEqual(self.sc._prepare_echo_prompts(req), ["Hi", "Yo"])
|
||||
|
||||
# token IDs
|
||||
req = CompletionRequest(model="x", prompt=[1, 2, 3], echo=True)
|
||||
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
|
||||
self.assertEqual(self.sc._prepare_echo_prompts(req), ["decoded"])
|
||||
|
||||
# ---------- response_format handling ----------
|
||||
def test_response_format_json_object(self):
|
||||
"""Test that response_format json_object is correctly processed in sampling params."""
|
||||
req = CompletionRequest(
|
||||
model="x",
|
||||
prompt="Generate a JSON object:",
|
||||
max_tokens=100,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
sampling_params = self.sc._build_sampling_params(req)
|
||||
self.assertEqual(sampling_params["json_schema"], '{"type": "object"}')
|
||||
|
||||
def test_response_format_json_schema(self):
|
||||
"""Test that response_format json_schema is correctly processed in sampling params."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
|
||||
}
|
||||
req = CompletionRequest(
|
||||
model="x",
|
||||
prompt="Generate a JSON object:",
|
||||
max_tokens=100,
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {"name": "person", "schema": schema},
|
||||
},
|
||||
)
|
||||
sampling_params = self.sc._build_sampling_params(req)
|
||||
# The schema should be converted to string by convert_json_schema_to_str
|
||||
self.assertIn("json_schema", sampling_params)
|
||||
self.assertIsInstance(sampling_params["json_schema"], str)
|
||||
|
||||
def test_response_format_structural_tag(self):
|
||||
"""Test that response_format structural_tag is correctly processed in sampling params."""
|
||||
req = CompletionRequest(
|
||||
model="x",
|
||||
prompt="Generate structured output:",
|
||||
max_tokens=100,
|
||||
response_format={
|
||||
"type": "structural_tag",
|
||||
"structures": [{"begin": "<data>", "end": "</data>"}],
|
||||
"triggers": ["<data>"],
|
||||
},
|
||||
)
|
||||
sampling_params = self.sc._build_sampling_params(req)
|
||||
# The structural_tag should be processed
|
||||
self.assertIn("structural_tag", sampling_params)
|
||||
self.assertIsInstance(sampling_params["structural_tag"], str)
|
||||
|
||||
def test_response_format_none(self):
|
||||
"""Test that no response_format doesn't add extra constraints."""
|
||||
req = CompletionRequest(model="x", prompt="Generate text:", max_tokens=100)
|
||||
sampling_params = self.sc._build_sampling_params(req)
|
||||
# Should not have json_schema or structural_tag from response_format
|
||||
# (but might have json_schema from the legacy json_schema field)
|
||||
self.assertIsNone(sampling_params.get("structural_tag"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
145
test/srt/openai_server/basic/test_serving_embedding.py
Normal file
145
test/srt/openai_server/basic/test_serving_embedding.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
Unit tests for the OpenAIServingEmbedding class from serving_embedding.py.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import uuid
|
||||
from unittest.mock import Mock
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import (
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
MultimodalEmbeddingInput,
|
||||
)
|
||||
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput
|
||||
|
||||
|
||||
# Mock TokenizerManager for embedding tests
|
||||
class _MockTokenizerManager:
|
||||
def __init__(self):
|
||||
self.model_config = Mock()
|
||||
self.model_config.is_multimodal = False
|
||||
self.server_args = Mock()
|
||||
self.server_args.enable_cache_report = False
|
||||
self.model_path = "test-model"
|
||||
|
||||
# Mock tokenizer
|
||||
self.tokenizer = Mock()
|
||||
self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
|
||||
self.tokenizer.decode = Mock(return_value="Test embedding input")
|
||||
self.tokenizer.chat_template = None
|
||||
self.tokenizer.bos_token_id = 1
|
||||
|
||||
# Mock generate_request method for embeddings
|
||||
async def mock_generate_embedding():
|
||||
yield {
|
||||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5] * 20, # 100-dim embedding
|
||||
"meta_info": {
|
||||
"id": f"embd-{uuid.uuid4()}",
|
||||
"prompt_tokens": 5,
|
||||
},
|
||||
}
|
||||
|
||||
self.generate_request = Mock(return_value=mock_generate_embedding())
|
||||
|
||||
|
||||
# Mock TemplateManager for embedding tests
|
||||
class _MockTemplateManager:
|
||||
def __init__(self):
|
||||
self.chat_template_name = None # None for embeddings usually
|
||||
self.jinja_template_content_format = None
|
||||
self.completion_template_name = None
|
||||
|
||||
|
||||
class ServingEmbeddingTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.tokenizer_manager = _MockTokenizerManager()
|
||||
self.template_manager = _MockTemplateManager()
|
||||
self.serving_embedding = OpenAIServingEmbedding(
|
||||
self.tokenizer_manager, self.template_manager
|
||||
)
|
||||
|
||||
self.request = Mock(spec=Request)
|
||||
self.request.headers = {}
|
||||
|
||||
self.basic_req = EmbeddingRequest(
|
||||
model="test-model",
|
||||
input="Hello, how are you?",
|
||||
encoding_format="float",
|
||||
)
|
||||
self.list_req = EmbeddingRequest(
|
||||
model="test-model",
|
||||
input=["Hello, how are you?", "I am fine, thank you!"],
|
||||
encoding_format="float",
|
||||
)
|
||||
self.multimodal_req = EmbeddingRequest(
|
||||
model="test-model",
|
||||
input=[
|
||||
MultimodalEmbeddingInput(text="Hello", image="base64_image_data"),
|
||||
MultimodalEmbeddingInput(text="World", image=None),
|
||||
],
|
||||
encoding_format="float",
|
||||
)
|
||||
self.token_ids_req = EmbeddingRequest(
|
||||
model="test-model",
|
||||
input=[1, 2, 3, 4, 5],
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
def test_convert_single_string_request(self):
|
||||
"""Test converting single string request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
self.serving_embedding._convert_to_internal_request(self.basic_req)
|
||||
)
|
||||
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
self.assertEqual(adapted_request.text, "Hello, how are you?")
|
||||
# self.assertEqual(adapted_request.rid, "test-id")
|
||||
self.assertEqual(processed_request, self.basic_req)
|
||||
|
||||
def test_convert_list_string_request(self):
|
||||
"""Test converting list of strings request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
self.serving_embedding._convert_to_internal_request(self.list_req)
|
||||
)
|
||||
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
self.assertEqual(
|
||||
adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
|
||||
)
|
||||
# self.assertEqual(adapted_request.rid, "test-id")
|
||||
self.assertEqual(processed_request, self.list_req)
|
||||
|
||||
def test_convert_token_ids_request(self):
|
||||
"""Test converting token IDs request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
self.serving_embedding._convert_to_internal_request(self.token_ids_req)
|
||||
)
|
||||
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
|
||||
# self.assertEqual(adapted_request.rid, "test-id")
|
||||
self.assertEqual(processed_request, self.token_ids_req)
|
||||
|
||||
def test_convert_multimodal_request(self):
|
||||
"""Test converting multimodal request to internal format."""
|
||||
adapted_request, processed_request = (
|
||||
self.serving_embedding._convert_to_internal_request(self.multimodal_req)
|
||||
)
|
||||
|
||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||
# Should extract text and images separately
|
||||
self.assertEqual(len(adapted_request.text), 2)
|
||||
self.assertIn("Hello", adapted_request.text)
|
||||
self.assertIn("World", adapted_request.text)
|
||||
self.assertEqual(adapted_request.image_data[0], "base64_image_data")
|
||||
self.assertIsNone(adapted_request.image_data[1])
|
||||
# self.assertEqual(adapted_request.rid, "test-id")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
0
test/srt/openai_server/features/__init__.py
Normal file
0
test/srt/openai_server/features/__init__.py
Normal file
212
test/srt/openai_server/features/test_cache_report.py
Normal file
212
test/srt/openai_server/features/test_cache_report.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestCacheReport(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.min_cached = 5
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=300,
|
||||
other_args=[
|
||||
"--chunked-prefill-size=40",
|
||||
"--enable-cache-report",
|
||||
],
|
||||
)
|
||||
cls.client = openai.Client(api_key="EMPTY", base_url=f"{cls.base_url}/v1")
|
||||
cls.aclient = openai.AsyncClient(api_key="EMPTY", base_url=f"{cls.base_url}/v1")
|
||||
|
||||
usage = cls.run_openai(cls, "1").usage
|
||||
# we can assume that our request is of size 1, plus the total template size
|
||||
# ideally we would like to know the begin size / end size of the template to be more precise
|
||||
total_template_size = usage.prompt_tokens - 1
|
||||
print(f"template size: {total_template_size}")
|
||||
usage2 = cls.run_openai(cls, "2").usage
|
||||
assert usage2.prompt_tokens_details.cached_tokens <= total_template_size
|
||||
cls.min_cached = max(
|
||||
usage2.prompt_tokens_details.cached_tokens,
|
||||
total_template_size - usage2.prompt_tokens_details.cached_tokens,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
# we use an uncommon start to minimise the chance that the cache is hit by chance
|
||||
json={
|
||||
"text": "_ The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": 128,
|
||||
"n": n,
|
||||
"stop_token_ids": [119690],
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
return response
|
||||
|
||||
def run_openai(self, message):
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
# {"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": message},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=100,
|
||||
)
|
||||
return response
|
||||
|
||||
async def run_openai_async(self, message):
|
||||
response = await self.aclient.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "user", "content": message},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=100,
|
||||
)
|
||||
return response
|
||||
|
||||
def cache_report_openai(self, message):
|
||||
response = self.run_openai(message)
|
||||
print(
|
||||
f"openai first request cached_tokens: {int(response.usage.prompt_tokens_details.cached_tokens)}"
|
||||
)
|
||||
first_cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
|
||||
# assert int(response.usage.cached_tokens) == 0
|
||||
assert first_cached_tokens <= self.min_cached
|
||||
response = self.run_openai(message)
|
||||
cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
|
||||
print(f"openai second request cached_tokens: {cached_tokens}")
|
||||
assert cached_tokens > 0
|
||||
assert cached_tokens == int(response.usage.prompt_tokens) - 1
|
||||
return first_cached_tokens
|
||||
|
||||
async def cache_report_openai_async(self, message):
|
||||
response = await self.run_openai_async(message)
|
||||
cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
|
||||
prompt_tokens = int(response.usage.prompt_tokens)
|
||||
return cached_tokens, prompt_tokens
|
||||
|
||||
def test_generate(self):
|
||||
print("=" * 100)
|
||||
response = self.run_decode()
|
||||
# print(response.json())
|
||||
cached_tokens = int(response.json()["meta_info"]["cached_tokens"])
|
||||
print(f"sglang first request cached_tokens: {cached_tokens}")
|
||||
print(
|
||||
f"sglang first request prompt_tokens: {int(response.json()['meta_info']['prompt_tokens'])}"
|
||||
)
|
||||
# can't assure to be 0: depends on the initialisation request / if a template is used with the model
|
||||
assert cached_tokens < self.min_cached
|
||||
response = self.run_decode()
|
||||
cached_tokens = int(response.json()["meta_info"]["cached_tokens"])
|
||||
print(f"sglang second request cached_tokens: {cached_tokens}")
|
||||
print(
|
||||
f"sglang second request prompt_tokens: {int(response.json()['meta_info']['prompt_tokens'])}"
|
||||
)
|
||||
assert cached_tokens == int(response.json()["meta_info"]["prompt_tokens"]) - 1
|
||||
|
||||
def test_cache_split_prefill_openai(self):
|
||||
print("=" * 100)
|
||||
self.cache_report_openai(
|
||||
"€ This is a very long and unique text that should not be already cached, the twist is"
|
||||
" that it should be longer than the chunked-prefill-size, so it should be split among"
|
||||
" several prefill requests. Still, it shouldn't be cached"
|
||||
)
|
||||
|
||||
def test_cache_report_openai(self):
|
||||
print("=" * 100)
|
||||
# warm up the cache, for the template
|
||||
self.run_openai("Introduce the capital of France.")
|
||||
|
||||
first_cached_tokens_1 = self.run_openai(
|
||||
"How many sparrow do you need to lift a coconut?"
|
||||
).usage.prompt_tokens_details.cached_tokens
|
||||
|
||||
usage_2 = self.run_openai("* sing something about cats").usage
|
||||
first_cached_tokens_2 = usage_2.prompt_tokens_details.cached_tokens
|
||||
# first request may not have 0 cached tokens, but if they only have the template in common they
|
||||
# should be the same once the cache is warmed up
|
||||
assert first_cached_tokens_1 == first_cached_tokens_2
|
||||
|
||||
resp = self.run_openai("* sing something about cats and dogs")
|
||||
print(resp.usage)
|
||||
|
||||
resp = self.run_openai("* sing something about cats, please")
|
||||
print(resp.usage)
|
||||
assert (
|
||||
resp.usage.prompt_tokens_details.cached_tokens
|
||||
>= usage_2.prompt_tokens - self.min_cached
|
||||
)
|
||||
|
||||
# TODO: flaky test
|
||||
# def test_cache_report_openai_async(self):
|
||||
# print("=" * 100)
|
||||
|
||||
# async def run_test():
|
||||
# task0 = asyncio.create_task(
|
||||
# self.cache_report_openai_async(
|
||||
# "first request, to start the inference and let the next two request be started in the same batch"
|
||||
# )
|
||||
# )
|
||||
# await asyncio.sleep(1) # to force the first request to be started first
|
||||
# task1 = asyncio.create_task(
|
||||
# self.cache_report_openai_async(
|
||||
# "> can the same batch parallel request use the cache?"
|
||||
# )
|
||||
# )
|
||||
# task2 = asyncio.create_task(
|
||||
# self.cache_report_openai_async(
|
||||
# "> can the same batch parallel request use the cache?"
|
||||
# )
|
||||
# )
|
||||
# result0, result1, result2 = await asyncio.gather(task0, task1, task2)
|
||||
|
||||
# cached_tokens0, prompt_tokens0 = result0
|
||||
# cached_tokens1, prompt_tokens1 = result1
|
||||
# cached_tokens2, prompt_tokens2 = result2
|
||||
|
||||
# print(
|
||||
# f"Async request 0 - Cached tokens: {cached_tokens0}, Prompt tokens: {prompt_tokens0}"
|
||||
# )
|
||||
# print(
|
||||
# f"Async request 1 - Cached tokens: {cached_tokens1}, Prompt tokens: {prompt_tokens1}"
|
||||
# )
|
||||
# print(
|
||||
# f"Async request 2 - Cached tokens: {cached_tokens2}, Prompt tokens: {prompt_tokens2}"
|
||||
# )
|
||||
|
||||
# # Assert that no requests used the cache (because first is alone, and the next two are in the same batch)
|
||||
# # If a new optimisation limiting starting request with same prefix at the same time was added
|
||||
# # to maximise the cache hit, this would not be true
|
||||
# assert cached_tokens1 == cached_tokens2 == cached_tokens0
|
||||
|
||||
# asyncio.run(run_test())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
243
test/srt/openai_server/features/test_enable_thinking.py
Normal file
243
test/srt/openai_server/features/test_enable_thinking.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_chat_completion_with_reasoning
|
||||
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_chat_completion_without_reasoning
|
||||
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_stream_chat_completion_with_reasoning
|
||||
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_stream_chat_completion_without_reasoning
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestEnableThinking(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-1234"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=[
|
||||
"--reasoning-parser",
|
||||
"qwen3",
|
||||
],
|
||||
)
|
||||
cls.additional_chat_kwargs = {}
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_chat_completion_with_reasoning(self):
|
||||
# Test non-streaming with "enable_thinking": True, reasoning_content should not be empty
|
||||
client = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"temperature": 0,
|
||||
"separate_reasoning": True,
|
||||
"chat_template_kwargs": {"enable_thinking": True},
|
||||
**self.additional_chat_kwargs,
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(client.status_code, 200, f"Failed with: {client.text}")
|
||||
data = client.json()
|
||||
|
||||
self.assertIn("choices", data)
|
||||
self.assertTrue(len(data["choices"]) > 0)
|
||||
self.assertIn("message", data["choices"][0])
|
||||
self.assertIn("reasoning_content", data["choices"][0]["message"])
|
||||
self.assertIsNotNone(data["choices"][0]["message"]["reasoning_content"])
|
||||
|
||||
def test_chat_completion_without_reasoning(self):
|
||||
# Test non-streaming with "enable_thinking": False, reasoning_content should be empty
|
||||
client = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"temperature": 0,
|
||||
"separate_reasoning": True,
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
**self.additional_chat_kwargs,
|
||||
},
|
||||
)
|
||||
|
||||
self.assertEqual(client.status_code, 200, f"Failed with: {client.text}")
|
||||
data = client.json()
|
||||
|
||||
self.assertIn("choices", data)
|
||||
self.assertTrue(len(data["choices"]) > 0)
|
||||
self.assertIn("message", data["choices"][0])
|
||||
|
||||
if "reasoning_content" in data["choices"][0]["message"]:
|
||||
self.assertIsNone(data["choices"][0]["message"]["reasoning_content"])
|
||||
|
||||
def test_stream_chat_completion_with_reasoning(self):
|
||||
# Test streaming with "enable_thinking": True, reasoning_content should not be empty
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"temperature": 0,
|
||||
"separate_reasoning": True,
|
||||
"stream": True,
|
||||
"chat_template_kwargs": {"enable_thinking": True},
|
||||
**self.additional_chat_kwargs,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
|
||||
|
||||
has_reasoning = False
|
||||
has_content = False
|
||||
|
||||
print("\n=== Stream With Reasoning ===")
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.decode("utf-8")
|
||||
if line.startswith("data:") and not line.startswith("data: [DONE]"):
|
||||
data = json.loads(line[6:])
|
||||
if "choices" in data and len(data["choices"]) > 0:
|
||||
delta = data["choices"][0].get("delta", {})
|
||||
|
||||
if "reasoning_content" in delta and delta["reasoning_content"]:
|
||||
has_reasoning = True
|
||||
|
||||
if "content" in delta and delta["content"]:
|
||||
has_content = True
|
||||
|
||||
self.assertTrue(
|
||||
has_reasoning,
|
||||
"The reasoning content is not included in the stream response",
|
||||
)
|
||||
self.assertTrue(
|
||||
has_content, "The stream response does not contain normal content"
|
||||
)
|
||||
|
||||
def test_stream_chat_completion_without_reasoning(self):
|
||||
# Test streaming with "enable_thinking": False, reasoning_content should be empty
|
||||
response = requests.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"temperature": 0,
|
||||
"separate_reasoning": True,
|
||||
"stream": True,
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
**self.additional_chat_kwargs,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
|
||||
|
||||
has_reasoning = False
|
||||
has_content = False
|
||||
|
||||
print("\n=== Stream Without Reasoning ===")
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.decode("utf-8")
|
||||
if line.startswith("data:") and not line.startswith("data: [DONE]"):
|
||||
data = json.loads(line[6:])
|
||||
if "choices" in data and len(data["choices"]) > 0:
|
||||
delta = data["choices"][0].get("delta", {})
|
||||
|
||||
if "reasoning_content" in delta and delta["reasoning_content"]:
|
||||
has_reasoning = True
|
||||
|
||||
if "content" in delta and delta["content"]:
|
||||
has_content = True
|
||||
|
||||
self.assertFalse(
|
||||
has_reasoning,
|
||||
"The reasoning content should not be included in the stream response",
|
||||
)
|
||||
self.assertTrue(
|
||||
has_content, "The stream response does not contain normal content"
|
||||
)
|
||||
|
||||
|
||||
# Skip for ci test
|
||||
# class TestGLM45EnableThinking(TestEnableThinking):
|
||||
# @classmethod
|
||||
# def setUpClass(cls):
|
||||
# # Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
# cls.model = "THUDM/GLM-4.5"
|
||||
# cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
# cls.api_key = "sk-1234"
|
||||
# cls.process = popen_launch_server(
|
||||
# cls.model,
|
||||
# cls.base_url,
|
||||
# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
# api_key=cls.api_key,
|
||||
# other_args=[
|
||||
# "--tool-call-parser",
|
||||
# "glm45",
|
||||
# "--reasoning-parser",
|
||||
# "glm45",
|
||||
# "--tp-size",
|
||||
# "8"
|
||||
# ],
|
||||
# )
|
||||
|
||||
# # Validate whether enable-thinking conflict with tool_calls
|
||||
# cls.additional_chat_kwargs = {
|
||||
# "tools": [
|
||||
# {
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": "add",
|
||||
# "description": "Compute the sum of two numbers",
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "a": {
|
||||
# "type": "int",
|
||||
# "description": "A number",
|
||||
# },
|
||||
# "b": {
|
||||
# "type": "int",
|
||||
# "description": "A number",
|
||||
# },
|
||||
# },
|
||||
# "required": ["a", "b"],
|
||||
# },
|
||||
# },
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
153
test/srt/openai_server/features/test_json_constrained.py
Normal file
153
test/srt/openai_server/features/test_json_constrained.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate
|
||||
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate
|
||||
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedLLGuidanceBackend.test_json_generate
|
||||
"""
|
||||
|
||||
import json
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
def setup_class(cls, backend: str):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.json_schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "pattern": "^[\\w]+$"},
|
||||
"population": {"type": "integer"},
|
||||
},
|
||||
"required": ["name", "population"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
)
|
||||
|
||||
other_args = [
|
||||
"--max-running-requests",
|
||||
"10",
|
||||
"--grammar-backend",
|
||||
backend,
|
||||
]
|
||||
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_args,
|
||||
)
|
||||
|
||||
|
||||
class TestJSONConstrainedOutlinesBackend(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, backend="outlines")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": 128,
|
||||
"n": n,
|
||||
"stop_token_ids": [119690],
|
||||
"json_schema": json_schema,
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
ret = response.json()
|
||||
print(json.dumps(ret))
|
||||
print("=" * 100)
|
||||
|
||||
if not json_schema or json_schema == "INVALID":
|
||||
return
|
||||
|
||||
# Make sure the json output is valid
|
||||
try:
|
||||
js_obj = json.loads(ret["text"])
|
||||
except (TypeError, json.decoder.JSONDecodeError):
|
||||
raise
|
||||
|
||||
self.assertIsInstance(js_obj["name"], str)
|
||||
self.assertIsInstance(js_obj["population"], int)
|
||||
|
||||
def test_json_generate(self):
|
||||
self.run_decode(json_schema=self.json_schema)
|
||||
|
||||
def test_json_invalid(self):
|
||||
self.run_decode(json_schema="INVALID")
|
||||
|
||||
def test_json_openai(self):
|
||||
client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Introduce the capital of France. Return in a JSON format.",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {"name": "foo", "schema": json.loads(self.json_schema)},
|
||||
},
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
|
||||
try:
|
||||
js_obj = json.loads(text)
|
||||
except (TypeError, json.decoder.JSONDecodeError):
|
||||
print("JSONDecodeError", text)
|
||||
raise
|
||||
|
||||
self.assertIsInstance(js_obj["name"], str)
|
||||
self.assertIsInstance(js_obj["population"], int)
|
||||
|
||||
def test_mix_json_and_other(self):
|
||||
json_schemas = [None, None, self.json_schema, self.json_schema] * 10
|
||||
|
||||
with ThreadPoolExecutor(len(json_schemas)) as executor:
|
||||
list(executor.map(self.run_decode, json_schemas))
|
||||
|
||||
|
||||
class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, backend="xgrammar")
|
||||
|
||||
|
||||
class TestJSONConstrainedLLGuidanceBackend(TestJSONConstrainedOutlinesBackend):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, backend="llguidance")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user