diff --git a/docs/backend/openai_api_completions.ipynb b/docs/backend/openai_api_completions.ipynb index 8660da2f9..42cdbb112 100644 --- a/docs/backend/openai_api_completions.ipynb +++ b/docs/backend/openai_api_completions.ipynb @@ -219,7 +219,7 @@ "SGLang supports two grammar backends:\n", "\n", "- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n", - "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema and EBNF constraints.\n", + "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.\n", " - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md)\n", "\n", "Initialize the XGrammar backend using `--grammar-backend xgrammar` flag\n", diff --git a/docs/backend/structured_outputs.ipynb b/docs/backend/structured_outputs.ipynb index 55ca0b627..a5e6f2335 100644 --- a/docs/backend/structured_outputs.ipynb +++ b/docs/backend/structured_outputs.ipynb @@ -16,7 +16,8 @@ "SGLang supports two grammar backends:\n", "\n", "- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n", - "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema and EBNF constraints and currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md).\n", + "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.\n", + " - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md)\n", "\n", "We suggest using XGrammar whenever possible for its better performance. For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\n", "\n", diff --git a/docs/references/sampling_params.md b/docs/references/sampling_params.md index 5dad3fd12..cdc53da61 100644 --- a/docs/references/sampling_params.md +++ b/docs/references/sampling_params.md @@ -189,7 +189,7 @@ You can specify a JSON schema, regular expression or [EBNF](https://en.wikipedia SGLang supports two grammar backends: - [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints. -- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema and EBNF constraints. +- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints. - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md) Initialize the XGrammar backend using `--grammar-backend xgrammar` flag diff --git a/python/pyproject.toml b/python/pyproject.toml index f1fcc4679..f97c9c266 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -23,7 +23,7 @@ runtime_common = [ "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop", - "xgrammar>=0.1.6" + "xgrammar>=0.1.10" ] srt = [ "sglang[runtime_common]", "cuda-python", diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index b0b2c31c2..c423a567e 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -19,6 +19,7 @@ from typing import List, Tuple import torch from xgrammar import ( CompiledGrammar, + Grammar, GrammarCompiler, GrammarMatcher, TokenizerInfo, @@ -133,10 +134,13 @@ class XGrammarGrammarBackend(BaseGrammarBackend): logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}") return None elif key_type == "regex": - logger.warning( - "regex hasn't been supported by xgrammar yet. This is skipped." - ) - return None + try: + ctx = self.grammar_compiler.compile_grammar( + Grammar.from_regex(key_string) + ) + except RuntimeError as e: + logging.warning(f"Skip invalid regex: regex={key_string}, {e=}") + return None else: raise ValueError(f"Invalid key_type: {key_type}") diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index fb1c6abf2..2ed252275 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -31,6 +31,7 @@ suites = { "test_openai_server.py", "test_pytorch_sampling_backend.py", "test_radix_attention.py", + "test_regex_constrained.py", "test_release_memory_occupation.py", "test_request_length_validation.py", "test_retract_decode.py", diff --git a/test/srt/test_regex_constrained.py b/test/srt/test_regex_constrained.py new file mode 100644 index 000000000..6d5acec15 --- /dev/null +++ b/test/srt/test_regex_constrained.py @@ -0,0 +1,186 @@ +""" +python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_email +python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_greeting +""" + +import json +import unittest + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +def setup_class(cls, disable_overlap: bool): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + + other_args = [ + "--max-running-requests", + "10", + "--grammar-backend", + "xgrammar", + ] + + if disable_overlap: + other_args += ["--disable-overlap-schedule"] + + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + +class TestRegexConstrained(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_class(cls, disable_overlap=False) + cls.check_jump_forward = False + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def run_decode( + self, + regex, + prompt, + return_logprob=False, + top_logprobs_num=0, + n=1, + ): + response = requests.post( + self.base_url + "/generate", + json={ + "text": prompt, + "sampling_params": { + "temperature": 0 if n == 1 else 0.5, + "max_new_tokens": 128, + "n": n, + "regex": regex, + }, + "stream": False, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "logprob_start_len": 0, + }, + ) + + ret = response.json() + print(json.dumps(ret, indent=2)) + print("=" * 100) + + if not isinstance(ret, list): + self.fail(f"Expected response to be a list, but got {type(ret)}") + + for item in ret: + text = item.get("text", "").strip() + if not text: + self.fail("Generated text is empty.") + + if not self.regex_match(text, regex): + self.fail(f"Text '{text}' does not match regex pattern.") + + def regex_match(self, text, pattern): + import re + + return re.match(pattern, text) is not None + + def test_regex_generate_email(self): + pattern = r"^user@example\.com$" + prompt = "Generate an email address:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_greeting(self): + pattern = r"^(Hello|Hi|Hey)$" + prompt = "Generate a greeting:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_number(self): + pattern = r"^\d{3}$" + prompt = "Generate a three-digit number:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_phone(self): + pattern = r"^\(\d{3}\) \d{3}-\d{4}$" + prompt = "Generate a phone number:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_date(self): + pattern = r"^2024-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$" + prompt = "Generate a date in YYYY-MM-DD format:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_hex_color(self): + pattern = r"^#[0-9A-F]{6}$" + prompt = "Generate a hex color code:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_complex_json(self): + pattern = r'^\{\s*"name"\s*:\s*"[a-zA-Z0-9 ]+"\s*,\s*"age"\s*:\s*[1-9][0-9]*\s*,\s*"city"\s*:\s*"[a-zA-Z0-9 ]+"\s*\}$' + prompt = "Generate a simple JSON with name, age, and city:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_custom_log_format(self): + pattern = r"^\[2024-01-01T12:00:00Z\] INFO: System\.process - Operation [a-z]+ successfully$" + prompt = "Generate a log entry:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + +class TestJumpForward(TestRegexConstrained): + @classmethod + def setUpClass(cls): + setup_class(cls, disable_overlap=True) + cls.check_jump_forward = True + + +if __name__ == "__main__": + unittest.main()