[feat] add small vocab table for eagle's draft model[1]. (#3822)

Co-authored-by: Achazwl <323163497@qq.com>
Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
Zhousx
2025-03-03 10:58:45 +08:00
committed by GitHub
parent b7e274f2d9
commit 7fbab730bd
5 changed files with 129 additions and 8 deletions

View File

@@ -117,9 +117,14 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
if hasattr(config, "hot_vocab_size"):
self.lm_head = ParallelLMHead(
config.hot_vocab_size, config.hidden_size, quant_config=quant_config
)
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -128,6 +128,7 @@ class ServerArgs:
speculative_num_steps: int = 5
speculative_eagle_topk: int = 8
speculative_num_draft_tokens: int = 64
speculative_token_map: Optional[str] = None
# Double Sparsity
enable_double_sparsity: bool = False
@@ -751,6 +752,12 @@ class ServerArgs:
help="The number of token sampled from draft model in Speculative Decoding.",
default=ServerArgs.speculative_num_draft_tokens,
)
parser.add_argument(
"--speculative-token-map",
type=str,
help="The path of the draft model's small vocab table.",
default=ServerArgs.speculative_token_map,
)
# Double Sparsity
parser.add_argument(

View File

@@ -1,8 +1,10 @@
import logging
import os
import time
from typing import List, Optional, Union
import torch
from huggingface_hub import snapshot_download
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
@@ -44,6 +46,23 @@ class EAGLEWorker(TpModelWorker):
# We will capture it later
backup_disable_cuda_graph = server_args.disable_cuda_graph
server_args.disable_cuda_graph = True
if server_args.speculative_token_map is not None:
if os.path.exists(server_args.speculative_token_map):
self.hot_token_id = torch.load(server_args.speculative_token_map)
else:
cache_dir = snapshot_download(
os.path.dirname(server_args.speculative_token_map),
ignore_patterns=["*.bin", "*.safetensors"],
)
file_path = os.path.join(
cache_dir, os.path.basename(server_args.speculative_token_map)
)
self.hot_token_id = torch.load(file_path)
server_args.json_model_override_args = (
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
)
super().__init__(
gpu_id=gpu_id,
tp_rank=tp_rank,
@@ -66,7 +85,21 @@ class EAGLEWorker(TpModelWorker):
# Share the embedding and lm_head
if not self.speculative_algorithm.is_nextn():
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
if server_args.speculative_token_map is not None:
head = head.clone()
self.hot_token_id = torch.tensor(
self.hot_token_id, dtype=torch.int32, device=head.device
)
head.data = head.data[self.hot_token_id]
else:
self.hot_token_id = None
self.model_runner.model.set_embed_and_head(embed, head)
else:
if server_args.speculative_token_map is not None:
raise NotImplementedError(
"NEXTN does not support speculative-token-map now"
)
self.hot_token_id = None
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
# Create multi-step attn backends and cuda graph runners
@@ -223,6 +256,8 @@ class EAGLEWorker(TpModelWorker):
spec_info.topk_index,
spec_info.hidden_states,
)
if self.hot_token_id is not None:
topk_index = self.hot_token_id[topk_index]
# Return values
score_list: List[torch.Tensor] = []
@@ -262,6 +297,8 @@ class EAGLEWorker(TpModelWorker):
)
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
if self.hot_token_id is not None:
topk_index = self.hot_token_id[topk_index]
hidden_states = logits_output.hidden_states
return score_list, token_list, parents_list