[New Model] Devstral support (#6547)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -21,6 +21,7 @@ from enum import IntEnum, auto
|
|||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
||||||
|
from sglang.srt.utils import read_system_prompt_from_file
|
||||||
|
|
||||||
|
|
||||||
class SeparatorStyle(IntEnum):
|
class SeparatorStyle(IntEnum):
|
||||||
@@ -648,6 +649,20 @@ register_conv_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name="devstral",
|
||||||
|
system_template="[SYSTEM_PROMPT]\n{system_message}\n[/SYSTEM_PROMPT]\n\n",
|
||||||
|
system_message=read_system_prompt_from_file("mistralai/Devstral-Small-2505"),
|
||||||
|
roles=("[INST]", "[/INST]"),
|
||||||
|
sep_style=SeparatorStyle.LLAMA2,
|
||||||
|
sep=" ",
|
||||||
|
sep2=" </s><s>",
|
||||||
|
stop_str=["[INST]", "[/INST]", "[SYSTEM_PROMPT]", "[/SYSTEM_PROMPT]"],
|
||||||
|
image_token="[IMG]",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
|
# reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
|
||||||
register_conv_template(
|
register_conv_template(
|
||||||
Conversation(
|
Conversation(
|
||||||
@@ -961,6 +976,12 @@ def match_moonshot_kimivl(model_path: str):
|
|||||||
return "kimi-vl"
|
return "kimi-vl"
|
||||||
|
|
||||||
|
|
||||||
|
@register_conv_template_matching_function
|
||||||
|
def match_devstral(model_path: str):
|
||||||
|
if re.search(r"devstral", model_path, re.IGNORECASE):
|
||||||
|
return "devstral"
|
||||||
|
|
||||||
|
|
||||||
@register_conv_template_matching_function
|
@register_conv_template_matching_function
|
||||||
def match_phi_4_mm(model_path: str):
|
def match_phi_4_mm(model_path: str):
|
||||||
if "phi-4-multimodal" in model_path.lower():
|
if "phi-4-multimodal" in model_path.lower():
|
||||||
|
|||||||
@@ -203,6 +203,10 @@ def get_tokenizer(
|
|||||||
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||||
kwargs["use_fast"] = False
|
kwargs["use_fast"] = False
|
||||||
|
|
||||||
|
# TODO(Xinyuan): Remove this once we have a proper tokenizer for Devstral
|
||||||
|
if tokenizer_name == "mistralai/Devstral-Small-2505":
|
||||||
|
tokenizer_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||||
|
|
||||||
is_gguf = check_gguf_file(tokenizer_name)
|
is_gguf = check_gguf_file(tokenizer_name)
|
||||||
if is_gguf:
|
if is_gguf:
|
||||||
kwargs["gguf_file"] = tokenizer_name
|
kwargs["gguf_file"] = tokenizer_name
|
||||||
|
|||||||
@@ -2169,3 +2169,51 @@ class Withable(Generic[T]):
|
|||||||
finally:
|
finally:
|
||||||
assert self._value is new_value
|
assert self._value is new_value
|
||||||
self._value = None
|
self._value = None
|
||||||
|
|
||||||
|
|
||||||
|
def find_local_repo_dir(repo_id: str, revision: Optional[str] = None) -> Optional[str]:
|
||||||
|
import huggingface_hub as hf
|
||||||
|
|
||||||
|
# Build cache path
|
||||||
|
cache_path = os.path.join(
|
||||||
|
hf.constants.HF_HUB_CACHE,
|
||||||
|
hf.constants.REPO_ID_SEPARATOR.join(["models", *repo_id.split("/")]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get revision from main ref if not specified
|
||||||
|
if not revision:
|
||||||
|
ref_path = os.path.join(cache_path, "refs", "main")
|
||||||
|
if os.path.isfile(ref_path):
|
||||||
|
with open(ref_path) as f:
|
||||||
|
revision = f.read().strip()
|
||||||
|
|
||||||
|
# List files from revision directory
|
||||||
|
if revision:
|
||||||
|
rev_dir = os.path.join(cache_path, "snapshots", revision)
|
||||||
|
if os.path.isdir(rev_dir):
|
||||||
|
return rev_dir
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def read_system_prompt_from_file(model_name: str) -> str:
|
||||||
|
"""Read system prompt from a file in the HuggingFace cache directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: The model name to construct the file path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The system prompt content from the file, or empty string if file not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
local_repo_dir = find_local_repo_dir(model_name)
|
||||||
|
if local_repo_dir:
|
||||||
|
system_prompt_file = os.path.join(local_repo_dir, "SYSTEM_PROMPT.txt")
|
||||||
|
if os.path.exists(system_prompt_file):
|
||||||
|
with open(system_prompt_file, "r", encoding="utf-8") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
return ""
|
||||||
|
except Exception:
|
||||||
|
# If anything fails, return empty string
|
||||||
|
return ""
|
||||||
|
|||||||
Reference in New Issue
Block a user