From 0d4e3228cfb4926ebb112ec8dc5d8574df315f5c Mon Sep 17 00:00:00 2001
From: William <323163497@qq.com>
Date: Tue, 4 Mar 2025 20:26:24 +0800
Subject: [PATCH] [Feature] Add test for speculative_token_map (#4016)
---
python/sglang/srt/speculative/eagle_worker.py | 28 ++++-----
test/srt/test_eagle_infer.py | 61 +++++++++++++++++++
2 files changed, 75 insertions(+), 14 deletions(-)
diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py
index 810514429..557216140 100644
--- a/python/sglang/srt/speculative/eagle_worker.py
+++ b/python/sglang/srt/speculative/eagle_worker.py
@@ -31,6 +31,16 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
logger = logging.getLogger(__name__)
+def load_token_map(token_map_path: str) -> List[int]:
+ if not os.path.exists(token_map_path):
+ cache_dir = snapshot_download(
+ os.path.dirname(token_map_path),
+ ignore_patterns=["*.bin", "*.safetensors"],
+ )
+ token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
+ return torch.load(token_map_path)
+
+
class EAGLEWorker(TpModelWorker):
def __init__(
@@ -48,20 +58,12 @@ class EAGLEWorker(TpModelWorker):
server_args.disable_cuda_graph = True
if server_args.speculative_token_map is not None:
- if os.path.exists(server_args.speculative_token_map):
- self.hot_token_id = torch.load(server_args.speculative_token_map)
- else:
- cache_dir = snapshot_download(
- os.path.dirname(server_args.speculative_token_map),
- ignore_patterns=["*.bin", "*.safetensors"],
- )
- file_path = os.path.join(
- cache_dir, os.path.basename(server_args.speculative_token_map)
- )
- self.hot_token_id = torch.load(file_path)
+ self.hot_token_id = load_token_map(server_args.speculative_token_map)
server_args.json_model_override_args = (
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
)
+ else:
+ self.hot_token_id = None
super().__init__(
gpu_id=gpu_id,
@@ -84,14 +86,12 @@ class EAGLEWorker(TpModelWorker):
# Share the embedding and lm_head
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
- if server_args.speculative_token_map is not None:
+ if self.hot_token_id is not None:
head = head.clone()
self.hot_token_id = torch.tensor(
self.hot_token_id, dtype=torch.int32, device=head.device
)
head.data = head.data[self.hot_token_id]
- else:
- self.hot_token_id = None
self.model_runner.model.set_embed_and_head(embed, head)
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py
index 30eaf2ab0..863da34bf 100644
--- a/test/srt/test_eagle_infer.py
+++ b/test/srt/test_eagle_infer.py
@@ -95,6 +95,67 @@ class TestEAGLEEngine(unittest.TestCase):
print("-" * 40)
+class TestEAGLEEngineTokenMap(unittest.TestCase):
+ BASE_CONFIG = {
+ "model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
+ "speculative_draft_model_path": "lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B",
+ "speculative_algorithm": "EAGLE",
+ "speculative_num_steps": 5,
+ "speculative_eagle_topk": 8,
+ "speculative_num_draft_tokens": 64,
+ "mem_fraction_static": 0.7,
+ "cuda_graph_max_bs": 4,
+ "dtype": "float16",
+ }
+
+ def setUp(self):
+ self.prompt = "Today is a sunny day and I like"
+ self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
+
+ ref_engine = sgl.Engine(model_path=self.BASE_CONFIG["model_path"])
+ self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
+ ref_engine.shutdown()
+
+ def test_token_map_accuracy(self):
+ configs = [
+ self.BASE_CONFIG,
+ {
+ **self.BASE_CONFIG,
+ "speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
+ },
+ ]
+
+ for config in configs:
+ print("testing config: ", config)
+ with self.subTest(cuda_graph="enabled"):
+ engine = sgl.Engine(**config)
+ try:
+ self._test_basic_generation(engine)
+ self._test_batch_generation(engine)
+ finally:
+ engine.shutdown()
+
+ def _test_basic_generation(self, engine):
+ output = engine.generate(self.prompt, self.sampling_params)["text"]
+ print(f"{output=}, {self.ref_output=}")
+ self.assertEqual(output, self.ref_output)
+
+ def _test_batch_generation(self, engine):
+ prompts = [
+ "Hello, my name is",
+ "The president of the United States is",
+ "The capital of France is",
+ "The future of AI is",
+ ]
+ params = {"temperature": 0, "max_new_tokens": 30}
+
+ outputs = engine.generate(prompts, params)
+ for prompt, output in zip(prompts, outputs):
+ print(f"Prompt: {prompt}")
+ print(f"Generated: {output['text']}")
+ print("-" * 40)
+
+
prompts = [
"[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like[/INST]"
'[INST] <>\\nYou are a helpful assistant.\\n<>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]',