[Fix] Fix eagle with disable cuda graph (#3411)

This commit is contained in:
Ying Sheng
2025-02-08 16:40:00 -08:00
committed by GitHub
parent 6222e1c228
commit 7b4e61fff3
2 changed files with 29 additions and 18 deletions

View File

@@ -22,12 +22,12 @@ from sglang.test.test_utils import (
class TestEAGLEEngine(unittest.TestCase):
def test_eagle_accuracy(self):
prompt = "Today is a sunny day and I like"
sampling_params = {"temperature": 0, "max_new_tokens": 8}
prompt1 = "Today is a sunny day and I like"
sampling_params1 = {"temperature": 0, "max_new_tokens": 8}
# Get the reference output
ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
ref_output = ref_engine.generate(prompt, sampling_params)["text"]
ref_output = ref_engine.generate(prompt1, sampling_params1)["text"]
ref_engine.shutdown()
# Test cases with different configurations
@@ -60,20 +60,20 @@ class TestEAGLEEngine(unittest.TestCase):
engine = sgl.Engine(**config)
# Case 1: Test the output of EAGLE engine is the same as normal engine
out1 = engine.generate(prompt, sampling_params)["text"]
out1 = engine.generate(prompt1, sampling_params1)["text"]
print(f"{out1=}, {ref_output=}")
self.assertEqual(out1, ref_output)
# Case 2: Test the output of EAGLE engine does not contain unexpected EOS
prompt = "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like [/INST]"
sampling_params = {
prompt2 = "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like [/INST]"
sampling_params2 = {
"temperature": 0,
"max_new_tokens": 1024,
"skip_special_tokens": False,
}
tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
out2 = engine.generate(prompt, sampling_params)["text"]
out2 = engine.generate(prompt2, sampling_params2)["text"]
print(f"{out2=}")
tokens = tokenizer.encode(out2, truncation=False)
assert tokenizer.eos_token_id not in tokens
@@ -85,8 +85,8 @@ class TestEAGLEEngine(unittest.TestCase):
"The capital of France is",
"The future of AI is",
]
sampling_params = {"temperature": 0, "max_new_tokens": 30}
outputs = engine.generate(prompts, sampling_params)
sampling_params3 = {"temperature": 0, "max_new_tokens": 30}
outputs = engine.generate(prompts, sampling_params3)
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")