From 7a020e0f3b36c0872db39b3cfa563ce139ef9633 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Mon, 20 Oct 2025 01:17:00 +0800 Subject: [PATCH] [Test] Add basic matched stop for beta eagle (#11833) --- python/sglang/test/kit_matched_stop.py | 157 ++++++++++++++++++ python/sglang/test/test_utils.py | 3 + .../validation/test_matched_stop.py | 145 +--------------- test/srt/test_eagle_infer_beta.py | 118 +++++-------- 4 files changed, 201 insertions(+), 222 deletions(-) create mode 100644 python/sglang/test/kit_matched_stop.py diff --git a/python/sglang/test/kit_matched_stop.py b/python/sglang/test/kit_matched_stop.py new file mode 100644 index 000000000..afccc6779 --- /dev/null +++ b/python/sglang/test/kit_matched_stop.py @@ -0,0 +1,157 @@ +import json + +import requests + +MANY_NEW_TOKENS_PROMPT = """ +Please write an extremely detailed and vivid fantasy story, set in a world full of intricate magic systems, political intrigue, and complex characters. +Ensure that you thoroughly describe every scene, character's motivations, and the environment. Include long, engaging dialogues and elaborate on the inner thoughts of the characters. +Each section should be as comprehensive as possible to create a rich and immersive experience for the reader. +The story should span multiple events, challenges, and character developments over time. Aim to make the story at least 3,000 words long. +""" + + +class MatchedStopMixin: + def _run_completions_generation( + self, + prompt=MANY_NEW_TOKENS_PROMPT, + max_tokens=1, + stop=None, + stop_regex=None, + finish_reason=None, + matched_stop=None, + ): + payload = { + "prompt": prompt, + "model": self.model, + "temperature": 0, + "top_p": 1, + "max_tokens": max_tokens, + } + + if stop is not None: + payload["stop"] = stop + + if stop_regex is not None: + payload["stop_regex"] = stop_regex + + response_completions = requests.post( + self.base_url + "/v1/completions", + json=payload, + ) + res = response_completions.json() + print(json.dumps(res)) + print("=" * 100) + + if not isinstance(matched_stop, list): + matched_stop = [matched_stop] + + assert ( + res["choices"][0]["finish_reason"] == finish_reason + ), f"Expected finish_reason: {finish_reason}, but got: {res['choices'][0]['finish_reason']}" + assert ( + res["choices"][0]["matched_stop"] in matched_stop + ), f"Expected matched_stop: {matched_stop}, but got: {res['choices'][0]['matched_stop']}" + + def _run_chat_completions_generation( + self, + prompt=MANY_NEW_TOKENS_PROMPT, + max_tokens=1, + stop=None, + stop_regex=None, + finish_reason=None, + matched_stop=None, + ): + chat_payload = { + "model": self.model, + "messages": [ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": prompt}, + ], + "temperature": 0, + "top_p": 1, + "max_tokens": max_tokens, + } + + if stop is not None: + chat_payload["stop"] = stop + + if stop_regex is not None: + chat_payload["stop_regex"] = stop_regex + + response_chat = requests.post( + self.base_url + "/v1/chat/completions", + json=chat_payload, + ) + res = response_chat.json() + print(json.dumps(res)) + print("=" * 100) + + if not isinstance(matched_stop, list): + matched_stop = [matched_stop] + + assert ( + res["choices"][0]["finish_reason"] == finish_reason + ), f"Expected finish_reason: {finish_reason}, but got: {res['choices'][0]['finish_reason']}" + assert ( + res["choices"][0]["matched_stop"] in matched_stop + ), f"Expected matched_stop: {matched_stop}, but got: {res['choices'][0]['matched_stop']}" + + def test_finish_stop_str(self): + self._run_completions_generation( + max_tokens=1000, stop="\n", finish_reason="stop", matched_stop="\n" + ) + self._run_chat_completions_generation( + max_tokens=1000, stop="\n", finish_reason="stop", matched_stop="\n" + ) + + def test_finish_stop_regex_str(self): + STOP_REGEX_STR = r"and|or" + self._run_completions_generation( + max_tokens=1000, + stop_regex=STOP_REGEX_STR, + finish_reason="stop", + matched_stop=STOP_REGEX_STR, + ) + self._run_chat_completions_generation( + max_tokens=1000, + stop_regex=STOP_REGEX_STR, + finish_reason="stop", + matched_stop=STOP_REGEX_STR, + ) + + # Match a complete sentence + STOP_REGEX_STR_SENTENCE = r"[.!?]\s*$" + self._run_chat_completions_generation( + max_tokens=1000, + stop_regex=STOP_REGEX_STR_SENTENCE, + finish_reason="stop", + matched_stop=STOP_REGEX_STR_SENTENCE, + ) + + def test_finish_stop_eos(self): + llama_format_prompt = """\ +<|begin_of_text|><|start_header_id|>system<|end_header_id|> +You are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|> +What is 2 + 2?<|eot_id|><|start_header_id|>assistant<|end_header_id|> + """ + eos_token_ids = [128000, 128009, 2] + self._run_completions_generation( + prompt=llama_format_prompt, + max_tokens=1000, + finish_reason="stop", + matched_stop=eos_token_ids, + ) + self._run_chat_completions_generation( + prompt="What is 2 + 2?", + max_tokens=1000, + finish_reason="stop", + matched_stop=eos_token_ids, + ) + + def test_finish_length(self): + self._run_completions_generation( + max_tokens=5, finish_reason="length", matched_stop=None + ) + self._run_chat_completions_generation( + max_tokens=5, finish_reason="length", matched_stop=None + ) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 2e75909e9..a38df4962 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -1622,6 +1622,9 @@ class CustomTestCase(unittest.TestCase): max_retry=max_retry, ) + def setUp(self): + print(f"[Test Method] {self._testMethodName}", flush=True) + def dump_bench_raw_result( path: str, diff --git a/test/srt/openai_server/validation/test_matched_stop.py b/test/srt/openai_server/validation/test_matched_stop.py index 5c264853a..12496dbb6 100644 --- a/test/srt/openai_server/validation/test_matched_stop.py +++ b/test/srt/openai_server/validation/test_matched_stop.py @@ -1,10 +1,8 @@ -import json import unittest -import requests - from sglang.srt.sampling.sampling_params import MAX_LEN, get_max_seq_length from sglang.srt.utils import kill_process_tree +from sglang.test.kit_matched_stop import MatchedStopMixin from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_URL_FOR_TEST, @@ -12,15 +10,8 @@ from sglang.test.test_utils import ( popen_launch_server, ) -MANY_NEW_TOKENS_PROMPT = """ -Please write an extremely detailed and vivid fantasy story, set in a world full of intricate magic systems, political intrigue, and complex characters. -Ensure that you thoroughly describe every scene, character's motivations, and the environment. Include long, engaging dialogues and elaborate on the inner thoughts of the characters. -Each section should be as comprehensive as possible to create a rich and immersive experience for the reader. -The story should span multiple events, challenges, and character developments over time. Aim to make the story at least 3,000 words long. -""" - -class TestMatchedStop(CustomTestCase): +class TestMatchedStop(CustomTestCase, MatchedStopMixin): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST @@ -36,138 +27,6 @@ class TestMatchedStop(CustomTestCase): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def run_completions_generation( - self, - prompt=MANY_NEW_TOKENS_PROMPT, - max_tokens=1, - stop=None, - stop_regex=None, - finish_reason=None, - matched_stop=None, - ): - payload = { - "prompt": prompt, - "model": self.model, - "temperature": 0, - "top_p": 1, - "max_tokens": max_tokens, - } - - if stop is not None: - payload["stop"] = stop - - if stop_regex is not None: - payload["stop_regex"] = stop_regex - - response_completions = requests.post( - self.base_url + "/v1/completions", - json=payload, - ) - print(json.dumps(response_completions.json())) - print("=" * 100) - - assert ( - response_completions.json()["choices"][0]["finish_reason"] == finish_reason - ) - assert response_completions.json()["choices"][0]["matched_stop"] == matched_stop - - def run_chat_completions_generation( - self, - prompt=MANY_NEW_TOKENS_PROMPT, - max_tokens=1, - stop=None, - stop_regex=None, - finish_reason=None, - matched_stop=None, - ): - chat_payload = { - "model": self.model, - "messages": [ - {"role": "system", "content": "You are a helpful AI assistant"}, - {"role": "user", "content": prompt}, - ], - "temperature": 0, - "top_p": 1, - "max_tokens": max_tokens, - } - - if stop is not None: - chat_payload["stop"] = stop - - if stop_regex is not None: - chat_payload["stop_regex"] = stop_regex - - response_chat = requests.post( - self.base_url + "/v1/chat/completions", - json=chat_payload, - ) - print(json.dumps(response_chat.json())) - print("=" * 100) - - assert response_chat.json()["choices"][0]["finish_reason"] == finish_reason - assert response_chat.json()["choices"][0]["matched_stop"] == matched_stop - - def test_finish_stop_str(self): - self.run_completions_generation( - max_tokens=1000, stop="\n", finish_reason="stop", matched_stop="\n" - ) - self.run_chat_completions_generation( - max_tokens=1000, stop="\n", finish_reason="stop", matched_stop="\n" - ) - - def test_finish_stop_regex_str(self): - STOP_REGEX_STR = r"and|or" - self.run_completions_generation( - max_tokens=1000, - stop_regex=STOP_REGEX_STR, - finish_reason="stop", - matched_stop=STOP_REGEX_STR, - ) - self.run_chat_completions_generation( - max_tokens=1000, - stop_regex=STOP_REGEX_STR, - finish_reason="stop", - matched_stop=STOP_REGEX_STR, - ) - - # Match a complete sentence - STOP_REGEX_STR_SENTENCE = r"[.!?]\s*$" - self.run_chat_completions_generation( - max_tokens=1000, - stop_regex=STOP_REGEX_STR_SENTENCE, - finish_reason="stop", - matched_stop=STOP_REGEX_STR_SENTENCE, - ) - - def test_finish_stop_eos(self): - llama_format_prompt = """ - <|begin_of_text|><|start_header_id|>system<|end_header_id|> - You are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|> - - What is 2 + 2?<|eot_id|><|start_header_id|>assistant<|end_header_id|> - """ - eos_token_id = 128009 - self.run_completions_generation( - prompt=llama_format_prompt, - max_tokens=1000, - finish_reason="stop", - matched_stop=eos_token_id, - ) - self.run_chat_completions_generation( - prompt="What is 2 + 2?", - max_tokens=1000, - finish_reason="stop", - matched_stop=eos_token_id, - ) - - def test_finish_length(self): - self.run_completions_generation( - max_tokens=5, finish_reason="length", matched_stop=None - ) - self.run_chat_completions_generation( - max_tokens=5, finish_reason="length", matched_stop=None - ) - class TestRegexPatternMaxLength(unittest.TestCase): @classmethod diff --git a/test/srt/test_eagle_infer_beta.py b/test/srt/test_eagle_infer_beta.py index fe7f18010..f0a99df36 100644 --- a/test/srt/test_eagle_infer_beta.py +++ b/test/srt/test_eagle_infer_beta.py @@ -3,7 +3,10 @@ from types import SimpleNamespace from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.kit_matched_stop import MatchedStopMixin from sglang.test.test_utils import ( + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, @@ -11,93 +14,50 @@ from sglang.test.test_utils import ( ) -class TestEagleBS1(CustomTestCase): - num_questions = 60 - - @classmethod - def setUpClass(cls): - cls.model = "meta-llama/Llama-2-7b-chat-hf" - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--attention-backend", - "triton", - "--enable-beta-spec", - "--speculative-algorithm", - "EAGLE", - "--speculative-draft-model", - "lmzheng/sglang-EAGLE-llama2-chat-7B", - "--speculative-num-steps", - "5", - "--speculative-eagle-topk", - "1", - "--speculative-num-draft-tokens", - "6", - "--max-running-requests", - "1", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=self.num_questions, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval(args) - print(f"TestEagleBS1 -- {metrics=}") - self.assertGreater( - metrics["accuracy"], 0.33 - ) # 0.3333 for 60 questions; 0.234 for 1319 questions - - -class TestEagleLargeBS(CustomTestCase): - num_questions = 10000 +class TestEagleServerBase(CustomTestCase, MatchedStopMixin): max_running_requests = 64 - other_args = [ - "--trust-remote-code", - "--attention-backend", - "triton", - "--enable-beta-spec", - "--speculative-algorithm", - "EAGLE", - "--speculative-draft-model", - "lmzheng/sglang-EAGLE-llama2-chat-7B", - "--speculative-num-steps", - "5", - "--speculative-eagle-topk", - "1", - "--speculative-num-draft-tokens", - "6", - "--mem-fraction-static", - "0.75", - "--max-running-requests", - str(max_running_requests), - "--cuda-graph-bs", - *[str(i) for i in range(1, max_running_requests + 1)], - ] + attention_backend = "triton" + spec_steps = 5 + spec_topk = 1 + spec_draft_tokens = 6 + page_size = 1 + other_launch_args = [] + model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST + draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST @classmethod def setUpClass(cls): - cls.model = "meta-llama/Llama-2-7b-chat-hf" cls.base_url = DEFAULT_URL_FOR_TEST + launch_args = [ + "--enable-beta-spec", + "--trust-remote-code", + "--attention-backend", + cls.attention_backend, + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model", + cls.draft_model, + "--speculative-num-steps", + cls.spec_steps, + "--speculative-eagle-topk", + cls.spec_topk, + "--speculative-num-draft-tokens", + cls.spec_draft_tokens, + "--page-size", + str(cls.page_size), + "--mem-fraction-static", + "0.75", + "--max-running-requests", + str(cls.max_running_requests), + "--cuda-graph-bs", + *[str(i) for i in range(1, cls.max_running_requests + 1)], + ] + launch_args.extend(cls.other_launch_args) cls.process = popen_launch_server( cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=cls.other_args, + other_args=launch_args, ) @classmethod @@ -108,7 +68,7 @@ class TestEagleLargeBS(CustomTestCase): args = SimpleNamespace( num_shots=5, data_path=None, - num_questions=self.num_questions, + num_questions=1000, max_new_tokens=512, parallel=128, host="http://127.0.0.1",