Fix the chat template for QWen (#83)
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Callable, Dict, List, Tuple
|
from typing import Callable, Dict, List, Tuple, Optional
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplateStyle(Enum):
|
class ChatTemplateStyle(Enum):
|
||||||
@@ -13,6 +13,7 @@ class ChatTemplate:
|
|||||||
name: str
|
name: str
|
||||||
default_system_prompt: str
|
default_system_prompt: str
|
||||||
role_prefix_and_suffix: Dict[str, Tuple[str]]
|
role_prefix_and_suffix: Dict[str, Tuple[str]]
|
||||||
|
stop_str: List[str] = ()
|
||||||
image_token: str = "<image>"
|
image_token: str = "<image>"
|
||||||
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
|
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
|
||||||
|
|
||||||
@@ -110,6 +111,7 @@ register_chat_template(
|
|||||||
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
|
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
|
||||||
},
|
},
|
||||||
style=ChatTemplateStyle.PLAIN,
|
style=ChatTemplateStyle.PLAIN,
|
||||||
|
stop_str=('<|im_end|>',)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -486,6 +486,12 @@ class StreamExecutor:
|
|||||||
if clone is None:
|
if clone is None:
|
||||||
clone = self.default_sampling_para.clone()
|
clone = self.default_sampling_para.clone()
|
||||||
setattr(clone, item, value)
|
setattr(clone, item, value)
|
||||||
|
|
||||||
|
if self.chat_template.stop_str:
|
||||||
|
if not clone:
|
||||||
|
clone = self.default_sampling_para.clone()
|
||||||
|
clone.stop += self.chat_template.stop_str
|
||||||
|
|
||||||
return clone or self.default_sampling_para
|
return clone or self.default_sampling_para
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user