[Feature] Function Calling (#2544)
Co-authored-by: Haoyu Wang <120358163+HaoyuWang4188@users.noreply.github.com>
This commit is contained in:
146
docs/backend/function_calling.ipynb
Normal file
146
docs/backend/function_calling.ipynb
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Function Calling\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook provides a quick-start guide to use function tooling using SGLang chat completions API\n",
|
||||||
|
"\n",
|
||||||
|
"## Supported Models\n",
|
||||||
|
"\n",
|
||||||
|
"Currently, we added the support for tools calling in the following models:\n",
|
||||||
|
" - Llama 3.2 models\n",
|
||||||
|
" - Llama 3.1 models\n",
|
||||||
|
" - Qwen 2.5 models\n",
|
||||||
|
" - InternLM Models"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Usage\n",
|
||||||
|
"\n",
|
||||||
|
"### Launch a server\n",
|
||||||
|
"\n",
|
||||||
|
"This code block is equivalent to executing\n",
|
||||||
|
"\n",
|
||||||
|
"`python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
|
||||||
|
"--port 30000 --host 0.0.0.0`\n",
|
||||||
|
"in your terminal and wait for the server to be ready. Once the server is running, you can send test requests using curl or requests. The server implements the OpenAI-compatible APIs."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from sglang.utils import (\n",
|
||||||
|
" execute_shell_command,\n",
|
||||||
|
" wait_for_server,\n",
|
||||||
|
" terminate_process,\n",
|
||||||
|
" print_highlight,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"server_process = execute_shell_command(\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0\n",
|
||||||
|
"\"\"\"\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"wait_for_server(\"http://localhost:30000\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Single Round Invocation"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from openai import OpenAI\n",
|
||||||
|
"\n",
|
||||||
|
"tools = [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"type\": \"function\",\n",
|
||||||
|
" \"function\": {\n",
|
||||||
|
" \"name\": \"get_current_weather\",\n",
|
||||||
|
" \"description\": \"Get the current weather in a given location\",\n",
|
||||||
|
" \"parameters\": {\n",
|
||||||
|
" \"type\": \"object\",\n",
|
||||||
|
" \"properties\": {\n",
|
||||||
|
" \"location\": {\n",
|
||||||
|
" \"type\": \"string\",\n",
|
||||||
|
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
|
||||||
|
" },\n",
|
||||||
|
" \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
|
||||||
|
" },\n",
|
||||||
|
" \"required\": [\"location\"],\n",
|
||||||
|
" },\n",
|
||||||
|
" },\n",
|
||||||
|
" }\n",
|
||||||
|
"]\n",
|
||||||
|
"messages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston today?\"}]\n",
|
||||||
|
"\n",
|
||||||
|
"client = OpenAI(api_key=\"YOUR_API_KEY\", base_url=\"http://0.0.0.0:30000/v1\")\n",
|
||||||
|
"model_name = client.models.list().data[0].id\n",
|
||||||
|
"response = client.chat.completions.create(\n",
|
||||||
|
" model=model_name,\n",
|
||||||
|
" messages=messages,\n",
|
||||||
|
" temperature=0.8,\n",
|
||||||
|
" top_p=0.8,\n",
|
||||||
|
" stream=False,\n",
|
||||||
|
" tools=tools,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"print(response)\n",
|
||||||
|
"\n",
|
||||||
|
"\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"ChatCompletion(id='d6f620e1767e490d85b5ce45c15151cf', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=None, refusal=None, \n",
|
||||||
|
"role='assistant', audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='0', function=Function(arguments='{\"a\": \"3\", \"b\": \"5\"}', name='add'), type='function')]), \n",
|
||||||
|
"matched_stop=128008)], created=1735411703, model='meta-llama/Llama-3.2-1B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, \n",
|
||||||
|
"usage=CompletionUsage(completion_tokens=23, prompt_tokens=198, total_tokens=221, completion_tokens_details=None, prompt_tokens_details=None))\n",
|
||||||
|
"\n",
|
||||||
|
"\"\"\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"terminate_process(server_process)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## How to support a new model?\n",
|
||||||
|
"\n",
|
||||||
|
"For adding support of more different models:\n",
|
||||||
|
" 1. Update the `TOOLS_TAG_LIST` in `sglang/srt/utils.py` with the tool tag used by the model.\n",
|
||||||
|
" 2. Add support in `parse_tool_response` function for converting into tool calls `sglang/srt/utils.py`\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
||||||
@@ -65,10 +65,13 @@ from sglang.srt.openai_api.protocol import (
|
|||||||
FileDeleteResponse,
|
FileDeleteResponse,
|
||||||
FileRequest,
|
FileRequest,
|
||||||
FileResponse,
|
FileResponse,
|
||||||
|
FunctionResponse,
|
||||||
LogProbs,
|
LogProbs,
|
||||||
|
ToolCall,
|
||||||
TopLogprob,
|
TopLogprob,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -879,6 +882,21 @@ def v1_chat_generate_request(
|
|||||||
# None skips any image processing in GenerateReqInput.
|
# None skips any image processing in GenerateReqInput.
|
||||||
if not isinstance(request.messages, str):
|
if not isinstance(request.messages, str):
|
||||||
# Apply chat template and its stop strings.
|
# Apply chat template and its stop strings.
|
||||||
|
tools = None
|
||||||
|
if request.tools and request.tool_choice != "none":
|
||||||
|
request.skip_special_tokens = False
|
||||||
|
if request.stream:
|
||||||
|
logger.warning("Streaming is not supported with tools.")
|
||||||
|
request.stream = False
|
||||||
|
if not isinstance(request.tool_choice, str):
|
||||||
|
tools = [
|
||||||
|
item.function.model_dump()
|
||||||
|
for item in request.tools
|
||||||
|
if item.function.name == request.tool_choice.function.name
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
tools = [item.function.model_dump() for item in request.tools]
|
||||||
|
|
||||||
if chat_template_name is None:
|
if chat_template_name is None:
|
||||||
openai_compatible_messages = []
|
openai_compatible_messages = []
|
||||||
for message in request.messages:
|
for message in request.messages:
|
||||||
@@ -902,6 +920,7 @@ def v1_chat_generate_request(
|
|||||||
openai_compatible_messages,
|
openai_compatible_messages,
|
||||||
tokenize=True,
|
tokenize=True,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
|
tools=tools,
|
||||||
)
|
)
|
||||||
if assistant_prefix:
|
if assistant_prefix:
|
||||||
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
|
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
|
||||||
@@ -1041,11 +1060,46 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
|||||||
|
|
||||||
finish_reason = ret_item["meta_info"]["finish_reason"]
|
finish_reason = ret_item["meta_info"]["finish_reason"]
|
||||||
|
|
||||||
|
tool_calls = None
|
||||||
|
text = ret_item["text"]
|
||||||
|
|
||||||
|
if isinstance(request, list):
|
||||||
|
tool_choice = request[idx].tool_choice
|
||||||
|
tools = request[idx].tools
|
||||||
|
else:
|
||||||
|
tool_choice = request.tool_choice
|
||||||
|
tools = request.tools
|
||||||
|
|
||||||
|
if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]):
|
||||||
|
if finish_reason == "stop":
|
||||||
|
finish_reason = "tool_calls"
|
||||||
|
try:
|
||||||
|
text, call_info_list = parse_tool_response(text, tools) # noqa
|
||||||
|
tool_calls = [
|
||||||
|
ToolCall(
|
||||||
|
id=str(call_info[0]),
|
||||||
|
function=FunctionResponse(
|
||||||
|
name=call_info[1], arguments=call_info[2]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for call_info in call_info_list
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception: {e}")
|
||||||
|
return create_error_response(
|
||||||
|
HTTPStatus.BAD_REQUEST,
|
||||||
|
"Failed to parse fc related info to json format!",
|
||||||
|
)
|
||||||
|
|
||||||
if to_file:
|
if to_file:
|
||||||
# to make the choice data json serializable
|
# to make the choice data json serializable
|
||||||
choice_data = {
|
choice_data = {
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"message": {"role": "assistant", "content": ret_item["text"]},
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": ret_item["text"] if tool_calls is None else None,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
},
|
||||||
"logprobs": choice_logprobs,
|
"logprobs": choice_logprobs,
|
||||||
"finish_reason": (finish_reason["type"] if finish_reason else ""),
|
"finish_reason": (finish_reason["type"] if finish_reason else ""),
|
||||||
"matched_stop": (
|
"matched_stop": (
|
||||||
@@ -1057,7 +1111,11 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
|||||||
else:
|
else:
|
||||||
choice_data = ChatCompletionResponseChoice(
|
choice_data = ChatCompletionResponseChoice(
|
||||||
index=idx,
|
index=idx,
|
||||||
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
message=ChatMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=ret_item["text"] if tool_calls is None else None,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
),
|
||||||
logprobs=choice_logprobs,
|
logprobs=choice_logprobs,
|
||||||
finish_reason=(finish_reason["type"] if finish_reason else ""),
|
finish_reason=(finish_reason["type"] if finish_reason else ""),
|
||||||
matched_stop=(
|
matched_stop=(
|
||||||
|
|||||||
@@ -257,6 +257,34 @@ class ResponseFormat(BaseModel):
|
|||||||
json_schema: Optional[JsonSchemaResponseFormat] = None
|
json_schema: Optional[JsonSchemaResponseFormat] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Function(BaseModel):
|
||||||
|
"""Function descriptions."""
|
||||||
|
|
||||||
|
description: Optional[str] = Field(default=None, examples=[None])
|
||||||
|
name: str
|
||||||
|
parameters: Optional[object] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Tool(BaseModel):
|
||||||
|
"""Function wrapper."""
|
||||||
|
|
||||||
|
type: str = Field(default="function", examples=["function"])
|
||||||
|
function: Function
|
||||||
|
|
||||||
|
|
||||||
|
class ToolChoiceFuncName(BaseModel):
|
||||||
|
"""The name of tool choice function."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ToolChoice(BaseModel):
|
||||||
|
"""The tool choice definition."""
|
||||||
|
|
||||||
|
function: ToolChoiceFuncName
|
||||||
|
type: Literal["function"] = Field(default="function", examples=["function"])
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
# Ordered by official OpenAI API documentation
|
# Ordered by official OpenAI API documentation
|
||||||
# https://platform.openai.com/docs/api-reference/chat/create
|
# https://platform.openai.com/docs/api-reference/chat/create
|
||||||
@@ -277,6 +305,10 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
top_p: float = 1.0
|
top_p: float = 1.0
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
|
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
|
||||||
|
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
|
||||||
|
default="auto", examples=["none"]
|
||||||
|
) # noqa
|
||||||
|
|
||||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||||
top_k: int = -1
|
top_k: int = -1
|
||||||
@@ -292,9 +324,25 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
ebnf: Optional[str] = None
|
ebnf: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionResponse(BaseModel):
|
||||||
|
"""Function response."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCall(BaseModel):
|
||||||
|
"""Tool call response."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
function: FunctionResponse
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: Optional[str] = None
|
role: Optional[str] = None
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
|
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseChoice(BaseModel):
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
|
|||||||
@@ -1273,3 +1273,65 @@ def dataclass_to_string_truncated(data, max_length=2048):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return str(data)
|
return str(data)
|
||||||
|
|
||||||
|
|
||||||
|
TOOLS_TAG_LIST = ["<|plugin|>", "<function=", "<tool_call>", "<|python_tag|>"]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_tool_response(text, tools, **kwargs):
|
||||||
|
"""Parse model response containing tool information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text(str): model response in string format
|
||||||
|
tools(List): tools from user request
|
||||||
|
"""
|
||||||
|
if "<|plugin|>" in text: # internlm2
|
||||||
|
text, action = text.split("<|action_start|><|plugin|>")
|
||||||
|
action = action.split("<|action_end|>".strip())[0]
|
||||||
|
action = action[action.find("{") :]
|
||||||
|
action = json.loads(action)
|
||||||
|
name, parameters = action["name"], json.dumps(
|
||||||
|
action.get("parameters", action.get("arguments", {})), ensure_ascii=False
|
||||||
|
)
|
||||||
|
call_info_list = [(name, parameters)]
|
||||||
|
elif "<function=" in text: # llama3.1
|
||||||
|
action, _ = text.split("</function>")
|
||||||
|
parameters = action[action.find("{") :]
|
||||||
|
name = action.split("<function=")[1].split(">{")[0]
|
||||||
|
call_info_list = [(name, parameters)]
|
||||||
|
elif "<tool_call>" in text and "</tool_call>" in text: # qwen2.5
|
||||||
|
# get tool_call in text
|
||||||
|
pattern = r"<tool_call>(.*?)</tool_call>"
|
||||||
|
match_result_list = re.findall(pattern, text, re.DOTALL)
|
||||||
|
call_info_list = []
|
||||||
|
for match_result in match_result_list:
|
||||||
|
action = json.loads(match_result)
|
||||||
|
call_info_list.append(
|
||||||
|
(action["name"], json.dumps(action["arguments"], ensure_ascii=False))
|
||||||
|
)
|
||||||
|
# get text outside of tags
|
||||||
|
if not text.startswith("<tool_call>"):
|
||||||
|
text = text[: text.find("<tool_call>")]
|
||||||
|
elif not text.endswith("</tool_call>"):
|
||||||
|
text = text[text.rfind("</tool_call>") + len("</tool_call>") :]
|
||||||
|
else:
|
||||||
|
text = ""
|
||||||
|
elif "<|python_tag|>" in text: # llama3.2
|
||||||
|
_, action = text.split("<|python_tag|>")
|
||||||
|
action = json.loads(action)
|
||||||
|
name, parameters = action["name"], json.dumps(
|
||||||
|
action.get("parameters", action.get("arguments", {})), ensure_ascii=False
|
||||||
|
)
|
||||||
|
call_info_list = [(name, parameters)]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unexpected model response: {text}")
|
||||||
|
|
||||||
|
call_info_list = [
|
||||||
|
(
|
||||||
|
[tool.function.name for tool in tools].index(call_info[0]),
|
||||||
|
call_info[0],
|
||||||
|
call_info[1],
|
||||||
|
)
|
||||||
|
for call_info in call_info_list
|
||||||
|
]
|
||||||
|
return text, call_info_list
|
||||||
|
|||||||
@@ -622,6 +622,58 @@ class TestOpenAIServerEBNF(unittest.TestCase):
|
|||||||
text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape"
|
text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_function_calling_format(self):
|
||||||
|
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "add",
|
||||||
|
"description": "Compute the sum of two numbers",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": {
|
||||||
|
"type": "int",
|
||||||
|
"description": "A number",
|
||||||
|
},
|
||||||
|
"b": {
|
||||||
|
"type": "int",
|
||||||
|
"description": "A number",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["a", "b"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Compute (3+5)"}]
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.8,
|
||||||
|
stream=False,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
tool_calls = response.choices[0].message.tool_calls
|
||||||
|
|
||||||
|
assert (
|
||||||
|
content is None
|
||||||
|
), "When tools provided by the response, content should be None"
|
||||||
|
assert (
|
||||||
|
isinstance(tool_calls, list) and len(tool_calls) > 0
|
||||||
|
), "Format not matched, tool_calls should be a list"
|
||||||
|
|
||||||
|
function_name = tool_calls[0].function.name
|
||||||
|
assert (
|
||||||
|
function_name == "add"
|
||||||
|
), "Function name should be add for the above response"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user