[Feature] Support new parameter - EBNF in xgrammar (#2526)

This commit is contained in:
Adarsh Shirawalmath
2024-12-26 18:42:41 +05:30
committed by GitHub
parent 08effbff35
commit acb340728c
8 changed files with 384 additions and 2 deletions

View File

@@ -366,6 +366,11 @@ class OpenAI(BaseBackend):
def openai_completion(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
# if "ebnf" is in kwargs, warn and remove
if "ebnf" in kwargs:
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
del kwargs["ebnf"]
for attempt in range(retries):
try:
if is_chat:
@@ -398,6 +403,11 @@ def openai_completion(
def openai_completion_stream(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
# if "ebnf" is in kwargs, warn and remove
if "ebnf" in kwargs:
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
del kwargs["ebnf"]
for attempt in range(retries):
try:
if is_chat:

View File

@@ -126,6 +126,12 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
)
return None
elif key_type == "ebnf":
try:
ctx = self.grammar_compiler.compile_grammar(key_string)
except RuntimeError as e:
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
return None
elif key_type == "regex":
logger.warning(
"regex hasn't been supported by xgrammar yet. This is skipped."

View File

@@ -589,12 +589,15 @@ class Scheduler:
if (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
or req.sampling_params.ebnf is not None
):
assert self.grammar_backend is not None
if req.sampling_params.json_schema is not None:
key = ("json", req.sampling_params.json_schema)
elif req.sampling_params.regex is not None:
key = ("regex", req.sampling_params.regex)
elif req.sampling_params.ebnf is not None:
key = ("ebnf", req.sampling_params.ebnf)
req.grammar = self.grammar_backend.get_cached_value(key)
if not req.grammar:

View File

@@ -517,6 +517,7 @@ def v1_generate_request(
"repetition_penalty": request.repetition_penalty,
"regex": request.regex,
"json_schema": request.json_schema,
"ebnf": request.ebnf,
"n": request.n,
"no_stop_trim": request.no_stop_trim,
"ignore_eos": request.ignore_eos,
@@ -692,6 +693,14 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
async def v1_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()
if "extra_body" in request_json:
extra = request_json["extra_body"]
if "ebnf" in extra:
request_json["ebnf"] = extra["ebnf"]
if "regex" in extra:
request_json["regex"] = extra["regex"]
# remove extra_body to avoid pydantic conflict
del request_json["extra_body"]
all_requests = [CompletionRequest(**request_json)]
adapted_request, request = v1_generate_request(all_requests)
@@ -936,6 +945,7 @@ def v1_chat_generate_request(
"frequency_penalty": request.frequency_penalty,
"repetition_penalty": request.repetition_penalty,
"regex": request.regex,
"ebnf": request.ebnf,
"n": request.n,
"no_stop_trim": request.no_stop_trim,
"ignore_eos": request.ignore_eos,
@@ -1108,6 +1118,15 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
async def v1_chat_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()
if "extra_body" in request_json:
extra = request_json["extra_body"]
# For example, if 'ebnf' is given:
if "ebnf" in extra:
request_json["ebnf"] = extra["ebnf"]
if "regex" in extra:
request_json["regex"] = extra["regex"]
# remove extra_body to avoid pydantic conflict
del request_json["extra_body"]
all_requests = [ChatCompletionRequest(**request_json)]
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)

View File

@@ -179,6 +179,7 @@ class CompletionRequest(BaseModel):
ignore_eos: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
ebnf: Optional[str] = None
class CompletionResponseChoice(BaseModel):
@@ -288,6 +289,7 @@ class ChatCompletionRequest(BaseModel):
ignore_eos: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
ebnf: Optional[str] = None
class ChatMessage(BaseModel):

View File

@@ -36,6 +36,7 @@ class SamplingParams:
regex: Optional[str] = None,
n: int = 1,
json_schema: Optional[str] = None,
ebnf: Optional[str] = None,
no_stop_trim: bool = False,
ignore_eos: bool = False,
skip_special_tokens: bool = True,
@@ -60,6 +61,7 @@ class SamplingParams:
self.regex = regex
self.n = n
self.json_schema = json_schema
self.ebnf = ebnf
self.no_stop_trim = no_stop_trim
# Process some special cases
@@ -111,8 +113,13 @@ class SamplingParams:
f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
f"{self.min_new_tokens}."
)
if self.regex is not None and self.json_schema is not None:
raise ValueError("regex and json_schema cannot be both set.")
grammars = [
self.json_schema,
self.regex,
self.ebnf,
] # since mutually exclusive, only one can be set
if sum(x is not None for x in grammars) > 1:
raise ValueError("Only one of regex, json_schema, or ebnf can be set.")
def normalize(self, tokenizer):
# Process stop strings