[Feature] Add test for speculative_token_map (#4016)
This commit is contained in:
@@ -95,6 +95,67 @@ class TestEAGLEEngine(unittest.TestCase):
|
||||
print("-" * 40)
|
||||
|
||||
|
||||
class TestEAGLEEngineTokenMap(unittest.TestCase):
|
||||
BASE_CONFIG = {
|
||||
"model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
"speculative_draft_model_path": "lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B",
|
||||
"speculative_algorithm": "EAGLE",
|
||||
"speculative_num_steps": 5,
|
||||
"speculative_eagle_topk": 8,
|
||||
"speculative_num_draft_tokens": 64,
|
||||
"mem_fraction_static": 0.7,
|
||||
"cuda_graph_max_bs": 4,
|
||||
"dtype": "float16",
|
||||
}
|
||||
|
||||
def setUp(self):
|
||||
self.prompt = "Today is a sunny day and I like"
|
||||
self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
||||
|
||||
ref_engine = sgl.Engine(model_path=self.BASE_CONFIG["model_path"])
|
||||
self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
|
||||
ref_engine.shutdown()
|
||||
|
||||
def test_token_map_accuracy(self):
|
||||
configs = [
|
||||
self.BASE_CONFIG,
|
||||
{
|
||||
**self.BASE_CONFIG,
|
||||
"speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
|
||||
},
|
||||
]
|
||||
|
||||
for config in configs:
|
||||
print("testing config: ", config)
|
||||
with self.subTest(cuda_graph="enabled"):
|
||||
engine = sgl.Engine(**config)
|
||||
try:
|
||||
self._test_basic_generation(engine)
|
||||
self._test_batch_generation(engine)
|
||||
finally:
|
||||
engine.shutdown()
|
||||
|
||||
def _test_basic_generation(self, engine):
|
||||
output = engine.generate(self.prompt, self.sampling_params)["text"]
|
||||
print(f"{output=}, {self.ref_output=}")
|
||||
self.assertEqual(output, self.ref_output)
|
||||
|
||||
def _test_batch_generation(self, engine):
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
params = {"temperature": 0, "max_new_tokens": 30}
|
||||
|
||||
outputs = engine.generate(prompts, params)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
print(f"Prompt: {prompt}")
|
||||
print(f"Generated: {output['text']}")
|
||||
print("-" * 40)
|
||||
|
||||
|
||||
prompts = [
|
||||
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like[/INST]"
|
||||
'[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]',
|
||||
|
||||
Reference in New Issue
Block a user