[Feature] Function Calling (#2544)

Co-authored-by: Haoyu Wang <120358163+HaoyuWang4188@users.noreply.github.com>
This commit is contained in:
Tanjiro
2024-12-29 11:28:52 +05:30
committed by GitHub
parent fd28640dc5
commit 8ee9a8501a
5 changed files with 368 additions and 2 deletions

View File

@@ -65,10 +65,13 @@ from sglang.srt.openai_api.protocol import (
FileDeleteResponse,
FileRequest,
FileResponse,
FunctionResponse,
LogProbs,
ToolCall,
TopLogprob,
UsageInfo,
)
from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
@@ -879,6 +882,21 @@ def v1_chat_generate_request(
# None skips any image processing in GenerateReqInput.
if not isinstance(request.messages, str):
# Apply chat template and its stop strings.
tools = None
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
if request.stream:
logger.warning("Streaming is not supported with tools.")
request.stream = False
if not isinstance(request.tool_choice, str):
tools = [
item.function.model_dump()
for item in request.tools
if item.function.name == request.tool_choice.function.name
]
else:
tools = [item.function.model_dump() for item in request.tools]
if chat_template_name is None:
openai_compatible_messages = []
for message in request.messages:
@@ -902,6 +920,7 @@ def v1_chat_generate_request(
openai_compatible_messages,
tokenize=True,
add_generation_prompt=True,
tools=tools,
)
if assistant_prefix:
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
@@ -1041,11 +1060,46 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
finish_reason = ret_item["meta_info"]["finish_reason"]
tool_calls = None
text = ret_item["text"]
if isinstance(request, list):
tool_choice = request[idx].tool_choice
tools = request[idx].tools
else:
tool_choice = request.tool_choice
tools = request.tools
if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]):
if finish_reason == "stop":
finish_reason = "tool_calls"
try:
text, call_info_list = parse_tool_response(text, tools) # noqa
tool_calls = [
ToolCall(
id=str(call_info[0]),
function=FunctionResponse(
name=call_info[1], arguments=call_info[2]
),
)
for call_info in call_info_list
]
except Exception as e:
logger.error(f"Exception: {e}")
return create_error_response(
HTTPStatus.BAD_REQUEST,
"Failed to parse fc related info to json format!",
)
if to_file:
# to make the choice data json serializable
choice_data = {
"index": 0,
"message": {"role": "assistant", "content": ret_item["text"]},
"message": {
"role": "assistant",
"content": ret_item["text"] if tool_calls is None else None,
"tool_calls": tool_calls,
},
"logprobs": choice_logprobs,
"finish_reason": (finish_reason["type"] if finish_reason else ""),
"matched_stop": (
@@ -1057,7 +1111,11 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
else:
choice_data = ChatCompletionResponseChoice(
index=idx,
message=ChatMessage(role="assistant", content=ret_item["text"]),
message=ChatMessage(
role="assistant",
content=ret_item["text"] if tool_calls is None else None,
tool_calls=tool_calls,
),
logprobs=choice_logprobs,
finish_reason=(finish_reason["type"] if finish_reason else ""),
matched_stop=(

View File

@@ -257,6 +257,34 @@ class ResponseFormat(BaseModel):
json_schema: Optional[JsonSchemaResponseFormat] = None
class Function(BaseModel):
"""Function descriptions."""
description: Optional[str] = Field(default=None, examples=[None])
name: str
parameters: Optional[object] = None
class Tool(BaseModel):
"""Function wrapper."""
type: str = Field(default="function", examples=["function"])
function: Function
class ToolChoiceFuncName(BaseModel):
"""The name of tool choice function."""
name: str
class ToolChoice(BaseModel):
"""The tool choice definition."""
function: ToolChoiceFuncName
type: Literal["function"] = Field(default="function", examples=["function"])
class ChatCompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
@@ -277,6 +305,10 @@ class ChatCompletionRequest(BaseModel):
temperature: float = 0.7
top_p: float = 1.0
user: Optional[str] = None
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
default="auto", examples=["none"]
) # noqa
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1
@@ -292,9 +324,25 @@ class ChatCompletionRequest(BaseModel):
ebnf: Optional[str] = None
class FunctionResponse(BaseModel):
"""Function response."""
name: str
arguments: str
class ToolCall(BaseModel):
"""Tool call response."""
id: str
type: Literal["function"] = "function"
function: FunctionResponse
class ChatMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
class ChatCompletionResponseChoice(BaseModel):

View File

@@ -1273,3 +1273,65 @@ def dataclass_to_string_truncated(data, max_length=2048):
)
else:
return str(data)
TOOLS_TAG_LIST = ["<|plugin|>", "<function=", "<tool_call>", "<|python_tag|>"]
def parse_tool_response(text, tools, **kwargs):
"""Parse model response containing tool information.
Args:
text(str): model response in string format
tools(List): tools from user request
"""
if "<|plugin|>" in text: # internlm2
text, action = text.split("<|action_start|><|plugin|>")
action = action.split("<|action_end|>".strip())[0]
action = action[action.find("{") :]
action = json.loads(action)
name, parameters = action["name"], json.dumps(
action.get("parameters", action.get("arguments", {})), ensure_ascii=False
)
call_info_list = [(name, parameters)]
elif "<function=" in text: # llama3.1
action, _ = text.split("</function>")
parameters = action[action.find("{") :]
name = action.split("<function=")[1].split(">{")[0]
call_info_list = [(name, parameters)]
elif "<tool_call>" in text and "</tool_call>" in text: # qwen2.5
# get tool_call in text
pattern = r"<tool_call>(.*?)</tool_call>"
match_result_list = re.findall(pattern, text, re.DOTALL)
call_info_list = []
for match_result in match_result_list:
action = json.loads(match_result)
call_info_list.append(
(action["name"], json.dumps(action["arguments"], ensure_ascii=False))
)
# get text outside of tags
if not text.startswith("<tool_call>"):
text = text[: text.find("<tool_call>")]
elif not text.endswith("</tool_call>"):
text = text[text.rfind("</tool_call>") + len("</tool_call>") :]
else:
text = ""
elif "<|python_tag|>" in text: # llama3.2
_, action = text.split("<|python_tag|>")
action = json.loads(action)
name, parameters = action["name"], json.dumps(
action.get("parameters", action.get("arguments", {})), ensure_ascii=False
)
call_info_list = [(name, parameters)]
else:
raise RuntimeError(f"Unexpected model response: {text}")
call_info_list = [
(
[tool.function.name for tool in tools].index(call_info[0]),
call_info[0],
call_info[1],
)
for call_info in call_info_list
]
return text, call_info_list