[Feature] Support new parameter - EBNF in xgrammar (#2526)
This commit is contained in:
committed by
GitHub
parent
08effbff35
commit
acb340728c
@@ -366,6 +366,11 @@ class OpenAI(BaseBackend):
|
||||
def openai_completion(
|
||||
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
||||
):
|
||||
# if "ebnf" is in kwargs, warn and remove
|
||||
if "ebnf" in kwargs:
|
||||
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
|
||||
del kwargs["ebnf"]
|
||||
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
if is_chat:
|
||||
@@ -398,6 +403,11 @@ def openai_completion(
|
||||
def openai_completion_stream(
|
||||
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
||||
):
|
||||
# if "ebnf" is in kwargs, warn and remove
|
||||
if "ebnf" in kwargs:
|
||||
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
|
||||
del kwargs["ebnf"]
|
||||
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
if is_chat:
|
||||
|
||||
@@ -126,6 +126,12 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
||||
)
|
||||
return None
|
||||
elif key_type == "ebnf":
|
||||
try:
|
||||
ctx = self.grammar_compiler.compile_grammar(key_string)
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
|
||||
return None
|
||||
elif key_type == "regex":
|
||||
logger.warning(
|
||||
"regex hasn't been supported by xgrammar yet. This is skipped."
|
||||
|
||||
@@ -589,12 +589,15 @@ class Scheduler:
|
||||
if (
|
||||
req.sampling_params.json_schema is not None
|
||||
or req.sampling_params.regex is not None
|
||||
or req.sampling_params.ebnf is not None
|
||||
):
|
||||
assert self.grammar_backend is not None
|
||||
if req.sampling_params.json_schema is not None:
|
||||
key = ("json", req.sampling_params.json_schema)
|
||||
elif req.sampling_params.regex is not None:
|
||||
key = ("regex", req.sampling_params.regex)
|
||||
elif req.sampling_params.ebnf is not None:
|
||||
key = ("ebnf", req.sampling_params.ebnf)
|
||||
|
||||
req.grammar = self.grammar_backend.get_cached_value(key)
|
||||
if not req.grammar:
|
||||
|
||||
@@ -517,6 +517,7 @@ def v1_generate_request(
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"regex": request.regex,
|
||||
"json_schema": request.json_schema,
|
||||
"ebnf": request.ebnf,
|
||||
"n": request.n,
|
||||
"no_stop_trim": request.no_stop_trim,
|
||||
"ignore_eos": request.ignore_eos,
|
||||
@@ -692,6 +693,14 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
||||
|
||||
async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
request_json = await raw_request.json()
|
||||
if "extra_body" in request_json:
|
||||
extra = request_json["extra_body"]
|
||||
if "ebnf" in extra:
|
||||
request_json["ebnf"] = extra["ebnf"]
|
||||
if "regex" in extra:
|
||||
request_json["regex"] = extra["regex"]
|
||||
# remove extra_body to avoid pydantic conflict
|
||||
del request_json["extra_body"]
|
||||
all_requests = [CompletionRequest(**request_json)]
|
||||
adapted_request, request = v1_generate_request(all_requests)
|
||||
|
||||
@@ -936,6 +945,7 @@ def v1_chat_generate_request(
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"regex": request.regex,
|
||||
"ebnf": request.ebnf,
|
||||
"n": request.n,
|
||||
"no_stop_trim": request.no_stop_trim,
|
||||
"ignore_eos": request.ignore_eos,
|
||||
@@ -1108,6 +1118,15 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
||||
|
||||
async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
request_json = await raw_request.json()
|
||||
if "extra_body" in request_json:
|
||||
extra = request_json["extra_body"]
|
||||
# For example, if 'ebnf' is given:
|
||||
if "ebnf" in extra:
|
||||
request_json["ebnf"] = extra["ebnf"]
|
||||
if "regex" in extra:
|
||||
request_json["regex"] = extra["regex"]
|
||||
# remove extra_body to avoid pydantic conflict
|
||||
del request_json["extra_body"]
|
||||
all_requests = [ChatCompletionRequest(**request_json)]
|
||||
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
|
||||
|
||||
|
||||
@@ -179,6 +179,7 @@ class CompletionRequest(BaseModel):
|
||||
ignore_eos: bool = False
|
||||
skip_special_tokens: bool = True
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
ebnf: Optional[str] = None
|
||||
|
||||
|
||||
class CompletionResponseChoice(BaseModel):
|
||||
@@ -288,6 +289,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
ignore_eos: bool = False
|
||||
skip_special_tokens: bool = True
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
ebnf: Optional[str] = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
|
||||
@@ -36,6 +36,7 @@ class SamplingParams:
|
||||
regex: Optional[str] = None,
|
||||
n: int = 1,
|
||||
json_schema: Optional[str] = None,
|
||||
ebnf: Optional[str] = None,
|
||||
no_stop_trim: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: bool = True,
|
||||
@@ -60,6 +61,7 @@ class SamplingParams:
|
||||
self.regex = regex
|
||||
self.n = n
|
||||
self.json_schema = json_schema
|
||||
self.ebnf = ebnf
|
||||
self.no_stop_trim = no_stop_trim
|
||||
|
||||
# Process some special cases
|
||||
@@ -111,8 +113,13 @@ class SamplingParams:
|
||||
f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
|
||||
f"{self.min_new_tokens}."
|
||||
)
|
||||
if self.regex is not None and self.json_schema is not None:
|
||||
raise ValueError("regex and json_schema cannot be both set.")
|
||||
grammars = [
|
||||
self.json_schema,
|
||||
self.regex,
|
||||
self.ebnf,
|
||||
] # since mutually exclusive, only one can be set
|
||||
if sum(x is not None for x in grammars) > 1:
|
||||
raise ValueError("Only one of regex, json_schema, or ebnf can be set.")
|
||||
|
||||
def normalize(self, tokenizer):
|
||||
# Process stop strings
|
||||
|
||||
Reference in New Issue
Block a user