From acb340728c169a9338e16783ff65510ab21179be Mon Sep 17 00:00:00 2001 From: Adarsh Shirawalmath <114558126+adarshxs@users.noreply.github.com> Date: Thu, 26 Dec 2024 18:42:41 +0530 Subject: [PATCH] [Feature] Support new parameter - EBNF in xgrammar (#2526) --- python/sglang/lang/backend/openai.py | 10 + .../srt/constrained/xgrammar_backend.py | 6 + python/sglang/srt/managers/scheduler.py | 3 + python/sglang/srt/openai_api/adapter.py | 19 ++ python/sglang/srt/openai_api/protocol.py | 2 + python/sglang/srt/sampling/sampling_params.py | 11 +- test/srt/test_ebnf_constrained.py | 247 ++++++++++++++++++ test/srt/test_openai_server.py | 88 +++++++ 8 files changed, 384 insertions(+), 2 deletions(-) create mode 100644 test/srt/test_ebnf_constrained.py diff --git a/python/sglang/lang/backend/openai.py b/python/sglang/lang/backend/openai.py index 6fa93d9b2..4f37da79b 100644 --- a/python/sglang/lang/backend/openai.py +++ b/python/sglang/lang/backend/openai.py @@ -366,6 +366,11 @@ class OpenAI(BaseBackend): def openai_completion( client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs ): + # if "ebnf" is in kwargs, warn and remove + if "ebnf" in kwargs: + warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.") + del kwargs["ebnf"] + for attempt in range(retries): try: if is_chat: @@ -398,6 +403,11 @@ def openai_completion( def openai_completion_stream( client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs ): + # if "ebnf" is in kwargs, warn and remove + if "ebnf" in kwargs: + warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.") + del kwargs["ebnf"] + for attempt in range(retries): try: if is_chat: diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index 91cd17c6f..b0b2c31c2 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -126,6 +126,12 @@ class XGrammarGrammarBackend(BaseGrammarBackend): f"Skip invalid json_schema: json_schema={key_string}, {e=}" ) return None + elif key_type == "ebnf": + try: + ctx = self.grammar_compiler.compile_grammar(key_string) + except RuntimeError as e: + logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}") + return None elif key_type == "regex": logger.warning( "regex hasn't been supported by xgrammar yet. This is skipped." diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 6b4d0276c..9be35b0ee 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -589,12 +589,15 @@ class Scheduler: if ( req.sampling_params.json_schema is not None or req.sampling_params.regex is not None + or req.sampling_params.ebnf is not None ): assert self.grammar_backend is not None if req.sampling_params.json_schema is not None: key = ("json", req.sampling_params.json_schema) elif req.sampling_params.regex is not None: key = ("regex", req.sampling_params.regex) + elif req.sampling_params.ebnf is not None: + key = ("ebnf", req.sampling_params.ebnf) req.grammar = self.grammar_backend.get_cached_value(key) if not req.grammar: diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index d8fd731c4..46094c556 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -517,6 +517,7 @@ def v1_generate_request( "repetition_penalty": request.repetition_penalty, "regex": request.regex, "json_schema": request.json_schema, + "ebnf": request.ebnf, "n": request.n, "no_stop_trim": request.no_stop_trim, "ignore_eos": request.ignore_eos, @@ -692,6 +693,14 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False): async def v1_completions(tokenizer_manager, raw_request: Request): request_json = await raw_request.json() + if "extra_body" in request_json: + extra = request_json["extra_body"] + if "ebnf" in extra: + request_json["ebnf"] = extra["ebnf"] + if "regex" in extra: + request_json["regex"] = extra["regex"] + # remove extra_body to avoid pydantic conflict + del request_json["extra_body"] all_requests = [CompletionRequest(**request_json)] adapted_request, request = v1_generate_request(all_requests) @@ -936,6 +945,7 @@ def v1_chat_generate_request( "frequency_penalty": request.frequency_penalty, "repetition_penalty": request.repetition_penalty, "regex": request.regex, + "ebnf": request.ebnf, "n": request.n, "no_stop_trim": request.no_stop_trim, "ignore_eos": request.ignore_eos, @@ -1108,6 +1118,15 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False): async def v1_chat_completions(tokenizer_manager, raw_request: Request): request_json = await raw_request.json() + if "extra_body" in request_json: + extra = request_json["extra_body"] + # For example, if 'ebnf' is given: + if "ebnf" in extra: + request_json["ebnf"] = extra["ebnf"] + if "regex" in extra: + request_json["regex"] = extra["regex"] + # remove extra_body to avoid pydantic conflict + del request_json["extra_body"] all_requests = [ChatCompletionRequest(**request_json)] adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager) diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 9fe3f25d5..4f7833a23 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -179,6 +179,7 @@ class CompletionRequest(BaseModel): ignore_eos: bool = False skip_special_tokens: bool = True lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + ebnf: Optional[str] = None class CompletionResponseChoice(BaseModel): @@ -288,6 +289,7 @@ class ChatCompletionRequest(BaseModel): ignore_eos: bool = False skip_special_tokens: bool = True lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + ebnf: Optional[str] = None class ChatMessage(BaseModel): diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index 64d5e0783..55a2c910d 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -36,6 +36,7 @@ class SamplingParams: regex: Optional[str] = None, n: int = 1, json_schema: Optional[str] = None, + ebnf: Optional[str] = None, no_stop_trim: bool = False, ignore_eos: bool = False, skip_special_tokens: bool = True, @@ -60,6 +61,7 @@ class SamplingParams: self.regex = regex self.n = n self.json_schema = json_schema + self.ebnf = ebnf self.no_stop_trim = no_stop_trim # Process some special cases @@ -111,8 +113,13 @@ class SamplingParams: f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got " f"{self.min_new_tokens}." ) - if self.regex is not None and self.json_schema is not None: - raise ValueError("regex and json_schema cannot be both set.") + grammars = [ + self.json_schema, + self.regex, + self.ebnf, + ] # since mutually exclusive, only one can be set + if sum(x is not None for x in grammars) > 1: + raise ValueError("Only one of regex, json_schema, or ebnf can be set.") def normalize(self, tokenizer): # Process stop strings diff --git a/test/srt/test_ebnf_constrained.py b/test/srt/test_ebnf_constrained.py new file mode 100644 index 000000000..97b6f7561 --- /dev/null +++ b/test/srt/test_ebnf_constrained.py @@ -0,0 +1,247 @@ +""" +python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_email +python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_greeting +""" + +import json +import unittest + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +def setup_class(cls, disable_overlap: bool): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.ebnf_grammar = 'root ::= "test"' # Default grammar + + other_args = [ + "--max-running-requests", + "10", + "--grammar-backend", + "xgrammar", + ] + + if disable_overlap: + other_args += ["--disable-overlap-schedule"] + + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + +class TestEBNFConstrained(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_class(cls, disable_overlap=False) + cls.check_jump_forward = False + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def run_decode( + self, + ebnf, + expected_patterns, + prompt, + return_logprob=False, + top_logprobs_num=0, + n=1, + ): + response = requests.post( + self.base_url + "/generate", + json={ + "text": prompt, + "sampling_params": { + "temperature": 0 if n == 1 else 0.5, + "max_new_tokens": 128, + "n": n, + "ebnf": ebnf, + }, + "stream": False, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "logprob_start_len": 0, + }, + ) + + ret = response.json() + print(json.dumps(ret, indent=2)) + print("=" * 100) + + if not isinstance(ret, list): + self.fail(f"Expected response to be a list, but got {type(ret)}") + + for item in ret: + text = item.get("text", "").strip() + if not text: + self.fail("Generated text is empty.") + + match = False + for pattern in expected_patterns: + if self.regex_match(text, pattern): + match = True + break + if not match: + self.fail(f"Text '{text}' does not match any of the allowed patterns.") + + def regex_match(self, text, pattern): + import re + + return re.match(pattern, text) is not None + + def test_ebnf_generate_email(self): + self.__class__.ebnf_grammar = 'root ::= "user@example.com"' + allowed_patterns = [r"^user@example\.com$"] + prompt = "Generate an email address:" + + self.run_decode( + ebnf=self.__class__.ebnf_grammar, + expected_patterns=allowed_patterns, + prompt=prompt, + n=3, + ) + + def test_ebnf_generate_greeting(self): + self.__class__.ebnf_grammar = 'root ::= "Hello" | "Hi" | "Hey"' + allowed_patterns = [r"^(Hello|Hi|Hey)$"] + prompt = "Generate a greeting:" + + self.run_decode( + ebnf=self.__class__.ebnf_grammar, + expected_patterns=allowed_patterns, + prompt=prompt, + n=3, + ) + + def test_ebnf_generate_number(self): + self.__class__.ebnf_grammar = """ + root ::= digit digit digit + digit ::= [0-9] + """ + allowed_patterns = [r"^\d{3}$"] + prompt = "Generate a three-digit number:" + + self.run_decode( + ebnf=self.__class__.ebnf_grammar, + expected_patterns=allowed_patterns, + prompt=prompt, + n=3, + ) + + def test_ebnf_generate_phone(self): + self.__class__.ebnf_grammar = """ + root ::= "(" area ")" " " prefix "-" line + area ::= [0-9] [0-9] [0-9] + prefix ::= [0-9] [0-9] [0-9] + line ::= [0-9] [0-9] [0-9] [0-9] + """ + allowed_patterns = [r"^\(\d{3}\) \d{3}-\d{4}$"] + prompt = "Generate a phone number:" + + self.run_decode( + ebnf=self.__class__.ebnf_grammar, + expected_patterns=allowed_patterns, + prompt=prompt, + n=3, + ) + + def test_ebnf_generate_date(self): + self.__class__.ebnf_grammar = """ + root ::= year "-" month "-" day + year ::= "2024" + month ::= "01" | "02" | "03" | "04" | "05" | "06" | "07" | "08" | "09" | "10" | "11" | "12" + day ::= "01" | "02" | "03" | "04" | "05" | "06" | "07" | "08" | "09" | "10" | + "11" | "12" | "13" | "14" | "15" | "16" | "17" | "18" | "19" | "20" | + "21" | "22" | "23" | "24" | "25" | "26" | "27" | "28" | "29" | "30" | "31" + """ + allowed_patterns = [r"^2024-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$"] + prompt = "Generate a date in YYYY-MM-DD format:" + + self.run_decode( + ebnf=self.__class__.ebnf_grammar, + expected_patterns=allowed_patterns, + prompt=prompt, + n=3, + ) + + def test_ebnf_generate_hex_color(self): + self.__class__.ebnf_grammar = """ + root ::= "#" hex hex hex hex hex hex + hex ::= [0-9] | [A-F] + """ + allowed_patterns = [r"^#[0-9A-F]{6}$"] + prompt = "Generate a hex color code:" + + self.run_decode( + ebnf=self.__class__.ebnf_grammar, + expected_patterns=allowed_patterns, + prompt=prompt, + n=3, + ) + + def test_ebnf_generate_complex_json(self): + self.__class__.ebnf_grammar = """ + root ::= object + object ::= "{" ws pair (ws "," ws pair)* ws "}" + pair ::= "\\"name\\"" ws ":" ws value | + "\\"age\\"" ws ":" ws number | + "\\"city\\"" ws ":" ws string + value ::= string | number + string ::= "\\"" [a-zA-Z0-9 ]+ "\\"" + number ::= [1-9] [0-9]* + ws ::= [ ]* + """ + allowed_patterns = [ + r'^{\s*"name"\s*:\s*"[a-zA-Z0-9 ]+"\s*,\s*"age"\s*:\s*[1-9][0-9]*\s*,\s*"city"\s*:\s*"[a-zA-Z0-9 ]+"\s*}$', + ] + prompt = "Generate a simple JSON with name, age, and city:" + + self.run_decode( + ebnf=self.__class__.ebnf_grammar, + expected_patterns=allowed_patterns, + prompt=prompt, + n=3, + ) + + def test_ebnf_generate_custom_log_format(self): + self.__class__.ebnf_grammar = """ + root ::= logentry + logentry ::= "[" datetime "] " level ": System.process - " message + datetime ::= "2024-01-01T12:00:00Z" + level ::= "INFO" + message ::= "Operation " [a-z]+ " successfully" + """ + allowed_patterns = [ + r"^\[2024-01-01T12:00:00Z\] INFO: System\.process - Operation [a-z]+ successfully$" + ] + prompt = "Generate a log entry:" + + self.run_decode( + ebnf=self.__class__.ebnf_grammar, + expected_patterns=allowed_patterns, + prompt=prompt, + n=3, + ) + + +class TestJumpForward(TestEBNFConstrained): + @classmethod + def setUpClass(cls): + setup_class(cls, disable_overlap=True) + cls.check_jump_forward = True + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index d007bed31..47932ae41 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -5,6 +5,7 @@ python3 -m unittest test_openai_server.TestOpenAIServer.test_completion """ import json +import re import time import unittest @@ -535,5 +536,92 @@ The SmartHome Mini is a compact smart home assistant available in black or white ) +# ------------------------------------------------------------------------- +# EBNF Test Class: TestOpenAIServerEBNF +# Launches the server with xgrammar, has only EBNF tests +# ------------------------------------------------------------------------- +class TestOpenAIServerEBNF(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + + # passing xgrammar specifically + other_args = ["--grammar-backend", "xgrammar"] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=other_args, + ) + cls.base_url += "/v1" + cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_ebnf(self): + """ + Ensure we can pass `ebnf` to the local openai server + and that it enforces the grammar. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + ebnf_grammar = r""" + root ::= "Hello" | "Hi" | "Hey" + """ + pattern = re.compile(r"^(Hello|Hi|Hey)[.!?]*\s*$") + + response = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful EBNF test bot."}, + {"role": "user", "content": "Say a greeting (Hello, Hi, or Hey)."}, + ], + temperature=0, + max_tokens=32, + extra_body={"ebnf": ebnf_grammar}, + ) + text = response.choices[0].message.content.strip() + print("EBNF test output:", repr(text)) + self.assertTrue(len(text) > 0, "Got empty text from EBNF generation") + self.assertRegex(text, pattern, f"Text '{text}' doesn't match EBNF choices") + + def test_ebnf_strict_json(self): + """ + A stricter EBNF that produces exactly {"name":"Alice"} format + with no trailing punctuation or extra fields. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + ebnf_grammar = r""" + root ::= "{" pair "}" + pair ::= "\"name\"" ":" string + string ::= "\"" [A-Za-z]+ "\"" + """ + pattern = re.compile(r'^\{"name":"[A-Za-z]+"\}$') + + response = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "EBNF mini-JSON generator."}, + { + "role": "user", + "content": "Generate single key JSON with only letters.", + }, + ], + temperature=0, + max_tokens=64, + extra_body={"ebnf": ebnf_grammar}, + ) + text = response.choices[0].message.content.strip() + print("EBNF strict JSON test output:", repr(text)) + self.assertTrue(len(text) > 0, "Got empty text from EBNF strict JSON test") + self.assertRegex( + text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape" + ) + + if __name__ == "__main__": unittest.main()