diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py
index 83e1b92c1..6eb175cb1 100644
--- a/test/srt/run_suite.py
+++ b/test/srt/run_suite.py
@@ -31,7 +31,8 @@ suites = {
TestFile("test_block_int8.py", 22),
TestFile("test_create_kvindices.py", 2),
TestFile("test_chunked_prefill.py", 313),
- TestFile("test_eagle_infer.py", 619),
+ TestFile("test_eagle_infer_a.py", 300),
+ TestFile("test_eagle_infer_b.py", 300),
TestFile("test_ebnf_constrained.py", 108),
TestFile("test_enable_thinking.py", 70),
TestFile("test_embedding_openai_server.py", 141),
diff --git a/test/srt/test_eagle_infer_a.py b/test/srt/test_eagle_infer_a.py
new file mode 100644
index 000000000..298f1073e
--- /dev/null
+++ b/test/srt/test_eagle_infer_a.py
@@ -0,0 +1,323 @@
+import unittest
+
+import requests
+import torch
+
+import sglang as sgl
+from sglang.srt.hf_transformers_utils import get_tokenizer
+from sglang.srt.utils import kill_process_tree
+from sglang.test.test_utils import (
+ DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
+ DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
+ DEFAULT_MODEL_NAME_FOR_TEST_MLA,
+ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
+ DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
+ is_in_ci,
+ popen_launch_server,
+)
+
+torch_dtype = torch.float16
+prefill_tolerance = 5e-2
+decode_tolerance: float = 5e-2
+
+
+class TestEAGLEEngine(CustomTestCase):
+ BASE_CONFIG = {
+ "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
+ "speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
+ "speculative_algorithm": "EAGLE",
+ "speculative_num_steps": 5,
+ "speculative_eagle_topk": 4,
+ "speculative_num_draft_tokens": 8,
+ "mem_fraction_static": 0.7,
+ "cuda_graph_max_bs": 5,
+ }
+ NUM_CONFIGS = 2
+
+ def setUp(self):
+ self.prompt = "Today is a sunny day and I like"
+ self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
+
+ ref_engine = sgl.Engine(
+ model_path=self.BASE_CONFIG["model_path"], cuda_graph_max_bs=1
+ )
+ self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
+ ref_engine.shutdown()
+
+ def test_correctness(self):
+ configs = [
+ # Basic config
+ self.BASE_CONFIG,
+ # Chunked prefill
+ {**self.BASE_CONFIG, "chunked_prefill_size": 4},
+ ]
+
+ for i, config in enumerate(configs[: self.NUM_CONFIGS]):
+ with self.subTest(i=i):
+ print(f"{config=}")
+ engine = sgl.Engine(**config, log_level="info", decode_log_interval=10)
+ try:
+ self._test_single_generation(engine)
+ self._test_batch_generation(engine)
+ self._test_eos_token(engine)
+ self._test_acc_length(engine)
+ finally:
+ engine.shutdown()
+ print("=" * 100)
+
+ def _test_single_generation(self, engine):
+ output = engine.generate(self.prompt, self.sampling_params)["text"]
+ print(f"{output=}, {self.ref_output=}")
+ self.assertEqual(output, self.ref_output)
+
+ def _test_batch_generation(self, engine):
+ prompts = [
+ "Hello, my name is",
+ "The president of the United States is",
+ "The capital of France is",
+ "The future of AI is",
+ ]
+ params = {"temperature": 0, "max_new_tokens": 50}
+
+ outputs = engine.generate(prompts, params)
+ for prompt, output in zip(prompts, outputs):
+ print(f"Prompt: {prompt}")
+ print(f"Generated: {output['text']}")
+ print("-" * 40)
+
+ print(f"{engine.get_server_info()=}")
+
+ avg_spec_accept_length = engine.get_server_info()["internal_states"][0][
+ "avg_spec_accept_length"
+ ]
+ print(f"{avg_spec_accept_length=}")
+ self.assertGreater(avg_spec_accept_length, 1.9)
+
+ def _test_eos_token(self, engine):
+ prompt = "[INST] <>\nYou are a helpful assistant.\n<>\nToday is a sunny day and I like [/INST]"
+ params = {
+ "temperature": 0.1,
+ "max_new_tokens": 1024,
+ "skip_special_tokens": False,
+ }
+
+ tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
+ output = engine.generate(prompt, params)["text"]
+ print(f"{output=}")
+
+ tokens = tokenizer.encode(output, truncation=False)
+ self.assertNotIn(tokenizer.eos_token_id, tokens)
+
+ def _test_acc_length(self, engine):
+ prompt = [
+ "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
+ ] * 5 # test batched generation
+ sampling_params = {"temperature": 0, "max_new_tokens": 512}
+ output = engine.generate(prompt, sampling_params)
+ output = output[0]
+
+ if "spec_verify_ct" in output["meta_info"]:
+ acc_length = (
+ output["meta_info"]["completion_tokens"]
+ / output["meta_info"]["spec_verify_ct"]
+ )
+ else:
+ acc_length = 1.0
+
+ speed = (
+ output["meta_info"]["completion_tokens"]
+ / output["meta_info"]["e2e_latency"]
+ )
+ print(f"{acc_length=}")
+
+ if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST:
+ self.assertGreater(acc_length, 3.6)
+ else:
+ self.assertGreater(acc_length, 2.5)
+
+
+class TestEAGLEEngineTokenMap(TestEAGLEEngine):
+ BASE_CONFIG = {
+ "model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
+ "speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B",
+ "speculative_algorithm": "EAGLE",
+ "speculative_num_steps": 5,
+ "speculative_eagle_topk": 4,
+ "speculative_num_draft_tokens": 8,
+ "speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
+ "mem_fraction_static": 0.7,
+ "cuda_graph_max_bs": 5,
+ "dtype": "float16",
+ }
+ NUM_CONFIGS = 1
+
+
+class TestEAGLE3Engine(TestEAGLEEngine):
+ BASE_CONFIG = {
+ "model_path": "meta-llama/Llama-3.1-8B-Instruct",
+ "speculative_draft_model_path": "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
+ "speculative_algorithm": "EAGLE3",
+ "speculative_num_steps": 5,
+ "speculative_eagle_topk": 16,
+ "speculative_num_draft_tokens": 64,
+ "mem_fraction_static": 0.7,
+ "cuda_graph_max_bs": 5,
+ "dtype": "float16",
+ }
+ NUM_CONFIGS = 1
+
+
+@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
+class TestEAGLEDraftExtend(CustomTestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.base_url = DEFAULT_URL_FOR_TEST
+ cls.process = popen_launch_server(
+ DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
+ cls.base_url,
+ timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
+ other_args=[
+ "--speculative-algorithm",
+ "EAGLE",
+ "--speculative-draft-model-path",
+ DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
+ "--speculative-num-steps",
+ 1,
+ "--speculative-eagle-topk",
+ 1,
+ "--speculative-num-draft-tokens",
+ 2,
+ "--max-running-requests",
+ 4,
+ "--attention-backend",
+ "fa3",
+ ],
+ )
+ cls.accept_len_threshold = 1.50
+
+ @classmethod
+ def tearDownClass(cls):
+ kill_process_tree(cls.process.pid)
+
+ def test_one_batch_accept_length(self):
+ resp = requests.get(self.base_url + "/flush_cache")
+ self.assertEqual(resp.status_code, 200)
+
+ prompts = [
+ "Hello, my name is",
+ "The president of the United States is",
+ "The capital of France is",
+ "The future of AI is",
+ ]
+ url = self.base_url + "/generate"
+ data = {
+ "text": prompts,
+ "sampling_params": {
+ "temperature": 0,
+ "max_new_tokens": 512,
+ },
+ }
+ response = requests.post(url, json=data)
+ self.assertEqual(response.status_code, 200)
+ outputs = response.json()
+ for i in range(len(prompts)):
+ output = outputs[i]
+ if "spec_verify_ct" in output["meta_info"]:
+ acc_length = (
+ output["meta_info"]["completion_tokens"]
+ / output["meta_info"]["spec_verify_ct"]
+ )
+ else:
+ acc_length = 1.0
+
+ print(f"{acc_length=}")
+ self.assertGreater(acc_length, self.accept_len_threshold)
+
+
+class TestEAGLEDraftExtendFlashinfer(TestEAGLEDraftExtend):
+ @classmethod
+ def setUpClass(cls):
+ cls.base_url = DEFAULT_URL_FOR_TEST
+ cls.process = popen_launch_server(
+ DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
+ cls.base_url,
+ timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
+ other_args=[
+ "--speculative-algorithm",
+ "EAGLE",
+ "--speculative-draft-model-path",
+ DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
+ "--speculative-num-steps",
+ 1,
+ "--speculative-eagle-topk",
+ 1,
+ "--speculative-num-draft-tokens",
+ 2,
+ "--max-running-requests",
+ 4,
+ "--attention-backend",
+ "flashinfer",
+ ],
+ )
+ cls.accept_len_threshold = 1.50
+
+
+@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
+class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend):
+ @classmethod
+ def setUpClass(cls):
+ cls.base_url = DEFAULT_URL_FOR_TEST
+ cls.process = popen_launch_server(
+ DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
+ cls.base_url,
+ timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
+ other_args=[
+ "--speculative-algorithm",
+ "EAGLE",
+ "--speculative-draft-model-path",
+ DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
+ "--speculative-num-steps",
+ 1,
+ "--speculative-eagle-topk",
+ 1,
+ "--speculative-num-draft-tokens",
+ 2,
+ "--max-running-requests",
+ 4,
+ "--attention-backend",
+ "triton",
+ ],
+ )
+ cls.accept_len_threshold = 1.50
+
+
+@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
+class TestEAGLEDraftExtendFlashinferMLA(TestEAGLEDraftExtend):
+ @classmethod
+ def setUpClass(cls):
+ cls.base_url = DEFAULT_URL_FOR_TEST
+ cls.process = popen_launch_server(
+ DEFAULT_MODEL_NAME_FOR_TEST_MLA,
+ cls.base_url,
+ timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
+ other_args=[
+ "--speculative-algorithm",
+ "EAGLE",
+ "--speculative-num-steps",
+ 1,
+ "--speculative-eagle-topk",
+ 1,
+ "--speculative-num-draft-tokens",
+ 2,
+ "--max-running-requests",
+ 4,
+ "--attention-backend",
+ "flashinfer",
+ ],
+ )
+ cls.accept_len_threshold = 1.85
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer_b.py
similarity index 58%
rename from test/srt/test_eagle_infer.py
rename to test/srt/test_eagle_infer_b.py
index f6dc2cf1f..72a69864f 100644
--- a/test/srt/test_eagle_infer.py
+++ b/test/srt/test_eagle_infer_b.py
@@ -12,18 +12,14 @@ import numpy as np
import requests
import torch
-import sglang as sgl
-from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
- DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
- is_in_ci,
popen_launch_server,
run_logprob_check,
)
@@ -33,152 +29,6 @@ prefill_tolerance = 5e-2
decode_tolerance: float = 5e-2
-class TestEAGLEEngine(CustomTestCase):
- BASE_CONFIG = {
- "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
- "speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
- "speculative_algorithm": "EAGLE",
- "speculative_num_steps": 5,
- "speculative_eagle_topk": 4,
- "speculative_num_draft_tokens": 8,
- "mem_fraction_static": 0.7,
- "cuda_graph_max_bs": 5,
- }
- NUM_CONFIGS = 2
-
- def setUp(self):
- self.prompt = "Today is a sunny day and I like"
- self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
-
- ref_engine = sgl.Engine(
- model_path=self.BASE_CONFIG["model_path"], cuda_graph_max_bs=1
- )
- self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
- ref_engine.shutdown()
-
- def test_correctness(self):
- configs = [
- # Basic config
- self.BASE_CONFIG,
- # Chunked prefill
- {**self.BASE_CONFIG, "chunked_prefill_size": 4},
- ]
-
- for i, config in enumerate(configs[: self.NUM_CONFIGS]):
- with self.subTest(i=i):
- print(f"{config=}")
- engine = sgl.Engine(**config, log_level="info", decode_log_interval=10)
- try:
- self._test_single_generation(engine)
- self._test_batch_generation(engine)
- self._test_eos_token(engine)
- self._test_acc_length(engine)
- finally:
- engine.shutdown()
- print("=" * 100)
-
- def _test_single_generation(self, engine):
- output = engine.generate(self.prompt, self.sampling_params)["text"]
- print(f"{output=}, {self.ref_output=}")
- self.assertEqual(output, self.ref_output)
-
- def _test_batch_generation(self, engine):
- prompts = [
- "Hello, my name is",
- "The president of the United States is",
- "The capital of France is",
- "The future of AI is",
- ]
- params = {"temperature": 0, "max_new_tokens": 50}
-
- outputs = engine.generate(prompts, params)
- for prompt, output in zip(prompts, outputs):
- print(f"Prompt: {prompt}")
- print(f"Generated: {output['text']}")
- print("-" * 40)
-
- print(f"{engine.get_server_info()=}")
-
- avg_spec_accept_length = engine.get_server_info()["internal_states"][0][
- "avg_spec_accept_length"
- ]
- print(f"{avg_spec_accept_length=}")
- self.assertGreater(avg_spec_accept_length, 1.9)
-
- def _test_eos_token(self, engine):
- prompt = "[INST] <>\nYou are a helpful assistant.\n<>\nToday is a sunny day and I like [/INST]"
- params = {
- "temperature": 0.1,
- "max_new_tokens": 1024,
- "skip_special_tokens": False,
- }
-
- tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
- output = engine.generate(prompt, params)["text"]
- print(f"{output=}")
-
- tokens = tokenizer.encode(output, truncation=False)
- self.assertNotIn(tokenizer.eos_token_id, tokens)
-
- def _test_acc_length(self, engine):
- prompt = [
- "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
- ] * 5 # test batched generation
- sampling_params = {"temperature": 0, "max_new_tokens": 512}
- output = engine.generate(prompt, sampling_params)
- output = output[0]
-
- if "spec_verify_ct" in output["meta_info"]:
- acc_length = (
- output["meta_info"]["completion_tokens"]
- / output["meta_info"]["spec_verify_ct"]
- )
- else:
- acc_length = 1.0
-
- speed = (
- output["meta_info"]["completion_tokens"]
- / output["meta_info"]["e2e_latency"]
- )
- print(f"{acc_length=}")
-
- if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST:
- self.assertGreater(acc_length, 3.6)
- else:
- self.assertGreater(acc_length, 2.5)
-
-
-class TestEAGLEEngineTokenMap(TestEAGLEEngine):
- BASE_CONFIG = {
- "model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
- "speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B",
- "speculative_algorithm": "EAGLE",
- "speculative_num_steps": 5,
- "speculative_eagle_topk": 4,
- "speculative_num_draft_tokens": 8,
- "speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
- "mem_fraction_static": 0.7,
- "cuda_graph_max_bs": 5,
- "dtype": "float16",
- }
- NUM_CONFIGS = 1
-
-
-class TestEAGLE3Engine(TestEAGLEEngine):
- BASE_CONFIG = {
- "model_path": "meta-llama/Llama-3.1-8B-Instruct",
- "speculative_draft_model_path": "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
- "speculative_algorithm": "EAGLE3",
- "speculative_num_steps": 5,
- "speculative_eagle_topk": 16,
- "speculative_num_draft_tokens": 64,
- "mem_fraction_static": 0.7,
- "cuda_graph_max_bs": 5,
- "dtype": "float16",
- }
- NUM_CONFIGS = 1
-
-
class TestEAGLEServer(CustomTestCase):
PROMPTS = [
"[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like[/INST]"
@@ -579,156 +429,5 @@ class TestEAGLEServerTriton(TestEAGLEServer):
)
-@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
-class TestEAGLEDraftExtend(CustomTestCase):
- @classmethod
- def setUpClass(cls):
- cls.base_url = DEFAULT_URL_FOR_TEST
- cls.process = popen_launch_server(
- DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
- cls.base_url,
- timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
- other_args=[
- "--speculative-algorithm",
- "EAGLE",
- "--speculative-draft-model-path",
- DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
- "--speculative-num-steps",
- 1,
- "--speculative-eagle-topk",
- 1,
- "--speculative-num-draft-tokens",
- 2,
- "--max-running-requests",
- 4,
- "--attention-backend",
- "fa3",
- ],
- )
- cls.accept_len_threshold = 1.50
-
- @classmethod
- def tearDownClass(cls):
- kill_process_tree(cls.process.pid)
-
- def test_one_batch_accept_length(self):
- resp = requests.get(self.base_url + "/flush_cache")
- self.assertEqual(resp.status_code, 200)
-
- prompts = [
- "Hello, my name is",
- "The president of the United States is",
- "The capital of France is",
- "The future of AI is",
- ]
- url = self.base_url + "/generate"
- data = {
- "text": prompts,
- "sampling_params": {
- "temperature": 0,
- "max_new_tokens": 512,
- },
- }
- response = requests.post(url, json=data)
- self.assertEqual(response.status_code, 200)
- outputs = response.json()
- for i in range(len(prompts)):
- output = outputs[i]
- if "spec_verify_ct" in output["meta_info"]:
- acc_length = (
- output["meta_info"]["completion_tokens"]
- / output["meta_info"]["spec_verify_ct"]
- )
- else:
- acc_length = 1.0
-
- print(f"{acc_length=}")
- self.assertGreater(acc_length, self.accept_len_threshold)
-
-
-class TestEAGLEDraftExtendFlashinfer(TestEAGLEDraftExtend):
- @classmethod
- def setUpClass(cls):
- cls.base_url = DEFAULT_URL_FOR_TEST
- cls.process = popen_launch_server(
- DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
- cls.base_url,
- timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
- other_args=[
- "--speculative-algorithm",
- "EAGLE",
- "--speculative-draft-model-path",
- DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
- "--speculative-num-steps",
- 1,
- "--speculative-eagle-topk",
- 1,
- "--speculative-num-draft-tokens",
- 2,
- "--max-running-requests",
- 4,
- "--attention-backend",
- "flashinfer",
- ],
- )
- cls.accept_len_threshold = 1.50
-
-
-@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
-class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend):
- @classmethod
- def setUpClass(cls):
- cls.base_url = DEFAULT_URL_FOR_TEST
- cls.process = popen_launch_server(
- DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
- cls.base_url,
- timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
- other_args=[
- "--speculative-algorithm",
- "EAGLE",
- "--speculative-draft-model-path",
- DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
- "--speculative-num-steps",
- 1,
- "--speculative-eagle-topk",
- 1,
- "--speculative-num-draft-tokens",
- 2,
- "--max-running-requests",
- 4,
- "--attention-backend",
- "triton",
- ],
- )
- cls.accept_len_threshold = 1.50
-
-
-@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
-class TestEAGLEDraftExtendFlashinferMLA(TestEAGLEDraftExtend):
- @classmethod
- def setUpClass(cls):
- cls.base_url = DEFAULT_URL_FOR_TEST
- cls.process = popen_launch_server(
- DEFAULT_MODEL_NAME_FOR_TEST_MLA,
- cls.base_url,
- timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
- other_args=[
- "--speculative-algorithm",
- "EAGLE",
- "--speculative-num-steps",
- 1,
- "--speculative-eagle-topk",
- 1,
- "--speculative-num-draft-tokens",
- 2,
- "--max-running-requests",
- 4,
- "--attention-backend",
- "flashinfer",
- ],
- )
- cls.accept_len_threshold = 1.85
-
-
if __name__ == "__main__":
unittest.main()