[Feature] Support EAGLE 3 (#4247)

This commit is contained in:
James Liu
2025-03-18 10:35:23 -04:00
committed by GitHub
parent 8baf9a0c18
commit 9e0186f352
11 changed files with 385 additions and 22 deletions

View File

@@ -30,6 +30,7 @@ from sglang.srt.speculative.eagle_utils import (
fast_topk,
select_top_k_tokens,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import empty_context, get_available_gpu_memory, is_cuda_available
if is_cuda_available():
@@ -66,6 +67,9 @@ class EAGLEWorker(TpModelWorker):
self.gpu_id = gpu_id
self.device = server_args.device
self.target_worker = target_worker
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
# Override context length with target model's context length
server_args.context_length = target_worker.model_runner.model_config.context_len
@@ -81,7 +85,13 @@ class EAGLEWorker(TpModelWorker):
)
# Load hot token ids
if server_args.speculative_token_map is not None:
if self.speculative_algorithm.is_eagle3():
if server_args.speculative_token_map is not None:
logger.warning(
"Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map."
)
self.hot_token_id = None
elif server_args.speculative_token_map is not None:
self.hot_token_id = load_token_map(server_args.speculative_token_map)
server_args.json_model_override_args = (
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
@@ -102,13 +112,24 @@ class EAGLEWorker(TpModelWorker):
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
# Share the embedding and lm_head
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
if self.hot_token_id is not None:
head = head.clone()
self.hot_token_id = self.hot_token_id.to(head.device)
head.data = head.data[self.hot_token_id]
self.draft_model_runner.model.set_embed_and_head(embed, head)
if self.speculative_algorithm.is_eagle3():
# EAGLE3 models don't share lm_head
self.draft_model_runner.model.set_embed(embed)
# grab hot token ids
self.hot_token_id = self.draft_model_runner.model.get_hot_token_id().to(
embed.device
)
else:
if self.hot_token_id is not None:
head = head.clone()
self.hot_token_id = self.hot_token_id.to(head.device)
head.data = head.data[self.hot_token_id]
# Share the embedding and lm_head
self.draft_model_runner.model.set_embed_and_head(embed, head)
# Init attention backend and cuda graphs
self.draft_model_runner.server_args.disable_cuda_graph = (