# SPDX-License-Identifier: Apache-2.0 import os from dataclasses import dataclass from typing import TYPE_CHECKING, Optional import torch from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar, StructuredOutputOptions) from vllm.v1.structured_output.request import get_structured_output_key if TYPE_CHECKING: import llguidance import llguidance.hf as llguidance_hf import llguidance.torch as llguidance_torch else: llguidance = LazyLoader("llguidance", globals(), "llguidance") llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf") llguidance_torch = LazyLoader("llguidance.torch", globals(), "llguidance.torch") logger = init_logger(__name__) class GuidanceBackend(StructuredOutputBackend): def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config tokenizer_group = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, parallel_config=vllm_config.parallel_config, lora_config=vllm_config.lora_config) # type: ignore[arg-type] tokenizer_group.ping() self.vllm_config = vllm_config self.vocab_size = vllm_config.model_config.get_vocab_size() self.disable_any_whitespace = ( "disable-any-whitespace" in vllm_config.decoding_config.guided_decoding_backend) tokenizer = tokenizer_group.get_lora_tokenizer(None) self.ll_tokenizer = llguidance_hf.from_tokenizer(tokenizer, None) def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: self.serialized_grammar = serialize_guidance_grammar( request_type, grammar_spec, self.disable_any_whitespace) ll_matcher = llguidance.LLMatcher( self.ll_tokenizer, self.serialized_grammar, log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), ) r = GuidanceGrammar( ll_matcher=ll_matcher, ll_tokenizer=self.ll_tokenizer, vocab_size=self.vocab_size, ) r.check_error() return r def allocate_token_bitmask(self, max_num_seqs: int): return llguidance_torch.allocate_token_bitmask( max_num_seqs, self.ll_tokenizer.vocab_size) @dataclass class GuidanceGrammar(StructuredOutputGrammar): ll_matcher: llguidance.LLMatcher ll_tokenizer: llguidance.LLTokenizer vocab_size: int printed_error: bool = False terminated: bool = False def check_error(self): if not self.printed_error: err = self.ll_matcher.get_error() if err: self.printed_error = True logger.warning("LLMatcher error: %s", err) def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: """Accepts a list of tokens and advances the parser. Returns True if the parser was advanced successfully. Returns False if the parser failed to advance. """ if self.ll_tokenizer.eos_token in tokens: self.terminated = True if self.ll_matcher.is_stopped(): return True # TODO - Add jump decoding support in the future: # self.ll_matcher.compute_ff_bytes() - this should always work # self.ll_matcher.compute_ff_tokens() - this only works for # "canonical" tokenizers # For conversion between the two, see # https://github.com/guidance-ai/llguidance/blob/main/docs/fast_forward.md r = self.ll_matcher.consume_tokens(tokens) self.check_error() return r def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: # this will automatically return [EOS] mask if the matcher is stopped # or otherwise in an error state llguidance_torch.fill_next_token_bitmask(self.ll_matcher, bitmask, idx) self.check_error() def is_terminated(self) -> bool: return self.terminated def reset(self): # This method may be not needed anymore? TODO self.ll_matcher.reset() def serialize_guidance_grammar(request_type: StructuredOutputOptions, grammar_spec: str, disable_any_whitespace: bool = False) -> str: if request_type == StructuredOutputOptions.JSON: return llguidance.LLMatcher.grammar_from_json_schema( grammar_spec, defaults={ "whitespace_flexible": not disable_any_whitespace, }) elif request_type == StructuredOutputOptions.JSON_OBJECT: return llguidance.LLMatcher.grammar_from_json_schema( '{"type": "object"}', defaults={ "whitespace_flexible": not disable_any_whitespace, }) else: if request_type == StructuredOutputOptions.REGEX: tp = "regex" elif request_type == StructuredOutputOptions.GRAMMAR: tp = "grammar" elif request_type == StructuredOutputOptions.CHOICE: tp = "choice" else: logger.error("Validation should have already occurred. " "Please file an issue.") raise ValueError("grammar is not of valid supported types. " f"({request_type!s})") return llguidance.grammar_from(tp, grammar_spec) def validate_guidance_grammar( sampling_params: SamplingParams, tokenizer: Optional[llguidance.LLTokenizer] = None) -> None: tp, grm = get_structured_output_key(sampling_params) guidance_grm = serialize_guidance_grammar(tp, grm) err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer=tokenizer) if err: raise ValueError(f"Grammar error: {err}")