[Feature] Support new parameter - EBNF in xgrammar (#2526)

This commit is contained in:
Adarsh Shirawalmath
2024-12-26 18:42:41 +05:30
committed by GitHub
parent 08effbff35
commit acb340728c
8 changed files with 384 additions and 2 deletions

View File

@@ -0,0 +1,247 @@
"""
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_email
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_greeting
"""
import json
import unittest
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
def setup_class(cls, disable_overlap: bool):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.ebnf_grammar = 'root ::= "test"' # Default grammar
other_args = [
"--max-running-requests",
"10",
"--grammar-backend",
"xgrammar",
]
if disable_overlap:
other_args += ["--disable-overlap-schedule"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
class TestEBNFConstrained(unittest.TestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, disable_overlap=False)
cls.check_jump_forward = False
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(
self,
ebnf,
expected_patterns,
prompt,
return_logprob=False,
top_logprobs_num=0,
n=1,
):
response = requests.post(
self.base_url + "/generate",
json={
"text": prompt,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 128,
"n": n,
"ebnf": ebnf,
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
)
ret = response.json()
print(json.dumps(ret, indent=2))
print("=" * 100)
if not isinstance(ret, list):
self.fail(f"Expected response to be a list, but got {type(ret)}")
for item in ret:
text = item.get("text", "").strip()
if not text:
self.fail("Generated text is empty.")
match = False
for pattern in expected_patterns:
if self.regex_match(text, pattern):
match = True
break
if not match:
self.fail(f"Text '{text}' does not match any of the allowed patterns.")
def regex_match(self, text, pattern):
import re
return re.match(pattern, text) is not None
def test_ebnf_generate_email(self):
self.__class__.ebnf_grammar = 'root ::= "user@example.com"'
allowed_patterns = [r"^user@example\.com$"]
prompt = "Generate an email address:"
self.run_decode(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
n=3,
)
def test_ebnf_generate_greeting(self):
self.__class__.ebnf_grammar = 'root ::= "Hello" | "Hi" | "Hey"'
allowed_patterns = [r"^(Hello|Hi|Hey)$"]
prompt = "Generate a greeting:"
self.run_decode(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
n=3,
)
def test_ebnf_generate_number(self):
self.__class__.ebnf_grammar = """
root ::= digit digit digit
digit ::= [0-9]
"""
allowed_patterns = [r"^\d{3}$"]
prompt = "Generate a three-digit number:"
self.run_decode(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
n=3,
)
def test_ebnf_generate_phone(self):
self.__class__.ebnf_grammar = """
root ::= "(" area ")" " " prefix "-" line
area ::= [0-9] [0-9] [0-9]
prefix ::= [0-9] [0-9] [0-9]
line ::= [0-9] [0-9] [0-9] [0-9]
"""
allowed_patterns = [r"^\(\d{3}\) \d{3}-\d{4}$"]
prompt = "Generate a phone number:"
self.run_decode(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
n=3,
)
def test_ebnf_generate_date(self):
self.__class__.ebnf_grammar = """
root ::= year "-" month "-" day
year ::= "2024"
month ::= "01" | "02" | "03" | "04" | "05" | "06" | "07" | "08" | "09" | "10" | "11" | "12"
day ::= "01" | "02" | "03" | "04" | "05" | "06" | "07" | "08" | "09" | "10" |
"11" | "12" | "13" | "14" | "15" | "16" | "17" | "18" | "19" | "20" |
"21" | "22" | "23" | "24" | "25" | "26" | "27" | "28" | "29" | "30" | "31"
"""
allowed_patterns = [r"^2024-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$"]
prompt = "Generate a date in YYYY-MM-DD format:"
self.run_decode(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
n=3,
)
def test_ebnf_generate_hex_color(self):
self.__class__.ebnf_grammar = """
root ::= "#" hex hex hex hex hex hex
hex ::= [0-9] | [A-F]
"""
allowed_patterns = [r"^#[0-9A-F]{6}$"]
prompt = "Generate a hex color code:"
self.run_decode(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
n=3,
)
def test_ebnf_generate_complex_json(self):
self.__class__.ebnf_grammar = """
root ::= object
object ::= "{" ws pair (ws "," ws pair)* ws "}"
pair ::= "\\"name\\"" ws ":" ws value |
"\\"age\\"" ws ":" ws number |
"\\"city\\"" ws ":" ws string
value ::= string | number
string ::= "\\"" [a-zA-Z0-9 ]+ "\\""
number ::= [1-9] [0-9]*
ws ::= [ ]*
"""
allowed_patterns = [
r'^{\s*"name"\s*:\s*"[a-zA-Z0-9 ]+"\s*,\s*"age"\s*:\s*[1-9][0-9]*\s*,\s*"city"\s*:\s*"[a-zA-Z0-9 ]+"\s*}$',
]
prompt = "Generate a simple JSON with name, age, and city:"
self.run_decode(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
n=3,
)
def test_ebnf_generate_custom_log_format(self):
self.__class__.ebnf_grammar = """
root ::= logentry
logentry ::= "[" datetime "] " level ": System.process - " message
datetime ::= "2024-01-01T12:00:00Z"
level ::= "INFO"
message ::= "Operation " [a-z]+ " successfully"
"""
allowed_patterns = [
r"^\[2024-01-01T12:00:00Z\] INFO: System\.process - Operation [a-z]+ successfully$"
]
prompt = "Generate a log entry:"
self.run_decode(
ebnf=self.__class__.ebnf_grammar,
expected_patterns=allowed_patterns,
prompt=prompt,
n=3,
)
class TestJumpForward(TestEBNFConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, disable_overlap=True)
cls.check_jump_forward = True
if __name__ == "__main__":
unittest.main()

View File

@@ -5,6 +5,7 @@ python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
"""
import json
import re
import time
import unittest
@@ -535,5 +536,92 @@ The SmartHome Mini is a compact smart home assistant available in black or white
)
# -------------------------------------------------------------------------
# EBNF Test Class: TestOpenAIServerEBNF
# Launches the server with xgrammar, has only EBNF tests
# -------------------------------------------------------------------------
class TestOpenAIServerEBNF(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
# passing xgrammar specifically
other_args = ["--grammar-backend", "xgrammar"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=other_args,
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_ebnf(self):
"""
Ensure we can pass `ebnf` to the local openai server
and that it enforces the grammar.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
ebnf_grammar = r"""
root ::= "Hello" | "Hi" | "Hey"
"""
pattern = re.compile(r"^(Hello|Hi|Hey)[.!?]*\s*$")
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful EBNF test bot."},
{"role": "user", "content": "Say a greeting (Hello, Hi, or Hey)."},
],
temperature=0,
max_tokens=32,
extra_body={"ebnf": ebnf_grammar},
)
text = response.choices[0].message.content.strip()
print("EBNF test output:", repr(text))
self.assertTrue(len(text) > 0, "Got empty text from EBNF generation")
self.assertRegex(text, pattern, f"Text '{text}' doesn't match EBNF choices")
def test_ebnf_strict_json(self):
"""
A stricter EBNF that produces exactly {"name":"Alice"} format
with no trailing punctuation or extra fields.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
ebnf_grammar = r"""
root ::= "{" pair "}"
pair ::= "\"name\"" ":" string
string ::= "\"" [A-Za-z]+ "\""
"""
pattern = re.compile(r'^\{"name":"[A-Za-z]+"\}$')
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "EBNF mini-JSON generator."},
{
"role": "user",
"content": "Generate single key JSON with only letters.",
},
],
temperature=0,
max_tokens=64,
extra_body={"ebnf": ebnf_grammar},
)
text = response.choices[0].message.content.strip()
print("EBNF strict JSON test output:", repr(text))
self.assertTrue(len(text) > 0, "Got empty text from EBNF strict JSON test")
self.assertRegex(
text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape"
)
if __name__ == "__main__":
unittest.main()