Support NextN (MTP) speculative decoding for DeepSeek-V3/R1 (#3582)
This commit is contained in:
@@ -24,6 +24,7 @@ from sglang.srt.speculative.eagle_utils import (
|
||||
fast_topk,
|
||||
select_top_k_tokens,
|
||||
)
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -57,11 +58,15 @@ class EAGLEWorker(TpModelWorker):
|
||||
# Parse arguments
|
||||
self.topk = server_args.speculative_eagle_topk
|
||||
self.speculative_num_steps = server_args.speculative_num_steps
|
||||
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
self.server_args = server_args
|
||||
|
||||
# Share the embedding and lm_head
|
||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||
self.model_runner.model.set_embed_and_head(embed, head)
|
||||
if not self.speculative_algorithm.is_nextn():
|
||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||
self.model_runner.model.set_embed_and_head(embed, head)
|
||||
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
||||
|
||||
# Create multi-step attn backends and cuda graph runners
|
||||
|
||||
@@ -5,18 +5,28 @@ class SpeculativeAlgorithm(IntEnum):
|
||||
NONE = auto()
|
||||
EAGLE = auto()
|
||||
|
||||
# NEXTN spec decoding is for DeepSeek V3/R1
|
||||
# currently it's implemented based on EAGLE
|
||||
NEXTN = auto()
|
||||
|
||||
def is_none(self):
|
||||
return self == SpeculativeAlgorithm.NONE
|
||||
|
||||
def is_eagle(self):
|
||||
return self == SpeculativeAlgorithm.EAGLE
|
||||
return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.NEXTN
|
||||
|
||||
def is_nextn(self):
|
||||
return self == SpeculativeAlgorithm.NEXTN
|
||||
|
||||
@staticmethod
|
||||
def from_string(name: str):
|
||||
name_map = {
|
||||
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
||||
"NEXTN": SpeculativeAlgorithm.NEXTN,
|
||||
None: SpeculativeAlgorithm.NONE,
|
||||
}
|
||||
if name is not None:
|
||||
name = name.upper()
|
||||
return name_map[name]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user