2024-07-28 23:07:12 +10:00
"""
Copyright 2023 - 2024 SGLang Team
Licensed under the Apache License , Version 2.0 ( the " License " ) ;
you may not use this file except in compliance with the License .
You may obtain a copy of the License at
http : / / www . apache . org / licenses / LICENSE - 2.0
Unless required by applicable law or agreed to in writing , software
distributed under the License is distributed on an " AS IS " BASIS ,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied .
See the License for the specific language governing permissions and
limitations under the License .
"""
2024-09-10 15:15:08 -07:00
""" Conversation chat templates. """
2024-06-12 21:48:40 -07:00
2024-01-18 23:43:09 -08:00
# Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import dataclasses
2024-01-18 23:51:19 -08:00
from enum import IntEnum , auto
2024-01-30 23:12:33 +09:00
from typing import Dict , List , Optional , Tuple , Union
2024-01-18 23:43:09 -08:00
2024-07-19 23:20:54 -07:00
from sglang . srt . openai_api . protocol import ChatCompletionRequest
2024-01-18 23:51:19 -08:00
2024-01-18 23:43:09 -08:00
class SeparatorStyle ( IntEnum ) :
""" Separator styles. """
ADD_COLON_SINGLE = auto ( )
ADD_COLON_TWO = auto ( )
ADD_COLON_SPACE_SINGLE = auto ( )
NO_COLON_SINGLE = auto ( )
NO_COLON_TWO = auto ( )
ADD_NEW_LINE_SINGLE = auto ( )
LLAMA2 = auto ( )
2024-08-24 05:11:16 +08:00
LLAMA3 = auto ( )
2024-01-18 23:43:09 -08:00
CHATGLM = auto ( )
CHATML = auto ( )
CHATINTERN = auto ( )
DOLLY = auto ( )
RWKV = auto ( )
PHOENIX = auto ( )
ROBIN = auto ( )
FALCON_CHAT = auto ( )
CHATGLM3 = auto ( )
DEEPSEEK_CHAT = auto ( )
METAMATH = auto ( )
@dataclasses.dataclass
class Conversation :
""" A class that manages prompt templates and keeps all conversation history. """
# The name of this template
name : str
# The template of the system prompt
system_template : str = " {system_message} "
# The system message
system_message : str = " "
# The names of two roles
roles : Tuple [ str ] = ( " USER " , " ASSISTANT " )
# All messages. Each item is (role, message).
messages : List [ List [ str ] ] = ( )
# The number of few shot examples
offset : int = 0
# The separator style and configurations
sep_style : SeparatorStyle = SeparatorStyle . ADD_COLON_SINGLE
sep : str = " \n "
sep2 : str = None
# Stop criteria (the default one is EOS token)
stop_str : Union [ str , List [ str ] ] = None
2024-10-11 05:07:51 -07:00
# The string that represents an image token in the prompt
image_token : str = " <image> "
2024-01-30 23:12:33 +09:00
image_data : Optional [ List [ str ] ] = None
2024-09-09 17:07:34 +08:00
modalities : Optional [ List [ str ] ] = None
2024-01-18 23:43:09 -08:00
def get_prompt ( self ) - > str :
""" Get the prompt for generation. """
system_prompt = self . system_template . format ( system_message = self . system_message )
if self . sep_style == SeparatorStyle . ADD_COLON_SINGLE :
ret = system_prompt + self . sep
for role , message in self . messages :
if message :
ret + = role + " : " + message + self . sep
else :
ret + = role + " : "
return ret
elif self . sep_style == SeparatorStyle . ADD_COLON_TWO :
seps = [ self . sep , self . sep2 ]
ret = system_prompt + seps [ 0 ]
for i , ( role , message ) in enumerate ( self . messages ) :
if message :
ret + = role + " : " + message + seps [ i % 2 ]
else :
ret + = role + " : "
return ret
elif self . sep_style == SeparatorStyle . ADD_COLON_SPACE_SINGLE :
ret = system_prompt + self . sep
for role , message in self . messages :
if message :
ret + = role + " : " + message + self . sep
else :
ret + = role + " : " # must be end with a space
return ret
elif self . sep_style == SeparatorStyle . ADD_NEW_LINE_SINGLE :
ret = " " if system_prompt == " " else system_prompt + self . sep
for role , message in self . messages :
if message :
ret + = role + " \n " + message + self . sep
else :
ret + = role + " \n "
return ret
elif self . sep_style == SeparatorStyle . NO_COLON_SINGLE :
ret = system_prompt
for role , message in self . messages :
if message :
ret + = role + message + self . sep
else :
ret + = role
return ret
elif self . sep_style == SeparatorStyle . NO_COLON_TWO :
seps = [ self . sep , self . sep2 ]
ret = system_prompt
for i , ( role , message ) in enumerate ( self . messages ) :
if message :
ret + = role + message + seps [ i % 2 ]
else :
ret + = role
return ret
elif self . sep_style == SeparatorStyle . RWKV :
ret = system_prompt
for i , ( role , message ) in enumerate ( self . messages ) :
if message :
2024-01-18 23:51:19 -08:00
ret + = (
role
+ " : "
+ message . replace ( " \r \n " , " \n " ) . replace ( " \n \n " , " \n " )
)
2024-01-18 23:43:09 -08:00
ret + = " \n \n "
else :
ret + = role + " : "
return ret
2024-08-24 05:11:16 +08:00
elif self . sep_style == SeparatorStyle . LLAMA3 :
ret = " <|begin_of_text|> "
if self . system_message :
ret + = system_prompt
else :
ret + = " "
for i , ( role , message ) in enumerate ( self . messages ) :
if message :
ret + = f " <|start_header_id|> { role } <|end_header_id|> \n \n "
ret + = f " { message . strip ( ) } <|eot_id|> "
else :
ret + = f " <|start_header_id|> { role } <|end_header_id|> \n \n "
# print(ret)
return ret
2024-01-18 23:43:09 -08:00
elif self . sep_style == SeparatorStyle . LLAMA2 :
seps = [ self . sep , self . sep2 ]
if self . system_message :
ret = system_prompt
else :
ret = " [INST] "
for i , ( role , message ) in enumerate ( self . messages ) :
tag = self . roles [ i % 2 ]
if message :
if i == 0 :
ret + = message + " "
else :
ret + = tag + " " + message + seps [ i % 2 ]
else :
ret + = tag
return ret
elif self . sep_style == SeparatorStyle . CHATGLM :
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
round_add_n = 1 if self . name == " chatglm2 " else 0
if system_prompt :
ret = system_prompt + self . sep
else :
ret = " "
for i , ( role , message ) in enumerate ( self . messages ) :
if i % 2 == 0 :
ret + = f " [Round { i / / 2 + round_add_n } ] { self . sep } "
if message :
ret + = f " { role } : { message } { self . sep } "
else :
ret + = f " { role } : "
return ret
elif self . sep_style == SeparatorStyle . CHATML :
ret = " " if system_prompt == " " else system_prompt + self . sep + " \n "
for role , message in self . messages :
if message :
ret + = role + " \n " + message + self . sep + " \n "
else :
ret + = role + " \n "
return ret
elif self . sep_style == SeparatorStyle . CHATGLM3 :
ret = " "
if self . system_message :
ret + = system_prompt
for role , message in self . messages :
if message :
ret + = role + " \n " + message
else :
ret + = role
return ret
elif self . sep_style == SeparatorStyle . CHATINTERN :
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
seps = [ self . sep , self . sep2 ]
ret = system_prompt
for i , ( role , message ) in enumerate ( self . messages ) :
if i % 2 == 0 :
ret + = " <s> "
if message :
ret + = role + " : " + message + seps [ i % 2 ] + " \n "
else :
ret + = role + " : "
return ret
elif self . sep_style == SeparatorStyle . DOLLY :
seps = [ self . sep , self . sep2 ]
ret = system_prompt
for i , ( role , message ) in enumerate ( self . messages ) :
if message :
ret + = role + " : \n " + message + seps [ i % 2 ]
if i % 2 == 1 :
ret + = " \n \n "
else :
ret + = role + " : \n "
return ret
elif self . sep_style == SeparatorStyle . PHOENIX :
ret = system_prompt
for role , message in self . messages :
if message :
ret + = role + " : " + " <s> " + message + " </s> "
else :
ret + = role + " : " + " <s> "
return ret
elif self . sep_style == SeparatorStyle . ROBIN :
ret = system_prompt + self . sep
for role , message in self . messages :
if message :
ret + = role + " : \n " + message + self . sep
else :
ret + = role + " : \n "
return ret
elif self . sep_style == SeparatorStyle . FALCON_CHAT :
ret = " "
if self . system_message :
ret + = system_prompt + self . sep
for role , message in self . messages :
if message :
ret + = role + " : " + message + self . sep
else :
ret + = role + " : "
return ret
elif self . sep_style == SeparatorStyle . METAMATH :
ret = " " if system_prompt == " " else system_prompt + self . sep
for i , ( role , message ) in enumerate ( self . messages ) :
# For MetaMath, sep2 is used to prefix the message.
starting_sep = " : \n " if i % 2 == 0 else " : " + self . sep2
ending_sep = self . sep if i % 2 == 0 else " "
if message :
ret + = role + starting_sep + message + ending_sep
else :
ret + = role + starting_sep
return ret
elif self . sep_style == SeparatorStyle . DEEPSEEK_CHAT :
seps = [ self . sep , self . sep2 ]
ret = system_prompt
for i , ( role , message ) in enumerate ( self . messages ) :
if message :
ret + = role + " : " + message + seps [ i % 2 ]
else :
ret + = role + " : "
return ret
else :
raise ValueError ( f " Invalid style: { self . sep_style } " )
def set_system_message ( self , system_message : str ) :
""" Set the system message. """
self . system_message = system_message
def append_message ( self , role : str , message : str ) :
""" Append a new message. """
self . messages . append ( [ role , message ] )
2024-01-30 23:12:33 +09:00
def append_image ( self , image : str ) :
""" Append a new message. """
self . image_data . append ( image )
2024-01-18 23:43:09 -08:00
def update_last_message ( self , message : str ) :
""" Update the last output.
The last message is typically set to be None when constructing the prompt ,
so we need to update it in - place after getting the response from a model .
"""
self . messages [ - 1 ] [ 1 ] = message
def to_gradio_chatbot ( self ) :
""" Convert the conversation to gradio chatbot format. """
ret = [ ]
for i , ( role , msg ) in enumerate ( self . messages [ self . offset : ] ) :
if i % 2 == 0 :
ret . append ( [ msg , None ] )
else :
ret [ - 1 ] [ - 1 ] = msg
return ret
def to_openai_api_messages ( self ) :
""" Convert the conversation to OpenAI chat completion format. """
if self . system_message == " " :
ret = [ ]
else :
ret = [ { " role " : " system " , " content " : self . system_message } ]
for i , ( _ , msg ) in enumerate ( self . messages [ self . offset : ] ) :
if i % 2 == 0 :
ret . append ( { " role " : " user " , " content " : msg } )
else :
if msg is not None :
ret . append ( { " role " : " assistant " , " content " : msg } )
return ret
def copy ( self ) :
return Conversation (
name = self . name ,
system_template = self . system_template ,
system_message = self . system_message ,
roles = self . roles ,
messages = [ [ x , y ] for x , y in self . messages ] ,
offset = self . offset ,
sep_style = self . sep_style ,
sep = self . sep ,
sep2 = self . sep2 ,
stop_str = self . stop_str ,
2024-10-11 05:07:51 -07:00
image_token = self . image_token ,
2024-01-18 23:43:09 -08:00
)
def dict ( self ) :
return {
" template_name " : self . name ,
" system_message " : self . system_message ,
" roles " : self . roles ,
" messages " : self . messages ,
" offset " : self . offset ,
}
# A global registry for all conversation templates
chat_templates : Dict [ str , Conversation ] = { }
def register_conv_template ( template : Conversation , override : bool = False ) :
""" Register a new conversation template. """
if not override :
2024-01-18 23:51:19 -08:00
assert (
template . name not in chat_templates
) , f " { template . name } has been registered. "
2024-01-18 23:43:09 -08:00
chat_templates [ template . name ] = template
def chat_template_exists ( template_name : str ) - > bool :
return template_name in chat_templates
2024-01-18 23:51:19 -08:00
def generate_chat_conv (
request : ChatCompletionRequest , template_name : str
) - > Conversation :
2024-01-18 23:43:09 -08:00
conv = chat_templates [ template_name ] . copy ( )
conv = Conversation (
name = conv . name ,
system_template = conv . system_template ,
system_message = conv . system_message ,
roles = conv . roles ,
messages = list ( conv . messages ) , # prevent in-place modification
offset = conv . offset ,
sep_style = SeparatorStyle ( conv . sep_style ) ,
sep = conv . sep ,
sep2 = conv . sep2 ,
stop_str = conv . stop_str ,
2024-01-30 23:12:33 +09:00
image_data = [ ] ,
2024-09-09 17:07:34 +08:00
modalities = [ ] ,
2024-10-11 05:07:51 -07:00
image_token = conv . image_token ,
2024-01-18 23:43:09 -08:00
)
if isinstance ( request . messages , str ) :
raise ValueError ( " The messages should be a list of dict. " )
for message in request . messages :
2024-01-30 23:12:33 +09:00
msg_role = message . role
2024-01-18 23:43:09 -08:00
if msg_role == " system " :
2024-08-30 11:51:44 -07:00
if isinstance ( message . content , str ) :
conv . system_message = message . content
elif isinstance ( message . content , list ) :
if (
len ( message . content ) != 1
or getattr ( message . content [ 0 ] , " type " , None ) != " text "
) :
raise ValueError ( " The system message should be a single text. " )
else :
conv . system_message = getattr ( message . content [ 0 ] , " text " , " " )
2024-01-18 23:43:09 -08:00
elif msg_role == " user " :
2024-01-30 23:12:33 +09:00
# Handle the various types of Chat Request content types here.
role = conv . roles [ 0 ]
if isinstance ( message . content , str ) :
conv . append_message ( conv . roles [ 0 ] , message . content )
else :
real_content = " "
2024-08-24 05:11:16 +08:00
# calculate number of image_url
num_image_url = 0
for content in message . content :
if content . type == " image_url " :
num_image_url + = 1
2024-09-09 17:07:34 +08:00
conv . modalities . append ( content . modalities )
2024-08-24 05:11:16 +08:00
if num_image_url > 1 :
2024-10-11 05:07:51 -07:00
image_token = conv . image_token
2024-08-24 05:11:16 +08:00
else :
2024-10-11 05:07:51 -07:00
image_token = (
conv . image_token + " \n "
if conv . name != " qwen2-vl "
else conv . image_token
)
2024-01-30 23:12:33 +09:00
for content in message . content :
if content . type == " text " :
2024-08-24 05:11:16 +08:00
if num_image_url > 16 :
real_content + = " \n " # for video
2024-01-30 23:12:33 +09:00
real_content + = content . text
elif content . type == " image_url " :
2024-02-11 05:50:13 -08:00
# NOTE: Only works for llava
2024-08-24 05:11:16 +08:00
real_content + = image_token
2024-01-30 23:12:33 +09:00
conv . append_image ( content . image_url . url )
conv . append_message ( conv . roles [ 0 ] , real_content )
2024-01-18 23:43:09 -08:00
elif msg_role == " assistant " :
2024-08-30 11:51:44 -07:00
parsed_content = " "
if isinstance ( message . content , str ) :
parsed_content = message . content
elif isinstance ( message . content , list ) :
if (
len ( message . content ) != 1
or getattr ( message . content [ 0 ] , " type " , None ) != " text "
) :
raise ValueError (
" The assistant ' s response should be a single text. "
)
else :
parsed_content = getattr ( message . content [ 0 ] , " text " , " " )
conv . append_message ( conv . roles [ 1 ] , parsed_content )
2024-01-18 23:43:09 -08:00
else :
raise ValueError ( f " Unknown role: { msg_role } " )
# Add a blank message for the assistant.
conv . append_message ( conv . roles [ 1 ] , None )
return conv
# llama2 template
# reference: https://huggingface.co/blog/codellama#conversational-instructions
# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212
register_conv_template (
Conversation (
name = " llama-2 " ,
system_template = " [INST] <<SYS>> \n {system_message} \n <</SYS>> \n \n " ,
roles = ( " [INST] " , " [/INST] " ) ,
sep_style = SeparatorStyle . LLAMA2 ,
sep = " " ,
sep2 = " </s><s> " ,
stop_str = [ " [INST] " , " [/INST] " , " <<SYS>> " , " <</SYS>> " ] ,
)
)
register_conv_template (
Conversation (
name = " chatml " ,
system_template = " <|im_start|>system \n {system_message} " ,
2024-04-30 15:53:39 -07:00
system_message = " You are a helpful assistant. " ,
2024-01-18 23:43:09 -08:00
roles = ( " <|im_start|>user " , " <|im_start|>assistant " ) ,
sep_style = SeparatorStyle . CHATML ,
sep = " <|im_end|> " ,
stop_str = [ " <|endoftext|> " , " <|im_end|> " ] ,
)
)
2024-01-19 17:03:33 -08:00
2024-08-24 05:11:16 +08:00
register_conv_template (
Conversation (
name = " chatml-llava " ,
system_template = " <|im_start|>system \n {system_message} " ,
system_message = " You are a helpful assistant. " ,
roles = ( " <|im_start|>user " , " <|im_start|>assistant " ) ,
sep_style = SeparatorStyle . CHATML ,
sep = " <|im_end|> " ,
stop_str = [ " <|endoftext|> " , " <|im_end|> " ] ,
)
)
2024-01-19 17:03:33 -08:00
register_conv_template (
Conversation (
name = " vicuna_v1.1 " ,
system_message = " A chat between a curious user and an artificial intelligence assistant. "
" The assistant gives helpful, detailed, and polite answers to the user ' s questions. " ,
roles = ( " USER " , " ASSISTANT " ) ,
sep_style = SeparatorStyle . ADD_COLON_TWO ,
sep = " " ,
sep2 = " </s> " ,
)
)
2024-07-30 05:18:03 +10:00
2024-10-21 15:01:21 -07:00
register_conv_template (
Conversation (
name = " llama_3_vision " ,
system_message = " You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. " ,
system_template = " <|start_header_id|>system<|end_header_id|> \n \n {system_message} <|eot_id|> " ,
roles = ( " user " , " assistant " ) ,
sep_style = SeparatorStyle . LLAMA3 ,
sep = " " ,
stop_str = [ " <|end_of_text|> " , " <|eot_id|> " ] ,
image_token = " <|image|> " ,
)
)
2024-08-24 05:11:16 +08:00
register_conv_template (
Conversation (
name = " llava_llama_3 " ,
system_message = " You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. " ,
system_template = " <|start_header_id|>system<|end_header_id|> \n \n {system_message} <|eot_id|> " ,
roles = ( " user " , " assistant " ) ,
sep_style = SeparatorStyle . LLAMA3 ,
sep = " " ,
stop_str = [ " <|end_of_text|> " , " <|eot_id|> " ] ,
)
)
2024-07-30 05:18:03 +10:00
# Reference: https://github.com/InternLM/lmdeploy/blob/387bf54b4f124e72aab30ae9755f562e435d3d01/lmdeploy/model.py#L425-L442
register_conv_template (
Conversation (
name = " internlm2-chat " ,
system_template = " <|im_start|>system \n {system_message} " ,
roles = ( " <|im_start|>user " , " <|im_start|>assistant " ) ,
sep = " \n " ,
stop_str = [ " <|im_end|> " , " <|action_end|> " ] ,
)
)
2024-10-19 21:44:38 -07:00
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
register_conv_template (
Conversation (
name = " qwen2-vl " ,
system_message = " You are a helpful assistant. " ,
system_template = " <|im_start|>system \n {system_message} " ,
roles = ( " <|im_start|>user " , " <|im_start|>assistant " ) ,
sep = " <|im_end|> \n " ,
sep_style = SeparatorStyle . ADD_NEW_LINE_SINGLE ,
stop_str = [ " <|im_end|> " ] ,
image_token = " <|vision_start|><|image_pad|><|vision_end|> " ,
)
)