From 0d4e3228cfb4926ebb112ec8dc5d8574df315f5c Mon Sep 17 00:00:00 2001 From: William <323163497@qq.com> Date: Tue, 4 Mar 2025 20:26:24 +0800 Subject: [PATCH] [Feature] Add test for speculative_token_map (#4016) --- python/sglang/srt/speculative/eagle_worker.py | 28 ++++----- test/srt/test_eagle_infer.py | 61 +++++++++++++++++++ 2 files changed, 75 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 810514429..557216140 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -31,6 +31,16 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm logger = logging.getLogger(__name__) +def load_token_map(token_map_path: str) -> List[int]: + if not os.path.exists(token_map_path): + cache_dir = snapshot_download( + os.path.dirname(token_map_path), + ignore_patterns=["*.bin", "*.safetensors"], + ) + token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path)) + return torch.load(token_map_path) + + class EAGLEWorker(TpModelWorker): def __init__( @@ -48,20 +58,12 @@ class EAGLEWorker(TpModelWorker): server_args.disable_cuda_graph = True if server_args.speculative_token_map is not None: - if os.path.exists(server_args.speculative_token_map): - self.hot_token_id = torch.load(server_args.speculative_token_map) - else: - cache_dir = snapshot_download( - os.path.dirname(server_args.speculative_token_map), - ignore_patterns=["*.bin", "*.safetensors"], - ) - file_path = os.path.join( - cache_dir, os.path.basename(server_args.speculative_token_map) - ) - self.hot_token_id = torch.load(file_path) + self.hot_token_id = load_token_map(server_args.speculative_token_map) server_args.json_model_override_args = ( f'{{"hot_vocab_size": {len(self.hot_token_id)}}}' ) + else: + self.hot_token_id = None super().__init__( gpu_id=gpu_id, @@ -84,14 +86,12 @@ class EAGLEWorker(TpModelWorker): # Share the embedding and lm_head embed, head = self.target_worker.model_runner.model.get_embed_and_head() - if server_args.speculative_token_map is not None: + if self.hot_token_id is not None: head = head.clone() self.hot_token_id = torch.tensor( self.hot_token_id, dtype=torch.int32, device=head.device ) head.data = head.data[self.hot_token_id] - else: - self.hot_token_id = None self.model_runner.model.set_embed_and_head(embed, head) self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index 30eaf2ab0..863da34bf 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -95,6 +95,67 @@ class TestEAGLEEngine(unittest.TestCase): print("-" * 40) +class TestEAGLEEngineTokenMap(unittest.TestCase): + BASE_CONFIG = { + "model_path": "meta-llama/Meta-Llama-3-8B-Instruct", + "speculative_draft_model_path": "lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B", + "speculative_algorithm": "EAGLE", + "speculative_num_steps": 5, + "speculative_eagle_topk": 8, + "speculative_num_draft_tokens": 64, + "mem_fraction_static": 0.7, + "cuda_graph_max_bs": 4, + "dtype": "float16", + } + + def setUp(self): + self.prompt = "Today is a sunny day and I like" + self.sampling_params = {"temperature": 0, "max_new_tokens": 8} + + ref_engine = sgl.Engine(model_path=self.BASE_CONFIG["model_path"]) + self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"] + ref_engine.shutdown() + + def test_token_map_accuracy(self): + configs = [ + self.BASE_CONFIG, + { + **self.BASE_CONFIG, + "speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt", + }, + ] + + for config in configs: + print("testing config: ", config) + with self.subTest(cuda_graph="enabled"): + engine = sgl.Engine(**config) + try: + self._test_basic_generation(engine) + self._test_batch_generation(engine) + finally: + engine.shutdown() + + def _test_basic_generation(self, engine): + output = engine.generate(self.prompt, self.sampling_params)["text"] + print(f"{output=}, {self.ref_output=}") + self.assertEqual(output, self.ref_output) + + def _test_batch_generation(self, engine): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + params = {"temperature": 0, "max_new_tokens": 30} + + outputs = engine.generate(prompts, params) + for prompt, output in zip(prompts, outputs): + print(f"Prompt: {prompt}") + print(f"Generated: {output['text']}") + print("-" * 40) + + prompts = [ "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like[/INST]" '[INST] <>\\nYou are a helpful assistant.\\n<>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]',