init
This commit is contained in:
25
vllm/model_executor/guided_decoding/__init__.py
Normal file
25
vllm/model_executor/guided_decoding/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
|
||||
get_lm_format_enforcer_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
||||
get_outlines_guided_decoding_logits_processor)
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
|
||||
|
||||
async def get_guided_decoding_logits_processor(
|
||||
guided_decoding_backend: str, request: Union[CompletionRequest,
|
||||
ChatCompletionRequest],
|
||||
tokenizer) -> Optional[LogitsProcessor]:
|
||||
if guided_decoding_backend == 'outlines':
|
||||
return await get_outlines_guided_decoding_logits_processor(
|
||||
request, tokenizer)
|
||||
if guided_decoding_backend == 'lm-format-enforcer':
|
||||
return await get_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
request, tokenizer)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
|
||||
"Must be one of 'outlines, 'lm-format-enforcer'")
|
||||
@@ -0,0 +1,70 @@
|
||||
from functools import lru_cache
|
||||
from json import loads as json_loads
|
||||
from typing import Optional, Union
|
||||
|
||||
from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
|
||||
RegexParser, StringParser,
|
||||
TokenEnforcerTokenizerData, UnionParser)
|
||||
from lmformatenforcer.integrations.vllm import (
|
||||
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
||||
get_outlines_guided_decoding_logits_processor)
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
|
||||
|
||||
async def get_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest],
|
||||
tokenizer) -> Optional[LogitsProcessor]:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
and get the necessary logits processor for the given guide.
|
||||
We cache logit processors by (guide, tokenizer), and on cache hit
|
||||
we make a shallow copy to reuse the same underlying FSM.
|
||||
"""
|
||||
|
||||
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||
tokenizer)
|
||||
character_level_parser: CharacterLevelParser
|
||||
if request.guided_json:
|
||||
schema = _normalize_json_schema_object(request.guided_json)
|
||||
character_level_parser = JsonSchemaParser(schema)
|
||||
elif request.guided_choice:
|
||||
character_level_parser = UnionParser(
|
||||
[StringParser(choice) for choice in request.guided_choice])
|
||||
elif request.guided_regex:
|
||||
character_level_parser = RegexParser(request.guided_regex)
|
||||
elif request.guided_grammar:
|
||||
# CFG grammar not supported by LMFE, revert to outlines
|
||||
return await get_outlines_guided_decoding_logits_processor(
|
||||
request, tokenizer)
|
||||
elif (request.response_format is not None
|
||||
and request.response_format.type == "json_object"):
|
||||
character_level_parser = JsonSchemaParser(
|
||||
None) # None means any json object
|
||||
else:
|
||||
return None
|
||||
|
||||
logits_processor = build_vllm_logits_processor(tokenizer_data,
|
||||
character_level_parser)
|
||||
return logits_processor
|
||||
|
||||
|
||||
def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
|
||||
if isinstance(schema, str):
|
||||
return json_loads(schema)
|
||||
if isinstance(schema, dict):
|
||||
return schema
|
||||
if isinstance(schema, BaseModel):
|
||||
return schema.model_json_schema()
|
||||
raise AssertionError(f"Unsupported schema type {schema}")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||
tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData:
|
||||
return build_vllm_token_enforcer_tokenizer_data(tokenizer)
|
||||
130
vllm/model_executor/guided_decoding/outlines_decoding.py
Normal file
130
vllm/model_executor/guided_decoding/outlines_decoding.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from copy import copy
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from json import dumps as json_dumps
|
||||
from re import escape as regex_escape
|
||||
from typing import Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
|
||||
|
||||
|
||||
class GuidedDecodingMode(Enum):
|
||||
JSON = "json"
|
||||
REGEX = "regex"
|
||||
CHOICE = "choice"
|
||||
GRAMMAR = "grammar"
|
||||
|
||||
|
||||
# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
|
||||
# the main difference is that we changed the start: value to
|
||||
# start: object | array, so we are denying scalar values as the root of the
|
||||
# JSON. Starting with scalars as the root seems to cause llama to generate
|
||||
# without stop.
|
||||
JSON_GRAMMAR = r"""
|
||||
?start: object | array
|
||||
|
||||
?value: object
|
||||
| array
|
||||
| UNESCAPED_STRING
|
||||
| SIGNED_NUMBER -> number
|
||||
| "true" -> true
|
||||
| "false" -> false
|
||||
| "null" -> null
|
||||
|
||||
array : "[" [value ("," value)*] "]"
|
||||
object : "{" [pair ("," pair)*] "}"
|
||||
pair : UNESCAPED_STRING ":" value
|
||||
|
||||
%import common.UNESCAPED_STRING
|
||||
%import common.SIGNED_NUMBER
|
||||
%import common.WS
|
||||
|
||||
%ignore WS
|
||||
"""
|
||||
|
||||
global_thread_pool = None # used for generating logits processor fsm
|
||||
|
||||
|
||||
async def get_outlines_guided_decoding_logits_processor(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest],
|
||||
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
and get the necessary logits processor for the given guide.
|
||||
We cache logit processors by (guide, tokenizer), and on cache hit
|
||||
we make a shallow copy to reuse the same underlying FSM.
|
||||
"""
|
||||
global global_thread_pool
|
||||
guide, mode = _get_guide_and_mode(request)
|
||||
if not guide:
|
||||
return None
|
||||
|
||||
if global_thread_pool is None:
|
||||
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=2)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
result = await loop.run_in_executor(global_thread_pool,
|
||||
_get_cached_logits_processor, guide,
|
||||
tokenizer, mode,
|
||||
request.guided_whitespace_pattern)
|
||||
|
||||
logits_processor = copy(result)
|
||||
# reset logits processor's internal state
|
||||
logits_processor.init_state()
|
||||
return logits_processor
|
||||
|
||||
|
||||
def _get_guide_and_mode(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest]
|
||||
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
|
||||
|
||||
if request.guided_json:
|
||||
json = request.guided_json
|
||||
if isinstance(json, dict):
|
||||
# turn dict into hashable string
|
||||
json = json_dumps(json)
|
||||
elif isinstance(json, BaseModel):
|
||||
# use pydantic signature so that different model classes
|
||||
# with the same fields will get hashed the same
|
||||
json = str(json.__signature__)
|
||||
return json, GuidedDecodingMode.JSON
|
||||
elif request.guided_regex:
|
||||
return request.guided_regex, GuidedDecodingMode.REGEX
|
||||
elif request.guided_choice:
|
||||
# choice just uses regex
|
||||
choices = [
|
||||
regex_escape(str(choice)) for choice in request.guided_choice
|
||||
]
|
||||
choices_regex = "(" + "|".join(choices) + ")"
|
||||
return choices_regex, GuidedDecodingMode.CHOICE
|
||||
elif request.guided_grammar:
|
||||
return request.guided_grammar, GuidedDecodingMode.GRAMMAR
|
||||
elif (request.response_format is not None
|
||||
and request.response_format.type == "json_object"):
|
||||
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
|
||||
else:
|
||||
return None, None
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_cached_logits_processor(guide: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
mode: GuidedDecodingMode,
|
||||
whitespace_pattern: Union[str, None]):
|
||||
if mode == GuidedDecodingMode.JSON:
|
||||
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern)
|
||||
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
||||
return RegexLogitsProcessor(guide, tokenizer)
|
||||
elif mode == GuidedDecodingMode.GRAMMAR:
|
||||
return CFGLogitsProcessor(guide, tokenizer)
|
||||
else:
|
||||
raise ValueError(f"Unknown guided decoding mode {mode}")
|
||||
@@ -0,0 +1,184 @@
|
||||
# Copyright 2024- the Outlines developers
|
||||
# This file is adapted from
|
||||
# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from typing import Callable, DefaultDict, Dict, List, Union
|
||||
|
||||
import torch
|
||||
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
|
||||
class BaseLogitsProcessor:
|
||||
|
||||
def __init__(self):
|
||||
# Child class should use initialize in their init.
|
||||
self.fsm: FSM
|
||||
|
||||
def init_state(self):
|
||||
"""Initialize the FSM states."""
|
||||
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
|
||||
|
||||
def __call__(self, input_ids: List[int],
|
||||
scores: torch.Tensor) -> torch.Tensor:
|
||||
"""Use the FSM to bias the logits before sampling the next token."""
|
||||
seq_id = hash(tuple(input_ids))
|
||||
|
||||
if len(input_ids) == 0:
|
||||
self.init_state()
|
||||
else:
|
||||
last_token = input_ids[-1]
|
||||
last_seq_id = hash(tuple(input_ids[:-1]))
|
||||
self.fsm_state[seq_id] = self.fsm.next_state(
|
||||
self.fsm_state[last_seq_id], last_token)
|
||||
|
||||
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
|
||||
|
||||
mask = torch.full((scores.shape[-1], ),
|
||||
-math.inf,
|
||||
device=scores.device)
|
||||
mask[allowed_tokens] = 0
|
||||
scores.add_(mask)
|
||||
return scores
|
||||
|
||||
|
||||
class RegexLogitsProcessor(BaseLogitsProcessor):
|
||||
|
||||
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
|
||||
"""Compile the FSM that drives the regex-structured generation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
regex_string
|
||||
A string that represents a regular expression
|
||||
tokenizer
|
||||
The model's tokenizer
|
||||
|
||||
"""
|
||||
tokenizer = _adapt_tokenizer(tokenizer)
|
||||
fsm = RegexFSM(regex_string, tokenizer)
|
||||
self.fsm = fsm
|
||||
|
||||
|
||||
class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||
|
||||
def __init__(self, schema: Union[str, Dict, BaseModel],
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
whitespace_pattern: Union[str, None]):
|
||||
"""Compile the FSM that drives the JSON-guided generation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
schema
|
||||
A JSON schema that encodes the structure we want the model to
|
||||
generate
|
||||
tokenizer
|
||||
The model's tokenizer
|
||||
whitespace_pattern
|
||||
Pattern to use for JSON syntactic whitespace (doesn't impact
|
||||
string literals)
|
||||
Example: allow only a single space or newline with
|
||||
`whitespace_pattern=r"[\n ]?"`
|
||||
"""
|
||||
if isinstance(schema, type(BaseModel)):
|
||||
schema_str = json.dumps(schema.model_json_schema())
|
||||
elif isinstance(schema, Dict):
|
||||
schema_str = json.dumps(schema)
|
||||
elif isinstance(schema, str):
|
||||
schema_str = schema
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot parse schema {schema}. The schema must be either "
|
||||
f"a Pydantic object, a dictionary or a string that contains "
|
||||
f"the JSON Schema specification")
|
||||
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
|
||||
super().__init__(regex_string, tokenizer)
|
||||
|
||||
|
||||
class CFGLogitsProcessor(BaseLogitsProcessor):
|
||||
|
||||
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
|
||||
"""Compile the FSM that drives the context free grammar generation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cfg
|
||||
A string that represents a context-free grammar
|
||||
tokenizer
|
||||
The model's tokenizer
|
||||
|
||||
"""
|
||||
tokenizer = _adapt_tokenizer(tokenizer)
|
||||
fsm = CFGFSM(cfg, tokenizer)
|
||||
self.fsm = fsm
|
||||
|
||||
def init_state(self):
|
||||
"""Initialize state with a CFGFSM copy."""
|
||||
super().init_state()
|
||||
self.fsm = self.fsm.copy()
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
||||
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
||||
|
||||
The API of Outlines tokenizers is slightly different to that of
|
||||
`transformers`. The decoder of outlines, returns a list whereas
|
||||
the decode of vLLM returns an str. To sync the vLLM decoder with
|
||||
outlines internal api, the decoder should be adapted. In addition
|
||||
we need to handle the missing spaces to Llama's tokenizer to be
|
||||
able to compile FSMs for this model.
|
||||
|
||||
"""
|
||||
if getattr(tokenizer, "_outlines_adapted", False):
|
||||
return tokenizer
|
||||
|
||||
tokenizer = copy.deepcopy(tokenizer)
|
||||
|
||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||
|
||||
def convert_token_to_string(token: str) -> str:
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
def change_decoder(
|
||||
decoder: Callable[[List[int]],
|
||||
str]) -> Callable[[List[int]], List[str]]:
|
||||
"""Sync vLLM's decoder with the outlines by returning list."""
|
||||
|
||||
def new_decoder(inp_tokens: List[int]) -> List[str]:
|
||||
return [decoder(inp_tokens)]
|
||||
|
||||
return new_decoder
|
||||
|
||||
tokenizer.convert_token_to_string = convert_token_to_string
|
||||
tokenizer.decode = change_decoder(tokenizer.decode)
|
||||
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
|
||||
|
||||
return tokenizer
|
||||
Reference in New Issue
Block a user