init
This commit is contained in:
426
model_executor/guided_decoding/xgrammar_decoding.py
Normal file
426
model_executor/guided_decoding/xgrammar_decoding.py
Normal file
@@ -0,0 +1,426 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# noqa: UP007
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
import vllm.envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
try:
|
||||
import xgrammar as xgr
|
||||
xgr_installed = True
|
||||
except ImportError:
|
||||
xgr_installed = False
|
||||
pass
|
||||
|
||||
from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
|
||||
grammar_is_likely_lark)
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_local_xgrammar_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model_config: ModelConfig,
|
||||
reasoner: ReasoningParser | None,
|
||||
max_threads: int = 8):
|
||||
config = GrammarConfig.from_guided_params(guided_params=guided_params,
|
||||
model_config=model_config,
|
||||
tokenizer=tokenizer,
|
||||
max_threads=max_threads)
|
||||
return XGrammarLogitsProcessor(config, reasoner)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TokenizerData:
|
||||
"""Immutable container for cached tokenizer data."""
|
||||
metadata: str
|
||||
encoded_vocab: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class TokenizerDataCache:
|
||||
"""Cache manager for tokenizer data to avoid repeated processing."""
|
||||
_cache: dict[int, TokenizerData] = {}
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer_data(
|
||||
cls,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
/,
|
||||
*,
|
||||
tokenizer_hash: int,
|
||||
vocab_size: int,
|
||||
) -> TokenizerData:
|
||||
|
||||
if tokenizer_hash not in cls._cache:
|
||||
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
||||
tokenizer,
|
||||
# NOTE: We will need to use lm_head's vocab_size
|
||||
# to determine correct special_token_ids for this tokenizer.
|
||||
# See https://github.com/mlc-ai/xgrammar/commit/70c959fb6d9cea75aae33c414763cd0602022d92 # noqa: E501
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
metadata = json.loads(tokenizer_info.dump_metadata())
|
||||
|
||||
# Vendored from xgrammar logic to get encoded_vocab
|
||||
# https://github.com/mlc-ai/xgrammar/blob/989222175c2a30fb7987d8bcce35bec1bf6817f2/python/xgrammar/tokenizer_info.py#L127 # noqa: E501
|
||||
try:
|
||||
vocab_dict = tokenizer.get_vocab()
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
f"Cannot get the vocabulary of the tokenizer "
|
||||
f"{type(tokenizer)}. The tokenizer should have a "
|
||||
"get_vocab method.") from e
|
||||
|
||||
# maintain tokenizer's indexing
|
||||
encoded_vocab = [""] * tokenizer_info.vocab_size
|
||||
for token, idx in vocab_dict.items():
|
||||
if idx < tokenizer_info.vocab_size:
|
||||
encoded_vocab[idx] = token
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
|
||||
metadata.update({
|
||||
"vocab_type": xgr.VocabType.BYTE_FALLBACK,
|
||||
"add_prefix_space": True
|
||||
})
|
||||
|
||||
cls._cache[tokenizer_hash] = TokenizerData(
|
||||
encoded_vocab=encoded_vocab,
|
||||
metadata=json.dumps(metadata),
|
||||
)
|
||||
|
||||
return cls._cache[tokenizer_hash]
|
||||
|
||||
|
||||
class GrammarCompilerCache:
|
||||
"""
|
||||
Cache for GrammarCompiler instances based on tokenizer.
|
||||
|
||||
This cache reduces the overhead of creating new compiler instances when
|
||||
using the same tokenizer configuration.
|
||||
"""
|
||||
_cache: dict[str, xgr.GrammarCompiler] = {}
|
||||
|
||||
@classmethod
|
||||
def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler:
|
||||
cache_key = str(config.tokenizer_hash)
|
||||
|
||||
if cache_key not in cls._cache:
|
||||
config_data = config.tokenizer_data
|
||||
|
||||
# In TokenizerDataCache.get_tokenizer_data, a serializable
|
||||
# tokenizer_data is created and cached. This data is used to build
|
||||
# a tokenizer_info and create an xgrammar compiler.
|
||||
tokenizer_info = xgr.TokenizerInfo.from_vocab_and_metadata(
|
||||
encoded_vocab=config_data.encoded_vocab,
|
||||
metadata=config_data.metadata,
|
||||
)
|
||||
cache_size = vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024
|
||||
cls._cache[cache_key] = xgr.GrammarCompiler(
|
||||
tokenizer_info,
|
||||
max_threads=config.max_threads,
|
||||
cache_enabled=True,
|
||||
cache_limit_bytes=cache_size,
|
||||
)
|
||||
|
||||
return cls._cache[cache_key]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GrammarConfig:
|
||||
"""Serializable configuration for grammar compilation"""
|
||||
tokenizer_hash: int
|
||||
tokenizer_data: TokenizerData
|
||||
json_str: str | None = None
|
||||
grammar_str: str | None = None
|
||||
json_object: bool | None = None
|
||||
any_whitespace: bool = True
|
||||
regex_str: str | None = None
|
||||
max_threads: int = 8
|
||||
|
||||
@classmethod
|
||||
def from_guided_params(cls,
|
||||
guided_params: GuidedDecodingParams,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_threads: int = 8) -> GrammarConfig:
|
||||
|
||||
tokenizer_hash = hash(tokenizer)
|
||||
tokenizer_data = TokenizerDataCache.get_tokenizer_data(
|
||||
tokenizer,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
vocab_size=model_config.hf_text_config.vocab_size,
|
||||
)
|
||||
|
||||
if guided_params.json:
|
||||
if not isinstance(guided_params.json, str):
|
||||
json_str = json.dumps(guided_params.json)
|
||||
else:
|
||||
json_str = guided_params.json
|
||||
|
||||
any_whitespace = not guided_params.disable_any_whitespace
|
||||
|
||||
# Check and log if model with xgrammar and whitespace have history
|
||||
# of runaway generation of whitespaces.
|
||||
# References:
|
||||
# https://github.com/vllm-project/vllm/pull/12744
|
||||
# https://github.com/mlc-ai/xgrammar/issues/212
|
||||
model_with_warn = None
|
||||
|
||||
if 'Mistral' in model_config.model:
|
||||
model_with_warn = 'Mistral'
|
||||
elif 'Qwen' in model_config.model:
|
||||
model_with_warn = 'Qwen'
|
||||
|
||||
if model_with_warn is not None and any_whitespace:
|
||||
logger.info_once(
|
||||
"%s model detected, consider setting `disable_any_whitespace` to prevent runaway generation of whitespaces.", # noqa: E501
|
||||
model_with_warn,
|
||||
)
|
||||
# Validate the schema and raise ValueError here if it is invalid.
|
||||
# This is to avoid exceptions in model execution, which will crash
|
||||
# the engine worker process.
|
||||
try:
|
||||
xgr.Grammar.from_json_schema(json_str,
|
||||
any_whitespace=any_whitespace)
|
||||
except RuntimeError as err:
|
||||
raise ValueError(str(err)) from err
|
||||
|
||||
return cls(json_str=json_str,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads,
|
||||
tokenizer_data=tokenizer_data,
|
||||
any_whitespace=any_whitespace)
|
||||
elif guided_params.grammar:
|
||||
# XGrammar only supports GBNF grammars, so we must convert Lark
|
||||
if grammar_is_likely_lark(guided_params.grammar):
|
||||
try:
|
||||
grammar_str = convert_lark_to_gbnf(guided_params.grammar)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"Failed to convert the grammar from Lark to GBNF. "
|
||||
"Please either use GBNF grammar directly or specify"
|
||||
" --guided-decoding-backend=outlines.\n"
|
||||
f"Conversion error: {str(e)}") from e
|
||||
else:
|
||||
grammar_str = guided_params.grammar
|
||||
|
||||
# Validate the grammar and raise ValueError here if it is invalid.
|
||||
# This is to avoid exceptions in model execution, which will crash
|
||||
# the engine worker process.
|
||||
try:
|
||||
xgr.Grammar.from_ebnf(grammar_str)
|
||||
except RuntimeError as err:
|
||||
raise ValueError(str(err)) from err
|
||||
|
||||
return cls(grammar_str=grammar_str,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads,
|
||||
tokenizer_data=tokenizer_data)
|
||||
elif guided_params.json_object:
|
||||
return cls(
|
||||
json_object=True,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads,
|
||||
tokenizer_data=tokenizer_data,
|
||||
)
|
||||
elif guided_params.choice:
|
||||
choice_str = GrammarConfig.choice_as_grammar(guided_params.choice)
|
||||
try:
|
||||
xgr.Grammar.from_ebnf(choice_str)
|
||||
except RuntimeError as err:
|
||||
raise ValueError(str(err)) from err
|
||||
|
||||
return cls(
|
||||
grammar_str=choice_str,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads,
|
||||
tokenizer_data=tokenizer_data,
|
||||
)
|
||||
elif guided_params.regex:
|
||||
return cls(
|
||||
regex_str=guided_params.regex,
|
||||
tokenizer_hash=tokenizer_hash,
|
||||
max_threads=max_threads,
|
||||
tokenizer_data=tokenizer_data,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Currently only support JSON and EBNF grammar mode for xgrammar"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def escape_ebnf_string(s: str) -> str:
|
||||
"""Escape special characters in a EBNF string."""
|
||||
# Escape double quotes and backslashes
|
||||
return re.sub(r'(["\\])', r'\\\1', s)
|
||||
|
||||
@staticmethod
|
||||
def choice_as_grammar(choice: list[str] | None) -> str:
|
||||
if choice is None:
|
||||
raise ValueError("Choice is not set")
|
||||
escaped_choices = (GrammarConfig.escape_ebnf_string(c) for c in choice)
|
||||
grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices))
|
||||
return grammar
|
||||
|
||||
@staticmethod
|
||||
def tokenizer_info(tokenizer_data: TokenizerData) -> xgr.TokenizerInfo:
|
||||
return xgr.TokenizerInfo.from_vocab_and_metadata(
|
||||
encoded_vocab=tokenizer_data.encoded_vocab,
|
||||
metadata=tokenizer_data.metadata,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class XGrammarLogitsProcessor:
|
||||
"""Wrapper class to support pickle protocol"""
|
||||
config: GrammarConfig
|
||||
reasoner: ReasoningParser | None = None
|
||||
|
||||
ctx: xgr.CompiledGrammar | None = None
|
||||
tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment]
|
||||
token_bitmask: torch.Tensor = None # type: ignore[assignment]
|
||||
matchers: list[xgr.GrammarMatcher] = field(default_factory=list)
|
||||
batch_size: int = field(default=1)
|
||||
prefilled: bool = field(default=False)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer_info is None:
|
||||
self.tokenizer_info = self.config.tokenizer_info(
|
||||
self.config.tokenizer_data)
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
return {'config': self.config, 'reasoner': self.reasoner}
|
||||
|
||||
def __setstate__(self, state: dict[str, Any]):
|
||||
self.config = state['config']
|
||||
self.reasoner = state['reasoner']
|
||||
|
||||
self.tokenizer_info = GrammarConfig.tokenizer_info(
|
||||
self.config.tokenizer_data)
|
||||
self.ctx = None
|
||||
self.matchers = []
|
||||
self.batch_size = 1
|
||||
self.token_bitmask = None # type: ignore[assignment]
|
||||
self.prefilled = False
|
||||
|
||||
def _ensure_ctx(self):
|
||||
"""Lazily initialize the processor in the worker process"""
|
||||
if self.ctx is None:
|
||||
compiler = GrammarCompilerCache.get_compiler(self.config)
|
||||
if self.config.json_str is not None:
|
||||
any_whitespace = self.config.any_whitespace
|
||||
self.ctx = compiler\
|
||||
.compile_json_schema(self.config.json_str,
|
||||
any_whitespace=any_whitespace)
|
||||
elif self.config.grammar_str is not None:
|
||||
self.ctx = compiler.compile_grammar(self.config.grammar_str)
|
||||
elif self.config.json_object:
|
||||
any_whitespace = self.config.any_whitespace
|
||||
self.ctx = compiler\
|
||||
.compile_json_schema('{"type": "object"}',
|
||||
any_whitespace=any_whitespace)
|
||||
elif self.config.regex_str:
|
||||
self.ctx = compiler.compile_regex(self.config.regex_str)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid configuration for xgrammar logits processor")
|
||||
|
||||
def __call__(self, input_ids: list[int],
|
||||
scores: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
# Skip the structured logits processing if reasoning is not finished.
|
||||
# reasoner is not None only when `--reasoning-parser` is set.
|
||||
if self.reasoner is not None and \
|
||||
not self.reasoner.is_reasoning_end(
|
||||
input_ids):
|
||||
return scores
|
||||
|
||||
if self.ctx is None:
|
||||
self._ensure_ctx()
|
||||
|
||||
if len(self.matchers) == 0:
|
||||
self.matchers = [
|
||||
xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
|
||||
]
|
||||
self.token_bitmask = xgr.allocate_token_bitmask(
|
||||
self.batch_size, self.tokenizer_info.vocab_size)
|
||||
|
||||
if not self.prefilled:
|
||||
# Have not sampled a token yet
|
||||
self.prefilled = True
|
||||
else:
|
||||
for i, matcher in enumerate(self.matchers):
|
||||
if not matcher.is_terminated():
|
||||
sampled_token = input_ids[-1]
|
||||
assert self.matchers[i].accept_token(sampled_token)
|
||||
|
||||
for i, matcher in enumerate(self.matchers):
|
||||
if not matcher.is_terminated():
|
||||
# @ubospica: ideally, fill_next_token_bitmask should be
|
||||
# parallelized with model decoding
|
||||
# See https://github.com/vllm-project/vllm/pull/10785/files#r1864278303
|
||||
matcher.fill_next_token_bitmask(self.token_bitmask, i)
|
||||
|
||||
# token_bitmask is a CPU tensor for use with accept_token and
|
||||
# fill_next_token_bitmask so we move it to the device of scores
|
||||
device_type = scores.device.type
|
||||
dtype = scores.dtype
|
||||
if device_type != "cuda":
|
||||
# xgrammar on cpu only supports float32 scores
|
||||
# see: https://github.com/mlc-ai/xgrammar/blob/c1b64920cad24f44f235778c1c00bb52d57da01a/python/xgrammar/kernels/apply_token_bitmask_inplace_cpu.py#L22
|
||||
scores = scores.to("cpu").float().unsqueeze(0)
|
||||
|
||||
# Note: In this method, if the tensors have different dimensions
|
||||
# on CPU device fails, but on GPU it runs without error. Hence the
|
||||
# unsqueeze above for scores, to match the token bitmask shape
|
||||
xgr.apply_token_bitmask_inplace(
|
||||
scores, self.token_bitmask.to(scores.device, non_blocking=True))
|
||||
if device_type != "cuda":
|
||||
scores = scores.to(dtype).to(device_type).squeeze()
|
||||
|
||||
return scores
|
||||
|
||||
def clone(self) -> XGrammarLogitsProcessor:
|
||||
"""Create a new instance with shared compiled grammar
|
||||
but separate state"""
|
||||
new_processor = XGrammarLogitsProcessor(self.config, self.reasoner,
|
||||
None, self.tokenizer_info)
|
||||
|
||||
# Share the compiled grammar context (immutable after compilation)
|
||||
new_processor.ctx = self.ctx
|
||||
|
||||
# Create fresh matchers for the new sequence
|
||||
if self.ctx is not None:
|
||||
new_processor.matchers = [
|
||||
xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
|
||||
]
|
||||
|
||||
# Create a new token bitmask with the same size
|
||||
if hasattr(self, 'token_bitmask') and self.token_bitmask is not None:
|
||||
new_processor.token_bitmask = self.token_bitmask
|
||||
|
||||
# Copy simple attributes
|
||||
new_processor.batch_size = self.batch_size
|
||||
# Reset prefilled state for new sequence
|
||||
new_processor.prefilled = False
|
||||
|
||||
return new_processor
|
||||
Reference in New Issue
Block a user