Files
enginex-bi_150-vllm/v1/structured_output/backend_outlines.py
2026-03-05 18:06:10 +08:00

325 lines
12 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright 2025-present the Outlines developers
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import ast
import importlib
import json
import sys
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
from regex import escape as regex_escape
from vllm.sampling_params import SamplingParams
from vllm.utils.import_utils import LazyLoader
from vllm.v1.structured_output.backend_types import (
StructuredOutputBackend,
StructuredOutputGrammar,
StructuredOutputOptions,
)
from vllm.v1.structured_output.utils import (
OutlinesVocabulary,
get_outlines_cache,
get_outlines_vocabulary,
)
if TYPE_CHECKING:
import outlines_core as oc
import outlines_core.json_schema as json_schema
else:
oc = LazyLoader("oc", globals(), "outlines_core")
json_schema = LazyLoader("json_schema", globals(), "outlines_core.json_schema")
# Python 3.11+ sre_parse and sre_constants
# are deprecated, so we must import them from re
if sys.version_info >= (3, 11):
# Hack to get around pre-commit regex module rule
# because going through re is the only way to get sre_parse
# and sre_constants in Python 3.11+
_re = importlib.import_module("re")
sre_parse = _re._parser
sre_constants = _re._constants
else:
import sre_constants
import sre_parse
@dataclass
class OutlinesBackend(StructuredOutputBackend):
def __post_init__(self):
self.vocabulary = get_outlines_vocabulary(self.tokenizer)
self.cache = get_outlines_cache()
def _compile_index(
self, regex_string: str, vocabulary: OutlinesVocabulary
) -> oc.Index:
cache_key = f"{vocabulary._hash}_{regex_string}"
if cache_key in self.cache:
return self.cache[cache_key]
index = oc.Index(regex_string, vocabulary.inner)
self.cache[cache_key] = index
return index
def compile_grammar(
self, request_type: StructuredOutputOptions, grammar_spec: str
) -> StructuredOutputGrammar:
if request_type == StructuredOutputOptions.JSON:
regex = json_schema.build_regex_from_schema(grammar_spec)
elif request_type == StructuredOutputOptions.REGEX:
regex = grammar_spec
elif request_type == StructuredOutputOptions.CHOICE:
choices = ast.literal_eval(grammar_spec)
choices = [regex_escape(c) for c in choices]
regex = "(" + "|".join(choices) + ")"
else:
raise ValueError(
f"Invalid request type for Outlines backend ({request_type!s})"
)
index = self._compile_index(regex, self.vocabulary)
max_rollback_tokens = (
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config is not None
else 0
)
return OutlinesGrammar(
vocab_size=self.vocab_size,
guide=oc.Guide(index, max_rollback=max_rollback_tokens),
)
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
return torch.full(
(max_num_seqs, (self.vocab_size + 31) // 32),
-1,
dtype=torch.int32,
pin_memory=torch.cuda.is_available(),
)
def destroy(self):
pass
@dataclass
class OutlinesGrammar(StructuredOutputGrammar):
vocab_size: int
guide: oc.Guide = field(hash=False)
num_processed_tokens: int = field(
default_factory=lambda: 0, repr=False, hash=False, init=False
)
# outlines_core signals done on DFA accept; vLLM expects done after EOS.
# We delay the finished flag by one step so EOS can still be emitted.
_prev_finished: bool = field(default=False, init=False, repr=False, hash=False)
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""Accepts a list of tokens and advances the FSM.
Returns True if the FSM was advanced successfully.
Returns False if the FSM failed to advance.
"""
if self.guide.accepts_tokens(tokens):
# Advance cannot fail because we checked Guide.accepts_tokens()
for t in tokens:
self.guide.advance(t)
self.num_processed_tokens += 1
return True
return False
def rollback(self, num_tokens: int) -> None:
self.guide.rollback_state(num_tokens)
self.num_processed_tokens -= num_tokens
def validate_tokens(self, tokens: list[int]) -> list[int]:
accepted: list[int] = []
for tok in tokens:
accepted.append(tok)
if not self.guide.accepts_tokens(accepted):
accepted.pop()
break
return accepted
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
mask = bitmask[idx]
self.guide.write_mask_into(mask.data_ptr(), mask.numel(), mask.element_size())
def is_terminated(self) -> bool:
curr = self.guide.is_finished()
prev = self._prev_finished
self._prev_finished = curr
return prev
def reset(self):
self.num_processed_tokens = 0
self._prev_finished = False
self.guide.reset()
def validate_structured_output_request_outlines(params: SamplingParams):
if params.structured_outputs is None:
return
so_params = params.structured_outputs
if so_params.regex:
validate_regex_is_buildable(so_params.regex)
elif so_params.json:
if isinstance(so_params.json, str):
try:
# make sure schema is valid json
json.loads(so_params.json)
schema = so_params.json
except json.JSONDecodeError as e:
raise ValueError("Invalid JSON grammar specification.") from e
else:
try:
schema = json.dumps(so_params.json)
except Exception as e:
raise ValueError(
f"Error serializing structured outputs jsonschema: {e}"
) from e
pattern = json_schema.build_regex_from_schema(schema)
validate_regex_is_buildable(pattern)
elif so_params.choice:
choices = [regex_escape(str(choice)) for choice in so_params.choice]
regex = "(" + "|".join(choices) + ")"
validate_regex_is_buildable(regex)
elif so_params.grammar:
raise ValueError(
"Outlines structured outputs backend "
"does not support grammar specifications"
)
def _prefix_needs_context(parsed) -> bool:
"""Return True if there's a look-around/anchor before any consumer."""
def subpattern_consumes(parsed) -> bool:
"""Return True if subpattern can consume at least one character."""
tokens = parsed.data if hasattr(parsed, "data") else parsed
for ttype, tval in tokens:
# literal, character class, or dot always consumes
if ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
return True
# quantified subpattern: check inner pattern
elif ttype == sre_parse.MAX_REPEAT:
_, mx, sub = tval
if mx != 0 and subpattern_consumes(sub):
return True
# alternation: if any branch consumes, the whole does
elif ttype == sre_parse.BRANCH:
_, branches = tval
if any(subpattern_consumes(br) for br in branches):
return True
# grouped subpattern: recurse into its contents
elif ttype == sre_parse.SUBPATTERN and subpattern_consumes(tval[3]):
return True
# No consumers, return False
return False
tokens = parsed.data if hasattr(parsed, "data") else parsed
for ttype, tval in tokens:
# Direct anchors or look-around
if ttype == sre_parse.AT or ttype in (
sre_constants.ASSERT,
sre_constants.ASSERT_NOT,
):
return True
# Nested subpattern: check
if ttype == sre_parse.SUBPATTERN:
# tval: (group, add_flags, del_flags, subpattern)
if _prefix_needs_context(tval[3]):
return True
if subpattern_consumes(tval[3]):
return False
# if any branch has a prefix anchor => True,
# else if at least one branch consumes => prefix ends => False
elif ttype == sre_parse.BRANCH:
saw_consumer = False
for br in tval[1]:
if _prefix_needs_context(br):
return True
if subpattern_consumes(br):
saw_consumer = True
if saw_consumer:
return False
# Immediate consumer tokens
elif ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
return False
# if subpattern has anchor => True, if it can consume => stop
elif ttype == sre_parse.MAX_REPEAT:
if _prefix_needs_context(tval[2]):
return True
if subpattern_consumes(tval[2]):
return False
return False
def _check_unsupported(parsed) -> None:
"""Check for regex features unsupported by regex-automata"""
tokens = parsed.data if hasattr(parsed, "data") else parsed
for ttype, tval in tokens:
# backreference
if ttype in (sre_parse.GROUPREF, sre_parse.GROUPREF_EXISTS):
raise ValueError("Backreferences are unsupported.")
# look-around assertion
elif ttype in (sre_constants.ASSERT, sre_constants.ASSERT_NOT):
raise ValueError("Look-Around assertion are unsupported.")
# unicode word boundaries
elif ttype == sre_parse.AT:
if tval in (sre_constants.AT_BOUNDARY, sre_constants.AT_NON_BOUNDARY):
raise ValueError("Unicode word boundaries are unsupported.")
elif ttype == sre_parse.BRANCH:
# tval is (None, branches)
for branch in tval[1]:
_check_unsupported(branch)
# tval is (min, max, subpattern)
elif ttype == sre_parse.MAX_REPEAT:
_check_unsupported(tval[2])
def validate_regex_is_buildable(pattern: str) -> None:
"""
Validates that the input regex is not using unsupported features
of the `regex-automata` crate (outlines_core regex engine) and has a
universal start state.
definition of universal start state used can be found at:
https://docs.rs/regex-automata/latest/regex_automata/dfa/trait.Automaton.html#method.universal_start_state
"""
try:
parsed = sre_parse.parse(pattern)
except sre_constants.error as e:
raise ValueError(f"Error parsing regex: {e}") from e
try:
_check_unsupported(parsed)
except ValueError as e:
raise ValueError(
f"Regex uses unsupported feature for structured outputs: {e}. "
"Only basic matching constructs are supported—lookarounds, "
"backreferences, and unicode boundaries are not."
) from e
if _prefix_needs_context(parsed):
raise ValueError(
"Regex does not have a anchored universal start state"
"This means that the Regex uses anchors (^) or look-arounds "
"in a way which requires context before any token is matched."
"structured outputs needs regexes that can match without needing "
"that context. Try rewriting the pattern without using these "
f"constructs. Pattern:\n{pattern}"
)