minor: cleanup test_eagle_infer (#3415)

This commit is contained in:
Yineng Zhang
2025-02-09 09:34:30 +08:00
committed by GitHub
parent 7b4e61fff3
commit 60abdb3e7c

View File

@@ -20,79 +20,78 @@ from sglang.test.test_utils import (
class TestEAGLEEngine(unittest.TestCase): class TestEAGLEEngine(unittest.TestCase):
BASE_CONFIG = {
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"speculative_algorithm": "EAGLE",
"speculative_num_steps": 5,
"speculative_eagle_topk": 8,
"speculative_num_draft_tokens": 64,
"mem_fraction_static": 0.7,
}
def test_eagle_accuracy(self): def setUp(self):
prompt1 = "Today is a sunny day and I like" self.prompt = "Today is a sunny day and I like"
sampling_params1 = {"temperature": 0, "max_new_tokens": 8} self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
# Get the reference output
ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
ref_output = ref_engine.generate(prompt1, sampling_params1)["text"] self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
ref_engine.shutdown() ref_engine.shutdown()
# Test cases with different configurations def test_eagle_accuracy(self):
configs = [ configs = [
# Original config self.BASE_CONFIG,
{ {**self.BASE_CONFIG, "disable_cuda_graph": True},
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"speculative_algorithm": "EAGLE",
"speculative_num_steps": 5,
"speculative_eagle_topk": 8,
"speculative_num_draft_tokens": 64,
"mem_fraction_static": 0.7,
},
# Config with CUDA graph disabled
{
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"speculative_algorithm": "EAGLE",
"speculative_num_steps": 5,
"speculative_eagle_topk": 8,
"speculative_num_draft_tokens": 64,
"mem_fraction_static": 0.7,
"disable_cuda_graph": True,
},
] ]
for config in configs: for config in configs:
# Launch EAGLE engine with self.subTest(
engine = sgl.Engine(**config) cuda_graph=(
"enabled" if len(config) == len(self.BASE_CONFIG) else "disabled"
)
):
engine = sgl.Engine(**config)
try:
self._test_basic_generation(engine)
self._test_eos_token(engine)
self._test_batch_generation(engine)
finally:
engine.shutdown()
# Case 1: Test the output of EAGLE engine is the same as normal engine def _test_basic_generation(self, engine):
out1 = engine.generate(prompt1, sampling_params1)["text"] output = engine.generate(self.prompt, self.sampling_params)["text"]
print(f"{out1=}, {ref_output=}") print(f"{output=}, {self.ref_output=}")
self.assertEqual(out1, ref_output) self.assertEqual(output, self.ref_output)
# Case 2: Test the output of EAGLE engine does not contain unexpected EOS def _test_eos_token(self, engine):
prompt2 = "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like [/INST]" prompt = "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\nToday is a sunny day and I like [/INST]"
sampling_params2 = { params = {
"temperature": 0, "temperature": 0,
"max_new_tokens": 1024, "max_new_tokens": 1024,
"skip_special_tokens": False, "skip_special_tokens": False,
} }
tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
out2 = engine.generate(prompt2, sampling_params2)["text"] output = engine.generate(prompt, params)["text"]
print(f"{out2=}") print(f"{output=}")
tokens = tokenizer.encode(out2, truncation=False)
assert tokenizer.eos_token_id not in tokens
# Case 3: Batched prompts tokens = tokenizer.encode(output, truncation=False)
prompts = [ self.assertNotIn(tokenizer.eos_token_id, tokens)
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params3 = {"temperature": 0, "max_new_tokens": 30}
outputs = engine.generate(prompts, sampling_params3)
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
# Shutdown the engine def _test_batch_generation(self, engine):
engine.shutdown() prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
params = {"temperature": 0, "max_new_tokens": 30}
outputs = engine.generate(prompts, params)
for prompt, output in zip(prompts, outputs):
print(f"Prompt: {prompt}")
print(f"Generated: {output['text']}")
print("-" * 40)
prompts = [ prompts = [