Auto set draft model path for MTP (#5793)
This commit is contained in:
@@ -22,7 +22,7 @@ import random
|
||||
import tempfile
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from sglang.srt.hf_transformers_utils import check_gguf_file
|
||||
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.utils import (
|
||||
configure_ipv6,
|
||||
@@ -333,6 +333,14 @@ class ServerArgs:
|
||||
"eagle speculative decoding."
|
||||
)
|
||||
|
||||
model_arch = get_model_arch(self)
|
||||
|
||||
# Auto set draft_model_path DeepSeek-V3/R1
|
||||
if self.speculative_draft_model_path is None and model_arch in [
|
||||
"DeepseekV3ForCausalLM"
|
||||
]:
|
||||
self.speculative_draft_model_path = self.model_path
|
||||
|
||||
# Auto choose parameters
|
||||
if self.speculative_num_steps is None:
|
||||
assert (
|
||||
@@ -343,7 +351,7 @@ class ServerArgs:
|
||||
self.speculative_num_steps,
|
||||
self.speculative_eagle_topk,
|
||||
self.speculative_num_draft_tokens,
|
||||
) = auto_choose_speculative_params(self)
|
||||
) = auto_choose_speculative_params(model_arch)
|
||||
|
||||
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
||||
self.speculative_eagle_topk = 1
|
||||
@@ -1367,20 +1375,22 @@ class DeprecatedAction(argparse.Action):
|
||||
raise ValueError(self.help)
|
||||
|
||||
|
||||
def auto_choose_speculative_params(self: ServerArgs):
|
||||
def get_model_arch(args: ServerArgs):
|
||||
hf_config = get_config(
|
||||
args.model_path,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
revision=args.revision,
|
||||
model_override_args=json.loads(args.json_model_override_args),
|
||||
)
|
||||
return hf_config.architectures[0]
|
||||
|
||||
|
||||
def auto_choose_speculative_params(arch: str):
|
||||
"""
|
||||
Automatically choose the parameters for speculative decoding.
|
||||
|
||||
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
|
||||
"""
|
||||
config_path = os.path.join(self.model_path, "config.json")
|
||||
if not os.path.exists(config_path):
|
||||
raise ValueError(f"{config_path} is not found.")
|
||||
|
||||
config = json.load(open(config_path))
|
||||
|
||||
arch = config.get("architectures", ["Unknown"])[0]
|
||||
|
||||
if arch in ["LlamaForCausalLM"]:
|
||||
# The default value for llama
|
||||
return (5, 4, 8)
|
||||
|
||||
Reference in New Issue
Block a user