[feat] add small vocab table for eagle's draft model[1]. (#3822)
Co-authored-by: Achazwl <323163497@qq.com> Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
@@ -117,9 +117,14 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
if hasattr(config, "hot_vocab_size"):
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.hot_vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -128,6 +128,7 @@ class ServerArgs:
|
||||
speculative_num_steps: int = 5
|
||||
speculative_eagle_topk: int = 8
|
||||
speculative_num_draft_tokens: int = 64
|
||||
speculative_token_map: Optional[str] = None
|
||||
|
||||
# Double Sparsity
|
||||
enable_double_sparsity: bool = False
|
||||
@@ -751,6 +752,12 @@ class ServerArgs:
|
||||
help="The number of token sampled from draft model in Speculative Decoding.",
|
||||
default=ServerArgs.speculative_num_draft_tokens,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-token-map",
|
||||
type=str,
|
||||
help="The path of the draft model's small vocab table.",
|
||||
default=ServerArgs.speculative_token_map,
|
||||
)
|
||||
|
||||
# Double Sparsity
|
||||
parser.add_argument(
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
@@ -44,6 +46,23 @@ class EAGLEWorker(TpModelWorker):
|
||||
# We will capture it later
|
||||
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
||||
server_args.disable_cuda_graph = True
|
||||
|
||||
if server_args.speculative_token_map is not None:
|
||||
if os.path.exists(server_args.speculative_token_map):
|
||||
self.hot_token_id = torch.load(server_args.speculative_token_map)
|
||||
else:
|
||||
cache_dir = snapshot_download(
|
||||
os.path.dirname(server_args.speculative_token_map),
|
||||
ignore_patterns=["*.bin", "*.safetensors"],
|
||||
)
|
||||
file_path = os.path.join(
|
||||
cache_dir, os.path.basename(server_args.speculative_token_map)
|
||||
)
|
||||
self.hot_token_id = torch.load(file_path)
|
||||
server_args.json_model_override_args = (
|
||||
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
@@ -66,7 +85,21 @@ class EAGLEWorker(TpModelWorker):
|
||||
# Share the embedding and lm_head
|
||||
if not self.speculative_algorithm.is_nextn():
|
||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||
if server_args.speculative_token_map is not None:
|
||||
head = head.clone()
|
||||
self.hot_token_id = torch.tensor(
|
||||
self.hot_token_id, dtype=torch.int32, device=head.device
|
||||
)
|
||||
head.data = head.data[self.hot_token_id]
|
||||
else:
|
||||
self.hot_token_id = None
|
||||
self.model_runner.model.set_embed_and_head(embed, head)
|
||||
else:
|
||||
if server_args.speculative_token_map is not None:
|
||||
raise NotImplementedError(
|
||||
"NEXTN does not support speculative-token-map now"
|
||||
)
|
||||
self.hot_token_id = None
|
||||
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
||||
|
||||
# Create multi-step attn backends and cuda graph runners
|
||||
@@ -223,6 +256,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
spec_info.topk_index,
|
||||
spec_info.hidden_states,
|
||||
)
|
||||
if self.hot_token_id is not None:
|
||||
topk_index = self.hot_token_id[topk_index]
|
||||
|
||||
# Return values
|
||||
score_list: List[torch.Tensor] = []
|
||||
@@ -262,6 +297,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||
if self.hot_token_id is not None:
|
||||
topk_index = self.hot_token_id[topk_index]
|
||||
hidden_states = logits_output.hidden_states
|
||||
|
||||
return score_list, token_list, parents_list
|
||||
|
||||
Reference in New Issue
Block a user