[Feature] Support new parameter - EBNF in xgrammar (#2526)
This commit is contained in:
committed by
GitHub
parent
08effbff35
commit
acb340728c
@@ -366,6 +366,11 @@ class OpenAI(BaseBackend):
|
||||
def openai_completion(
|
||||
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
||||
):
|
||||
# if "ebnf" is in kwargs, warn and remove
|
||||
if "ebnf" in kwargs:
|
||||
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
|
||||
del kwargs["ebnf"]
|
||||
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
if is_chat:
|
||||
@@ -398,6 +403,11 @@ def openai_completion(
|
||||
def openai_completion_stream(
|
||||
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
||||
):
|
||||
# if "ebnf" is in kwargs, warn and remove
|
||||
if "ebnf" in kwargs:
|
||||
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
|
||||
del kwargs["ebnf"]
|
||||
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
if is_chat:
|
||||
|
||||
@@ -126,6 +126,12 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
||||
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
||||
)
|
||||
return None
|
||||
elif key_type == "ebnf":
|
||||
try:
|
||||
ctx = self.grammar_compiler.compile_grammar(key_string)
|
||||
except RuntimeError as e:
|
||||
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
|
||||
return None
|
||||
elif key_type == "regex":
|
||||
logger.warning(
|
||||
"regex hasn't been supported by xgrammar yet. This is skipped."
|
||||
|
||||
@@ -589,12 +589,15 @@ class Scheduler:
|
||||
if (
|
||||
req.sampling_params.json_schema is not None
|
||||
or req.sampling_params.regex is not None
|
||||
or req.sampling_params.ebnf is not None
|
||||
):
|
||||
assert self.grammar_backend is not None
|
||||
if req.sampling_params.json_schema is not None:
|
||||
key = ("json", req.sampling_params.json_schema)
|
||||
elif req.sampling_params.regex is not None:
|
||||
key = ("regex", req.sampling_params.regex)
|
||||
elif req.sampling_params.ebnf is not None:
|
||||
key = ("ebnf", req.sampling_params.ebnf)
|
||||
|
||||
req.grammar = self.grammar_backend.get_cached_value(key)
|
||||
if not req.grammar:
|
||||
|
||||
@@ -517,6 +517,7 @@ def v1_generate_request(
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"regex": request.regex,
|
||||
"json_schema": request.json_schema,
|
||||
"ebnf": request.ebnf,
|
||||
"n": request.n,
|
||||
"no_stop_trim": request.no_stop_trim,
|
||||
"ignore_eos": request.ignore_eos,
|
||||
@@ -692,6 +693,14 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
||||
|
||||
async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
request_json = await raw_request.json()
|
||||
if "extra_body" in request_json:
|
||||
extra = request_json["extra_body"]
|
||||
if "ebnf" in extra:
|
||||
request_json["ebnf"] = extra["ebnf"]
|
||||
if "regex" in extra:
|
||||
request_json["regex"] = extra["regex"]
|
||||
# remove extra_body to avoid pydantic conflict
|
||||
del request_json["extra_body"]
|
||||
all_requests = [CompletionRequest(**request_json)]
|
||||
adapted_request, request = v1_generate_request(all_requests)
|
||||
|
||||
@@ -936,6 +945,7 @@ def v1_chat_generate_request(
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"regex": request.regex,
|
||||
"ebnf": request.ebnf,
|
||||
"n": request.n,
|
||||
"no_stop_trim": request.no_stop_trim,
|
||||
"ignore_eos": request.ignore_eos,
|
||||
@@ -1108,6 +1118,15 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
||||
|
||||
async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
request_json = await raw_request.json()
|
||||
if "extra_body" in request_json:
|
||||
extra = request_json["extra_body"]
|
||||
# For example, if 'ebnf' is given:
|
||||
if "ebnf" in extra:
|
||||
request_json["ebnf"] = extra["ebnf"]
|
||||
if "regex" in extra:
|
||||
request_json["regex"] = extra["regex"]
|
||||
# remove extra_body to avoid pydantic conflict
|
||||
del request_json["extra_body"]
|
||||
all_requests = [ChatCompletionRequest(**request_json)]
|
||||
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
|
||||
|
||||
|
||||
@@ -179,6 +179,7 @@ class CompletionRequest(BaseModel):
|
||||
ignore_eos: bool = False
|
||||
skip_special_tokens: bool = True
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
ebnf: Optional[str] = None
|
||||
|
||||
|
||||
class CompletionResponseChoice(BaseModel):
|
||||
@@ -288,6 +289,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
ignore_eos: bool = False
|
||||
skip_special_tokens: bool = True
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
ebnf: Optional[str] = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
|
||||
@@ -36,6 +36,7 @@ class SamplingParams:
|
||||
regex: Optional[str] = None,
|
||||
n: int = 1,
|
||||
json_schema: Optional[str] = None,
|
||||
ebnf: Optional[str] = None,
|
||||
no_stop_trim: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
skip_special_tokens: bool = True,
|
||||
@@ -60,6 +61,7 @@ class SamplingParams:
|
||||
self.regex = regex
|
||||
self.n = n
|
||||
self.json_schema = json_schema
|
||||
self.ebnf = ebnf
|
||||
self.no_stop_trim = no_stop_trim
|
||||
|
||||
# Process some special cases
|
||||
@@ -111,8 +113,13 @@ class SamplingParams:
|
||||
f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
|
||||
f"{self.min_new_tokens}."
|
||||
)
|
||||
if self.regex is not None and self.json_schema is not None:
|
||||
raise ValueError("regex and json_schema cannot be both set.")
|
||||
grammars = [
|
||||
self.json_schema,
|
||||
self.regex,
|
||||
self.ebnf,
|
||||
] # since mutually exclusive, only one can be set
|
||||
if sum(x is not None for x in grammars) > 1:
|
||||
raise ValueError("Only one of regex, json_schema, or ebnf can be set.")
|
||||
|
||||
def normalize(self, tokenizer):
|
||||
# Process stop strings
|
||||
|
||||
247
test/srt/test_ebnf_constrained.py
Normal file
247
test/srt/test_ebnf_constrained.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_email
|
||||
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_greeting
|
||||
"""
|
||||
|
||||
import json
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
def setup_class(cls, disable_overlap: bool):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.ebnf_grammar = 'root ::= "test"' # Default grammar
|
||||
|
||||
other_args = [
|
||||
"--max-running-requests",
|
||||
"10",
|
||||
"--grammar-backend",
|
||||
"xgrammar",
|
||||
]
|
||||
|
||||
if disable_overlap:
|
||||
other_args += ["--disable-overlap-schedule"]
|
||||
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_args,
|
||||
)
|
||||
|
||||
|
||||
class TestEBNFConstrained(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, disable_overlap=False)
|
||||
cls.check_jump_forward = False
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_decode(
|
||||
self,
|
||||
ebnf,
|
||||
expected_patterns,
|
||||
prompt,
|
||||
return_logprob=False,
|
||||
top_logprobs_num=0,
|
||||
n=1,
|
||||
):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": 128,
|
||||
"n": n,
|
||||
"ebnf": ebnf,
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
|
||||
ret = response.json()
|
||||
print(json.dumps(ret, indent=2))
|
||||
print("=" * 100)
|
||||
|
||||
if not isinstance(ret, list):
|
||||
self.fail(f"Expected response to be a list, but got {type(ret)}")
|
||||
|
||||
for item in ret:
|
||||
text = item.get("text", "").strip()
|
||||
if not text:
|
||||
self.fail("Generated text is empty.")
|
||||
|
||||
match = False
|
||||
for pattern in expected_patterns:
|
||||
if self.regex_match(text, pattern):
|
||||
match = True
|
||||
break
|
||||
if not match:
|
||||
self.fail(f"Text '{text}' does not match any of the allowed patterns.")
|
||||
|
||||
def regex_match(self, text, pattern):
|
||||
import re
|
||||
|
||||
return re.match(pattern, text) is not None
|
||||
|
||||
def test_ebnf_generate_email(self):
|
||||
self.__class__.ebnf_grammar = 'root ::= "user@example.com"'
|
||||
allowed_patterns = [r"^user@example\.com$"]
|
||||
prompt = "Generate an email address:"
|
||||
|
||||
self.run_decode(
|
||||
ebnf=self.__class__.ebnf_grammar,
|
||||
expected_patterns=allowed_patterns,
|
||||
prompt=prompt,
|
||||
n=3,
|
||||
)
|
||||
|
||||
def test_ebnf_generate_greeting(self):
|
||||
self.__class__.ebnf_grammar = 'root ::= "Hello" | "Hi" | "Hey"'
|
||||
allowed_patterns = [r"^(Hello|Hi|Hey)$"]
|
||||
prompt = "Generate a greeting:"
|
||||
|
||||
self.run_decode(
|
||||
ebnf=self.__class__.ebnf_grammar,
|
||||
expected_patterns=allowed_patterns,
|
||||
prompt=prompt,
|
||||
n=3,
|
||||
)
|
||||
|
||||
def test_ebnf_generate_number(self):
|
||||
self.__class__.ebnf_grammar = """
|
||||
root ::= digit digit digit
|
||||
digit ::= [0-9]
|
||||
"""
|
||||
allowed_patterns = [r"^\d{3}$"]
|
||||
prompt = "Generate a three-digit number:"
|
||||
|
||||
self.run_decode(
|
||||
ebnf=self.__class__.ebnf_grammar,
|
||||
expected_patterns=allowed_patterns,
|
||||
prompt=prompt,
|
||||
n=3,
|
||||
)
|
||||
|
||||
def test_ebnf_generate_phone(self):
|
||||
self.__class__.ebnf_grammar = """
|
||||
root ::= "(" area ")" " " prefix "-" line
|
||||
area ::= [0-9] [0-9] [0-9]
|
||||
prefix ::= [0-9] [0-9] [0-9]
|
||||
line ::= [0-9] [0-9] [0-9] [0-9]
|
||||
"""
|
||||
allowed_patterns = [r"^\(\d{3}\) \d{3}-\d{4}$"]
|
||||
prompt = "Generate a phone number:"
|
||||
|
||||
self.run_decode(
|
||||
ebnf=self.__class__.ebnf_grammar,
|
||||
expected_patterns=allowed_patterns,
|
||||
prompt=prompt,
|
||||
n=3,
|
||||
)
|
||||
|
||||
def test_ebnf_generate_date(self):
|
||||
self.__class__.ebnf_grammar = """
|
||||
root ::= year "-" month "-" day
|
||||
year ::= "2024"
|
||||
month ::= "01" | "02" | "03" | "04" | "05" | "06" | "07" | "08" | "09" | "10" | "11" | "12"
|
||||
day ::= "01" | "02" | "03" | "04" | "05" | "06" | "07" | "08" | "09" | "10" |
|
||||
"11" | "12" | "13" | "14" | "15" | "16" | "17" | "18" | "19" | "20" |
|
||||
"21" | "22" | "23" | "24" | "25" | "26" | "27" | "28" | "29" | "30" | "31"
|
||||
"""
|
||||
allowed_patterns = [r"^2024-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$"]
|
||||
prompt = "Generate a date in YYYY-MM-DD format:"
|
||||
|
||||
self.run_decode(
|
||||
ebnf=self.__class__.ebnf_grammar,
|
||||
expected_patterns=allowed_patterns,
|
||||
prompt=prompt,
|
||||
n=3,
|
||||
)
|
||||
|
||||
def test_ebnf_generate_hex_color(self):
|
||||
self.__class__.ebnf_grammar = """
|
||||
root ::= "#" hex hex hex hex hex hex
|
||||
hex ::= [0-9] | [A-F]
|
||||
"""
|
||||
allowed_patterns = [r"^#[0-9A-F]{6}$"]
|
||||
prompt = "Generate a hex color code:"
|
||||
|
||||
self.run_decode(
|
||||
ebnf=self.__class__.ebnf_grammar,
|
||||
expected_patterns=allowed_patterns,
|
||||
prompt=prompt,
|
||||
n=3,
|
||||
)
|
||||
|
||||
def test_ebnf_generate_complex_json(self):
|
||||
self.__class__.ebnf_grammar = """
|
||||
root ::= object
|
||||
object ::= "{" ws pair (ws "," ws pair)* ws "}"
|
||||
pair ::= "\\"name\\"" ws ":" ws value |
|
||||
"\\"age\\"" ws ":" ws number |
|
||||
"\\"city\\"" ws ":" ws string
|
||||
value ::= string | number
|
||||
string ::= "\\"" [a-zA-Z0-9 ]+ "\\""
|
||||
number ::= [1-9] [0-9]*
|
||||
ws ::= [ ]*
|
||||
"""
|
||||
allowed_patterns = [
|
||||
r'^{\s*"name"\s*:\s*"[a-zA-Z0-9 ]+"\s*,\s*"age"\s*:\s*[1-9][0-9]*\s*,\s*"city"\s*:\s*"[a-zA-Z0-9 ]+"\s*}$',
|
||||
]
|
||||
prompt = "Generate a simple JSON with name, age, and city:"
|
||||
|
||||
self.run_decode(
|
||||
ebnf=self.__class__.ebnf_grammar,
|
||||
expected_patterns=allowed_patterns,
|
||||
prompt=prompt,
|
||||
n=3,
|
||||
)
|
||||
|
||||
def test_ebnf_generate_custom_log_format(self):
|
||||
self.__class__.ebnf_grammar = """
|
||||
root ::= logentry
|
||||
logentry ::= "[" datetime "] " level ": System.process - " message
|
||||
datetime ::= "2024-01-01T12:00:00Z"
|
||||
level ::= "INFO"
|
||||
message ::= "Operation " [a-z]+ " successfully"
|
||||
"""
|
||||
allowed_patterns = [
|
||||
r"^\[2024-01-01T12:00:00Z\] INFO: System\.process - Operation [a-z]+ successfully$"
|
||||
]
|
||||
prompt = "Generate a log entry:"
|
||||
|
||||
self.run_decode(
|
||||
ebnf=self.__class__.ebnf_grammar,
|
||||
expected_patterns=allowed_patterns,
|
||||
prompt=prompt,
|
||||
n=3,
|
||||
)
|
||||
|
||||
|
||||
class TestJumpForward(TestEBNFConstrained):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, disable_overlap=True)
|
||||
cls.check_jump_forward = True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -5,6 +5,7 @@ python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import unittest
|
||||
|
||||
@@ -535,5 +536,92 @@ The SmartHome Mini is a compact smart home assistant available in black or white
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# EBNF Test Class: TestOpenAIServerEBNF
|
||||
# Launches the server with xgrammar, has only EBNF tests
|
||||
# -------------------------------------------------------------------------
|
||||
class TestOpenAIServerEBNF(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
# passing xgrammar specifically
|
||||
other_args = ["--grammar-backend", "xgrammar"]
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=other_args,
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_ebnf(self):
|
||||
"""
|
||||
Ensure we can pass `ebnf` to the local openai server
|
||||
and that it enforces the grammar.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
ebnf_grammar = r"""
|
||||
root ::= "Hello" | "Hi" | "Hey"
|
||||
"""
|
||||
pattern = re.compile(r"^(Hello|Hi|Hey)[.!?]*\s*$")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful EBNF test bot."},
|
||||
{"role": "user", "content": "Say a greeting (Hello, Hi, or Hey)."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
extra_body={"ebnf": ebnf_grammar},
|
||||
)
|
||||
text = response.choices[0].message.content.strip()
|
||||
print("EBNF test output:", repr(text))
|
||||
self.assertTrue(len(text) > 0, "Got empty text from EBNF generation")
|
||||
self.assertRegex(text, pattern, f"Text '{text}' doesn't match EBNF choices")
|
||||
|
||||
def test_ebnf_strict_json(self):
|
||||
"""
|
||||
A stricter EBNF that produces exactly {"name":"Alice"} format
|
||||
with no trailing punctuation or extra fields.
|
||||
"""
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
ebnf_grammar = r"""
|
||||
root ::= "{" pair "}"
|
||||
pair ::= "\"name\"" ":" string
|
||||
string ::= "\"" [A-Za-z]+ "\""
|
||||
"""
|
||||
pattern = re.compile(r'^\{"name":"[A-Za-z]+"\}$')
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "EBNF mini-JSON generator."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Generate single key JSON with only letters.",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=64,
|
||||
extra_body={"ebnf": ebnf_grammar},
|
||||
)
|
||||
text = response.choices[0].message.content.strip()
|
||||
print("EBNF strict JSON test output:", repr(text))
|
||||
self.assertTrue(len(text) > 0, "Got empty text from EBNF strict JSON test")
|
||||
self.assertRegex(
|
||||
text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user