Sync from v0.13
This commit is contained in:
353
vllm/v1/structured_output/__init__.py
Normal file
353
vllm/v1/structured_output/__init__.py
Normal file
@@ -0,0 +1,353 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import multiprocessing
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.tokenizers import cached_tokenizer_from_config
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
|
||||
from vllm.v1.structured_output.backend_types import (
|
||||
StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
)
|
||||
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.v1.request import Request
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
ReasoningParser = object
|
||||
Request = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class StructuredOutputManager:
|
||||
"""Engine-level manager for structured output requests."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.backend: StructuredOutputBackend | None = None
|
||||
self.reasoner: ReasoningParser | None = None
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
# When in external_launcher mode, async grammar compilation causes deadlocks
|
||||
# due to external_launcher mode having a scheduler for each TP rank.
|
||||
# Async grammar compilation causes the WAITING_FOR_FSM → WAITING transition to
|
||||
# happen at different times on different TP ranks,
|
||||
# breaking the determinism assumption that external_launcher relies on.
|
||||
self._use_async_grammar_compilation = (
|
||||
vllm_config.parallel_config.distributed_executor_backend
|
||||
!= "external_launcher"
|
||||
)
|
||||
|
||||
self._grammar_bitmask: torch.Tensor | None = None
|
||||
self._full_mask = torch.tensor(-1, dtype=torch.int32)
|
||||
|
||||
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
|
||||
self.fill_bitmask_parallel_threshold = 128
|
||||
if self.fill_bitmask_parallel_threshold < max_batch_size:
|
||||
self.fill_bitmask_parallel_batch_size = 16
|
||||
# Use:
|
||||
# - at least 1 CPU
|
||||
# - at most half the number of CPUs or 8, whichever is less
|
||||
max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8))
|
||||
self.executor_for_fillmask = ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
if not self.vllm_config.model_config.skip_tokenizer_init:
|
||||
# The default max_workers if not specified is the number of
|
||||
# CPUs * 5, which is way too high since these tasks are CPU-bound,
|
||||
# not I/O bound. We also know we would never dominate CPU usage
|
||||
# with just grammar compilation, so we set it to half the number
|
||||
# of CPUs.
|
||||
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.tokenizer = cached_tokenizer_from_config(
|
||||
model_config=self.vllm_config.model_config
|
||||
)
|
||||
reasoning_parser = (
|
||||
self.vllm_config.structured_outputs_config.reasoning_parser
|
||||
)
|
||||
reasoning_parser_plugin = (
|
||||
self.vllm_config.structured_outputs_config.reasoning_parser_plugin
|
||||
)
|
||||
if reasoning_parser_plugin and len(reasoning_parser_plugin) > 3:
|
||||
ReasoningParserManager.import_reasoning_parser(reasoning_parser_plugin)
|
||||
|
||||
reasoning_parser = (
|
||||
self.vllm_config.structured_outputs_config.reasoning_parser
|
||||
)
|
||||
if reasoning_parser:
|
||||
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
|
||||
reasoning_parser
|
||||
)
|
||||
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
|
||||
|
||||
self.enable_in_reasoning = (
|
||||
self.vllm_config.structured_outputs_config.enable_in_reasoning
|
||||
)
|
||||
|
||||
def grammar_init(self, request: Request) -> None:
|
||||
if request.structured_output_request is None:
|
||||
return
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert (
|
||||
request.sampling_params is not None
|
||||
and request.sampling_params.structured_outputs is not None
|
||||
)
|
||||
|
||||
# Initialize the backend the first time it is needed.
|
||||
#
|
||||
# NOTE: We only support a single backend. We do NOT support different
|
||||
# backends on a per-request basis in V1 (for now, anyway...).
|
||||
# _backend is set in Processor._validate_structured_output
|
||||
if self.backend is None:
|
||||
assert request.sampling_params is not None
|
||||
backend = request.sampling_params.structured_outputs._backend
|
||||
vocab_size = self.vllm_config.model_config.get_vocab_size()
|
||||
if backend == "xgrammar":
|
||||
self.backend = XgrammarBackend(
|
||||
self.vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
elif backend == "guidance":
|
||||
self.backend = GuidanceBackend(
|
||||
self.vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
elif backend == "outlines":
|
||||
from vllm.v1.structured_output.backend_outlines import OutlinesBackend
|
||||
|
||||
self.backend = OutlinesBackend(
|
||||
self.vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
elif backend == "lm-format-enforcer":
|
||||
from vllm.v1.structured_output.backend_lm_format_enforcer import ( # noqa: E501
|
||||
LMFormatEnforcerBackend,
|
||||
)
|
||||
|
||||
self.backend = LMFormatEnforcerBackend(
|
||||
self.vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported structured output backend: {backend}")
|
||||
|
||||
if self._use_async_grammar_compilation:
|
||||
grammar = self.executor.submit(self._create_grammar, request)
|
||||
else:
|
||||
grammar = self._create_grammar(request) # type: ignore[assignment]
|
||||
request.structured_output_request.grammar = grammar # type: ignore[assignment]
|
||||
|
||||
def _create_grammar(
|
||||
self,
|
||||
request: Request,
|
||||
) -> StructuredOutputGrammar:
|
||||
key = request.structured_output_request.structured_output_key # type: ignore[union-attr]
|
||||
|
||||
# Note that the request was validated in the engine core client,
|
||||
# so at this point we know it is a supported type of request.
|
||||
#
|
||||
# TODO: we still need to handle xgrammar compilation failures,
|
||||
# though it should be unlikely as we test that up front as well.
|
||||
request_type, grammar_spec = key
|
||||
|
||||
assert self.backend is not None
|
||||
return self.backend.compile_grammar(request_type, grammar_spec)
|
||||
|
||||
def _fill_bitmasks(
|
||||
self,
|
||||
batch: list[tuple[StructuredOutputGrammar, int, bool]],
|
||||
) -> None:
|
||||
assert self._grammar_bitmask is not None
|
||||
for grammar, index, apply_bitmask in batch:
|
||||
if apply_bitmask and not grammar.is_terminated():
|
||||
grammar.fill_bitmask(self._grammar_bitmask, index)
|
||||
else:
|
||||
# Note that for thinking support, we will need to
|
||||
# reset the relevant part of the bitmask for consequent
|
||||
# requests here.
|
||||
self._grammar_bitmask[index].fill_(self._full_mask)
|
||||
|
||||
def _async_submit_fill_bitmask(
|
||||
self,
|
||||
batch: list[tuple[StructuredOutputGrammar, int, bool]],
|
||||
) -> Future:
|
||||
return self.executor_for_fillmask.submit(self._fill_bitmasks, batch)
|
||||
|
||||
def grammar_bitmask(
|
||||
self,
|
||||
requests: dict[str, Request],
|
||||
structured_output_request_ids: list[str],
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]],
|
||||
) -> "npt.NDArray[np.int32] | None":
|
||||
# Prepare the structured output bitmask for this batch.
|
||||
if not structured_output_request_ids:
|
||||
return None
|
||||
|
||||
max_num_spec_tokens = 0
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
max_num_spec_tokens = (
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
)
|
||||
|
||||
if self._grammar_bitmask is None:
|
||||
assert self.backend is not None
|
||||
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
|
||||
|
||||
# Allocate a bitmask for each token needing to be checked:
|
||||
# one for each speculative position, and one more for the
|
||||
# bonus token / non-speculative token.
|
||||
self._grammar_bitmask = self.backend.allocate_token_bitmask(
|
||||
max_batch_size * (1 + max_num_spec_tokens)
|
||||
)
|
||||
|
||||
# Generate a batched bitmask for all structured output requests.
|
||||
# When speculative decoding is enabled, we need to include multiple
|
||||
# masks for each request, one for each possible bonus token position.
|
||||
# These are stored inline in the tensor and unpacked by the gpu runner.
|
||||
cumulative_index = 0
|
||||
|
||||
# Optimized parallel filling of bitmasks for
|
||||
# non-spec, large-batch-size cases
|
||||
if (
|
||||
len(structured_output_request_ids) > self.fill_bitmask_parallel_threshold
|
||||
and max_num_spec_tokens == 0
|
||||
):
|
||||
promises = []
|
||||
batch = []
|
||||
for req_id in structured_output_request_ids:
|
||||
request = requests[req_id]
|
||||
structured_output_request = request.structured_output_request
|
||||
if TYPE_CHECKING:
|
||||
assert structured_output_request is not None
|
||||
assert structured_output_request.grammar is not None
|
||||
|
||||
apply_bitmask = self.should_fill_bitmask(request)
|
||||
batch.append(
|
||||
(structured_output_request.grammar, cumulative_index, apply_bitmask)
|
||||
)
|
||||
if len(batch) == self.fill_bitmask_parallel_batch_size:
|
||||
promises.append(self._async_submit_fill_bitmask(batch))
|
||||
batch = []
|
||||
|
||||
cumulative_index += 1
|
||||
if batch:
|
||||
promises.append(self._async_submit_fill_bitmask(batch))
|
||||
|
||||
# Wait for all bitmask filling tasks to complete.
|
||||
for promise in promises:
|
||||
promise.result()
|
||||
else:
|
||||
# Fallback to serial filling of bitmasks for small-batch-size cases
|
||||
for req_id in structured_output_request_ids:
|
||||
request = requests[req_id]
|
||||
structured_output_request = request.structured_output_request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert structured_output_request is not None
|
||||
assert structured_output_request.grammar is not None
|
||||
apply_bitmask = self.should_fill_bitmask(request)
|
||||
|
||||
state_advancements = 0
|
||||
req_tokens = scheduled_spec_decode_tokens.get(req_id, [])
|
||||
for i, token in enumerate(req_tokens + [None]):
|
||||
self._fill_bitmasks(
|
||||
[
|
||||
(
|
||||
structured_output_request.grammar,
|
||||
cumulative_index,
|
||||
apply_bitmask,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if (
|
||||
apply_bitmask
|
||||
and token is not None
|
||||
and not structured_output_request.grammar.is_terminated()
|
||||
):
|
||||
accepted = structured_output_request.grammar.accept_tokens(
|
||||
req_id, [token]
|
||||
)
|
||||
assert accepted, (token, req_id, scheduled_spec_decode_tokens)
|
||||
state_advancements += 1
|
||||
cumulative_index += 1
|
||||
if state_advancements > 0:
|
||||
structured_output_request.grammar.rollback(state_advancements)
|
||||
|
||||
bitmask_tensor = self._grammar_bitmask
|
||||
if cumulative_index < bitmask_tensor.shape[0]:
|
||||
bitmask_tensor = bitmask_tensor[:cumulative_index]
|
||||
|
||||
# After finishing with the xgrammar operations, we convert to
|
||||
# np.ndarray, because that is much more efficient for serialization
|
||||
# and deserialization when sending this to the GPU workers.
|
||||
return bitmask_tensor.numpy()
|
||||
|
||||
def should_fill_bitmask(self, request: Request) -> bool:
|
||||
# NOTE (Hanchen) if enable_in_reasoning is True, it means that
|
||||
# the model needs to be constrained in reasoning. So we should always
|
||||
# enable the bitmask filling.
|
||||
|
||||
if self.reasoner is not None:
|
||||
if self.enable_in_reasoning:
|
||||
return True
|
||||
assert request.structured_output_request is not None
|
||||
if request.structured_output_request.reasoning_ended is None:
|
||||
request.structured_output_request.reasoning_ended = (
|
||||
self.reasoner.is_reasoning_end(request.prompt_token_ids)
|
||||
)
|
||||
return request.structured_output_request.reasoning_ended
|
||||
return True
|
||||
|
||||
def should_advance(self, request: Request) -> bool:
|
||||
if not request.use_structured_output:
|
||||
return False
|
||||
|
||||
# To determine whether we can advance the FSM.
|
||||
# Supports thinking usage where we skip the reasoning components.
|
||||
if TYPE_CHECKING:
|
||||
assert request.structured_output_request is not None
|
||||
assert request.structured_output_request.grammar is not None
|
||||
# by default, we should always advance
|
||||
# for cases that don't use thinking mode.
|
||||
if self.reasoner is None:
|
||||
return True
|
||||
|
||||
# if the model needs structured in reasoning, we should advance
|
||||
if self.enable_in_reasoning:
|
||||
return True
|
||||
|
||||
structured_req = request.structured_output_request
|
||||
if structured_req.reasoning_ended:
|
||||
return True
|
||||
|
||||
# Check if reasoning ends in *this* step
|
||||
if self.reasoner.is_reasoning_end_streaming(
|
||||
request.all_token_ids, request.all_token_ids[request.num_computed_tokens :]
|
||||
):
|
||||
# Reasoning just ended, so we shouldn't advance til
|
||||
# next pass
|
||||
structured_req.reasoning_ended = True
|
||||
|
||||
return False
|
||||
|
||||
def clear_backend(self) -> None:
|
||||
if self.backend is not None:
|
||||
self.backend.destroy()
|
||||
265
vllm/v1/structured_output/backend_guidance.py
Normal file
265
vllm/v1/structured_output/backend_guidance.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (
|
||||
StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
StructuredOutputOptions,
|
||||
)
|
||||
from vllm.v1.structured_output.request import get_structured_output_key
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import llguidance
|
||||
import llguidance.hf as llguidance_hf
|
||||
import llguidance.torch as llguidance_torch
|
||||
else:
|
||||
llguidance = LazyLoader("llguidance", globals(), "llguidance")
|
||||
llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf")
|
||||
llguidance_torch = LazyLoader("llguidance.torch", globals(), "llguidance.torch")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _walk_json_for_additional_properties(data: object):
|
||||
if isinstance(data, dict):
|
||||
for value in data.values():
|
||||
_walk_json_for_additional_properties(value)
|
||||
if "additionalProperties" not in data and (
|
||||
"properties" in data or "patternProperties" in data
|
||||
):
|
||||
data["additionalProperties"] = False
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
_walk_json_for_additional_properties(item)
|
||||
|
||||
|
||||
def process_for_additional_properties(
|
||||
guide_json: str | dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
if isinstance(guide_json, str):
|
||||
guide_json_obj = json.loads(guide_json)
|
||||
else:
|
||||
# copy for modifications
|
||||
guide_json_obj = copy.deepcopy(guide_json)
|
||||
_walk_json_for_additional_properties(guide_json_obj)
|
||||
return guide_json_obj
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuidanceBackend(StructuredOutputBackend):
|
||||
def __post_init__(self):
|
||||
self.disable_any_whitespace = (
|
||||
self.vllm_config.structured_outputs_config.disable_any_whitespace
|
||||
)
|
||||
self.disable_additional_properties = (
|
||||
self.vllm_config.structured_outputs_config.disable_additional_properties
|
||||
)
|
||||
|
||||
self.ll_tokenizer = llguidance_hf.from_tokenizer(
|
||||
self.tokenizer, self.vocab_size
|
||||
)
|
||||
|
||||
def compile_grammar(
|
||||
self, request_type: StructuredOutputOptions, grammar_spec: str
|
||||
) -> StructuredOutputGrammar:
|
||||
self.serialized_grammar = serialize_guidance_grammar(
|
||||
request_type,
|
||||
grammar_spec,
|
||||
self.disable_any_whitespace,
|
||||
self.disable_additional_properties,
|
||||
)
|
||||
|
||||
ll_matcher = llguidance.LLMatcher(
|
||||
self.ll_tokenizer,
|
||||
self.serialized_grammar,
|
||||
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
|
||||
)
|
||||
|
||||
r = GuidanceGrammar(
|
||||
ll_matcher=ll_matcher,
|
||||
ll_tokenizer=self.ll_tokenizer,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
r.check_error()
|
||||
return r
|
||||
|
||||
def allocate_token_bitmask(self, max_num_seqs: int):
|
||||
return llguidance_torch.allocate_token_bitmask(
|
||||
max_num_seqs, self.ll_tokenizer.vocab_size
|
||||
)
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuidanceGrammar(StructuredOutputGrammar):
|
||||
ll_matcher: llguidance.LLMatcher
|
||||
ll_tokenizer: llguidance.LLTokenizer
|
||||
vocab_size: int
|
||||
printed_error: bool = False
|
||||
terminated: bool = False
|
||||
rollback_lag: int = 0
|
||||
|
||||
def check_error(self):
|
||||
if not self.printed_error:
|
||||
err = self.ll_matcher.get_error()
|
||||
if err:
|
||||
self.printed_error = True
|
||||
logger.warning("LLMatcher error: %s", err)
|
||||
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""Accepts a list of tokens and advances the parser.
|
||||
|
||||
Returns True if the parser was advanced successfully.
|
||||
Returns False if the parser failed to advance.
|
||||
"""
|
||||
|
||||
if self.ll_tokenizer.eos_token in tokens:
|
||||
if self.ll_matcher.is_stopped() and not self.terminated:
|
||||
self.rollback_lag = 1
|
||||
self.terminated = True
|
||||
|
||||
if self.ll_matcher.is_stopped():
|
||||
return True
|
||||
|
||||
# TODO - Add jump decoding support in the future:
|
||||
# self.ll_matcher.compute_ff_bytes() - this should always work
|
||||
# self.ll_matcher.compute_ff_tokens() - this only works for
|
||||
# "canonical" tokenizers
|
||||
# For conversion between the two, see
|
||||
# https://github.com/guidance-ai/llguidance/blob/main/docs/fast_forward.md
|
||||
|
||||
r = self.ll_matcher.consume_tokens(tokens)
|
||||
|
||||
self.check_error()
|
||||
|
||||
return r
|
||||
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
"""Checks if the list of tokens are accepted by the parser in sequence.
|
||||
Will not advance the parser.
|
||||
|
||||
Returns the prefix list of tokens that are accepted by the parser.
|
||||
"""
|
||||
if len(tokens) == 0:
|
||||
return []
|
||||
if self.ll_matcher.is_stopped():
|
||||
return []
|
||||
|
||||
num_tokens = self.ll_matcher.validate_tokens(tokens)
|
||||
|
||||
self.check_error()
|
||||
|
||||
return tokens[:num_tokens]
|
||||
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
if num_tokens > 0:
|
||||
self.ll_matcher.rollback(num_tokens - self.rollback_lag)
|
||||
self.terminated = False
|
||||
self.rollback_lag = 0
|
||||
self.check_error()
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||
# this will automatically return [EOS] mask if the matcher is stopped
|
||||
# or otherwise in an error state
|
||||
llguidance_torch.fill_next_token_bitmask(self.ll_matcher, bitmask, idx)
|
||||
self.check_error()
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
return self.terminated
|
||||
|
||||
def reset(self):
|
||||
# This method may be not needed anymore? TODO
|
||||
self.ll_matcher.reset()
|
||||
|
||||
|
||||
def serialize_guidance_grammar(
|
||||
request_type: StructuredOutputOptions,
|
||||
grammar_spec: str | dict[str, Any],
|
||||
disable_any_whitespace: bool = False,
|
||||
disable_additional_properties: bool = False,
|
||||
) -> str:
|
||||
def _process_schema(
|
||||
grammar_spec: str | dict[str, Any],
|
||||
) -> str:
|
||||
if disable_additional_properties:
|
||||
grammar_spec = process_for_additional_properties(grammar_spec)
|
||||
return llguidance.LLMatcher.grammar_from_json_schema(
|
||||
grammar_spec,
|
||||
defaults={
|
||||
"whitespace_flexible": not disable_any_whitespace,
|
||||
},
|
||||
)
|
||||
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
return _process_schema(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
||||
return llguidance.LLMatcher.grammar_from_json_schema(
|
||||
'{"type": "object"}',
|
||||
defaults={
|
||||
"whitespace_flexible": not disable_any_whitespace,
|
||||
},
|
||||
)
|
||||
else:
|
||||
if request_type == StructuredOutputOptions.REGEX:
|
||||
tp = "regex"
|
||||
elif request_type == StructuredOutputOptions.GRAMMAR:
|
||||
tp = "grammar"
|
||||
elif request_type == StructuredOutputOptions.CHOICE:
|
||||
tp = "choice"
|
||||
elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
|
||||
if isinstance(grammar_spec, str):
|
||||
s_tag = json.loads(grammar_spec)
|
||||
else:
|
||||
s_tag = grammar_spec
|
||||
triggers: list[str] = s_tag["triggers"]
|
||||
tags: list[llguidance.StructTag] = []
|
||||
for s in s_tag["structures"]:
|
||||
begin: str = s["begin"]
|
||||
trig = next((t for t in triggers if begin.startswith(t)), None)
|
||||
if trig is None:
|
||||
raise ValueError(
|
||||
f"Trigger {begin} not found in triggers {triggers}"
|
||||
)
|
||||
tags.append(
|
||||
llguidance.StructTag(
|
||||
trigger=trig,
|
||||
begin=s["begin"],
|
||||
grammar=_process_schema(s["schema"]),
|
||||
end=s["end"],
|
||||
)
|
||||
)
|
||||
if not tags:
|
||||
raise ValueError("No structural tags found in the grammar spec.")
|
||||
return llguidance.StructTag.to_grammar(tags)
|
||||
else:
|
||||
logger.error(
|
||||
"Validation should have already occurred. Please file an issue."
|
||||
)
|
||||
raise ValueError(
|
||||
f"grammar is not of valid supported types. ({request_type!s})"
|
||||
)
|
||||
return llguidance.grammar_from(tp, grammar_spec)
|
||||
|
||||
|
||||
def validate_guidance_grammar(
|
||||
sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None
|
||||
) -> None:
|
||||
tp, grm = get_structured_output_key(sampling_params.structured_outputs)
|
||||
guidance_grm = serialize_guidance_grammar(tp, grm)
|
||||
err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer)
|
||||
if err:
|
||||
raise ValueError(f"Grammar error: {err}")
|
||||
177
vllm/v1/structured_output/backend_lm_format_enforcer.py
Normal file
177
vllm/v1/structured_output/backend_lm_format_enforcer.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (
|
||||
StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
StructuredOutputOptions,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import lmformatenforcer
|
||||
import lmformatenforcer.integrations.vllm as lmfe_vllm
|
||||
else:
|
||||
lmformatenforcer = LazyLoader("lmformatenforcer", globals(), "lmformatenforcer")
|
||||
lmfe_vllm = LazyLoader(
|
||||
"lmformatenforcer.integrations.vllm",
|
||||
globals(),
|
||||
"lmformatenforcer.integrations.vllm",
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||
tokenizer: PreTrainedTokenizerBase, vocab_size: int
|
||||
) -> "lmfe_vllm.TokenEnforcerTokenizerData":
|
||||
return lmfe_vllm.build_vllm_token_enforcer_tokenizer_data(
|
||||
tokenizer, use_bitmask=True, vocab_size=vocab_size
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LMFormatEnforcerGrammar(StructuredOutputGrammar):
|
||||
token_enforcer: lmformatenforcer.TokenEnforcer
|
||||
current_tokens_prefix: list[int] = field(default_factory=list)
|
||||
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
original_len = len(self.current_tokens_prefix)
|
||||
for token in tokens:
|
||||
if not self.token_enforcer.get_allowed_tokens(
|
||||
self.current_tokens_prefix
|
||||
).is_token_allowed(token):
|
||||
# Rollback partial updates to ensure atomicity.
|
||||
del self.current_tokens_prefix[original_len:]
|
||||
return False
|
||||
self.current_tokens_prefix.append(token)
|
||||
return True
|
||||
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
for prefix_length in range(len(tokens)):
|
||||
prefix = tokens[:prefix_length]
|
||||
next_token = tokens[prefix_length]
|
||||
if not self.token_enforcer.get_allowed_tokens(
|
||||
self.current_tokens_prefix + prefix
|
||||
).is_token_allowed(next_token):
|
||||
break
|
||||
else:
|
||||
return tokens
|
||||
|
||||
return tokens[:prefix_length]
|
||||
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
self.current_tokens_prefix = self.current_tokens_prefix[:-num_tokens]
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
|
||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(
|
||||
self.current_tokens_prefix
|
||||
)
|
||||
bitmask[batch_index] = allowed_tokens.allowed_tokens
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
# We are considered terminated if the prefix ends with eos_token_id
|
||||
return_value = (
|
||||
len(self.current_tokens_prefix) > 0
|
||||
and self.current_tokens_prefix[-1] == self.token_enforcer.eos_token_id
|
||||
)
|
||||
return return_value
|
||||
|
||||
def reset(self):
|
||||
self.current_tokens_prefix = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class LMFormatEnforcerBackend(StructuredOutputBackend):
|
||||
def __post_init__(self):
|
||||
self.tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||
self.tokenizer, self.vocab_size
|
||||
)
|
||||
|
||||
def compile_grammar(
|
||||
self, request_type: StructuredOutputOptions, grammar_spec: str
|
||||
) -> StructuredOutputGrammar:
|
||||
character_level_parser: lmformatenforcer.CharacterLevelParser
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
spec_dict = json.loads(grammar_spec)
|
||||
character_level_parser = lmformatenforcer.JsonSchemaParser(spec_dict)
|
||||
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
||||
character_level_parser = lmformatenforcer.JsonSchemaParser(None)
|
||||
elif request_type == StructuredOutputOptions.REGEX:
|
||||
character_level_parser = lmformatenforcer.RegexParser(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.CHOICE:
|
||||
choices = ast.literal_eval(grammar_spec)
|
||||
character_level_parser = lmformatenforcer.UnionParser(
|
||||
[lmformatenforcer.StringParser(choice) for choice in choices]
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid request type for LM Format Enforcer backend({request_type!s})"
|
||||
)
|
||||
max_rollback_tokens = (
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config is not None
|
||||
else 0
|
||||
)
|
||||
|
||||
if max_rollback_tokens > 0:
|
||||
raise ValueError(
|
||||
"LM Format Enforcer backend does not support speculative tokens"
|
||||
)
|
||||
|
||||
token_enforcer = lmformatenforcer.TokenEnforcer(
|
||||
tokenizer_data=self.tokenizer_data,
|
||||
parser=character_level_parser,
|
||||
)
|
||||
return LMFormatEnforcerGrammar(token_enforcer)
|
||||
|
||||
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
|
||||
return torch.full(
|
||||
(max_num_seqs, (self.vocab_size + 31) // 32),
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
)
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
def validate_structured_output_request_lm_format_enforcer(params: SamplingParams):
|
||||
if params.structured_outputs is None:
|
||||
return
|
||||
|
||||
so_params = params.structured_outputs
|
||||
|
||||
if so_params.regex:
|
||||
return
|
||||
elif so_params.json:
|
||||
if isinstance(so_params.json, str):
|
||||
try:
|
||||
# make sure schema is valid json
|
||||
json.loads(so_params.json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError("Invalid JSON grammar specification.") from e
|
||||
else:
|
||||
try:
|
||||
json.dumps(so_params.json)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error serializing structured outputs jsonschema: {e}"
|
||||
) from e
|
||||
return
|
||||
elif so_params.choice:
|
||||
return
|
||||
elif so_params.grammar:
|
||||
raise ValueError(
|
||||
"LM Format Enforcer structured outputs backend "
|
||||
"does not support grammar specifications"
|
||||
)
|
||||
324
vllm/v1/structured_output/backend_outlines.py
Normal file
324
vllm/v1/structured_output/backend_outlines.py
Normal file
@@ -0,0 +1,324 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright 2025-present the Outlines developers
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import importlib
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from regex import escape as regex_escape
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (
|
||||
StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
StructuredOutputOptions,
|
||||
)
|
||||
from vllm.v1.structured_output.utils import (
|
||||
OutlinesVocabulary,
|
||||
get_outlines_cache,
|
||||
get_outlines_vocabulary,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import outlines_core as oc
|
||||
import outlines_core.json_schema as json_schema
|
||||
else:
|
||||
oc = LazyLoader("oc", globals(), "outlines_core")
|
||||
json_schema = LazyLoader("json_schema", globals(), "outlines_core.json_schema")
|
||||
|
||||
# Python 3.11+ sre_parse and sre_constants
|
||||
# are deprecated, so we must import them from re
|
||||
if sys.version_info >= (3, 11):
|
||||
# Hack to get around pre-commit regex module rule
|
||||
# because going through re is the only way to get sre_parse
|
||||
# and sre_constants in Python 3.11+
|
||||
_re = importlib.import_module("re")
|
||||
sre_parse = _re._parser
|
||||
sre_constants = _re._constants
|
||||
else:
|
||||
import sre_constants
|
||||
import sre_parse
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutlinesBackend(StructuredOutputBackend):
|
||||
def __post_init__(self):
|
||||
self.vocabulary = get_outlines_vocabulary(self.tokenizer)
|
||||
self.cache = get_outlines_cache()
|
||||
|
||||
def _compile_index(
|
||||
self, regex_string: str, vocabulary: OutlinesVocabulary
|
||||
) -> oc.Index:
|
||||
cache_key = f"{vocabulary._hash}_{regex_string}"
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key]
|
||||
|
||||
index = oc.Index(regex_string, vocabulary.inner)
|
||||
self.cache[cache_key] = index
|
||||
|
||||
return index
|
||||
|
||||
def compile_grammar(
|
||||
self, request_type: StructuredOutputOptions, grammar_spec: str
|
||||
) -> StructuredOutputGrammar:
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
regex = json_schema.build_regex_from_schema(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.REGEX:
|
||||
regex = grammar_spec
|
||||
elif request_type == StructuredOutputOptions.CHOICE:
|
||||
choices = ast.literal_eval(grammar_spec)
|
||||
choices = [regex_escape(c) for c in choices]
|
||||
regex = "(" + "|".join(choices) + ")"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid request type for Outlines backend ({request_type!s})"
|
||||
)
|
||||
index = self._compile_index(regex, self.vocabulary)
|
||||
max_rollback_tokens = (
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config is not None
|
||||
else 0
|
||||
)
|
||||
return OutlinesGrammar(
|
||||
vocab_size=self.vocab_size,
|
||||
guide=oc.Guide(index, max_rollback=max_rollback_tokens),
|
||||
)
|
||||
|
||||
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
|
||||
return torch.full(
|
||||
(max_num_seqs, (self.vocab_size + 31) // 32),
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
)
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutlinesGrammar(StructuredOutputGrammar):
|
||||
vocab_size: int
|
||||
guide: oc.Guide = field(hash=False)
|
||||
num_processed_tokens: int = field(
|
||||
default_factory=lambda: 0, repr=False, hash=False, init=False
|
||||
)
|
||||
|
||||
# outlines_core signals done on DFA accept; vLLM expects done after EOS.
|
||||
# We delay the finished flag by one step so EOS can still be emitted.
|
||||
_prev_finished: bool = field(default=False, init=False, repr=False, hash=False)
|
||||
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""Accepts a list of tokens and advances the FSM.
|
||||
|
||||
Returns True if the FSM was advanced successfully.
|
||||
Returns False if the FSM failed to advance.
|
||||
"""
|
||||
if self.guide.accepts_tokens(tokens):
|
||||
# Advance cannot fail because we checked Guide.accepts_tokens()
|
||||
for t in tokens:
|
||||
self.guide.advance(t)
|
||||
self.num_processed_tokens += 1
|
||||
return True
|
||||
return False
|
||||
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
self.guide.rollback_state(num_tokens)
|
||||
self.num_processed_tokens -= num_tokens
|
||||
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
accepted: list[int] = []
|
||||
for tok in tokens:
|
||||
accepted.append(tok)
|
||||
if not self.guide.accepts_tokens(accepted):
|
||||
accepted.pop()
|
||||
break
|
||||
return accepted
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||
mask = bitmask[idx]
|
||||
self.guide.write_mask_into(mask.data_ptr(), mask.numel(), mask.element_size())
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
curr = self.guide.is_finished()
|
||||
prev = self._prev_finished
|
||||
self._prev_finished = curr
|
||||
return prev
|
||||
|
||||
def reset(self):
|
||||
self.num_processed_tokens = 0
|
||||
self._prev_finished = False
|
||||
self.guide.reset()
|
||||
|
||||
|
||||
def validate_structured_output_request_outlines(params: SamplingParams):
|
||||
if params.structured_outputs is None:
|
||||
return
|
||||
|
||||
so_params = params.structured_outputs
|
||||
|
||||
if so_params.regex:
|
||||
validate_regex_is_buildable(so_params.regex)
|
||||
elif so_params.json:
|
||||
if isinstance(so_params.json, str):
|
||||
try:
|
||||
# make sure schema is valid json
|
||||
json.loads(so_params.json)
|
||||
schema = so_params.json
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError("Invalid JSON grammar specification.") from e
|
||||
else:
|
||||
try:
|
||||
schema = json.dumps(so_params.json)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error serializing structured outputs jsonschema: {e}"
|
||||
) from e
|
||||
pattern = json_schema.build_regex_from_schema(schema)
|
||||
validate_regex_is_buildable(pattern)
|
||||
elif so_params.choice:
|
||||
choices = [regex_escape(str(choice)) for choice in so_params.choice]
|
||||
regex = "(" + "|".join(choices) + ")"
|
||||
validate_regex_is_buildable(regex)
|
||||
elif so_params.grammar:
|
||||
raise ValueError(
|
||||
"Outlines structured outputs backend "
|
||||
"does not support grammar specifications"
|
||||
)
|
||||
|
||||
|
||||
def _prefix_needs_context(parsed) -> bool:
|
||||
"""Return True if there's a look-around/anchor before any consumer."""
|
||||
|
||||
def subpattern_consumes(parsed) -> bool:
|
||||
"""Return True if subpattern can consume at least one character."""
|
||||
tokens = parsed.data if hasattr(parsed, "data") else parsed
|
||||
for ttype, tval in tokens:
|
||||
# literal, character class, or dot always consumes
|
||||
if ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
|
||||
return True
|
||||
# quantified subpattern: check inner pattern
|
||||
elif ttype == sre_parse.MAX_REPEAT:
|
||||
_, mx, sub = tval
|
||||
if mx != 0 and subpattern_consumes(sub):
|
||||
return True
|
||||
# alternation: if any branch consumes, the whole does
|
||||
elif ttype == sre_parse.BRANCH:
|
||||
_, branches = tval
|
||||
if any(subpattern_consumes(br) for br in branches):
|
||||
return True
|
||||
# grouped subpattern: recurse into its contents
|
||||
elif ttype == sre_parse.SUBPATTERN and subpattern_consumes(tval[3]):
|
||||
return True
|
||||
# No consumers, return False
|
||||
return False
|
||||
|
||||
tokens = parsed.data if hasattr(parsed, "data") else parsed
|
||||
for ttype, tval in tokens:
|
||||
# Direct anchors or look-around
|
||||
if ttype == sre_parse.AT or ttype in (
|
||||
sre_constants.ASSERT,
|
||||
sre_constants.ASSERT_NOT,
|
||||
):
|
||||
return True
|
||||
|
||||
# Nested subpattern: check
|
||||
if ttype == sre_parse.SUBPATTERN:
|
||||
# tval: (group, add_flags, del_flags, subpattern)
|
||||
if _prefix_needs_context(tval[3]):
|
||||
return True
|
||||
if subpattern_consumes(tval[3]):
|
||||
return False
|
||||
|
||||
# if any branch has a prefix anchor => True,
|
||||
# else if at least one branch consumes => prefix ends => False
|
||||
elif ttype == sre_parse.BRANCH:
|
||||
saw_consumer = False
|
||||
for br in tval[1]:
|
||||
if _prefix_needs_context(br):
|
||||
return True
|
||||
if subpattern_consumes(br):
|
||||
saw_consumer = True
|
||||
if saw_consumer:
|
||||
return False
|
||||
|
||||
# Immediate consumer tokens
|
||||
elif ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
|
||||
return False
|
||||
|
||||
# if subpattern has anchor => True, if it can consume => stop
|
||||
elif ttype == sre_parse.MAX_REPEAT:
|
||||
if _prefix_needs_context(tval[2]):
|
||||
return True
|
||||
if subpattern_consumes(tval[2]):
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _check_unsupported(parsed) -> None:
|
||||
"""Check for regex features unsupported by regex-automata"""
|
||||
tokens = parsed.data if hasattr(parsed, "data") else parsed
|
||||
for ttype, tval in tokens:
|
||||
# backreference
|
||||
if ttype in (sre_parse.GROUPREF, sre_parse.GROUPREF_EXISTS):
|
||||
raise ValueError("Backreferences are unsupported.")
|
||||
|
||||
# look-around assertion
|
||||
elif ttype in (sre_constants.ASSERT, sre_constants.ASSERT_NOT):
|
||||
raise ValueError("Look-Around assertion are unsupported.")
|
||||
|
||||
# unicode word boundaries
|
||||
elif ttype == sre_parse.AT:
|
||||
if tval in (sre_constants.AT_BOUNDARY, sre_constants.AT_NON_BOUNDARY):
|
||||
raise ValueError("Unicode word boundaries are unsupported.")
|
||||
|
||||
elif ttype == sre_parse.BRANCH:
|
||||
# tval is (None, branches)
|
||||
for branch in tval[1]:
|
||||
_check_unsupported(branch)
|
||||
|
||||
# tval is (min, max, subpattern)
|
||||
elif ttype == sre_parse.MAX_REPEAT:
|
||||
_check_unsupported(tval[2])
|
||||
|
||||
|
||||
def validate_regex_is_buildable(pattern: str) -> None:
|
||||
"""
|
||||
Validates that the input regex is not using unsupported features
|
||||
of the `regex-automata` crate (outlines_core regex engine) and has a
|
||||
universal start state.
|
||||
definition of universal start state used can be found at:
|
||||
https://docs.rs/regex-automata/latest/regex_automata/dfa/trait.Automaton.html#method.universal_start_state
|
||||
"""
|
||||
try:
|
||||
parsed = sre_parse.parse(pattern)
|
||||
|
||||
except sre_constants.error as e:
|
||||
raise ValueError(f"Error parsing regex: {e}") from e
|
||||
|
||||
try:
|
||||
_check_unsupported(parsed)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Regex uses unsupported feature for structured outputs: {e}. "
|
||||
"Only basic matching constructs are supported—lookarounds, "
|
||||
"backreferences, and unicode boundaries are not."
|
||||
) from e
|
||||
|
||||
if _prefix_needs_context(parsed):
|
||||
raise ValueError(
|
||||
"Regex does not have a anchored universal start state"
|
||||
"This means that the Regex uses anchors (^) or look-arounds "
|
||||
"in a way which requires context before any token is matched."
|
||||
"structured outputs needs regexes that can match without needing "
|
||||
"that context. Try rewriting the pattern without using these "
|
||||
f"constructs. Pattern:\n{pattern}"
|
||||
)
|
||||
136
vllm/v1/structured_output/backend_types.py
Normal file
136
vllm/v1/structured_output/backend_types.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
else:
|
||||
VllmConfig = object
|
||||
TokenizerLike = object
|
||||
|
||||
|
||||
class StructuredOutputOptions(enum.Enum):
|
||||
JSON = enum.auto()
|
||||
JSON_OBJECT = enum.auto()
|
||||
REGEX = enum.auto()
|
||||
GRAMMAR = enum.auto()
|
||||
CHOICE = enum.auto()
|
||||
STRUCTURAL_TAG = enum.auto()
|
||||
|
||||
|
||||
StructuredOutputKey = tuple[StructuredOutputOptions, str]
|
||||
|
||||
|
||||
class StructuredOutputGrammar(ABC):
|
||||
"""Request-level backend for structured output requests."""
|
||||
|
||||
@abstractmethod
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""
|
||||
Determines whether the provided tokens are accepted for the
|
||||
given request.
|
||||
|
||||
Args:
|
||||
request_id (str): The unique identifier for the request.
|
||||
tokens (list[int]): A list of token IDs to evaluate.
|
||||
|
||||
Returns:
|
||||
bool: True if the tokens are accepted, False otherwise.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
"""
|
||||
Validates the provided tokens against the grammar.
|
||||
Will not advance the FSM.
|
||||
|
||||
Args:
|
||||
tokens (list[int]): A list of token IDs to validate.
|
||||
|
||||
Returns:
|
||||
list[int]: A list of accepted token IDs. Will be a prefix
|
||||
of the input tokens, and empty if none are accepted.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
"""
|
||||
Rolls back the state of the grammar by a specified number of tokens.
|
||||
Will also revert counters for the number of processed tokens.
|
||||
|
||||
Args:
|
||||
num_tokens (int): The number of tokens to roll back.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def fill_bitmask(self, bitmask: "torch.Tensor", batch_index: int) -> None:
|
||||
"""
|
||||
Fills the bitmask for a specific batch index.
|
||||
|
||||
Args:
|
||||
bitmask (torch.Tensor): The bitmask to fill
|
||||
batch_index (int): The index in the bitmask to fill
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def is_terminated(self) -> bool:
|
||||
"""
|
||||
Checks whether the structured output process has terminated.
|
||||
|
||||
Returns:
|
||||
bool: True if the process is terminated, False otherwise.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the state of the structured output grammar.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructuredOutputBackend(ABC):
|
||||
"""Engine-level backend for structured output requests."""
|
||||
|
||||
vllm_config: VllmConfig
|
||||
tokenizer: TokenizerLike
|
||||
vocab_size: int
|
||||
|
||||
@abstractmethod
|
||||
def compile_grammar(
|
||||
self, request_type: StructuredOutputOptions, grammar_spec: str
|
||||
) -> StructuredOutputGrammar:
|
||||
"""
|
||||
Compiles a grammar specification into a structured output grammar.
|
||||
|
||||
Args:
|
||||
request_type (StructuredOutputOptions): The type of structured
|
||||
output request.
|
||||
grammar_spec (str): The grammar specification to compile.
|
||||
|
||||
Returns:
|
||||
StructuredOutputGrammar: The compiled structured output grammar.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def allocate_token_bitmask(self, max_num_seqs: int) -> "torch.Tensor":
|
||||
"""
|
||||
Allocates a token bitmask for the specified maximum number of sequences.
|
||||
|
||||
Args:
|
||||
max_num_seqs (int): The maximum number of sequences for which
|
||||
to allocate the bitmask.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def destroy(self):
|
||||
"""
|
||||
Backend-specific cleanup.
|
||||
"""
|
||||
378
vllm/v1/structured_output/backend_xgrammar.py
Normal file
378
vllm/v1/structured_output/backend_xgrammar.py
Normal file
@@ -0,0 +1,378 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (
|
||||
StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
StructuredOutputOptions,
|
||||
)
|
||||
from vllm.v1.structured_output.utils import (
|
||||
choice_as_grammar,
|
||||
convert_lark_to_ebnf,
|
||||
grammar_is_likely_lark,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr
|
||||
else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class XgrammarBackend(StructuredOutputBackend):
|
||||
def __post_init__(self):
|
||||
self.disable_any_whitespace = (
|
||||
self.vllm_config.structured_outputs_config.disable_any_whitespace
|
||||
)
|
||||
|
||||
if isinstance(self.tokenizer, MistralTokenizer):
|
||||
# NOTE: ideally, xgrammar should handle this accordingly.
|
||||
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
|
||||
stop_token_ids = [self.tokenizer.eos_token_id]
|
||||
|
||||
# not self.tokenizer.vocab_size as self.tokenizer.vocab
|
||||
# collapses all decoded errors into a single token.
|
||||
self.vocab_size = len(self.tokenizer.vocab)
|
||||
tokenizer_info = xgr.TokenizerInfo( # type: ignore
|
||||
encoded_vocab=self.tokenizer.vocab,
|
||||
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
|
||||
vocab_type=xgr.VocabType.RAW
|
||||
if self.tokenizer.is_tekken
|
||||
else xgr.VocabType.BYTE_FALLBACK,
|
||||
vocab_size=self.vocab_size,
|
||||
stop_token_ids=stop_token_ids,
|
||||
add_prefix_space=True,
|
||||
)
|
||||
elif isinstance(self.tokenizer, DeepseekV32Tokenizer):
|
||||
# copy from xgr.TokenizerInfo.from_huggingface()
|
||||
# because we are using a custom tokenizer wrapper here.
|
||||
vocab_dict = self.tokenizer.get_vocab()
|
||||
tokenizer_vocab_size = max(len(vocab_dict), self.tokenizer.max_token_id + 1)
|
||||
vocab_size = self.vocab_size or tokenizer_vocab_size
|
||||
# maintain tokenizer's indexing
|
||||
encoded_vocab = [""] * vocab_size
|
||||
for token, idx in vocab_dict.items():
|
||||
if idx < vocab_size:
|
||||
encoded_vocab[idx] = token
|
||||
stop_token_ids = [self.tokenizer.eos_token_id]
|
||||
backend_str = self.tokenizer.tokenizer.backend_tokenizer.to_str()
|
||||
metadata = xgr.TokenizerInfo._detect_metadata_from_hf(backend_str)
|
||||
tokenizer_info = xgr.TokenizerInfo(
|
||||
encoded_vocab=encoded_vocab,
|
||||
vocab_type=metadata["vocab_type"],
|
||||
vocab_size=vocab_size,
|
||||
stop_token_ids=stop_token_ids,
|
||||
add_prefix_space=metadata["add_prefix_space"],
|
||||
)
|
||||
else:
|
||||
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
||||
self.tokenizer,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
self.compiler = xgr.GrammarCompiler(
|
||||
tokenizer_info,
|
||||
max_threads=8,
|
||||
cache_enabled=True,
|
||||
cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024,
|
||||
)
|
||||
|
||||
self.num_speculative_tokens = 0
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
self.num_speculative_tokens = (
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
)
|
||||
|
||||
def compile_grammar(
|
||||
self, request_type: StructuredOutputOptions, grammar_spec: str
|
||||
) -> StructuredOutputGrammar:
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
ctx = self.compiler.compile_json_schema(
|
||||
grammar_spec, any_whitespace=not self.disable_any_whitespace
|
||||
)
|
||||
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
||||
ctx = self.compiler.compile_json_schema(
|
||||
'{"type": "object"}', any_whitespace=not self.disable_any_whitespace
|
||||
)
|
||||
elif request_type == StructuredOutputOptions.GRAMMAR:
|
||||
ctx = self.compiler.compile_grammar(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.REGEX:
|
||||
ctx = self.compiler.compile_regex(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
|
||||
s_tag = json.loads(grammar_spec)
|
||||
if "structures" in s_tag:
|
||||
# Falling back to deprecated method of compiling structural tag
|
||||
tags = [
|
||||
xgr.StructuralTagItem(
|
||||
begin=s["begin"],
|
||||
schema=json.dumps(s["schema"]),
|
||||
end=s["end"],
|
||||
)
|
||||
for s in s_tag["structures"]
|
||||
]
|
||||
ctx = self.compiler.compile_structural_tag(tags, s_tag["triggers"])
|
||||
else:
|
||||
ctx = self.compiler.compile_structural_tag(grammar_spec)
|
||||
else:
|
||||
logger.error(
|
||||
"Validation should have already occurred. Please file an issue."
|
||||
)
|
||||
raise ValueError(
|
||||
f"grammar is not of valid supported types. ({request_type!s})"
|
||||
)
|
||||
|
||||
return XgrammarGrammar(
|
||||
matcher=xgr.GrammarMatcher(
|
||||
ctx,
|
||||
max_rollback_tokens=self.num_speculative_tokens,
|
||||
),
|
||||
vocab_size=self.vocab_size,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
def allocate_token_bitmask(self, max_num_seqs: int):
|
||||
return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size)
|
||||
|
||||
def destroy(self):
|
||||
del self.compiler
|
||||
|
||||
|
||||
@dataclass
|
||||
class XgrammarGrammar(StructuredOutputGrammar):
|
||||
# NOTE: This would be a generic-enough class for
|
||||
# supporting different backends, in the future.
|
||||
# For now, just xgrammar.
|
||||
#
|
||||
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
|
||||
# for jump-forward decoding
|
||||
|
||||
vocab_size: int
|
||||
matcher: xgr.GrammarMatcher = field(hash=False)
|
||||
ctx: xgr.CompiledGrammar = field(hash=False)
|
||||
num_processed_tokens: int = field(
|
||||
default_factory=lambda: 0, repr=False, hash=False, init=False
|
||||
)
|
||||
_is_terminated: bool = field(default=False, repr=False, hash=False)
|
||||
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""Accepts a list of tokens and advances the FSM.
|
||||
|
||||
Returns True if the FSM was advanced successfully.
|
||||
Returns False if the FSM failed to advance.
|
||||
"""
|
||||
if self._is_terminated:
|
||||
return False
|
||||
for token in tokens:
|
||||
if not self.matcher.accept_token(token):
|
||||
logger.error(
|
||||
"Failed to advance FSM for request %s "
|
||||
"for tokens %s. Please file an issue.",
|
||||
request_id,
|
||||
token,
|
||||
)
|
||||
return False
|
||||
self.num_processed_tokens += 1
|
||||
self._is_terminated = self.matcher.is_terminated()
|
||||
return True
|
||||
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
"""Checks if the list of tokens are accepted by the FSM in sequence.
|
||||
Will not advance the FSM.
|
||||
|
||||
Returns the prefix list of tokens that are accepted by the FSM.
|
||||
"""
|
||||
accepted_tokens = []
|
||||
for token in tokens:
|
||||
if self.matcher.accept_token(token):
|
||||
accepted_tokens.append(token)
|
||||
else:
|
||||
break
|
||||
if len(accepted_tokens) > 0:
|
||||
# Rollback the FSM to the initial state
|
||||
self.matcher.rollback(len(accepted_tokens))
|
||||
return accepted_tokens
|
||||
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
self.matcher.rollback(num_tokens)
|
||||
self.num_processed_tokens -= num_tokens
|
||||
self._is_terminated = self.matcher.is_terminated()
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||
self.matcher.fill_next_token_bitmask(bitmask, idx)
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
return self._is_terminated
|
||||
|
||||
def reset(self):
|
||||
self.num_processed_tokens = 0
|
||||
self.matcher.reset()
|
||||
|
||||
|
||||
# cf https://github.com/mlc-ai/xgrammar/blob/a32ac892676d2eedc0327416105b9b06edfb94b2/cpp/json_schema_converter.cc
|
||||
STRING_SUPPORTED_FORMATS = {
|
||||
"email",
|
||||
"date",
|
||||
"time",
|
||||
"date-time",
|
||||
"duration",
|
||||
"ipv4",
|
||||
"ipv6",
|
||||
"hostname",
|
||||
"uuid",
|
||||
"uri",
|
||||
"uri-reference",
|
||||
"uri-template",
|
||||
"json-pointer",
|
||||
"relative-json-pointer",
|
||||
}
|
||||
|
||||
|
||||
def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
|
||||
"""Check if JSON schema contains features unsupported by xgrammar."""
|
||||
|
||||
def check_object(obj: dict[str, Any]) -> bool:
|
||||
if not isinstance(obj, dict):
|
||||
return False
|
||||
|
||||
# Check for numeric ranges
|
||||
if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
|
||||
return True
|
||||
|
||||
# Check for array unsupported keywords
|
||||
if obj.get("type") == "array" and any(
|
||||
key in obj
|
||||
for key in ("uniqueItems", "contains", "minContains", "maxContains")
|
||||
):
|
||||
return True
|
||||
|
||||
# Unsupported keywords for strings
|
||||
if (
|
||||
obj.get("type") == "string"
|
||||
and "format" in obj
|
||||
and obj["format"] not in STRING_SUPPORTED_FORMATS
|
||||
):
|
||||
return True
|
||||
|
||||
# Unsupported keywords for objects
|
||||
if obj.get("type") == "object" and any(
|
||||
key in obj for key in ("patternProperties", "propertyNames")
|
||||
):
|
||||
return True
|
||||
|
||||
# Recursively check all nested objects and arrays
|
||||
for value in obj.values():
|
||||
if isinstance(value, dict):
|
||||
if check_object(value):
|
||||
return True
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and check_object(item):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return check_object(schema)
|
||||
|
||||
|
||||
def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
|
||||
"""Validate that the request is supported by structured output.
|
||||
|
||||
Raises ValueError if the request is not supported.
|
||||
"""
|
||||
if sampling_params.structured_outputs is None:
|
||||
return
|
||||
|
||||
so_params = sampling_params.structured_outputs
|
||||
|
||||
if so_params.regex:
|
||||
try:
|
||||
xgr.Grammar.from_regex(so_params.regex)
|
||||
except Exception as err:
|
||||
raise ValueError(
|
||||
f"Failed to transform regex into a grammar: {err}"
|
||||
) from err
|
||||
|
||||
if so_params.choice:
|
||||
choice_grammar = choice_as_grammar(so_params.choice)
|
||||
try:
|
||||
xgr.Grammar.from_ebnf(choice_grammar)
|
||||
except Exception as err:
|
||||
raise ValueError(
|
||||
"Failed to transform choices into a grammar: {err}"
|
||||
) from err
|
||||
so_params.choice = None
|
||||
so_params.grammar = choice_grammar
|
||||
return
|
||||
|
||||
if so_params.json:
|
||||
if isinstance(so_params.json, str):
|
||||
try:
|
||||
schema = json.loads(so_params.json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError("Invalid JSON grammar specification.") from e
|
||||
else:
|
||||
schema = so_params.json
|
||||
|
||||
try:
|
||||
xgr.Grammar.from_json_schema(schema)
|
||||
except Exception as err:
|
||||
raise ValueError(
|
||||
f"Failed to transform json schema into a grammar: {err}"
|
||||
) from err
|
||||
|
||||
if has_xgrammar_unsupported_json_features(schema):
|
||||
raise ValueError(
|
||||
"The provided JSON schema contains features not supported by xgrammar."
|
||||
)
|
||||
return
|
||||
|
||||
if so_params.grammar:
|
||||
if grammar_is_likely_lark(so_params.grammar):
|
||||
# xgrammar supports EBNF grammars only
|
||||
try:
|
||||
so_params.grammar = convert_lark_to_ebnf(so_params.grammar)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"Failed to convert the grammar from Lark to EBNF. "
|
||||
) from e
|
||||
|
||||
# Test parsing EBNF grammar, possibly already converted from Lark
|
||||
try:
|
||||
# parse the grammar, but we aren't compiling it.
|
||||
xgr.Grammar.from_ebnf(so_params.grammar)
|
||||
except Exception as e:
|
||||
raise ValueError("Invalid grammar specification.") from e
|
||||
return
|
||||
|
||||
if so_params.structural_tag:
|
||||
try:
|
||||
s_tag = json.loads(so_params.structural_tag)
|
||||
|
||||
# Using the deprecated method of compiling structural tag
|
||||
if "structures" in s_tag:
|
||||
tags = [
|
||||
xgr.StructuralTagItem(
|
||||
begin=s["begin"],
|
||||
schema=json.dumps(s["schema"]),
|
||||
end=s["end"],
|
||||
)
|
||||
for s in s_tag["structures"]
|
||||
]
|
||||
xgr.Grammar.from_structural_tag(tags, s_tag["triggers"])
|
||||
else:
|
||||
xgr.Grammar.from_structural_tag(so_params.structural_tag)
|
||||
except Exception as e:
|
||||
raise ValueError("Invalid structural tag specification.") from e
|
||||
94
vllm/v1/structured_output/request.py
Normal file
94
vllm/v1/structured_output/request.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
import functools
|
||||
import json
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures._base import TimeoutError
|
||||
from typing import cast
|
||||
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
from vllm.v1.structured_output.backend_types import (
|
||||
StructuredOutputGrammar,
|
||||
StructuredOutputKey,
|
||||
StructuredOutputOptions,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StructuredOutputRequest:
|
||||
params: StructuredOutputsParams
|
||||
_grammar: Future[StructuredOutputGrammar] | StructuredOutputGrammar | None = None
|
||||
reasoning_ended: bool | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_sampling_params(
|
||||
sampling_params: SamplingParams | None,
|
||||
) -> "StructuredOutputRequest | None":
|
||||
if sampling_params is None:
|
||||
return None
|
||||
params = sampling_params.structured_outputs
|
||||
if params:
|
||||
if params.all_constraints_none():
|
||||
return None
|
||||
else:
|
||||
return StructuredOutputRequest(params=params)
|
||||
return None
|
||||
|
||||
def _check_grammar_completion(self) -> bool:
|
||||
# NOTE: We have to lazy import to gate circular imports
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
if isinstance(self._grammar, Future):
|
||||
try:
|
||||
# We will check whether the future is ready within 100 us
|
||||
self._grammar = self._grammar.result(timeout=0.0001)
|
||||
self.status = RequestStatus.WAITING
|
||||
except TimeoutError:
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_grammar_ready(self) -> bool:
|
||||
return self._check_grammar_completion()
|
||||
|
||||
@property
|
||||
def grammar(self) -> StructuredOutputGrammar | None:
|
||||
completed = self._check_grammar_completion()
|
||||
return (
|
||||
cast(StructuredOutputGrammar | None, self._grammar) if completed else None
|
||||
)
|
||||
|
||||
@grammar.setter
|
||||
def grammar(
|
||||
self, grammar: StructuredOutputGrammar | Future[StructuredOutputGrammar]
|
||||
) -> None:
|
||||
self._grammar = grammar
|
||||
|
||||
@functools.cached_property
|
||||
def structured_output_key(self) -> StructuredOutputKey:
|
||||
return get_structured_output_key(self.params)
|
||||
|
||||
|
||||
def get_structured_output_key(params: StructuredOutputsParams) -> StructuredOutputKey:
|
||||
if params.json is not None:
|
||||
if not isinstance(params.json, str):
|
||||
json_str = json.dumps(params.json)
|
||||
else:
|
||||
json_str = params.json
|
||||
return StructuredOutputOptions.JSON, json_str
|
||||
if params.json_object:
|
||||
return StructuredOutputOptions.JSON_OBJECT, ""
|
||||
if params.regex is not None:
|
||||
return StructuredOutputOptions.REGEX, params.regex
|
||||
if params.choice is not None:
|
||||
if not isinstance(params.choice, str):
|
||||
json_str = json.dumps(params.choice)
|
||||
else:
|
||||
json_str = params.choice
|
||||
return StructuredOutputOptions.CHOICE, json_str
|
||||
if params.grammar is not None:
|
||||
return StructuredOutputOptions.GRAMMAR, params.grammar
|
||||
if params.structural_tag is not None:
|
||||
return StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag
|
||||
raise ValueError("No valid structured output parameter found")
|
||||
469
vllm/v1/structured_output/utils.py
Normal file
469
vllm/v1/structured_output/utils.py
Normal file
@@ -0,0 +1,469 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import importlib.metadata
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import regex as re
|
||||
import torch
|
||||
from cachetools import LRUCache
|
||||
from diskcache import Cache
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import outlines_core as oc
|
||||
import transformers.file_utils as file_utils
|
||||
import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2
|
||||
import xgrammar as xgr
|
||||
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
oc = LazyLoader("oc", globals(), "outlines_core")
|
||||
file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils")
|
||||
tokenization_gpt2 = LazyLoader(
|
||||
"tokenization_gpt2",
|
||||
globals(),
|
||||
"transformers.models.gpt2.tokenization_gpt2",
|
||||
)
|
||||
|
||||
TokenizerLike = object
|
||||
SchedulerOutput = object
|
||||
InputBatch = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
CACHE = None
|
||||
|
||||
|
||||
def apply_grammar_bitmask(
|
||||
scheduler_output: SchedulerOutput,
|
||||
grammar_output: GrammarOutput,
|
||||
input_batch: InputBatch,
|
||||
logits: torch.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Apply grammar bitmask to output logits of the model with xgrammar function.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): The result of engine scheduling.
|
||||
input_batch (InputBatch): The input of model runner.
|
||||
logits (torch.Tensor): The output logits of model forward.
|
||||
"""
|
||||
# Serialization of np.ndarray is much more efficient than a tensor,
|
||||
# so we receive it in that format.
|
||||
grammar_bitmask = grammar_output.grammar_bitmask
|
||||
|
||||
# We receive the structured output bitmask from the scheduler,
|
||||
# compacted to contain bitmasks only for structured output requests.
|
||||
# The order of the requests in the bitmask is not guaranteed to be the
|
||||
# same as the order of the requests in the gpu runner's batch. We need
|
||||
# to sort the bitmask to match the order of the requests used here.
|
||||
|
||||
# Get the batch indices of the structured output requests.
|
||||
# Keep track of the number of speculative tokens scheduled for every
|
||||
# request in the batch, as the logit indices are offset by this amount.
|
||||
struct_out_req_batch_indices: dict[str, int] = {}
|
||||
cumulative_offset = 0
|
||||
seq = sorted(input_batch.req_id_to_index.items(), key=lambda x: x[1])
|
||||
for req_id, batch_index in seq:
|
||||
logit_index = batch_index + cumulative_offset
|
||||
cumulative_offset += len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
||||
)
|
||||
if req_id in grammar_output.structured_output_request_ids:
|
||||
struct_out_req_batch_indices[req_id] = logit_index
|
||||
|
||||
out_indices = []
|
||||
|
||||
# Reorder the bitmask to match the order of the requests in the batch.
|
||||
sorted_bitmask = np.full(
|
||||
shape=(logits.shape[0], grammar_bitmask.shape[1]),
|
||||
fill_value=-1,
|
||||
dtype=grammar_bitmask.dtype,
|
||||
)
|
||||
cumulative_index = 0
|
||||
for req_id in grammar_output.structured_output_request_ids:
|
||||
num_spec_tokens = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
||||
)
|
||||
if req_id in struct_out_req_batch_indices:
|
||||
logit_index = struct_out_req_batch_indices[req_id]
|
||||
for i in range(1 + num_spec_tokens):
|
||||
sorted_bitmask[logit_index + i] = grammar_bitmask[cumulative_index + i]
|
||||
out_indices.append(logit_index + i)
|
||||
cumulative_index += 1 + num_spec_tokens
|
||||
|
||||
# Copy async to device as tensor.
|
||||
grammar_bitmask = torch.from_numpy(sorted_bitmask).to(
|
||||
logits.device, non_blocking=True
|
||||
)
|
||||
|
||||
# If the length of out indices and the logits have the same shape
|
||||
# we don't need to pass indices to the kernel,
|
||||
# since the bitmask is already aligned with the logits.
|
||||
skip_out_indices = len(out_indices) == logits.shape[0]
|
||||
|
||||
index_tensor = None
|
||||
if not skip_out_indices:
|
||||
# xgrammar expects a python list of indices but it will actually work with
|
||||
# a tensor. If we copy the tensor ourselves here we can do it in a non_blocking
|
||||
# manner and there should be no cpu sync within xgrammar.
|
||||
index_tensor = torch.tensor(
|
||||
out_indices, dtype=torch.int32, device="cpu", pin_memory=True
|
||||
)
|
||||
index_tensor = index_tensor.to(logits.device, non_blocking=True)
|
||||
|
||||
xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor)
|
||||
|
||||
|
||||
class OutlinesVocabulary:
|
||||
"""
|
||||
Wrapper class for `outlines_core.Vocabulary`,
|
||||
which allows us to store a hash with the vocabulary
|
||||
"""
|
||||
|
||||
def __init__(self, vocabulary: oc.Vocabulary) -> None:
|
||||
# Actual vocabulary object
|
||||
self.inner = vocabulary
|
||||
# Have to do abs(hash()) because python hashes can
|
||||
# be negative, and we are using hash as a cache key.
|
||||
hex_str = hashlib.sha256(vocabulary.__repr__().encode("utf-8")).hexdigest()
|
||||
hash_int = int(hex_str, 16)
|
||||
self._hash = hash_int
|
||||
|
||||
|
||||
def get_outlines_cache_path() -> str:
|
||||
"""Get the context object that contains previously-computed return values"""
|
||||
outlines_cache_dir = os.getenv("OUTLINES_CACHE_DIR")
|
||||
xdg_cache_home = os.getenv("XDG_CACHE_HOME")
|
||||
home_dir = os.path.expanduser("~")
|
||||
|
||||
if outlines_cache_dir:
|
||||
# OUTLINES_CACHE_DIR takes precedence
|
||||
return outlines_cache_dir
|
||||
elif xdg_cache_home:
|
||||
return os.path.join(xdg_cache_home, ".cache", "outlines")
|
||||
# If homedir is "/", we may be inside a container, and thus writing to
|
||||
# root would be problematic, so we fall back to using a tempfile.
|
||||
# Also validate the path exists, since os.path.expanduser does
|
||||
# not guarantee existence.
|
||||
elif os.path.isdir(home_dir) and home_dir != "/":
|
||||
# Default Unix fallback: ~/.cache/outlines
|
||||
return os.path.join(home_dir, ".cache", "outlines")
|
||||
else:
|
||||
import tempfile
|
||||
|
||||
# home_dir may be / inside a docker container without existing user
|
||||
tempdir = tempfile.gettempdir()
|
||||
return os.path.join(tempdir, ".cache", "outlines")
|
||||
|
||||
|
||||
def get_outlines_cache():
|
||||
"""Get the Cache instance to be used for index caching"""
|
||||
|
||||
cache_dir = get_outlines_cache_path()
|
||||
if envs.VLLM_V1_USE_OUTLINES_CACHE:
|
||||
logger.warning(
|
||||
"Enabling outlines cache. This is an unbounded on-disk "
|
||||
"cache. It may consume a lot of disk space and should "
|
||||
"not be used with untrusted clients."
|
||||
)
|
||||
cache = Cache(cache_dir, eviction_policy="none", cull_limit=0)
|
||||
outlines_version = importlib.metadata.version("outlines_core")
|
||||
|
||||
cached_version = cache.get("__version__", None)
|
||||
if cached_version != outlines_version:
|
||||
cache.clear()
|
||||
cache.set("__version__", outlines_version)
|
||||
return cache
|
||||
else:
|
||||
return LRUCache(maxsize=128)
|
||||
|
||||
|
||||
re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$")
|
||||
re_replacement_seq = re.compile(r"^.{0,6}<7D>+.{0,6}$")
|
||||
|
||||
|
||||
def _reduced_vocabulary(
|
||||
tokenizer: TokenizerLike,
|
||||
eos_token_id: int,
|
||||
) -> dict[bytes, list[int]]:
|
||||
"""Create a map from vocabulary tokens to lists of equivalent token ids.
|
||||
|
||||
Returns:
|
||||
A Dict of token string -> equivalent token ids
|
||||
"""
|
||||
|
||||
unicode_to_bytes = {v: k for k, v in tokenization_gpt2.bytes_to_unicode().items()}
|
||||
|
||||
def convert_token_to_string(token: str) -> str:
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if (
|
||||
type(token) is str
|
||||
and token.startswith(file_utils.SPIECE_UNDERLINE)
|
||||
or token == "<0x20>"
|
||||
):
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
vocabulary: dict[bytes, list[int]] = {}
|
||||
empty_token_ids: list[int] = []
|
||||
for token, token_idx in tokenizer.get_vocab().items():
|
||||
if token in tokenizer.all_special_tokens:
|
||||
continue
|
||||
|
||||
token_str = convert_token_to_string(token)
|
||||
if token_str:
|
||||
if isinstance(token, (bytes, bytearray)):
|
||||
# For BPE tokenizers where tokens are stored as bytes.
|
||||
|
||||
# safe to ignore since token_str is of type (bytearray, bytes)
|
||||
# by this point.
|
||||
token_bytes = bytes(token_str) # type: ignore[arg-type]
|
||||
|
||||
elif "\ufffd" in token_str and not re_replacement_seq.match(token_str):
|
||||
# Handle tokens with invalid UTF-8 sequences.
|
||||
if re_llama_byte_token.match(token):
|
||||
# Llama-like tokenizers use <0xXX> for incomplete sequences.
|
||||
token_bytes = bytes([int(token[3:5], 16)])
|
||||
else:
|
||||
# GPT2 tokenizers: map each byte back using unicode_to_bytes
|
||||
byte_vals = [unicode_to_bytes.get(c) for c in token]
|
||||
if None in byte_vals:
|
||||
raise RuntimeError(
|
||||
f"Cannot convert token `{token}`"
|
||||
f" ({token_idx}) to bytes: {token_str}"
|
||||
)
|
||||
# safe to ignore, since if None in byte_vals,
|
||||
# an error is thrown.
|
||||
token_bytes = bytes(byte_vals) # type: ignore[arg-type]
|
||||
else:
|
||||
token_bytes = token_str.encode("utf-8")
|
||||
|
||||
if token_idx != eos_token_id:
|
||||
vocabulary.setdefault(token_bytes, []).append(token_idx)
|
||||
else:
|
||||
empty_token_ids.append(token_idx)
|
||||
|
||||
return vocabulary
|
||||
|
||||
|
||||
def get_outlines_vocabulary(tokenizer: TokenizerLike) -> oc.Vocabulary:
|
||||
"""Get the `Vocabulary` object for a given tokenizer."""
|
||||
if hasattr(tokenizer, "_outlines_vocabulary"):
|
||||
return tokenizer._outlines_vocabulary # type: ignore
|
||||
|
||||
try:
|
||||
if (
|
||||
hasattr(
|
||||
tokenizer,
|
||||
"eos_token_id",
|
||||
)
|
||||
and tokenizer.eos_token_id is not None
|
||||
):
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Error during structured outputs setup for outlines: Tokenizer ({type(tokenizer)}) has no `eos_token_id` property, but `eos_token_id` is required for structured outputs to work properly." # noqa: E501
|
||||
)
|
||||
|
||||
reduced_vocab = _reduced_vocabulary(
|
||||
tokenizer,
|
||||
eos_token_id, # type: ignore
|
||||
)
|
||||
vocabulary = OutlinesVocabulary(oc.Vocabulary(eos_token_id, reduced_vocab))
|
||||
tokenizer._outlines_vocabulary = vocabulary # type: ignore
|
||||
|
||||
return vocabulary
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
f"Cannot get the vocabulary of the tokenizer "
|
||||
f"({type(tokenizer)}). The tokenizer should have a "
|
||||
"get_vocab method."
|
||||
) from e
|
||||
|
||||
|
||||
def grammar_is_likely_lark(grammar_str: str) -> bool:
|
||||
"""
|
||||
Check if grammar appears to use Lark syntax.
|
||||
|
||||
Args:
|
||||
grammar_str: Input grammar string
|
||||
|
||||
Returns:
|
||||
bool: True if grammar appears to be in Lark format, False otherwise
|
||||
|
||||
Examples:
|
||||
>>> grammar_is_likely_lark("rule: 'abc'")
|
||||
True
|
||||
>>> grammar_is_likely_lark("rule ::= 'abc'")
|
||||
False
|
||||
"""
|
||||
if not grammar_str or not isinstance(grammar_str, str):
|
||||
return False
|
||||
|
||||
for line in grammar_str.split("\n"):
|
||||
# Remove both comment styles
|
||||
line = re.sub(r"(#|//).*$", "", line).strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Look for EBNF rule definition
|
||||
if "::=" in line:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def convert_lark_to_ebnf(grammar_str: str) -> str:
|
||||
"""
|
||||
Convert a Lark grammar string to EBNF format.
|
||||
|
||||
EBNF reference:
|
||||
https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
|
||||
Lark grammar reference:
|
||||
https://lark-parser.readthedocs.io/en/latest/grammar.html
|
||||
|
||||
Args:
|
||||
grammar_str: Input grammar in Lark format
|
||||
|
||||
Returns:
|
||||
str: Converted grammar in EBNF format
|
||||
|
||||
Examples:
|
||||
>>> print(convert_lark_to_ebnf("rule: 'hello'"))
|
||||
root ::= rule
|
||||
rule ::= "hello"
|
||||
"""
|
||||
if not isinstance(grammar_str, str):
|
||||
raise ValueError(f"Grammar must be a string, got {type(grammar_str)}")
|
||||
if not grammar_str.strip():
|
||||
raise ValueError("Grammar string cannot be empty")
|
||||
|
||||
defined_rules = set()
|
||||
referenced_rules = set()
|
||||
output_lines = []
|
||||
|
||||
def clean_line(line: str) -> str:
|
||||
"""Remove comments and whitespace from line."""
|
||||
return re.sub(r"(#|//).*$", "", line).strip()
|
||||
|
||||
def check_quotes(text: str, rule_name: str, line_num: int) -> None:
|
||||
"""Validate quote matching in text."""
|
||||
if text.count("'") % 2 != 0 or text.count('"') % 2 != 0:
|
||||
raise ValueError(f"Mismatched quotes in {rule_name} on line {line_num}")
|
||||
|
||||
def extract_references(text: str) -> set[str]:
|
||||
"""Extract rule references from text."""
|
||||
# Remove quoted strings and special characters
|
||||
text = re.sub(r'"[^"]*"', "", text)
|
||||
text = re.sub(r"[+*?()|\[\]{}]", " ", text)
|
||||
return set(re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", text))
|
||||
|
||||
# First pass: Find root rule and validate rule definitions
|
||||
lines = [clean_line(line) for line in grammar_str.split("\n")]
|
||||
first_rule = None
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line or line.startswith("|"):
|
||||
continue
|
||||
|
||||
if ":" in line:
|
||||
try:
|
||||
name = line.split(":", 1)[0].strip().strip("?")
|
||||
defined_rules.add(name)
|
||||
if first_rule is None:
|
||||
first_rule = name
|
||||
if name == "start":
|
||||
first_rule = "start"
|
||||
except IndexError as e:
|
||||
raise ValueError(
|
||||
f"Invalid rule format on line {line_num}. "
|
||||
"Expected 'rule_name: definition'"
|
||||
) from e
|
||||
|
||||
if not defined_rules:
|
||||
raise ValueError("No valid rules found in grammar")
|
||||
|
||||
# Add root rule
|
||||
output_lines.append(f"root ::= {first_rule}")
|
||||
|
||||
# Second pass: Process rule definitions and alternatives
|
||||
current_rule = None
|
||||
current_definition = []
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
if ":" in line and not line.startswith("|"):
|
||||
# Save previous rule if exists
|
||||
if current_rule:
|
||||
output_lines.append(
|
||||
f"{current_rule} ::= {' | '.join(current_definition)}"
|
||||
)
|
||||
|
||||
# Process new rule
|
||||
name, definition = line.split(":", 1)
|
||||
current_rule = name.strip().strip("?")
|
||||
|
||||
check_quotes(definition, f"rule '{current_rule}'", line_num)
|
||||
definition = re.sub(r"'([^']*)'", r'"\1"', definition)
|
||||
referenced_rules.update(extract_references(definition))
|
||||
current_definition = [definition.strip()]
|
||||
|
||||
elif line.startswith("|"):
|
||||
if not current_rule:
|
||||
raise ValueError(
|
||||
f"Alternative '|' on line {line_num} "
|
||||
"without a preceding rule definition"
|
||||
)
|
||||
|
||||
alt_def = line[1:].strip()
|
||||
check_quotes(
|
||||
alt_def, f"alternative for rule '{current_rule}'", line_num
|
||||
)
|
||||
alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def)
|
||||
referenced_rules.update(extract_references(alt_def))
|
||||
current_definition.append(alt_def)
|
||||
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Error on line {line_num}: {str(e)}") from e
|
||||
|
||||
# Add final rule if exists
|
||||
if current_rule:
|
||||
output_lines.append(f"{current_rule} ::= {' | '.join(current_definition)}")
|
||||
|
||||
# Validate all rules are defined
|
||||
undefined_rules = referenced_rules - defined_rules - {"root"}
|
||||
if undefined_rules:
|
||||
raise ValueError(
|
||||
f"Referenced rules are not defined: {', '.join(sorted(undefined_rules))}"
|
||||
)
|
||||
|
||||
return "\n".join(output_lines)
|
||||
|
||||
|
||||
def choice_as_grammar(choice: list[str]) -> str:
|
||||
def escape_ebnf_string(s: str) -> str:
|
||||
"""Escape special characters in a EBNF string."""
|
||||
# Escape double quotes and backslashes
|
||||
return re.sub(r'(["\\])', r"\\\1", s)
|
||||
|
||||
escaped_choices = (escape_ebnf_string(c) for c in choice)
|
||||
grammar = "root ::= " + " | ".join(f'"{c}"' for c in escaped_choices)
|
||||
return grammar
|
||||
Reference in New Issue
Block a user