[Feature] Support llguidance for constrained decoding (#3298)
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
"""
|
||||
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_email
|
||||
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_greeting
|
||||
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_email
|
||||
python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_greeting
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -17,7 +19,7 @@ from sglang.test.test_utils import (
|
||||
)
|
||||
|
||||
|
||||
def setup_class(cls, disable_overlap: bool):
|
||||
def setup_class(cls, backend: str, disable_overlap: bool):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.ebnf_grammar = 'root ::= "test"' # Default grammar
|
||||
@@ -26,7 +28,7 @@ def setup_class(cls, disable_overlap: bool):
|
||||
"--max-running-requests",
|
||||
"10",
|
||||
"--grammar-backend",
|
||||
"xgrammar",
|
||||
backend,
|
||||
]
|
||||
|
||||
if disable_overlap:
|
||||
@@ -43,7 +45,7 @@ def setup_class(cls, disable_overlap: bool):
|
||||
class TestEBNFConstrained(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, disable_overlap=False)
|
||||
setup_class(cls, "xgrammar", disable_overlap=False)
|
||||
cls.check_jump_forward = False
|
||||
|
||||
@classmethod
|
||||
@@ -236,5 +238,12 @@ class TestEBNFConstrained(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestEBNFConstrainedLLGuidance(TestEBNFConstrained):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, "llguidance", disable_overlap=False)
|
||||
cls.check_jump_forward = False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
python3 -m unittest test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate
|
||||
python3 -m unittest test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate
|
||||
python3 -m unittest test_json_constrained.TestJSONConstrainedLLGuidanceBackend.test_json_generate
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -30,6 +31,7 @@ def setup_class(cls, backend: str, disable_overlap: bool):
|
||||
"population": {"type": "integer"},
|
||||
},
|
||||
"required": ["name", "population"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -146,5 +148,12 @@ class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend):
|
||||
cls.check_jump_forward = False
|
||||
|
||||
|
||||
class TestJSONConstrainedLLGuidanceBackend(TestJSONConstrainedOutlinesBackend):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, backend="llguidance", disable_overlap=False)
|
||||
cls.check_jump_forward = False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
"""
|
||||
python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_email
|
||||
python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_greeting
|
||||
python3 -m unittest test_regex_constrained.TestRegexConstrainedLLGuidance.test_regex_generate_email
|
||||
python3 -m unittest test_regex_constrained.TestRegexConstrainedLLGuidance.test_regex_generate_greeting
|
||||
python3 -m unittest test_regex_constrained.TestJumpForwardLLGuidance.test_regex_generate_email
|
||||
python3 -m unittest test_regex_constrained.TestJumpForwardLLGuidance.test_regex_generate_greeting
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -17,7 +21,7 @@ from sglang.test.test_utils import (
|
||||
)
|
||||
|
||||
|
||||
def setup_class(cls, disable_overlap: bool):
|
||||
def setup_class(cls, backend: str, disable_overlap: bool):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
|
||||
@@ -25,7 +29,7 @@ def setup_class(cls, disable_overlap: bool):
|
||||
"--max-running-requests",
|
||||
"10",
|
||||
"--grammar-backend",
|
||||
"xgrammar",
|
||||
backend,
|
||||
]
|
||||
|
||||
if disable_overlap:
|
||||
@@ -42,7 +46,7 @@ def setup_class(cls, disable_overlap: bool):
|
||||
class TestRegexConstrained(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, disable_overlap=False)
|
||||
setup_class(cls, "xgrammar", disable_overlap=False)
|
||||
cls.check_jump_forward = False
|
||||
|
||||
@classmethod
|
||||
@@ -178,9 +182,22 @@ class TestRegexConstrained(unittest.TestCase):
|
||||
class TestJumpForward(TestRegexConstrained):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, disable_overlap=True)
|
||||
setup_class(cls, "xgrammar", disable_overlap=True)
|
||||
cls.check_jump_forward = True
|
||||
|
||||
|
||||
class TestJumpForwardLLGuidance(TestRegexConstrained):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, "llguidance", disable_overlap=True)
|
||||
cls.check_jump_forward = True
|
||||
|
||||
|
||||
class TestRegexConstrainedLLGuidance(TestRegexConstrained):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, "llguidance", disable_overlap=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user