Support jinja as chat template file (#1104)
This commit is contained in:
@@ -117,7 +117,7 @@ def create_streaming_error_response(
|
|||||||
return json_str
|
return json_str
|
||||||
|
|
||||||
|
|
||||||
def load_chat_template_for_openai_api(chat_template_arg):
|
def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg):
|
||||||
global chat_template_name
|
global chat_template_name
|
||||||
|
|
||||||
print(f"Use chat template: {chat_template_arg}")
|
print(f"Use chat template: {chat_template_arg}")
|
||||||
@@ -127,27 +127,38 @@ def load_chat_template_for_openai_api(chat_template_arg):
|
|||||||
f"Chat template {chat_template_arg} is not a built-in template name "
|
f"Chat template {chat_template_arg} is not a built-in template name "
|
||||||
"or a valid chat template file path."
|
"or a valid chat template file path."
|
||||||
)
|
)
|
||||||
with open(chat_template_arg, "r") as filep:
|
if chat_template_arg.endswith(".jinja"):
|
||||||
template = json.load(filep)
|
with open(chat_template_arg, "r") as f:
|
||||||
try:
|
chat_template = "".join(f.readlines()).strip("\n")
|
||||||
sep_style = SeparatorStyle[template["sep_style"]]
|
tokenizer_manager.tokenizer.chat_template = chat_template.replace(
|
||||||
except KeyError:
|
"\\n", "\n"
|
||||||
raise ValueError(
|
|
||||||
f"Unknown separator style: {template['sep_style']}"
|
|
||||||
) from None
|
|
||||||
register_conv_template(
|
|
||||||
Conversation(
|
|
||||||
name=template["name"],
|
|
||||||
system_template=template["system"] + "\n{system_message}",
|
|
||||||
system_message=template.get("system_message", ""),
|
|
||||||
roles=(template["user"], template["assistant"]),
|
|
||||||
sep_style=sep_style,
|
|
||||||
sep=template.get("sep", "\n"),
|
|
||||||
stop_str=template["stop_str"],
|
|
||||||
),
|
|
||||||
override=True,
|
|
||||||
)
|
)
|
||||||
chat_template_name = template["name"]
|
chat_template_name = None
|
||||||
|
else:
|
||||||
|
assert chat_template_arg.endswith(
|
||||||
|
".json"
|
||||||
|
), "unrecognized format of chat template file"
|
||||||
|
with open(chat_template_arg, "r") as filep:
|
||||||
|
template = json.load(filep)
|
||||||
|
try:
|
||||||
|
sep_style = SeparatorStyle[template["sep_style"]]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown separator style: {template['sep_style']}"
|
||||||
|
) from None
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name=template["name"],
|
||||||
|
system_template=template["system"] + "\n{system_message}",
|
||||||
|
system_message=template.get("system_message", ""),
|
||||||
|
roles=(template["user"], template["assistant"]),
|
||||||
|
sep_style=sep_style,
|
||||||
|
sep=template.get("sep", "\n"),
|
||||||
|
stop_str=template["stop_str"],
|
||||||
|
),
|
||||||
|
override=True,
|
||||||
|
)
|
||||||
|
chat_template_name = template["name"]
|
||||||
else:
|
else:
|
||||||
chat_template_name = chat_template_arg
|
chat_template_name = chat_template_arg
|
||||||
|
|
||||||
|
|||||||
@@ -288,6 +288,8 @@ def launch_server(
|
|||||||
|
|
||||||
# Launch processes
|
# Launch processes
|
||||||
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
||||||
|
if server_args.chat_template:
|
||||||
|
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
||||||
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
||||||
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
||||||
|
|
||||||
@@ -375,11 +377,6 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
||||||
maybe_set_triton_cache_manager()
|
maybe_set_triton_cache_manager()
|
||||||
|
|
||||||
# Set global chat template
|
|
||||||
if server_args.chat_template:
|
|
||||||
# TODO: replace this with huggingface transformers template
|
|
||||||
load_chat_template_for_openai_api(server_args.chat_template)
|
|
||||||
|
|
||||||
# Check flashinfer version
|
# Check flashinfer version
|
||||||
if not server_args.disable_flashinfer:
|
if not server_args.disable_flashinfer:
|
||||||
assert_pkg_version(
|
assert_pkg_version(
|
||||||
|
|||||||
Reference in New Issue
Block a user