[Feature] Support "strict" in function calling (#4310)

This commit is contained in:
DarkSharpness
2025-03-25 14:15:25 +09:00
committed by GitHub
parent 2d1b83e57a
commit ac3fae8445
4 changed files with 188 additions and 52 deletions

View File

@@ -20,7 +20,7 @@ import os
import time
import uuid
from http import HTTPStatus
from typing import Dict, List
from typing import Any, Dict, List, Set
from fastapi import HTTPException, Request, UploadFile
from fastapi.responses import ORJSONResponse, StreamingResponse
@@ -38,7 +38,7 @@ from sglang.srt.conversation import (
generate_embedding_convs,
register_conv_template,
)
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
from sglang.srt.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.openai_api.protocol import (
BatchRequest,
@@ -915,6 +915,7 @@ def v1_chat_generate_request(
# - image_data: None or a list of image strings (URLs or base64 strings).
# - audio_data: None or a list of audio strings (URLs).
# None skips any image processing in GenerateReqInput.
strict_tag = None
if not isinstance(request.messages, str):
# Apply chat template and its stop strings.
tools = None
@@ -929,6 +930,10 @@ def v1_chat_generate_request(
else:
tools = [item.function.model_dump() for item in request.tools]
tool_call_parser = tokenizer_manager.server_args.tool_call_parser
parser = FunctionCallParser(request.tools, tool_call_parser)
strict_tag = parser.get_structure_tag()
if chat_template_name is None:
openai_compatible_messages = []
for message in request.messages:
@@ -1036,6 +1041,22 @@ def v1_chat_generate_request(
sampling_params["structural_tag"] = convert_json_schema_to_str(
request.response_format.model_dump(by_alias=True)
)
if strict_tag is not None:
if (
sampling_params.get("regex")
or sampling_params.get("ebnf")
or sampling_params.get("structural_tag")
or sampling_params.get("json_schema")
):
logger.warning(
"Constrained decoding is not compatible with tool calls."
)
else:
sampling_params["structural_tag"] = convert_json_schema_to_str(
strict_tag.model_dump(by_alias=True)
)
sampling_params_list.append(sampling_params)
image_data_list.append(image_data)

View File

@@ -287,6 +287,7 @@ class Function(BaseModel):
description: Optional[str] = Field(default=None, examples=[None])
name: Optional[str] = None
parameters: Optional[object] = None
strict: bool = False
class Tool(BaseModel):