[Feature] integrate Structural Tag in xgrammar backend for function calling (#3566)
Co-authored-by: shuaills <shishuaiuoe@gmail.com>
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# ==============================================================================
|
||||
"""Constrained decoding with xgrammar backend."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
@@ -21,6 +22,7 @@ from xgrammar import (
|
||||
CompiledGrammar,
|
||||
GrammarCompiler,
|
||||
GrammarMatcher,
|
||||
StructuralTagItem,
|
||||
TokenizerInfo,
|
||||
allocate_token_bitmask,
|
||||
apply_token_bitmask_inplace,
|
||||
@@ -138,6 +140,23 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
||||
return None
|
||||
elif key_type == "structural_tag":
|
||||
try:
|
||||
structural_tag = json.loads(key_string)
|
||||
tags = [
|
||||
StructuralTagItem(
|
||||
begin=structure["begin"],
|
||||
schema=json.dumps(structure["schema"]),
|
||||
end=structure["end"],
|
||||
)
|
||||
for structure in structural_tag["structures"]
|
||||
]
|
||||
ctx = self.grammar_compiler.compile_structural_tag(
|
||||
tags, structural_tag["triggers"]
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Invalid key_type: {key_type}")
|
||||
|
||||
|
||||
@@ -710,6 +710,7 @@ class Scheduler:
|
||||
req.sampling_params.json_schema is not None
|
||||
or req.sampling_params.regex is not None
|
||||
or req.sampling_params.ebnf is not None
|
||||
or req.sampling_params.structural_tag is not None
|
||||
):
|
||||
assert self.grammar_backend is not None
|
||||
if req.sampling_params.json_schema is not None:
|
||||
@@ -718,6 +719,8 @@ class Scheduler:
|
||||
key = ("regex", req.sampling_params.regex)
|
||||
elif req.sampling_params.ebnf is not None:
|
||||
key = ("ebnf", req.sampling_params.ebnf)
|
||||
elif req.sampling_params.structural_tag:
|
||||
key = ("structural_tag", req.sampling_params.structural_tag)
|
||||
|
||||
req.grammar = self.grammar_backend.get_cached_value(key)
|
||||
if not req.grammar:
|
||||
|
||||
@@ -994,10 +994,17 @@ def v1_chat_generate_request(
|
||||
"ignore_eos": request.ignore_eos,
|
||||
"skip_special_tokens": request.skip_special_tokens,
|
||||
}
|
||||
|
||||
if request.response_format and request.response_format.type == "json_schema":
|
||||
sampling_params["json_schema"] = convert_json_schema_to_str(
|
||||
request.response_format.json_schema.schema_
|
||||
)
|
||||
elif (
|
||||
request.response_format and request.response_format.type == "structural_tag"
|
||||
):
|
||||
sampling_params["structural_tag"] = convert_json_schema_to_str(
|
||||
request.response_format.model_dump(by_alias=True)
|
||||
)
|
||||
sampling_params_list.append(sampling_params)
|
||||
|
||||
image_data_list.append(image_data)
|
||||
|
||||
@@ -258,6 +258,18 @@ class ResponseFormat(BaseModel):
|
||||
json_schema: Optional[JsonSchemaResponseFormat] = None
|
||||
|
||||
|
||||
class StructuresResponseFormat(BaseModel):
|
||||
begin: str
|
||||
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
|
||||
end: str
|
||||
|
||||
|
||||
class StructuralTagResponseFormat(BaseModel):
|
||||
type: Literal["structural_tag"]
|
||||
structures: List[StructuresResponseFormat]
|
||||
triggers: List[str]
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
"""Function descriptions."""
|
||||
|
||||
@@ -298,7 +310,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
max_tokens: Optional[int] = None
|
||||
n: int = 1
|
||||
presence_penalty: float = 0.0
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
response_format: Union[ResponseFormat, StructuralTagResponseFormat] = None
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stream: bool = False
|
||||
|
||||
@@ -45,6 +45,7 @@ class SamplingParams:
|
||||
json_schema: Optional[str] = None,
|
||||
regex: Optional[str] = None,
|
||||
ebnf: Optional[str] = None,
|
||||
structural_tag: Optional[str] = None,
|
||||
no_stop_trim: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: bool = True,
|
||||
@@ -72,6 +73,7 @@ class SamplingParams:
|
||||
self.n = n
|
||||
self.json_schema = json_schema
|
||||
self.ebnf = ebnf
|
||||
self.structural_tag = structural_tag
|
||||
self.no_stop_trim = no_stop_trim
|
||||
self.return_hidden_states = return_hidden_states
|
||||
self.custom_params = custom_params
|
||||
|
||||
Reference in New Issue
Block a user