First commit
This commit is contained in:
0
vllm/entrypoints/__init__.py
Normal file
0
vllm/entrypoints/__init__.py
Normal file
BIN
vllm/entrypoints/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/api_server.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/api_server.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/chat_utils.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/chat_utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/launcher.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/launcher.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/llm.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/llm.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/logger.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/logger.cpython-310.pyc
Normal file
Binary file not shown.
163
vllm/entrypoints/api_server.py
Normal file
163
vllm/entrypoints/api_server.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
NOTE: This API server is used only for demonstrating usage of AsyncEngine
|
||||
and simple performance benchmarks. It is not intended for production use.
|
||||
For production use, we recommend using our OpenAI compatible server.
|
||||
We are also not going to accept PRs modifying this file, please
|
||||
change `vllm/entrypoints/openai/api_server.py` instead.
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import ssl
|
||||
from argparse import Namespace
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation,
|
||||
random_uuid)
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger("vllm.entrypoints.api_server")
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
app = FastAPI()
|
||||
engine = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Health check."""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(request: Request) -> Response:
|
||||
"""Generate completion for the request.
|
||||
|
||||
The request should be a JSON object with the following fields:
|
||||
- prompt: the prompt to use for the generation.
|
||||
- stream: whether to stream the results or not.
|
||||
- other fields: the sampling parameters (See `SamplingParams` for details).
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
prompt = request_dict.pop("prompt")
|
||||
stream = request_dict.pop("stream", False)
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
|
||||
assert engine is not None
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
results_generator = iterate_with_cancellation(
|
||||
results_generator, is_cancelled=request.is_disconnected)
|
||||
|
||||
# Streaming case
|
||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||
async for request_output in results_generator:
|
||||
prompt = request_output.prompt
|
||||
assert prompt is not None
|
||||
text_outputs = [
|
||||
prompt + output.text for output in request_output.outputs
|
||||
]
|
||||
ret = {"text": text_outputs}
|
||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(stream_results())
|
||||
|
||||
# Non-streaming case
|
||||
final_output = None
|
||||
try:
|
||||
async for request_output in results_generator:
|
||||
final_output = request_output
|
||||
except asyncio.CancelledError:
|
||||
return Response(status_code=499)
|
||||
|
||||
assert final_output is not None
|
||||
prompt = final_output.prompt
|
||||
assert prompt is not None
|
||||
text_outputs = [prompt + output.text for output in final_output.outputs]
|
||||
ret = {"text": text_outputs}
|
||||
return JSONResponse(ret)
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
global app
|
||||
|
||||
app.root_path = args.root_path
|
||||
return app
|
||||
|
||||
|
||||
async def init_app(
|
||||
args: Namespace,
|
||||
llm_engine: Optional[AsyncLLMEngine] = None,
|
||||
) -> FastAPI:
|
||||
app = build_app(args)
|
||||
|
||||
global engine
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = (llm_engine
|
||||
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.API_SERVER))
|
||||
|
||||
return app
|
||||
|
||||
|
||||
async def run_server(args: Namespace,
|
||||
llm_engine: Optional[AsyncLLMEngine] = None,
|
||||
**uvicorn_kwargs: Any) -> None:
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
app = await init_app(args, llm_engine)
|
||||
assert engine is not None
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.log_level,
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
ssl_ca_certs=args.ssl_ca_certs,
|
||||
ssl_cert_reqs=args.ssl_cert_reqs,
|
||||
**uvicorn_kwargs,
|
||||
)
|
||||
|
||||
await shutdown_task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--ssl-keyfile", type=str, default=None)
|
||||
parser.add_argument("--ssl-certfile", type=str, default=None)
|
||||
parser.add_argument("--ssl-ca-certs",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The CA certificates file")
|
||||
parser.add_argument(
|
||||
"--ssl-cert-reqs",
|
||||
type=int,
|
||||
default=int(ssl.CERT_NONE),
|
||||
help="Whether client certificate is required (see stdlib ssl module's)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="FastAPI root_path when app is behind a path based routing proxy")
|
||||
parser.add_argument("--log-level", type=str, default="debug")
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(run_server(args))
|
||||
581
vllm/entrypoints/chat_utils.py
Normal file
581
vllm/entrypoints/chat_utils.py
Normal file
@@ -0,0 +1,581 @@
|
||||
import asyncio
|
||||
import codecs
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache, partial
|
||||
from pathlib import Path
|
||||
from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal,
|
||||
Mapping, Optional, Tuple, TypeVar, Union, cast)
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from openai.types.chat import (ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionContentPartImageParam)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
|
||||
from openai.types.chat import (ChatCompletionContentPartRefusalParam,
|
||||
ChatCompletionContentPartTextParam)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
|
||||
from openai.types.chat import (ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionToolMessageParam)
|
||||
# yapf: enable
|
||||
# pydantic needs the TypedDict from typing_extensions
|
||||
from pydantic import ConfigDict
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from typing_extensions import Required, TypeAlias, TypedDict
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal.utils import (async_get_and_parse_audio,
|
||||
async_get_and_parse_image,
|
||||
get_and_parse_audio, get_and_parse_image)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AudioURL(TypedDict, total=False):
|
||||
url: Required[str]
|
||||
"""
|
||||
Either a URL of the audio or a data URL with base64 encoded audio data.
|
||||
"""
|
||||
|
||||
|
||||
class ChatCompletionContentPartAudioParam(TypedDict, total=False):
|
||||
audio_url: Required[AudioURL]
|
||||
|
||||
type: Required[Literal["audio_url"]]
|
||||
"""The type of the content part."""
|
||||
|
||||
|
||||
class CustomChatCompletionContentPartParam(TypedDict, total=False):
|
||||
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
|
||||
|
||||
type: Required[str]
|
||||
"""The type of the content part."""
|
||||
|
||||
|
||||
ChatCompletionContentPartParam: TypeAlias = Union[
|
||||
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
|
||||
ChatCompletionContentPartRefusalParam,
|
||||
CustomChatCompletionContentPartParam]
|
||||
|
||||
|
||||
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
"""Enables custom roles in the Chat Completion API."""
|
||||
role: Required[str]
|
||||
"""The role of the message's author."""
|
||||
|
||||
content: Union[str, List[ChatCompletionContentPartParam]]
|
||||
"""The contents of the message."""
|
||||
|
||||
name: str
|
||||
"""An optional name for the participant.
|
||||
|
||||
Provides the model information to differentiate between participants of the
|
||||
same role.
|
||||
"""
|
||||
|
||||
tool_call_id: Optional[str]
|
||||
"""Tool call that this message is responding to."""
|
||||
|
||||
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
|
||||
"""The tool calls generated by the model, such as function calls."""
|
||||
|
||||
|
||||
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
|
||||
CustomChatCompletionMessageParam]
|
||||
|
||||
|
||||
# TODO: Make fields ReadOnly once mypy supports it
|
||||
class ConversationMessage(TypedDict, total=False):
|
||||
role: Required[str]
|
||||
"""The role of the message's author."""
|
||||
|
||||
content: Optional[str]
|
||||
"""The contents of the message"""
|
||||
|
||||
tool_call_id: Optional[str]
|
||||
"""Tool call that this message is responding to."""
|
||||
|
||||
name: Optional[str]
|
||||
"""The name of the function to call"""
|
||||
|
||||
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
|
||||
"""The tool calls generated by the model, such as function calls."""
|
||||
|
||||
|
||||
ModalityStr = Literal["image", "audio", "video"]
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
"""
|
||||
Tracks multi-modal items in a given request and ensures that the number
|
||||
of multi-modal items in a given request does not exceed the configured
|
||||
maximum per prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
|
||||
super().__init__()
|
||||
|
||||
self._model_config = model_config
|
||||
self._tokenizer = tokenizer
|
||||
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
|
||||
if model_config.multimodal_config else {})
|
||||
self._consumed_items = {k: 0 for k in self._allowed_items}
|
||||
|
||||
self._items: List[_T] = []
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=None)
|
||||
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
|
||||
return tokenizer.decode(token_index)
|
||||
|
||||
def _placeholder_str(self, modality: ModalityStr,
|
||||
current_count: int) -> Optional[str]:
|
||||
# TODO: Let user specify how to insert image tokens into prompt
|
||||
# (similar to chat template)
|
||||
hf_config = self._model_config.hf_config
|
||||
model_type = hf_config.model_type
|
||||
|
||||
if modality == "image":
|
||||
if model_type == "phi3_v":
|
||||
# Workaround since this token is not defined in the tokenizer
|
||||
return f"<|image_{current_count}|>"
|
||||
if model_type == "minicpmv":
|
||||
return "(<image>./</image>)"
|
||||
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
|
||||
"pixtral"):
|
||||
# These models do not use image tokens in the prompt
|
||||
return None
|
||||
if model_type == "qwen":
|
||||
return f"Picture {current_count}: <img></img>"
|
||||
if model_type.startswith("llava"):
|
||||
return self._cached_token_str(self._tokenizer,
|
||||
hf_config.image_token_index)
|
||||
if model_type in ("chameleon", "internvl_chat", "NVLM_D"):
|
||||
return "<image>"
|
||||
if model_type == "mllama":
|
||||
return "<|image|>"
|
||||
if model_type == "qwen2_vl":
|
||||
return "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
if model_type == "molmo":
|
||||
return ""
|
||||
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
elif modality == "audio":
|
||||
if model_type == "ultravox":
|
||||
return "<|reserved_special_token_0|>"
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
elif modality == "video":
|
||||
if model_type == "qwen2_vl":
|
||||
return "<|vision_start|><|video_pad|><|vision_end|>"
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
else:
|
||||
raise TypeError(f"Unknown modality: {modality}")
|
||||
|
||||
@staticmethod
|
||||
def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
|
||||
mm_lists: Mapping[str, List[object]] = defaultdict(list)
|
||||
|
||||
# Merge all the multi-modal items
|
||||
for single_mm_data in items:
|
||||
for mm_key, mm_item in single_mm_data.items():
|
||||
if isinstance(mm_item, list):
|
||||
mm_lists[mm_key].extend(mm_item)
|
||||
else:
|
||||
mm_lists[mm_key].append(mm_item)
|
||||
|
||||
# Unpack any single item lists for models that don't expect multiple.
|
||||
return {
|
||||
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
|
||||
for mm_key, mm_list in mm_lists.items()
|
||||
}
|
||||
|
||||
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
|
||||
"""
|
||||
Add a multi-modal item to the current prompt and returns the
|
||||
placeholder string to use, if any.
|
||||
"""
|
||||
allowed_count = self._allowed_items.get(modality, 1)
|
||||
current_count = self._consumed_items.get(modality, 0) + 1
|
||||
if current_count > allowed_count:
|
||||
raise ValueError(
|
||||
f"At most {allowed_count} {modality}(s) may be provided in "
|
||||
"one request.")
|
||||
|
||||
self._consumed_items[modality] = current_count
|
||||
self._items.append(item)
|
||||
|
||||
return self._placeholder_str(modality, current_count)
|
||||
|
||||
@abstractmethod
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]):
|
||||
|
||||
def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||
return self._combine(self._items) if self._items else None
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
return MultiModalContentParser(self)
|
||||
|
||||
|
||||
class AsyncMultiModalItemTracker(
|
||||
BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
|
||||
|
||||
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||
if self._items:
|
||||
items = await asyncio.gather(*self._items)
|
||||
return self._combine(items)
|
||||
|
||||
return None
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
return AsyncMultiModalContentParser(self)
|
||||
|
||||
|
||||
class BaseMultiModalContentParser(ABC):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# multimodal placeholder_string : count
|
||||
self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0)
|
||||
|
||||
def _add_placeholder(self, placeholder: Optional[str]):
|
||||
if placeholder:
|
||||
self._placeholder_counts[placeholder] += 1
|
||||
|
||||
def mm_placeholder_counts(self) -> Dict[str, int]:
|
||||
return dict(self._placeholder_counts)
|
||||
|
||||
@abstractmethod
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
|
||||
def __init__(self, tracker: MultiModalItemTracker) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._tracker = tracker
|
||||
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
image = get_and_parse_image(image_url)
|
||||
|
||||
placeholder = self._tracker.add("image", image)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio = get_and_parse_audio(audio_url)
|
||||
|
||||
placeholder = self._tracker.add("audio", audio)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
|
||||
class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
|
||||
def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._tracker = tracker
|
||||
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
image_coro = async_get_and_parse_image(image_url)
|
||||
|
||||
placeholder = self._tracker.add("image", image_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio_coro = async_get_and_parse_audio(audio_url)
|
||||
|
||||
placeholder = self._tracker.add("audio", audio_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
|
||||
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
|
||||
"""Raises if the provided chat template appears invalid."""
|
||||
if chat_template is None:
|
||||
return
|
||||
|
||||
elif isinstance(chat_template, Path) and not chat_template.exists():
|
||||
raise FileNotFoundError(
|
||||
"the supplied chat template path doesn't exist")
|
||||
|
||||
elif isinstance(chat_template, str):
|
||||
JINJA_CHARS = "{}\n"
|
||||
if not any(c in chat_template
|
||||
for c in JINJA_CHARS) and not Path(chat_template).exists():
|
||||
raise ValueError(
|
||||
f"The supplied chat template string ({chat_template}) "
|
||||
f"appears path-like, but doesn't exist!")
|
||||
|
||||
else:
|
||||
raise TypeError(
|
||||
f"{type(chat_template)} is not a valid chat template type")
|
||||
|
||||
|
||||
def load_chat_template(
|
||||
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
|
||||
if chat_template is None:
|
||||
return None
|
||||
try:
|
||||
with open(chat_template, "r") as f:
|
||||
resolved_chat_template = f.read()
|
||||
except OSError as e:
|
||||
if isinstance(chat_template, Path):
|
||||
raise
|
||||
|
||||
JINJA_CHARS = "{}\n"
|
||||
if not any(c in chat_template for c in JINJA_CHARS):
|
||||
msg = (f"The supplied chat template ({chat_template}) "
|
||||
f"looks like a file path, but it failed to be "
|
||||
f"opened. Reason: {e}")
|
||||
raise ValueError(msg) from e
|
||||
|
||||
# If opening a file fails, set chat template to be args to
|
||||
# ensure we decode so our escape are interpreted correctly
|
||||
resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
|
||||
|
||||
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
||||
return resolved_chat_template
|
||||
|
||||
|
||||
# TODO: Let user specify how to insert multimodal tokens into prompt
|
||||
# (similar to chat template)
|
||||
def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
|
||||
text_prompt: str) -> str:
|
||||
"""Combine multimodal prompts for a multimodal language model."""
|
||||
|
||||
# Look through the text prompt to check for missing placeholders
|
||||
missing_placeholders: List[str] = []
|
||||
for placeholder in placeholder_counts:
|
||||
|
||||
# For any existing placeholder in the text prompt, we leave it as is
|
||||
placeholder_counts[placeholder] -= text_prompt.count(placeholder)
|
||||
|
||||
if placeholder_counts[placeholder] < 0:
|
||||
raise ValueError(
|
||||
f"Found more '{placeholder}' placeholders in input prompt than "
|
||||
"actual multimodal data items.")
|
||||
|
||||
missing_placeholders.extend([placeholder] *
|
||||
placeholder_counts[placeholder])
|
||||
|
||||
# NOTE: For now we always add missing placeholders at the front of
|
||||
# the prompt. This may change to be customizable in the future.
|
||||
return "\n".join(missing_placeholders + [text_prompt])
|
||||
|
||||
|
||||
# No need to validate using Pydantic again
|
||||
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
|
||||
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
|
||||
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
|
||||
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}
|
||||
|
||||
|
||||
def _parse_chat_message_content_parts(
|
||||
role: str,
|
||||
parts: Iterable[ChatCompletionContentPartParam],
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
) -> List[ConversationMessage]:
|
||||
texts: List[str] = []
|
||||
|
||||
mm_parser = mm_tracker.create_parser()
|
||||
keep_multimodal_content = \
|
||||
mm_tracker._model_config.hf_config.model_type in \
|
||||
MODEL_KEEP_MULTI_MODAL_CONTENT
|
||||
|
||||
has_image = False
|
||||
for part in parts:
|
||||
part_type = part["type"]
|
||||
if part_type == "text":
|
||||
text = _TextParser(part)["text"]
|
||||
texts.append(text)
|
||||
elif part_type == "image_url":
|
||||
image_url = _ImageParser(part)["image_url"]
|
||||
|
||||
if image_url.get("detail", "auto") != "auto":
|
||||
logger.warning(
|
||||
"'image_url.detail' is currently not supported and "
|
||||
"will be ignored.")
|
||||
|
||||
mm_parser.parse_image(image_url["url"])
|
||||
has_image = True
|
||||
elif part_type == "audio_url":
|
||||
audio_url = _AudioParser(part)["audio_url"]
|
||||
|
||||
mm_parser.parse_audio(audio_url["url"])
|
||||
elif part_type == "refusal":
|
||||
text = _RefusalParser(part)["refusal"]
|
||||
texts.append(text)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||
|
||||
text_prompt = "\n".join(texts)
|
||||
if keep_multimodal_content:
|
||||
text_prompt = "\n".join(texts)
|
||||
role_content = [{'type': 'text', 'text': text_prompt}]
|
||||
|
||||
if has_image:
|
||||
role_content = [{'type': 'image'}] + role_content
|
||||
return [ConversationMessage(role=role,
|
||||
content=role_content)] # type: ignore
|
||||
else:
|
||||
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
|
||||
if mm_placeholder_counts:
|
||||
text_prompt = _get_full_multimodal_text_prompt(
|
||||
mm_placeholder_counts, text_prompt)
|
||||
return [ConversationMessage(role=role, content=text_prompt)]
|
||||
|
||||
|
||||
# No need to validate using Pydantic again
|
||||
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
|
||||
_ToolParser = partial(cast, ChatCompletionToolMessageParam)
|
||||
|
||||
|
||||
def _parse_chat_message_content(
|
||||
message: ChatCompletionMessageParam,
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
) -> List[ConversationMessage]:
|
||||
role = message["role"]
|
||||
content = message.get("content")
|
||||
|
||||
if content is None:
|
||||
content = []
|
||||
elif isinstance(content, str):
|
||||
content = [
|
||||
ChatCompletionContentPartTextParam(type="text", text=content)
|
||||
]
|
||||
|
||||
result = _parse_chat_message_content_parts(
|
||||
role,
|
||||
content, # type: ignore
|
||||
mm_tracker,
|
||||
)
|
||||
|
||||
for result_msg in result:
|
||||
if role == 'assistant':
|
||||
parsed_msg = _AssistantParser(message)
|
||||
|
||||
if "tool_calls" in parsed_msg:
|
||||
result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
|
||||
elif role == "tool":
|
||||
parsed_msg = _ToolParser(message)
|
||||
if "tool_call_id" in parsed_msg:
|
||||
result_msg["tool_call_id"] = parsed_msg["tool_call_id"]
|
||||
|
||||
if "name" in message and isinstance(message["name"], str):
|
||||
result_msg["name"] = message["name"]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _postprocess_messages(messages: List[ConversationMessage]) -> None:
|
||||
# per the Transformers docs & maintainers, tool call arguments in
|
||||
# assistant-role messages with tool_calls need to be dicts not JSON str -
|
||||
# this is how tool-use chat templates will expect them moving forwards
|
||||
# so, for messages that have tool_calls, parse the string (which we get
|
||||
# from openAI format) to dict
|
||||
for message in messages:
|
||||
if (message["role"] == "assistant" and "tool_calls" in message
|
||||
and isinstance(message["tool_calls"], list)):
|
||||
|
||||
for item in message["tool_calls"]:
|
||||
item["function"]["arguments"] = json.loads(
|
||||
item["function"]["arguments"])
|
||||
|
||||
|
||||
def parse_chat_messages(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]:
|
||||
conversation: List[ConversationMessage] = []
|
||||
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
|
||||
|
||||
for msg in messages:
|
||||
sub_messages = _parse_chat_message_content(msg, mm_tracker)
|
||||
|
||||
conversation.extend(sub_messages)
|
||||
|
||||
_postprocess_messages(conversation)
|
||||
|
||||
return conversation, mm_tracker.all_mm_data()
|
||||
|
||||
|
||||
def parse_chat_messages_futures(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
|
||||
conversation: List[ConversationMessage] = []
|
||||
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
|
||||
|
||||
for msg in messages:
|
||||
sub_messages = _parse_chat_message_content(msg, mm_tracker)
|
||||
|
||||
conversation.extend(sub_messages)
|
||||
|
||||
_postprocess_messages(conversation)
|
||||
|
||||
return conversation, mm_tracker.all_mm_data()
|
||||
|
||||
|
||||
def apply_hf_chat_template(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
conversation: List[ConversationMessage],
|
||||
chat_template: Optional[str],
|
||||
*,
|
||||
tokenize: bool = False, # Different from HF's default
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if chat_template is None and tokenizer.chat_template is None:
|
||||
raise ValueError(
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one.")
|
||||
|
||||
return tokenizer.apply_chat_template(
|
||||
conversation=conversation, # type: ignore[arg-type]
|
||||
chat_template=chat_template,
|
||||
tokenize=tokenize,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def apply_mistral_chat_template(
|
||||
tokenizer: MistralTokenizer,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
chat_template: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[int]:
|
||||
if chat_template is not None:
|
||||
logger.warning(
|
||||
"'chat_template' cannot be overridden for mistral tokenizer.")
|
||||
if "add_generation_prompt" in kwargs:
|
||||
logger.warning(
|
||||
"'add_generation_prompt' is not supported for mistral tokenizer, "
|
||||
"so it will be ignored.")
|
||||
if "continue_final_message" in kwargs:
|
||||
logger.warning(
|
||||
"'continue_final_message' is not supported for mistral tokenizer, "
|
||||
"so it will be ignored.")
|
||||
|
||||
return tokenizer.apply_chat_template(
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
)
|
||||
103
vllm/entrypoints/launcher.py
Normal file
103
vllm/entrypoints/launcher.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import asyncio
|
||||
import signal
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request, Response
|
||||
|
||||
from vllm import envs
|
||||
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
||||
from vllm.engine.multiprocessing import MQEngineDeadError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import find_process_using_port
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
|
||||
logger.info("Available routes are:")
|
||||
for route in app.routes:
|
||||
methods = getattr(route, "methods", None)
|
||||
path = getattr(route, "path", None)
|
||||
|
||||
if methods is None or path is None:
|
||||
continue
|
||||
|
||||
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
|
||||
|
||||
config = uvicorn.Config(app, **uvicorn_kwargs)
|
||||
server = uvicorn.Server(config)
|
||||
_add_shutdown_handlers(app, server)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
server_task = loop.create_task(server.serve())
|
||||
|
||||
def signal_handler() -> None:
|
||||
# prevents the uvicorn signal handler to exit early
|
||||
server_task.cancel()
|
||||
|
||||
async def dummy_shutdown() -> None:
|
||||
pass
|
||||
|
||||
loop.add_signal_handler(signal.SIGINT, signal_handler)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
await server_task
|
||||
return dummy_shutdown()
|
||||
except asyncio.CancelledError:
|
||||
port = uvicorn_kwargs["port"]
|
||||
process = find_process_using_port(port)
|
||||
if process is not None:
|
||||
logger.debug(
|
||||
"port %s is used by process %s launched with command:\n%s",
|
||||
port, process, " ".join(process.cmdline()))
|
||||
logger.info("Shutting down FastAPI HTTP server.")
|
||||
return server.shutdown()
|
||||
|
||||
|
||||
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
||||
"""Adds handlers for fatal errors that should crash the server"""
|
||||
|
||||
@app.exception_handler(RuntimeError)
|
||||
async def runtime_error_handler(request: Request, __):
|
||||
"""On generic runtime error, check to see if the engine has died.
|
||||
It probably has, in which case the server will no longer be able to
|
||||
handle requests. Trigger a graceful shutdown with a SIGTERM."""
|
||||
engine = request.app.state.engine_client
|
||||
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
|
||||
and not engine.is_running):
|
||||
logger.fatal("AsyncLLMEngine has failed, terminating server "
|
||||
"process")
|
||||
# See discussions here on shutting down a uvicorn server
|
||||
# https://github.com/encode/uvicorn/discussions/1103
|
||||
# In this case we cannot await the server shutdown here because
|
||||
# this handler must first return to close the connection for
|
||||
# this request.
|
||||
server.should_exit = True
|
||||
|
||||
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
@app.exception_handler(AsyncEngineDeadError)
|
||||
async def async_engine_dead_handler(_, __):
|
||||
"""Kill the server if the async engine is already dead. It will
|
||||
not handle any further requests."""
|
||||
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
|
||||
logger.fatal("AsyncLLMEngine is already dead, terminating server "
|
||||
"process")
|
||||
server.should_exit = True
|
||||
|
||||
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
@app.exception_handler(MQEngineDeadError)
|
||||
async def mq_engine_dead_handler(_, __):
|
||||
"""Kill the server if the mq engine is already dead. It will
|
||||
not handle any further requests."""
|
||||
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
|
||||
logger.fatal("MQLLMEngine is already dead, terminating server "
|
||||
"process")
|
||||
server.should_exit = True
|
||||
|
||||
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
909
vllm/entrypoints/llm.py
Normal file
909
vllm/entrypoints/llm.py
Normal file
@@ -0,0 +1,909 @@
|
||||
import itertools
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
|
||||
Union, cast, overload)
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
||||
BeamSearchSequence, get_beam_search_score)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages)
|
||||
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
GuidedDecodingRequest, LLMGuidedOptions)
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
RequestOutputKind, SamplingParams)
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, deprecate_kwargs, is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LLM:
|
||||
"""An LLM for generating texts from given prompts and sampling parameters.
|
||||
|
||||
This class includes a tokenizer, a language model (possibly distributed
|
||||
across multiple GPUs), and GPU memory space allocated for intermediate
|
||||
states (aka KV cache). Given a batch of prompts and sampling parameters,
|
||||
this class generates texts from the model, using an intelligent batching
|
||||
mechanism and efficient memory management.
|
||||
|
||||
Args:
|
||||
model: The name or path of a HuggingFace Transformers model.
|
||||
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
|
||||
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
|
||||
if available, and "slow" will always use the slow tokenizer.
|
||||
skip_tokenizer_init: If true, skip initialization of tokenizer and
|
||||
detokenizer. Expect valid prompt_token_ids and None for prompt
|
||||
from the input.
|
||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||
downloading the model and tokenizer.
|
||||
tensor_parallel_size: The number of GPUs to use for distributed
|
||||
execution with tensor parallelism.
|
||||
dtype: The data type for the model weights and activations. Currently,
|
||||
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
|
||||
the `torch_dtype` attribute specified in the model config file.
|
||||
However, if the `torch_dtype` in the config is `float32`, we will
|
||||
use `float16` instead.
|
||||
quantization: The method used to quantize the model weights. Currently,
|
||||
we support "awq", "gptq", and "fp8" (experimental).
|
||||
If None, we first check the `quantization_config` attribute in the
|
||||
model config file. If that is None, we assume the model weights are
|
||||
not quantized and use `dtype` to determine the data type of
|
||||
the weights.
|
||||
revision: The specific model version to use. It can be a branch name,
|
||||
a tag name, or a commit id.
|
||||
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||
branch name, a tag name, or a commit id.
|
||||
seed: The seed to initialize the random number generator for sampling.
|
||||
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
|
||||
reserve for the model weights, activations, and KV cache. Higher
|
||||
values will increase the KV cache size and thus improve the model's
|
||||
throughput. However, if the value is too high, it may cause out-of-
|
||||
memory (OOM) errors.
|
||||
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
||||
This can be used for temporarily storing the states of the requests
|
||||
when their `best_of` sampling parameters are larger than 1. If all
|
||||
requests will have `best_of=1`, you can safely set this to 0.
|
||||
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
||||
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
|
||||
the model weights. This virtually increases the GPU memory space
|
||||
you can use to hold the model weights, at the cost of CPU-GPU data
|
||||
transfer for every forward pass.
|
||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||
disable CUDA graph and always execute the model in eager mode.
|
||||
If False, we will use CUDA graph and eager execution in hybrid.
|
||||
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
|
||||
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode. Additionally for encoder-decoder models, if the
|
||||
sequence length of the encoder input is larger than this, we fall
|
||||
back to the eager mode.
|
||||
disable_custom_all_reduce: See ParallelConfig
|
||||
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
|
||||
:ref:`engine_args`)
|
||||
|
||||
Note:
|
||||
This class is intended to be used for offline inference. For online
|
||||
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
|
||||
"""
|
||||
|
||||
DEPRECATE_LEGACY: ClassVar[bool] = False
|
||||
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def deprecate_legacy_api(cls):
|
||||
cls.DEPRECATE_LEGACY = True
|
||||
|
||||
yield
|
||||
|
||||
cls.DEPRECATE_LEGACY = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
tokenizer: Optional[str] = None,
|
||||
tokenizer_mode: str = "auto",
|
||||
skip_tokenizer_init: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: float = 4,
|
||||
cpu_offload_gb: float = 0,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
disable_async_output_proc: bool = False,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
'''
|
||||
LLM constructor.
|
||||
|
||||
Note: if enforce_eager is unset (enforce_eager is None)
|
||||
it defaults to False.
|
||||
'''
|
||||
|
||||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
quantization=quantization,
|
||||
revision=revision,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
seed=seed,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=swap_space,
|
||||
cpu_offload_gb=cpu_offload_gb,
|
||||
enforce_eager=enforce_eager,
|
||||
max_context_len_to_capture=max_context_len_to_capture,
|
||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
self.llm_engine = LLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.LLM_CLASS)
|
||||
self.request_counter = Counter()
|
||||
|
||||
def get_tokenizer(self) -> AnyTokenizer:
|
||||
return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer
|
||||
|
||||
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
|
||||
tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
|
||||
|
||||
# While CachedTokenizer is dynamic, have no choice but
|
||||
# compare class name. Misjudgment will arise from
|
||||
# user-defined tokenizer started with 'Cached'
|
||||
if tokenizer.__class__.__name__.startswith("Cached"):
|
||||
tokenizer_group.tokenizer = tokenizer
|
||||
else:
|
||||
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
|
||||
|
||||
@overload # LEGACY: single (prompt + optional token ids)
|
||||
def generate(
|
||||
self,
|
||||
prompts: str,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
List[SamplingParams]]] = None,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: multi (prompt + optional token ids)
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
List[SamplingParams]]] = None,
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: single (token ids + optional prompt)
|
||||
def generate(
|
||||
self,
|
||||
prompts: Optional[str] = None,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
List[SamplingParams]]] = None,
|
||||
*,
|
||||
prompt_token_ids: List[int],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: multi (token ids + optional prompt)
|
||||
def generate(
|
||||
self,
|
||||
prompts: Optional[List[str]] = None,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
List[SamplingParams]]] = None,
|
||||
*,
|
||||
prompt_token_ids: List[List[int]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: single or multi token ids [pos-only]
|
||||
def generate(
|
||||
self,
|
||||
prompts: None,
|
||||
sampling_params: None,
|
||||
prompt_token_ids: Union[List[int], List[List[int]]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
*,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
Sequence[SamplingParams]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
"prompt_token_ids",
|
||||
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
||||
additional_message="Please use the 'prompts' parameter instead.",
|
||||
)
|
||||
def generate(
|
||||
self,
|
||||
prompts: Union[Union[PromptType, Sequence[PromptType]],
|
||||
Optional[Union[str, List[str]]]] = None,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
Sequence[SamplingParams]]] = None,
|
||||
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
guided_options_request: Optional[Union[LLMGuidedOptions,
|
||||
GuidedDecodingRequest]] = None,
|
||||
priority: Optional[List[int]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
|
||||
This class automatically batches the given prompts, considering
|
||||
the memory constraint. For the best performance, put all of your prompts
|
||||
into a single list and pass it to this method.
|
||||
|
||||
Args:
|
||||
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
||||
for batch inference. See :class:`~vllm.inputs.PromptType`
|
||||
for more details about the format of each prompts.
|
||||
sampling_params: The sampling parameters for text generation. If
|
||||
None, we use the default sampling parameters.
|
||||
When it is a single value, it is applied to every prompt.
|
||||
When it is a list, the list must have the same length as the
|
||||
prompts and it is paired one by one with the prompt.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
generation, if any.
|
||||
priority: The priority of the requests, if any.
|
||||
Only applicable when priority scheduling policy is enabled.
|
||||
|
||||
Returns:
|
||||
A list of ``RequestOutput`` objects containing the
|
||||
generated completions in the same order as the input prompts.
|
||||
|
||||
Note:
|
||||
Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
|
||||
considered legacy and may be deprecated in the future. You should
|
||||
instead pass them via the ``inputs`` parameter.
|
||||
"""
|
||||
if self.llm_engine.model_config.embedding_mode:
|
||||
raise ValueError(
|
||||
"LLM.generate() is only supported for (conditional) generation "
|
||||
"models (XForCausalLM, XForConditionalGeneration).")
|
||||
|
||||
if prompt_token_ids is not None:
|
||||
parsed_prompts = self._convert_v1_inputs(
|
||||
prompts=cast(Optional[Union[str, List[str]]], prompts),
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
else:
|
||||
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
|
||||
prompts)
|
||||
|
||||
if isinstance(guided_options_request, dict):
|
||||
if len(guided_options_request) > 1:
|
||||
raise ValueError(
|
||||
"You can only use one guided decoding but multiple is "
|
||||
f"specified: {guided_options_request}")
|
||||
guided_options_request = GuidedDecodingRequest(
|
||||
**guided_options_request)
|
||||
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
guided_options=guided_options_request,
|
||||
priority=priority)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return LLMEngine.validate_outputs(outputs, RequestOutput)
|
||||
|
||||
def beam_search(
|
||||
self,
|
||||
prompts: List[Union[str, List[int]]],
|
||||
params: BeamSearchParams,
|
||||
) -> List[BeamSearchOutput]:
|
||||
"""
|
||||
Generate sequences using beam search.
|
||||
|
||||
Args:
|
||||
prompts: A list of prompts. Each prompt can be a string or a list
|
||||
of token IDs.
|
||||
params: The beam search parameters.
|
||||
|
||||
TODO: how does beam search work together with length penalty, frequency
|
||||
penalty, and stopping criteria, etc.?
|
||||
"""
|
||||
|
||||
beam_width = params.beam_width
|
||||
max_tokens = params.max_tokens
|
||||
temperature = params.temperature
|
||||
ignore_eos = params.ignore_eos
|
||||
length_penalty = params.length_penalty
|
||||
|
||||
def sort_beams_key(x: BeamSearchSequence) -> float:
|
||||
return get_beam_search_score(x.tokens, x.cum_logprob,
|
||||
tokenizer.eos_token_id,
|
||||
length_penalty)
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
# generate 2 * beam_width candidates at each step
|
||||
# following the huggingface transformers implementation
|
||||
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
|
||||
beam_search_params = SamplingParams(logprobs=2 * beam_width,
|
||||
max_tokens=1,
|
||||
temperature=temperature)
|
||||
instances: List[BeamSearchInstance] = []
|
||||
|
||||
for prompt in prompts:
|
||||
prompt_tokens = prompt if isinstance(
|
||||
prompt, list) else tokenizer.encode(prompt)
|
||||
instances.append(BeamSearchInstance(prompt_tokens))
|
||||
|
||||
for _ in range(max_tokens):
|
||||
all_beams: List[BeamSearchSequence] = list(
|
||||
sum((instance.beams for instance in instances), []))
|
||||
pos = [0] + list(
|
||||
itertools.accumulate(
|
||||
len(instance.beams) for instance in instances))
|
||||
instance_start_and_end: List[Tuple[int, int]] = list(
|
||||
zip(pos[:-1], pos[1:]))
|
||||
|
||||
if len(all_beams) == 0:
|
||||
break
|
||||
|
||||
prompts_batch = [
|
||||
TokensPrompt(prompt_token_ids=beam.tokens)
|
||||
for beam in all_beams
|
||||
]
|
||||
|
||||
# only runs for one step
|
||||
# we don't need to use tqdm here
|
||||
output = self.generate(prompts_batch,
|
||||
sampling_params=beam_search_params,
|
||||
use_tqdm=False)
|
||||
|
||||
for (start, end), instance in zip(instance_start_and_end,
|
||||
instances):
|
||||
instance_new_beams = []
|
||||
for i in range(start, end):
|
||||
current_beam = all_beams[i]
|
||||
result = output[i]
|
||||
|
||||
if result.outputs[0].logprobs is not None:
|
||||
# if `result.outputs[0].logprobs` is None, it means
|
||||
# the sequence is completed because of the max-model-len
|
||||
# or abortion. we don't need to add it to the new beams.
|
||||
logprobs = result.outputs[0].logprobs[0]
|
||||
for token_id, logprob_obj in logprobs.items():
|
||||
new_beam = BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
cum_logprob=current_beam.cum_logprob +
|
||||
logprob_obj.logprob)
|
||||
|
||||
if token_id == tokenizer.eos_token_id and \
|
||||
not ignore_eos:
|
||||
instance.completed.append(new_beam)
|
||||
else:
|
||||
instance_new_beams.append(new_beam)
|
||||
sorted_beams = sorted(instance_new_beams,
|
||||
key=sort_beams_key,
|
||||
reverse=True)
|
||||
instance.beams = sorted_beams[:beam_width]
|
||||
|
||||
outputs = []
|
||||
for instance in instances:
|
||||
instance.completed.extend(instance.beams)
|
||||
sorted_completed = sorted(instance.completed,
|
||||
key=sort_beams_key,
|
||||
reverse=True)
|
||||
best_beams = sorted_completed[:beam_width]
|
||||
|
||||
for beam in best_beams:
|
||||
beam.text = tokenizer.decode(beam.tokens)
|
||||
outputs.append(BeamSearchOutput(sequences=best_beams))
|
||||
|
||||
return outputs
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: Union[List[ChatCompletionMessageParam],
|
||||
List[List[ChatCompletionMessageParam]]],
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
List[SamplingParams]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
chat_template: Optional[str] = None,
|
||||
add_generation_prompt: bool = True,
|
||||
continue_final_message: bool = False,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
"""
|
||||
Generate responses for a chat conversation.
|
||||
|
||||
The chat conversation is converted into a text prompt using the
|
||||
tokenizer and calls the :meth:`generate` method to generate the
|
||||
responses.
|
||||
|
||||
Multi-modal inputs can be passed in the same way you would pass them
|
||||
to the OpenAI API.
|
||||
|
||||
Args:
|
||||
messages: A list of conversations or a single conversation.
|
||||
- Each conversation is represented as a list of messages.
|
||||
- Each message is a dictionary with 'role' and 'content' keys.
|
||||
sampling_params: The sampling parameters for text generation.
|
||||
If None, we use the default sampling parameters. When it
|
||||
is a single value, it is applied to every prompt. When it
|
||||
is a list, the list must have the same length as the
|
||||
prompts and it is paired one by one with the prompt.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
chat_template: The template to use for structuring the chat.
|
||||
If not provided, the model's default chat template will be used.
|
||||
add_generation_prompt: If True, adds a generation template
|
||||
to each message.
|
||||
continue_final_message: If True, continues the final message in
|
||||
the conversation instead of starting a new one. Cannot be `True`
|
||||
if `add_generation_prompt` is also `True`.
|
||||
mm_processor_kwargs: Multimodal processor kwarg overrides for this
|
||||
chat request. Only used for offline requests.
|
||||
|
||||
Returns:
|
||||
A list of ``RequestOutput`` objects containing the generated
|
||||
responses in the same order as the input messages.
|
||||
"""
|
||||
list_of_messages: List[List[ChatCompletionMessageParam]]
|
||||
|
||||
# Handle multi and single conversations
|
||||
if is_list_of(messages, list):
|
||||
# messages is List[List[...]]
|
||||
list_of_messages = cast(List[List[ChatCompletionMessageParam]],
|
||||
messages)
|
||||
else:
|
||||
# messages is List[...]
|
||||
list_of_messages = [
|
||||
cast(List[ChatCompletionMessageParam], messages)
|
||||
]
|
||||
|
||||
prompts: List[Union[TokensPrompt, TextPrompt]] = []
|
||||
|
||||
for msgs in list_of_messages:
|
||||
tokenizer = self.get_tokenizer()
|
||||
model_config = self.llm_engine.get_model_config()
|
||||
|
||||
# NOTE: _parse_chat_message_content_parts() currently doesn't
|
||||
# handle mm_processor_kwargs, since there is no implementation in
|
||||
# the chat message parsing for it.
|
||||
conversation, mm_data = parse_chat_messages(
|
||||
msgs, model_config, tokenizer)
|
||||
|
||||
prompt_data: Union[str, List[int]]
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
prompt_data = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=msgs,
|
||||
chat_template=chat_template,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tools,
|
||||
)
|
||||
else:
|
||||
prompt_data = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
conversation=conversation,
|
||||
chat_template=chat_template,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
prompt: Union[TokensPrompt, TextPrompt]
|
||||
if is_list_of(prompt_data, int):
|
||||
prompt = TokensPrompt(prompt_token_ids=prompt_data)
|
||||
else:
|
||||
prompt = TextPrompt(prompt=prompt_data)
|
||||
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
|
||||
if mm_processor_kwargs is not None:
|
||||
prompt["mm_processor_kwargs"] = mm_processor_kwargs
|
||||
|
||||
prompts.append(prompt)
|
||||
|
||||
return self.generate(
|
||||
prompts,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
@overload # LEGACY: single (prompt + optional token ids)
|
||||
def encode(
|
||||
self,
|
||||
prompts: str,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: multi (prompt + optional token ids)
|
||||
def encode(
|
||||
self,
|
||||
prompts: List[str],
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: single (token ids + optional prompt)
|
||||
def encode(
|
||||
self,
|
||||
prompts: Optional[str] = None,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
*,
|
||||
prompt_token_ids: List[int],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: multi (token ids + optional prompt)
|
||||
def encode(
|
||||
self,
|
||||
prompts: Optional[List[str]] = None,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
*,
|
||||
prompt_token_ids: List[List[int]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: single or multi token ids [pos-only]
|
||||
def encode(
|
||||
self,
|
||||
prompts: None,
|
||||
pooling_params: None,
|
||||
prompt_token_ids: Union[List[int], List[List[int]]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def encode(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
/,
|
||||
*,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@deprecate_kwargs(
|
||||
"prompt_token_ids",
|
||||
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
||||
additional_message="Please use the 'prompts' parameter instead.",
|
||||
)
|
||||
def encode(
|
||||
self,
|
||||
prompts: Union[Union[PromptType, Sequence[PromptType]],
|
||||
Optional[Union[str, List[str]]]] = None,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
|
||||
This class automatically batches the given prompts, considering
|
||||
the memory constraint. For the best performance, put all of your prompts
|
||||
into a single list and pass it to this method.
|
||||
|
||||
Args:
|
||||
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
||||
for batch inference. See :class:`~vllm.inputs.PromptType`
|
||||
for more details about the format of each prompts.
|
||||
pooling_params: The pooling parameters for pooling. If None, we
|
||||
use the default pooling parameters.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of `EmbeddingRequestOutput` objects containing the
|
||||
generated embeddings in the same order as the input prompts.
|
||||
|
||||
Note:
|
||||
Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
|
||||
considered legacy and may be deprecated in the future. You should
|
||||
instead pass them via the ``inputs`` parameter.
|
||||
"""
|
||||
if not self.llm_engine.model_config.embedding_mode:
|
||||
raise ValueError(
|
||||
"LLM.encode() is only supported for embedding models (XModel)."
|
||||
)
|
||||
|
||||
if prompt_token_ids is not None:
|
||||
parsed_prompts = self._convert_v1_inputs(
|
||||
prompts=cast(Optional[Union[str, List[str]]], prompts),
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
else:
|
||||
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
|
||||
prompts)
|
||||
|
||||
if pooling_params is None:
|
||||
# Use default pooling params.
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.llm_engine.start_profile()
|
||||
|
||||
def stop_profile(self) -> None:
|
||||
self.llm_engine.stop_profile()
|
||||
|
||||
# LEGACY
|
||||
def _convert_v1_inputs(
|
||||
self,
|
||||
prompts: Optional[Union[str, List[str]]],
|
||||
prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
|
||||
):
|
||||
# skip_tokenizer_init is now checked in engine
|
||||
|
||||
if prompts is not None:
|
||||
prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
|
||||
if prompt_token_ids is not None:
|
||||
prompt_token_ids = [
|
||||
p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
|
||||
]
|
||||
|
||||
num_requests = None
|
||||
if prompts is not None:
|
||||
num_requests = len(prompts)
|
||||
if prompt_token_ids is not None:
|
||||
if (num_requests is not None
|
||||
and num_requests != len(prompt_token_ids)):
|
||||
raise ValueError("The lengths of prompts and prompt_token_ids "
|
||||
"must be the same.")
|
||||
|
||||
num_requests = len(prompt_token_ids)
|
||||
if num_requests is None:
|
||||
raise ValueError("Either prompts or prompt_token_ids must be "
|
||||
"provided.")
|
||||
|
||||
parsed_prompts: List[PromptType] = []
|
||||
for i in range(num_requests):
|
||||
item: PromptType
|
||||
|
||||
if prompts is not None:
|
||||
item = TextPrompt(prompt=prompts[i])
|
||||
elif prompt_token_ids is not None:
|
||||
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
|
||||
else:
|
||||
raise AssertionError
|
||||
|
||||
parsed_prompts.append(item)
|
||||
|
||||
return parsed_prompts
|
||||
|
||||
def _validate_and_add_requests(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
|
||||
Sequence[PoolingParams]],
|
||||
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
guided_options: Optional[GuidedDecodingRequest] = None,
|
||||
priority: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
if guided_options is not None:
|
||||
warnings.warn(
|
||||
"guided_options_request is deprecated, use "
|
||||
"SamplingParams.guided_decoding instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if isinstance(prompts, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
prompts = [prompts]
|
||||
|
||||
num_requests = len(prompts)
|
||||
if isinstance(params, list) and len(params) != num_requests:
|
||||
raise ValueError("The lengths of prompts and params "
|
||||
"must be the same.")
|
||||
if isinstance(lora_request,
|
||||
list) and len(lora_request) != num_requests:
|
||||
raise ValueError("The lengths of prompts and lora_request "
|
||||
"must be the same.")
|
||||
|
||||
for sp in params if isinstance(params, list) else (params, ):
|
||||
if isinstance(sp, SamplingParams):
|
||||
self._add_guided_params(sp, guided_options)
|
||||
|
||||
# We only care about the final output
|
||||
sp.output_kind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
# Add requests to the engine.
|
||||
for i, prompt in enumerate(prompts):
|
||||
self._add_request(
|
||||
prompt,
|
||||
params[i] if isinstance(params, Sequence) else params,
|
||||
lora_request=lora_request[i] if isinstance(
|
||||
lora_request, Sequence) else lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority[i] if priority else 0,
|
||||
)
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_engine.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
def _add_guided_params(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
guided_options: Optional[GuidedDecodingRequest] = None):
|
||||
if guided_options is None:
|
||||
return params
|
||||
|
||||
if params.guided_decoding is not None:
|
||||
raise ValueError("Cannot set both guided_options_request and"
|
||||
"params.guided_decoding.")
|
||||
|
||||
params.guided_decoding = GuidedDecodingParams(
|
||||
json=guided_options.guided_json,
|
||||
regex=guided_options.guided_regex,
|
||||
choice=guided_options.guided_choice,
|
||||
grammar=guided_options.guided_grammar,
|
||||
json_object=guided_options.guided_json_object,
|
||||
backend=guided_options.guided_decoding_backend,
|
||||
whitespace_pattern=guided_options.guided_whitespace_pattern)
|
||||
return params
|
||||
|
||||
def _run_engine(
|
||||
self, *, use_tqdm: bool
|
||||
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
num_requests = self.llm_engine.get_num_unfinished_requests()
|
||||
pbar = tqdm(
|
||||
total=num_requests,
|
||||
desc="Processed prompts",
|
||||
dynamic_ncols=True,
|
||||
postfix=(f"est. speed input: {0:.2f} toks/s, "
|
||||
f"output: {0:.2f} toks/s"),
|
||||
)
|
||||
|
||||
# Run the engine.
|
||||
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
|
||||
total_in_toks = 0
|
||||
total_out_toks = 0
|
||||
while self.llm_engine.has_unfinished_requests():
|
||||
step_outputs = self.llm_engine.step()
|
||||
for output in step_outputs:
|
||||
if output.finished:
|
||||
outputs.append(output)
|
||||
if use_tqdm:
|
||||
if isinstance(output, RequestOutput):
|
||||
# Calculate tokens only for RequestOutput
|
||||
assert output.prompt_token_ids is not None
|
||||
total_in_toks += len(output.prompt_token_ids)
|
||||
in_spd = total_in_toks / pbar.format_dict["elapsed"]
|
||||
total_out_toks += sum(
|
||||
len(stp.token_ids) for stp in output.outputs)
|
||||
out_spd = (total_out_toks /
|
||||
pbar.format_dict["elapsed"])
|
||||
pbar.postfix = (
|
||||
f"est. speed input: {in_spd:.2f} toks/s, "
|
||||
f"output: {out_spd:.2f} toks/s")
|
||||
pbar.update(1)
|
||||
|
||||
if use_tqdm:
|
||||
pbar.close()
|
||||
# Sort the outputs by request ID.
|
||||
# This is necessary because some requests may be finished earlier than
|
||||
# its previous requests.
|
||||
return sorted(outputs, key=lambda x: int(x.request_id))
|
||||
|
||||
def _is_encoder_decoder_model(self):
|
||||
return self.llm_engine.is_encoder_decoder_model()
|
||||
|
||||
def _is_embedding_model(self):
|
||||
return self.llm_engine.is_embedding_model()
|
||||
42
vllm/entrypoints/logger.py
Normal file
42
vllm/entrypoints/logger.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RequestLogger:
|
||||
|
||||
def __init__(self, *, max_log_len: Optional[int]) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.max_log_len = max_log_len
|
||||
|
||||
def log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]],
|
||||
params: Optional[Union[SamplingParams, PoolingParams,
|
||||
BeamSearchParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> None:
|
||||
max_log_len = self.max_log_len
|
||||
if max_log_len is not None:
|
||||
if prompt is not None:
|
||||
prompt = prompt[:max_log_len]
|
||||
|
||||
if prompt_token_ids is not None:
|
||||
prompt_token_ids = prompt_token_ids[:max_log_len]
|
||||
|
||||
logger.info(
|
||||
"Received request %s: prompt: %r, "
|
||||
"params: %s, prompt_token_ids: %s, "
|
||||
"lora_request: %s, prompt_adapter_request: %s.", request_id,
|
||||
prompt, params, prompt_token_ids, lora_request,
|
||||
prompt_adapter_request)
|
||||
0
vllm/entrypoints/openai/__init__.py
Normal file
0
vllm/entrypoints/openai/__init__.py
Normal file
BIN
vllm/entrypoints/openai/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/api_server.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/api_server.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/cli_args.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/cli_args.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/protocol.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/protocol.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/run_batch.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/run_batch.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/serving_chat.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/serving_chat.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
585
vllm/entrypoints/openai/api_server.py
Normal file
585
vllm/entrypoints/openai/api_server.py
Normal file
@@ -0,0 +1,585 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import inspect
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import socket
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncIterator, Set
|
||||
|
||||
import uvloop
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from starlette.datastructures import State
|
||||
from starlette.routing import Mount
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.engine.multiprocessing.engine import run_mp_engine
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||
validate_parsed_serve_args)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse, ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
UnloadLoraAdapterRequest)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
logger = init_logger('vllm.entrypoints.openai.api_server')
|
||||
|
||||
_running_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
try:
|
||||
if app.state.log_stats:
|
||||
engine_client: EngineClient = app.state.engine_client
|
||||
|
||||
async def _force_log():
|
||||
while True:
|
||||
await asyncio.sleep(10.)
|
||||
await engine_client.do_log_stats()
|
||||
|
||||
task = asyncio.create_task(_force_log())
|
||||
_running_tasks.add(task)
|
||||
task.add_done_callback(_running_tasks.remove)
|
||||
else:
|
||||
task = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
finally:
|
||||
# Ensure app state including engine ref is gc'd
|
||||
del app.state
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client(
|
||||
args: Namespace) -> AsyncIterator[EngineClient]:
|
||||
|
||||
# Context manager to handle engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, args.disable_frontend_multiprocessing) as engine:
|
||||
yield engine
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client_from_engine_args(
|
||||
engine_args: AsyncEngineArgs,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
) -> AsyncIterator[EngineClient]:
|
||||
"""
|
||||
Create EngineClient, either:
|
||||
- in-process using the AsyncLLMEngine Directly
|
||||
- multiprocess using AsyncLLMEngine RPC
|
||||
|
||||
Returns the Client or None if the creation failed.
|
||||
"""
|
||||
|
||||
# Fall back
|
||||
# TODO: fill out feature matrix.
|
||||
if (MQLLMEngineClient.is_unsupported_config(engine_args)
|
||||
or disable_frontend_multiprocessing):
|
||||
engine_config = engine_args.create_engine_config()
|
||||
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
|
||||
"uses_ray", False)
|
||||
|
||||
build_engine = partial(AsyncLLMEngine.from_engine_args,
|
||||
engine_args=engine_args,
|
||||
engine_config=engine_config,
|
||||
usage_context=UsageContext.OPENAI_API_SERVER)
|
||||
if uses_ray:
|
||||
# Must run in main thread with ray for its signal handlers to work
|
||||
engine_client = build_engine()
|
||||
else:
|
||||
engine_client = await asyncio.get_running_loop().run_in_executor(
|
||||
None, build_engine)
|
||||
|
||||
yield engine_client
|
||||
return
|
||||
|
||||
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
||||
else:
|
||||
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
|
||||
# Make TemporaryDirectory for prometheus multiprocessing
|
||||
# Note: global TemporaryDirectory will be automatically
|
||||
# cleaned up upon exit.
|
||||
global prometheus_multiproc_dir
|
||||
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
||||
os.environ[
|
||||
"PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
|
||||
else:
|
||||
logger.warning(
|
||||
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
|
||||
"This directory must be wiped between vLLM runs or "
|
||||
"you will find inaccurate metrics. Unset the variable "
|
||||
"and vLLM will properly handle cleanup.")
|
||||
|
||||
# Select random path for IPC.
|
||||
ipc_path = get_open_zmq_ipc_path()
|
||||
logger.info("Multiprocessing frontend to use %s for IPC Path.",
|
||||
ipc_path)
|
||||
|
||||
# Start RPCServer in separate process (holds the LLMEngine).
|
||||
# the current process might have CUDA context,
|
||||
# so we need to spawn a new process
|
||||
context = multiprocessing.get_context("spawn")
|
||||
|
||||
engine_process = context.Process(target=run_mp_engine,
|
||||
args=(engine_args,
|
||||
UsageContext.OPENAI_API_SERVER,
|
||||
ipc_path))
|
||||
engine_process.start()
|
||||
logger.info("Started engine process with PID %d", engine_process.pid)
|
||||
|
||||
# Build RPCClient, which conforms to EngineClient Protocol.
|
||||
# NOTE: Actually, this is not true yet. We still need to support
|
||||
# embedding models via RPC (see TODO above)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
await mp_engine_client.setup()
|
||||
break
|
||||
except TimeoutError:
|
||||
if not engine_process.is_alive():
|
||||
raise RuntimeError(
|
||||
"Engine process failed to start") from None
|
||||
|
||||
yield mp_engine_client # type: ignore[misc]
|
||||
finally:
|
||||
# Ensure rpc server process was terminated
|
||||
engine_process.terminate()
|
||||
|
||||
# Close all open connections to the backend
|
||||
mp_engine_client.close()
|
||||
|
||||
# Wait for engine process to join
|
||||
engine_process.join(4)
|
||||
if engine_process.exitcode is None:
|
||||
# Kill if taking longer than 5 seconds to stop
|
||||
engine_process.kill()
|
||||
|
||||
# Lazy import for prometheus multiprocessing.
|
||||
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||
# before prometheus_client is imported.
|
||||
# See https://prometheus.github.io/client_python/multiprocess/
|
||||
from prometheus_client import multiprocess
|
||||
multiprocess.mark_process_dead(engine_process.pid)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def mount_metrics(app: FastAPI):
|
||||
# Lazy import for prometheus multiprocessing.
|
||||
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||
# before prometheus_client is imported.
|
||||
# See https://prometheus.github.io/client_python/multiprocess/
|
||||
from prometheus_client import (CollectorRegistry, make_asgi_app,
|
||||
multiprocess)
|
||||
|
||||
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
|
||||
if prometheus_multiproc_dir_path is not None:
|
||||
logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
|
||||
prometheus_multiproc_dir_path)
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
|
||||
# Add prometheus asgi middleware to route /metrics requests
|
||||
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||
else:
|
||||
# Add prometheus asgi middleware to route /metrics requests
|
||||
metrics_route = Mount("/metrics", make_asgi_app())
|
||||
|
||||
# Workaround for 307 Redirect for /metrics
|
||||
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||
app.routes.append(metrics_route)
|
||||
|
||||
|
||||
def chat(request: Request) -> OpenAIServingChat:
|
||||
return request.app.state.openai_serving_chat
|
||||
|
||||
|
||||
def completion(request: Request) -> OpenAIServingCompletion:
|
||||
return request.app.state.openai_serving_completion
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
def embedding(request: Request) -> OpenAIServingEmbedding:
|
||||
return request.app.state.openai_serving_embedding
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health(raw_request: Request) -> Response:
|
||||
"""Health check."""
|
||||
await engine_client(raw_request).check_health()
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.post("/tokenize")
|
||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
generator = await tokenization(raw_request).create_tokenize(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, TokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/detokenize")
|
||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||
generator = await tokenization(raw_request).create_detokenize(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, DetokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def show_available_models(raw_request: Request):
|
||||
models = await completion(raw_request).show_available_models()
|
||||
return JSONResponse(content=models.model_dump())
|
||||
|
||||
|
||||
@router.get("/version")
|
||||
async def show_version():
|
||||
ver = {"version": VLLM_VERSION}
|
||||
return JSONResponse(content=ver)
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def create_chat_completion(request: ChatCompletionRequest,
|
||||
raw_request: Request):
|
||||
|
||||
generator = await chat(raw_request).create_chat_completion(
|
||||
request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
|
||||
elif isinstance(generator, ChatCompletionResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
generator = await completion(raw_request).create_completion(
|
||||
request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
generator = await embedding(raw_request).create_embedding(
|
||||
request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, EmbeddingResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!")
|
||||
|
||||
@router.post("/start_profile")
|
||||
async def start_profile(raw_request: Request):
|
||||
logger.info("Starting profiler...")
|
||||
await engine_client(raw_request).start_profile()
|
||||
logger.info("Profiler started.")
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/stop_profile")
|
||||
async def stop_profile(raw_request: Request):
|
||||
logger.info("Stopping profiler...")
|
||||
await engine_client(raw_request).stop_profile()
|
||||
logger.info("Profiler stopped.")
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
logger.warning(
|
||||
"Lora dynamic loading & unloading is enabled in the API server. "
|
||||
"This should ONLY be used for local development!")
|
||||
|
||||
@router.post("/v1/load_lora_adapter")
|
||||
async def load_lora_adapter(request: LoadLoraAdapterRequest,
|
||||
raw_request: Request):
|
||||
response = await chat(raw_request).load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
response = await completion(raw_request).load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
@router.post("/v1/unload_lora_adapter")
|
||||
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
|
||||
raw_request: Request):
|
||||
response = await chat(raw_request).unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
response = await completion(raw_request).unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
if args.disable_fastapi_docs:
|
||||
app = FastAPI(openapi_url=None,
|
||||
docs_url=None,
|
||||
redoc_url=None,
|
||||
lifespan=lifespan)
|
||||
else:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.include_router(router)
|
||||
app.root_path = args.root_path
|
||||
|
||||
mount_metrics(app)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=args.allowed_origins,
|
||||
allow_credentials=args.allow_credentials,
|
||||
allow_methods=args.allowed_methods,
|
||||
allow_headers=args.allowed_headers,
|
||||
)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(_, exc):
|
||||
chat = app.state.openai_serving_chat
|
||||
err = chat.create_error_response(message=str(exc))
|
||||
return JSONResponse(err.model_dump(),
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
if token := envs.VLLM_API_KEY or args.api_key:
|
||||
|
||||
@app.middleware("http")
|
||||
async def authentication(request: Request, call_next):
|
||||
root_path = "" if args.root_path is None else args.root_path
|
||||
if request.method == "OPTIONS":
|
||||
return await call_next(request)
|
||||
if not request.url.path.startswith(f"{root_path}/v1"):
|
||||
return await call_next(request)
|
||||
if request.headers.get("Authorization") != "Bearer " + token:
|
||||
return JSONResponse(content={"error": "Unauthorized"},
|
||||
status_code=401)
|
||||
return await call_next(request)
|
||||
|
||||
for middleware in args.middleware:
|
||||
module_path, object_name = middleware.rsplit(".", 1)
|
||||
imported = getattr(importlib.import_module(module_path), object_name)
|
||||
if inspect.isclass(imported):
|
||||
app.add_middleware(imported)
|
||||
elif inspect.iscoroutinefunction(imported):
|
||||
app.middleware("http")(imported)
|
||||
else:
|
||||
raise ValueError(f"Invalid middleware {middleware}. "
|
||||
f"Must be a function or a class.")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def init_app_state(
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
state: State,
|
||||
args: Namespace,
|
||||
) -> None:
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
if args.disable_log_requests:
|
||||
request_logger = None
|
||||
else:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model)
|
||||
for name in served_model_names
|
||||
]
|
||||
|
||||
state.engine_client = engine_client
|
||||
state.log_stats = not args.disable_log_stats
|
||||
|
||||
state.openai_serving_chat = OpenAIServingChat(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
args.response_role,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser)
|
||||
state.openai_serving_completion = OpenAIServingCompletion(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
)
|
||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
)
|
||||
|
||||
|
||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
valide_tool_parses = ToolParserManager.tool_parsers.keys()
|
||||
if args.enable_auto_tool_choice \
|
||||
and args.tool_call_parser not in valide_tool_parses:
|
||||
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
|
||||
f"(chose from {{ {','.join(valide_tool_parses)} }})")
|
||||
|
||||
# workaround to make sure that we bind the port before the engine is set up.
|
||||
# This avoids race conditions with ray.
|
||||
# see https://github.com/vllm-project/vllm/issues/8204
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.bind(("", args.port))
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm while initializing
|
||||
raise KeyboardInterrupt("terminated")
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
async with build_async_engine_client(args) as engine_client:
|
||||
app = build_app(args)
|
||||
|
||||
model_config = await engine_client.get_model_config()
|
||||
init_app_state(engine_client, model_config, app.state, args)
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.uvicorn_log_level,
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
ssl_ca_certs=args.ssl_ca_certs,
|
||||
ssl_cert_reqs=args.ssl_cert_reqs,
|
||||
fd=sock.fileno(),
|
||||
**uvicorn_kwargs,
|
||||
)
|
||||
|
||||
# NB: Await server shutdown only after the backend context is exited
|
||||
await shutdown_task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# NOTE(simon):
|
||||
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args()
|
||||
validate_parsed_serve_args(args)
|
||||
|
||||
uvloop.run(run_server(args))
|
||||
252
vllm/entrypoints/openai/cli_args.py
Normal file
252
vllm/entrypoints/openai/cli_args.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""
|
||||
This file contains the command line arguments for the vLLM's
|
||||
OpenAI-compatible server. It is kept in a separate file for documentation
|
||||
purposes.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import ssl
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.entrypoints.chat_utils import validate_chat_template
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
PromptAdapterPath)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class LoRAParserAction(argparse.Action):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
parser: argparse.ArgumentParser,
|
||||
namespace: argparse.Namespace,
|
||||
values: Optional[Union[str, Sequence[str]]],
|
||||
option_string: Optional[str] = None,
|
||||
):
|
||||
if values is None:
|
||||
values = []
|
||||
if isinstance(values, str):
|
||||
raise TypeError("Expected values to be a list")
|
||||
|
||||
lora_list: List[LoRAModulePath] = []
|
||||
for item in values:
|
||||
if item in [None, '']: # Skip if item is None or empty string
|
||||
continue
|
||||
if '=' in item and ',' not in item: # Old format: name=path
|
||||
name, path = item.split('=')
|
||||
lora_list.append(LoRAModulePath(name, path))
|
||||
else: # Assume JSON format
|
||||
try:
|
||||
lora_dict = json.loads(item)
|
||||
lora = LoRAModulePath(**lora_dict)
|
||||
lora_list.append(lora)
|
||||
except json.JSONDecodeError:
|
||||
parser.error(
|
||||
f"Invalid JSON format for --lora-modules: {item}")
|
||||
except TypeError as e:
|
||||
parser.error(
|
||||
f"Invalid fields for --lora-modules: {item} - {str(e)}"
|
||||
)
|
||||
setattr(namespace, self.dest, lora_list)
|
||||
|
||||
|
||||
class PromptAdapterParserAction(argparse.Action):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
parser: argparse.ArgumentParser,
|
||||
namespace: argparse.Namespace,
|
||||
values: Optional[Union[str, Sequence[str]]],
|
||||
option_string: Optional[str] = None,
|
||||
):
|
||||
if values is None:
|
||||
values = []
|
||||
if isinstance(values, str):
|
||||
raise TypeError("Expected values to be a list")
|
||||
|
||||
adapter_list: List[PromptAdapterPath] = []
|
||||
for item in values:
|
||||
name, path = item.split('=')
|
||||
adapter_list.append(PromptAdapterPath(name, path))
|
||||
setattr(namespace, self.dest, adapter_list)
|
||||
|
||||
|
||||
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
parser.add_argument("--host",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="host name")
|
||||
parser.add_argument("--port", type=int, default=8000, help="port number")
|
||||
parser.add_argument(
|
||||
"--uvicorn-log-level",
|
||||
type=str,
|
||||
default="info",
|
||||
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
|
||||
help="log level for uvicorn")
|
||||
parser.add_argument("--allow-credentials",
|
||||
action="store_true",
|
||||
help="allow credentials")
|
||||
parser.add_argument("--allowed-origins",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed origins")
|
||||
parser.add_argument("--allowed-methods",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed methods")
|
||||
parser.add_argument("--allowed-headers",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed headers")
|
||||
parser.add_argument("--api-key",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="If provided, the server will require this key "
|
||||
"to be presented in the header.")
|
||||
parser.add_argument(
|
||||
"--lora-modules",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
nargs='+',
|
||||
action=LoRAParserAction,
|
||||
help="LoRA module configurations in either 'name=path' format"
|
||||
"or JSON format. "
|
||||
"Example (old format): 'name=path' "
|
||||
"Example (new format): "
|
||||
"'{\"name\": \"name\", \"local_path\": \"path\", "
|
||||
"\"base_model_name\": \"id\"}'")
|
||||
parser.add_argument(
|
||||
"--prompt-adapters",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
nargs='+',
|
||||
action=PromptAdapterParserAction,
|
||||
help="Prompt adapter configurations in the format name=path. "
|
||||
"Multiple adapters can be specified.")
|
||||
parser.add_argument("--chat-template",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="The file path to the chat template, "
|
||||
"or the template in single-line form "
|
||||
"for the specified model")
|
||||
parser.add_argument("--response-role",
|
||||
type=nullable_str,
|
||||
default="assistant",
|
||||
help="The role name to return if "
|
||||
"`request.add_generation_prompt=true`.")
|
||||
parser.add_argument("--ssl-keyfile",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="The file path to the SSL key file")
|
||||
parser.add_argument("--ssl-certfile",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="The file path to the SSL cert file")
|
||||
parser.add_argument("--ssl-ca-certs",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="The CA certificates file")
|
||||
parser.add_argument(
|
||||
"--ssl-cert-reqs",
|
||||
type=int,
|
||||
default=int(ssl.CERT_NONE),
|
||||
help="Whether client certificate is required (see stdlib ssl module's)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root-path",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="FastAPI root_path when app is behind a path based routing proxy")
|
||||
parser.add_argument(
|
||||
"--middleware",
|
||||
type=nullable_str,
|
||||
action="append",
|
||||
default=[],
|
||||
help="Additional ASGI middleware to apply to the app. "
|
||||
"We accept multiple --middleware arguments. "
|
||||
"The value should be an import path. "
|
||||
"If a function is provided, vLLM will add it to the server "
|
||||
"using @app.middleware('http'). "
|
||||
"If a class is provided, vLLM will add it to the server "
|
||||
"using app.add_middleware(). ")
|
||||
parser.add_argument(
|
||||
"--return-tokens-as-token-ids",
|
||||
action="store_true",
|
||||
help="When --max-logprobs is specified, represents single tokens as "
|
||||
"strings of the form 'token_id:{token_id}' so that tokens that "
|
||||
"are not JSON-encodable can be identified.")
|
||||
parser.add_argument(
|
||||
"--disable-frontend-multiprocessing",
|
||||
action="store_true",
|
||||
help="If specified, will run the OpenAI frontend server in the same "
|
||||
"process as the model serving engine.")
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-auto-tool-choice",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=
|
||||
"Enable auto tool choice for supported models. Use --tool-call-parser"
|
||||
"to specify which parser to use")
|
||||
|
||||
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
|
||||
parser.add_argument(
|
||||
"--tool-call-parser",
|
||||
type=str,
|
||||
metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
|
||||
"--tool-parser-plugin",
|
||||
default=None,
|
||||
help=
|
||||
"Select the tool call parser depending on the model that you're using."
|
||||
" This is used to parse the model-generated tool call into OpenAI API "
|
||||
"format. Required for --enable-auto-tool-choice.")
|
||||
|
||||
parser.add_argument(
|
||||
"--tool-parser-plugin",
|
||||
type=str,
|
||||
default="",
|
||||
help=
|
||||
"Special the tool parser plugin write to parse the model-generated tool"
|
||||
" into OpenAI API format, the name register in this plugin can be used "
|
||||
"in --tool-call-parser.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument('--max-log-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Max number of prompt characters or prompt '
|
||||
'ID numbers being printed in log.'
|
||||
'\n\nDefault: Unlimited')
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-fastapi-docs",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def validate_parsed_serve_args(args: argparse.Namespace):
|
||||
"""Quick checks for model serve args that raise prior to loading."""
|
||||
if hasattr(args, "subparser") and args.subparser != "serve":
|
||||
return
|
||||
|
||||
# Ensure that the chat template is valid; raises if it likely isn't
|
||||
validate_chat_template(args.chat_template)
|
||||
|
||||
# Enable auto tool needs a tool call parser to be valid
|
||||
if args.enable_auto_tool_choice and not args.tool_call_parser:
|
||||
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||
"--tool-call-parser")
|
||||
|
||||
|
||||
def create_parser_for_docs() -> FlexibleArgumentParser:
|
||||
parser_for_docs = FlexibleArgumentParser(
|
||||
prog="-m vllm.entrypoints.openai.api_server")
|
||||
return make_arg_parser(parser_for_docs)
|
||||
86
vllm/entrypoints/openai/logits_processors.py
Normal file
86
vllm/entrypoints/openai/logits_processors.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from functools import lru_cache, partial
|
||||
from typing import Dict, FrozenSet, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
class AllowedTokenIdsLogitsProcessor:
|
||||
"""Logits processor for constraining generated tokens to a
|
||||
specific set of token ids."""
|
||||
|
||||
def __init__(self, allowed_ids: Iterable[int]):
|
||||
self.allowed_ids: Optional[List[int]] = list(allowed_ids)
|
||||
self.mask: Optional[torch.Tensor] = None
|
||||
|
||||
def __call__(self, token_ids: List[int],
|
||||
logits: torch.Tensor) -> torch.Tensor:
|
||||
if self.mask is None:
|
||||
self.mask = torch.ones((logits.shape[-1], ),
|
||||
dtype=torch.bool,
|
||||
device=logits.device)
|
||||
self.mask[self.allowed_ids] = False
|
||||
self.allowed_ids = None
|
||||
logits.masked_fill_(self.mask, float("-inf"))
|
||||
return logits
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_allowed_token_ids_logits_processor(
|
||||
allowed_token_ids: FrozenSet[int],
|
||||
vocab_size: int,
|
||||
) -> LogitsProcessor:
|
||||
if not allowed_token_ids:
|
||||
raise ValueError("Empty allowed_token_ids provided")
|
||||
if not all(0 <= tid < vocab_size for tid in allowed_token_ids):
|
||||
raise ValueError("allowed_token_ids contains "
|
||||
"out-of-vocab token id")
|
||||
return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
|
||||
|
||||
|
||||
def logit_bias_logits_processor(
|
||||
logit_bias: Dict[int, float],
|
||||
token_ids: List[int],
|
||||
logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
for token_id, bias in logit_bias.items():
|
||||
logits[token_id] += bias
|
||||
return logits
|
||||
|
||||
|
||||
def get_logits_processors(
|
||||
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
|
||||
allowed_token_ids: Optional[List[int]],
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> List[LogitsProcessor]:
|
||||
logits_processors: List[LogitsProcessor] = []
|
||||
if logit_bias:
|
||||
try:
|
||||
# Convert token_id to integer
|
||||
# Clamp the bias between -100 and 100 per OpenAI API spec
|
||||
clamped_logit_bias: Dict[int, float] = {
|
||||
int(token_id): min(100.0, max(-100.0, bias))
|
||||
for token_id, bias in logit_bias.items()
|
||||
}
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
"Found token_id in logit_bias that is not "
|
||||
"an integer or string representing an integer") from exc
|
||||
|
||||
# Check if token_id is within the vocab size
|
||||
for token_id, bias in clamped_logit_bias.items():
|
||||
if token_id < 0 or token_id >= tokenizer.vocab_size:
|
||||
raise ValueError(f"token_id {token_id} in logit_bias contains "
|
||||
"out-of-vocab token id")
|
||||
|
||||
logits_processors.append(
|
||||
partial(logit_bias_logits_processor, clamped_logit_bias))
|
||||
|
||||
if allowed_token_ids is not None:
|
||||
logits_processors.append(
|
||||
_get_allowed_token_ids_logits_processor(
|
||||
frozenset(allowed_token_ids), tokenizer.vocab_size))
|
||||
|
||||
return logits_processors
|
||||
992
vllm/entrypoints/openai/protocol.py
Normal file
992
vllm/entrypoints/openai/protocol.py
Normal file
@@ -0,0 +1,992 @@
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||
import time
|
||||
from argparse import Namespace
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from openai.types.chat import ChatCompletionContentPartParam
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Annotated, Required, TypedDict
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
RequestOutputKind, SamplingParams)
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
# torch is mocked during docs generation,
|
||||
# so we have to provide the values as literals
|
||||
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
|
||||
_LONG_INFO: Union["torch.iinfo", Namespace]
|
||||
|
||||
try:
|
||||
from sphinx.ext.autodoc.mock import _MockModule
|
||||
|
||||
if isinstance(torch, _MockModule):
|
||||
_LONG_INFO = _MOCK_LONG_INFO
|
||||
else:
|
||||
_LONG_INFO = torch.iinfo(torch.long)
|
||||
except ModuleNotFoundError:
|
||||
_LONG_INFO = torch.iinfo(torch.long)
|
||||
|
||||
assert _LONG_INFO.min == _MOCK_LONG_INFO.min
|
||||
assert _LONG_INFO.max == _MOCK_LONG_INFO.max
|
||||
|
||||
|
||||
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
"""Enables custom roles in the Chat Completion API."""
|
||||
role: Required[str]
|
||||
"""The role of the message's author."""
|
||||
|
||||
content: Union[str, List[ChatCompletionContentPartParam]]
|
||||
"""The contents of the message."""
|
||||
|
||||
name: str
|
||||
"""An optional name for the participant.
|
||||
|
||||
Provides the model information to differentiate between participants of the
|
||||
same role.
|
||||
"""
|
||||
|
||||
tool_call_id: Optional[str]
|
||||
|
||||
tool_calls: Optional[List[dict]]
|
||||
|
||||
|
||||
class OpenAIBaseModel(BaseModel):
|
||||
# OpenAI API does not allow extra fields
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ErrorResponse(OpenAIBaseModel):
|
||||
object: str = "error"
|
||||
message: str
|
||||
type: str
|
||||
param: Optional[str] = None
|
||||
code: int
|
||||
|
||||
|
||||
class ModelPermission(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
|
||||
object: str = "model_permission"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
allow_create_engine: bool = False
|
||||
allow_sampling: bool = True
|
||||
allow_logprobs: bool = True
|
||||
allow_search_indices: bool = False
|
||||
allow_view: bool = True
|
||||
allow_fine_tuning: bool = False
|
||||
organization: str = "*"
|
||||
group: Optional[str] = None
|
||||
is_blocking: bool = False
|
||||
|
||||
|
||||
class ModelCard(OpenAIBaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "vllm"
|
||||
root: Optional[str] = None
|
||||
parent: Optional[str] = None
|
||||
max_model_len: Optional[int] = None
|
||||
permission: List[ModelPermission] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ModelList(OpenAIBaseModel):
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = Field(default_factory=list)
|
||||
|
||||
|
||||
class UsageInfo(OpenAIBaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
|
||||
|
||||
class RequestResponseMetadata(BaseModel):
|
||||
request_id: str
|
||||
final_usage_info: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
class JsonSchemaResponseFormat(OpenAIBaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
# schema is the field in openai but that causes conflicts with pydantic so
|
||||
# instead use json_schema with an alias
|
||||
json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema')
|
||||
strict: Optional[bool] = None
|
||||
|
||||
|
||||
class ResponseFormat(OpenAIBaseModel):
|
||||
# type must be "json_schema", "json_object" or "text"
|
||||
type: Literal["text", "json_object", "json_schema"]
|
||||
json_schema: Optional[JsonSchemaResponseFormat] = None
|
||||
|
||||
|
||||
class StreamOptions(OpenAIBaseModel):
|
||||
include_usage: Optional[bool] = True
|
||||
continuous_usage_stats: Optional[bool] = True
|
||||
|
||||
|
||||
class FunctionDefinition(OpenAIBaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ChatCompletionToolsParam(OpenAIBaseModel):
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionDefinition
|
||||
|
||||
|
||||
class ChatCompletionNamedFunction(OpenAIBaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
|
||||
function: ChatCompletionNamedFunction
|
||||
type: Literal["function"] = "function"
|
||||
|
||||
|
||||
class ChatCompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
model: str
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: Optional[bool] = False
|
||||
top_logprobs: Optional[int] = 0
|
||||
max_tokens: Optional[int] = None
|
||||
n: Optional[int] = 1
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
temperature: Optional[float] = 0.7
|
||||
top_p: Optional[float] = 1.0
|
||||
tools: Optional[List[ChatCompletionToolsParam]] = None
|
||||
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
|
||||
ChatCompletionNamedToolChoiceParam]] = "none"
|
||||
|
||||
# NOTE this will be ignored by VLLM -- the model determines the behavior
|
||||
parallel_tool_calls: Optional[bool] = False
|
||||
user: Optional[str] = None
|
||||
|
||||
# doc: begin-chat-completion-sampling-params
|
||||
best_of: Optional[int] = None
|
||||
use_beam_search: bool = False
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
repetition_penalty: float = 1.0
|
||||
length_penalty: float = 1.0
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
include_stop_str_in_output: bool = False
|
||||
ignore_eos: bool = False
|
||||
min_tokens: int = 0
|
||||
skip_special_tokens: bool = True
|
||||
spaces_between_special_tokens: bool = True
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
prompt_logprobs: Optional[int] = None
|
||||
# doc: end-chat-completion-sampling-params
|
||||
|
||||
# doc: begin-chat-completion-extra-params
|
||||
echo: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, the new message will be prepended with the last message "
|
||||
"if they belong to the same role."),
|
||||
)
|
||||
add_generation_prompt: bool = Field(
|
||||
default=True,
|
||||
description=
|
||||
("If true, the generation prompt will be added to the chat template. "
|
||||
"This is a parameter used by chat template in tokenizer config of the "
|
||||
"model."),
|
||||
)
|
||||
continue_final_message: bool = Field(
|
||||
default=False,
|
||||
description=
|
||||
("If this is set, the chat will be formatted so that the final "
|
||||
"message in the chat is open-ended, without any EOS tokens. The "
|
||||
"model will continue this message rather than starting a new one. "
|
||||
"This allows you to \"prefill\" part of the model's response for it. "
|
||||
"Cannot be used at the same time as `add_generation_prompt`."),
|
||||
)
|
||||
add_special_tokens: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||
"on top of what is added by the chat template. "
|
||||
"For most models, the chat template takes care of adding the "
|
||||
"special tokens so this should be set to false (as is the "
|
||||
"default)."),
|
||||
)
|
||||
documents: Optional[List[Dict[str, str]]] = Field(
|
||||
default=None,
|
||||
description=
|
||||
("A list of dicts representing documents that will be accessible to "
|
||||
"the model if it is performing RAG (retrieval-augmented generation)."
|
||||
" If the template does not support RAG, this argument will have no "
|
||||
"effect. We recommend that each document should be a dict containing "
|
||||
"\"title\" and \"text\" keys."),
|
||||
)
|
||||
chat_template: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A Jinja template to use for this conversion. "
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."),
|
||||
)
|
||||
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."),
|
||||
)
|
||||
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
|
||||
default=None,
|
||||
description=("If specified, the output will follow the JSON schema."),
|
||||
)
|
||||
guided_regex: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the output will follow the regex pattern."),
|
||||
)
|
||||
guided_choice: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the output will be exactly one of the choices."),
|
||||
)
|
||||
guided_grammar: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the output will follow the context free grammar."),
|
||||
)
|
||||
guided_decoding_backend: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default guided decoding backend "
|
||||
"of the server for this specific request. If set, must be either "
|
||||
"'outlines' / 'lm-format-enforcer'"))
|
||||
guided_whitespace_pattern: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default whitespace pattern "
|
||||
"for guided json decoding."))
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
|
||||
# doc: end-chat-completion-extra-params
|
||||
|
||||
def to_beam_search_params(self,
|
||||
default_max_tokens: int) -> BeamSearchParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
n = self.n if self.n is not None else 1
|
||||
temperature = self.temperature if self.temperature is not None else 0.0
|
||||
|
||||
return BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=max_tokens,
|
||||
ignore_eos=self.ignore_eos,
|
||||
temperature=temperature,
|
||||
length_penalty=self.length_penalty,
|
||||
)
|
||||
|
||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
prompt_logprobs = self.prompt_logprobs
|
||||
if prompt_logprobs is None and self.echo:
|
||||
prompt_logprobs = self.top_logprobs
|
||||
|
||||
guided_json_object = None
|
||||
if (self.response_format is not None
|
||||
and self.response_format.type == "json_object"):
|
||||
guided_json_object = True
|
||||
|
||||
guided_decoding = GuidedDecodingParams.from_optional(
|
||||
json=self._get_guided_json_from_tool() or self.guided_json,
|
||||
regex=self.guided_regex,
|
||||
choice=self.guided_choice,
|
||||
grammar=self.guided_grammar,
|
||||
json_object=guided_json_object,
|
||||
backend=self.guided_decoding_backend,
|
||||
whitespace_pattern=self.guided_whitespace_pattern)
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
best_of=self.best_of,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
min_p=self.min_p,
|
||||
seed=self.seed,
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
logprobs=self.top_logprobs if self.logprobs else None,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
ignore_eos=self.ignore_eos,
|
||||
max_tokens=max_tokens,
|
||||
min_tokens=self.min_tokens,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
guided_decoding=guided_decoding,
|
||||
logit_bias=self.logit_bias)
|
||||
|
||||
def _get_guided_json_from_tool(
|
||||
self) -> Optional[Union[str, dict, BaseModel]]:
|
||||
# user has chosen to not use any tool
|
||||
if self.tool_choice == "none" or self.tools is None:
|
||||
return None
|
||||
|
||||
# user has chosen to use a named tool
|
||||
if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||
tool_name = self.tool_choice.function.name
|
||||
tools = {tool.function.name: tool.function for tool in self.tools}
|
||||
if tool_name not in tools:
|
||||
raise ValueError(
|
||||
f"Tool '{tool_name}' has not been passed in `tools`.")
|
||||
tool = tools[tool_name]
|
||||
return tool.parameters
|
||||
|
||||
return None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
if data.get("stream_options") and not data.get("stream"):
|
||||
raise ValueError(
|
||||
"Stream options can only be defined when `stream=True`.")
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||
if data.get("stream") and prompt_logprobs > 0:
|
||||
raise ValueError(
|
||||
"`prompt_logprobs` are not available when `stream=True`.")
|
||||
|
||||
if prompt_logprobs < 0:
|
||||
raise ValueError("`prompt_logprobs` must be a positive value.")
|
||||
|
||||
if (top_logprobs := data.get("top_logprobs")) is not None:
|
||||
if top_logprobs < 0:
|
||||
raise ValueError("`top_logprobs` must be a positive value.")
|
||||
|
||||
if not data.get("logprobs"):
|
||||
raise ValueError(
|
||||
"when using `top_logprobs`, `logprobs` must be set to true."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_guided_decoding_count(cls, data):
|
||||
if isinstance(data, ValueError):
|
||||
raise data
|
||||
|
||||
guide_count = sum([
|
||||
"guided_json" in data and data["guided_json"] is not None,
|
||||
"guided_regex" in data and data["guided_regex"] is not None,
|
||||
"guided_choice" in data and data["guided_choice"] is not None
|
||||
])
|
||||
# you can only use one kind of guided decoding
|
||||
if guide_count > 1:
|
||||
raise ValueError(
|
||||
"You can only use one kind of guided decoding "
|
||||
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||
# you can only either use guided decoding or tools, not both
|
||||
if guide_count > 1 and data.get("tool_choice",
|
||||
"none") not in ("none", "auto"):
|
||||
raise ValueError(
|
||||
"You can only either use guided decoding or tools, not both.")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_tool_usage(cls, data):
|
||||
|
||||
# if "tool_choice" is not specified but tools are provided,
|
||||
# default to "auto" tool_choice
|
||||
if "tool_choice" not in data and data.get("tools"):
|
||||
data["tool_choice"] = "auto"
|
||||
|
||||
# if "tool_choice" is specified -- validation
|
||||
if "tool_choice" in data:
|
||||
|
||||
# ensure that if "tool choice" is specified, tools are present
|
||||
if "tools" not in data or data["tools"] is None:
|
||||
raise ValueError(
|
||||
"When using `tool_choice`, `tools` must be set.")
|
||||
|
||||
# make sure that tool choice is either a named tool
|
||||
# OR that it's set to "auto"
|
||||
if data["tool_choice"] != "auto" and not isinstance(
|
||||
data["tool_choice"], dict):
|
||||
raise ValueError(
|
||||
"`tool_choice` must either be a named tool or \"auto\". "
|
||||
"`tool_choice=\"none\" is not supported.")
|
||||
|
||||
# ensure that if "tool_choice" is specified as an object,
|
||||
# it matches a valid tool
|
||||
if isinstance(data["tool_choice"], dict):
|
||||
valid_tool = False
|
||||
specified_function = data["tool_choice"]["function"]
|
||||
if not specified_function:
|
||||
raise ValueError(
|
||||
"Incorrectly formatted `tool_choice`. Should be like "
|
||||
"`{\"type\": \"function\","
|
||||
" \"function\": {\"name\": \"my_function\"}}`")
|
||||
specified_function_name = specified_function["name"]
|
||||
if not specified_function_name:
|
||||
raise ValueError(
|
||||
"Incorrectly formatted `tool_choice`. Should be like "
|
||||
"`{\"type\": \"function\", "
|
||||
"\"function\": {\"name\": \"my_function\"}}`")
|
||||
for tool in data["tools"]:
|
||||
if tool["function"]["name"] == specified_function_name:
|
||||
valid_tool = True
|
||||
break
|
||||
if not valid_tool:
|
||||
raise ValueError(
|
||||
"The tool specified in `tool_choice` does not match any"
|
||||
" of the specified `tools`")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_generation_prompt(cls, data):
|
||||
if data.get("continue_final_message") and data.get(
|
||||
"add_generation_prompt"):
|
||||
raise ValueError("Cannot set both `continue_final_message` and "
|
||||
"`add_generation_prompt` to True.")
|
||||
return data
|
||||
|
||||
|
||||
class CompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/completions/create
|
||||
model: str
|
||||
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||
best_of: Optional[int] = None
|
||||
echo: Optional[bool] = False
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: Optional[int] = None
|
||||
max_tokens: Optional[int] = 16
|
||||
n: int = 1
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
suffix: Optional[str] = None
|
||||
temperature: Optional[float] = 1.0
|
||||
top_p: Optional[float] = 1.0
|
||||
user: Optional[str] = None
|
||||
|
||||
# doc: begin-completion-sampling-params
|
||||
use_beam_search: bool = False
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
repetition_penalty: float = 1.0
|
||||
length_penalty: float = 1.0
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
include_stop_str_in_output: bool = False
|
||||
ignore_eos: bool = False
|
||||
min_tokens: int = 0
|
||||
skip_special_tokens: bool = True
|
||||
spaces_between_special_tokens: bool = True
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
allowed_token_ids: Optional[List[int]] = None
|
||||
prompt_logprobs: Optional[int] = None
|
||||
# doc: end-completion-sampling-params
|
||||
|
||||
# doc: begin-completion-extra-params
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||
"the prompt."),
|
||||
)
|
||||
response_format: Optional[ResponseFormat] = Field(
|
||||
default=None,
|
||||
description=
|
||||
("Similar to chat completion, this parameter specifies the format of "
|
||||
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
|
||||
"supported."),
|
||||
)
|
||||
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
|
||||
default=None,
|
||||
description="If specified, the output will follow the JSON schema.",
|
||||
)
|
||||
guided_regex: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the output will follow the regex pattern."),
|
||||
)
|
||||
guided_choice: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the output will be exactly one of the choices."),
|
||||
)
|
||||
guided_grammar: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the output will follow the context free grammar."),
|
||||
)
|
||||
guided_decoding_backend: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default guided decoding backend "
|
||||
"of the server for this specific request. If set, must be one of "
|
||||
"'outlines' / 'lm-format-enforcer'"))
|
||||
guided_whitespace_pattern: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default whitespace pattern "
|
||||
"for guided json decoding."))
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
def to_beam_search_params(self,
|
||||
default_max_tokens: int) -> BeamSearchParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
n = self.n if self.n is not None else 1
|
||||
temperature = self.temperature if self.temperature is not None else 0.0
|
||||
|
||||
return BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=max_tokens,
|
||||
ignore_eos=self.ignore_eos,
|
||||
temperature=temperature,
|
||||
length_penalty=self.length_penalty,
|
||||
)
|
||||
|
||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
prompt_logprobs = self.prompt_logprobs
|
||||
if prompt_logprobs is None and self.echo:
|
||||
prompt_logprobs = self.logprobs
|
||||
|
||||
echo_without_generation = self.echo and self.max_tokens == 0
|
||||
|
||||
guided_json_object = None
|
||||
if (self.response_format is not None
|
||||
and self.response_format.type == "json_object"):
|
||||
guided_json_object = True
|
||||
|
||||
guided_decoding = GuidedDecodingParams.from_optional(
|
||||
json=self.guided_json,
|
||||
regex=self.guided_regex,
|
||||
choice=self.guided_choice,
|
||||
grammar=self.guided_grammar,
|
||||
json_object=guided_json_object,
|
||||
backend=self.guided_decoding_backend,
|
||||
whitespace_pattern=self.guided_whitespace_pattern)
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
best_of=self.best_of,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
min_p=self.min_p,
|
||||
seed=self.seed,
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
logprobs=self.logprobs,
|
||||
ignore_eos=self.ignore_eos,
|
||||
max_tokens=max_tokens if not echo_without_generation else 1,
|
||||
min_tokens=self.min_tokens,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
guided_decoding=guided_decoding,
|
||||
logit_bias=self.logit_bias,
|
||||
allowed_token_ids=self.allowed_token_ids)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_guided_decoding_count(cls, data):
|
||||
guide_count = sum([
|
||||
"guided_json" in data and data["guided_json"] is not None,
|
||||
"guided_regex" in data and data["guided_regex"] is not None,
|
||||
"guided_choice" in data and data["guided_choice"] is not None
|
||||
])
|
||||
if guide_count > 1:
|
||||
raise ValueError(
|
||||
"You can only use one kind of guided decoding "
|
||||
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||
if data.get("stream") and prompt_logprobs > 0:
|
||||
raise ValueError(
|
||||
"`prompt_logprobs` are not available when `stream=True`.")
|
||||
|
||||
if prompt_logprobs < 0:
|
||||
raise ValueError("`prompt_logprobs` must be a positive value.")
|
||||
|
||||
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
|
||||
raise ValueError("`logprobs` must be a positive value.")
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
if data.get("stream_options") and not data.get("stream"):
|
||||
raise ValueError(
|
||||
"Stream options can only be defined when `stream=True`.")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class EmbeddingRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/embeddings
|
||||
model: str
|
||||
input: Union[List[int], List[List[int]], str, List[str]]
|
||||
encoding_format: Literal["float", "base64"] = "float"
|
||||
dimensions: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
|
||||
# doc: begin-embedding-pooling-params
|
||||
additional_data: Optional[Any] = None
|
||||
|
||||
# doc: end-embedding-pooling-params
|
||||
|
||||
# doc: begin-embedding-extra-params
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
|
||||
# doc: end-embedding-extra-params
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
|
||||
|
||||
class CompletionLogProbs(OpenAIBaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
tokens: List[str] = Field(default_factory=list)
|
||||
top_logprobs: List[Optional[Dict[str,
|
||||
float]]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CompletionResponseChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[CompletionLogProbs] = None
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Optional[Union[int, str]] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The stop string or token id that caused the completion "
|
||||
"to stop, None if the completion finished for some other reason "
|
||||
"including encountering the EOS token"),
|
||||
)
|
||||
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
|
||||
|
||||
|
||||
class CompletionResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||
object: str = "text_completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[CompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class CompletionResponseStreamChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[CompletionLogProbs] = None
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Optional[Union[int, str]] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The stop string or token id that caused the completion "
|
||||
"to stop, None if the completion finished for some other reason "
|
||||
"including encountering the EOS token"),
|
||||
)
|
||||
|
||||
|
||||
class CompletionStreamResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||
object: str = "text_completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[CompletionResponseStreamChoice]
|
||||
usage: Optional[UsageInfo] = Field(default=None)
|
||||
|
||||
|
||||
class EmbeddingResponseData(OpenAIBaseModel):
|
||||
index: int
|
||||
object: str = "embedding"
|
||||
embedding: Union[List[float], str]
|
||||
|
||||
|
||||
class EmbeddingResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
data: List[EmbeddingResponseData]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class FunctionCall(OpenAIBaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
class DeltaFunctionCall(BaseModel):
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
# a tool call delta where everything is optional
|
||||
class DeltaToolCall(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
|
||||
type: Literal["function"] = "function"
|
||||
index: int
|
||||
function: Optional[DeltaFunctionCall] = None
|
||||
|
||||
|
||||
class ExtractedToolCallInformation(BaseModel):
|
||||
# indicate if tools were called
|
||||
tools_called: bool
|
||||
|
||||
# extracted tool calls
|
||||
tool_calls: List[ToolCall]
|
||||
|
||||
# content - per OpenAI spec, content AND tool calls can be returned rarely
|
||||
# But some models will do this intentionally
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
class ChatMessage(OpenAIBaseModel):
|
||||
role: str
|
||||
content: Optional[str] = None
|
||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatCompletionLogProb(OpenAIBaseModel):
|
||||
token: str
|
||||
logprob: float = -9999.0
|
||||
bytes: Optional[List[int]] = None
|
||||
|
||||
|
||||
class ChatCompletionLogProbsContent(ChatCompletionLogProb):
|
||||
top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatCompletionLogProbs(OpenAIBaseModel):
|
||||
content: Optional[List[ChatCompletionLogProbsContent]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
logprobs: Optional[ChatCompletionLogProbs] = None
|
||||
# per OpenAI spec this is the default
|
||||
finish_reason: Optional[str] = "stop"
|
||||
# not part of the OpenAI spec but included in vLLM for legacy reasons
|
||||
stop_reason: Optional[Union[int, str]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
|
||||
|
||||
|
||||
class DeltaMessage(OpenAIBaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: List[DeltaToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
logprobs: Optional[ChatCompletionLogProbs] = None
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Optional[Union[int, str]] = None
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
usage: Optional[UsageInfo] = Field(default=None)
|
||||
|
||||
|
||||
class BatchRequestInput(OpenAIBaseModel):
|
||||
"""
|
||||
The per-line object of the batch input file.
|
||||
|
||||
NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
|
||||
"""
|
||||
|
||||
# A developer-provided per-request id that will be used to match outputs to
|
||||
# inputs. Must be unique for each request in a batch.
|
||||
custom_id: str
|
||||
|
||||
# The HTTP method to be used for the request. Currently only POST is
|
||||
# supported.
|
||||
method: str
|
||||
|
||||
# The OpenAI API relative URL to be used for the request. Currently
|
||||
# /v1/chat/completions is supported.
|
||||
url: str
|
||||
|
||||
# The parameters of the request.
|
||||
body: Union[ChatCompletionRequest, EmbeddingRequest]
|
||||
|
||||
|
||||
class BatchResponseData(OpenAIBaseModel):
|
||||
# HTTP status code of the response.
|
||||
status_code: int = 200
|
||||
|
||||
# An unique identifier for the API request.
|
||||
request_id: str
|
||||
|
||||
# The body of the response.
|
||||
body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
|
||||
|
||||
|
||||
class BatchRequestOutput(OpenAIBaseModel):
|
||||
"""
|
||||
The per-line object of the batch output and error files
|
||||
"""
|
||||
|
||||
id: str
|
||||
|
||||
# A developer-provided per-request id that will be used to match outputs to
|
||||
# inputs.
|
||||
custom_id: str
|
||||
|
||||
response: Optional[BatchResponseData]
|
||||
|
||||
# For requests that failed with a non-HTTP error, this will contain more
|
||||
# information on the cause of the failure.
|
||||
error: Optional[Any]
|
||||
|
||||
|
||||
class TokenizeCompletionRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
|
||||
add_special_tokens: bool = Field(default=True)
|
||||
|
||||
|
||||
class TokenizeChatRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
|
||||
add_generation_prompt: bool = Field(default=True)
|
||||
continue_final_message: bool = Field(default=False)
|
||||
add_special_tokens: bool = Field(default=False)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_generation_prompt(cls, data):
|
||||
if data.get("continue_final_message") and data.get(
|
||||
"add_generation_prompt"):
|
||||
raise ValueError("Cannot set both `continue_final_message` and "
|
||||
"`add_generation_prompt` to True.")
|
||||
return data
|
||||
|
||||
|
||||
TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
|
||||
|
||||
|
||||
class TokenizeResponse(OpenAIBaseModel):
|
||||
count: int
|
||||
max_model_len: int
|
||||
tokens: List[int]
|
||||
|
||||
|
||||
class DetokenizeRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
tokens: List[int]
|
||||
|
||||
|
||||
class DetokenizeResponse(OpenAIBaseModel):
|
||||
prompt: str
|
||||
|
||||
|
||||
class LoadLoraAdapterRequest(BaseModel):
|
||||
lora_name: str
|
||||
lora_path: str
|
||||
|
||||
|
||||
class UnloadLoraAdapterRequest(BaseModel):
|
||||
lora_name: str
|
||||
lora_int_id: Optional[int] = Field(default=None)
|
||||
285
vllm/entrypoints/openai/run_batch.py
Normal file
285
vllm/entrypoints/openai/run_batch.py
Normal file
@@ -0,0 +1,285 @@
|
||||
import asyncio
|
||||
from http import HTTPStatus
|
||||
from io import StringIO
|
||||
from typing import Awaitable, Callable, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from prometheus_client import start_http_server
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.logger import RequestLogger, logger
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
||||
BatchRequestOutput,
|
||||
BatchResponseData,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingResponse, ErrorResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, random_uuid
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM OpenAI-Compatible batch runner.")
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input-file",
|
||||
required=True,
|
||||
type=str,
|
||||
help=
|
||||
"The path or url to a single input file. Currently supports local file "
|
||||
"paths, or the http protocol (http or https). If a URL is specified, "
|
||||
"the file should be available via HTTP GET.")
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-file",
|
||||
required=True,
|
||||
type=str,
|
||||
help="The path or url to a single output file. Currently supports "
|
||||
"local file paths, or web (http or https) urls. If a URL is specified,"
|
||||
" the file should be available via HTTP PUT.")
|
||||
parser.add_argument("--response-role",
|
||||
type=nullable_str,
|
||||
default="assistant",
|
||||
help="The role name to return if "
|
||||
"`request.add_generation_prompt=True`.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument('--max-log-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Max number of prompt characters or prompt '
|
||||
'ID numbers being printed in log.'
|
||||
'\n\nDefault: Unlimited')
|
||||
|
||||
parser.add_argument("--enable-metrics",
|
||||
action="store_true",
|
||||
help="Enable Prometheus metrics")
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="URL to the Prometheus metrics server "
|
||||
"(only needed if enable-metrics is set).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port number for the Prometheus metrics server "
|
||||
"(only needed if enable-metrics is set).",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# explicitly use pure text format, with a newline at the end
|
||||
# this makes it impossible to see the animation in the progress bar
|
||||
# but will avoid messing up with ray or multiprocessing, which wraps
|
||||
# each line of output with some prefix.
|
||||
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
|
||||
|
||||
|
||||
class BatchProgressTracker:
|
||||
|
||||
def __init__(self):
|
||||
self._total = 0
|
||||
self._pbar: Optional[tqdm] = None
|
||||
|
||||
def submitted(self):
|
||||
self._total += 1
|
||||
|
||||
def completed(self):
|
||||
if self._pbar:
|
||||
self._pbar.update()
|
||||
|
||||
def pbar(self) -> tqdm:
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
self._pbar = tqdm(total=self._total,
|
||||
unit="req",
|
||||
desc="Running batch",
|
||||
mininterval=5,
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT)
|
||||
return self._pbar
|
||||
|
||||
|
||||
async def read_file(path_or_url: str) -> str:
|
||||
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
||||
async with aiohttp.ClientSession() as session, \
|
||||
session.get(path_or_url) as resp:
|
||||
return await resp.text()
|
||||
else:
|
||||
with open(path_or_url, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
async def write_file(path_or_url: str, data: str) -> None:
|
||||
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
||||
async with aiohttp.ClientSession() as session, \
|
||||
session.put(path_or_url, data=data.encode("utf-8")):
|
||||
pass
|
||||
else:
|
||||
# We should make this async, but as long as this is always run as a
|
||||
# standalone program, blocking the event loop won't effect performance
|
||||
# in this particular case.
|
||||
with open(path_or_url, "w", encoding="utf-8") as f:
|
||||
f.write(data)
|
||||
|
||||
|
||||
def make_error_request_output(request: BatchRequestInput,
|
||||
error_msg: str) -> BatchRequestOutput:
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
request_id=f"vllm-batch-{random_uuid()}",
|
||||
),
|
||||
error=error_msg,
|
||||
)
|
||||
return batch_output
|
||||
|
||||
|
||||
async def make_async_error_request_output(
|
||||
request: BatchRequestInput, error_msg: str) -> BatchRequestOutput:
|
||||
return make_error_request_output(request, error_msg)
|
||||
|
||||
|
||||
async def run_request(serving_engine_func: Callable,
|
||||
request: BatchRequestInput,
|
||||
tracker: BatchProgressTracker) -> BatchRequestOutput:
|
||||
response = await serving_engine_func(request.body)
|
||||
|
||||
if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
body=response, request_id=f"vllm-batch-{random_uuid()}"),
|
||||
error=None,
|
||||
)
|
||||
elif isinstance(response, ErrorResponse):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
status_code=response.code,
|
||||
request_id=f"vllm-batch-{random_uuid()}"),
|
||||
error=response,
|
||||
)
|
||||
else:
|
||||
batch_output = make_error_request_output(
|
||||
request, error_msg="Request must not be sent in stream mode")
|
||||
|
||||
tracker.completed()
|
||||
return batch_output
|
||||
|
||||
|
||||
async def main(args):
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
|
||||
|
||||
model_config = await engine.get_model_config()
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model)
|
||||
for name in served_model_names
|
||||
]
|
||||
|
||||
if args.disable_log_requests:
|
||||
request_logger = None
|
||||
else:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
# Create the openai serving objects.
|
||||
openai_serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
args.response_role,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
)
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
|
||||
tracker = BatchProgressTracker()
|
||||
logger.info("Reading batch from %s...", args.input_file)
|
||||
|
||||
# Submit all requests in the file to the engine "concurrently".
|
||||
response_futures: List[Awaitable[BatchRequestOutput]] = []
|
||||
for request_json in (await read_file(args.input_file)).strip().split("\n"):
|
||||
# Skip empty lines.
|
||||
request_json = request_json.strip()
|
||||
if not request_json:
|
||||
continue
|
||||
|
||||
request = BatchRequestInput.model_validate_json(request_json)
|
||||
|
||||
# Determine the type of request and run it.
|
||||
if request.url == "/v1/chat/completions":
|
||||
response_futures.append(
|
||||
run_request(openai_serving_chat.create_chat_completion,
|
||||
request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url == "/v1/embeddings":
|
||||
response_futures.append(
|
||||
run_request(openai_serving_embedding.create_embedding, request,
|
||||
tracker))
|
||||
tracker.submitted()
|
||||
else:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="Only /v1/chat/completions and "
|
||||
"/v1/embeddings are supported in the batch endpoint.",
|
||||
))
|
||||
|
||||
with tracker.pbar():
|
||||
responses = await asyncio.gather(*response_futures)
|
||||
|
||||
output_buffer = StringIO()
|
||||
for response in responses:
|
||||
print(response.model_dump_json(), file=output_buffer)
|
||||
|
||||
output_buffer.seek(0)
|
||||
await write_file(args.output_file, output_buffer.read().strip())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
logger.info("vLLM batch processing API version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
# Start the Prometheus metrics server. LLMEngine uses the Prometheus client
|
||||
# to publish metrics at the /metrics endpoint.
|
||||
if args.enable_metrics:
|
||||
logger.info("Prometheus metrics enabled")
|
||||
start_http_server(port=args.port, addr=args.url)
|
||||
else:
|
||||
logger.info("Prometheus metrics disabled")
|
||||
|
||||
asyncio.run(main(args))
|
||||
891
vllm/entrypoints/openai/serving_chat.py
Normal file
891
vllm/entrypoints/openai/serving_chat.py
Normal file
@@ -0,0 +1,891 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List,
|
||||
Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
load_chat_template,
|
||||
parse_chat_messages_futures)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionLogProb, ChatCompletionLogProbs,
|
||||
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata,
|
||||
ToolCall, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
LoRAModulePath,
|
||||
OpenAIServing,
|
||||
PromptAdapterPath,
|
||||
TextTokensPrompt)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import iterate_with_cancellation, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
def __init__(self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
response_role: str,
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
enable_auto_tools: bool = False,
|
||||
tool_parser: Optional[str] = None):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||
|
||||
self.response_role = response_role
|
||||
self.use_tool_use_model_template = False
|
||||
self.chat_template = load_chat_template(chat_template)
|
||||
|
||||
# set up tool use
|
||||
self.enable_auto_tools: bool = enable_auto_tools
|
||||
if self.enable_auto_tools:
|
||||
logger.info(
|
||||
"\"auto\" tool choice has been enabled please note that while"
|
||||
" the parallel_tool_calls client option is preset for "
|
||||
"compatibility reasons, it will be ignored.")
|
||||
|
||||
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
|
||||
if self.enable_auto_tools:
|
||||
try:
|
||||
self.tool_parser = ToolParserManager.get_tool_parser(
|
||||
tool_parser)
|
||||
except Exception as e:
|
||||
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||
f"tool_parser:'{tool_parser}' which has not "
|
||||
"been registered") from e
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
|
||||
ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/chat/create
|
||||
for the API specification. This API mimics the OpenAI
|
||||
ChatCompletion API.
|
||||
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
logger.error("Error with model %s", error_check_ret)
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
model_config = self.model_config
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
conversation, mm_data_future = parse_chat_messages_futures(
|
||||
request.messages, model_config, tokenizer)
|
||||
|
||||
tool_dicts = None if request.tools is None else [
|
||||
tool.model_dump() for tool in request.tools
|
||||
]
|
||||
|
||||
prompt: Union[str, List[int]]
|
||||
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
|
||||
if is_mistral_tokenizer:
|
||||
prompt = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
tools=tool_dicts,
|
||||
documents=request.documents,
|
||||
**(request.chat_template_kwargs or {}),
|
||||
)
|
||||
else:
|
||||
prompt = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
conversation=conversation,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
tools=tool_dicts,
|
||||
documents=request.documents,
|
||||
**(request.chat_template_kwargs or {}),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error in applying chat template from request")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
try:
|
||||
mm_data = await mm_data_future
|
||||
except Exception as e:
|
||||
logger.exception("Error in loading multi-modal data")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# validation for OpenAI tools
|
||||
# tool_choice = "required" is not supported
|
||||
if request.tool_choice == "required":
|
||||
return self.create_error_response(
|
||||
"tool_choice = \"required\" is not supported!")
|
||||
|
||||
if not is_mistral_tokenizer and request.tool_choice == "auto" and not (
|
||||
self.enable_auto_tools and self.tool_parser is not None):
|
||||
# for hf tokenizers, "auto" tools requires
|
||||
# --enable-auto-tool-choice and --tool-call-parser
|
||||
return self.create_error_response(
|
||||
"\"auto\" tool choice requires "
|
||||
"--enable-auto-tool-choice and --tool-call-parser to be set")
|
||||
|
||||
request_id = f"chat-{random_uuid()}"
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
if self.enable_auto_tools and self.tool_parser:
|
||||
request = self.tool_parser(tokenizer).adjust_request(
|
||||
request=request)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt_inputs = self._tokenize_prompt_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
assert isinstance(prompt, list) and isinstance(
|
||||
prompt[0], int
|
||||
), "Prompt has to be either a string or a list of token ids"
|
||||
prompt_inputs = TextTokensPrompt(
|
||||
prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)
|
||||
|
||||
assert prompt_inputs is not None
|
||||
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
prompt_inputs["prompt_token_ids"])
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
default_max_tokens)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens)
|
||||
|
||||
self._log_inputs(request_id,
|
||||
prompt_inputs,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
engine_inputs = TokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
||||
if mm_data is not None:
|
||||
engine_inputs["multi_modal_data"] = mm_data
|
||||
|
||||
is_tracing_enabled = (await
|
||||
self.engine_client.is_tracing_enabled())
|
||||
trace_headers = None
|
||||
if is_tracing_enabled and raw_request:
|
||||
trace_headers = extract_trace_headers(raw_request.headers)
|
||||
if (not is_tracing_enabled and raw_request
|
||||
and contains_trace_headers(raw_request.headers)):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
assert isinstance(self.engine_client,
|
||||
(AsyncLLMEngine,
|
||||
MQLLMEngineClient)), \
|
||||
"Beam search is only supported with" \
|
||||
"AsyncLLMEngine and MQLLMEngineClient."
|
||||
result_generator = self.engine_client.beam_search(
|
||||
engine_inputs['prompt_token_ids'],
|
||||
request_id,
|
||||
sampling_params,
|
||||
)
|
||||
else:
|
||||
result_generator = self.engine_client.generate(
|
||||
engine_inputs,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=request.priority,
|
||||
)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if raw_request:
|
||||
result_generator = iterate_with_cancellation(
|
||||
result_generator, raw_request.is_disconnected)
|
||||
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
request, result_generator, request_id, conversation, tokenizer,
|
||||
request_metadata)
|
||||
|
||||
try:
|
||||
return await self.chat_completion_full_generator(
|
||||
request, result_generator, request_id, conversation, tokenizer,
|
||||
request_metadata)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
|
||||
if request.add_generation_prompt:
|
||||
return self.response_role
|
||||
return request.messages[-1]["role"]
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str,
|
||||
conversation: List[ConversationMessage],
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
model_name = self.base_model_paths[0].name
|
||||
created_time = int(time.time())
|
||||
chunk_object_type: Final = "chat.completion.chunk"
|
||||
first_iteration = True
|
||||
|
||||
# Send response for each token for each request.n (index)
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_num_tokens = [0] * num_choices
|
||||
finish_reason_sent = [False] * num_choices
|
||||
num_prompt_tokens = 0
|
||||
|
||||
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||
tool_choice_function_name = request.tool_choice.function.name
|
||||
else:
|
||||
tool_choice_function_name = None
|
||||
|
||||
# Determine whether tools are in use with "auto" tool choice
|
||||
tool_choice_auto = (
|
||||
not tool_choice_function_name
|
||||
and self._should_stream_with_auto_tool_parsing(request))
|
||||
|
||||
all_previous_token_ids: Optional[List[List[int]]]
|
||||
if tool_choice_auto:
|
||||
# These are only required in "auto" tool choice case
|
||||
previous_texts = [""] * num_choices
|
||||
all_previous_token_ids = [[]] * num_choices
|
||||
else:
|
||||
previous_texts, all_previous_token_ids = None, None
|
||||
|
||||
# Prepare the tool parser if it's needed
|
||||
try:
|
||||
if tool_choice_auto and self.tool_parser:
|
||||
tool_parsers: List[Optional[ToolParser]] = [
|
||||
self.tool_parser(tokenizer)
|
||||
] * num_choices
|
||||
else:
|
||||
tool_parsers = [None] * num_choices
|
||||
except RuntimeError as e:
|
||||
logger.error("Error in tool parser creation: %s", e)
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
if res.prompt_token_ids is not None:
|
||||
num_prompt_tokens = len(res.prompt_token_ids)
|
||||
if res.encoder_prompt_token_ids is not None:
|
||||
num_prompt_tokens += len(res.encoder_prompt_token_ids)
|
||||
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
# response (by the try...catch).
|
||||
if first_iteration:
|
||||
# Send first response for each request.n (index) with
|
||||
# the role
|
||||
role = self.get_chat_request_role(request)
|
||||
|
||||
# NOTE num_choices defaults to 1 so this usually executes
|
||||
# once per request
|
||||
for i in range(num_choices):
|
||||
tool_parser = tool_parsers[i]
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(
|
||||
role=role,
|
||||
content="",
|
||||
),
|
||||
logprobs=None,
|
||||
finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
|
||||
# if usage should be included
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
# if continuous usage stats are requested, add it
|
||||
if request.stream_options.continuous_usage_stats:
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=num_prompt_tokens)
|
||||
chunk.usage = usage
|
||||
# otherwise don't
|
||||
else:
|
||||
chunk.usage = None
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send response to echo the input portion of the
|
||||
# last message
|
||||
if request.echo or request.continue_final_message:
|
||||
last_msg_content: str = ""
|
||||
if conversation and "content" in conversation[
|
||||
-1] and conversation[-1].get("role") == role:
|
||||
last_msg_content = conversation[-1]["content"] or ""
|
||||
|
||||
if last_msg_content:
|
||||
for i in range(num_choices):
|
||||
choice_data = (
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(
|
||||
content=last_msg_content),
|
||||
logprobs=None,
|
||||
finish_reason=None))
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
if (request.stream_options and
|
||||
request.stream_options.include_usage):
|
||||
if (request.stream_options.
|
||||
continuous_usage_stats):
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=num_prompt_tokens)
|
||||
chunk.usage = usage
|
||||
else:
|
||||
chunk.usage = None
|
||||
|
||||
data = chunk.model_dump_json(
|
||||
exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
first_iteration = False
|
||||
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
tool_parser = tool_parsers[i]
|
||||
|
||||
if finish_reason_sent[i]:
|
||||
continue
|
||||
|
||||
if request.logprobs and request.top_logprobs is not None:
|
||||
assert output.logprobs is not None, (
|
||||
"Did not output logprobs")
|
||||
logprobs = self._create_chat_logprobs(
|
||||
token_ids=output.token_ids,
|
||||
top_logprobs=output.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
num_output_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
delta_text = output.text
|
||||
delta_message: Optional[DeltaMessage]
|
||||
|
||||
# handle streaming deltas for tools with named tool_choice
|
||||
if tool_choice_function_name:
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(function=DeltaFunctionCall(
|
||||
name=tool_choice_function_name,
|
||||
arguments=delta_text),
|
||||
index=i)
|
||||
])
|
||||
|
||||
# handle streaming deltas for tools with "auto" tool choice
|
||||
elif tool_choice_auto:
|
||||
assert previous_texts is not None
|
||||
assert all_previous_token_ids is not None
|
||||
assert tool_parser is not None
|
||||
#TODO optimize manipulation of these lists
|
||||
previous_text = previous_texts[i]
|
||||
previous_token_ids = all_previous_token_ids[i]
|
||||
current_text = previous_text + delta_text
|
||||
current_token_ids = previous_token_ids + list(
|
||||
output.token_ids)
|
||||
|
||||
delta_message = (
|
||||
tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=output.token_ids,
|
||||
request=request))
|
||||
|
||||
# update the previous values for the next iteration
|
||||
previous_texts[i] = current_text
|
||||
all_previous_token_ids[i] = current_token_ids
|
||||
|
||||
# handle streaming just a content delta
|
||||
else:
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
|
||||
# set the previous values for the next iteration
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
|
||||
# if the message delta is None (e.g. because it was a
|
||||
# "control token" for tool calls or the parser otherwise
|
||||
# wasn't ready to send a token, then
|
||||
# get the next token without streaming a chunk
|
||||
if delta_message is None:
|
||||
continue
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Send token-by-token response for each request.n
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=delta_message,
|
||||
logprobs=logprobs,
|
||||
finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
|
||||
# handle usage stats if requested & if continuous
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
if request.stream_options.continuous_usage_stats:
|
||||
completion_tokens = len(output.token_ids)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens +
|
||||
completion_tokens,
|
||||
)
|
||||
chunk.usage = usage
|
||||
else:
|
||||
chunk.usage = None
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# if the model is finished generating
|
||||
else:
|
||||
# check to make sure we haven't "forgotten" to stream
|
||||
# any tokens that were generated but previously
|
||||
# matched by partial json parsing
|
||||
# only happens if we are NOT using guided decoding
|
||||
auto_tools_called = False
|
||||
if tool_parser:
|
||||
auto_tools_called = len(
|
||||
tool_parser.prev_tool_call_arr) > 0
|
||||
index = len(tool_parser.prev_tool_call_arr
|
||||
) - 1 if auto_tools_called else 0
|
||||
else:
|
||||
index = 0
|
||||
|
||||
if self._should_check_for_unstreamed_tool_arg_tokens(
|
||||
delta_message, output) and tool_parser:
|
||||
# get the expected call based on partial JSON
|
||||
# parsing which "autocompletes" the JSON
|
||||
expected_call = json.dumps(
|
||||
tool_parser.prev_tool_call_arr[index].get(
|
||||
"arguments", {}))
|
||||
|
||||
# get what we've streamed so far for arguments
|
||||
# for the current tool
|
||||
actual_call = tool_parser.streamed_args_for_tool[
|
||||
index]
|
||||
|
||||
# check to see if there's anything left to stream
|
||||
remaining_call = expected_call.replace(
|
||||
actual_call, "", 1)
|
||||
|
||||
# set that as a delta message
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=remaining_call).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
|
||||
# Send the finish response for each request.n only once
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=delta_message,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason
|
||||
if not auto_tools_called else "tool_calls",
|
||||
stop_reason=output.stop_reason)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
if request.stream_options.continuous_usage_stats:
|
||||
completion_tokens = len(output.token_ids)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens +
|
||||
completion_tokens,
|
||||
)
|
||||
chunk.usage = usage
|
||||
else:
|
||||
chunk.usage = None
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
finish_reason_sent[i] = True
|
||||
|
||||
# once the final token is handled, if stream_options.include_usage
|
||||
# is sent, send the usage
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
completion_tokens = previous_num_tokens[i]
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
final_usage_chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[],
|
||||
model=model_name,
|
||||
usage=final_usage)
|
||||
final_usage_data = (final_usage_chunk.model_dump_json(
|
||||
exclude_unset=True, exclude_none=True))
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
num_completion_tokens = sum(previous_num_tokens)
|
||||
request_metadata.final_usage_info = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_completion_tokens,
|
||||
total_tokens=num_prompt_tokens + num_completion_tokens)
|
||||
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.error("error in chat completion stream generator: %s", e)
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def chat_completion_full_generator(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str,
|
||||
conversation: List[ConversationMessage],
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||
|
||||
model_name = self.base_model_paths[0].name
|
||||
created_time = int(time.time())
|
||||
final_res: Optional[RequestOutput] = None
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
final_res = res
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
|
||||
assert final_res is not None
|
||||
|
||||
choices: List[ChatCompletionResponseChoice] = []
|
||||
|
||||
role = self.get_chat_request_role(request)
|
||||
for output in final_res.outputs:
|
||||
token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
|
||||
if request.logprobs and request.top_logprobs is not None:
|
||||
assert out_logprobs is not None, "Did not output logprobs"
|
||||
logprobs = self._create_chat_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=request.top_logprobs,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
# In the OpenAI API the finish_reason is "tools_called"
|
||||
# if the tool choice is auto and the model produced a tool
|
||||
# call. The same is not true for named function calls
|
||||
auto_tools_called = False
|
||||
|
||||
# if auto tools are not enabled, and a named tool choice using
|
||||
# outlines is not being used
|
||||
if (not self.enable_auto_tools
|
||||
or not self.tool_parser) and not isinstance(
|
||||
request.tool_choice,
|
||||
ChatCompletionNamedToolChoiceParam):
|
||||
message = ChatMessage(role=role, content=output.text)
|
||||
|
||||
# if the request uses tools and specified a tool choice
|
||||
elif request.tool_choice and type(
|
||||
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(function=FunctionCall(
|
||||
name=request.tool_choice.function.name,
|
||||
arguments=output.text))
|
||||
])
|
||||
|
||||
# if the request doesn't use tool choice
|
||||
# OR specifies to not use a tool
|
||||
elif not request.tool_choice or request.tool_choice == "none":
|
||||
|
||||
message = ChatMessage(role=role, content=output.text)
|
||||
|
||||
# handle when there are tools and tool choice is auto
|
||||
elif request.tools and (
|
||||
request.tool_choice == "auto"
|
||||
or request.tool_choice is None) and self.enable_auto_tools \
|
||||
and self.tool_parser:
|
||||
|
||||
try:
|
||||
tool_parser = self.tool_parser(tokenizer)
|
||||
except RuntimeError as e:
|
||||
logger.error("Error in tool parser creation: %s", e)
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
tool_call_info = tool_parser.extract_tool_calls(
|
||||
output.text, request=request)
|
||||
# In the OpenAI API the finish_reason is "tools_called"
|
||||
# if the tool choice is auto and the model produced a tool
|
||||
# call. The same is not true for named function calls
|
||||
auto_tools_called = tool_call_info.tools_called
|
||||
if tool_call_info.tools_called:
|
||||
message = ChatMessage(role=role,
|
||||
content=tool_call_info.content,
|
||||
tool_calls=tool_call_info.tool_calls)
|
||||
|
||||
else:
|
||||
# FOR NOW make it a chat message; we will have to detect
|
||||
# the type to make it later.
|
||||
message = ChatMessage(role=role, content=output.text)
|
||||
|
||||
# undetermined case that is still important to handle
|
||||
else:
|
||||
logger.error(
|
||||
"Error in chat_completion_full_generator - cannot determine"
|
||||
" if tools should be extracted. Returning a standard chat "
|
||||
"completion.")
|
||||
message = ChatMessage(role=role, content=output.text)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=message,
|
||||
logprobs=logprobs,
|
||||
finish_reason="tool_calls" if auto_tools_called else
|
||||
output.finish_reason if output.finish_reason else "stop",
|
||||
stop_reason=output.stop_reason)
|
||||
choices.append(choice_data)
|
||||
|
||||
if request.echo or request.continue_final_message:
|
||||
last_msg_content = ""
|
||||
if conversation and "content" in conversation[-1] and conversation[
|
||||
-1].get("role") == role:
|
||||
last_msg_content = conversation[-1]["content"] or ""
|
||||
|
||||
for choice in choices:
|
||||
full_message = last_msg_content + (choice.message.content
|
||||
or "")
|
||||
choice.message.content = full_message
|
||||
|
||||
assert final_res.prompt_token_ids is not None
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
if final_res.encoder_prompt_token_ids is not None:
|
||||
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
|
||||
request_metadata.final_usage_info = usage
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
prompt_logprobs=final_res.prompt_logprobs,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _get_top_logprobs(
|
||||
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
|
||||
tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
|
||||
return [
|
||||
ChatCompletionLogProb(token=(token := self._get_decoded_token(
|
||||
p[1],
|
||||
p[0],
|
||||
tokenizer,
|
||||
return_as_token_id=self.return_tokens_as_token_ids)),
|
||||
logprob=max(p[1].logprob, -9999.0),
|
||||
bytes=list(
|
||||
token.encode("utf-8", errors="replace")))
|
||||
for i, p in enumerate(logprobs.items())
|
||||
if top_logprobs and i < top_logprobs
|
||||
]
|
||||
|
||||
def _create_chat_logprobs(
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||
tokenizer: AnyTokenizer,
|
||||
num_output_top_logprobs: Optional[int] = None,
|
||||
) -> ChatCompletionLogProbs:
|
||||
"""Create OpenAI-style logprobs."""
|
||||
logprobs_content: List[ChatCompletionLogProbsContent] = []
|
||||
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
token = tokenizer.decode(token_id)
|
||||
if self.return_tokens_as_token_ids:
|
||||
token = f"token_id:{token_id}"
|
||||
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=token,
|
||||
bytes=list(token.encode("utf-8", errors="replace")),
|
||||
))
|
||||
else:
|
||||
step_token = step_top_logprobs[token_id]
|
||||
step_decoded = step_token.decoded_token
|
||||
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=self._get_decoded_token(
|
||||
step_token,
|
||||
token_id,
|
||||
tokenizer,
|
||||
self.return_tokens_as_token_ids,
|
||||
),
|
||||
logprob=max(step_token.logprob, -9999.0),
|
||||
bytes=None if step_decoded is None else list(
|
||||
step_decoded.encode("utf-8", errors="replace")),
|
||||
top_logprobs=self._get_top_logprobs(
|
||||
step_top_logprobs,
|
||||
num_output_top_logprobs,
|
||||
tokenizer,
|
||||
),
|
||||
))
|
||||
|
||||
return ChatCompletionLogProbs(content=logprobs_content)
|
||||
|
||||
def _should_stream_with_auto_tool_parsing(self,
|
||||
request: ChatCompletionRequest):
|
||||
"""
|
||||
Utility function to check if streamed tokens should go through the tool
|
||||
call parser that was configured.
|
||||
|
||||
We only want to do this IF user-provided tools are set, a tool parser
|
||||
is configured, "auto" tool choice is enabled, and the request's tool
|
||||
choice field indicates that "auto" tool choice should be used.
|
||||
"""
|
||||
return (request.tools and self.tool_parser and self.enable_auto_tools
|
||||
and request.tool_choice in ['auto', None])
|
||||
|
||||
def _should_check_for_unstreamed_tool_arg_tokens(
|
||||
self,
|
||||
delta_message: Optional[DeltaMessage],
|
||||
output: CompletionOutput,
|
||||
) -> bool:
|
||||
"""
|
||||
Check to see if we should check for unstreamed tool arguments tokens.
|
||||
This is only applicable when auto tool parsing is enabled, the delta
|
||||
is a tool call with arguments.
|
||||
"""
|
||||
|
||||
# yapf: disable
|
||||
return bool(
|
||||
# if there is a delta message that includes tool calls which
|
||||
# include a function that has arguments
|
||||
output.finish_reason is not None
|
||||
and self.enable_auto_tools and self.tool_parser and delta_message
|
||||
and delta_message.tool_calls and delta_message.tool_calls[0]
|
||||
and delta_message.tool_calls[0].function
|
||||
and delta_message.tool_calls[0].function.arguments is not None
|
||||
)
|
||||
554
vllm/entrypoints/openai/serving_completion.py
Normal file
554
vllm/entrypoints/openai/serving_completion.py
Normal file
@@ -0,0 +1,554 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
|
||||
Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Tuple, Union, cast
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
ErrorResponse,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
LoRAModulePath,
|
||||
OpenAIServing,
|
||||
PromptAdapterPath)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import merge_async_iterators, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
TypeTokenIDs = List[int]
|
||||
TypeTopLogProbs = List[Optional[Dict[int, float]]]
|
||||
TypeCreateLogProbsFn = Callable[
|
||||
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
|
||||
|
||||
|
||||
class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||
|
||||
async def create_completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following feature:
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
# Return error for unsupported features.
|
||||
if request.suffix is not None:
|
||||
return self.create_error_response(
|
||||
"suffix is not currently supported")
|
||||
|
||||
model_name = self.base_model_paths[0].name
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
created_time = int(time.time())
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
prompts = list(
|
||||
self._tokenize_prompt_input_or_inputs(
|
||||
request,
|
||||
tokenizer,
|
||||
request.prompt,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
))
|
||||
|
||||
for i, prompt_inputs in enumerate(prompts):
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
prompt_inputs["prompt_token_ids"])
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
default_max_tokens)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens)
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
prompt_inputs,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
is_tracing_enabled = (await
|
||||
self.engine_client.is_tracing_enabled())
|
||||
trace_headers = None
|
||||
if is_tracing_enabled:
|
||||
trace_headers = extract_trace_headers(raw_request.headers)
|
||||
if not is_tracing_enabled and contains_trace_headers(
|
||||
raw_request.headers):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
assert isinstance(self.engine_client,
|
||||
(AsyncLLMEngine,
|
||||
MQLLMEngineClient)), \
|
||||
"Beam search is only supported with" \
|
||||
"AsyncLLMEngine and MQLLMEngineClient."
|
||||
generator = self.engine_client.beam_search(
|
||||
prompt_inputs["prompt_token_ids"],
|
||||
request_id_item,
|
||||
sampling_params,
|
||||
)
|
||||
else:
|
||||
generator = self.engine_client.generate(
|
||||
{
|
||||
"prompt_token_ids":
|
||||
prompt_inputs["prompt_token_ids"]
|
||||
},
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(
|
||||
*generators, is_cancelled=raw_request.is_disconnected)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. In addition, we do not stream the results when use
|
||||
# beam search.
|
||||
stream = (request.stream
|
||||
and (request.best_of is None or request.n == request.best_of)
|
||||
and not request.use_beam_search)
|
||||
|
||||
# Streaming response
|
||||
if stream:
|
||||
return self.completion_stream_generator(
|
||||
request,
|
||||
result_generator,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
num_prompts=len(prompts),
|
||||
tokenizer=tokenizer,
|
||||
request_metadata=request_metadata)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
for i, final_res in enumerate(final_res_batch):
|
||||
assert final_res is not None
|
||||
|
||||
# The output should contain the input text
|
||||
# We did not pass it into vLLM engine to avoid being redundant
|
||||
# with the inputs token IDs
|
||||
if final_res.prompt is None:
|
||||
final_res.prompt = prompts[i]["prompt"]
|
||||
|
||||
final_res_batch_checked = cast(List[RequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = self.request_output_to_completion_response(
|
||||
final_res_batch_checked,
|
||||
request,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
if request.stream:
|
||||
response_json = response.model_dump_json()
|
||||
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return fake_stream_generator()
|
||||
|
||||
return response
|
||||
|
||||
async def completion_stream_generator(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
num_prompts: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_text_lens = [0] * num_choices * num_prompts
|
||||
previous_num_tokens = [0] * num_choices * num_prompts
|
||||
has_echoed = [False] * num_choices * num_prompts
|
||||
num_prompt_tokens = [0] * num_prompts
|
||||
|
||||
try:
|
||||
async for prompt_idx, res in result_generator:
|
||||
prompt_token_ids = res.prompt_token_ids
|
||||
prompt_logprobs = res.prompt_logprobs
|
||||
prompt_text = res.prompt
|
||||
|
||||
# Prompt details are excluded from later streamed outputs
|
||||
if res.prompt_token_ids is not None:
|
||||
num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)
|
||||
|
||||
delta_token_ids: GenericSequence[int]
|
||||
out_logprobs: Optional[GenericSequence[Optional[Dict[
|
||||
int, Logprob]]]]
|
||||
|
||||
for output in res.outputs:
|
||||
i = output.index + prompt_idx * num_choices
|
||||
# TODO(simon): optimize the performance by avoiding full
|
||||
# text O(n^2) sending.
|
||||
|
||||
assert request.max_tokens is not None
|
||||
if request.echo and request.max_tokens == 0:
|
||||
assert prompt_token_ids is not None
|
||||
assert prompt_text is not None
|
||||
# only return the prompt
|
||||
delta_text = prompt_text
|
||||
delta_token_ids = prompt_token_ids
|
||||
out_logprobs = prompt_logprobs
|
||||
has_echoed[i] = True
|
||||
elif (request.echo and request.max_tokens > 0
|
||||
and not has_echoed[i]):
|
||||
assert prompt_token_ids is not None
|
||||
assert prompt_text is not None
|
||||
assert prompt_logprobs is not None
|
||||
# echo the prompt and first token
|
||||
delta_text = prompt_text + output.text
|
||||
delta_token_ids = [
|
||||
*prompt_token_ids, *output.token_ids
|
||||
]
|
||||
out_logprobs = [
|
||||
*prompt_logprobs,
|
||||
*(output.logprobs or []),
|
||||
]
|
||||
has_echoed[i] = True
|
||||
else:
|
||||
# return just the delta
|
||||
delta_text = output.text
|
||||
delta_token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert out_logprobs is not None, (
|
||||
"Did not output logprobs")
|
||||
logprobs = self._create_completion_logprobs(
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
initial_text_offset=previous_text_lens[i],
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
previous_text_lens[i] += len(output.text)
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
finish_reason = output.finish_reason
|
||||
stop_reason = output.stop_reason
|
||||
|
||||
chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
CompletionResponseStreamChoice(
|
||||
index=i,
|
||||
text=delta_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
])
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
if (request.stream_options.continuous_usage_stats
|
||||
or output.finish_reason is not None):
|
||||
prompt_tokens = num_prompt_tokens[prompt_idx]
|
||||
completion_tokens = previous_num_tokens[i]
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
if request.stream_options.continuous_usage_stats:
|
||||
chunk.usage = usage
|
||||
else:
|
||||
chunk.usage = None
|
||||
|
||||
response_json = chunk.model_dump_json(exclude_unset=False)
|
||||
yield f"data: {response_json}\n\n"
|
||||
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
final_usage_chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[],
|
||||
usage=usage,
|
||||
)
|
||||
final_usage_data = (final_usage_chunk.model_dump_json(
|
||||
exclude_unset=False, exclude_none=True))
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
total_prompt_tokens = sum(num_prompt_tokens)
|
||||
total_completion_tokens = sum(previous_num_tokens)
|
||||
request_metadata.final_usage_info = UsageInfo(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=total_completion_tokens,
|
||||
total_tokens=total_prompt_tokens + total_completion_tokens)
|
||||
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
def request_output_to_completion_response(
|
||||
self,
|
||||
final_res_batch: List[RequestOutput],
|
||||
request: CompletionRequest,
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> CompletionResponse:
|
||||
choices: List[CompletionResponseChoice] = []
|
||||
num_prompt_tokens = 0
|
||||
num_generated_tokens = 0
|
||||
|
||||
for final_res in final_res_batch:
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
assert prompt_token_ids is not None
|
||||
prompt_logprobs = final_res.prompt_logprobs
|
||||
prompt_text = final_res.prompt
|
||||
|
||||
token_ids: GenericSequence[int]
|
||||
out_logprobs: Optional[GenericSequence[Optional[Dict[int,
|
||||
Logprob]]]]
|
||||
|
||||
for output in final_res.outputs:
|
||||
assert request.max_tokens is not None
|
||||
if request.echo and request.max_tokens == 0:
|
||||
assert prompt_text is not None
|
||||
token_ids = prompt_token_ids
|
||||
out_logprobs = prompt_logprobs
|
||||
output_text = prompt_text
|
||||
elif request.echo and request.max_tokens > 0:
|
||||
assert prompt_text is not None
|
||||
token_ids = [*prompt_token_ids, *output.token_ids]
|
||||
|
||||
if request.logprobs is None:
|
||||
out_logprobs = None
|
||||
else:
|
||||
assert prompt_logprobs is not None
|
||||
assert output.logprobs is not None
|
||||
out_logprobs = [
|
||||
*prompt_logprobs,
|
||||
*output.logprobs,
|
||||
]
|
||||
|
||||
output_text = prompt_text + output.text
|
||||
else:
|
||||
token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
output_text = output.text
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert out_logprobs is not None, "Did not output logprobs"
|
||||
logprobs = self._create_completion_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
tokenizer=tokenizer,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=len(choices),
|
||||
text=output_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason,
|
||||
prompt_logprobs=final_res.prompt_logprobs,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_generated_tokens += len(output.token_ids)
|
||||
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
|
||||
request_metadata.final_usage_info = usage
|
||||
|
||||
return CompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _create_completion_logprobs(
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||
num_output_top_logprobs: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
initial_text_offset: int = 0,
|
||||
) -> CompletionLogProbs:
|
||||
"""Create logprobs for OpenAI Completion API."""
|
||||
out_text_offset: List[int] = []
|
||||
out_token_logprobs: List[Optional[float]] = []
|
||||
out_tokens: List[str] = []
|
||||
out_top_logprobs: List[Optional[Dict[str, float]]] = []
|
||||
|
||||
last_token_len = 0
|
||||
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
token = tokenizer.decode(token_id)
|
||||
if self.return_tokens_as_token_ids:
|
||||
token = f"token_id:{token_id}"
|
||||
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(None)
|
||||
out_top_logprobs.append(None)
|
||||
else:
|
||||
step_token = step_top_logprobs[token_id]
|
||||
|
||||
token = self._get_decoded_token(
|
||||
step_token,
|
||||
token_id,
|
||||
tokenizer,
|
||||
return_as_token_id=self.return_tokens_as_token_ids,
|
||||
)
|
||||
token_logprob = max(step_token.logprob, -9999.0)
|
||||
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(token_logprob)
|
||||
|
||||
# makes sure to add the top num_output_top_logprobs + 1
|
||||
# logprobs, as defined in the openai API
|
||||
# (cf. https://github.com/openai/openai-openapi/blob/
|
||||
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
|
||||
out_top_logprobs.append({
|
||||
# Convert float("-inf") to the
|
||||
# JSON-serializable float that OpenAI uses
|
||||
self._get_decoded_token(
|
||||
top_lp[1],
|
||||
top_lp[0],
|
||||
tokenizer,
|
||||
return_as_token_id=self.return_tokens_as_token_ids):
|
||||
max(top_lp[1].logprob, -9999.0)
|
||||
for i, top_lp in enumerate(step_top_logprobs.items())
|
||||
if num_output_top_logprobs >= i
|
||||
})
|
||||
|
||||
if len(out_text_offset) == 0:
|
||||
out_text_offset.append(initial_text_offset)
|
||||
else:
|
||||
out_text_offset.append(out_text_offset[-1] + last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
return CompletionLogProbs(
|
||||
text_offset=out_text_offset,
|
||||
token_logprobs=out_token_logprobs,
|
||||
tokens=out_tokens,
|
||||
top_logprobs=out_top_logprobs,
|
||||
)
|
||||
203
vllm/entrypoints/openai/serving_embedding.py
Normal file
203
vllm/entrypoints/openai/serving_embedding.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from typing import AsyncGenerator, List, Literal, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
ErrorResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
|
||||
from vllm.utils import merge_async_iterators, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
TypeTokenIDs = List[int]
|
||||
|
||||
|
||||
def _get_embedding(
|
||||
output: EmbeddingOutput,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
) -> Union[List[float], str]:
|
||||
if encoding_format == "float":
|
||||
return output.embedding
|
||||
elif encoding_format == "base64":
|
||||
# Force to use float32 for base64 encoding
|
||||
# to match the OpenAI python client behavior
|
||||
embedding_bytes = np.array(output.embedding, dtype="float32").tobytes()
|
||||
return base64.b64encode(embedding_bytes).decode("utf-8")
|
||||
|
||||
assert_never(encoding_format)
|
||||
|
||||
|
||||
def request_output_to_embedding_response(
|
||||
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
|
||||
created_time: int, model_name: str,
|
||||
encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
|
||||
data: List[EmbeddingResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
embedding = _get_embedding(final_res.outputs, encoding_format)
|
||||
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
|
||||
data.append(embedding_data)
|
||||
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=data,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIServingEmbedding(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger)
|
||||
self._enabled = self._check_embedding_mode(model_config.embedding_mode)
|
||||
|
||||
async def create_embedding(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[EmbeddingResponse, ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return self.create_error_response("Embedding API disabled")
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
encoding_format = request.encoding_format
|
||||
if request.dimensions is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"embd-{random_uuid()}"
|
||||
created_time = int(time.monotonic())
|
||||
|
||||
truncate_prompt_tokens = None
|
||||
|
||||
if request.truncate_prompt_tokens is not None:
|
||||
if request.truncate_prompt_tokens <= self.max_model_len:
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
else:
|
||||
return self.create_error_response(
|
||||
"truncate_prompt_tokens value is "
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size.")
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
prompts = list(
|
||||
self._tokenize_prompt_input_or_inputs(request, tokenizer,
|
||||
request.input,
|
||||
truncate_prompt_tokens))
|
||||
|
||||
for i, prompt_inputs in enumerate(prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
prompt_inputs,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError(
|
||||
"Prompt adapter is not supported "
|
||||
"for embedding models")
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(
|
||||
*generators,
|
||||
is_cancelled=raw_request.is_disconnected if raw_request else None,
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
||||
final_res_batch = [None] * len(prompts)
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
for final_res in final_res_batch:
|
||||
assert final_res is not None
|
||||
|
||||
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = request_output_to_embedding_response(
|
||||
final_res_batch_checked, request_id, created_time, model_name,
|
||||
encoding_format)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return response
|
||||
|
||||
def _check_embedding_mode(self, embedding_mode: bool) -> bool:
|
||||
if not embedding_mode:
|
||||
logger.warning(
|
||||
"embedding_mode is False. Embedding API will not work.")
|
||||
else:
|
||||
logger.info("Activating the server engine with embedding enabled.")
|
||||
return embedding_mode
|
||||
487
vllm/entrypoints/openai/serving_engine.py
Normal file
487
vllm/entrypoints/openai/serving_engine.py
Normal file
@@ -0,0 +1,487 @@
|
||||
import json
|
||||
import pathlib
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
from pydantic import Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
DetokenizeRequest,
|
||||
EmbeddingRequest, ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
ModelCard, ModelList,
|
||||
ModelPermission,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest,
|
||||
TokenizeRequest,
|
||||
UnloadLoraAdapterRequest)
|
||||
# yapf: enable
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import AtomicCounter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelPath:
|
||||
name: str
|
||||
model_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptAdapterPath:
|
||||
name: str
|
||||
local_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAModulePath:
|
||||
name: str
|
||||
path: str
|
||||
base_model_name: Optional[str] = None
|
||||
|
||||
|
||||
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
|
||||
EmbeddingRequest, TokenizeRequest]
|
||||
|
||||
|
||||
class TextTokensPrompt(TypedDict):
|
||||
prompt: str
|
||||
prompt_token_ids: List[int]
|
||||
|
||||
|
||||
class OpenAIServing:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.engine_client = engine_client
|
||||
self.model_config = model_config
|
||||
self.max_model_len = model_config.max_model_len
|
||||
|
||||
self.base_model_paths = base_model_paths
|
||||
|
||||
self.lora_id_counter = AtomicCounter(0)
|
||||
self.lora_requests = []
|
||||
if lora_modules is not None:
|
||||
self.lora_requests = [
|
||||
LoRARequest(lora_name=lora.name,
|
||||
lora_int_id=i,
|
||||
lora_path=lora.path,
|
||||
base_model_name=lora.base_model_name
|
||||
if lora.base_model_name
|
||||
and self._is_model_supported(lora.base_model_name)
|
||||
else self.base_model_paths[0].name)
|
||||
for i, lora in enumerate(lora_modules, start=1)
|
||||
]
|
||||
|
||||
self.prompt_adapter_requests = []
|
||||
if prompt_adapters is not None:
|
||||
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
|
||||
with pathlib.Path(prompt_adapter.local_path,
|
||||
"adapter_config.json").open() as f:
|
||||
adapter_config = json.load(f)
|
||||
num_virtual_tokens = adapter_config["num_virtual_tokens"]
|
||||
self.prompt_adapter_requests.append(
|
||||
PromptAdapterRequest(
|
||||
prompt_adapter_name=prompt_adapter.name,
|
||||
prompt_adapter_id=i,
|
||||
prompt_adapter_local_path=prompt_adapter.local_path,
|
||||
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
|
||||
|
||||
self.request_logger = request_logger
|
||||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. Right now we only have one model."""
|
||||
model_cards = [
|
||||
ModelCard(id=base_model.name,
|
||||
max_model_len=self.max_model_len,
|
||||
root=base_model.model_path,
|
||||
permission=[ModelPermission()])
|
||||
for base_model in self.base_model_paths
|
||||
]
|
||||
lora_cards = [
|
||||
ModelCard(id=lora.lora_name,
|
||||
root=lora.local_path,
|
||||
parent=lora.base_model_name if lora.base_model_name else
|
||||
self.base_model_paths[0].name,
|
||||
permission=[ModelPermission()])
|
||||
for lora in self.lora_requests
|
||||
]
|
||||
prompt_adapter_cards = [
|
||||
ModelCard(id=prompt_adapter.prompt_adapter_name,
|
||||
root=self.base_model_paths[0].name,
|
||||
permission=[ModelPermission()])
|
||||
for prompt_adapter in self.prompt_adapter_requests
|
||||
]
|
||||
model_cards.extend(lora_cards)
|
||||
model_cards.extend(prompt_adapter_cards)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
def create_error_response(
|
||||
self,
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
|
||||
return ErrorResponse(message=message,
|
||||
type=err_type,
|
||||
code=status_code.value)
|
||||
|
||||
def create_streaming_error_response(
|
||||
self,
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
|
||||
json_str = json.dumps({
|
||||
"error":
|
||||
self.create_error_response(message=message,
|
||||
err_type=err_type,
|
||||
status_code=status_code).model_dump()
|
||||
})
|
||||
return json_str
|
||||
|
||||
async def _check_model(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
) -> Optional[ErrorResponse]:
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
||||
return None
|
||||
if request.model in [
|
||||
prompt_adapter.prompt_adapter_name
|
||||
for prompt_adapter in self.prompt_adapter_requests
|
||||
]:
|
||||
return None
|
||||
return self.create_error_response(
|
||||
message=f"The model `{request.model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
def _maybe_get_adapters(
|
||||
self, request: AnyRequest
|
||||
) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
|
||||
None, PromptAdapterRequest]]:
|
||||
if self._is_model_supported(request.model):
|
||||
return None, None
|
||||
for lora in self.lora_requests:
|
||||
if request.model == lora.lora_name:
|
||||
return lora, None
|
||||
for prompt_adapter in self.prompt_adapter_requests:
|
||||
if request.model == prompt_adapter.prompt_adapter_name:
|
||||
return None, prompt_adapter
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||
|
||||
def _normalize_prompt_text_to_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt: str,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
||||
add_special_tokens: bool,
|
||||
) -> TextTokensPrompt:
|
||||
if truncate_prompt_tokens is None:
|
||||
encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
|
||||
else:
|
||||
encoded = tokenizer(prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=True,
|
||||
max_length=truncate_prompt_tokens)
|
||||
|
||||
input_ids = encoded.input_ids
|
||||
|
||||
input_text = prompt
|
||||
|
||||
return self._validate_input(request, input_ids, input_text)
|
||||
|
||||
def _normalize_prompt_tokens_to_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_ids: List[int],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
||||
) -> TextTokensPrompt:
|
||||
if truncate_prompt_tokens is None:
|
||||
input_ids = prompt_ids
|
||||
else:
|
||||
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
||||
|
||||
input_text = tokenizer.decode(input_ids)
|
||||
|
||||
return self._validate_input(request, input_ids, input_text)
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
input_ids: List[int],
|
||||
input_text: str,
|
||||
) -> TextTokensPrompt:
|
||||
token_num = len(input_ids)
|
||||
|
||||
# Note: EmbeddingRequest doesn't have max_tokens
|
||||
if isinstance(request, EmbeddingRequest):
|
||||
if token_num > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the input for embedding "
|
||||
f"generation. Please reduce the length of the input.")
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
|
||||
# and does not require model context length validation
|
||||
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
|
||||
DetokenizeRequest)):
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
if request.max_tokens is None:
|
||||
if token_num >= self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the messages, "
|
||||
f"Please reduce the length of the messages.")
|
||||
elif token_num + request.max_tokens > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{request.max_tokens + token_num} tokens "
|
||||
f"({token_num} in the messages, "
|
||||
f"{request.max_tokens} in the completion). "
|
||||
f"Please reduce the length of the messages or completion.")
|
||||
|
||||
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
def _tokenize_prompt_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_input: Union[str, List[int]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> TextTokensPrompt:
|
||||
"""
|
||||
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
|
||||
that assumes single input.
|
||||
"""
|
||||
return next(
|
||||
self._tokenize_prompt_inputs(
|
||||
request,
|
||||
tokenizer,
|
||||
[prompt_input],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
))
|
||||
|
||||
def _tokenize_prompt_inputs(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_inputs: Iterable[Union[str, List[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> Iterator[TextTokensPrompt]:
|
||||
"""
|
||||
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
|
||||
that assumes multiple inputs.
|
||||
"""
|
||||
for text in prompt_inputs:
|
||||
if isinstance(text, str):
|
||||
yield self._normalize_prompt_text_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt=text,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
yield self._normalize_prompt_tokens_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt_ids=text,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
def _tokenize_prompt_input_or_inputs(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> Iterator[TextTokensPrompt]:
|
||||
"""
|
||||
Tokenize/detokenize depending on the input format.
|
||||
|
||||
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
|
||||
, each input can be a string or array of tokens. Note that each request
|
||||
can pass one or more inputs.
|
||||
"""
|
||||
for prompt_input in parse_and_batch_prompt(input_or_inputs):
|
||||
# Although our type checking is based on mypy,
|
||||
# VSCode Pyright extension should still work properly
|
||||
# "is True" is required for Pyright to perform type narrowing
|
||||
# See: https://github.com/microsoft/pyright/issues/7672
|
||||
if prompt_input["is_tokens"] is False:
|
||||
yield self._normalize_prompt_text_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt=prompt_input["content"],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
yield self._normalize_prompt_tokens_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt_ids=prompt_input["content"],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: Union[str, List[int], TextTokensPrompt],
|
||||
params: Optional[Union[SamplingParams, PoolingParams,
|
||||
BeamSearchParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> None:
|
||||
if self.request_logger is None:
|
||||
return
|
||||
|
||||
if isinstance(inputs, str):
|
||||
prompt = inputs
|
||||
prompt_token_ids = None
|
||||
elif isinstance(inputs, list):
|
||||
prompt = None
|
||||
prompt_token_ids = inputs
|
||||
else:
|
||||
prompt = inputs["prompt"]
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
|
||||
self.request_logger.log_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
prompt_token_ids,
|
||||
params=params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_decoded_token(logprob: Logprob,
|
||||
token_id: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
return_as_token_id: bool = False) -> str:
|
||||
if return_as_token_id:
|
||||
return f"token_id:{token_id}"
|
||||
|
||||
if logprob.decoded_token is not None:
|
||||
return logprob.decoded_token
|
||||
return tokenizer.decode(token_id)
|
||||
|
||||
async def _check_load_lora_adapter_request(
|
||||
self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
|
||||
# Check if both 'lora_name' and 'lora_path' are provided
|
||||
if not request.lora_name or not request.lora_path:
|
||||
return self.create_error_response(
|
||||
message="Both 'lora_name' and 'lora_path' must be provided.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
# Check if the lora adapter with the given name already exists
|
||||
if any(lora_request.lora_name == request.lora_name
|
||||
for lora_request in self.lora_requests):
|
||||
return self.create_error_response(
|
||||
message=
|
||||
f"The lora adapter '{request.lora_name}' has already been"
|
||||
"loaded.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
return None
|
||||
|
||||
async def _check_unload_lora_adapter_request(
|
||||
self,
|
||||
request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
|
||||
# Check if either 'lora_name' or 'lora_int_id' is provided
|
||||
if not request.lora_name and not request.lora_int_id:
|
||||
return self.create_error_response(
|
||||
message=
|
||||
"either 'lora_name' and 'lora_int_id' needs to be provided.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
# Check if the lora adapter with the given name exists
|
||||
if not any(lora_request.lora_name == request.lora_name
|
||||
for lora_request in self.lora_requests):
|
||||
return self.create_error_response(
|
||||
message=
|
||||
f"The lora adapter '{request.lora_name}' cannot be found.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
return None
|
||||
|
||||
async def load_lora_adapter(
|
||||
self,
|
||||
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
|
||||
error_check_ret = await self._check_load_lora_adapter_request(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
lora_name, lora_path = request.lora_name, request.lora_path
|
||||
unique_id = self.lora_id_counter.inc(1)
|
||||
self.lora_requests.append(
|
||||
LoRARequest(lora_name=lora_name,
|
||||
lora_int_id=unique_id,
|
||||
lora_path=lora_path))
|
||||
return f"Success: LoRA adapter '{lora_name}' added successfully."
|
||||
|
||||
async def unload_lora_adapter(
|
||||
self,
|
||||
request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
|
||||
error_check_ret = await self._check_unload_lora_adapter_request(request
|
||||
)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
lora_name = request.lora_name
|
||||
self.lora_requests = [
|
||||
lora_request for lora_request in self.lora_requests
|
||||
if lora_request.lora_name != lora_name
|
||||
]
|
||||
return f"Success: LoRA adapter '{lora_name}' removed successfully."
|
||||
|
||||
def _is_model_supported(self, model_name):
|
||||
return any(model.name == model_name for model in self.base_model_paths)
|
||||
157
vllm/entrypoints/openai/serving_tokenization.py
Normal file
157
vllm/entrypoints/openai/serving_tokenization.py
Normal file
@@ -0,0 +1,157 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
load_chat_template,
|
||||
parse_chat_messages_futures)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
ErrorResponse,
|
||||
TokenizeChatRequest,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
LoRAModulePath,
|
||||
OpenAIServing)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger)
|
||||
|
||||
# If this is None we use the tokenizer's default chat template
|
||||
# the list of commonly-used chat template names for HF named templates
|
||||
hf_chat_templates: List[str] = ['default', 'tool_use']
|
||||
self.chat_template = chat_template \
|
||||
if chat_template in hf_chat_templates \
|
||||
else load_chat_template(chat_template)
|
||||
|
||||
async def create_tokenize(
|
||||
self,
|
||||
request: TokenizeRequest,
|
||||
) -> Union[TokenizeResponse, ErrorResponse]:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"tokn-{random_uuid()}"
|
||||
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
prompt: Union[str, List[int]]
|
||||
if isinstance(request, TokenizeChatRequest):
|
||||
model_config = self.model_config
|
||||
|
||||
conversation, mm_data_future = parse_chat_messages_futures(
|
||||
request.messages, model_config, tokenizer)
|
||||
|
||||
mm_data = await mm_data_future
|
||||
if mm_data:
|
||||
logger.warning(
|
||||
"Multi-modal inputs are ignored during tokenization")
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
prompt = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=request.messages,
|
||||
chat_template=self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
)
|
||||
else:
|
||||
prompt = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
conversation=conversation,
|
||||
chat_template=self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
)
|
||||
else:
|
||||
prompt = request.prompt
|
||||
|
||||
self._log_inputs(request_id,
|
||||
prompt,
|
||||
params=None,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
# Silently ignore prompt adapter since it does not affect tokenization
|
||||
|
||||
prompt_input = self._tokenize_prompt_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
input_ids = prompt_input["prompt_token_ids"]
|
||||
|
||||
return TokenizeResponse(tokens=input_ids,
|
||||
count=len(input_ids),
|
||||
max_model_len=self.max_model_len)
|
||||
|
||||
async def create_detokenize(
|
||||
self,
|
||||
request: DetokenizeRequest,
|
||||
) -> Union[DetokenizeResponse, ErrorResponse]:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"tokn-{random_uuid()}"
|
||||
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
self._log_inputs(request_id,
|
||||
request.tokens,
|
||||
params=None,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for tokenization")
|
||||
|
||||
prompt_input = self._tokenize_prompt_input(
|
||||
request,
|
||||
tokenizer,
|
||||
request.tokens,
|
||||
)
|
||||
input_text = prompt_input["prompt"]
|
||||
|
||||
return DetokenizeResponse(prompt=input_text)
|
||||
10
vllm/entrypoints/openai/tool_parsers/__init__.py
Normal file
10
vllm/entrypoints/openai/tool_parsers/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .abstract_tool_parser import ToolParser, ToolParserManager
|
||||
from .hermes_tool_parser import Hermes2ProToolParser
|
||||
from .internlm2_tool_parser import Internlm2ToolParser
|
||||
from .llama_tool_parser import Llama3JsonToolParser
|
||||
from .mistral_tool_parser import MistralToolParser
|
||||
|
||||
__all__ = [
|
||||
"ToolParser", "ToolParserManager", "Hermes2ProToolParser",
|
||||
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser"
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
161
vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py
Normal file
161
vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Type, Union
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ToolParser:
|
||||
"""
|
||||
Abstract ToolParser class that should not be used directly. Provided
|
||||
properties and methods should be used in
|
||||
derived classes.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
self.prev_tool_call_arr: List[Dict] = []
|
||||
# the index of the tool call that is currently being parsed
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: List[str] = []
|
||||
|
||||
self.model_tokenizer = tokenizer
|
||||
|
||||
@cached_property
|
||||
def vocab(self) -> Dict[str, int]:
|
||||
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
|
||||
# whereas all tokenizers have .get_vocab()
|
||||
return self.model_tokenizer.get_vocab()
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
"""
|
||||
Static method that used to adjust the request parameters.
|
||||
"""
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Static method that should be implemented for extracting tool calls from
|
||||
a complete model-generated string.
|
||||
Used for non-streaming responses where we have the entire model response
|
||||
available before sending to the client.
|
||||
Static because it's stateless.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"AbstractToolParser.extract_tool_calls has not been implemented!")
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Instance method that should be implemented for extracting tool calls
|
||||
from an incomplete response; for use when handling tool calls and
|
||||
streaming. Has to be an instance method because it requires state -
|
||||
the current tokens/diffs, but also the information about what has
|
||||
previously been parsed and extracted (see constructor)
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"AbstractToolParser.extract_tool_calls_streaming has not been "
|
||||
"implemented!")
|
||||
|
||||
|
||||
class ToolParserManager:
|
||||
tool_parsers: Dict[str, Type] = {}
|
||||
|
||||
@classmethod
|
||||
def get_tool_parser(cls, name) -> Type:
|
||||
"""
|
||||
Get tool parser by name which is registered by `register_module`.
|
||||
|
||||
Raise a KeyError exception if the name is not registered.
|
||||
"""
|
||||
if name in cls.tool_parsers:
|
||||
return cls.tool_parsers[name]
|
||||
|
||||
raise KeyError(f"tool helper: '{name}' not found in tool_parsers")
|
||||
|
||||
@classmethod
|
||||
def _register_module(cls,
|
||||
module: Type,
|
||||
module_name: Optional[Union[str, List[str]]] = None,
|
||||
force: bool = True) -> None:
|
||||
if not issubclass(module, ToolParser):
|
||||
raise TypeError(
|
||||
f'module must be subclass of ToolParser, but got {type(module)}'
|
||||
)
|
||||
if module_name is None:
|
||||
module_name = module.__name__
|
||||
if isinstance(module_name, str):
|
||||
module_name = [module_name]
|
||||
for name in module_name:
|
||||
if not force and name in cls.tool_parsers:
|
||||
existed_module = cls.tool_parsers[name]
|
||||
raise KeyError(f'{name} is already registered '
|
||||
f'at {existed_module.__module__}')
|
||||
cls.tool_parsers[name] = module
|
||||
|
||||
@classmethod
|
||||
def register_module(
|
||||
cls,
|
||||
name: Optional[Union[str, List[str]]] = None,
|
||||
force: bool = True,
|
||||
module: Union[Type, None] = None) -> Union[type, Callable]:
|
||||
"""
|
||||
Register module with the given name or name list. it can be used as a
|
||||
decoder(with module as None) or normal function(with module as not
|
||||
None).
|
||||
"""
|
||||
if not isinstance(force, bool):
|
||||
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
||||
|
||||
# raise the error ahead of time
|
||||
if not (name is None or isinstance(name, str)
|
||||
or is_list_of(name, str)):
|
||||
raise TypeError(
|
||||
'name must be None, an instance of str, or a sequence of str, '
|
||||
f'but got {type(name)}')
|
||||
|
||||
# use it as a normal method: x.register_module(module=SomeClass)
|
||||
if module is not None:
|
||||
cls._register_module(module=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
# use it as a decorator: @x.register_module()
|
||||
def _register(module):
|
||||
cls._register_module(module=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
return _register
|
||||
|
||||
@classmethod
|
||||
def import_tool_parser(cls, plugin_path: str) -> None:
|
||||
"""
|
||||
Import a user defined tool parser by the path of the tool parser define
|
||||
file.
|
||||
"""
|
||||
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
|
||||
spec = importlib.util.spec_from_file_location(module_name, plugin_path)
|
||||
if spec is None or spec.loader is None:
|
||||
logger.error("load %s from %s failed.", module_name, plugin_path)
|
||||
return
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
338
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
Normal file
338
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
Normal file
@@ -0,0 +1,338 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, List, Sequence, Union
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("hermes")
|
||||
class Hermes2ProToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
logger.error(
|
||||
"Detected Mistral tokenizer when using a Hermes model")
|
||||
self.model_tokenizer = self.model_tokenizer.tokenizer
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: List[Dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: List[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
|
||||
self.scratch_pad_regex = re.compile(
|
||||
r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
if not self.tool_call_start_token_id or not self.tool_call_end_token_id:
|
||||
raise RuntimeError(
|
||||
"Hermes 2 Pro Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!")
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_call_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
else:
|
||||
|
||||
try:
|
||||
# there are two possible captures - between tags, or between a
|
||||
# tag and end-of-string so the result of
|
||||
# findall is an array of tuples where one is a function call and
|
||||
# the other is None
|
||||
function_call_tuples = (
|
||||
self.tool_call_regex.findall(model_output))
|
||||
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
raw_function_calls = [
|
||||
json.loads(match[0] if match[0] else match[1])
|
||||
for match in function_call_tuples
|
||||
]
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(function_call["arguments"])))
|
||||
for function_call in raw_function_calls
|
||||
]
|
||||
|
||||
content = model_output[:model_output.
|
||||
find(self.tool_call_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in extracting tool call from response %s",
|
||||
e)
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
# check to see if we should be streaming a tool call - is there a
|
||||
if self.tool_call_start_token_id not in current_token_ids:
|
||||
logger.debug("No tool call tokens found!")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
prev_tool_start_count = previous_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
prev_tool_end_count = previous_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
cur_tool_start_count = current_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
cur_tool_end_count = current_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
|
||||
# case: if we're generating text, OR rounding out a tool call
|
||||
if (cur_tool_start_count == cur_tool_end_count
|
||||
and prev_tool_end_count == cur_tool_end_count):
|
||||
logger.debug("Generating text content! skipping tool parsing.")
|
||||
if delta_text != self.tool_call_end_token:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# case: if tool open & close tag counts don't match, we're doing
|
||||
# imaginary "else" block here
|
||||
# something with tools with this diff.
|
||||
# flags for partial JSON parting. exported constants from
|
||||
# "Allow" are handled via BIT MASK
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
|
||||
# case -- we're starting a new tool call
|
||||
if (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count > prev_tool_start_count):
|
||||
if len(delta_token_ids) > 1:
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
else:
|
||||
tool_call_portion = None
|
||||
delta = None
|
||||
|
||||
text_portion = None
|
||||
|
||||
# set cursors and state appropriately
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
||||
|
||||
# case -- we're updating an existing tool call
|
||||
elif (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count == prev_tool_start_count):
|
||||
|
||||
# get the portion of the text that's the tool call
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
text_portion = None
|
||||
|
||||
# case -- the current tool call is being closed.
|
||||
elif (cur_tool_start_count == cur_tool_end_count
|
||||
and cur_tool_end_count > prev_tool_end_count):
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
if diff:
|
||||
diff = json.dumps(diff).replace(
|
||||
self.streamed_args_for_tool[self.current_tool_id], "")
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not "
|
||||
"been streamed yet: %s", diff)
|
||||
self.streamed_args_for_tool[self.current_tool_id] \
|
||||
+= diff
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
|
||||
# case -- otherwise we're just generating text
|
||||
else:
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
delta = DeltaMessage(tool_calls=[], content=text)
|
||||
return delta
|
||||
|
||||
try:
|
||||
|
||||
current_tool_call = partial_json_parser.loads(
|
||||
tool_call_portion or "{}",
|
||||
flags) if tool_call_portion else None
|
||||
logger.debug("Parsed tool call %s", current_tool_call)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# case - we haven't sent the tool name yet. If it's available, send
|
||||
# it. otherwise, wait until it's available.
|
||||
if not self.current_tool_name_sent:
|
||||
function_name: Union[str, None] = current_tool_call.get("name")
|
||||
if function_name:
|
||||
self.current_tool_name_sent = True
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
else:
|
||||
return None
|
||||
# case -- otherwise, send the tool call delta
|
||||
|
||||
# if the tool call portion is None, send the delta as text
|
||||
if tool_call_portion is None:
|
||||
# if there's text but not tool calls, send that -
|
||||
# otherwise None to skip chunk
|
||||
delta = DeltaMessage(content=delta_text) \
|
||||
if text_portion is not None else None
|
||||
return delta
|
||||
|
||||
# now, the nitty-gritty of tool calls
|
||||
# now we have the portion to parse as tool call.
|
||||
|
||||
logger.debug("Trying to parse current tool call with ID %s",
|
||||
self.current_tool_id)
|
||||
|
||||
# if we're starting a new tool call, push an empty object in as
|
||||
# a placeholder for the arguments
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
|
||||
# main logic for tool parsing here - compare prev. partially-parsed
|
||||
# JSON to the current partially-parsed JSON
|
||||
prev_arguments = (
|
||||
self.prev_tool_call_arr[self.current_tool_id].get("arguments"))
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
||||
logger.debug("against new ones: %s", cur_arguments)
|
||||
|
||||
# case -- no arguments have been created yet. skip sending a delta.
|
||||
if not cur_arguments and not prev_arguments:
|
||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
||||
delta = None
|
||||
|
||||
# case -- prev arguments are defined, but non are now.
|
||||
# probably impossible, but not a fatal error - just keep going
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error("should be impossible to have arguments reset "
|
||||
"mid-call. skipping streaming anything.")
|
||||
delta = None
|
||||
|
||||
# case -- we now have the first info about arguments available from
|
||||
# autocompleting the JSON
|
||||
elif cur_arguments and not prev_arguments:
|
||||
|
||||
cur_arguments_json = json.dumps(cur_arguments)
|
||||
logger.debug("finding %s in %s", delta_text,
|
||||
cur_arguments_json)
|
||||
|
||||
# get the location where previous args differ from current
|
||||
args_delta_start_loc = cur_arguments_json.index(delta_text) \
|
||||
+ len(delta_text)
|
||||
|
||||
# use that to find the actual delta
|
||||
arguments_delta = cur_arguments_json[:args_delta_start_loc]
|
||||
logger.debug("First tokens in arguments received: %s",
|
||||
arguments_delta)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[self.current_tool_id] \
|
||||
+= arguments_delta
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
prev_args_json = json.dumps(prev_arguments)
|
||||
logger.debug("Searching for diff between\n%s", cur_args_json)
|
||||
logger.debug("and\n%s", prev_args_json)
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json)
|
||||
logger.debug("got argument diff %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[self.current_tool_id] \
|
||||
+= argument_diff
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
||||
self.prev_tool_call_arr[self.current_tool_id] = \
|
||||
current_tool_call
|
||||
else:
|
||||
self.prev_tool_call_arr.append(current_tool_call)
|
||||
|
||||
return delta
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error trying to handle streaming tool call: %s", e)
|
||||
return None # do not stream a delta. skip this token ID.
|
||||
208
vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
Normal file
208
vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import json
|
||||
from typing import Dict, Sequence, Union
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module(["internlm"])
|
||||
class Internlm2ToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.position = 0
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
if request.tools and request.tool_choice != 'none':
|
||||
# do not skip special tokens because internlm use the special
|
||||
# tokens to indicated the start and end of the tool calls
|
||||
# information.
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def get_argments(self, obj):
|
||||
if "parameters" in obj:
|
||||
return obj.get("parameters")
|
||||
elif "arguments" in obj:
|
||||
return obj.get("arguments")
|
||||
return None
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
if '<|action_start|>' not in current_text:
|
||||
self.position = len(current_text)
|
||||
return DeltaMessage(content=delta_text)
|
||||
# if the tool call is sended, return a empty delta message
|
||||
# to make sure the finish_reason will be send correctly.
|
||||
if self.current_tool_id > 0:
|
||||
return DeltaMessage(content='')
|
||||
|
||||
last_pos = self.position
|
||||
if '<|action_start|><|plugin|>' not in current_text[last_pos:]:
|
||||
return None
|
||||
|
||||
new_delta = current_text[last_pos:]
|
||||
text, action = new_delta.split('<|action_start|><|plugin|>')
|
||||
|
||||
if len(text) > 0:
|
||||
self.position = self.position + len(text)
|
||||
return DeltaMessage(content=text)
|
||||
|
||||
action = action.strip()
|
||||
action = action.split('<|action_end|>'.strip())[0]
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
|
||||
try:
|
||||
parsable_arr = action
|
||||
|
||||
# tool calls are generated in an object in inernlm2
|
||||
# it's not support parallel tool calls
|
||||
try:
|
||||
tool_call_arr: Dict = partial_json_parser.loads(
|
||||
parsable_arr, flags)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
if not self.current_tool_name_sent:
|
||||
function_name = tool_call_arr.get("name")
|
||||
if function_name:
|
||||
self.current_tool_id = self.current_tool_id + 1
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
self.streamed_args_for_tool.append("")
|
||||
else:
|
||||
delta = None
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
prev_arguments = self.get_argments(
|
||||
self.prev_tool_call_arr[self.current_tool_id])
|
||||
cur_arguments = self.get_argments(tool_call_arr)
|
||||
|
||||
# not arguments generated
|
||||
if not cur_arguments and not prev_arguments:
|
||||
delta = None
|
||||
# will never happen
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"INVARIANT - impossible to have arguments reset "
|
||||
"mid-arguments")
|
||||
delta = None
|
||||
# first time to get parameters
|
||||
elif cur_arguments and not prev_arguments:
|
||||
cur_arguments_json = json.dumps(cur_arguments)
|
||||
|
||||
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
||||
index(delta_text) +
|
||||
len(delta_text)]
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += arguments_delta
|
||||
# both prev and cur parameters, send the increase parameters
|
||||
elif cur_arguments and prev_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
prev_args_json = json.dumps(prev_arguments)
|
||||
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
|
||||
# check to see if the name is defined and has been sent. if so,
|
||||
# stream the name - otherwise keep waiting
|
||||
# finish by setting old and returning None as base case
|
||||
tool_call_arr["arguments"] = self.get_argments(tool_call_arr)
|
||||
self.prev_tool_call_arr = [tool_call_arr]
|
||||
return delta
|
||||
except Exception as e:
|
||||
logger.error("Error trying to handle streaming tool call: %s", e)
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
text = model_output
|
||||
tools = request.tools
|
||||
if '<|action_start|><|plugin|>' in text:
|
||||
text, action = text.split('<|action_start|><|plugin|>')
|
||||
action = action.split('<|action_end|>'.strip())[0]
|
||||
action = action[action.find('{'):]
|
||||
action_dict = json.loads(action)
|
||||
name, parameters = action_dict['name'], json.dumps(
|
||||
action_dict.get('parameters', action_dict.get('arguments',
|
||||
{})))
|
||||
|
||||
if not tools or name not in [t.function.name for t in tools]:
|
||||
ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=text)
|
||||
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
function=FunctionCall(name=name, arguments=parameters))
|
||||
]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=text if len(text) > 0 else None)
|
||||
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=text)
|
||||
277
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
Normal file
277
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
Normal file
@@ -0,0 +1,277 @@
|
||||
import json
|
||||
import re
|
||||
from json import JSONDecodeError, JSONDecoder
|
||||
from typing import Dict, List, Sequence, Union
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import find_common_prefix
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# partial_json_parser doesn't support extra data and
|
||||
# JSONDecorder.raw_decode doesn't support partial JSON
|
||||
def partial_json_loads(input_str, flags):
|
||||
try:
|
||||
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||
except JSONDecodeError as e:
|
||||
if "Extra data" in e.msg:
|
||||
dec = JSONDecoder()
|
||||
return dec.raw_decode(input_str)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def is_complete_json(input_str):
|
||||
try:
|
||||
json.loads(input_str)
|
||||
return True
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
|
||||
|
||||
@ToolParserManager.register_module("llama3_json")
|
||||
class Llama3JsonToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Llama 3.1 models intended for use with the
|
||||
examples/tool_chat_template_llama.jinja template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
# streaming mode
|
||||
self.prev_tool_call_arr: List[Dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: List[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
self.bot_token = "<|python_tag|>"
|
||||
self.bot_token_id = tokenizer.encode(self.bot_token,
|
||||
add_special_tokens=False)[0]
|
||||
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
# case -- if a tool call token is not present, return a text response
|
||||
if not (model_output.startswith(self.bot_token)
|
||||
or model_output.startswith('{')):
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
try:
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
dec = JSONDecoder()
|
||||
function_call_arr = []
|
||||
|
||||
# depending on the prompt format the Llama model may or may not
|
||||
# prefix the output with the <|python_tag|> token
|
||||
start_idx = len(self.bot_token) if model_output.startswith(
|
||||
self.bot_token) else 0
|
||||
while start_idx < len(model_output):
|
||||
(obj, end_idx) = dec.raw_decode(model_output[start_idx:])
|
||||
start_idx += end_idx + len('; ')
|
||||
function_call_arr.append(obj)
|
||||
|
||||
tool_calls: List[ToolCall] = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=raw_function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(raw_function_call["arguments"] \
|
||||
if "arguments" in raw_function_call \
|
||||
else raw_function_call["parameters"])))
|
||||
for raw_function_call in function_call_arr
|
||||
]
|
||||
|
||||
# get any content before the tool call
|
||||
ret = ExtractedToolCallInformation(tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=None)
|
||||
return ret
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in extracting tool call from response: %s", e)
|
||||
print("ERROR", e)
|
||||
# return information to just treat the tool call as regular JSON
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
if not (current_text.startswith(self.bot_token)
|
||||
or current_text.startswith('{')):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
tool_call_arr = []
|
||||
is_complete = []
|
||||
try:
|
||||
# depending on the prompt format the Llama model may or may not
|
||||
# prefix the output with the <|python_tag|> token
|
||||
start_idx = len(self.bot_token) if current_text.startswith(
|
||||
self.bot_token) else 0
|
||||
while start_idx < len(current_text):
|
||||
(obj,
|
||||
end_idx) = partial_json_loads(current_text[start_idx:],
|
||||
flags)
|
||||
is_complete.append(
|
||||
is_complete_json(current_text[start_idx:start_idx +
|
||||
end_idx]))
|
||||
start_idx += end_idx + len('; ')
|
||||
# depending on the prompt Llama can use
|
||||
# either arguments or parameters
|
||||
if "parameters" in obj:
|
||||
assert "arguments" not in obj, \
|
||||
"model generated both parameters and arguments"
|
||||
obj["arguments"] = obj["parameters"]
|
||||
tool_call_arr.append(obj)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
|
||||
if len(tool_call_arr) > 0 else {}
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return None
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (len(tool_call_arr) > 0
|
||||
and len(tool_call_arr) > self.current_tool_id + 1):
|
||||
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
if cur_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
sent = len(
|
||||
self.streamed_args_for_tool[self.current_tool_id])
|
||||
argument_diff = cur_args_json[sent:]
|
||||
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
else:
|
||||
delta = None
|
||||
else:
|
||||
delta = None
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
elif not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
delta = None
|
||||
|
||||
if cur_arguments:
|
||||
sent = len(
|
||||
self.streamed_args_for_tool[self.current_tool_id])
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
prev_arguments = self.prev_tool_call_arr[
|
||||
self.current_tool_id].get("arguments")
|
||||
|
||||
argument_diff = None
|
||||
if is_complete[self.current_tool_id]:
|
||||
argument_diff = cur_args_json[sent:]
|
||||
elif prev_arguments:
|
||||
prev_args_json = json.dumps(prev_arguments)
|
||||
if cur_args_json != prev_args_json:
|
||||
|
||||
prefix = find_common_prefix(
|
||||
prev_args_json, cur_args_json)
|
||||
argument_diff = prefix[sent:]
|
||||
|
||||
if argument_diff is not None:
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error trying to handle streaming tool call: %s", e)
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
306
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
Normal file
306
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
Normal file
@@ -0,0 +1,306 @@
|
||||
import json
|
||||
import re
|
||||
from random import choices
|
||||
from string import ascii_letters, digits
|
||||
from typing import Dict, List, Sequence, Union
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ALPHANUMERIC = ascii_letters + digits
|
||||
|
||||
|
||||
class MistralToolCall(ToolCall):
|
||||
id: str = Field(
|
||||
default_factory=lambda: MistralToolCall.generate_random_id())
|
||||
|
||||
@staticmethod
|
||||
def generate_random_id():
|
||||
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
|
||||
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
|
||||
return "".join(choices(ALPHANUMERIC, k=9))
|
||||
|
||||
|
||||
@ToolParserManager.register_module("mistral")
|
||||
class MistralToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
|
||||
examples/tool_chat_template_mistral.jinja template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if not isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
logger.info("Non-Mistral tokenizer detected when using a Mistral "
|
||||
"model...")
|
||||
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
# streaming mode
|
||||
self.prev_tool_call_arr: List[Dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: List[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
self.bot_token = "[TOOL_CALLS]"
|
||||
self.bot_token_id = self.vocab.get(self.bot_token)
|
||||
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
||||
if not self.bot_token_id:
|
||||
raise RuntimeError(
|
||||
"Mistral Tool Parser could not locate the tool call token in "
|
||||
"the tokenizer!")
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response. Requires
|
||||
find-and-replacing single quotes with double quotes for JSON parsing,
|
||||
make sure your tool call arguments don't ever include quotes!
|
||||
"""
|
||||
|
||||
# case -- if a tool call token is not present, return a text response
|
||||
if self.bot_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
try:
|
||||
|
||||
# use a regex to find the tool call. remove the BOT token
|
||||
# and make sure to replace single quotes with double quotes
|
||||
raw_tool_call = self.tool_call_regex.findall(
|
||||
model_output.replace(self.bot_token, ""))[0]
|
||||
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
function_call_arr = json.loads(raw_tool_call)
|
||||
tool_calls: List[MistralToolCall] = [
|
||||
MistralToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=raw_function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(raw_function_call["arguments"])))
|
||||
for raw_function_call in function_call_arr
|
||||
]
|
||||
|
||||
# get any content before the tool call
|
||||
content = model_output.split(self.bot_token)[0]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if len(content) > 0 else None)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in extracting tool call from response: %s", e)
|
||||
# return information to just treat the tool call as regular JSON
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
# if the tool call token is not in the tokens generated so far, append
|
||||
# output to contents since it's not a tool
|
||||
if self.bot_token not in current_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# if the tool call token ID IS in the tokens generated so far, that
|
||||
# means we're parsing as tool calls now
|
||||
|
||||
# handle if we detected the BOT token which means the start of tool
|
||||
# calling
|
||||
if (self.bot_token_id in delta_token_ids
|
||||
and len(delta_token_ids) == 1):
|
||||
# if it's the only token, return None, so we don't send a chat
|
||||
# completion any don't send a control token
|
||||
return None
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
|
||||
# replace BOT token with empty string, and convert single quotes
|
||||
# to double to allow parsing as JSON since mistral uses single
|
||||
# quotes instead of double for tool calls
|
||||
parsable_arr = current_text.split(self.bot_token)[-1]
|
||||
|
||||
# tool calls are generated in an array, so do partial JSON
|
||||
# parsing on the entire array
|
||||
try:
|
||||
tool_call_arr: List[Dict] = partial_json_parser.loads(
|
||||
parsable_arr, flags)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
|
||||
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
|
||||
if len(tool_call_arr) > 0 else {}
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return None
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (len(tool_call_arr) > 0
|
||||
and len(tool_call_arr) > self.current_tool_id + 1):
|
||||
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
diff: Union[str, None] = current_tool_call.get("arguments")
|
||||
|
||||
if diff:
|
||||
diff = json.dumps(diff).replace(
|
||||
self.streamed_args_for_tool[self.current_tool_id],
|
||||
"")
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += diff
|
||||
else:
|
||||
delta = None
|
||||
else:
|
||||
delta = None
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# case: update an existing tool - this is handled below
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
if not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
|
||||
prev_arguments = self.prev_tool_call_arr[
|
||||
self.current_tool_id].get("arguments")
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
new_text = delta_text.replace("\'", "\"")
|
||||
|
||||
if not cur_arguments and not prev_arguments:
|
||||
|
||||
delta = None
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"INVARIANT - impossible to have arguments reset "
|
||||
"mid-arguments")
|
||||
delta = None
|
||||
elif cur_arguments and not prev_arguments:
|
||||
cur_arguments_json = json.dumps(cur_arguments)
|
||||
logger.debug("finding %s in %s", new_text,
|
||||
cur_arguments_json)
|
||||
|
||||
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
||||
index(new_text) +
|
||||
len(new_text)]
|
||||
logger.debug("First tokens in arguments received: %s",
|
||||
arguments_delta)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += arguments_delta
|
||||
|
||||
elif cur_arguments and prev_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
prev_args_json = json.dumps(prev_arguments)
|
||||
logger.debug("Searching for diff between \n%s\n%s",
|
||||
cur_args_json, prev_args_json)
|
||||
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json)
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
else:
|
||||
# try parsing it with regular JSON - if it works we're
|
||||
# at the end, and we need to send the difference between
|
||||
# tokens streamed so far and the valid JSON
|
||||
delta = None
|
||||
|
||||
# check to see if the name is defined and has been sent. if so,
|
||||
# stream the name - otherwise keep waiting
|
||||
# finish by setting old and returning None as base case
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error trying to handle streaming tool call: %s", e)
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
87
vllm/entrypoints/openai/tool_parsers/utils.py
Normal file
87
vllm/entrypoints/openai/tool_parsers/utils.py
Normal file
@@ -0,0 +1,87 @@
|
||||
def find_common_prefix(s1: str, s2: str) -> str:
|
||||
"""
|
||||
Finds a common prefix that is shared between two strings, if there is one.
|
||||
Order of arguments is NOT important.
|
||||
|
||||
This function is provided as a UTILITY for extracting information from JSON
|
||||
generated by partial_json_parser, to help in ensuring that the right tokens
|
||||
are returned in streaming, so that close-quotes, close-brackets and
|
||||
close-braces are not returned prematurely.
|
||||
|
||||
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
|
||||
'{"fruit": "ap'
|
||||
"""
|
||||
prefix = ''
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(0, min_length):
|
||||
if s1[i] == s2[i]:
|
||||
prefix += s1[i]
|
||||
else:
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def find_common_suffix(s1: str, s2: str) -> str:
|
||||
"""
|
||||
Finds a common suffix shared between two strings, if there is one. Order of
|
||||
arguments is NOT important.
|
||||
Stops when the suffix ends OR it hits an alphanumeric character
|
||||
|
||||
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
|
||||
"""
|
||||
suffix = ''
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(1, min_length + 1):
|
||||
if s1[-i] == s2[-i] and not s1[-i].isalnum():
|
||||
suffix = s1[-i] + suffix
|
||||
else:
|
||||
break
|
||||
return suffix
|
||||
|
||||
|
||||
def extract_intermediate_diff(curr: str, old: str) -> str:
|
||||
"""
|
||||
Given two strings, extract the difference in the middle between two strings
|
||||
that are known to have a common prefix and/or suffix.
|
||||
|
||||
This function is provided as a UTILITY for extracting information from JSON
|
||||
generated by partial_json_parser, to help in ensuring that the right tokens
|
||||
are returned in streaming, so that close-quotes, close-brackets and
|
||||
close-braces are not returned prematurely. The order of arguments IS
|
||||
important - the new version of the partially-parsed JSON must be the first
|
||||
argument, and the secnod argument must be from the previous generation.
|
||||
|
||||
What it returns, is tokens that should be streamed to the client.
|
||||
|
||||
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
|
||||
-> 'ple'
|
||||
|
||||
"""
|
||||
suffix = find_common_suffix(curr, old)
|
||||
|
||||
old = old[::-1].replace(suffix[::-1], '', 1)[::-1]
|
||||
prefix = find_common_prefix(curr, old)
|
||||
diff = curr
|
||||
if len(suffix):
|
||||
diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1]
|
||||
|
||||
if len(prefix):
|
||||
# replace the prefix only once in case it's mirrored
|
||||
diff = diff.replace(prefix, '', 1)
|
||||
|
||||
return diff
|
||||
|
||||
|
||||
def find_all_indices(string, substring):
|
||||
"""
|
||||
Find all (starting) indices of a substring in a given string. Useful for
|
||||
tool call extraction
|
||||
"""
|
||||
indices = []
|
||||
index = -1
|
||||
while True:
|
||||
index = string.find(substring, index + 1)
|
||||
if index == -1:
|
||||
break
|
||||
indices.append(index)
|
||||
return indices
|
||||
Reference in New Issue
Block a user