Support NextN (MTP) speculative decoding for DeepSeek-V3/R1 (#3582)

This commit is contained in:
Ke Bao
2025-02-15 05:28:34 +08:00
committed by GitHub
parent fb4c9c3a30
commit 862dd76c76
7 changed files with 437 additions and 7 deletions

View File

@@ -24,6 +24,7 @@ from sglang.srt.speculative.eagle_utils import (
fast_topk,
select_top_k_tokens,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
logger = logging.getLogger(__name__)
@@ -57,11 +58,15 @@ class EAGLEWorker(TpModelWorker):
# Parse arguments
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.server_args = server_args
# Share the embedding and lm_head
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
self.model_runner.model.set_embed_and_head(embed, head)
if not self.speculative_algorithm.is_nextn():
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
self.model_runner.model.set_embed_and_head(embed, head)
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
# Create multi-step attn backends and cuda graph runners