minor: cleanup test_eagle_infer (#3415)

This commit is contained in:
Yineng Zhang
2025-02-09 09:34:30 +08:00
committed by GitHub
parent 7b4e61fff3
commit 60abdb3e7c

View File

@@ -20,79 +20,78 @@ from sglang.test.test_utils import (
class TestEAGLEEngine(unittest.TestCase): class TestEAGLEEngine(unittest.TestCase):
BASE_CONFIG = {
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"speculative_algorithm": "EAGLE",
"speculative_num_steps": 5,
"speculative_eagle_topk": 8,
"speculative_num_draft_tokens": 64,
"mem_fraction_static": 0.7,
}
def test_eagle_accuracy(self): def setUp(self):
prompt1 = "Today is a sunny day and I like" self.prompt = "Today is a sunny day and I like"
sampling_params1 = {"temperature": 0, "max_new_tokens": 8} self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
# Get the reference output
ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
ref_output = ref_engine.generate(prompt1, sampling_params1)["text"] self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
ref_engine.shutdown() ref_engine.shutdown()
# Test cases with different configurations def test_eagle_accuracy(self):
configs = [ configs = [
# Original config self.BASE_CONFIG,
{ {**self.BASE_CONFIG, "disable_cuda_graph": True},
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"speculative_algorithm": "EAGLE",
"speculative_num_steps": 5,
"speculative_eagle_topk": 8,
"speculative_num_draft_tokens": 64,
"mem_fraction_static": 0.7,
},
# Config with CUDA graph disabled
{
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"speculative_algorithm": "EAGLE",
"speculative_num_steps": 5,
"speculative_eagle_topk": 8,
"speculative_num_draft_tokens": 64,
"mem_fraction_static": 0.7,
"disable_cuda_graph": True,
},
] ]
for config in configs: for config in configs:
# Launch EAGLE engine with self.subTest(
cuda_graph=(
"enabled" if len(config) == len(self.BASE_CONFIG) else "disabled"
)
):
engine = sgl.Engine(**config) engine = sgl.Engine(**config)
try:
self._test_basic_generation(engine)
self._test_eos_token(engine)
self._test_batch_generation(engine)
finally:
engine.shutdown()
# Case 1: Test the output of EAGLE engine is the same as normal engine def _test_basic_generation(self, engine):
out1 = engine.generate(prompt1, sampling_params1)["text"] output = engine.generate(self.prompt, self.sampling_params)["text"]
print(f"{out1=}, {ref_output=}") print(f"{output=}, {self.ref_output=}")
self.assertEqual(out1, ref_output) self.assertEqual(output, self.ref_output)
# Case 2: Test the output of EAGLE engine does not contain unexpected EOS def _test_eos_token(self, engine):
prompt2 = "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like [/INST]" prompt = "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\nToday is a sunny day and I like [/INST]"
sampling_params2 = { params = {
"temperature": 0, "temperature": 0,
"max_new_tokens": 1024, "max_new_tokens": 1024,
"skip_special_tokens": False, "skip_special_tokens": False,
} }
tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
out2 = engine.generate(prompt2, sampling_params2)["text"] output = engine.generate(prompt, params)["text"]
print(f"{out2=}") print(f"{output=}")
tokens = tokenizer.encode(out2, truncation=False)
assert tokenizer.eos_token_id not in tokens
# Case 3: Batched prompts tokens = tokenizer.encode(output, truncation=False)
self.assertNotIn(tokenizer.eos_token_id, tokens)
def _test_batch_generation(self, engine):
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
"The capital of France is", "The capital of France is",
"The future of AI is", "The future of AI is",
] ]
sampling_params3 = {"temperature": 0, "max_new_tokens": 30} params = {"temperature": 0, "max_new_tokens": 30}
outputs = engine.generate(prompts, sampling_params3)
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
# Shutdown the engine outputs = engine.generate(prompts, params)
engine.shutdown() for prompt, output in zip(prompts, outputs):
print(f"Prompt: {prompt}")
print(f"Generated: {output['text']}")
print("-" * 40)
prompts = [ prompts = [