diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 5ea9786b8..9b7347fd3 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -1,6 +1,6 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum, auto -from typing import Callable, Dict, List, Tuple +from typing import Callable, Dict, List, Tuple, Optional class ChatTemplateStyle(Enum): @@ -13,6 +13,7 @@ class ChatTemplate: name: str default_system_prompt: str role_prefix_and_suffix: Dict[str, Tuple[str]] + stop_str: List[str] = () image_token: str = "" style: ChatTemplateStyle = ChatTemplateStyle.PLAIN @@ -110,6 +111,7 @@ register_chat_template( "assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"), }, style=ChatTemplateStyle.PLAIN, + stop_str=('<|im_end|>',) ) ) diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 2803eedae..23b8ca4bc 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -486,6 +486,12 @@ class StreamExecutor: if clone is None: clone = self.default_sampling_para.clone() setattr(clone, item, value) + + if self.chat_template.stop_str: + if not clone: + clone = self.default_sampling_para.clone() + clone.stop += self.chat_template.stop_str + return clone or self.default_sampling_para def __del__(self):