Update XGrammar to the latest API (#2176)
Co-authored-by: Ben Gitter <gitterbd@gmail.com>
This commit is contained in:
@@ -22,7 +22,7 @@ runtime_common = ["aiohttp", "decord", "fastapi",
|
|||||||
"packaging", "pillow", "prometheus-client>=0.20.0",
|
"packaging", "pillow", "prometheus-client>=0.20.0",
|
||||||
"psutil", "pydantic", "python-multipart",
|
"psutil", "pydantic", "python-multipart",
|
||||||
"pyzmq>=25.1.2", "torchao", "uvicorn", "uvloop",
|
"pyzmq>=25.1.2", "torchao", "uvicorn", "uvloop",
|
||||||
"modelscope", "xgrammar"]
|
"modelscope", "xgrammar==0.1.4"]
|
||||||
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1"]
|
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1"]
|
||||||
|
|
||||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||||
|
|||||||
@@ -17,21 +17,14 @@ import logging
|
|||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from xgrammar import (
|
||||||
try:
|
CompiledGrammar,
|
||||||
from xgrammar import (
|
GrammarCompiler,
|
||||||
CachedGrammarCompiler,
|
GrammarMatcher,
|
||||||
CompiledGrammar,
|
TokenizerInfo,
|
||||||
GrammarMatcher,
|
allocate_token_bitmask,
|
||||||
TokenizerInfo,
|
apply_token_bitmask_inplace,
|
||||||
)
|
)
|
||||||
|
|
||||||
import_error = None
|
|
||||||
except ImportError as e:
|
|
||||||
CachedGrammarCompiler = CompiledGrammar = GrammarMatcher = TokenizerInfo = (
|
|
||||||
ImportError
|
|
||||||
)
|
|
||||||
import_error = e
|
|
||||||
|
|
||||||
from sglang.srt.constrained.base_grammar_backend import (
|
from sglang.srt.constrained.base_grammar_backend import (
|
||||||
BaseGrammarBackend,
|
BaseGrammarBackend,
|
||||||
@@ -41,7 +34,7 @@ from sglang.srt.constrained.base_grammar_backend import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
MAX_ROLLBACK_TOKENS = 10
|
MAX_ROLLBACK_TOKENS = 200
|
||||||
|
|
||||||
|
|
||||||
class XGrammarGrammar(BaseGrammarObject):
|
class XGrammarGrammar(BaseGrammarObject):
|
||||||
@@ -86,21 +79,22 @@ class XGrammarGrammar(BaseGrammarObject):
|
|||||||
def allocate_vocab_mask(
|
def allocate_vocab_mask(
|
||||||
self, vocab_size: int, batch_size: int, device
|
self, vocab_size: int, batch_size: int, device
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.matcher.allocate_token_bitmask(vocab_size, batch_size)
|
return allocate_token_bitmask(batch_size, vocab_size)
|
||||||
|
|
||||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||||
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
|
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
||||||
GrammarMatcher.apply_token_bitmask_inplace(logits, vocab_mask)
|
if vocab_mask.device.type != logits.device.type:
|
||||||
|
# vocab_mask must then be on the same device as logits
|
||||||
|
# when applying the token bitmask, so we check and move if needed
|
||||||
|
vocab_mask = vocab_mask.to(logits.device)
|
||||||
|
|
||||||
|
apply_token_bitmask_inplace(logits, vocab_mask)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
matcher = GrammarMatcher(
|
matcher = GrammarMatcher(self.ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
||||||
self.ctx,
|
|
||||||
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
)
|
|
||||||
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
|
return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
|
||||||
|
|
||||||
|
|
||||||
@@ -112,25 +106,18 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if import_error:
|
tokenizer_info = TokenizerInfo.from_huggingface(
|
||||||
logger.warning(
|
tokenizer, vocab_size=vocab_size
|
||||||
f"Ignore import error for the grammar backend: {import_error}"
|
)
|
||||||
)
|
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
||||||
self.grammar_cache = None
|
|
||||||
return
|
|
||||||
|
|
||||||
tokenizer_info = TokenizerInfo.from_huggingface(tokenizer)
|
|
||||||
self.grammar_cache = CachedGrammarCompiler(tokenizer_info=tokenizer_info)
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
|
||||||
if import_error:
|
|
||||||
raise import_error
|
|
||||||
|
|
||||||
key_type, key_string = key
|
key_type, key_string = key
|
||||||
if key_type == "json":
|
if key_type == "json":
|
||||||
try:
|
try:
|
||||||
ctx = self.grammar_cache.compile_json_schema_grammar(schema=key_string)
|
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
|
||||||
@@ -144,13 +131,9 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid key_type: {key_type}")
|
raise ValueError(f"Invalid key_type: {key_type}")
|
||||||
|
|
||||||
matcher = GrammarMatcher(
|
matcher = GrammarMatcher(ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS)
|
||||||
ctx,
|
|
||||||
max_rollback_tokens=MAX_ROLLBACK_TOKENS,
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
)
|
|
||||||
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
return XGrammarGrammar(matcher, self.vocab_size, ctx)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
if self.grammar_cache:
|
if self.grammar_compiler:
|
||||||
self.grammar_cache.clear()
|
self.grammar_compiler.clear_cache()
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from sglang.test.test_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestJSONConstrained(unittest.TestCase):
|
class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
@@ -36,7 +36,12 @@ class TestJSONConstrained(unittest.TestCase):
|
|||||||
cls.model,
|
cls.model,
|
||||||
cls.base_url,
|
cls.base_url,
|
||||||
timeout=300,
|
timeout=300,
|
||||||
other_args=["--max-running-requests", "10"],
|
other_args=[
|
||||||
|
"--max-running-requests",
|
||||||
|
"10",
|
||||||
|
"--grammar-backend",
|
||||||
|
"outlines",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -121,5 +126,33 @@ class TestJSONConstrained(unittest.TestCase):
|
|||||||
list(executor.map(self.run_decode, json_schemas))
|
list(executor.map(self.run_decode, json_schemas))
|
||||||
|
|
||||||
|
|
||||||
|
class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.json_schema = json.dumps(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"population": {"type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["name", "population"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=300,
|
||||||
|
other_args=[
|
||||||
|
"--max-running-requests",
|
||||||
|
"10",
|
||||||
|
"--grammar-backend",
|
||||||
|
"xgrammar",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user