[Feature] Support "strict" in function calling (#4310)
This commit is contained in:
@@ -1,12 +1,21 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
from json import JSONDecodeError, JSONDecoder
|
from json import JSONDecodeError, JSONDecoder
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
|
||||||
|
|
||||||
import partial_json_parser
|
import partial_json_parser
|
||||||
|
from partial_json_parser.core.exceptions import MalformedJSON
|
||||||
from partial_json_parser.core.options import Allow
|
from partial_json_parser.core.options import Allow
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from sglang.srt.openai_api.protocol import (
|
||||||
|
StructuralTagResponseFormat,
|
||||||
|
StructuresResponseFormat,
|
||||||
|
Tool,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -19,14 +28,6 @@ TOOLS_TAG_LIST = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Function(BaseModel):
|
|
||||||
"""Function Tool Template."""
|
|
||||||
|
|
||||||
description: Optional[str] = Field(default=None, examples=[None])
|
|
||||||
name: Optional[str] = None
|
|
||||||
parameters: Optional[object] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ToolCallItem(BaseModel):
|
class ToolCallItem(BaseModel):
|
||||||
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
|
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
|
||||||
|
|
||||||
@@ -74,7 +75,22 @@ class StreamingParseResult:
|
|||||||
self.calls = calls or []
|
self.calls = calls or []
|
||||||
|
|
||||||
|
|
||||||
class BaseFormatDetector:
|
@dataclass
|
||||||
|
class StructureInfo:
|
||||||
|
begin: str
|
||||||
|
end: str
|
||||||
|
trigger: str
|
||||||
|
|
||||||
|
|
||||||
|
_GetInfoFunc = Callable[[str], StructureInfo]
|
||||||
|
"""
|
||||||
|
helper alias of function
|
||||||
|
ususally it is a function that takes a name string and returns a StructureInfo object,
|
||||||
|
which can be used to construct a structural_tag object
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFormatDetector(ABC):
|
||||||
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
|
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -90,26 +106,12 @@ class BaseFormatDetector:
|
|||||||
self.bot_token = ""
|
self.bot_token = ""
|
||||||
self.eot_token = ""
|
self.eot_token = ""
|
||||||
|
|
||||||
def parse_base_json(self, action: Any, tools: List[Function]) -> List[ToolCallItem]:
|
def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
|
||||||
tool_indices = {
|
tool_indices = {
|
||||||
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
|
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
|
||||||
}
|
}
|
||||||
if not isinstance(action, list):
|
if not isinstance(action, list):
|
||||||
name = action.get("name")
|
action = [action]
|
||||||
if not name or name not in tool_indices:
|
|
||||||
logger.warning(f"Model attempted to call undefined function: {name}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
return [
|
|
||||||
ToolCallItem(
|
|
||||||
tool_index=tool_indices[name],
|
|
||||||
name=name,
|
|
||||||
parameters=json.dumps(
|
|
||||||
action.get("parameters") or action.get("arguments", {}),
|
|
||||||
ensure_ascii=False,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for act in action:
|
for act in action:
|
||||||
@@ -125,12 +127,13 @@ class BaseFormatDetector:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Model attempted to call undefined function: {name}")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def detect_and_parse(
|
@abstractmethod
|
||||||
self, text: str, tools: List[Function]
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||||
) -> StreamingParseResult:
|
|
||||||
"""
|
"""
|
||||||
Parses the text in one go. Returns success=True if the format matches, otherwise False.
|
Parses the text in one go. Returns success=True if the format matches, otherwise False.
|
||||||
Note that leftover_text here represents "content that this parser will not consume further".
|
Note that leftover_text here represents "content that this parser will not consume further".
|
||||||
@@ -139,7 +142,7 @@ class BaseFormatDetector:
|
|||||||
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
||||||
|
|
||||||
def parse_streaming_increment(
|
def parse_streaming_increment(
|
||||||
self, new_text: str, tools: List[Function]
|
self, new_text: str, tools: List[Tool]
|
||||||
) -> StreamingParseResult:
|
) -> StreamingParseResult:
|
||||||
"""
|
"""
|
||||||
Streaming incremental parsing with tool validation.
|
Streaming incremental parsing with tool validation.
|
||||||
@@ -198,7 +201,7 @@ class BaseFormatDetector:
|
|||||||
obj["arguments"] = obj["parameters"]
|
obj["arguments"] = obj["parameters"]
|
||||||
tool_call_arr.append(obj)
|
tool_call_arr.append(obj)
|
||||||
|
|
||||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
except MalformedJSON:
|
||||||
return StreamingParseResult()
|
return StreamingParseResult()
|
||||||
|
|
||||||
if len(tool_call_arr) == 0:
|
if len(tool_call_arr) == 0:
|
||||||
@@ -304,6 +307,14 @@ class BaseFormatDetector:
|
|||||||
logger.error(f"Error in parse_streaming_increment: {e}")
|
logger.error(f"Error in parse_streaming_increment: {e}")
|
||||||
return StreamingParseResult()
|
return StreamingParseResult()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def has_tool_call(self, text: str) -> bool:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def structure_info(self) -> _GetInfoFunc:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class Qwen25Detector(BaseFormatDetector):
|
class Qwen25Detector(BaseFormatDetector):
|
||||||
"""
|
"""
|
||||||
@@ -324,9 +335,7 @@ class Qwen25Detector(BaseFormatDetector):
|
|||||||
"""Check if the text contains a Qwen 2.5 format tool call."""
|
"""Check if the text contains a Qwen 2.5 format tool call."""
|
||||||
return self.bot_token in text
|
return self.bot_token in text
|
||||||
|
|
||||||
def detect_and_parse(
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||||
self, text: str, tools: List[Function]
|
|
||||||
) -> StreamingParseResult:
|
|
||||||
"""
|
"""
|
||||||
One-time parsing: Detects and parses tool calls in the provided text.
|
One-time parsing: Detects and parses tool calls in the provided text.
|
||||||
|
|
||||||
@@ -346,6 +355,13 @@ class Qwen25Detector(BaseFormatDetector):
|
|||||||
calls.extend(self.parse_base_json(match_result, tools))
|
calls.extend(self.parse_base_json(match_result, tools))
|
||||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||||
|
|
||||||
|
def structure_info(self) -> _GetInfoFunc:
|
||||||
|
return lambda name: StructureInfo(
|
||||||
|
begin='<tool_call>{"name":"' + name + '", "arguments":',
|
||||||
|
end="}</tool_call>",
|
||||||
|
trigger="<tool_call>",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MistralDetector(BaseFormatDetector):
|
class MistralDetector(BaseFormatDetector):
|
||||||
"""
|
"""
|
||||||
@@ -380,9 +396,7 @@ class MistralDetector(BaseFormatDetector):
|
|||||||
else:
|
else:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def detect_and_parse(
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||||
self, text: str, tools: List[Function]
|
|
||||||
) -> StreamingParseResult:
|
|
||||||
"""
|
"""
|
||||||
One-time parsing: Detects and parses tool calls in the provided text.
|
One-time parsing: Detects and parses tool calls in the provided text.
|
||||||
|
|
||||||
@@ -403,6 +417,13 @@ class MistralDetector(BaseFormatDetector):
|
|||||||
calls.extend(self.parse_base_json(match_result, tools))
|
calls.extend(self.parse_base_json(match_result, tools))
|
||||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||||
|
|
||||||
|
def structure_info(self) -> _GetInfoFunc:
|
||||||
|
return lambda name: StructureInfo(
|
||||||
|
begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
|
||||||
|
end="}]",
|
||||||
|
trigger="[TOOL_CALLS]",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Llama32Detector(BaseFormatDetector):
|
class Llama32Detector(BaseFormatDetector):
|
||||||
"""
|
"""
|
||||||
@@ -421,15 +442,15 @@ class Llama32Detector(BaseFormatDetector):
|
|||||||
# prefix the output with the <|python_tag|> token
|
# prefix the output with the <|python_tag|> token
|
||||||
return "<|python_tag|>" in text or text.startswith("{")
|
return "<|python_tag|>" in text or text.startswith("{")
|
||||||
|
|
||||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
||||||
"""Parse function calls from text, handling multiple JSON objects."""
|
"""Parse function calls from text, handling multiple JSON objects."""
|
||||||
if "<|python_tag|>" not in text and not text.startswith("{"):
|
if "<|python_tag|>" not in text and not text.startswith("{"):
|
||||||
return StreamingParseResult(normal_text=text, calls=[])
|
return StreamingParseResult(normal_text=text, calls=[])
|
||||||
|
|
||||||
if "<|python_tag|>" in text:
|
if "<|python_tag|>" in text:
|
||||||
_, action_text = text.split("<|python_tag|>")
|
normal_text, action_text = text.split("<|python_tag|>")
|
||||||
else:
|
else:
|
||||||
action_text = text
|
normal_text, action_text = "", text
|
||||||
|
|
||||||
# Split by semicolon and process each part
|
# Split by semicolon and process each part
|
||||||
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
|
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
|
||||||
@@ -449,6 +470,13 @@ class Llama32Detector(BaseFormatDetector):
|
|||||||
calls = self.parse_base_json(all_actions, tools)
|
calls = self.parse_base_json(all_actions, tools)
|
||||||
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
return StreamingParseResult(normal_text=normal_text, calls=calls)
|
||||||
|
|
||||||
|
def structure_info(self) -> _GetInfoFunc:
|
||||||
|
return lambda name: StructureInfo(
|
||||||
|
begin='<|python_tag|>{"name":"' + name + '", "arguments":',
|
||||||
|
end="}",
|
||||||
|
trigger="<|python_tag|>",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MultiFormatParser:
|
class MultiFormatParser:
|
||||||
def __init__(self, detectors: List[BaseFormatDetector]):
|
def __init__(self, detectors: List[BaseFormatDetector]):
|
||||||
@@ -458,7 +486,7 @@ class MultiFormatParser:
|
|||||||
self.detectors = detectors
|
self.detectors = detectors
|
||||||
|
|
||||||
def parse_once(
|
def parse_once(
|
||||||
self, text: str, tools: List[Function]
|
self, text: str, tools: List[Tool]
|
||||||
) -> Tuple[str, list[ToolCallItem]]:
|
) -> Tuple[str, list[ToolCallItem]]:
|
||||||
"""
|
"""
|
||||||
One-time parsing: Loop through detectors until there are no new matches or text is exhausted
|
One-time parsing: Loop through detectors until there are no new matches or text is exhausted
|
||||||
@@ -480,7 +508,7 @@ class MultiFormatParser:
|
|||||||
return final_normal_text, final_calls
|
return final_normal_text, final_calls
|
||||||
|
|
||||||
def parse_streaming_increment(
|
def parse_streaming_increment(
|
||||||
self, new_text: str, tools: List[Function]
|
self, new_text: str, tools: List[Tool]
|
||||||
) -> Tuple[str, list[ToolCallItem]]:
|
) -> Tuple[str, list[ToolCallItem]]:
|
||||||
"""
|
"""
|
||||||
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
|
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
|
||||||
@@ -512,13 +540,13 @@ class FunctionCallParser:
|
|||||||
and returns the resulting normal_text and calls to the upper layer (or SSE).
|
and returns the resulting normal_text and calls to the upper layer (or SSE).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ToolCallParserEnum: Dict[str, BaseFormatDetector] = {
|
ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = {
|
||||||
"llama3": Llama32Detector,
|
"llama3": Llama32Detector,
|
||||||
"qwen25": Qwen25Detector,
|
"qwen25": Qwen25Detector,
|
||||||
"mistral": MistralDetector,
|
"mistral": MistralDetector,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, tools: List[Function], tool_call_parser: str = None):
|
def __init__(self, tools: List[Tool], tool_call_parser: str):
|
||||||
detectors = []
|
detectors = []
|
||||||
if tool_call_parser:
|
if tool_call_parser:
|
||||||
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
|
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
|
||||||
@@ -563,3 +591,40 @@ class FunctionCallParser:
|
|||||||
chunk_text, self.tools
|
chunk_text, self.tools
|
||||||
)
|
)
|
||||||
return normal_text, calls
|
return normal_text, calls
|
||||||
|
|
||||||
|
def structure_infos(self) -> List[_GetInfoFunc]:
|
||||||
|
"""
|
||||||
|
Returns a list of structure_info functions for each detector
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
detector.structure_info() for detector in self.multi_format_parser.detectors
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_structure_tag(self) -> StructuralTagResponseFormat:
|
||||||
|
tool_structures: List[StructuresResponseFormat] = list()
|
||||||
|
tool_trigger_set: Set[str] = set()
|
||||||
|
|
||||||
|
for wrapper in self.structure_infos():
|
||||||
|
for tool in self.tools:
|
||||||
|
function = tool.function
|
||||||
|
name = function.name
|
||||||
|
assert name is not None
|
||||||
|
info = wrapper(name)
|
||||||
|
|
||||||
|
# accept all if not strict, otherwise only accept the schema
|
||||||
|
schema = function.parameters if function.strict else {}
|
||||||
|
|
||||||
|
tool_structures.append(
|
||||||
|
StructuresResponseFormat(
|
||||||
|
begin=info.begin,
|
||||||
|
schema=schema, # type: ignore
|
||||||
|
end=info.end,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tool_trigger_set.add(info.trigger)
|
||||||
|
|
||||||
|
return StructuralTagResponseFormat(
|
||||||
|
type="structural_tag",
|
||||||
|
structures=tool_structures,
|
||||||
|
triggers=list(tool_trigger_set),
|
||||||
|
)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Dict, List
|
from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
from fastapi import HTTPException, Request, UploadFile
|
from fastapi import HTTPException, Request, UploadFile
|
||||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||||
@@ -38,7 +38,7 @@ from sglang.srt.conversation import (
|
|||||||
generate_embedding_convs,
|
generate_embedding_convs,
|
||||||
register_conv_template,
|
register_conv_template,
|
||||||
)
|
)
|
||||||
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
|
from sglang.srt.function_call_parser import FunctionCallParser
|
||||||
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
||||||
from sglang.srt.openai_api.protocol import (
|
from sglang.srt.openai_api.protocol import (
|
||||||
BatchRequest,
|
BatchRequest,
|
||||||
@@ -915,6 +915,7 @@ def v1_chat_generate_request(
|
|||||||
# - image_data: None or a list of image strings (URLs or base64 strings).
|
# - image_data: None or a list of image strings (URLs or base64 strings).
|
||||||
# - audio_data: None or a list of audio strings (URLs).
|
# - audio_data: None or a list of audio strings (URLs).
|
||||||
# None skips any image processing in GenerateReqInput.
|
# None skips any image processing in GenerateReqInput.
|
||||||
|
strict_tag = None
|
||||||
if not isinstance(request.messages, str):
|
if not isinstance(request.messages, str):
|
||||||
# Apply chat template and its stop strings.
|
# Apply chat template and its stop strings.
|
||||||
tools = None
|
tools = None
|
||||||
@@ -929,6 +930,10 @@ def v1_chat_generate_request(
|
|||||||
else:
|
else:
|
||||||
tools = [item.function.model_dump() for item in request.tools]
|
tools = [item.function.model_dump() for item in request.tools]
|
||||||
|
|
||||||
|
tool_call_parser = tokenizer_manager.server_args.tool_call_parser
|
||||||
|
parser = FunctionCallParser(request.tools, tool_call_parser)
|
||||||
|
strict_tag = parser.get_structure_tag()
|
||||||
|
|
||||||
if chat_template_name is None:
|
if chat_template_name is None:
|
||||||
openai_compatible_messages = []
|
openai_compatible_messages = []
|
||||||
for message in request.messages:
|
for message in request.messages:
|
||||||
@@ -1036,6 +1041,22 @@ def v1_chat_generate_request(
|
|||||||
sampling_params["structural_tag"] = convert_json_schema_to_str(
|
sampling_params["structural_tag"] = convert_json_schema_to_str(
|
||||||
request.response_format.model_dump(by_alias=True)
|
request.response_format.model_dump(by_alias=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if strict_tag is not None:
|
||||||
|
if (
|
||||||
|
sampling_params.get("regex")
|
||||||
|
or sampling_params.get("ebnf")
|
||||||
|
or sampling_params.get("structural_tag")
|
||||||
|
or sampling_params.get("json_schema")
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"Constrained decoding is not compatible with tool calls."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sampling_params["structural_tag"] = convert_json_schema_to_str(
|
||||||
|
strict_tag.model_dump(by_alias=True)
|
||||||
|
)
|
||||||
|
|
||||||
sampling_params_list.append(sampling_params)
|
sampling_params_list.append(sampling_params)
|
||||||
|
|
||||||
image_data_list.append(image_data)
|
image_data_list.append(image_data)
|
||||||
|
|||||||
@@ -287,6 +287,7 @@ class Function(BaseModel):
|
|||||||
description: Optional[str] = Field(default=None, examples=[None])
|
description: Optional[str] = Field(default=None, examples=[None])
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
parameters: Optional[object] = None
|
parameters: Optional[object] = None
|
||||||
|
strict: bool = False
|
||||||
|
|
||||||
|
|
||||||
class Tool(BaseModel):
|
class Tool(BaseModel):
|
||||||
|
|||||||
@@ -237,12 +237,61 @@ class TestOpenAIServerFunctionCalling(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertIn("a", args_obj, "Missing parameter 'a'")
|
self.assertIn("a", args_obj, "Missing parameter 'a'")
|
||||||
self.assertIn("b", args_obj, "Missing parameter 'b'")
|
self.assertIn("b", args_obj, "Missing parameter 'b'")
|
||||||
self.assertEqual(
|
self.assertEqual(str(args_obj["a"]), "5", "Parameter a should be 5")
|
||||||
args_obj["a"],
|
self.assertEqual(str(args_obj["b"]), "7", "Parameter b should be 7")
|
||||||
5,
|
|
||||||
"Parameter a should be 5",
|
def test_function_call_strict(self):
|
||||||
|
"""
|
||||||
|
Test: Whether the strict mode of function calling works as expected.
|
||||||
|
- When strict mode is enabled, the AI should not return a function call if the function name is not recognized.
|
||||||
|
"""
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "sub",
|
||||||
|
"description": "Compute the difference of two integers",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"int_a": {
|
||||||
|
"type": "int",
|
||||||
|
"description": "First integer",
|
||||||
|
},
|
||||||
|
"int_b": {
|
||||||
|
"type": "int",
|
||||||
|
"description": "Second integer",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["int_a", "int_b"],
|
||||||
|
},
|
||||||
|
"strict": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Please compute 5 - 7, using your tool."}
|
||||||
|
]
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.8,
|
||||||
|
stream=False,
|
||||||
|
tools=tools,
|
||||||
)
|
)
|
||||||
self.assertEqual(args_obj["b"], 7, "Parameter b should be 7")
|
|
||||||
|
tool_calls = response.choices[0].message.tool_calls
|
||||||
|
function_name = tool_calls[0].function.name
|
||||||
|
arguments = tool_calls[0].function.arguments
|
||||||
|
args_obj = json.loads(arguments)
|
||||||
|
|
||||||
|
self.assertEqual(function_name, "sub", "Function name should be 'sub'")
|
||||||
|
self.assertEqual(str(args_obj["int_a"]), "5", "Parameter int_a should be 5")
|
||||||
|
self.assertEqual(str(args_obj["int_b"]), "7", "Parameter int_b should be 7")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user