Feature/function calling update (#2700)
Co-authored-by: Mingyuan Ma <mamingyuan2001@berkeley.edu> Co-authored-by: Chayenne <zhaochen20@outlook.com> Co-authored-by: shuaills <shishuaiuoe@gmail.com>
This commit is contained in:
@@ -20,7 +20,7 @@ import os
|
||||
import time
|
||||
import uuid
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, UploadFile
|
||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||
@@ -40,6 +40,7 @@ from sglang.srt.conversation import (
|
||||
generate_chat_conv,
|
||||
register_conv_template,
|
||||
)
|
||||
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
||||
from sglang.srt.openai_api.protocol import (
|
||||
BatchRequest,
|
||||
@@ -71,7 +72,6 @@ from sglang.srt.openai_api.protocol import (
|
||||
TopLogprob,
|
||||
UsageInfo,
|
||||
)
|
||||
from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -309,6 +309,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
||||
ret,
|
||||
to_file=True,
|
||||
cache_report=tokenizer_manager.server_args.enable_cache_report,
|
||||
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
||||
)
|
||||
else:
|
||||
responses = v1_generate_response(
|
||||
@@ -877,9 +878,6 @@ def v1_chat_generate_request(
|
||||
tools = None
|
||||
if request.tools and request.tool_choice != "none":
|
||||
request.skip_special_tokens = False
|
||||
if request.stream:
|
||||
logger.warning("Streaming is not supported with tools.")
|
||||
request.stream = False
|
||||
if not isinstance(request.tool_choice, str):
|
||||
tools = [
|
||||
item.function.model_dump()
|
||||
@@ -908,12 +906,26 @@ def v1_chat_generate_request(
|
||||
openai_compatible_messages = openai_compatible_messages[:-1]
|
||||
else:
|
||||
assistant_prefix = None
|
||||
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
||||
openai_compatible_messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
try:
|
||||
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
||||
openai_compatible_messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
tools=tools,
|
||||
)
|
||||
except:
|
||||
# This except branch will be triggered when the chosen model
|
||||
# has a different tools input format that is not compatiable
|
||||
# with openAI's apply_chat_template tool_call format, like Mistral.
|
||||
tools = [t if "function" in t else {"function": t} for t in tools]
|
||||
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
||||
openai_compatible_messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
if assistant_prefix:
|
||||
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
|
||||
stop = request.stop
|
||||
@@ -1005,7 +1017,9 @@ def v1_chat_generate_request(
|
||||
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
||||
|
||||
|
||||
def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
||||
def v1_chat_generate_response(
|
||||
request, ret, to_file=False, cache_report=False, tool_call_parser=None
|
||||
):
|
||||
choices = []
|
||||
|
||||
for idx, ret_item in enumerate(ret):
|
||||
@@ -1066,12 +1080,13 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
||||
if finish_reason == "stop":
|
||||
finish_reason = "tool_calls"
|
||||
try:
|
||||
text, call_info_list = parse_tool_response(text, tools) # noqa
|
||||
parser = FunctionCallParser(tools, tool_call_parser)
|
||||
full_normal_text, call_info_list = parser.parse_non_stream(text)
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=str(call_info[0]),
|
||||
id=str(call_info.tool_index),
|
||||
function=FunctionResponse(
|
||||
name=call_info[1], arguments=call_info[2]
|
||||
name=call_info.name, arguments=call_info.parameters
|
||||
),
|
||||
)
|
||||
for call_info in call_info_list
|
||||
@@ -1172,6 +1187,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
|
||||
|
||||
if adapted_request.stream:
|
||||
parser_dict = {}
|
||||
|
||||
async def generate_stream_resp():
|
||||
is_firsts = {}
|
||||
@@ -1184,6 +1200,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
adapted_request, raw_request
|
||||
):
|
||||
index = content.get("index", 0)
|
||||
text = content["text"]
|
||||
|
||||
is_first = is_firsts.get(index, True)
|
||||
stream_buffer = stream_buffers.get(index, "")
|
||||
@@ -1263,29 +1280,111 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
|
||||
text = content["text"]
|
||||
delta = text[len(stream_buffer) :]
|
||||
stream_buffer = stream_buffer + delta
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=delta),
|
||||
finish_reason=(finish_reason["type"] if finish_reason else ""),
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
logprobs=choice_logprobs,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
new_stream_buffer = stream_buffer + delta
|
||||
|
||||
is_firsts[index] = is_first
|
||||
stream_buffers[index] = stream_buffer
|
||||
n_prev_tokens[index] = n_prev_token
|
||||
if request.tool_choice != "none" and request.tools:
|
||||
if index not in parser_dict:
|
||||
parser_dict[index] = FunctionCallParser(
|
||||
tools=request.tools,
|
||||
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
||||
)
|
||||
parser = parser_dict[index]
|
||||
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
# parse_increment => returns (normal_text, calls)
|
||||
normal_text, calls = parser.parse_stream_chunk(delta)
|
||||
|
||||
# 1) if there's normal_text, output it as normal content
|
||||
if normal_text:
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=normal_text),
|
||||
finish_reason=(
|
||||
finish_reason["type"] if finish_reason else ""
|
||||
),
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# 2) if we found calls, we output them as separate chunk(s)
|
||||
for call_item in calls:
|
||||
# transform call_item -> FunctionResponse + ToolCall
|
||||
|
||||
if (
|
||||
content["meta_info"]["finish_reason"]
|
||||
and content["meta_info"]["finish_reason"]["type"]
|
||||
== "stop"
|
||||
):
|
||||
latest_delta_len = 0
|
||||
if isinstance(call_item.parameters, str):
|
||||
latest_delta_len = len(call_item.parameters)
|
||||
|
||||
expected_call = json.dumps(
|
||||
parser.multi_format_parser.detectors[0]
|
||||
.prev_tool_call_arr[index]
|
||||
.get("arguments", {}),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
actual_call = parser.multi_format_parser.detectors[
|
||||
0
|
||||
].streamed_args_for_tool[index]
|
||||
if latest_delta_len > 0:
|
||||
actual_call = actual_call[:-latest_delta_len]
|
||||
remaining_call = expected_call.replace(
|
||||
actual_call, "", 1
|
||||
)
|
||||
call_item.parameters = remaining_call
|
||||
|
||||
tool_call = ToolCall(
|
||||
id=str(call_item.tool_index),
|
||||
function=FunctionResponse(
|
||||
name=call_item.name,
|
||||
arguments=call_item.parameters,
|
||||
),
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(
|
||||
role="assistant", tool_calls=[tool_call]
|
||||
),
|
||||
finish_reason="tool_call",
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
stream_buffers[index] = new_stream_buffer
|
||||
is_firsts[index] = is_first
|
||||
|
||||
else:
|
||||
# No tool calls => just treat this as normal text
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=delta),
|
||||
finish_reason=(
|
||||
finish_reason["type"] if finish_reason else ""
|
||||
),
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
logprobs=choice_logprobs,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
stream_buffers[index] = new_stream_buffer
|
||||
is_firsts[index] = is_first
|
||||
if request.stream_options and request.stream_options.include_usage:
|
||||
total_prompt_tokens = sum(
|
||||
tokens
|
||||
@@ -1333,7 +1432,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
ret = [ret]
|
||||
|
||||
response = v1_chat_generate_response(
|
||||
request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report
|
||||
request,
|
||||
ret,
|
||||
cache_report=tokenizer_manager.server_args.enable_cache_report,
|
||||
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -262,7 +262,7 @@ class Function(BaseModel):
|
||||
"""Function descriptions."""
|
||||
|
||||
description: Optional[str] = Field(default=None, examples=[None])
|
||||
name: str
|
||||
name: Optional[str] = None
|
||||
parameters: Optional[object] = None
|
||||
|
||||
|
||||
@@ -276,7 +276,7 @@ class Tool(BaseModel):
|
||||
class ToolChoiceFuncName(BaseModel):
|
||||
"""The name of tool choice function."""
|
||||
|
||||
name: str
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class ToolChoice(BaseModel):
|
||||
@@ -329,8 +329,8 @@ class ChatCompletionRequest(BaseModel):
|
||||
class FunctionResponse(BaseModel):
|
||||
"""Function response."""
|
||||
|
||||
name: str
|
||||
arguments: str
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
@@ -367,6 +367,7 @@ class ChatCompletionResponse(BaseModel):
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user