[Feature] Function Calling (#2544)

Co-authored-by: Haoyu Wang <120358163+HaoyuWang4188@users.noreply.github.com>
This commit is contained in:
Tanjiro
2024-12-29 11:28:52 +05:30
committed by GitHub
parent fd28640dc5
commit 8ee9a8501a
5 changed files with 368 additions and 2 deletions

View File

@@ -65,10 +65,13 @@ from sglang.srt.openai_api.protocol import (
FileDeleteResponse,
FileRequest,
FileResponse,
FunctionResponse,
LogProbs,
ToolCall,
TopLogprob,
UsageInfo,
)
from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
@@ -879,6 +882,21 @@ def v1_chat_generate_request(
# None skips any image processing in GenerateReqInput.
if not isinstance(request.messages, str):
# Apply chat template and its stop strings.
tools = None
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
if request.stream:
logger.warning("Streaming is not supported with tools.")
request.stream = False
if not isinstance(request.tool_choice, str):
tools = [
item.function.model_dump()
for item in request.tools
if item.function.name == request.tool_choice.function.name
]
else:
tools = [item.function.model_dump() for item in request.tools]
if chat_template_name is None:
openai_compatible_messages = []
for message in request.messages:
@@ -902,6 +920,7 @@ def v1_chat_generate_request(
openai_compatible_messages,
tokenize=True,
add_generation_prompt=True,
tools=tools,
)
if assistant_prefix:
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
@@ -1041,11 +1060,46 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
finish_reason = ret_item["meta_info"]["finish_reason"]
tool_calls = None
text = ret_item["text"]
if isinstance(request, list):
tool_choice = request[idx].tool_choice
tools = request[idx].tools
else:
tool_choice = request.tool_choice
tools = request.tools
if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]):
if finish_reason == "stop":
finish_reason = "tool_calls"
try:
text, call_info_list = parse_tool_response(text, tools) # noqa
tool_calls = [
ToolCall(
id=str(call_info[0]),
function=FunctionResponse(
name=call_info[1], arguments=call_info[2]
),
)
for call_info in call_info_list
]
except Exception as e:
logger.error(f"Exception: {e}")
return create_error_response(
HTTPStatus.BAD_REQUEST,
"Failed to parse fc related info to json format!",
)
if to_file:
# to make the choice data json serializable
choice_data = {
"index": 0,
"message": {"role": "assistant", "content": ret_item["text"]},
"message": {
"role": "assistant",
"content": ret_item["text"] if tool_calls is None else None,
"tool_calls": tool_calls,
},
"logprobs": choice_logprobs,
"finish_reason": (finish_reason["type"] if finish_reason else ""),
"matched_stop": (
@@ -1057,7 +1111,11 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
else:
choice_data = ChatCompletionResponseChoice(
index=idx,
message=ChatMessage(role="assistant", content=ret_item["text"]),
message=ChatMessage(
role="assistant",
content=ret_item["text"] if tool_calls is None else None,
tool_calls=tool_calls,
),
logprobs=choice_logprobs,
finish_reason=(finish_reason["type"] if finish_reason else ""),
matched_stop=(