[Feature] Support "strict" in function calling (#4310)
This commit is contained in:
@@ -20,7 +20,7 @@ import os
|
||||
import time
|
||||
import uuid
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
from fastapi import HTTPException, Request, UploadFile
|
||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||
@@ -38,7 +38,7 @@ from sglang.srt.conversation import (
|
||||
generate_embedding_convs,
|
||||
register_conv_template,
|
||||
)
|
||||
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
|
||||
from sglang.srt.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
||||
from sglang.srt.openai_api.protocol import (
|
||||
BatchRequest,
|
||||
@@ -915,6 +915,7 @@ def v1_chat_generate_request(
|
||||
# - image_data: None or a list of image strings (URLs or base64 strings).
|
||||
# - audio_data: None or a list of audio strings (URLs).
|
||||
# None skips any image processing in GenerateReqInput.
|
||||
strict_tag = None
|
||||
if not isinstance(request.messages, str):
|
||||
# Apply chat template and its stop strings.
|
||||
tools = None
|
||||
@@ -929,6 +930,10 @@ def v1_chat_generate_request(
|
||||
else:
|
||||
tools = [item.function.model_dump() for item in request.tools]
|
||||
|
||||
tool_call_parser = tokenizer_manager.server_args.tool_call_parser
|
||||
parser = FunctionCallParser(request.tools, tool_call_parser)
|
||||
strict_tag = parser.get_structure_tag()
|
||||
|
||||
if chat_template_name is None:
|
||||
openai_compatible_messages = []
|
||||
for message in request.messages:
|
||||
@@ -1036,6 +1041,22 @@ def v1_chat_generate_request(
|
||||
sampling_params["structural_tag"] = convert_json_schema_to_str(
|
||||
request.response_format.model_dump(by_alias=True)
|
||||
)
|
||||
|
||||
if strict_tag is not None:
|
||||
if (
|
||||
sampling_params.get("regex")
|
||||
or sampling_params.get("ebnf")
|
||||
or sampling_params.get("structural_tag")
|
||||
or sampling_params.get("json_schema")
|
||||
):
|
||||
logger.warning(
|
||||
"Constrained decoding is not compatible with tool calls."
|
||||
)
|
||||
else:
|
||||
sampling_params["structural_tag"] = convert_json_schema_to_str(
|
||||
strict_tag.model_dump(by_alias=True)
|
||||
)
|
||||
|
||||
sampling_params_list.append(sampling_params)
|
||||
|
||||
image_data_list.append(image_data)
|
||||
|
||||
@@ -287,6 +287,7 @@ class Function(BaseModel):
|
||||
description: Optional[str] = Field(default=None, examples=[None])
|
||||
name: Optional[str] = None
|
||||
parameters: Optional[object] = None
|
||||
strict: bool = False
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user