add disable cuda graph unit test for eagle 2 (#3412)
This commit is contained in:
@@ -30,51 +30,69 @@ class TestEAGLEEngine(unittest.TestCase):
|
|||||||
ref_output = ref_engine.generate(prompt, sampling_params)["text"]
|
ref_output = ref_engine.generate(prompt, sampling_params)["text"]
|
||||||
ref_engine.shutdown()
|
ref_engine.shutdown()
|
||||||
|
|
||||||
# Launch EAGLE engine
|
# Test cases with different configurations
|
||||||
engine = sgl.Engine(
|
configs = [
|
||||||
model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
# Original config
|
||||||
speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
{
|
||||||
speculative_algorithm="EAGLE",
|
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||||
speculative_num_steps=5,
|
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||||
speculative_eagle_topk=8,
|
"speculative_algorithm": "EAGLE",
|
||||||
speculative_num_draft_tokens=64,
|
"speculative_num_steps": 5,
|
||||||
mem_fraction_static=0.7,
|
"speculative_eagle_topk": 8,
|
||||||
)
|
"speculative_num_draft_tokens": 64,
|
||||||
|
"mem_fraction_static": 0.7,
|
||||||
# Case 1: Test the output of EAGLE engine is the same as normal engine
|
},
|
||||||
out1 = engine.generate(prompt, sampling_params)["text"]
|
# Config with CUDA graph disabled
|
||||||
print(f"{out1=}, {ref_output=}")
|
{
|
||||||
self.assertEqual(out1, ref_output)
|
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||||
|
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||||
# Case 2: Test the output of EAGLE engine does not contain unexpected EOS
|
"speculative_algorithm": "EAGLE",
|
||||||
prompt = "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like [/INST]"
|
"speculative_num_steps": 5,
|
||||||
sampling_params = {
|
"speculative_eagle_topk": 8,
|
||||||
"temperature": 0,
|
"speculative_num_draft_tokens": 64,
|
||||||
"max_new_tokens": 1024,
|
"mem_fraction_static": 0.7,
|
||||||
"skip_special_tokens": False,
|
"disable_cuda_graph": True,
|
||||||
}
|
},
|
||||||
|
|
||||||
tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
|
|
||||||
out2 = engine.generate(prompt, sampling_params)["text"]
|
|
||||||
print(f"{out2=}")
|
|
||||||
tokens = tokenizer.encode(out2, truncation=False)
|
|
||||||
assert tokenizer.eos_token_id not in tokens
|
|
||||||
|
|
||||||
# Case 3: Batched prompts
|
|
||||||
prompts = [
|
|
||||||
"Hello, my name is",
|
|
||||||
"The president of the United States is",
|
|
||||||
"The capital of France is",
|
|
||||||
"The future of AI is",
|
|
||||||
]
|
]
|
||||||
sampling_params = {"temperature": 0, "max_new_tokens": 30}
|
|
||||||
outputs = engine.generate(prompts, sampling_params)
|
|
||||||
for prompt, output in zip(prompts, outputs):
|
|
||||||
print("===============================")
|
|
||||||
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
|
|
||||||
|
|
||||||
# Shutdown the engine
|
for config in configs:
|
||||||
engine.shutdown()
|
# Launch EAGLE engine
|
||||||
|
engine = sgl.Engine(**config)
|
||||||
|
|
||||||
|
# Case 1: Test the output of EAGLE engine is the same as normal engine
|
||||||
|
out1 = engine.generate(prompt, sampling_params)["text"]
|
||||||
|
print(f"{out1=}, {ref_output=}")
|
||||||
|
self.assertEqual(out1, ref_output)
|
||||||
|
|
||||||
|
# Case 2: Test the output of EAGLE engine does not contain unexpected EOS
|
||||||
|
prompt = "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like [/INST]"
|
||||||
|
sampling_params = {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 1024,
|
||||||
|
"skip_special_tokens": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
|
||||||
|
out2 = engine.generate(prompt, sampling_params)["text"]
|
||||||
|
print(f"{out2=}")
|
||||||
|
tokens = tokenizer.encode(out2, truncation=False)
|
||||||
|
assert tokenizer.eos_token_id not in tokens
|
||||||
|
|
||||||
|
# Case 3: Batched prompts
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
sampling_params = {"temperature": 0, "max_new_tokens": 30}
|
||||||
|
outputs = engine.generate(prompts, sampling_params)
|
||||||
|
for prompt, output in zip(prompts, outputs):
|
||||||
|
print("===============================")
|
||||||
|
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
|
||||||
|
|
||||||
|
# Shutdown the engine
|
||||||
|
engine.shutdown()
|
||||||
|
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
|
|||||||
Reference in New Issue
Block a user