[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
222
vllm/v1/structured_output/__init__.py
Normal file
222
vllm/v1/structured_output/__init__.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||
StructuredOutputGrammar)
|
||||
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.v1.request import Request
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class StructuredOutputManager:
|
||||
"""Engine-level manager for structured output requests."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.backend: Optional[StructuredOutputBackend] = None
|
||||
self.reasoner: Optional[ReasoningParser] = None
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
self._grammar_bitmask: Optional[torch.Tensor] = None
|
||||
self._full_mask = torch.tensor(-1, dtype=torch.int32)
|
||||
|
||||
# The default max_workers if not specified is the number of CPUs * 5,
|
||||
# which is way too high since these tasks are CPU-bound, not I/O bound.
|
||||
# We also know we would never dominate CPU usage with just grammar
|
||||
# compilation, so we set it to half the number of CPUs.
|
||||
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=self.vllm_config.model_config,
|
||||
scheduler_config=self.vllm_config.scheduler_config,
|
||||
lora_config=self.vllm_config.lora_config,
|
||||
).get_lora_tokenizer(None)
|
||||
reasoning_backend = vllm_config.decoding_config.reasoning_backend
|
||||
if reasoning_backend:
|
||||
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
|
||||
reasoning_backend)
|
||||
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
|
||||
|
||||
def grammar_init(self, request: Request) -> None:
|
||||
if request.structured_output_request is None:
|
||||
return
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert request.sampling_params.guided_decoding is not None
|
||||
|
||||
# Initialize the backend the first time it is needed.
|
||||
#
|
||||
# NOTE: We only support a single backend. We do NOT support different
|
||||
# backends on a per-request basis in V1 (for now, anyway...).
|
||||
if self.backend is None:
|
||||
backend = request.sampling_params.guided_decoding.backend
|
||||
vocab_size = self.vllm_config.model_config.get_vocab_size()
|
||||
if backend == "xgrammar":
|
||||
self.backend = XgrammarBackend(
|
||||
self.vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
elif backend == "guidance":
|
||||
self.backend = GuidanceBackend(
|
||||
self.vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
vocab_size=vocab_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported structured output backend: {backend}")
|
||||
|
||||
grammar = self.executor.submit(self._async_create_grammar, request)
|
||||
request.structured_output_request.grammar = grammar # type: ignore[assignment]
|
||||
|
||||
def _async_create_grammar(
|
||||
self,
|
||||
request: Request,
|
||||
) -> StructuredOutputGrammar:
|
||||
key = request.structured_output_request.structured_output_key # type: ignore[union-attr]
|
||||
|
||||
# Note that the request was validated in the engine core client,
|
||||
# so at this point we know it is a supported type of request.
|
||||
#
|
||||
# TODO: we still need to handle xgrammar compilation failures,
|
||||
# though it should be unlikely as we test that up front as well.
|
||||
request_type, grammar_spec = key
|
||||
|
||||
assert self.backend is not None
|
||||
return self.backend.compile_grammar(request_type, grammar_spec)
|
||||
|
||||
def grammar_bitmask(
|
||||
self,
|
||||
requests: dict[str, Request],
|
||||
structured_output_request_ids: dict[str, int],
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]],
|
||||
) -> Optional[npt.NDArray[np.int32]]:
|
||||
# Prepare the structured output bitmask for this batch.
|
||||
if not structured_output_request_ids:
|
||||
return None
|
||||
|
||||
max_num_spec_tokens = 0
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
max_num_spec_tokens = \
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
|
||||
if self._grammar_bitmask is None:
|
||||
assert self.backend is not None
|
||||
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
|
||||
|
||||
# Allocate a bitmask for each token needing to be checked:
|
||||
# one for each speculative position, and one more for the
|
||||
# bonus token / non-speculative token.
|
||||
self._grammar_bitmask = \
|
||||
self.backend.allocate_token_bitmask(
|
||||
max_batch_size * (1 + max_num_spec_tokens))
|
||||
|
||||
bitmask_tensor = self._grammar_bitmask
|
||||
# Generate a batched bitmask for all structured output requests.
|
||||
# When speculative decoding is enabled, we need to include multiple
|
||||
# masks for each request, one for each possible bonus token position.
|
||||
# These are stored inline in the tensor and unpacked by the gpu runner.
|
||||
cumulative_index = 0
|
||||
ordered_seq = sorted(structured_output_request_ids.items(),
|
||||
key=lambda x: x[1])
|
||||
|
||||
# Note that for thinking support, we will need to
|
||||
# reset the relevant part of the bitmask for consequent
|
||||
# request here.
|
||||
bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_(
|
||||
self._full_mask)
|
||||
|
||||
# NOTE: This outer loop can likely be parallelized to improve
|
||||
# performance of bitmask generation for large batches.
|
||||
for req_id, _ in ordered_seq:
|
||||
request = requests[req_id]
|
||||
structured_output_request = request.structured_output_request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert structured_output_request is not None
|
||||
assert structured_output_request.grammar is not None
|
||||
apply_bitmask: bool = True
|
||||
if self.reasoner is not None:
|
||||
if structured_output_request.reasoning_ended is None:
|
||||
structured_output_request.reasoning_ended = \
|
||||
self.reasoner.is_reasoning_end(request.prompt_token_ids)
|
||||
apply_bitmask = structured_output_request.reasoning_ended
|
||||
|
||||
state_advancements = 0
|
||||
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
|
||||
for i, token in enumerate(req_tokens):
|
||||
if apply_bitmask and not \
|
||||
structured_output_request.grammar.is_terminated():
|
||||
structured_output_request.grammar.fill_bitmask(
|
||||
bitmask_tensor, cumulative_index)
|
||||
if token is not None:
|
||||
# In order to generate the correct bitmask for each
|
||||
# position in the speculative sequence, we advance
|
||||
# the FSM state for each speculative token and rollback
|
||||
# to restore the previous state when we are finished.
|
||||
assert structured_output_request.grammar.accept_tokens(
|
||||
req_id, [token])
|
||||
state_advancements += 1
|
||||
cumulative_index += 1
|
||||
if state_advancements > 0:
|
||||
structured_output_request.grammar.rollback(state_advancements)
|
||||
|
||||
if cumulative_index < bitmask_tensor.shape[0]:
|
||||
bitmask_tensor = bitmask_tensor[:cumulative_index]
|
||||
|
||||
# After finishing with the xgrammar operations, we convert to
|
||||
# np.ndarray, because that is much more efficient for serialization
|
||||
# and deserialization when sending this to the GPU workers.
|
||||
return bitmask_tensor.numpy()
|
||||
|
||||
def should_advance(self, request: Request) -> bool:
|
||||
if not request.use_structured_output:
|
||||
return False
|
||||
|
||||
# To determine whether we can advance the FSM.
|
||||
# Supports thinking usage where we skip the reasoning components.
|
||||
if TYPE_CHECKING:
|
||||
assert request.structured_output_request is not None
|
||||
assert request.structured_output_request.grammar is not None
|
||||
# by default, we should always advance
|
||||
# for cases that doesn't uses thinking mode.
|
||||
if self.reasoner is not None:
|
||||
structured_req = request.structured_output_request
|
||||
|
||||
if structured_req.reasoning_ended:
|
||||
return True
|
||||
|
||||
# Check if reasoning ends in *this* step
|
||||
if self.reasoner.is_reasoning_end(request.all_token_ids):
|
||||
# Reasoning just ended, so we shouldn't advanced til
|
||||
# next pass
|
||||
structured_req.reasoning_ended = True
|
||||
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def clear_backend(self) -> None:
|
||||
if self.backend is not None:
|
||||
self.backend.destroy()
|
||||
245
vllm/v1/structured_output/backend_guidance.py
Normal file
245
vllm/v1/structured_output/backend_guidance.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
StructuredOutputOptions)
|
||||
from vllm.v1.structured_output.request import get_structured_output_key
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import llguidance
|
||||
import llguidance.hf as llguidance_hf
|
||||
import llguidance.torch as llguidance_torch
|
||||
else:
|
||||
llguidance = LazyLoader("llguidance", globals(), "llguidance")
|
||||
llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf")
|
||||
llguidance_torch = LazyLoader("llguidance.torch", globals(),
|
||||
"llguidance.torch")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _walk_json_for_additional_properties(data: object):
|
||||
if isinstance(data, dict):
|
||||
for value in data.values():
|
||||
_walk_json_for_additional_properties(value)
|
||||
if 'additionalProperties' not in data and \
|
||||
('properties' in data or 'patternProperties' in data):
|
||||
data['additionalProperties'] = False
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
_walk_json_for_additional_properties(item)
|
||||
|
||||
|
||||
def process_for_additional_properties(
|
||||
guide_json: Union[str, dict[str, Any]]) -> dict[str, Any]:
|
||||
if isinstance(guide_json, str):
|
||||
guide_json_obj = json.loads(guide_json)
|
||||
else:
|
||||
# copy for modifications
|
||||
guide_json_obj = copy.deepcopy(guide_json)
|
||||
_walk_json_for_additional_properties(guide_json_obj)
|
||||
return guide_json_obj
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuidanceBackend(StructuredOutputBackend):
|
||||
|
||||
def __post_init__(self):
|
||||
self.disable_any_whitespace = \
|
||||
self.vllm_config.decoding_config.disable_any_whitespace
|
||||
self.disable_additional_properties = \
|
||||
self.vllm_config.decoding_config.disable_additional_properties
|
||||
|
||||
self.ll_tokenizer = llguidance_hf.from_tokenizer(
|
||||
self.tokenizer, self.vocab_size)
|
||||
|
||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||
grammar_spec: str) -> StructuredOutputGrammar:
|
||||
self.serialized_grammar = serialize_guidance_grammar(
|
||||
request_type, grammar_spec, self.disable_any_whitespace,
|
||||
self.disable_additional_properties)
|
||||
|
||||
ll_matcher = llguidance.LLMatcher(
|
||||
self.ll_tokenizer,
|
||||
self.serialized_grammar,
|
||||
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
|
||||
)
|
||||
|
||||
r = GuidanceGrammar(
|
||||
ll_matcher=ll_matcher,
|
||||
ll_tokenizer=self.ll_tokenizer,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
r.check_error()
|
||||
return r
|
||||
|
||||
def allocate_token_bitmask(self, max_num_seqs: int):
|
||||
return llguidance_torch.allocate_token_bitmask(
|
||||
max_num_seqs, self.ll_tokenizer.vocab_size)
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuidanceGrammar(StructuredOutputGrammar):
|
||||
ll_matcher: llguidance.LLMatcher
|
||||
ll_tokenizer: llguidance.LLTokenizer
|
||||
vocab_size: int
|
||||
printed_error: bool = False
|
||||
terminated: bool = False
|
||||
|
||||
def check_error(self):
|
||||
if not self.printed_error:
|
||||
err = self.ll_matcher.get_error()
|
||||
if err:
|
||||
self.printed_error = True
|
||||
logger.warning("LLMatcher error: %s", err)
|
||||
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""Accepts a list of tokens and advances the parser.
|
||||
|
||||
Returns True if the parser was advanced successfully.
|
||||
Returns False if the parser failed to advance.
|
||||
"""
|
||||
|
||||
if self.ll_tokenizer.eos_token in tokens:
|
||||
self.terminated = True
|
||||
|
||||
if self.ll_matcher.is_stopped():
|
||||
return True
|
||||
|
||||
# TODO - Add jump decoding support in the future:
|
||||
# self.ll_matcher.compute_ff_bytes() - this should always work
|
||||
# self.ll_matcher.compute_ff_tokens() - this only works for
|
||||
# "canonical" tokenizers
|
||||
# For conversion between the two, see
|
||||
# https://github.com/guidance-ai/llguidance/blob/main/docs/fast_forward.md
|
||||
|
||||
r = self.ll_matcher.consume_tokens(tokens)
|
||||
|
||||
self.check_error()
|
||||
|
||||
return r
|
||||
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
"""Checks if the list of tokens are accepted by the parser in sequence.
|
||||
Will not advance the parser.
|
||||
|
||||
Returns the prefix list of tokens that are accepted by the parser.
|
||||
"""
|
||||
if len(tokens) == 0:
|
||||
return []
|
||||
if self.ll_matcher.is_stopped():
|
||||
return []
|
||||
|
||||
num_tokens = self.ll_matcher.validate_tokens(tokens)
|
||||
|
||||
self.check_error()
|
||||
|
||||
return tokens[:num_tokens]
|
||||
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
self.ll_matcher.rollback(num_tokens)
|
||||
self.check_error()
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||
# this will automatically return [EOS] mask if the matcher is stopped
|
||||
# or otherwise in an error state
|
||||
llguidance_torch.fill_next_token_bitmask(self.ll_matcher, bitmask, idx)
|
||||
self.check_error()
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
return self.terminated
|
||||
|
||||
def reset(self):
|
||||
# This method may be not needed anymore? TODO
|
||||
self.ll_matcher.reset()
|
||||
|
||||
|
||||
def serialize_guidance_grammar(
|
||||
request_type: StructuredOutputOptions,
|
||||
grammar_spec: Union[str, dict[str, Any]],
|
||||
disable_any_whitespace: bool = False,
|
||||
disable_additional_properties: bool = False,
|
||||
) -> str:
|
||||
|
||||
def _process_schema(grammar_spec: Union[str, dict[str, Any]], ) -> str:
|
||||
if disable_additional_properties:
|
||||
grammar_spec = process_for_additional_properties(grammar_spec)
|
||||
return llguidance.LLMatcher.grammar_from_json_schema(
|
||||
grammar_spec,
|
||||
defaults={
|
||||
"whitespace_flexible": not disable_any_whitespace,
|
||||
})
|
||||
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
return _process_schema(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
||||
return llguidance.LLMatcher.grammar_from_json_schema(
|
||||
'{"type": "object"}',
|
||||
defaults={
|
||||
"whitespace_flexible": not disable_any_whitespace,
|
||||
})
|
||||
else:
|
||||
if request_type == StructuredOutputOptions.REGEX:
|
||||
tp = "regex"
|
||||
elif request_type == StructuredOutputOptions.GRAMMAR:
|
||||
tp = "grammar"
|
||||
elif request_type == StructuredOutputOptions.CHOICE:
|
||||
tp = "choice"
|
||||
elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
|
||||
if isinstance(grammar_spec, str):
|
||||
s_tag = json.loads(grammar_spec)
|
||||
else:
|
||||
s_tag = grammar_spec
|
||||
triggers: list[str] = s_tag["triggers"]
|
||||
tags: list[llguidance.StructTag] = []
|
||||
for s in s_tag["structures"]:
|
||||
begin: str = s["begin"]
|
||||
trig = next((t for t in triggers if begin.startswith(t)), None)
|
||||
if trig is None:
|
||||
raise ValueError(
|
||||
f"Trigger {begin} not found in triggers {triggers}")
|
||||
tags.append(
|
||||
llguidance.StructTag(
|
||||
trigger=trig,
|
||||
begin=s["begin"],
|
||||
grammar=_process_schema(s["schema"]),
|
||||
end=s["end"],
|
||||
))
|
||||
if not tags:
|
||||
raise ValueError(
|
||||
"No structural tags found in the grammar spec.")
|
||||
return llguidance.StructTag.to_grammar(tags)
|
||||
else:
|
||||
logger.error("Validation should have already occurred. "
|
||||
"Please file an issue.")
|
||||
raise ValueError("grammar is not of valid supported types. "
|
||||
f"({request_type!s})")
|
||||
return llguidance.grammar_from(tp, grammar_spec)
|
||||
|
||||
|
||||
def validate_guidance_grammar(
|
||||
sampling_params: SamplingParams,
|
||||
tokenizer: Optional[llguidance.LLTokenizer] = None) -> None:
|
||||
tp, grm = get_structured_output_key(sampling_params)
|
||||
guidance_grm = serialize_guidance_grammar(tp, grm)
|
||||
err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer)
|
||||
if err:
|
||||
raise ValueError(f"Grammar error: {err}")
|
||||
134
vllm/v1/structured_output/backend_types.py
Normal file
134
vllm/v1/structured_output/backend_types.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
class StructuredOutputOptions(enum.Enum):
|
||||
JSON = enum.auto()
|
||||
JSON_OBJECT = enum.auto()
|
||||
REGEX = enum.auto()
|
||||
GRAMMAR = enum.auto()
|
||||
CHOICE = enum.auto()
|
||||
STRUCTURAL_TAG = enum.auto()
|
||||
|
||||
|
||||
StructuredOutputKey = tuple[StructuredOutputOptions, str]
|
||||
|
||||
|
||||
class StructuredOutputGrammar(ABC):
|
||||
"""Request-level backend for structured output requests."""
|
||||
|
||||
@abstractmethod
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""
|
||||
Determines whether the provided tokens are accepted for the
|
||||
given request.
|
||||
|
||||
Args:
|
||||
request_id (str): The unique identifier for the request.
|
||||
tokens (list[int]): A list of token IDs to evaluate.
|
||||
|
||||
Returns:
|
||||
bool: True if the tokens are accepted, False otherwise.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
"""
|
||||
Validates the provided tokens against the grammar.
|
||||
Will not advance the FSM.
|
||||
|
||||
Args:
|
||||
tokens (list[int]): A list of token IDs to validate.
|
||||
|
||||
Returns:
|
||||
list[int]: A list of accepted token IDs. Will be a prefix
|
||||
of the input tokens, and empty if none are accepted.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
"""
|
||||
Rolls back the state of the grammar by a specified number of tokens.
|
||||
Will also revert counters for the number of processed tokens.
|
||||
|
||||
Args:
|
||||
num_tokens (int): The number of tokens to roll back.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
|
||||
"""
|
||||
Fills the bitmask for a specific batch index.
|
||||
|
||||
Args:
|
||||
bitmask (torch.Tensor): The bitmask to fill
|
||||
batch_index (int): The index in the bitmask to fill
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def is_terminated(self) -> bool:
|
||||
"""
|
||||
Checks whether the structured output process has terminated.
|
||||
|
||||
Returns:
|
||||
bool: True if the process is terminated, False otherwise.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the state of the structured output grammar.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructuredOutputBackend(ABC):
|
||||
"""Engine-level backend for structured output requests."""
|
||||
|
||||
vllm_config: VllmConfig
|
||||
tokenizer: AnyTokenizer
|
||||
vocab_size: int
|
||||
|
||||
@abstractmethod
|
||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||
grammar_spec: str) -> StructuredOutputGrammar:
|
||||
"""
|
||||
Compiles a grammar specification into a structured output grammar.
|
||||
|
||||
Args:
|
||||
request_type (StructuredOutputOptions): The type of structured
|
||||
output request.
|
||||
grammar_spec (str): The grammar specification to compile.
|
||||
|
||||
Returns:
|
||||
StructuredOutputGrammar: The compiled structured output grammar.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
|
||||
"""
|
||||
Allocates a token bitmask for the specified maximum number of sequences.
|
||||
|
||||
Args:
|
||||
max_num_seqs (int): The maximum number of sequences for which
|
||||
to allocate the bitmask.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def destroy(self):
|
||||
"""
|
||||
Backend-specific cleanup.
|
||||
"""
|
||||
318
vllm/v1/structured_output/backend_xgrammar.py
Normal file
318
vllm/v1/structured_output/backend_xgrammar.py
Normal file
@@ -0,0 +1,318 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
StructuredOutputOptions)
|
||||
from vllm.v1.structured_output.utils import (choice_as_grammar,
|
||||
convert_lark_to_ebnf,
|
||||
grammar_is_likely_lark)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr
|
||||
else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class XgrammarBackend(StructuredOutputBackend):
|
||||
|
||||
def __post_init__(self):
|
||||
self.disable_any_whitespace = \
|
||||
self.vllm_config.decoding_config.disable_any_whitespace
|
||||
|
||||
if isinstance(self.tokenizer, MistralTokenizer):
|
||||
# NOTE: ideally, xgrammar should handle this accordingly.
|
||||
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
|
||||
try:
|
||||
if self.tokenizer.is_tekken:
|
||||
encoded_vocab = self.tokenizer._vocab
|
||||
else:
|
||||
encoded_vocab = [
|
||||
token for token, _ in sorted(
|
||||
self.tokenizer.get_vocab().items(),
|
||||
key=lambda x: x[1],
|
||||
)
|
||||
]
|
||||
stop_token_ids = None
|
||||
if (hasattr(
|
||||
self.tokenizer,
|
||||
"eos_token_id",
|
||||
) and self.tokenizer.eos_token_id is not None):
|
||||
stop_token_ids = [self.tokenizer.eos_token_id]
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
f"Cannot get the vocabulary of the tokenizer "
|
||||
f"{type(self.tokenizer)}. The tokenizer should have a "
|
||||
"get_vocab method.") from e
|
||||
tokenizer_info = xgr.TokenizerInfo( # type: ignore
|
||||
encoded_vocab=encoded_vocab,
|
||||
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
|
||||
vocab_type=xgr.VocabType.RAW
|
||||
if self.tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK,
|
||||
vocab_size=self.vocab_size,
|
||||
stop_token_ids=stop_token_ids,
|
||||
add_prefix_space=True,
|
||||
)
|
||||
else:
|
||||
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
||||
self.tokenizer,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
self.compiler = xgr.GrammarCompiler(
|
||||
tokenizer_info,
|
||||
max_threads=8,
|
||||
cache_enabled=True,
|
||||
cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024,
|
||||
)
|
||||
|
||||
self.num_speculative_tokens = 0
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
self.num_speculative_tokens = \
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
|
||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||
grammar_spec: str) -> StructuredOutputGrammar:
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
ctx = self.compiler.compile_json_schema(
|
||||
grammar_spec, any_whitespace=not self.disable_any_whitespace)
|
||||
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
||||
ctx = self.compiler.compile_json_schema(
|
||||
'{"type": "object"}',
|
||||
any_whitespace=not self.disable_any_whitespace)
|
||||
elif request_type == StructuredOutputOptions.GRAMMAR:
|
||||
ctx = self.compiler.compile_grammar(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.REGEX:
|
||||
ctx = self.compiler.compile_regex(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
|
||||
s_tag = json.loads(grammar_spec)
|
||||
tags = [
|
||||
xgr.StructuralTagItem(
|
||||
begin=s["begin"],
|
||||
schema=json.dumps(s["schema"]),
|
||||
end=s["end"],
|
||||
) for s in s_tag["structures"]
|
||||
]
|
||||
ctx = self.compiler.compile_structural_tag(tags, s_tag["triggers"])
|
||||
else:
|
||||
logger.error(
|
||||
"Validation should have already occurred. Please file an issue."
|
||||
)
|
||||
raise ValueError(
|
||||
f"grammar is not of valid supported types. ({request_type!s})")
|
||||
|
||||
return XgrammarGrammar(
|
||||
matcher=xgr.GrammarMatcher(
|
||||
ctx,
|
||||
max_rollback_tokens=self.num_speculative_tokens,
|
||||
),
|
||||
vocab_size=self.vocab_size,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
def allocate_token_bitmask(self, max_num_seqs: int):
|
||||
return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size)
|
||||
|
||||
def destroy(self):
|
||||
del self.compiler
|
||||
|
||||
|
||||
@dataclass
|
||||
class XgrammarGrammar(StructuredOutputGrammar):
|
||||
# NOTE: This would be a generic-enough class for
|
||||
# supporting different backends, in the future.
|
||||
# For now, just xgrammar.
|
||||
#
|
||||
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
|
||||
# for jump-forward decoding
|
||||
|
||||
vocab_size: int
|
||||
matcher: xgr.GrammarMatcher = field(hash=False)
|
||||
ctx: xgr.CompiledGrammar = field(hash=False)
|
||||
num_processed_tokens: int = field(default_factory=lambda: 0,
|
||||
repr=False,
|
||||
hash=False,
|
||||
init=False)
|
||||
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""Accepts a list of tokens and advances the FSM.
|
||||
|
||||
Returns True if the FSM was advanced successfully.
|
||||
Returns False if the FSM failed to advance.
|
||||
"""
|
||||
for token in tokens:
|
||||
if not self.matcher.accept_token(token):
|
||||
logger.error(
|
||||
"Failed to advance FSM for request %s "
|
||||
"for tokens %s. Please file an issue.", request_id, token)
|
||||
return False
|
||||
self.num_processed_tokens += 1
|
||||
return True
|
||||
|
||||
def validate_tokens(self, tokens: list[int]) -> list[int]:
|
||||
"""Checks if the list of tokens are accepted by the FSM in sequence.
|
||||
Will not advance the FSM.
|
||||
|
||||
Returns the prefix list of tokens that are accepted by the FSM.
|
||||
"""
|
||||
accepted_tokens = []
|
||||
for token in tokens:
|
||||
if self.matcher.accept_token(token):
|
||||
accepted_tokens.append(token)
|
||||
else:
|
||||
break
|
||||
if len(accepted_tokens) > 0:
|
||||
# Rollback the FSM to the initial state
|
||||
self.matcher.rollback(len(accepted_tokens))
|
||||
return accepted_tokens
|
||||
|
||||
def rollback(self, num_tokens: int) -> None:
|
||||
self.matcher.rollback(num_tokens)
|
||||
self.num_processed_tokens -= num_tokens
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||
self.matcher.fill_next_token_bitmask(bitmask, idx)
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
return self.matcher.is_terminated()
|
||||
|
||||
def reset(self):
|
||||
self.num_processed_tokens = 0
|
||||
self.matcher.reset()
|
||||
|
||||
|
||||
def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
|
||||
"""Check if JSON schema contains features unsupported by xgrammar."""
|
||||
|
||||
def check_object(obj: dict[str, Any]) -> bool:
|
||||
if not isinstance(obj, dict):
|
||||
return False
|
||||
|
||||
# Check for numeric ranges
|
||||
if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
|
||||
return True
|
||||
|
||||
# Check for array unsupported keywords
|
||||
if obj.get("type") == "array" and any(
|
||||
key in obj for key in ("uniqueItems", "contains",
|
||||
"minContains", "maxContains")):
|
||||
return True
|
||||
|
||||
# Unsupported keywords for strings
|
||||
if obj.get("type") == "string" and "format" in obj:
|
||||
return True
|
||||
|
||||
# Unsupported keywords for objects
|
||||
if obj.get("type") == "object" and any(
|
||||
key in obj for key in ("minProperties", "maxProperties",
|
||||
"propertyNames", "patternProperties")):
|
||||
return True
|
||||
|
||||
# Recursively check all nested objects and arrays
|
||||
for value in obj.values():
|
||||
if isinstance(value, dict):
|
||||
if check_object(value):
|
||||
return True
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and check_object(item):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return check_object(schema)
|
||||
|
||||
|
||||
def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
|
||||
"""Validate that the request is supported by structured output.
|
||||
|
||||
Raises ValueError if the request is not supported.
|
||||
"""
|
||||
if sampling_params.guided_decoding is None:
|
||||
return
|
||||
|
||||
gd_params = sampling_params.guided_decoding
|
||||
|
||||
if gd_params.regex:
|
||||
try:
|
||||
xgr.Grammar.from_regex(gd_params.regex)
|
||||
except Exception as err:
|
||||
raise ValueError("Failed to transform regex into a grammar: "
|
||||
f"{err}") from err
|
||||
|
||||
if gd_params.choice:
|
||||
choice_grammar = choice_as_grammar(gd_params.choice)
|
||||
try:
|
||||
xgr.Grammar.from_ebnf(choice_grammar)
|
||||
except Exception as err:
|
||||
raise ValueError("Failed to transform choices into a grammar: "
|
||||
"{err}") from err
|
||||
gd_params.choice = None
|
||||
gd_params.grammar = choice_grammar
|
||||
return
|
||||
|
||||
if gd_params.json:
|
||||
if isinstance(gd_params.json, str):
|
||||
try:
|
||||
schema = json.loads(gd_params.json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError("Invalid JSON grammar specification.") from e
|
||||
else:
|
||||
schema = gd_params.json
|
||||
|
||||
try:
|
||||
xgr.Grammar.from_json_schema(schema)
|
||||
except Exception as err:
|
||||
raise ValueError("Failed to transform json schema into a grammar: "
|
||||
f"{err}") from err
|
||||
|
||||
if has_xgrammar_unsupported_json_features(schema):
|
||||
raise ValueError("The provided JSON schema contains features not "
|
||||
"supported by xgrammar.")
|
||||
return
|
||||
|
||||
if gd_params.grammar:
|
||||
if grammar_is_likely_lark(gd_params.grammar):
|
||||
# xgrammar supports EBNF grammars only
|
||||
try:
|
||||
gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"Failed to convert the grammar from Lark to EBNF. ") from e
|
||||
|
||||
# Test parsing EBNF grammar, possibly already converted from Lark
|
||||
try:
|
||||
# parse the grammar, but we aren't compiling it.
|
||||
xgr.Grammar.from_ebnf(gd_params.grammar)
|
||||
except Exception as e:
|
||||
raise ValueError("Invalid grammar specification.") from e
|
||||
return
|
||||
|
||||
if gd_params.structural_tag:
|
||||
try:
|
||||
s_tag = json.loads(gd_params.structural_tag)
|
||||
tags = [
|
||||
xgr.StructuralTagItem(
|
||||
begin=s["begin"],
|
||||
schema=json.dumps(s["schema"]),
|
||||
end=s["end"],
|
||||
) for s in s_tag["structures"]
|
||||
]
|
||||
xgr.Grammar.from_structural_tag(tags, s_tag["triggers"])
|
||||
except Exception as e:
|
||||
raise ValueError("Invalid structural tag specification.") from e
|
||||
86
vllm/v1/structured_output/request.py
Normal file
86
vllm/v1/structured_output/request.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import json
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures._base import TimeoutError
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputGrammar,
|
||||
StructuredOutputKey,
|
||||
StructuredOutputOptions)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StructuredOutputRequest:
|
||||
|
||||
sampling_params: SamplingParams
|
||||
_grammar: Optional[Union[Future[StructuredOutputGrammar],
|
||||
StructuredOutputGrammar]] = None
|
||||
reasoning_ended: Optional[bool] = None
|
||||
|
||||
def _check_grammar_completion(self) -> bool:
|
||||
# NOTE: We have to lazy import to gate circular imports
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
if isinstance(self._grammar, Future):
|
||||
try:
|
||||
# We will check whether the future is ready within 100 us
|
||||
self._grammar = self._grammar.result(timeout=0.0001)
|
||||
self.status = RequestStatus.WAITING
|
||||
except TimeoutError:
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_grammar_ready(self) -> bool:
|
||||
return self._check_grammar_completion()
|
||||
|
||||
@property
|
||||
def grammar(self) -> Optional[StructuredOutputGrammar]:
|
||||
completed = self._check_grammar_completion()
|
||||
return cast(Optional[StructuredOutputGrammar],
|
||||
self._grammar) if completed else None
|
||||
|
||||
@grammar.setter
|
||||
def grammar(
|
||||
self, grammar: Union[StructuredOutputGrammar,
|
||||
Future[StructuredOutputGrammar]]
|
||||
) -> None:
|
||||
self._grammar = grammar
|
||||
|
||||
@functools.cached_property
|
||||
def structured_output_key(self) -> StructuredOutputKey:
|
||||
return get_structured_output_key(self.sampling_params)
|
||||
|
||||
|
||||
def get_structured_output_key(
|
||||
sampling_params: SamplingParams) -> StructuredOutputKey:
|
||||
params = sampling_params.guided_decoding
|
||||
assert params is not None, "params can't be None."
|
||||
if params.json is not None:
|
||||
if not isinstance(params.json, str):
|
||||
json_str = json.dumps(params.json)
|
||||
else:
|
||||
json_str = params.json
|
||||
return (StructuredOutputOptions.JSON, json_str)
|
||||
elif params.json_object:
|
||||
return (StructuredOutputOptions.JSON_OBJECT, "")
|
||||
elif params.regex is not None:
|
||||
return (StructuredOutputOptions.REGEX, params.regex)
|
||||
elif params.choice is not None:
|
||||
if not isinstance(params.choice, str):
|
||||
json_str = json.dumps(params.choice)
|
||||
else:
|
||||
json_str = params.choice
|
||||
return (StructuredOutputOptions.CHOICE, json_str)
|
||||
elif params.grammar is not None:
|
||||
return (StructuredOutputOptions.GRAMMAR, params.grammar)
|
||||
elif params.structural_tag is not None:
|
||||
return (StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag)
|
||||
else:
|
||||
raise ValueError("No valid structured output parameter found")
|
||||
175
vllm/v1/structured_output/utils.py
Normal file
175
vllm/v1/structured_output/utils.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import regex as re
|
||||
|
||||
|
||||
def grammar_is_likely_lark(grammar_str: str) -> bool:
|
||||
"""
|
||||
Check if grammar appears to use Lark syntax.
|
||||
|
||||
Args:
|
||||
grammar_str: Input grammar string
|
||||
|
||||
Returns:
|
||||
bool: True if grammar appears to be in Lark format, False otherwise
|
||||
|
||||
Examples:
|
||||
>>> grammar_is_likely_lark("rule: 'abc'")
|
||||
True
|
||||
>>> grammar_is_likely_lark("rule ::= 'abc'")
|
||||
False
|
||||
"""
|
||||
if not grammar_str or not isinstance(grammar_str, str):
|
||||
return False
|
||||
|
||||
for line in grammar_str.split('\n'):
|
||||
# Remove both comment styles
|
||||
line = re.sub(r'(#|//).*$', '', line).strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Look for EBNF rule definition
|
||||
if '::=' in line:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def convert_lark_to_ebnf(grammar_str: str) -> str:
|
||||
"""
|
||||
Convert a Lark grammar string to EBNF format.
|
||||
|
||||
EBNF reference:
|
||||
https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
|
||||
Lark grammar reference:
|
||||
https://lark-parser.readthedocs.io/en/latest/grammar.html
|
||||
|
||||
Args:
|
||||
grammar_str: Input grammar in Lark format
|
||||
|
||||
Returns:
|
||||
str: Converted grammar in EBNF format
|
||||
|
||||
Examples:
|
||||
>>> print(convert_lark_to_ebnf("rule: 'hello'"))
|
||||
root ::= rule
|
||||
rule ::= "hello"
|
||||
"""
|
||||
if not isinstance(grammar_str, str):
|
||||
raise ValueError(f"Grammar must be a string, got {type(grammar_str)}")
|
||||
if not grammar_str.strip():
|
||||
raise ValueError("Grammar string cannot be empty")
|
||||
|
||||
defined_rules = set()
|
||||
referenced_rules = set()
|
||||
output_lines = []
|
||||
|
||||
def clean_line(line: str) -> str:
|
||||
"""Remove comments and whitespace from line."""
|
||||
return re.sub(r'(#|//).*$', '', line).strip()
|
||||
|
||||
def check_quotes(text: str, rule_name: str, line_num: int) -> None:
|
||||
"""Validate quote matching in text."""
|
||||
if text.count("'") % 2 != 0 or text.count('"') % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Mismatched quotes in {rule_name} on line {line_num}")
|
||||
|
||||
def extract_references(text: str) -> set:
|
||||
"""Extract rule references from text."""
|
||||
# Remove quoted strings and special characters
|
||||
text = re.sub(r'"[^"]*"', '', text)
|
||||
text = re.sub(r'[+*?()|\[\]{}]', ' ', text)
|
||||
return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text))
|
||||
|
||||
# First pass: Find root rule and validate rule definitions
|
||||
lines = [clean_line(line) for line in grammar_str.split('\n')]
|
||||
first_rule = None
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line or line.startswith('|'):
|
||||
continue
|
||||
|
||||
if ':' in line:
|
||||
try:
|
||||
name = line.split(':', 1)[0].strip().strip('?')
|
||||
defined_rules.add(name)
|
||||
if first_rule is None:
|
||||
first_rule = name
|
||||
if name == 'start':
|
||||
first_rule = 'start'
|
||||
except IndexError as e:
|
||||
raise ValueError(f"Invalid rule format on line {line_num}. "
|
||||
"Expected 'rule_name: definition'") from e
|
||||
|
||||
if not defined_rules:
|
||||
raise ValueError("No valid rules found in grammar")
|
||||
|
||||
# Add root rule
|
||||
output_lines.append(f"root ::= {first_rule}")
|
||||
|
||||
# Second pass: Process rule definitions and alternatives
|
||||
current_rule = None
|
||||
current_definition = []
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
if ':' in line and not line.startswith('|'):
|
||||
# Save previous rule if exists
|
||||
if current_rule:
|
||||
output_lines.append(
|
||||
f"{current_rule} ::= {' | '.join(current_definition)}")
|
||||
|
||||
# Process new rule
|
||||
name, definition = line.split(':', 1)
|
||||
current_rule = name.strip().strip('?')
|
||||
|
||||
check_quotes(definition, f"rule '{current_rule}'", line_num)
|
||||
definition = re.sub(r"'([^']*)'", r'"\1"', definition)
|
||||
referenced_rules.update(extract_references(definition))
|
||||
current_definition = [definition.strip()]
|
||||
|
||||
elif line.startswith('|'):
|
||||
if not current_rule:
|
||||
raise ValueError(f"Alternative '|' on line {line_num} "
|
||||
"without a preceding rule definition")
|
||||
|
||||
alt_def = line[1:].strip()
|
||||
check_quotes(alt_def, f"alternative for rule '{current_rule}'",
|
||||
line_num)
|
||||
alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def)
|
||||
referenced_rules.update(extract_references(alt_def))
|
||||
current_definition.append(alt_def)
|
||||
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Error on line {line_num}: {str(e)}") from e
|
||||
|
||||
# Add final rule if exists
|
||||
if current_rule:
|
||||
output_lines.append(
|
||||
f"{current_rule} ::= {' | '.join(current_definition)}")
|
||||
|
||||
# Validate all rules are defined
|
||||
undefined_rules = referenced_rules - defined_rules - {'root'}
|
||||
if undefined_rules:
|
||||
raise ValueError("Referenced rules are not defined: "
|
||||
f"{', '.join(sorted(undefined_rules))}")
|
||||
|
||||
return '\n'.join(output_lines)
|
||||
|
||||
|
||||
def choice_as_grammar(choice: list[str]) -> str:
|
||||
|
||||
def escape_ebnf_string(s: str) -> str:
|
||||
"""Escape special characters in a EBNF string."""
|
||||
# Escape double quotes and backslashes
|
||||
return re.sub(r'(["\\])', r'\\\1', s)
|
||||
|
||||
escaped_choices = (escape_ebnf_string(c) for c in choice)
|
||||
grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices))
|
||||
return grammar
|
||||
Reference in New Issue
Block a user