Support NextN (MTP) speculative decoding for DeepSeek-V3/R1 (#3582)
This commit is contained in:
@@ -24,6 +24,7 @@ from sglang.srt.speculative.eagle_utils import (
|
||||
fast_topk,
|
||||
select_top_k_tokens,
|
||||
)
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -57,11 +58,15 @@ class EAGLEWorker(TpModelWorker):
|
||||
# Parse arguments
|
||||
self.topk = server_args.speculative_eagle_topk
|
||||
self.speculative_num_steps = server_args.speculative_num_steps
|
||||
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
self.server_args = server_args
|
||||
|
||||
# Share the embedding and lm_head
|
||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||
self.model_runner.model.set_embed_and_head(embed, head)
|
||||
if not self.speculative_algorithm.is_nextn():
|
||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||
self.model_runner.model.set_embed_and_head(embed, head)
|
||||
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
||||
|
||||
# Create multi-step attn backends and cuda graph runners
|
||||
|
||||
Reference in New Issue
Block a user