Auto set draft model path for MTP (#5793)

This commit is contained in:
Ke Bao
2025-04-30 07:25:40 +08:00
committed by GitHub
parent 9419e75d60
commit dd408ee481
6 changed files with 115 additions and 287 deletions

View File

@@ -22,7 +22,7 @@ import random
import tempfile
from typing import List, Literal, Optional
from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.utils import (
configure_ipv6,
@@ -333,6 +333,14 @@ class ServerArgs:
"eagle speculative decoding."
)
model_arch = get_model_arch(self)
# Auto set draft_model_path DeepSeek-V3/R1
if self.speculative_draft_model_path is None and model_arch in [
"DeepseekV3ForCausalLM"
]:
self.speculative_draft_model_path = self.model_path
# Auto choose parameters
if self.speculative_num_steps is None:
assert (
@@ -343,7 +351,7 @@ class ServerArgs:
self.speculative_num_steps,
self.speculative_eagle_topk,
self.speculative_num_draft_tokens,
) = auto_choose_speculative_params(self)
) = auto_choose_speculative_params(model_arch)
if self.page_size > 1 and self.speculative_eagle_topk > 1:
self.speculative_eagle_topk = 1
@@ -1367,20 +1375,22 @@ class DeprecatedAction(argparse.Action):
raise ValueError(self.help)
def auto_choose_speculative_params(self: ServerArgs):
def get_model_arch(args: ServerArgs):
hf_config = get_config(
args.model_path,
trust_remote_code=args.trust_remote_code,
revision=args.revision,
model_override_args=json.loads(args.json_model_override_args),
)
return hf_config.architectures[0]
def auto_choose_speculative_params(arch: str):
"""
Automatically choose the parameters for speculative decoding.
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
"""
config_path = os.path.join(self.model_path, "config.json")
if not os.path.exists(config_path):
raise ValueError(f"{config_path} is not found.")
config = json.load(open(config_path))
arch = config.get("architectures", ["Unknown"])[0]
if arch in ["LlamaForCausalLM"]:
# The default value for llama
return (5, 4, 8)