Support NextN (MTP) speculative decoding for DeepSeek-V3/R1 (#3582)

This commit is contained in:
Ke Bao
2025-02-15 05:28:34 +08:00
committed by GitHub
parent fb4c9c3a30
commit 862dd76c76
7 changed files with 437 additions and 7 deletions

View File

@@ -24,6 +24,7 @@ from sglang.srt.speculative.eagle_utils import (
fast_topk,
select_top_k_tokens,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
logger = logging.getLogger(__name__)
@@ -57,11 +58,15 @@ class EAGLEWorker(TpModelWorker):
# Parse arguments
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
self.server_args = server_args
# Share the embedding and lm_head
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
self.model_runner.model.set_embed_and_head(embed, head)
if not self.speculative_algorithm.is_nextn():
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
self.model_runner.model.set_embed_and_head(embed, head)
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
# Create multi-step attn backends and cuda graph runners

View File

@@ -5,18 +5,28 @@ class SpeculativeAlgorithm(IntEnum):
NONE = auto()
EAGLE = auto()
# NEXTN spec decoding is for DeepSeek V3/R1
# currently it's implemented based on EAGLE
NEXTN = auto()
def is_none(self):
return self == SpeculativeAlgorithm.NONE
def is_eagle(self):
return self == SpeculativeAlgorithm.EAGLE
return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.NEXTN
def is_nextn(self):
return self == SpeculativeAlgorithm.NEXTN
@staticmethod
def from_string(name: str):
name_map = {
"EAGLE": SpeculativeAlgorithm.EAGLE,
"NEXTN": SpeculativeAlgorithm.NEXTN,
None: SpeculativeAlgorithm.NONE,
}
if name is not None:
name = name.upper()
return name_map[name]