update
This commit is contained in:
@@ -1,338 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import multiprocessing
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
|
||||
from vllm.v1.structured_output.backend_types import (
|
||||
StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
)
|
||||
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.v1.request import Request
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
ReasoningParser = object
|
||||
Request = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class StructuredOutputManager:
|
||||
"""Engine-level manager for structured output requests."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.backend: StructuredOutputBackend | None = None
|
||||
self.reasoner: ReasoningParser | None = None
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
self._grammar_bitmask: torch.Tensor | None = None
|
||||
self._full_mask = torch.tensor(-1, dtype=torch.int32)
|
||||
|
||||
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
|
||||
self.fill_bitmask_parallel_threshold = 128
|
||||
if self.fill_bitmask_parallel_threshold < max_batch_size:
|
||||
self.fill_bitmask_parallel_batch_size = 16
|
||||
# Use:
|
||||
# - at least 1 CPU
|
||||
# - at most half the number of CPUs or 8, whichever is less
|
||||
max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8))
|
||||
self.executor_for_fillmask = ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
if not self.vllm_config.model_config.skip_tokenizer_init:
|
||||
# The default max_workers if not specified is the number of
|
||||
# CPUs * 5, which is way too high since these tasks are CPU-bound,
|
||||
# not I/O bound. We also know we would never dominate CPU usage
|
||||
# with just grammar compilation, so we set it to half the number
|
||||
# of CPUs.
|
||||
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=self.vllm_config.model_config
|
||||
)
|
||||
reasoning_parser = (
|
||||
self.vllm_config.structured_outputs_config.reasoning_parser
|
||||
)
|
||||
reasoning_parser_plugin = (
|
||||
self.vllm_config.structured_outputs_config.reasoning_parser_plugin
|
||||
)
|
||||
if reasoning_parser_plugin and len(reasoning_parser_plugin) > 3:
|
||||
ReasoningParserManager.import_reasoning_parser(reasoning_parser_plugin)
|
||||
|
||||
reasoning_parser = (
|
||||
self.vllm_config.structured_outputs_config.reasoning_parser
|
||||
)
|
||||
if reasoning_parser:
|
||||
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
|
||||
reasoning_parser
|
||||
)
|
||||
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
|
||||
|
||||
self.enable_in_reasoning = (
|
||||
self.vllm_config.structured_outputs_config.enable_in_reasoning
|
||||
)
|
||||
|
||||
def grammar_init(self, request: Request) -> None:
|
||||
if request.structured_output_request is None:
|
||||
return
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert (
|
||||
request.sampling_params is not None
|
||||
and request.sampling_params.structured_outputs is not None
|
||||
)
|
||||
|
||||
# Initialize the backend the first time it is needed.
|
||||
#
|
||||
# NOTE: We only support a single backend. We do NOT support different
|
||||
# backends on a per-request basis in V1 (for now, anyway...).
|
||||
# _backend is set in Processor._validate_structured_output
|
||||
if self.backend is None:
|
||||
assert request.sampling_params is not None
|
||||
backend = request.sampling_params.structured_outputs._backend
|
||||
vocab_size = self.vllm_config.model_config.get_vocab_size()
|
||||
if backend == "xgrammar":
|
||||
self.backend = XgrammarBackend(
|
||||
self.vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
elif backend == "guidance":
|
||||
self.backend = GuidanceBackend(
|
||||
self.vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
elif backend == "outlines":
|
||||
from vllm.v1.structured_output.backend_outlines import OutlinesBackend
|
||||
|
||||
self.backend = OutlinesBackend(
|
||||
self.vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
elif backend == "lm-format-enforcer":
|
||||
from vllm.v1.structured_output.backend_lm_format_enforcer import ( # noqa: E501
|
||||
LMFormatEnforcerBackend,
|
||||
)
|
||||
|
||||
self.backend = LMFormatEnforcerBackend(
|
||||
self.vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported structured output backend: {backend}")
|
||||
|
||||
grammar = self.executor.submit(self._async_create_grammar, request)
|
||||
request.structured_output_request.grammar = grammar # type: ignore[assignment]
|
||||
|
||||
def _async_create_grammar(
|
||||
self,
|
||||
request: Request,
|
||||
) -> StructuredOutputGrammar:
|
||||
key = request.structured_output_request.structured_output_key # type: ignore[union-attr]
|
||||
|
||||
# Note that the request was validated in the engine core client,
|
||||
# so at this point we know it is a supported type of request.
|
||||
#
|
||||
# TODO: we still need to handle xgrammar compilation failures,
|
||||
# though it should be unlikely as we test that up front as well.
|
||||
request_type, grammar_spec = key
|
||||
|
||||
assert self.backend is not None
|
||||
return self.backend.compile_grammar(request_type, grammar_spec)
|
||||
|
||||
def _fill_bitmasks(
|
||||
self,
|
||||
batch: list[tuple[StructuredOutputGrammar, int, bool]],
|
||||
) -> None:
|
||||
assert self._grammar_bitmask is not None
|
||||
for grammar, index, apply_bitmask in batch:
|
||||
if apply_bitmask and not grammar.is_terminated():
|
||||
grammar.fill_bitmask(self._grammar_bitmask, index)
|
||||
else:
|
||||
# Note that for thinking support, we will need to
|
||||
# reset the relevant part of the bitmask for consequent
|
||||
# requests here.
|
||||
self._grammar_bitmask[index].fill_(self._full_mask)
|
||||
|
||||
def _async_submit_fill_bitmask(
|
||||
self,
|
||||
batch: list[tuple[StructuredOutputGrammar, int, bool]],
|
||||
) -> Future:
|
||||
return self.executor_for_fillmask.submit(self._fill_bitmasks, batch)
|
||||
|
||||
def grammar_bitmask(
|
||||
self,
|
||||
requests: dict[str, Request],
|
||||
structured_output_request_ids: list[str],
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]],
|
||||
) -> "npt.NDArray[np.int32] | None":
|
||||
# Prepare the structured output bitmask for this batch.
|
||||
if not structured_output_request_ids:
|
||||
return None
|
||||
|
||||
max_num_spec_tokens = 0
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
max_num_spec_tokens = (
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
)
|
||||
|
||||
if self._grammar_bitmask is None:
|
||||
assert self.backend is not None
|
||||
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
|
||||
|
||||
# Allocate a bitmask for each token needing to be checked:
|
||||
# one for each speculative position, and one more for the
|
||||
# bonus token / non-speculative token.
|
||||
self._grammar_bitmask = self.backend.allocate_token_bitmask(
|
||||
max_batch_size * (1 + max_num_spec_tokens)
|
||||
)
|
||||
|
||||
# Generate a batched bitmask for all structured output requests.
|
||||
# When speculative decoding is enabled, we need to include multiple
|
||||
# masks for each request, one for each possible bonus token position.
|
||||
# These are stored inline in the tensor and unpacked by the gpu runner.
|
||||
cumulative_index = 0
|
||||
|
||||
# Optimized parallel filling of bitmasks for
|
||||
# non-spec, large-batch-size cases
|
||||
if (
|
||||
len(structured_output_request_ids) > self.fill_bitmask_parallel_threshold
|
||||
and max_num_spec_tokens == 0
|
||||
):
|
||||
promises = []
|
||||
batch = []
|
||||
for req_id in structured_output_request_ids:
|
||||
request = requests[req_id]
|
||||
structured_output_request = request.structured_output_request
|
||||
if TYPE_CHECKING:
|
||||
assert structured_output_request is not None
|
||||
assert structured_output_request.grammar is not None
|
||||
|
||||
apply_bitmask = self.should_fill_bitmask(request)
|
||||
batch.append(
|
||||
(structured_output_request.grammar, cumulative_index, apply_bitmask)
|
||||
)
|
||||
if len(batch) == self.fill_bitmask_parallel_batch_size:
|
||||
promises.append(self._async_submit_fill_bitmask(batch))
|
||||
batch = []
|
||||
|
||||
cumulative_index += 1
|
||||
if batch:
|
||||
promises.append(self._async_submit_fill_bitmask(batch))
|
||||
|
||||
# Wait for all bitmask filling tasks to complete.
|
||||
for promise in promises:
|
||||
promise.result()
|
||||
else:
|
||||
# Fallback to serial filling of bitmasks for small-batch-size cases
|
||||
for req_id in structured_output_request_ids:
|
||||
request = requests[req_id]
|
||||
structured_output_request = request.structured_output_request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert structured_output_request is not None
|
||||
assert structured_output_request.grammar is not None
|
||||
apply_bitmask = self.should_fill_bitmask(request)
|
||||
|
||||
state_advancements = 0
|
||||
req_tokens = scheduled_spec_decode_tokens.get(req_id, [])
|
||||
for i, token in enumerate(req_tokens + [None]):
|
||||
self._fill_bitmasks(
|
||||
[
|
||||
(
|
||||
structured_output_request.grammar,
|
||||
cumulative_index,
|
||||
apply_bitmask,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if (
|
||||
apply_bitmask
|
||||
and token is not None
|
||||
and not structured_output_request.grammar.is_terminated()
|
||||
):
|
||||
accepted = structured_output_request.grammar.accept_tokens(
|
||||
req_id, [token]
|
||||
)
|
||||
assert accepted, (token, req_id, scheduled_spec_decode_tokens)
|
||||
state_advancements += 1
|
||||
cumulative_index += 1
|
||||
if state_advancements > 0:
|
||||
structured_output_request.grammar.rollback(state_advancements)
|
||||
|
||||
bitmask_tensor = self._grammar_bitmask
|
||||
if cumulative_index < bitmask_tensor.shape[0]:
|
||||
bitmask_tensor = bitmask_tensor[:cumulative_index]
|
||||
|
||||
# After finishing with the xgrammar operations, we convert to
|
||||
# np.ndarray, because that is much more efficient for serialization
|
||||
# and deserialization when sending this to the GPU workers.
|
||||
return bitmask_tensor.numpy()
|
||||
|
||||
def should_fill_bitmask(self, request: Request) -> bool:
|
||||
# NOTE (Hanchen) if enable_in_reasoning is True, it means that
|
||||
# the model needs to be constrained in reasoning. So we should always
|
||||
# enable the bitmask filling.
|
||||
|
||||
if self.reasoner is not None:
|
||||
if self.enable_in_reasoning:
|
||||
return True
|
||||
assert request.structured_output_request is not None
|
||||
if request.structured_output_request.reasoning_ended is None:
|
||||
request.structured_output_request.reasoning_ended = (
|
||||
self.reasoner.is_reasoning_end(request.prompt_token_ids)
|
||||
)
|
||||
return request.structured_output_request.reasoning_ended
|
||||
return True
|
||||
|
||||
def should_advance(self, request: Request) -> bool:
|
||||
if not request.use_structured_output:
|
||||
return False
|
||||
|
||||
# To determine whether we can advance the FSM.
|
||||
# Supports thinking usage where we skip the reasoning components.
|
||||
if TYPE_CHECKING:
|
||||
assert request.structured_output_request is not None
|
||||
assert request.structured_output_request.grammar is not None
|
||||
# by default, we should always advance
|
||||
# for cases that don't use thinking mode.
|
||||
if self.reasoner is None:
|
||||
return True
|
||||
|
||||
# if the model needs structured in reasoning, we should advance
|
||||
if self.enable_in_reasoning:
|
||||
return True
|
||||
|
||||
structured_req = request.structured_output_request
|
||||
if structured_req.reasoning_ended:
|
||||
return True
|
||||
|
||||
# Check if reasoning ends in *this* step
|
||||
if self.reasoner.is_reasoning_end(request.all_token_ids):
|
||||
# Reasoning just ended, so we shouldn't advance til
|
||||
# next pass
|
||||
structured_req.reasoning_ended = True
|
||||
|
||||
return False
|
||||
|
||||
def clear_backend(self) -> None:
|
||||
if self.backend is not None:
|
||||
self.backend.destroy()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,265 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (
|
||||
StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
StructuredOutputOptions,
|
||||
)
|
||||
from vllm.v1.structured_output.request import get_structured_output_key
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import llguidance
|
||||
import llguidance.hf as llguidance_hf
|
||||
import llguidance.torch as llguidance_torch
|
||||
else:
|
||||
llguidance = LazyLoader("llguidance", globals(), "llguidance")
|
||||
llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf")
|
||||
llguidance_torch = LazyLoader("llguidance.torch", globals(), "llguidance.torch")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _walk_json_for_additional_properties(data: object):
|
||||
if isinstance(data, dict):
|
||||
for value in data.values():
|
||||
_walk_json_for_additional_properties(value)
|
||||
if "additionalProperties" not in data and (
|
||||
"properties" in data or "patternProperties" in data
|
||||
):
|
||||
data["additionalProperties"] = False
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
_walk_json_for_additional_properties(item)
|
||||
|
||||
|
||||
def process_for_additional_properties(
|
||||
guide_json: str | dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
if isinstance(guide_json, str):
|
||||
guide_json_obj = json.loads(guide_json)
|
||||
else:
|
||||
# copy for modifications
|
||||
guide_json_obj = copy.deepcopy(guide_json)
|
||||
_walk_json_for_additional_properties(guide_json_obj)
|
||||
return guide_json_obj
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuidanceBackend(StructuredOutputBackend):
|
||||
def __post_init__(self):
|
||||
self.disable_any_whitespace = (
|
||||
self.vllm_config.structured_outputs_config.disable_any_whitespace
|
||||
)
|
||||
self.disable_additional_properties = (
|
||||
self.vllm_config.structured_outputs_config.disable_additional_properties
|
||||
)
|
||||
|
||||
self.ll_tokenizer = llguidance_hf.from_tokenizer(
|
||||
self.tokenizer, self.vocab_size
|
||||
)
|
||||
|
||||
def compile_grammar(
|
||||
self, request_type: StructuredOutputOptions, grammar_spec: str
|
||||
) -> StructuredOutputGrammar:
|
||||
self.serialized_grammar = serialize_guidance_grammar(
|
||||
request_type,
|
||||
grammar_spec,
|
||||
self.disable_any_whitespace,
|
||||
self.disable_additional_properties,
|
||||
)
|
||||
|
||||
ll_matcher = llguidance.LLMatcher(
|
||||
self.ll_tokenizer,
|
||||
self.serialized_grammar,
|
||||
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
|
||||
)
|
||||
|
||||
r = GuidanceGrammar(
|
||||
ll_matcher=ll_matcher,
|
||||
ll_tokenizer=self.ll_tokenizer,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
r.check_error()
|
||||
return r
|
||||
|
||||
def allocate_token_bitmask(self, max_num_seqs: int):
|
||||
return llguidance_torch.allocate_token_bitmask(
|
||||
max_num_seqs, self.ll_tokenizer.vocab_size
|
||||
)
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuidanceGrammar(StructuredOutputGrammar):
|
||||
ll_matcher: llguidance.LLMatcher
|
||||
ll_tokenizer: llguidance.LLTokenizer
|
||||
vocab_size: int
|
||||
printed_error: bool = False
|
||||
terminated: bool = False
|
||||
rollback_lag: int = 0
|
||||
|
||||
def check_error(self):
|
||||
if not self.printed_error:
|
||||
err = self.ll_matcher.get_error()
|
||||
if err:
|
||||
self.printed_error = True
|
||||
logger.warning("LLMatcher error: %s", err)
|
||||
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""Accepts a list of tokens and advances the parser.
|
||||
|
||||
Returns True if the parser was advanced successfully.
|
||||
Returns False if the parser failed to advance.
|
||||
"""
|
||||
|
||||
if self.ll_tokenizer.eos_token in tokens:
|
||||
if self.ll_matcher.is_stopped() and not self.terminated:
|
||||
self.rollback_lag = 1
|
||||
self.terminated = True
|
||||
|
||||
if self.ll_matcher.is_stopped():
|
||||
return True
|
||||
|
||||
# TODO - Add jump decoding support in the future:
|
||||
# self.ll_matcher.compute_ff_bytes() - this should always work
|
||||
# self.ll_matcher.compute_ff_tokens() - this only works for
|
||||
# "canonical" tokenizers
|
||||
# For conversion between the two, see
|
||||
# https://github.com/guidance-ai/llguidance/blob/main/docs/fast_forward.md
|
||||
|
||||
r = self.ll_matcher.consume_tokens(tokens)
|
||||
|
||||
self.check_error()
|
||||
|
||||
return r
|
||||
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
"""Checks if the list of tokens are accepted by the parser in sequence.
|
||||
Will not advance the parser.
|
||||
|
||||
Returns the prefix list of tokens that are accepted by the parser.
|
||||
"""
|
||||
if len(tokens) == 0:
|
||||
return []
|
||||
if self.ll_matcher.is_stopped():
|
||||
return []
|
||||
|
||||
num_tokens = self.ll_matcher.validate_tokens(tokens)
|
||||
|
||||
self.check_error()
|
||||
|
||||
return tokens[:num_tokens]
|
||||
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
if num_tokens > 0:
|
||||
self.ll_matcher.rollback(num_tokens - self.rollback_lag)
|
||||
self.terminated = False
|
||||
self.rollback_lag = 0
|
||||
self.check_error()
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||
# this will automatically return [EOS] mask if the matcher is stopped
|
||||
# or otherwise in an error state
|
||||
llguidance_torch.fill_next_token_bitmask(self.ll_matcher, bitmask, idx)
|
||||
self.check_error()
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
return self.terminated
|
||||
|
||||
def reset(self):
|
||||
# This method may be not needed anymore? TODO
|
||||
self.ll_matcher.reset()
|
||||
|
||||
|
||||
def serialize_guidance_grammar(
|
||||
request_type: StructuredOutputOptions,
|
||||
grammar_spec: str | dict[str, Any],
|
||||
disable_any_whitespace: bool = False,
|
||||
disable_additional_properties: bool = False,
|
||||
) -> str:
|
||||
def _process_schema(
|
||||
grammar_spec: str | dict[str, Any],
|
||||
) -> str:
|
||||
if disable_additional_properties:
|
||||
grammar_spec = process_for_additional_properties(grammar_spec)
|
||||
return llguidance.LLMatcher.grammar_from_json_schema(
|
||||
grammar_spec,
|
||||
defaults={
|
||||
"whitespace_flexible": not disable_any_whitespace,
|
||||
},
|
||||
)
|
||||
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
return _process_schema(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
||||
return llguidance.LLMatcher.grammar_from_json_schema(
|
||||
'{"type": "object"}',
|
||||
defaults={
|
||||
"whitespace_flexible": not disable_any_whitespace,
|
||||
},
|
||||
)
|
||||
else:
|
||||
if request_type == StructuredOutputOptions.REGEX:
|
||||
tp = "regex"
|
||||
elif request_type == StructuredOutputOptions.GRAMMAR:
|
||||
tp = "grammar"
|
||||
elif request_type == StructuredOutputOptions.CHOICE:
|
||||
tp = "choice"
|
||||
elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
|
||||
if isinstance(grammar_spec, str):
|
||||
s_tag = json.loads(grammar_spec)
|
||||
else:
|
||||
s_tag = grammar_spec
|
||||
triggers: list[str] = s_tag["triggers"]
|
||||
tags: list[llguidance.StructTag] = []
|
||||
for s in s_tag["structures"]:
|
||||
begin: str = s["begin"]
|
||||
trig = next((t for t in triggers if begin.startswith(t)), None)
|
||||
if trig is None:
|
||||
raise ValueError(
|
||||
f"Trigger {begin} not found in triggers {triggers}"
|
||||
)
|
||||
tags.append(
|
||||
llguidance.StructTag(
|
||||
trigger=trig,
|
||||
begin=s["begin"],
|
||||
grammar=_process_schema(s["schema"]),
|
||||
end=s["end"],
|
||||
)
|
||||
)
|
||||
if not tags:
|
||||
raise ValueError("No structural tags found in the grammar spec.")
|
||||
return llguidance.StructTag.to_grammar(tags)
|
||||
else:
|
||||
logger.error(
|
||||
"Validation should have already occurred. Please file an issue."
|
||||
)
|
||||
raise ValueError(
|
||||
f"grammar is not of valid supported types. ({request_type!s})"
|
||||
)
|
||||
return llguidance.grammar_from(tp, grammar_spec)
|
||||
|
||||
|
||||
def validate_guidance_grammar(
|
||||
sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None
|
||||
) -> None:
|
||||
tp, grm = get_structured_output_key(sampling_params.structured_outputs)
|
||||
guidance_grm = serialize_guidance_grammar(tp, grm)
|
||||
err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer)
|
||||
if err:
|
||||
raise ValueError(f"Grammar error: {err}")
|
||||
@@ -1,177 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (
|
||||
StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
StructuredOutputOptions,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import lmformatenforcer
|
||||
import lmformatenforcer.integrations.vllm as lmfe_vllm
|
||||
else:
|
||||
lmformatenforcer = LazyLoader("lmformatenforcer", globals(), "lmformatenforcer")
|
||||
lmfe_vllm = LazyLoader(
|
||||
"lmformatenforcer.integrations.vllm",
|
||||
globals(),
|
||||
"lmformatenforcer.integrations.vllm",
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||
tokenizer: PreTrainedTokenizerBase, vocab_size: int
|
||||
) -> "lmfe_vllm.TokenEnforcerTokenizerData":
|
||||
return lmfe_vllm.build_vllm_token_enforcer_tokenizer_data(
|
||||
tokenizer, use_bitmask=True, vocab_size=vocab_size
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LMFormatEnforcerGrammar(StructuredOutputGrammar):
|
||||
token_enforcer: lmformatenforcer.TokenEnforcer
|
||||
current_tokens_prefix: list[int] = field(default_factory=list)
|
||||
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
original_len = len(self.current_tokens_prefix)
|
||||
for token in tokens:
|
||||
if not self.token_enforcer.get_allowed_tokens(
|
||||
self.current_tokens_prefix
|
||||
).is_token_allowed(token):
|
||||
# Rollback partial updates to ensure atomicity.
|
||||
del self.current_tokens_prefix[original_len:]
|
||||
return False
|
||||
self.current_tokens_prefix.append(token)
|
||||
return True
|
||||
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
for prefix_length in range(len(tokens)):
|
||||
prefix = tokens[:prefix_length]
|
||||
next_token = tokens[prefix_length]
|
||||
if not self.token_enforcer.get_allowed_tokens(
|
||||
self.current_tokens_prefix + prefix
|
||||
).is_token_allowed(next_token):
|
||||
break
|
||||
else:
|
||||
return tokens
|
||||
|
||||
return tokens[:prefix_length]
|
||||
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
self.current_tokens_prefix = self.current_tokens_prefix[:-num_tokens]
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
|
||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(
|
||||
self.current_tokens_prefix
|
||||
)
|
||||
bitmask[batch_index] = allowed_tokens.allowed_tokens
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
# We are considered terminated if the prefix ends with eos_token_id
|
||||
return_value = (
|
||||
len(self.current_tokens_prefix) > 0
|
||||
and self.current_tokens_prefix[-1] == self.token_enforcer.eos_token_id
|
||||
)
|
||||
return return_value
|
||||
|
||||
def reset(self):
|
||||
self.current_tokens_prefix = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class LMFormatEnforcerBackend(StructuredOutputBackend):
|
||||
def __post_init__(self):
|
||||
self.tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||
self.tokenizer, self.vocab_size
|
||||
)
|
||||
|
||||
def compile_grammar(
|
||||
self, request_type: StructuredOutputOptions, grammar_spec: str
|
||||
) -> StructuredOutputGrammar:
|
||||
character_level_parser: lmformatenforcer.CharacterLevelParser
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
spec_dict = json.loads(grammar_spec)
|
||||
character_level_parser = lmformatenforcer.JsonSchemaParser(spec_dict)
|
||||
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
||||
character_level_parser = lmformatenforcer.JsonSchemaParser(None)
|
||||
elif request_type == StructuredOutputOptions.REGEX:
|
||||
character_level_parser = lmformatenforcer.RegexParser(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.CHOICE:
|
||||
choices = ast.literal_eval(grammar_spec)
|
||||
character_level_parser = lmformatenforcer.UnionParser(
|
||||
[lmformatenforcer.StringParser(choice) for choice in choices]
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid request type for LM Format Enforcer backend({request_type!s})"
|
||||
)
|
||||
max_rollback_tokens = (
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config is not None
|
||||
else 0
|
||||
)
|
||||
|
||||
if max_rollback_tokens > 0:
|
||||
raise ValueError(
|
||||
"LM Format Enforcer backend does not support speculative tokens"
|
||||
)
|
||||
|
||||
token_enforcer = lmformatenforcer.TokenEnforcer(
|
||||
tokenizer_data=self.tokenizer_data,
|
||||
parser=character_level_parser,
|
||||
)
|
||||
return LMFormatEnforcerGrammar(token_enforcer)
|
||||
|
||||
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
|
||||
return torch.full(
|
||||
(max_num_seqs, (self.vocab_size + 31) // 32),
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
)
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
def validate_structured_output_request_lm_format_enforcer(params: SamplingParams):
|
||||
if params.structured_outputs is None:
|
||||
return
|
||||
|
||||
so_params = params.structured_outputs
|
||||
|
||||
if so_params.regex:
|
||||
return
|
||||
elif so_params.json:
|
||||
if isinstance(so_params.json, str):
|
||||
try:
|
||||
# make sure schema is valid json
|
||||
json.loads(so_params.json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError("Invalid JSON grammar specification.") from e
|
||||
else:
|
||||
try:
|
||||
json.dumps(so_params.json)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error serializing structured outputs jsonschema: {e}"
|
||||
) from e
|
||||
return
|
||||
elif so_params.choice:
|
||||
return
|
||||
elif so_params.grammar:
|
||||
raise ValueError(
|
||||
"LM Format Enforcer structured outputs backend "
|
||||
"does not support grammar specifications"
|
||||
)
|
||||
@@ -1,324 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright 2025-present the Outlines developers
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import importlib
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from regex import escape as regex_escape
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (
|
||||
StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
StructuredOutputOptions,
|
||||
)
|
||||
from vllm.v1.structured_output.utils import (
|
||||
OutlinesVocabulary,
|
||||
get_outlines_cache,
|
||||
get_outlines_vocabulary,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import outlines_core as oc
|
||||
import outlines_core.json_schema as json_schema
|
||||
else:
|
||||
oc = LazyLoader("oc", globals(), "outlines_core")
|
||||
json_schema = LazyLoader("json_schema", globals(), "outlines_core.json_schema")
|
||||
|
||||
# Python 3.11+ sre_parse and sre_constants
|
||||
# are deprecated, so we must import them from re
|
||||
if sys.version_info >= (3, 11):
|
||||
# Hack to get around pre-commit regex module rule
|
||||
# because going through re is the only way to get sre_parse
|
||||
# and sre_constants in Python 3.11+
|
||||
_re = importlib.import_module("re")
|
||||
sre_parse = _re._parser
|
||||
sre_constants = _re._constants
|
||||
else:
|
||||
import sre_constants
|
||||
import sre_parse
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutlinesBackend(StructuredOutputBackend):
|
||||
def __post_init__(self):
|
||||
self.vocabulary = get_outlines_vocabulary(self.tokenizer)
|
||||
self.cache = get_outlines_cache()
|
||||
|
||||
def _compile_index(
|
||||
self, regex_string: str, vocabulary: OutlinesVocabulary
|
||||
) -> oc.Index:
|
||||
cache_key = f"{vocabulary._hash}_{regex_string}"
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key]
|
||||
|
||||
index = oc.Index(regex_string, vocabulary.inner)
|
||||
self.cache[cache_key] = index
|
||||
|
||||
return index
|
||||
|
||||
def compile_grammar(
|
||||
self, request_type: StructuredOutputOptions, grammar_spec: str
|
||||
) -> StructuredOutputGrammar:
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
regex = json_schema.build_regex_from_schema(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.REGEX:
|
||||
regex = grammar_spec
|
||||
elif request_type == StructuredOutputOptions.CHOICE:
|
||||
choices = ast.literal_eval(grammar_spec)
|
||||
choices = [regex_escape(c) for c in choices]
|
||||
regex = "(" + "|".join(choices) + ")"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid request type for Outlines backend ({request_type!s})"
|
||||
)
|
||||
index = self._compile_index(regex, self.vocabulary)
|
||||
max_rollback_tokens = (
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config is not None
|
||||
else 0
|
||||
)
|
||||
return OutlinesGrammar(
|
||||
vocab_size=self.vocab_size,
|
||||
guide=oc.Guide(index, max_rollback=max_rollback_tokens),
|
||||
)
|
||||
|
||||
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
|
||||
return torch.full(
|
||||
(max_num_seqs, (self.vocab_size + 31) // 32),
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
)
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutlinesGrammar(StructuredOutputGrammar):
|
||||
vocab_size: int
|
||||
guide: oc.Guide = field(hash=False)
|
||||
num_processed_tokens: int = field(
|
||||
default_factory=lambda: 0, repr=False, hash=False, init=False
|
||||
)
|
||||
|
||||
# outlines_core signals done on DFA accept; vLLM expects done after EOS.
|
||||
# We delay the finished flag by one step so EOS can still be emitted.
|
||||
_prev_finished: bool = field(default=False, init=False, repr=False, hash=False)
|
||||
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""Accepts a list of tokens and advances the FSM.
|
||||
|
||||
Returns True if the FSM was advanced successfully.
|
||||
Returns False if the FSM failed to advance.
|
||||
"""
|
||||
if self.guide.accepts_tokens(tokens):
|
||||
# Advance cannot fail because we checked Guide.accepts_tokens()
|
||||
for t in tokens:
|
||||
self.guide.advance(t)
|
||||
self.num_processed_tokens += 1
|
||||
return True
|
||||
return False
|
||||
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
self.guide.rollback_state(num_tokens)
|
||||
self.num_processed_tokens -= num_tokens
|
||||
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
accepted: list[int] = []
|
||||
for tok in tokens:
|
||||
accepted.append(tok)
|
||||
if not self.guide.accepts_tokens(accepted):
|
||||
accepted.pop()
|
||||
break
|
||||
return accepted
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||
mask = bitmask[idx]
|
||||
self.guide.write_mask_into(mask.data_ptr(), mask.numel(), mask.element_size())
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
curr = self.guide.is_finished()
|
||||
prev = self._prev_finished
|
||||
self._prev_finished = curr
|
||||
return prev
|
||||
|
||||
def reset(self):
|
||||
self.num_processed_tokens = 0
|
||||
self._prev_finished = False
|
||||
self.guide.reset()
|
||||
|
||||
|
||||
def validate_structured_output_request_outlines(params: SamplingParams):
|
||||
if params.structured_outputs is None:
|
||||
return
|
||||
|
||||
so_params = params.structured_outputs
|
||||
|
||||
if so_params.regex:
|
||||
validate_regex_is_buildable(so_params.regex)
|
||||
elif so_params.json:
|
||||
if isinstance(so_params.json, str):
|
||||
try:
|
||||
# make sure schema is valid json
|
||||
json.loads(so_params.json)
|
||||
schema = so_params.json
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError("Invalid JSON grammar specification.") from e
|
||||
else:
|
||||
try:
|
||||
schema = json.dumps(so_params.json)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error serializing structured outputs jsonschema: {e}"
|
||||
) from e
|
||||
pattern = json_schema.build_regex_from_schema(schema)
|
||||
validate_regex_is_buildable(pattern)
|
||||
elif so_params.choice:
|
||||
choices = [regex_escape(str(choice)) for choice in so_params.choice]
|
||||
regex = "(" + "|".join(choices) + ")"
|
||||
validate_regex_is_buildable(regex)
|
||||
elif so_params.grammar:
|
||||
raise ValueError(
|
||||
"Outlines structured outputs backend "
|
||||
"does not support grammar specifications"
|
||||
)
|
||||
|
||||
|
||||
def _prefix_needs_context(parsed) -> bool:
|
||||
"""Return True if there's a look-around/anchor before any consumer."""
|
||||
|
||||
def subpattern_consumes(parsed) -> bool:
|
||||
"""Return True if subpattern can consume at least one character."""
|
||||
tokens = parsed.data if hasattr(parsed, "data") else parsed
|
||||
for ttype, tval in tokens:
|
||||
# literal, character class, or dot always consumes
|
||||
if ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
|
||||
return True
|
||||
# quantified subpattern: check inner pattern
|
||||
elif ttype == sre_parse.MAX_REPEAT:
|
||||
_, mx, sub = tval
|
||||
if mx != 0 and subpattern_consumes(sub):
|
||||
return True
|
||||
# alternation: if any branch consumes, the whole does
|
||||
elif ttype == sre_parse.BRANCH:
|
||||
_, branches = tval
|
||||
if any(subpattern_consumes(br) for br in branches):
|
||||
return True
|
||||
# grouped subpattern: recurse into its contents
|
||||
elif ttype == sre_parse.SUBPATTERN and subpattern_consumes(tval[3]):
|
||||
return True
|
||||
# No consumers, return False
|
||||
return False
|
||||
|
||||
tokens = parsed.data if hasattr(parsed, "data") else parsed
|
||||
for ttype, tval in tokens:
|
||||
# Direct anchors or look-around
|
||||
if ttype == sre_parse.AT or ttype in (
|
||||
sre_constants.ASSERT,
|
||||
sre_constants.ASSERT_NOT,
|
||||
):
|
||||
return True
|
||||
|
||||
# Nested subpattern: check
|
||||
if ttype == sre_parse.SUBPATTERN:
|
||||
# tval: (group, add_flags, del_flags, subpattern)
|
||||
if _prefix_needs_context(tval[3]):
|
||||
return True
|
||||
if subpattern_consumes(tval[3]):
|
||||
return False
|
||||
|
||||
# if any branch has a prefix anchor => True,
|
||||
# else if at least one branch consumes => prefix ends => False
|
||||
elif ttype == sre_parse.BRANCH:
|
||||
saw_consumer = False
|
||||
for br in tval[1]:
|
||||
if _prefix_needs_context(br):
|
||||
return True
|
||||
if subpattern_consumes(br):
|
||||
saw_consumer = True
|
||||
if saw_consumer:
|
||||
return False
|
||||
|
||||
# Immediate consumer tokens
|
||||
elif ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY):
|
||||
return False
|
||||
|
||||
# if subpattern has anchor => True, if it can consume => stop
|
||||
elif ttype == sre_parse.MAX_REPEAT:
|
||||
if _prefix_needs_context(tval[2]):
|
||||
return True
|
||||
if subpattern_consumes(tval[2]):
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _check_unsupported(parsed) -> None:
|
||||
"""Check for regex features unsupported by regex-automata"""
|
||||
tokens = parsed.data if hasattr(parsed, "data") else parsed
|
||||
for ttype, tval in tokens:
|
||||
# backreference
|
||||
if ttype in (sre_parse.GROUPREF, sre_parse.GROUPREF_EXISTS):
|
||||
raise ValueError("Backreferences are unsupported.")
|
||||
|
||||
# look-around assertion
|
||||
elif ttype in (sre_constants.ASSERT, sre_constants.ASSERT_NOT):
|
||||
raise ValueError("Look-Around assertion are unsupported.")
|
||||
|
||||
# unicode word boundaries
|
||||
elif ttype == sre_parse.AT:
|
||||
if tval in (sre_constants.AT_BOUNDARY, sre_constants.AT_NON_BOUNDARY):
|
||||
raise ValueError("Unicode word boundaries are unsupported.")
|
||||
|
||||
elif ttype == sre_parse.BRANCH:
|
||||
# tval is (None, branches)
|
||||
for branch in tval[1]:
|
||||
_check_unsupported(branch)
|
||||
|
||||
# tval is (min, max, subpattern)
|
||||
elif ttype == sre_parse.MAX_REPEAT:
|
||||
_check_unsupported(tval[2])
|
||||
|
||||
|
||||
def validate_regex_is_buildable(pattern: str) -> None:
|
||||
"""
|
||||
Validates that the input regex is not using unsupported features
|
||||
of the `regex-automata` crate (outlines_core regex engine) and has a
|
||||
universal start state.
|
||||
definition of universal start state used can be found at:
|
||||
https://docs.rs/regex-automata/latest/regex_automata/dfa/trait.Automaton.html#method.universal_start_state
|
||||
"""
|
||||
try:
|
||||
parsed = sre_parse.parse(pattern)
|
||||
|
||||
except sre_constants.error as e:
|
||||
raise ValueError(f"Error parsing regex: {e}") from e
|
||||
|
||||
try:
|
||||
_check_unsupported(parsed)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Regex uses unsupported feature for structured outputs: {e}. "
|
||||
"Only basic matching constructs are supported—lookarounds, "
|
||||
"backreferences, and unicode boundaries are not."
|
||||
) from e
|
||||
|
||||
if _prefix_needs_context(parsed):
|
||||
raise ValueError(
|
||||
"Regex does not have a anchored universal start state"
|
||||
"This means that the Regex uses anchors (^) or look-arounds "
|
||||
"in a way which requires context before any token is matched."
|
||||
"structured outputs needs regexes that can match without needing "
|
||||
"that context. Try rewriting the pattern without using these "
|
||||
f"constructs. Pattern:\n{pattern}"
|
||||
)
|
||||
@@ -1,136 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
else:
|
||||
VllmConfig = object
|
||||
AnyTokenizer = object
|
||||
|
||||
|
||||
class StructuredOutputOptions(enum.Enum):
|
||||
JSON = enum.auto()
|
||||
JSON_OBJECT = enum.auto()
|
||||
REGEX = enum.auto()
|
||||
GRAMMAR = enum.auto()
|
||||
CHOICE = enum.auto()
|
||||
STRUCTURAL_TAG = enum.auto()
|
||||
|
||||
|
||||
StructuredOutputKey = tuple[StructuredOutputOptions, str]
|
||||
|
||||
|
||||
class StructuredOutputGrammar(ABC):
|
||||
"""Request-level backend for structured output requests."""
|
||||
|
||||
@abstractmethod
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""
|
||||
Determines whether the provided tokens are accepted for the
|
||||
given request.
|
||||
|
||||
Args:
|
||||
request_id (str): The unique identifier for the request.
|
||||
tokens (list[int]): A list of token IDs to evaluate.
|
||||
|
||||
Returns:
|
||||
bool: True if the tokens are accepted, False otherwise.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
"""
|
||||
Validates the provided tokens against the grammar.
|
||||
Will not advance the FSM.
|
||||
|
||||
Args:
|
||||
tokens (list[int]): A list of token IDs to validate.
|
||||
|
||||
Returns:
|
||||
list[int]: A list of accepted token IDs. Will be a prefix
|
||||
of the input tokens, and empty if none are accepted.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
"""
|
||||
Rolls back the state of the grammar by a specified number of tokens.
|
||||
Will also revert counters for the number of processed tokens.
|
||||
|
||||
Args:
|
||||
num_tokens (int): The number of tokens to roll back.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def fill_bitmask(self, bitmask: "torch.Tensor", batch_index: int) -> None:
|
||||
"""
|
||||
Fills the bitmask for a specific batch index.
|
||||
|
||||
Args:
|
||||
bitmask (torch.Tensor): The bitmask to fill
|
||||
batch_index (int): The index in the bitmask to fill
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def is_terminated(self) -> bool:
|
||||
"""
|
||||
Checks whether the structured output process has terminated.
|
||||
|
||||
Returns:
|
||||
bool: True if the process is terminated, False otherwise.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the state of the structured output grammar.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructuredOutputBackend(ABC):
|
||||
"""Engine-level backend for structured output requests."""
|
||||
|
||||
vllm_config: VllmConfig
|
||||
tokenizer: AnyTokenizer
|
||||
vocab_size: int
|
||||
|
||||
@abstractmethod
|
||||
def compile_grammar(
|
||||
self, request_type: StructuredOutputOptions, grammar_spec: str
|
||||
) -> StructuredOutputGrammar:
|
||||
"""
|
||||
Compiles a grammar specification into a structured output grammar.
|
||||
|
||||
Args:
|
||||
request_type (StructuredOutputOptions): The type of structured
|
||||
output request.
|
||||
grammar_spec (str): The grammar specification to compile.
|
||||
|
||||
Returns:
|
||||
StructuredOutputGrammar: The compiled structured output grammar.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def allocate_token_bitmask(self, max_num_seqs: int) -> "torch.Tensor":
|
||||
"""
|
||||
Allocates a token bitmask for the specified maximum number of sequences.
|
||||
|
||||
Args:
|
||||
max_num_seqs (int): The maximum number of sequences for which
|
||||
to allocate the bitmask.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def destroy(self):
|
||||
"""
|
||||
Backend-specific cleanup.
|
||||
"""
|
||||
@@ -1,362 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (
|
||||
StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
StructuredOutputOptions,
|
||||
)
|
||||
from vllm.v1.structured_output.utils import (
|
||||
choice_as_grammar,
|
||||
convert_lark_to_ebnf,
|
||||
grammar_is_likely_lark,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr
|
||||
else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class XgrammarBackend(StructuredOutputBackend):
|
||||
def __post_init__(self):
|
||||
self.disable_any_whitespace = (
|
||||
self.vllm_config.structured_outputs_config.disable_any_whitespace
|
||||
)
|
||||
|
||||
if isinstance(self.tokenizer, MistralTokenizer):
|
||||
# NOTE: ideally, xgrammar should handle this accordingly.
|
||||
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
|
||||
stop_token_ids = [self.tokenizer.eos_token_id]
|
||||
|
||||
# not self.tokenizer.vocab_size as self.tokenizer.vocab
|
||||
# collapses all decoded errors into a single token.
|
||||
self.vocab_size = len(self.tokenizer.vocab)
|
||||
tokenizer_info = xgr.TokenizerInfo( # type: ignore
|
||||
encoded_vocab=self.tokenizer.vocab,
|
||||
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
|
||||
vocab_type=xgr.VocabType.RAW
|
||||
if self.tokenizer.is_tekken
|
||||
else xgr.VocabType.BYTE_FALLBACK,
|
||||
vocab_size=self.vocab_size,
|
||||
stop_token_ids=stop_token_ids,
|
||||
add_prefix_space=True,
|
||||
)
|
||||
else:
|
||||
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
||||
self.tokenizer,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
self.compiler = xgr.GrammarCompiler(
|
||||
tokenizer_info,
|
||||
max_threads=8,
|
||||
cache_enabled=True,
|
||||
cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024,
|
||||
)
|
||||
|
||||
self.num_speculative_tokens = 0
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
self.num_speculative_tokens = (
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
)
|
||||
|
||||
def compile_grammar(
|
||||
self, request_type: StructuredOutputOptions, grammar_spec: str
|
||||
) -> StructuredOutputGrammar:
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
ctx = self.compiler.compile_json_schema(
|
||||
grammar_spec, any_whitespace=not self.disable_any_whitespace
|
||||
)
|
||||
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
||||
ctx = self.compiler.compile_json_schema(
|
||||
'{"type": "object"}', any_whitespace=not self.disable_any_whitespace
|
||||
)
|
||||
elif request_type == StructuredOutputOptions.GRAMMAR:
|
||||
ctx = self.compiler.compile_grammar(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.REGEX:
|
||||
ctx = self.compiler.compile_regex(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
|
||||
s_tag = json.loads(grammar_spec)
|
||||
if "structures" in s_tag:
|
||||
# Falling back to deprecated method of compiling structural tag
|
||||
tags = [
|
||||
xgr.StructuralTagItem(
|
||||
begin=s["begin"],
|
||||
schema=json.dumps(s["schema"]),
|
||||
end=s["end"],
|
||||
)
|
||||
for s in s_tag["structures"]
|
||||
]
|
||||
ctx = self.compiler.compile_structural_tag(tags, s_tag["triggers"])
|
||||
else:
|
||||
ctx = self.compiler.compile_structural_tag(grammar_spec)
|
||||
else:
|
||||
logger.error(
|
||||
"Validation should have already occurred. Please file an issue."
|
||||
)
|
||||
raise ValueError(
|
||||
f"grammar is not of valid supported types. ({request_type!s})"
|
||||
)
|
||||
|
||||
return XgrammarGrammar(
|
||||
matcher=xgr.GrammarMatcher(
|
||||
ctx,
|
||||
max_rollback_tokens=self.num_speculative_tokens,
|
||||
),
|
||||
vocab_size=self.vocab_size,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
def allocate_token_bitmask(self, max_num_seqs: int):
|
||||
return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size)
|
||||
|
||||
def destroy(self):
|
||||
del self.compiler
|
||||
|
||||
|
||||
@dataclass
|
||||
class XgrammarGrammar(StructuredOutputGrammar):
|
||||
# NOTE: This would be a generic-enough class for
|
||||
# supporting different backends, in the future.
|
||||
# For now, just xgrammar.
|
||||
#
|
||||
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
|
||||
# for jump-forward decoding
|
||||
|
||||
vocab_size: int
|
||||
matcher: xgr.GrammarMatcher = field(hash=False)
|
||||
ctx: xgr.CompiledGrammar = field(hash=False)
|
||||
num_processed_tokens: int = field(
|
||||
default_factory=lambda: 0, repr=False, hash=False, init=False
|
||||
)
|
||||
_is_terminated: bool = field(default=False, repr=False, hash=False)
|
||||
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""Accepts a list of tokens and advances the FSM.
|
||||
|
||||
Returns True if the FSM was advanced successfully.
|
||||
Returns False if the FSM failed to advance.
|
||||
"""
|
||||
if self._is_terminated:
|
||||
return False
|
||||
for token in tokens:
|
||||
if not self.matcher.accept_token(token):
|
||||
logger.error(
|
||||
"Failed to advance FSM for request %s "
|
||||
"for tokens %s. Please file an issue.",
|
||||
request_id,
|
||||
token,
|
||||
)
|
||||
return False
|
||||
self.num_processed_tokens += 1
|
||||
self._is_terminated = self.matcher.is_terminated()
|
||||
return True
|
||||
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
"""Checks if the list of tokens are accepted by the FSM in sequence.
|
||||
Will not advance the FSM.
|
||||
|
||||
Returns the prefix list of tokens that are accepted by the FSM.
|
||||
"""
|
||||
accepted_tokens = []
|
||||
for token in tokens:
|
||||
if self.matcher.accept_token(token):
|
||||
accepted_tokens.append(token)
|
||||
else:
|
||||
break
|
||||
if len(accepted_tokens) > 0:
|
||||
# Rollback the FSM to the initial state
|
||||
self.matcher.rollback(len(accepted_tokens))
|
||||
return accepted_tokens
|
||||
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
self.matcher.rollback(num_tokens)
|
||||
self.num_processed_tokens -= num_tokens
|
||||
self._is_terminated = self.matcher.is_terminated()
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||
self.matcher.fill_next_token_bitmask(bitmask, idx)
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
return self._is_terminated
|
||||
|
||||
def reset(self):
|
||||
self.num_processed_tokens = 0
|
||||
self.matcher.reset()
|
||||
|
||||
|
||||
# cf https://github.com/mlc-ai/xgrammar/blob/a32ac892676d2eedc0327416105b9b06edfb94b2/cpp/json_schema_converter.cc
|
||||
STRING_SUPPORTED_FORMATS = {
|
||||
"email",
|
||||
"date",
|
||||
"time",
|
||||
"date-time",
|
||||
"duration",
|
||||
"ipv4",
|
||||
"ipv6",
|
||||
"hostname",
|
||||
"uuid",
|
||||
"uri",
|
||||
"uri-reference",
|
||||
"uri-template",
|
||||
"json-pointer",
|
||||
"relative-json-pointer",
|
||||
}
|
||||
|
||||
|
||||
def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
|
||||
"""Check if JSON schema contains features unsupported by xgrammar."""
|
||||
|
||||
def check_object(obj: dict[str, Any]) -> bool:
|
||||
if not isinstance(obj, dict):
|
||||
return False
|
||||
|
||||
# Check for numeric ranges
|
||||
if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
|
||||
return True
|
||||
|
||||
# Check for array unsupported keywords
|
||||
if obj.get("type") == "array" and any(
|
||||
key in obj
|
||||
for key in ("uniqueItems", "contains", "minContains", "maxContains")
|
||||
):
|
||||
return True
|
||||
|
||||
# Unsupported keywords for strings
|
||||
if (
|
||||
obj.get("type") == "string"
|
||||
and "format" in obj
|
||||
and obj["format"] not in STRING_SUPPORTED_FORMATS
|
||||
):
|
||||
return True
|
||||
|
||||
# Unsupported keywords for objects
|
||||
if obj.get("type") == "object" and any(
|
||||
key in obj
|
||||
for key in (
|
||||
"minProperties",
|
||||
"maxProperties",
|
||||
"propertyNames",
|
||||
"patternProperties",
|
||||
)
|
||||
):
|
||||
return True
|
||||
|
||||
# Recursively check all nested objects and arrays
|
||||
for value in obj.values():
|
||||
if isinstance(value, dict):
|
||||
if check_object(value):
|
||||
return True
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and check_object(item):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return check_object(schema)
|
||||
|
||||
|
||||
def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
|
||||
"""Validate that the request is supported by structured output.
|
||||
|
||||
Raises ValueError if the request is not supported.
|
||||
"""
|
||||
if sampling_params.structured_outputs is None:
|
||||
return
|
||||
|
||||
so_params = sampling_params.structured_outputs
|
||||
|
||||
if so_params.regex:
|
||||
try:
|
||||
xgr.Grammar.from_regex(so_params.regex)
|
||||
except Exception as err:
|
||||
raise ValueError(
|
||||
f"Failed to transform regex into a grammar: {err}"
|
||||
) from err
|
||||
|
||||
if so_params.choice:
|
||||
choice_grammar = choice_as_grammar(so_params.choice)
|
||||
try:
|
||||
xgr.Grammar.from_ebnf(choice_grammar)
|
||||
except Exception as err:
|
||||
raise ValueError(
|
||||
"Failed to transform choices into a grammar: {err}"
|
||||
) from err
|
||||
so_params.choice = None
|
||||
so_params.grammar = choice_grammar
|
||||
return
|
||||
|
||||
if so_params.json:
|
||||
if isinstance(so_params.json, str):
|
||||
try:
|
||||
schema = json.loads(so_params.json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError("Invalid JSON grammar specification.") from e
|
||||
else:
|
||||
schema = so_params.json
|
||||
|
||||
try:
|
||||
xgr.Grammar.from_json_schema(schema)
|
||||
except Exception as err:
|
||||
raise ValueError(
|
||||
f"Failed to transform json schema into a grammar: {err}"
|
||||
) from err
|
||||
|
||||
if has_xgrammar_unsupported_json_features(schema):
|
||||
raise ValueError(
|
||||
"The provided JSON schema contains features not supported by xgrammar."
|
||||
)
|
||||
return
|
||||
|
||||
if so_params.grammar:
|
||||
if grammar_is_likely_lark(so_params.grammar):
|
||||
# xgrammar supports EBNF grammars only
|
||||
try:
|
||||
so_params.grammar = convert_lark_to_ebnf(so_params.grammar)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"Failed to convert the grammar from Lark to EBNF. "
|
||||
) from e
|
||||
|
||||
# Test parsing EBNF grammar, possibly already converted from Lark
|
||||
try:
|
||||
# parse the grammar, but we aren't compiling it.
|
||||
xgr.Grammar.from_ebnf(so_params.grammar)
|
||||
except Exception as e:
|
||||
raise ValueError("Invalid grammar specification.") from e
|
||||
return
|
||||
|
||||
if so_params.structural_tag:
|
||||
try:
|
||||
s_tag = json.loads(so_params.structural_tag)
|
||||
|
||||
# Using the deprecated method of compiling structural tag
|
||||
if "structures" in s_tag:
|
||||
tags = [
|
||||
xgr.StructuralTagItem(
|
||||
begin=s["begin"],
|
||||
schema=json.dumps(s["schema"]),
|
||||
end=s["end"],
|
||||
)
|
||||
for s in s_tag["structures"]
|
||||
]
|
||||
xgr.Grammar.from_structural_tag(tags, s_tag["triggers"])
|
||||
else:
|
||||
xgr.Grammar.from_structural_tag(so_params.structural_tag)
|
||||
except Exception as e:
|
||||
raise ValueError("Invalid structural tag specification.") from e
|
||||
@@ -1,94 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
import functools
|
||||
import json
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures._base import TimeoutError
|
||||
from typing import cast
|
||||
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
from vllm.v1.structured_output.backend_types import (
|
||||
StructuredOutputGrammar,
|
||||
StructuredOutputKey,
|
||||
StructuredOutputOptions,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StructuredOutputRequest:
|
||||
params: StructuredOutputsParams
|
||||
_grammar: Future[StructuredOutputGrammar] | StructuredOutputGrammar | None = None
|
||||
reasoning_ended: bool | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_sampling_params(
|
||||
sampling_params: SamplingParams | None,
|
||||
) -> "StructuredOutputRequest | None":
|
||||
if sampling_params is None:
|
||||
return None
|
||||
params = sampling_params.structured_outputs
|
||||
if params:
|
||||
if params.all_constraints_none():
|
||||
return None
|
||||
else:
|
||||
return StructuredOutputRequest(params=params)
|
||||
return None
|
||||
|
||||
def _check_grammar_completion(self) -> bool:
|
||||
# NOTE: We have to lazy import to gate circular imports
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
if isinstance(self._grammar, Future):
|
||||
try:
|
||||
# We will check whether the future is ready within 100 us
|
||||
self._grammar = self._grammar.result(timeout=0.0001)
|
||||
self.status = RequestStatus.WAITING
|
||||
except TimeoutError:
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_grammar_ready(self) -> bool:
|
||||
return self._check_grammar_completion()
|
||||
|
||||
@property
|
||||
def grammar(self) -> StructuredOutputGrammar | None:
|
||||
completed = self._check_grammar_completion()
|
||||
return (
|
||||
cast(StructuredOutputGrammar | None, self._grammar) if completed else None
|
||||
)
|
||||
|
||||
@grammar.setter
|
||||
def grammar(
|
||||
self, grammar: StructuredOutputGrammar | Future[StructuredOutputGrammar]
|
||||
) -> None:
|
||||
self._grammar = grammar
|
||||
|
||||
@functools.cached_property
|
||||
def structured_output_key(self) -> StructuredOutputKey:
|
||||
return get_structured_output_key(self.params)
|
||||
|
||||
|
||||
def get_structured_output_key(params: StructuredOutputsParams) -> StructuredOutputKey:
|
||||
if params.json is not None:
|
||||
if not isinstance(params.json, str):
|
||||
json_str = json.dumps(params.json)
|
||||
else:
|
||||
json_str = params.json
|
||||
return StructuredOutputOptions.JSON, json_str
|
||||
if params.json_object:
|
||||
return StructuredOutputOptions.JSON_OBJECT, ""
|
||||
if params.regex is not None:
|
||||
return StructuredOutputOptions.REGEX, params.regex
|
||||
if params.choice is not None:
|
||||
if not isinstance(params.choice, str):
|
||||
json_str = json.dumps(params.choice)
|
||||
else:
|
||||
json_str = params.choice
|
||||
return StructuredOutputOptions.CHOICE, json_str
|
||||
if params.grammar is not None:
|
||||
return StructuredOutputOptions.GRAMMAR, params.grammar
|
||||
if params.structural_tag is not None:
|
||||
return StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag
|
||||
raise ValueError("No valid structured output parameter found")
|
||||
@@ -1,469 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import importlib.metadata
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import regex as re
|
||||
import torch
|
||||
from cachetools import LRUCache
|
||||
from diskcache import Cache
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import outlines_core as oc
|
||||
import transformers.file_utils as file_utils
|
||||
import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2
|
||||
import xgrammar as xgr
|
||||
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
oc = LazyLoader("oc", globals(), "outlines_core")
|
||||
file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils")
|
||||
tokenization_gpt2 = LazyLoader(
|
||||
"tokenization_gpt2",
|
||||
globals(),
|
||||
"transformers.models.gpt2.tokenization_gpt2",
|
||||
)
|
||||
|
||||
AnyTokenizer = object
|
||||
SchedulerOutput = object
|
||||
InputBatch = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
CACHE = None
|
||||
|
||||
|
||||
def apply_grammar_bitmask(
|
||||
scheduler_output: SchedulerOutput,
|
||||
grammar_output: GrammarOutput,
|
||||
input_batch: InputBatch,
|
||||
logits: torch.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Apply grammar bitmask to output logits of the model with xgrammar function.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): The result of engine scheduling.
|
||||
input_batch (InputBatch): The input of model runner.
|
||||
logits (torch.Tensor): The output logits of model forward.
|
||||
"""
|
||||
# Serialization of np.ndarray is much more efficient than a tensor,
|
||||
# so we receive it in that format.
|
||||
grammar_bitmask = grammar_output.grammar_bitmask
|
||||
|
||||
# We receive the structured output bitmask from the scheduler,
|
||||
# compacted to contain bitmasks only for structured output requests.
|
||||
# The order of the requests in the bitmask is not guaranteed to be the
|
||||
# same as the order of the requests in the gpu runner's batch. We need
|
||||
# to sort the bitmask to match the order of the requests used here.
|
||||
|
||||
# Get the batch indices of the structured output requests.
|
||||
# Keep track of the number of speculative tokens scheduled for every
|
||||
# request in the batch, as the logit indices are offset by this amount.
|
||||
struct_out_req_batch_indices: dict[str, int] = {}
|
||||
cumulative_offset = 0
|
||||
seq = sorted(input_batch.req_id_to_index.items(), key=lambda x: x[1])
|
||||
for req_id, batch_index in seq:
|
||||
logit_index = batch_index + cumulative_offset
|
||||
cumulative_offset += len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
||||
)
|
||||
if req_id in grammar_output.structured_output_request_ids:
|
||||
struct_out_req_batch_indices[req_id] = logit_index
|
||||
|
||||
out_indices = []
|
||||
|
||||
# Reorder the bitmask to match the order of the requests in the batch.
|
||||
sorted_bitmask = np.full(
|
||||
shape=(logits.shape[0], grammar_bitmask.shape[1]),
|
||||
fill_value=-1,
|
||||
dtype=grammar_bitmask.dtype,
|
||||
)
|
||||
cumulative_index = 0
|
||||
for req_id in grammar_output.structured_output_request_ids:
|
||||
num_spec_tokens = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
||||
)
|
||||
if req_id in struct_out_req_batch_indices:
|
||||
logit_index = struct_out_req_batch_indices[req_id]
|
||||
for i in range(1 + num_spec_tokens):
|
||||
sorted_bitmask[logit_index + i] = grammar_bitmask[cumulative_index + i]
|
||||
out_indices.append(logit_index + i)
|
||||
cumulative_index += 1 + num_spec_tokens
|
||||
|
||||
# Copy async to device as tensor.
|
||||
grammar_bitmask = torch.from_numpy(sorted_bitmask).to(
|
||||
logits.device, non_blocking=True
|
||||
)
|
||||
|
||||
# If the length of out indices and the logits have the same shape
|
||||
# we don't need to pass indices to the kernel,
|
||||
# since the bitmask is already aligned with the logits.
|
||||
skip_out_indices = len(out_indices) == logits.shape[0]
|
||||
|
||||
index_tensor = None
|
||||
if not skip_out_indices:
|
||||
# xgrammar expects a python list of indices but it will actually work with
|
||||
# a tensor. If we copy the tensor ourselves here we can do it in a non_blocking
|
||||
# manner and there should be no cpu sync within xgrammar.
|
||||
index_tensor = torch.tensor(
|
||||
out_indices, dtype=torch.int32, device="cpu", pin_memory=True
|
||||
)
|
||||
index_tensor = index_tensor.to(logits.device, non_blocking=True)
|
||||
|
||||
xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor)
|
||||
|
||||
|
||||
class OutlinesVocabulary:
|
||||
"""
|
||||
Wrapper class for `outlines_core.Vocabulary`,
|
||||
which allows us to store a hash with the vocabulary
|
||||
"""
|
||||
|
||||
def __init__(self, vocabulary: oc.Vocabulary) -> None:
|
||||
# Actual vocabulary object
|
||||
self.inner = vocabulary
|
||||
# Have to do abs(hash()) because python hashes can
|
||||
# be negative, and we are using hash as a cache key.
|
||||
hex_str = hashlib.sha256(vocabulary.__repr__().encode("utf-8")).hexdigest()
|
||||
hash_int = int(hex_str, 16)
|
||||
self._hash = hash_int
|
||||
|
||||
|
||||
def get_outlines_cache_path() -> str:
|
||||
"""Get the context object that contains previously-computed return values"""
|
||||
outlines_cache_dir = os.getenv("OUTLINES_CACHE_DIR")
|
||||
xdg_cache_home = os.getenv("XDG_CACHE_HOME")
|
||||
home_dir = os.path.expanduser("~")
|
||||
|
||||
if outlines_cache_dir:
|
||||
# OUTLINES_CACHE_DIR takes precedence
|
||||
return outlines_cache_dir
|
||||
elif xdg_cache_home:
|
||||
return os.path.join(xdg_cache_home, ".cache", "outlines")
|
||||
# If homedir is "/", we may be inside a container, and thus writing to
|
||||
# root would be problematic, so we fall back to using a tempfile.
|
||||
# Also validate the path exists, since os.path.expanduser does
|
||||
# not guarantee existence.
|
||||
elif os.path.isdir(home_dir) and home_dir != "/":
|
||||
# Default Unix fallback: ~/.cache/outlines
|
||||
return os.path.join(home_dir, ".cache", "outlines")
|
||||
else:
|
||||
import tempfile
|
||||
|
||||
# home_dir may be / inside a docker container without existing user
|
||||
tempdir = tempfile.gettempdir()
|
||||
return os.path.join(tempdir, ".cache", "outlines")
|
||||
|
||||
|
||||
def get_outlines_cache():
|
||||
"""Get the Cache instance to be used for index caching"""
|
||||
|
||||
cache_dir = get_outlines_cache_path()
|
||||
if envs.VLLM_V1_USE_OUTLINES_CACHE:
|
||||
logger.warning(
|
||||
"Enabling outlines cache. This is an unbounded on-disk "
|
||||
"cache. It may consume a lot of disk space and should "
|
||||
"not be used with untrusted clients."
|
||||
)
|
||||
cache = Cache(cache_dir, eviction_policy="none", cull_limit=0)
|
||||
outlines_version = importlib.metadata.version("outlines_core")
|
||||
|
||||
cached_version = cache.get("__version__", None)
|
||||
if cached_version != outlines_version:
|
||||
cache.clear()
|
||||
cache.set("__version__", outlines_version)
|
||||
return cache
|
||||
else:
|
||||
return LRUCache(maxsize=128)
|
||||
|
||||
|
||||
re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$")
|
||||
re_replacement_seq = re.compile(r"^.{0,6}<7D>+.{0,6}$")
|
||||
|
||||
|
||||
def _reduced_vocabulary(
|
||||
tokenizer: AnyTokenizer,
|
||||
eos_token_id: int,
|
||||
) -> dict[bytes, list[int]]:
|
||||
"""Create a map from vocabulary tokens to lists of equivalent token ids.
|
||||
|
||||
Returns:
|
||||
A Dict of token string -> equivalent token ids
|
||||
"""
|
||||
|
||||
unicode_to_bytes = {v: k for k, v in tokenization_gpt2.bytes_to_unicode().items()}
|
||||
|
||||
def convert_token_to_string(token: str) -> str:
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if (
|
||||
type(token) is str
|
||||
and token.startswith(file_utils.SPIECE_UNDERLINE)
|
||||
or token == "<0x20>"
|
||||
):
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
vocabulary: dict[bytes, list[int]] = {}
|
||||
empty_token_ids: list[int] = []
|
||||
for token, token_idx in tokenizer.get_vocab().items():
|
||||
if token in tokenizer.all_special_tokens: # type: ignore
|
||||
continue
|
||||
|
||||
token_str = convert_token_to_string(token)
|
||||
if token_str:
|
||||
if isinstance(token, (bytes, bytearray)):
|
||||
# For BPE tokenizers where tokens are stored as bytes.
|
||||
|
||||
# safe to ignore since token_str is of type (bytearray, bytes)
|
||||
# by this point.
|
||||
token_bytes = bytes(token_str) # type: ignore[arg-type]
|
||||
|
||||
elif "\ufffd" in token_str and not re_replacement_seq.match(token_str):
|
||||
# Handle tokens with invalid UTF-8 sequences.
|
||||
if re_llama_byte_token.match(token):
|
||||
# Llama-like tokenizers use <0xXX> for incomplete sequences.
|
||||
token_bytes = bytes([int(token[3:5], 16)])
|
||||
else:
|
||||
# GPT2 tokenizers: map each byte back using unicode_to_bytes
|
||||
byte_vals = [unicode_to_bytes.get(c) for c in token]
|
||||
if None in byte_vals:
|
||||
raise RuntimeError(
|
||||
f"Cannot convert token `{token}`"
|
||||
f" ({token_idx}) to bytes: {token_str}"
|
||||
)
|
||||
# safe to ignore, since if None in byte_vals,
|
||||
# an error is thrown.
|
||||
token_bytes = bytes(byte_vals) # type: ignore[arg-type]
|
||||
else:
|
||||
token_bytes = token_str.encode("utf-8")
|
||||
|
||||
if token_idx != eos_token_id:
|
||||
vocabulary.setdefault(token_bytes, []).append(token_idx)
|
||||
else:
|
||||
empty_token_ids.append(token_idx)
|
||||
|
||||
return vocabulary
|
||||
|
||||
|
||||
def get_outlines_vocabulary(tokenizer: AnyTokenizer) -> oc.Vocabulary:
|
||||
"""Get the `Vocabulary` object for a given tokenizer."""
|
||||
if hasattr(tokenizer, "_outlines_vocabulary"):
|
||||
return tokenizer._outlines_vocabulary # type: ignore
|
||||
|
||||
try:
|
||||
if (
|
||||
hasattr(
|
||||
tokenizer,
|
||||
"eos_token_id",
|
||||
)
|
||||
and tokenizer.eos_token_id is not None
|
||||
):
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Error during structured outputs setup for outlines: Tokenizer ({type(tokenizer)}) has no `eos_token_id` property, but `eos_token_id` is required for structured outputs to work properly." # noqa: E501
|
||||
)
|
||||
|
||||
reduced_vocab = _reduced_vocabulary(
|
||||
tokenizer,
|
||||
eos_token_id, # type: ignore
|
||||
)
|
||||
vocabulary = OutlinesVocabulary(oc.Vocabulary(eos_token_id, reduced_vocab))
|
||||
tokenizer._outlines_vocabulary = vocabulary # type: ignore
|
||||
|
||||
return vocabulary
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
f"Cannot get the vocabulary of the tokenizer "
|
||||
f"({type(tokenizer)}). The tokenizer should have a "
|
||||
"get_vocab method."
|
||||
) from e
|
||||
|
||||
|
||||
def grammar_is_likely_lark(grammar_str: str) -> bool:
|
||||
"""
|
||||
Check if grammar appears to use Lark syntax.
|
||||
|
||||
Args:
|
||||
grammar_str: Input grammar string
|
||||
|
||||
Returns:
|
||||
bool: True if grammar appears to be in Lark format, False otherwise
|
||||
|
||||
Examples:
|
||||
>>> grammar_is_likely_lark("rule: 'abc'")
|
||||
True
|
||||
>>> grammar_is_likely_lark("rule ::= 'abc'")
|
||||
False
|
||||
"""
|
||||
if not grammar_str or not isinstance(grammar_str, str):
|
||||
return False
|
||||
|
||||
for line in grammar_str.split("\n"):
|
||||
# Remove both comment styles
|
||||
line = re.sub(r"(#|//).*$", "", line).strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Look for EBNF rule definition
|
||||
if "::=" in line:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def convert_lark_to_ebnf(grammar_str: str) -> str:
|
||||
"""
|
||||
Convert a Lark grammar string to EBNF format.
|
||||
|
||||
EBNF reference:
|
||||
https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
|
||||
Lark grammar reference:
|
||||
https://lark-parser.readthedocs.io/en/latest/grammar.html
|
||||
|
||||
Args:
|
||||
grammar_str: Input grammar in Lark format
|
||||
|
||||
Returns:
|
||||
str: Converted grammar in EBNF format
|
||||
|
||||
Examples:
|
||||
>>> print(convert_lark_to_ebnf("rule: 'hello'"))
|
||||
root ::= rule
|
||||
rule ::= "hello"
|
||||
"""
|
||||
if not isinstance(grammar_str, str):
|
||||
raise ValueError(f"Grammar must be a string, got {type(grammar_str)}")
|
||||
if not grammar_str.strip():
|
||||
raise ValueError("Grammar string cannot be empty")
|
||||
|
||||
defined_rules = set()
|
||||
referenced_rules = set()
|
||||
output_lines = []
|
||||
|
||||
def clean_line(line: str) -> str:
|
||||
"""Remove comments and whitespace from line."""
|
||||
return re.sub(r"(#|//).*$", "", line).strip()
|
||||
|
||||
def check_quotes(text: str, rule_name: str, line_num: int) -> None:
|
||||
"""Validate quote matching in text."""
|
||||
if text.count("'") % 2 != 0 or text.count('"') % 2 != 0:
|
||||
raise ValueError(f"Mismatched quotes in {rule_name} on line {line_num}")
|
||||
|
||||
def extract_references(text: str) -> set[str]:
|
||||
"""Extract rule references from text."""
|
||||
# Remove quoted strings and special characters
|
||||
text = re.sub(r'"[^"]*"', "", text)
|
||||
text = re.sub(r"[+*?()|\[\]{}]", " ", text)
|
||||
return set(re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", text))
|
||||
|
||||
# First pass: Find root rule and validate rule definitions
|
||||
lines = [clean_line(line) for line in grammar_str.split("\n")]
|
||||
first_rule = None
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line or line.startswith("|"):
|
||||
continue
|
||||
|
||||
if ":" in line:
|
||||
try:
|
||||
name = line.split(":", 1)[0].strip().strip("?")
|
||||
defined_rules.add(name)
|
||||
if first_rule is None:
|
||||
first_rule = name
|
||||
if name == "start":
|
||||
first_rule = "start"
|
||||
except IndexError as e:
|
||||
raise ValueError(
|
||||
f"Invalid rule format on line {line_num}. "
|
||||
"Expected 'rule_name: definition'"
|
||||
) from e
|
||||
|
||||
if not defined_rules:
|
||||
raise ValueError("No valid rules found in grammar")
|
||||
|
||||
# Add root rule
|
||||
output_lines.append(f"root ::= {first_rule}")
|
||||
|
||||
# Second pass: Process rule definitions and alternatives
|
||||
current_rule = None
|
||||
current_definition = []
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
if ":" in line and not line.startswith("|"):
|
||||
# Save previous rule if exists
|
||||
if current_rule:
|
||||
output_lines.append(
|
||||
f"{current_rule} ::= {' | '.join(current_definition)}"
|
||||
)
|
||||
|
||||
# Process new rule
|
||||
name, definition = line.split(":", 1)
|
||||
current_rule = name.strip().strip("?")
|
||||
|
||||
check_quotes(definition, f"rule '{current_rule}'", line_num)
|
||||
definition = re.sub(r"'([^']*)'", r'"\1"', definition)
|
||||
referenced_rules.update(extract_references(definition))
|
||||
current_definition = [definition.strip()]
|
||||
|
||||
elif line.startswith("|"):
|
||||
if not current_rule:
|
||||
raise ValueError(
|
||||
f"Alternative '|' on line {line_num} "
|
||||
"without a preceding rule definition"
|
||||
)
|
||||
|
||||
alt_def = line[1:].strip()
|
||||
check_quotes(
|
||||
alt_def, f"alternative for rule '{current_rule}'", line_num
|
||||
)
|
||||
alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def)
|
||||
referenced_rules.update(extract_references(alt_def))
|
||||
current_definition.append(alt_def)
|
||||
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Error on line {line_num}: {str(e)}") from e
|
||||
|
||||
# Add final rule if exists
|
||||
if current_rule:
|
||||
output_lines.append(f"{current_rule} ::= {' | '.join(current_definition)}")
|
||||
|
||||
# Validate all rules are defined
|
||||
undefined_rules = referenced_rules - defined_rules - {"root"}
|
||||
if undefined_rules:
|
||||
raise ValueError(
|
||||
f"Referenced rules are not defined: {', '.join(sorted(undefined_rules))}"
|
||||
)
|
||||
|
||||
return "\n".join(output_lines)
|
||||
|
||||
|
||||
def choice_as_grammar(choice: list[str]) -> str:
|
||||
def escape_ebnf_string(s: str) -> str:
|
||||
"""Escape special characters in a EBNF string."""
|
||||
# Escape double quotes and backslashes
|
||||
return re.sub(r'(["\\])', r"\\\1", s)
|
||||
|
||||
escaped_choices = (escape_ebnf_string(c) for c in choice)
|
||||
grammar = "root ::= " + " | ".join(f'"{c}"' for c in escaped_choices)
|
||||
return grammar
|
||||
Reference in New Issue
Block a user