Fix the chat template for QWen (#83)

This commit is contained in:
Lianmin Zheng
2024-01-22 21:46:47 -08:00
committed by GitHub
parent 94e05770db
commit 959c4174b2
2 changed files with 10 additions and 2 deletions

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass, field
from enum import Enum, auto from enum import Enum, auto
from typing import Callable, Dict, List, Tuple from typing import Callable, Dict, List, Tuple, Optional
class ChatTemplateStyle(Enum): class ChatTemplateStyle(Enum):
@@ -13,6 +13,7 @@ class ChatTemplate:
name: str name: str
default_system_prompt: str default_system_prompt: str
role_prefix_and_suffix: Dict[str, Tuple[str]] role_prefix_and_suffix: Dict[str, Tuple[str]]
stop_str: List[str] = ()
image_token: str = "<image>" image_token: str = "<image>"
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
@@ -110,6 +111,7 @@ register_chat_template(
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"), "assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
}, },
style=ChatTemplateStyle.PLAIN, style=ChatTemplateStyle.PLAIN,
stop_str=('<|im_end|>',)
) )
) )

View File

@@ -486,6 +486,12 @@ class StreamExecutor:
if clone is None: if clone is None:
clone = self.default_sampling_para.clone() clone = self.default_sampling_para.clone()
setattr(clone, item, value) setattr(clone, item, value)
if self.chat_template.stop_str:
if not clone:
clone = self.default_sampling_para.clone()
clone.stop += self.chat_template.stop_str
return clone or self.default_sampling_para return clone or self.default_sampling_para
def __del__(self): def __del__(self):