[Feature] Support EAGLE 3 (#4247)
This commit is contained in:
@@ -30,6 +30,7 @@ from sglang.srt.speculative.eagle_utils import (
|
||||
fast_topk,
|
||||
select_top_k_tokens,
|
||||
)
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import empty_context, get_available_gpu_memory, is_cuda_available
|
||||
|
||||
if is_cuda_available():
|
||||
@@ -66,6 +67,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.gpu_id = gpu_id
|
||||
self.device = server_args.device
|
||||
self.target_worker = target_worker
|
||||
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
|
||||
# Override context length with target model's context length
|
||||
server_args.context_length = target_worker.model_runner.model_config.context_len
|
||||
@@ -81,7 +85,13 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
|
||||
# Load hot token ids
|
||||
if server_args.speculative_token_map is not None:
|
||||
if self.speculative_algorithm.is_eagle3():
|
||||
if server_args.speculative_token_map is not None:
|
||||
logger.warning(
|
||||
"Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map."
|
||||
)
|
||||
self.hot_token_id = None
|
||||
elif server_args.speculative_token_map is not None:
|
||||
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
||||
server_args.json_model_override_args = (
|
||||
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
|
||||
@@ -102,13 +112,24 @@ class EAGLEWorker(TpModelWorker):
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
)
|
||||
|
||||
# Share the embedding and lm_head
|
||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||
if self.hot_token_id is not None:
|
||||
head = head.clone()
|
||||
self.hot_token_id = self.hot_token_id.to(head.device)
|
||||
head.data = head.data[self.hot_token_id]
|
||||
self.draft_model_runner.model.set_embed_and_head(embed, head)
|
||||
|
||||
if self.speculative_algorithm.is_eagle3():
|
||||
# EAGLE3 models don't share lm_head
|
||||
self.draft_model_runner.model.set_embed(embed)
|
||||
|
||||
# grab hot token ids
|
||||
self.hot_token_id = self.draft_model_runner.model.get_hot_token_id().to(
|
||||
embed.device
|
||||
)
|
||||
else:
|
||||
if self.hot_token_id is not None:
|
||||
head = head.clone()
|
||||
self.hot_token_id = self.hot_token_id.to(head.device)
|
||||
head.data = head.data[self.hot_token_id]
|
||||
|
||||
# Share the embedding and lm_head
|
||||
self.draft_model_runner.model.set_embed_and_head(embed, head)
|
||||
|
||||
# Init attention backend and cuda graphs
|
||||
self.draft_model_runner.server_args.disable_cuda_graph = (
|
||||
|
||||
Reference in New Issue
Block a user