From 6222e1c2282766189adbdf1c0335725d2759f84c Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sun, 9 Feb 2025 08:02:56 +0800 Subject: [PATCH] add disable cuda graph unit test for eagle 2 (#3412) --- test/srt/test_eagle_infer.py | 104 ++++++++++++++++++++--------------- 1 file changed, 61 insertions(+), 43 deletions(-) diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index b01c26049..b5b17dad1 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -30,51 +30,69 @@ class TestEAGLEEngine(unittest.TestCase): ref_output = ref_engine.generate(prompt, sampling_params)["text"] ref_engine.shutdown() - # Launch EAGLE engine - engine = sgl.Engine( - model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, - speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, - speculative_algorithm="EAGLE", - speculative_num_steps=5, - speculative_eagle_topk=8, - speculative_num_draft_tokens=64, - mem_fraction_static=0.7, - ) - - # Case 1: Test the output of EAGLE engine is the same as normal engine - out1 = engine.generate(prompt, sampling_params)["text"] - print(f"{out1=}, {ref_output=}") - self.assertEqual(out1, ref_output) - - # Case 2: Test the output of EAGLE engine does not contain unexpected EOS - prompt = "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like [/INST]" - sampling_params = { - "temperature": 0, - "max_new_tokens": 1024, - "skip_special_tokens": False, - } - - tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) - out2 = engine.generate(prompt, sampling_params)["text"] - print(f"{out2=}") - tokens = tokenizer.encode(out2, truncation=False) - assert tokenizer.eos_token_id not in tokens - - # Case 3: Batched prompts - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", + # Test cases with different configurations + configs = [ + # Original config + { + "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + "speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "speculative_algorithm": "EAGLE", + "speculative_num_steps": 5, + "speculative_eagle_topk": 8, + "speculative_num_draft_tokens": 64, + "mem_fraction_static": 0.7, + }, + # Config with CUDA graph disabled + { + "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + "speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "speculative_algorithm": "EAGLE", + "speculative_num_steps": 5, + "speculative_eagle_topk": 8, + "speculative_num_draft_tokens": 64, + "mem_fraction_static": 0.7, + "disable_cuda_graph": True, + }, ] - sampling_params = {"temperature": 0, "max_new_tokens": 30} - outputs = engine.generate(prompts, sampling_params) - for prompt, output in zip(prompts, outputs): - print("===============================") - print(f"Prompt: {prompt}\nGenerated text: {output['text']}") - # Shutdown the engine - engine.shutdown() + for config in configs: + # Launch EAGLE engine + engine = sgl.Engine(**config) + + # Case 1: Test the output of EAGLE engine is the same as normal engine + out1 = engine.generate(prompt, sampling_params)["text"] + print(f"{out1=}, {ref_output=}") + self.assertEqual(out1, ref_output) + + # Case 2: Test the output of EAGLE engine does not contain unexpected EOS + prompt = "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like [/INST]" + sampling_params = { + "temperature": 0, + "max_new_tokens": 1024, + "skip_special_tokens": False, + } + + tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) + out2 = engine.generate(prompt, sampling_params)["text"] + print(f"{out2=}") + tokens = tokenizer.encode(out2, truncation=False) + assert tokenizer.eos_token_id not in tokens + + # Case 3: Batched prompts + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = {"temperature": 0, "max_new_tokens": 30} + outputs = engine.generate(prompts, sampling_params) + for prompt, output in zip(prompts, outputs): + print("===============================") + print(f"Prompt: {prompt}\nGenerated text: {output['text']}") + + # Shutdown the engine + engine.shutdown() prompts = [