diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py
index b04b13211..4a6170320 100644
--- a/test/srt/test_eagle_infer.py
+++ b/test/srt/test_eagle_infer.py
@@ -20,79 +20,78 @@ from sglang.test.test_utils import (
class TestEAGLEEngine(unittest.TestCase):
+ BASE_CONFIG = {
+ "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
+ "speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
+ "speculative_algorithm": "EAGLE",
+ "speculative_num_steps": 5,
+ "speculative_eagle_topk": 8,
+ "speculative_num_draft_tokens": 64,
+ "mem_fraction_static": 0.7,
+ }
- def test_eagle_accuracy(self):
- prompt1 = "Today is a sunny day and I like"
- sampling_params1 = {"temperature": 0, "max_new_tokens": 8}
+ def setUp(self):
+ self.prompt = "Today is a sunny day and I like"
+ self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
- # Get the reference output
ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
- ref_output = ref_engine.generate(prompt1, sampling_params1)["text"]
+ self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
ref_engine.shutdown()
- # Test cases with different configurations
+ def test_eagle_accuracy(self):
configs = [
- # Original config
- {
- "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
- "speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
- "speculative_algorithm": "EAGLE",
- "speculative_num_steps": 5,
- "speculative_eagle_topk": 8,
- "speculative_num_draft_tokens": 64,
- "mem_fraction_static": 0.7,
- },
- # Config with CUDA graph disabled
- {
- "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
- "speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
- "speculative_algorithm": "EAGLE",
- "speculative_num_steps": 5,
- "speculative_eagle_topk": 8,
- "speculative_num_draft_tokens": 64,
- "mem_fraction_static": 0.7,
- "disable_cuda_graph": True,
- },
+ self.BASE_CONFIG,
+ {**self.BASE_CONFIG, "disable_cuda_graph": True},
]
for config in configs:
- # Launch EAGLE engine
- engine = sgl.Engine(**config)
+ with self.subTest(
+ cuda_graph=(
+ "enabled" if len(config) == len(self.BASE_CONFIG) else "disabled"
+ )
+ ):
+ engine = sgl.Engine(**config)
+ try:
+ self._test_basic_generation(engine)
+ self._test_eos_token(engine)
+ self._test_batch_generation(engine)
+ finally:
+ engine.shutdown()
- # Case 1: Test the output of EAGLE engine is the same as normal engine
- out1 = engine.generate(prompt1, sampling_params1)["text"]
- print(f"{out1=}, {ref_output=}")
- self.assertEqual(out1, ref_output)
+ def _test_basic_generation(self, engine):
+ output = engine.generate(self.prompt, self.sampling_params)["text"]
+ print(f"{output=}, {self.ref_output=}")
+ self.assertEqual(output, self.ref_output)
- # Case 2: Test the output of EAGLE engine does not contain unexpected EOS
- prompt2 = "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like [/INST]"
- sampling_params2 = {
- "temperature": 0,
- "max_new_tokens": 1024,
- "skip_special_tokens": False,
- }
+ def _test_eos_token(self, engine):
+ prompt = "[INST] <>\nYou are a helpful assistant.\n<>\nToday is a sunny day and I like [/INST]"
+ params = {
+ "temperature": 0,
+ "max_new_tokens": 1024,
+ "skip_special_tokens": False,
+ }
- tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
- out2 = engine.generate(prompt2, sampling_params2)["text"]
- print(f"{out2=}")
- tokens = tokenizer.encode(out2, truncation=False)
- assert tokenizer.eos_token_id not in tokens
+ tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
+ output = engine.generate(prompt, params)["text"]
+ print(f"{output=}")
- # Case 3: Batched prompts
- prompts = [
- "Hello, my name is",
- "The president of the United States is",
- "The capital of France is",
- "The future of AI is",
- ]
- sampling_params3 = {"temperature": 0, "max_new_tokens": 30}
- outputs = engine.generate(prompts, sampling_params3)
- for prompt, output in zip(prompts, outputs):
- print("===============================")
- print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
+ tokens = tokenizer.encode(output, truncation=False)
+ self.assertNotIn(tokenizer.eos_token_id, tokens)
- # Shutdown the engine
- engine.shutdown()
+ def _test_batch_generation(self, engine):
+ prompts = [
+ "Hello, my name is",
+ "The president of the United States is",
+ "The capital of France is",
+ "The future of AI is",
+ ]
+ params = {"temperature": 0, "max_new_tokens": 30}
+
+ outputs = engine.generate(prompts, params)
+ for prompt, output in zip(prompts, outputs):
+ print(f"Prompt: {prompt}")
+ print(f"Generated: {output['text']}")
+ print("-" * 40)
prompts = [