First commit
This commit is contained in:
0
vllm/entrypoints/openai/__init__.py
Normal file
0
vllm/entrypoints/openai/__init__.py
Normal file
BIN
vllm/entrypoints/openai/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/api_server.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/api_server.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/cli_args.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/cli_args.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/protocol.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/protocol.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/run_batch.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/run_batch.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/serving_chat.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/serving_chat.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
585
vllm/entrypoints/openai/api_server.py
Normal file
585
vllm/entrypoints/openai/api_server.py
Normal file
@@ -0,0 +1,585 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import inspect
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import socket
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncIterator, Set
|
||||
|
||||
import uvloop
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from starlette.datastructures import State
|
||||
from starlette.routing import Mount
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.engine.multiprocessing.engine import run_mp_engine
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||
validate_parsed_serve_args)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse, ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
UnloadLoraAdapterRequest)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
logger = init_logger('vllm.entrypoints.openai.api_server')
|
||||
|
||||
_running_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
try:
|
||||
if app.state.log_stats:
|
||||
engine_client: EngineClient = app.state.engine_client
|
||||
|
||||
async def _force_log():
|
||||
while True:
|
||||
await asyncio.sleep(10.)
|
||||
await engine_client.do_log_stats()
|
||||
|
||||
task = asyncio.create_task(_force_log())
|
||||
_running_tasks.add(task)
|
||||
task.add_done_callback(_running_tasks.remove)
|
||||
else:
|
||||
task = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
finally:
|
||||
# Ensure app state including engine ref is gc'd
|
||||
del app.state
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client(
|
||||
args: Namespace) -> AsyncIterator[EngineClient]:
|
||||
|
||||
# Context manager to handle engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, args.disable_frontend_multiprocessing) as engine:
|
||||
yield engine
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client_from_engine_args(
|
||||
engine_args: AsyncEngineArgs,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
) -> AsyncIterator[EngineClient]:
|
||||
"""
|
||||
Create EngineClient, either:
|
||||
- in-process using the AsyncLLMEngine Directly
|
||||
- multiprocess using AsyncLLMEngine RPC
|
||||
|
||||
Returns the Client or None if the creation failed.
|
||||
"""
|
||||
|
||||
# Fall back
|
||||
# TODO: fill out feature matrix.
|
||||
if (MQLLMEngineClient.is_unsupported_config(engine_args)
|
||||
or disable_frontend_multiprocessing):
|
||||
engine_config = engine_args.create_engine_config()
|
||||
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
|
||||
"uses_ray", False)
|
||||
|
||||
build_engine = partial(AsyncLLMEngine.from_engine_args,
|
||||
engine_args=engine_args,
|
||||
engine_config=engine_config,
|
||||
usage_context=UsageContext.OPENAI_API_SERVER)
|
||||
if uses_ray:
|
||||
# Must run in main thread with ray for its signal handlers to work
|
||||
engine_client = build_engine()
|
||||
else:
|
||||
engine_client = await asyncio.get_running_loop().run_in_executor(
|
||||
None, build_engine)
|
||||
|
||||
yield engine_client
|
||||
return
|
||||
|
||||
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
||||
else:
|
||||
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
|
||||
# Make TemporaryDirectory for prometheus multiprocessing
|
||||
# Note: global TemporaryDirectory will be automatically
|
||||
# cleaned up upon exit.
|
||||
global prometheus_multiproc_dir
|
||||
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
||||
os.environ[
|
||||
"PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
|
||||
else:
|
||||
logger.warning(
|
||||
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
|
||||
"This directory must be wiped between vLLM runs or "
|
||||
"you will find inaccurate metrics. Unset the variable "
|
||||
"and vLLM will properly handle cleanup.")
|
||||
|
||||
# Select random path for IPC.
|
||||
ipc_path = get_open_zmq_ipc_path()
|
||||
logger.info("Multiprocessing frontend to use %s for IPC Path.",
|
||||
ipc_path)
|
||||
|
||||
# Start RPCServer in separate process (holds the LLMEngine).
|
||||
# the current process might have CUDA context,
|
||||
# so we need to spawn a new process
|
||||
context = multiprocessing.get_context("spawn")
|
||||
|
||||
engine_process = context.Process(target=run_mp_engine,
|
||||
args=(engine_args,
|
||||
UsageContext.OPENAI_API_SERVER,
|
||||
ipc_path))
|
||||
engine_process.start()
|
||||
logger.info("Started engine process with PID %d", engine_process.pid)
|
||||
|
||||
# Build RPCClient, which conforms to EngineClient Protocol.
|
||||
# NOTE: Actually, this is not true yet. We still need to support
|
||||
# embedding models via RPC (see TODO above)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
await mp_engine_client.setup()
|
||||
break
|
||||
except TimeoutError:
|
||||
if not engine_process.is_alive():
|
||||
raise RuntimeError(
|
||||
"Engine process failed to start") from None
|
||||
|
||||
yield mp_engine_client # type: ignore[misc]
|
||||
finally:
|
||||
# Ensure rpc server process was terminated
|
||||
engine_process.terminate()
|
||||
|
||||
# Close all open connections to the backend
|
||||
mp_engine_client.close()
|
||||
|
||||
# Wait for engine process to join
|
||||
engine_process.join(4)
|
||||
if engine_process.exitcode is None:
|
||||
# Kill if taking longer than 5 seconds to stop
|
||||
engine_process.kill()
|
||||
|
||||
# Lazy import for prometheus multiprocessing.
|
||||
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||
# before prometheus_client is imported.
|
||||
# See https://prometheus.github.io/client_python/multiprocess/
|
||||
from prometheus_client import multiprocess
|
||||
multiprocess.mark_process_dead(engine_process.pid)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def mount_metrics(app: FastAPI):
|
||||
# Lazy import for prometheus multiprocessing.
|
||||
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||
# before prometheus_client is imported.
|
||||
# See https://prometheus.github.io/client_python/multiprocess/
|
||||
from prometheus_client import (CollectorRegistry, make_asgi_app,
|
||||
multiprocess)
|
||||
|
||||
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
|
||||
if prometheus_multiproc_dir_path is not None:
|
||||
logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
|
||||
prometheus_multiproc_dir_path)
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
|
||||
# Add prometheus asgi middleware to route /metrics requests
|
||||
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||
else:
|
||||
# Add prometheus asgi middleware to route /metrics requests
|
||||
metrics_route = Mount("/metrics", make_asgi_app())
|
||||
|
||||
# Workaround for 307 Redirect for /metrics
|
||||
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||
app.routes.append(metrics_route)
|
||||
|
||||
|
||||
def chat(request: Request) -> OpenAIServingChat:
|
||||
return request.app.state.openai_serving_chat
|
||||
|
||||
|
||||
def completion(request: Request) -> OpenAIServingCompletion:
|
||||
return request.app.state.openai_serving_completion
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
def embedding(request: Request) -> OpenAIServingEmbedding:
|
||||
return request.app.state.openai_serving_embedding
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health(raw_request: Request) -> Response:
|
||||
"""Health check."""
|
||||
await engine_client(raw_request).check_health()
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.post("/tokenize")
|
||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
generator = await tokenization(raw_request).create_tokenize(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, TokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/detokenize")
|
||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||
generator = await tokenization(raw_request).create_detokenize(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, DetokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def show_available_models(raw_request: Request):
|
||||
models = await completion(raw_request).show_available_models()
|
||||
return JSONResponse(content=models.model_dump())
|
||||
|
||||
|
||||
@router.get("/version")
|
||||
async def show_version():
|
||||
ver = {"version": VLLM_VERSION}
|
||||
return JSONResponse(content=ver)
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def create_chat_completion(request: ChatCompletionRequest,
|
||||
raw_request: Request):
|
||||
|
||||
generator = await chat(raw_request).create_chat_completion(
|
||||
request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
|
||||
elif isinstance(generator, ChatCompletionResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
generator = await completion(raw_request).create_completion(
|
||||
request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
generator = await embedding(raw_request).create_embedding(
|
||||
request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, EmbeddingResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!")
|
||||
|
||||
@router.post("/start_profile")
|
||||
async def start_profile(raw_request: Request):
|
||||
logger.info("Starting profiler...")
|
||||
await engine_client(raw_request).start_profile()
|
||||
logger.info("Profiler started.")
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/stop_profile")
|
||||
async def stop_profile(raw_request: Request):
|
||||
logger.info("Stopping profiler...")
|
||||
await engine_client(raw_request).stop_profile()
|
||||
logger.info("Profiler stopped.")
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
logger.warning(
|
||||
"Lora dynamic loading & unloading is enabled in the API server. "
|
||||
"This should ONLY be used for local development!")
|
||||
|
||||
@router.post("/v1/load_lora_adapter")
|
||||
async def load_lora_adapter(request: LoadLoraAdapterRequest,
|
||||
raw_request: Request):
|
||||
response = await chat(raw_request).load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
response = await completion(raw_request).load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
@router.post("/v1/unload_lora_adapter")
|
||||
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
|
||||
raw_request: Request):
|
||||
response = await chat(raw_request).unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
response = await completion(raw_request).unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
if args.disable_fastapi_docs:
|
||||
app = FastAPI(openapi_url=None,
|
||||
docs_url=None,
|
||||
redoc_url=None,
|
||||
lifespan=lifespan)
|
||||
else:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.include_router(router)
|
||||
app.root_path = args.root_path
|
||||
|
||||
mount_metrics(app)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=args.allowed_origins,
|
||||
allow_credentials=args.allow_credentials,
|
||||
allow_methods=args.allowed_methods,
|
||||
allow_headers=args.allowed_headers,
|
||||
)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(_, exc):
|
||||
chat = app.state.openai_serving_chat
|
||||
err = chat.create_error_response(message=str(exc))
|
||||
return JSONResponse(err.model_dump(),
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
if token := envs.VLLM_API_KEY or args.api_key:
|
||||
|
||||
@app.middleware("http")
|
||||
async def authentication(request: Request, call_next):
|
||||
root_path = "" if args.root_path is None else args.root_path
|
||||
if request.method == "OPTIONS":
|
||||
return await call_next(request)
|
||||
if not request.url.path.startswith(f"{root_path}/v1"):
|
||||
return await call_next(request)
|
||||
if request.headers.get("Authorization") != "Bearer " + token:
|
||||
return JSONResponse(content={"error": "Unauthorized"},
|
||||
status_code=401)
|
||||
return await call_next(request)
|
||||
|
||||
for middleware in args.middleware:
|
||||
module_path, object_name = middleware.rsplit(".", 1)
|
||||
imported = getattr(importlib.import_module(module_path), object_name)
|
||||
if inspect.isclass(imported):
|
||||
app.add_middleware(imported)
|
||||
elif inspect.iscoroutinefunction(imported):
|
||||
app.middleware("http")(imported)
|
||||
else:
|
||||
raise ValueError(f"Invalid middleware {middleware}. "
|
||||
f"Must be a function or a class.")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def init_app_state(
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
state: State,
|
||||
args: Namespace,
|
||||
) -> None:
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
if args.disable_log_requests:
|
||||
request_logger = None
|
||||
else:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model)
|
||||
for name in served_model_names
|
||||
]
|
||||
|
||||
state.engine_client = engine_client
|
||||
state.log_stats = not args.disable_log_stats
|
||||
|
||||
state.openai_serving_chat = OpenAIServingChat(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
args.response_role,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser)
|
||||
state.openai_serving_completion = OpenAIServingCompletion(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
)
|
||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
)
|
||||
|
||||
|
||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
valide_tool_parses = ToolParserManager.tool_parsers.keys()
|
||||
if args.enable_auto_tool_choice \
|
||||
and args.tool_call_parser not in valide_tool_parses:
|
||||
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
|
||||
f"(chose from {{ {','.join(valide_tool_parses)} }})")
|
||||
|
||||
# workaround to make sure that we bind the port before the engine is set up.
|
||||
# This avoids race conditions with ray.
|
||||
# see https://github.com/vllm-project/vllm/issues/8204
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.bind(("", args.port))
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm while initializing
|
||||
raise KeyboardInterrupt("terminated")
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
async with build_async_engine_client(args) as engine_client:
|
||||
app = build_app(args)
|
||||
|
||||
model_config = await engine_client.get_model_config()
|
||||
init_app_state(engine_client, model_config, app.state, args)
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.uvicorn_log_level,
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
ssl_ca_certs=args.ssl_ca_certs,
|
||||
ssl_cert_reqs=args.ssl_cert_reqs,
|
||||
fd=sock.fileno(),
|
||||
**uvicorn_kwargs,
|
||||
)
|
||||
|
||||
# NB: Await server shutdown only after the backend context is exited
|
||||
await shutdown_task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# NOTE(simon):
|
||||
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args()
|
||||
validate_parsed_serve_args(args)
|
||||
|
||||
uvloop.run(run_server(args))
|
||||
252
vllm/entrypoints/openai/cli_args.py
Normal file
252
vllm/entrypoints/openai/cli_args.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""
|
||||
This file contains the command line arguments for the vLLM's
|
||||
OpenAI-compatible server. It is kept in a separate file for documentation
|
||||
purposes.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import ssl
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.entrypoints.chat_utils import validate_chat_template
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
PromptAdapterPath)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class LoRAParserAction(argparse.Action):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
parser: argparse.ArgumentParser,
|
||||
namespace: argparse.Namespace,
|
||||
values: Optional[Union[str, Sequence[str]]],
|
||||
option_string: Optional[str] = None,
|
||||
):
|
||||
if values is None:
|
||||
values = []
|
||||
if isinstance(values, str):
|
||||
raise TypeError("Expected values to be a list")
|
||||
|
||||
lora_list: List[LoRAModulePath] = []
|
||||
for item in values:
|
||||
if item in [None, '']: # Skip if item is None or empty string
|
||||
continue
|
||||
if '=' in item and ',' not in item: # Old format: name=path
|
||||
name, path = item.split('=')
|
||||
lora_list.append(LoRAModulePath(name, path))
|
||||
else: # Assume JSON format
|
||||
try:
|
||||
lora_dict = json.loads(item)
|
||||
lora = LoRAModulePath(**lora_dict)
|
||||
lora_list.append(lora)
|
||||
except json.JSONDecodeError:
|
||||
parser.error(
|
||||
f"Invalid JSON format for --lora-modules: {item}")
|
||||
except TypeError as e:
|
||||
parser.error(
|
||||
f"Invalid fields for --lora-modules: {item} - {str(e)}"
|
||||
)
|
||||
setattr(namespace, self.dest, lora_list)
|
||||
|
||||
|
||||
class PromptAdapterParserAction(argparse.Action):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
parser: argparse.ArgumentParser,
|
||||
namespace: argparse.Namespace,
|
||||
values: Optional[Union[str, Sequence[str]]],
|
||||
option_string: Optional[str] = None,
|
||||
):
|
||||
if values is None:
|
||||
values = []
|
||||
if isinstance(values, str):
|
||||
raise TypeError("Expected values to be a list")
|
||||
|
||||
adapter_list: List[PromptAdapterPath] = []
|
||||
for item in values:
|
||||
name, path = item.split('=')
|
||||
adapter_list.append(PromptAdapterPath(name, path))
|
||||
setattr(namespace, self.dest, adapter_list)
|
||||
|
||||
|
||||
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
parser.add_argument("--host",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="host name")
|
||||
parser.add_argument("--port", type=int, default=8000, help="port number")
|
||||
parser.add_argument(
|
||||
"--uvicorn-log-level",
|
||||
type=str,
|
||||
default="info",
|
||||
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
|
||||
help="log level for uvicorn")
|
||||
parser.add_argument("--allow-credentials",
|
||||
action="store_true",
|
||||
help="allow credentials")
|
||||
parser.add_argument("--allowed-origins",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed origins")
|
||||
parser.add_argument("--allowed-methods",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed methods")
|
||||
parser.add_argument("--allowed-headers",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed headers")
|
||||
parser.add_argument("--api-key",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="If provided, the server will require this key "
|
||||
"to be presented in the header.")
|
||||
parser.add_argument(
|
||||
"--lora-modules",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
nargs='+',
|
||||
action=LoRAParserAction,
|
||||
help="LoRA module configurations in either 'name=path' format"
|
||||
"or JSON format. "
|
||||
"Example (old format): 'name=path' "
|
||||
"Example (new format): "
|
||||
"'{\"name\": \"name\", \"local_path\": \"path\", "
|
||||
"\"base_model_name\": \"id\"}'")
|
||||
parser.add_argument(
|
||||
"--prompt-adapters",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
nargs='+',
|
||||
action=PromptAdapterParserAction,
|
||||
help="Prompt adapter configurations in the format name=path. "
|
||||
"Multiple adapters can be specified.")
|
||||
parser.add_argument("--chat-template",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="The file path to the chat template, "
|
||||
"or the template in single-line form "
|
||||
"for the specified model")
|
||||
parser.add_argument("--response-role",
|
||||
type=nullable_str,
|
||||
default="assistant",
|
||||
help="The role name to return if "
|
||||
"`request.add_generation_prompt=true`.")
|
||||
parser.add_argument("--ssl-keyfile",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="The file path to the SSL key file")
|
||||
parser.add_argument("--ssl-certfile",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="The file path to the SSL cert file")
|
||||
parser.add_argument("--ssl-ca-certs",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="The CA certificates file")
|
||||
parser.add_argument(
|
||||
"--ssl-cert-reqs",
|
||||
type=int,
|
||||
default=int(ssl.CERT_NONE),
|
||||
help="Whether client certificate is required (see stdlib ssl module's)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root-path",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="FastAPI root_path when app is behind a path based routing proxy")
|
||||
parser.add_argument(
|
||||
"--middleware",
|
||||
type=nullable_str,
|
||||
action="append",
|
||||
default=[],
|
||||
help="Additional ASGI middleware to apply to the app. "
|
||||
"We accept multiple --middleware arguments. "
|
||||
"The value should be an import path. "
|
||||
"If a function is provided, vLLM will add it to the server "
|
||||
"using @app.middleware('http'). "
|
||||
"If a class is provided, vLLM will add it to the server "
|
||||
"using app.add_middleware(). ")
|
||||
parser.add_argument(
|
||||
"--return-tokens-as-token-ids",
|
||||
action="store_true",
|
||||
help="When --max-logprobs is specified, represents single tokens as "
|
||||
"strings of the form 'token_id:{token_id}' so that tokens that "
|
||||
"are not JSON-encodable can be identified.")
|
||||
parser.add_argument(
|
||||
"--disable-frontend-multiprocessing",
|
||||
action="store_true",
|
||||
help="If specified, will run the OpenAI frontend server in the same "
|
||||
"process as the model serving engine.")
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-auto-tool-choice",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=
|
||||
"Enable auto tool choice for supported models. Use --tool-call-parser"
|
||||
"to specify which parser to use")
|
||||
|
||||
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
|
||||
parser.add_argument(
|
||||
"--tool-call-parser",
|
||||
type=str,
|
||||
metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
|
||||
"--tool-parser-plugin",
|
||||
default=None,
|
||||
help=
|
||||
"Select the tool call parser depending on the model that you're using."
|
||||
" This is used to parse the model-generated tool call into OpenAI API "
|
||||
"format. Required for --enable-auto-tool-choice.")
|
||||
|
||||
parser.add_argument(
|
||||
"--tool-parser-plugin",
|
||||
type=str,
|
||||
default="",
|
||||
help=
|
||||
"Special the tool parser plugin write to parse the model-generated tool"
|
||||
" into OpenAI API format, the name register in this plugin can be used "
|
||||
"in --tool-call-parser.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument('--max-log-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Max number of prompt characters or prompt '
|
||||
'ID numbers being printed in log.'
|
||||
'\n\nDefault: Unlimited')
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-fastapi-docs",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def validate_parsed_serve_args(args: argparse.Namespace):
|
||||
"""Quick checks for model serve args that raise prior to loading."""
|
||||
if hasattr(args, "subparser") and args.subparser != "serve":
|
||||
return
|
||||
|
||||
# Ensure that the chat template is valid; raises if it likely isn't
|
||||
validate_chat_template(args.chat_template)
|
||||
|
||||
# Enable auto tool needs a tool call parser to be valid
|
||||
if args.enable_auto_tool_choice and not args.tool_call_parser:
|
||||
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||
"--tool-call-parser")
|
||||
|
||||
|
||||
def create_parser_for_docs() -> FlexibleArgumentParser:
|
||||
parser_for_docs = FlexibleArgumentParser(
|
||||
prog="-m vllm.entrypoints.openai.api_server")
|
||||
return make_arg_parser(parser_for_docs)
|
||||
86
vllm/entrypoints/openai/logits_processors.py
Normal file
86
vllm/entrypoints/openai/logits_processors.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from functools import lru_cache, partial
|
||||
from typing import Dict, FrozenSet, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
class AllowedTokenIdsLogitsProcessor:
|
||||
"""Logits processor for constraining generated tokens to a
|
||||
specific set of token ids."""
|
||||
|
||||
def __init__(self, allowed_ids: Iterable[int]):
|
||||
self.allowed_ids: Optional[List[int]] = list(allowed_ids)
|
||||
self.mask: Optional[torch.Tensor] = None
|
||||
|
||||
def __call__(self, token_ids: List[int],
|
||||
logits: torch.Tensor) -> torch.Tensor:
|
||||
if self.mask is None:
|
||||
self.mask = torch.ones((logits.shape[-1], ),
|
||||
dtype=torch.bool,
|
||||
device=logits.device)
|
||||
self.mask[self.allowed_ids] = False
|
||||
self.allowed_ids = None
|
||||
logits.masked_fill_(self.mask, float("-inf"))
|
||||
return logits
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_allowed_token_ids_logits_processor(
|
||||
allowed_token_ids: FrozenSet[int],
|
||||
vocab_size: int,
|
||||
) -> LogitsProcessor:
|
||||
if not allowed_token_ids:
|
||||
raise ValueError("Empty allowed_token_ids provided")
|
||||
if not all(0 <= tid < vocab_size for tid in allowed_token_ids):
|
||||
raise ValueError("allowed_token_ids contains "
|
||||
"out-of-vocab token id")
|
||||
return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
|
||||
|
||||
|
||||
def logit_bias_logits_processor(
|
||||
logit_bias: Dict[int, float],
|
||||
token_ids: List[int],
|
||||
logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
for token_id, bias in logit_bias.items():
|
||||
logits[token_id] += bias
|
||||
return logits
|
||||
|
||||
|
||||
def get_logits_processors(
|
||||
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
|
||||
allowed_token_ids: Optional[List[int]],
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> List[LogitsProcessor]:
|
||||
logits_processors: List[LogitsProcessor] = []
|
||||
if logit_bias:
|
||||
try:
|
||||
# Convert token_id to integer
|
||||
# Clamp the bias between -100 and 100 per OpenAI API spec
|
||||
clamped_logit_bias: Dict[int, float] = {
|
||||
int(token_id): min(100.0, max(-100.0, bias))
|
||||
for token_id, bias in logit_bias.items()
|
||||
}
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
"Found token_id in logit_bias that is not "
|
||||
"an integer or string representing an integer") from exc
|
||||
|
||||
# Check if token_id is within the vocab size
|
||||
for token_id, bias in clamped_logit_bias.items():
|
||||
if token_id < 0 or token_id >= tokenizer.vocab_size:
|
||||
raise ValueError(f"token_id {token_id} in logit_bias contains "
|
||||
"out-of-vocab token id")
|
||||
|
||||
logits_processors.append(
|
||||
partial(logit_bias_logits_processor, clamped_logit_bias))
|
||||
|
||||
if allowed_token_ids is not None:
|
||||
logits_processors.append(
|
||||
_get_allowed_token_ids_logits_processor(
|
||||
frozenset(allowed_token_ids), tokenizer.vocab_size))
|
||||
|
||||
return logits_processors
|
||||
992
vllm/entrypoints/openai/protocol.py
Normal file
992
vllm/entrypoints/openai/protocol.py
Normal file
@@ -0,0 +1,992 @@
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||
import time
|
||||
from argparse import Namespace
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from openai.types.chat import ChatCompletionContentPartParam
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Annotated, Required, TypedDict
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
RequestOutputKind, SamplingParams)
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
# torch is mocked during docs generation,
|
||||
# so we have to provide the values as literals
|
||||
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
|
||||
_LONG_INFO: Union["torch.iinfo", Namespace]
|
||||
|
||||
try:
|
||||
from sphinx.ext.autodoc.mock import _MockModule
|
||||
|
||||
if isinstance(torch, _MockModule):
|
||||
_LONG_INFO = _MOCK_LONG_INFO
|
||||
else:
|
||||
_LONG_INFO = torch.iinfo(torch.long)
|
||||
except ModuleNotFoundError:
|
||||
_LONG_INFO = torch.iinfo(torch.long)
|
||||
|
||||
assert _LONG_INFO.min == _MOCK_LONG_INFO.min
|
||||
assert _LONG_INFO.max == _MOCK_LONG_INFO.max
|
||||
|
||||
|
||||
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
"""Enables custom roles in the Chat Completion API."""
|
||||
role: Required[str]
|
||||
"""The role of the message's author."""
|
||||
|
||||
content: Union[str, List[ChatCompletionContentPartParam]]
|
||||
"""The contents of the message."""
|
||||
|
||||
name: str
|
||||
"""An optional name for the participant.
|
||||
|
||||
Provides the model information to differentiate between participants of the
|
||||
same role.
|
||||
"""
|
||||
|
||||
tool_call_id: Optional[str]
|
||||
|
||||
tool_calls: Optional[List[dict]]
|
||||
|
||||
|
||||
class OpenAIBaseModel(BaseModel):
|
||||
# OpenAI API does not allow extra fields
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ErrorResponse(OpenAIBaseModel):
|
||||
object: str = "error"
|
||||
message: str
|
||||
type: str
|
||||
param: Optional[str] = None
|
||||
code: int
|
||||
|
||||
|
||||
class ModelPermission(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
|
||||
object: str = "model_permission"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
allow_create_engine: bool = False
|
||||
allow_sampling: bool = True
|
||||
allow_logprobs: bool = True
|
||||
allow_search_indices: bool = False
|
||||
allow_view: bool = True
|
||||
allow_fine_tuning: bool = False
|
||||
organization: str = "*"
|
||||
group: Optional[str] = None
|
||||
is_blocking: bool = False
|
||||
|
||||
|
||||
class ModelCard(OpenAIBaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "vllm"
|
||||
root: Optional[str] = None
|
||||
parent: Optional[str] = None
|
||||
max_model_len: Optional[int] = None
|
||||
permission: List[ModelPermission] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ModelList(OpenAIBaseModel):
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = Field(default_factory=list)
|
||||
|
||||
|
||||
class UsageInfo(OpenAIBaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
|
||||
|
||||
class RequestResponseMetadata(BaseModel):
|
||||
request_id: str
|
||||
final_usage_info: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
class JsonSchemaResponseFormat(OpenAIBaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
# schema is the field in openai but that causes conflicts with pydantic so
|
||||
# instead use json_schema with an alias
|
||||
json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema')
|
||||
strict: Optional[bool] = None
|
||||
|
||||
|
||||
class ResponseFormat(OpenAIBaseModel):
|
||||
# type must be "json_schema", "json_object" or "text"
|
||||
type: Literal["text", "json_object", "json_schema"]
|
||||
json_schema: Optional[JsonSchemaResponseFormat] = None
|
||||
|
||||
|
||||
class StreamOptions(OpenAIBaseModel):
|
||||
include_usage: Optional[bool] = True
|
||||
continuous_usage_stats: Optional[bool] = True
|
||||
|
||||
|
||||
class FunctionDefinition(OpenAIBaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ChatCompletionToolsParam(OpenAIBaseModel):
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionDefinition
|
||||
|
||||
|
||||
class ChatCompletionNamedFunction(OpenAIBaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
|
||||
function: ChatCompletionNamedFunction
|
||||
type: Literal["function"] = "function"
|
||||
|
||||
|
||||
class ChatCompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
model: str
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: Optional[bool] = False
|
||||
top_logprobs: Optional[int] = 0
|
||||
max_tokens: Optional[int] = None
|
||||
n: Optional[int] = 1
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
temperature: Optional[float] = 0.7
|
||||
top_p: Optional[float] = 1.0
|
||||
tools: Optional[List[ChatCompletionToolsParam]] = None
|
||||
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
|
||||
ChatCompletionNamedToolChoiceParam]] = "none"
|
||||
|
||||
# NOTE this will be ignored by VLLM -- the model determines the behavior
|
||||
parallel_tool_calls: Optional[bool] = False
|
||||
user: Optional[str] = None
|
||||
|
||||
# doc: begin-chat-completion-sampling-params
|
||||
best_of: Optional[int] = None
|
||||
use_beam_search: bool = False
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
repetition_penalty: float = 1.0
|
||||
length_penalty: float = 1.0
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
include_stop_str_in_output: bool = False
|
||||
ignore_eos: bool = False
|
||||
min_tokens: int = 0
|
||||
skip_special_tokens: bool = True
|
||||
spaces_between_special_tokens: bool = True
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
prompt_logprobs: Optional[int] = None
|
||||
# doc: end-chat-completion-sampling-params
|
||||
|
||||
# doc: begin-chat-completion-extra-params
|
||||
echo: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, the new message will be prepended with the last message "
|
||||
"if they belong to the same role."),
|
||||
)
|
||||
add_generation_prompt: bool = Field(
|
||||
default=True,
|
||||
description=
|
||||
("If true, the generation prompt will be added to the chat template. "
|
||||
"This is a parameter used by chat template in tokenizer config of the "
|
||||
"model."),
|
||||
)
|
||||
continue_final_message: bool = Field(
|
||||
default=False,
|
||||
description=
|
||||
("If this is set, the chat will be formatted so that the final "
|
||||
"message in the chat is open-ended, without any EOS tokens. The "
|
||||
"model will continue this message rather than starting a new one. "
|
||||
"This allows you to \"prefill\" part of the model's response for it. "
|
||||
"Cannot be used at the same time as `add_generation_prompt`."),
|
||||
)
|
||||
add_special_tokens: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||
"on top of what is added by the chat template. "
|
||||
"For most models, the chat template takes care of adding the "
|
||||
"special tokens so this should be set to false (as is the "
|
||||
"default)."),
|
||||
)
|
||||
documents: Optional[List[Dict[str, str]]] = Field(
|
||||
default=None,
|
||||
description=
|
||||
("A list of dicts representing documents that will be accessible to "
|
||||
"the model if it is performing RAG (retrieval-augmented generation)."
|
||||
" If the template does not support RAG, this argument will have no "
|
||||
"effect. We recommend that each document should be a dict containing "
|
||||
"\"title\" and \"text\" keys."),
|
||||
)
|
||||
chat_template: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A Jinja template to use for this conversion. "
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."),
|
||||
)
|
||||
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."),
|
||||
)
|
||||
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
|
||||
default=None,
|
||||
description=("If specified, the output will follow the JSON schema."),
|
||||
)
|
||||
guided_regex: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the output will follow the regex pattern."),
|
||||
)
|
||||
guided_choice: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the output will be exactly one of the choices."),
|
||||
)
|
||||
guided_grammar: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the output will follow the context free grammar."),
|
||||
)
|
||||
guided_decoding_backend: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default guided decoding backend "
|
||||
"of the server for this specific request. If set, must be either "
|
||||
"'outlines' / 'lm-format-enforcer'"))
|
||||
guided_whitespace_pattern: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default whitespace pattern "
|
||||
"for guided json decoding."))
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
|
||||
# doc: end-chat-completion-extra-params
|
||||
|
||||
def to_beam_search_params(self,
|
||||
default_max_tokens: int) -> BeamSearchParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
n = self.n if self.n is not None else 1
|
||||
temperature = self.temperature if self.temperature is not None else 0.0
|
||||
|
||||
return BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=max_tokens,
|
||||
ignore_eos=self.ignore_eos,
|
||||
temperature=temperature,
|
||||
length_penalty=self.length_penalty,
|
||||
)
|
||||
|
||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
prompt_logprobs = self.prompt_logprobs
|
||||
if prompt_logprobs is None and self.echo:
|
||||
prompt_logprobs = self.top_logprobs
|
||||
|
||||
guided_json_object = None
|
||||
if (self.response_format is not None
|
||||
and self.response_format.type == "json_object"):
|
||||
guided_json_object = True
|
||||
|
||||
guided_decoding = GuidedDecodingParams.from_optional(
|
||||
json=self._get_guided_json_from_tool() or self.guided_json,
|
||||
regex=self.guided_regex,
|
||||
choice=self.guided_choice,
|
||||
grammar=self.guided_grammar,
|
||||
json_object=guided_json_object,
|
||||
backend=self.guided_decoding_backend,
|
||||
whitespace_pattern=self.guided_whitespace_pattern)
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
best_of=self.best_of,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
min_p=self.min_p,
|
||||
seed=self.seed,
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
logprobs=self.top_logprobs if self.logprobs else None,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
ignore_eos=self.ignore_eos,
|
||||
max_tokens=max_tokens,
|
||||
min_tokens=self.min_tokens,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
guided_decoding=guided_decoding,
|
||||
logit_bias=self.logit_bias)
|
||||
|
||||
def _get_guided_json_from_tool(
|
||||
self) -> Optional[Union[str, dict, BaseModel]]:
|
||||
# user has chosen to not use any tool
|
||||
if self.tool_choice == "none" or self.tools is None:
|
||||
return None
|
||||
|
||||
# user has chosen to use a named tool
|
||||
if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||
tool_name = self.tool_choice.function.name
|
||||
tools = {tool.function.name: tool.function for tool in self.tools}
|
||||
if tool_name not in tools:
|
||||
raise ValueError(
|
||||
f"Tool '{tool_name}' has not been passed in `tools`.")
|
||||
tool = tools[tool_name]
|
||||
return tool.parameters
|
||||
|
||||
return None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
if data.get("stream_options") and not data.get("stream"):
|
||||
raise ValueError(
|
||||
"Stream options can only be defined when `stream=True`.")
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||
if data.get("stream") and prompt_logprobs > 0:
|
||||
raise ValueError(
|
||||
"`prompt_logprobs` are not available when `stream=True`.")
|
||||
|
||||
if prompt_logprobs < 0:
|
||||
raise ValueError("`prompt_logprobs` must be a positive value.")
|
||||
|
||||
if (top_logprobs := data.get("top_logprobs")) is not None:
|
||||
if top_logprobs < 0:
|
||||
raise ValueError("`top_logprobs` must be a positive value.")
|
||||
|
||||
if not data.get("logprobs"):
|
||||
raise ValueError(
|
||||
"when using `top_logprobs`, `logprobs` must be set to true."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_guided_decoding_count(cls, data):
|
||||
if isinstance(data, ValueError):
|
||||
raise data
|
||||
|
||||
guide_count = sum([
|
||||
"guided_json" in data and data["guided_json"] is not None,
|
||||
"guided_regex" in data and data["guided_regex"] is not None,
|
||||
"guided_choice" in data and data["guided_choice"] is not None
|
||||
])
|
||||
# you can only use one kind of guided decoding
|
||||
if guide_count > 1:
|
||||
raise ValueError(
|
||||
"You can only use one kind of guided decoding "
|
||||
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||
# you can only either use guided decoding or tools, not both
|
||||
if guide_count > 1 and data.get("tool_choice",
|
||||
"none") not in ("none", "auto"):
|
||||
raise ValueError(
|
||||
"You can only either use guided decoding or tools, not both.")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_tool_usage(cls, data):
|
||||
|
||||
# if "tool_choice" is not specified but tools are provided,
|
||||
# default to "auto" tool_choice
|
||||
if "tool_choice" not in data and data.get("tools"):
|
||||
data["tool_choice"] = "auto"
|
||||
|
||||
# if "tool_choice" is specified -- validation
|
||||
if "tool_choice" in data:
|
||||
|
||||
# ensure that if "tool choice" is specified, tools are present
|
||||
if "tools" not in data or data["tools"] is None:
|
||||
raise ValueError(
|
||||
"When using `tool_choice`, `tools` must be set.")
|
||||
|
||||
# make sure that tool choice is either a named tool
|
||||
# OR that it's set to "auto"
|
||||
if data["tool_choice"] != "auto" and not isinstance(
|
||||
data["tool_choice"], dict):
|
||||
raise ValueError(
|
||||
"`tool_choice` must either be a named tool or \"auto\". "
|
||||
"`tool_choice=\"none\" is not supported.")
|
||||
|
||||
# ensure that if "tool_choice" is specified as an object,
|
||||
# it matches a valid tool
|
||||
if isinstance(data["tool_choice"], dict):
|
||||
valid_tool = False
|
||||
specified_function = data["tool_choice"]["function"]
|
||||
if not specified_function:
|
||||
raise ValueError(
|
||||
"Incorrectly formatted `tool_choice`. Should be like "
|
||||
"`{\"type\": \"function\","
|
||||
" \"function\": {\"name\": \"my_function\"}}`")
|
||||
specified_function_name = specified_function["name"]
|
||||
if not specified_function_name:
|
||||
raise ValueError(
|
||||
"Incorrectly formatted `tool_choice`. Should be like "
|
||||
"`{\"type\": \"function\", "
|
||||
"\"function\": {\"name\": \"my_function\"}}`")
|
||||
for tool in data["tools"]:
|
||||
if tool["function"]["name"] == specified_function_name:
|
||||
valid_tool = True
|
||||
break
|
||||
if not valid_tool:
|
||||
raise ValueError(
|
||||
"The tool specified in `tool_choice` does not match any"
|
||||
" of the specified `tools`")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_generation_prompt(cls, data):
|
||||
if data.get("continue_final_message") and data.get(
|
||||
"add_generation_prompt"):
|
||||
raise ValueError("Cannot set both `continue_final_message` and "
|
||||
"`add_generation_prompt` to True.")
|
||||
return data
|
||||
|
||||
|
||||
class CompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/completions/create
|
||||
model: str
|
||||
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||
best_of: Optional[int] = None
|
||||
echo: Optional[bool] = False
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: Optional[int] = None
|
||||
max_tokens: Optional[int] = 16
|
||||
n: int = 1
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
suffix: Optional[str] = None
|
||||
temperature: Optional[float] = 1.0
|
||||
top_p: Optional[float] = 1.0
|
||||
user: Optional[str] = None
|
||||
|
||||
# doc: begin-completion-sampling-params
|
||||
use_beam_search: bool = False
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
repetition_penalty: float = 1.0
|
||||
length_penalty: float = 1.0
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
include_stop_str_in_output: bool = False
|
||||
ignore_eos: bool = False
|
||||
min_tokens: int = 0
|
||||
skip_special_tokens: bool = True
|
||||
spaces_between_special_tokens: bool = True
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
allowed_token_ids: Optional[List[int]] = None
|
||||
prompt_logprobs: Optional[int] = None
|
||||
# doc: end-completion-sampling-params
|
||||
|
||||
# doc: begin-completion-extra-params
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||
"the prompt."),
|
||||
)
|
||||
response_format: Optional[ResponseFormat] = Field(
|
||||
default=None,
|
||||
description=
|
||||
("Similar to chat completion, this parameter specifies the format of "
|
||||
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
|
||||
"supported."),
|
||||
)
|
||||
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
|
||||
default=None,
|
||||
description="If specified, the output will follow the JSON schema.",
|
||||
)
|
||||
guided_regex: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the output will follow the regex pattern."),
|
||||
)
|
||||
guided_choice: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the output will be exactly one of the choices."),
|
||||
)
|
||||
guided_grammar: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the output will follow the context free grammar."),
|
||||
)
|
||||
guided_decoding_backend: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default guided decoding backend "
|
||||
"of the server for this specific request. If set, must be one of "
|
||||
"'outlines' / 'lm-format-enforcer'"))
|
||||
guided_whitespace_pattern: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default whitespace pattern "
|
||||
"for guided json decoding."))
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
def to_beam_search_params(self,
|
||||
default_max_tokens: int) -> BeamSearchParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
n = self.n if self.n is not None else 1
|
||||
temperature = self.temperature if self.temperature is not None else 0.0
|
||||
|
||||
return BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=max_tokens,
|
||||
ignore_eos=self.ignore_eos,
|
||||
temperature=temperature,
|
||||
length_penalty=self.length_penalty,
|
||||
)
|
||||
|
||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
prompt_logprobs = self.prompt_logprobs
|
||||
if prompt_logprobs is None and self.echo:
|
||||
prompt_logprobs = self.logprobs
|
||||
|
||||
echo_without_generation = self.echo and self.max_tokens == 0
|
||||
|
||||
guided_json_object = None
|
||||
if (self.response_format is not None
|
||||
and self.response_format.type == "json_object"):
|
||||
guided_json_object = True
|
||||
|
||||
guided_decoding = GuidedDecodingParams.from_optional(
|
||||
json=self.guided_json,
|
||||
regex=self.guided_regex,
|
||||
choice=self.guided_choice,
|
||||
grammar=self.guided_grammar,
|
||||
json_object=guided_json_object,
|
||||
backend=self.guided_decoding_backend,
|
||||
whitespace_pattern=self.guided_whitespace_pattern)
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
best_of=self.best_of,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
min_p=self.min_p,
|
||||
seed=self.seed,
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
logprobs=self.logprobs,
|
||||
ignore_eos=self.ignore_eos,
|
||||
max_tokens=max_tokens if not echo_without_generation else 1,
|
||||
min_tokens=self.min_tokens,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
guided_decoding=guided_decoding,
|
||||
logit_bias=self.logit_bias,
|
||||
allowed_token_ids=self.allowed_token_ids)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_guided_decoding_count(cls, data):
|
||||
guide_count = sum([
|
||||
"guided_json" in data and data["guided_json"] is not None,
|
||||
"guided_regex" in data and data["guided_regex"] is not None,
|
||||
"guided_choice" in data and data["guided_choice"] is not None
|
||||
])
|
||||
if guide_count > 1:
|
||||
raise ValueError(
|
||||
"You can only use one kind of guided decoding "
|
||||
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||
if data.get("stream") and prompt_logprobs > 0:
|
||||
raise ValueError(
|
||||
"`prompt_logprobs` are not available when `stream=True`.")
|
||||
|
||||
if prompt_logprobs < 0:
|
||||
raise ValueError("`prompt_logprobs` must be a positive value.")
|
||||
|
||||
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
|
||||
raise ValueError("`logprobs` must be a positive value.")
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
if data.get("stream_options") and not data.get("stream"):
|
||||
raise ValueError(
|
||||
"Stream options can only be defined when `stream=True`.")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class EmbeddingRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/embeddings
|
||||
model: str
|
||||
input: Union[List[int], List[List[int]], str, List[str]]
|
||||
encoding_format: Literal["float", "base64"] = "float"
|
||||
dimensions: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
|
||||
# doc: begin-embedding-pooling-params
|
||||
additional_data: Optional[Any] = None
|
||||
|
||||
# doc: end-embedding-pooling-params
|
||||
|
||||
# doc: begin-embedding-extra-params
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
|
||||
# doc: end-embedding-extra-params
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
|
||||
|
||||
class CompletionLogProbs(OpenAIBaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
tokens: List[str] = Field(default_factory=list)
|
||||
top_logprobs: List[Optional[Dict[str,
|
||||
float]]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CompletionResponseChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[CompletionLogProbs] = None
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Optional[Union[int, str]] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The stop string or token id that caused the completion "
|
||||
"to stop, None if the completion finished for some other reason "
|
||||
"including encountering the EOS token"),
|
||||
)
|
||||
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
|
||||
|
||||
|
||||
class CompletionResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||
object: str = "text_completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[CompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class CompletionResponseStreamChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[CompletionLogProbs] = None
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Optional[Union[int, str]] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The stop string or token id that caused the completion "
|
||||
"to stop, None if the completion finished for some other reason "
|
||||
"including encountering the EOS token"),
|
||||
)
|
||||
|
||||
|
||||
class CompletionStreamResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||
object: str = "text_completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[CompletionResponseStreamChoice]
|
||||
usage: Optional[UsageInfo] = Field(default=None)
|
||||
|
||||
|
||||
class EmbeddingResponseData(OpenAIBaseModel):
|
||||
index: int
|
||||
object: str = "embedding"
|
||||
embedding: Union[List[float], str]
|
||||
|
||||
|
||||
class EmbeddingResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
data: List[EmbeddingResponseData]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class FunctionCall(OpenAIBaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
class DeltaFunctionCall(BaseModel):
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
# a tool call delta where everything is optional
|
||||
class DeltaToolCall(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
|
||||
type: Literal["function"] = "function"
|
||||
index: int
|
||||
function: Optional[DeltaFunctionCall] = None
|
||||
|
||||
|
||||
class ExtractedToolCallInformation(BaseModel):
|
||||
# indicate if tools were called
|
||||
tools_called: bool
|
||||
|
||||
# extracted tool calls
|
||||
tool_calls: List[ToolCall]
|
||||
|
||||
# content - per OpenAI spec, content AND tool calls can be returned rarely
|
||||
# But some models will do this intentionally
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
class ChatMessage(OpenAIBaseModel):
|
||||
role: str
|
||||
content: Optional[str] = None
|
||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatCompletionLogProb(OpenAIBaseModel):
|
||||
token: str
|
||||
logprob: float = -9999.0
|
||||
bytes: Optional[List[int]] = None
|
||||
|
||||
|
||||
class ChatCompletionLogProbsContent(ChatCompletionLogProb):
|
||||
top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatCompletionLogProbs(OpenAIBaseModel):
|
||||
content: Optional[List[ChatCompletionLogProbsContent]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
logprobs: Optional[ChatCompletionLogProbs] = None
|
||||
# per OpenAI spec this is the default
|
||||
finish_reason: Optional[str] = "stop"
|
||||
# not part of the OpenAI spec but included in vLLM for legacy reasons
|
||||
stop_reason: Optional[Union[int, str]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
|
||||
|
||||
|
||||
class DeltaMessage(OpenAIBaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: List[DeltaToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
logprobs: Optional[ChatCompletionLogProbs] = None
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Optional[Union[int, str]] = None
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
usage: Optional[UsageInfo] = Field(default=None)
|
||||
|
||||
|
||||
class BatchRequestInput(OpenAIBaseModel):
|
||||
"""
|
||||
The per-line object of the batch input file.
|
||||
|
||||
NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
|
||||
"""
|
||||
|
||||
# A developer-provided per-request id that will be used to match outputs to
|
||||
# inputs. Must be unique for each request in a batch.
|
||||
custom_id: str
|
||||
|
||||
# The HTTP method to be used for the request. Currently only POST is
|
||||
# supported.
|
||||
method: str
|
||||
|
||||
# The OpenAI API relative URL to be used for the request. Currently
|
||||
# /v1/chat/completions is supported.
|
||||
url: str
|
||||
|
||||
# The parameters of the request.
|
||||
body: Union[ChatCompletionRequest, EmbeddingRequest]
|
||||
|
||||
|
||||
class BatchResponseData(OpenAIBaseModel):
|
||||
# HTTP status code of the response.
|
||||
status_code: int = 200
|
||||
|
||||
# An unique identifier for the API request.
|
||||
request_id: str
|
||||
|
||||
# The body of the response.
|
||||
body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
|
||||
|
||||
|
||||
class BatchRequestOutput(OpenAIBaseModel):
|
||||
"""
|
||||
The per-line object of the batch output and error files
|
||||
"""
|
||||
|
||||
id: str
|
||||
|
||||
# A developer-provided per-request id that will be used to match outputs to
|
||||
# inputs.
|
||||
custom_id: str
|
||||
|
||||
response: Optional[BatchResponseData]
|
||||
|
||||
# For requests that failed with a non-HTTP error, this will contain more
|
||||
# information on the cause of the failure.
|
||||
error: Optional[Any]
|
||||
|
||||
|
||||
class TokenizeCompletionRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
|
||||
add_special_tokens: bool = Field(default=True)
|
||||
|
||||
|
||||
class TokenizeChatRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
|
||||
add_generation_prompt: bool = Field(default=True)
|
||||
continue_final_message: bool = Field(default=False)
|
||||
add_special_tokens: bool = Field(default=False)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_generation_prompt(cls, data):
|
||||
if data.get("continue_final_message") and data.get(
|
||||
"add_generation_prompt"):
|
||||
raise ValueError("Cannot set both `continue_final_message` and "
|
||||
"`add_generation_prompt` to True.")
|
||||
return data
|
||||
|
||||
|
||||
TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
|
||||
|
||||
|
||||
class TokenizeResponse(OpenAIBaseModel):
|
||||
count: int
|
||||
max_model_len: int
|
||||
tokens: List[int]
|
||||
|
||||
|
||||
class DetokenizeRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
tokens: List[int]
|
||||
|
||||
|
||||
class DetokenizeResponse(OpenAIBaseModel):
|
||||
prompt: str
|
||||
|
||||
|
||||
class LoadLoraAdapterRequest(BaseModel):
|
||||
lora_name: str
|
||||
lora_path: str
|
||||
|
||||
|
||||
class UnloadLoraAdapterRequest(BaseModel):
|
||||
lora_name: str
|
||||
lora_int_id: Optional[int] = Field(default=None)
|
||||
285
vllm/entrypoints/openai/run_batch.py
Normal file
285
vllm/entrypoints/openai/run_batch.py
Normal file
@@ -0,0 +1,285 @@
|
||||
import asyncio
|
||||
from http import HTTPStatus
|
||||
from io import StringIO
|
||||
from typing import Awaitable, Callable, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from prometheus_client import start_http_server
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.logger import RequestLogger, logger
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
||||
BatchRequestOutput,
|
||||
BatchResponseData,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingResponse, ErrorResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, random_uuid
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM OpenAI-Compatible batch runner.")
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input-file",
|
||||
required=True,
|
||||
type=str,
|
||||
help=
|
||||
"The path or url to a single input file. Currently supports local file "
|
||||
"paths, or the http protocol (http or https). If a URL is specified, "
|
||||
"the file should be available via HTTP GET.")
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-file",
|
||||
required=True,
|
||||
type=str,
|
||||
help="The path or url to a single output file. Currently supports "
|
||||
"local file paths, or web (http or https) urls. If a URL is specified,"
|
||||
" the file should be available via HTTP PUT.")
|
||||
parser.add_argument("--response-role",
|
||||
type=nullable_str,
|
||||
default="assistant",
|
||||
help="The role name to return if "
|
||||
"`request.add_generation_prompt=True`.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument('--max-log-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Max number of prompt characters or prompt '
|
||||
'ID numbers being printed in log.'
|
||||
'\n\nDefault: Unlimited')
|
||||
|
||||
parser.add_argument("--enable-metrics",
|
||||
action="store_true",
|
||||
help="Enable Prometheus metrics")
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="URL to the Prometheus metrics server "
|
||||
"(only needed if enable-metrics is set).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port number for the Prometheus metrics server "
|
||||
"(only needed if enable-metrics is set).",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# explicitly use pure text format, with a newline at the end
|
||||
# this makes it impossible to see the animation in the progress bar
|
||||
# but will avoid messing up with ray or multiprocessing, which wraps
|
||||
# each line of output with some prefix.
|
||||
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
|
||||
|
||||
|
||||
class BatchProgressTracker:
|
||||
|
||||
def __init__(self):
|
||||
self._total = 0
|
||||
self._pbar: Optional[tqdm] = None
|
||||
|
||||
def submitted(self):
|
||||
self._total += 1
|
||||
|
||||
def completed(self):
|
||||
if self._pbar:
|
||||
self._pbar.update()
|
||||
|
||||
def pbar(self) -> tqdm:
|
||||
enable_tqdm = not torch.distributed.is_initialized(
|
||||
) or torch.distributed.get_rank() == 0
|
||||
self._pbar = tqdm(total=self._total,
|
||||
unit="req",
|
||||
desc="Running batch",
|
||||
mininterval=5,
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT)
|
||||
return self._pbar
|
||||
|
||||
|
||||
async def read_file(path_or_url: str) -> str:
|
||||
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
||||
async with aiohttp.ClientSession() as session, \
|
||||
session.get(path_or_url) as resp:
|
||||
return await resp.text()
|
||||
else:
|
||||
with open(path_or_url, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
async def write_file(path_or_url: str, data: str) -> None:
|
||||
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
||||
async with aiohttp.ClientSession() as session, \
|
||||
session.put(path_or_url, data=data.encode("utf-8")):
|
||||
pass
|
||||
else:
|
||||
# We should make this async, but as long as this is always run as a
|
||||
# standalone program, blocking the event loop won't effect performance
|
||||
# in this particular case.
|
||||
with open(path_or_url, "w", encoding="utf-8") as f:
|
||||
f.write(data)
|
||||
|
||||
|
||||
def make_error_request_output(request: BatchRequestInput,
|
||||
error_msg: str) -> BatchRequestOutput:
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
request_id=f"vllm-batch-{random_uuid()}",
|
||||
),
|
||||
error=error_msg,
|
||||
)
|
||||
return batch_output
|
||||
|
||||
|
||||
async def make_async_error_request_output(
|
||||
request: BatchRequestInput, error_msg: str) -> BatchRequestOutput:
|
||||
return make_error_request_output(request, error_msg)
|
||||
|
||||
|
||||
async def run_request(serving_engine_func: Callable,
|
||||
request: BatchRequestInput,
|
||||
tracker: BatchProgressTracker) -> BatchRequestOutput:
|
||||
response = await serving_engine_func(request.body)
|
||||
|
||||
if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
body=response, request_id=f"vllm-batch-{random_uuid()}"),
|
||||
error=None,
|
||||
)
|
||||
elif isinstance(response, ErrorResponse):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
status_code=response.code,
|
||||
request_id=f"vllm-batch-{random_uuid()}"),
|
||||
error=response,
|
||||
)
|
||||
else:
|
||||
batch_output = make_error_request_output(
|
||||
request, error_msg="Request must not be sent in stream mode")
|
||||
|
||||
tracker.completed()
|
||||
return batch_output
|
||||
|
||||
|
||||
async def main(args):
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
|
||||
|
||||
model_config = await engine.get_model_config()
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model)
|
||||
for name in served_model_names
|
||||
]
|
||||
|
||||
if args.disable_log_requests:
|
||||
request_logger = None
|
||||
else:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
# Create the openai serving objects.
|
||||
openai_serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
args.response_role,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
)
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
|
||||
tracker = BatchProgressTracker()
|
||||
logger.info("Reading batch from %s...", args.input_file)
|
||||
|
||||
# Submit all requests in the file to the engine "concurrently".
|
||||
response_futures: List[Awaitable[BatchRequestOutput]] = []
|
||||
for request_json in (await read_file(args.input_file)).strip().split("\n"):
|
||||
# Skip empty lines.
|
||||
request_json = request_json.strip()
|
||||
if not request_json:
|
||||
continue
|
||||
|
||||
request = BatchRequestInput.model_validate_json(request_json)
|
||||
|
||||
# Determine the type of request and run it.
|
||||
if request.url == "/v1/chat/completions":
|
||||
response_futures.append(
|
||||
run_request(openai_serving_chat.create_chat_completion,
|
||||
request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url == "/v1/embeddings":
|
||||
response_futures.append(
|
||||
run_request(openai_serving_embedding.create_embedding, request,
|
||||
tracker))
|
||||
tracker.submitted()
|
||||
else:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
error_msg="Only /v1/chat/completions and "
|
||||
"/v1/embeddings are supported in the batch endpoint.",
|
||||
))
|
||||
|
||||
with tracker.pbar():
|
||||
responses = await asyncio.gather(*response_futures)
|
||||
|
||||
output_buffer = StringIO()
|
||||
for response in responses:
|
||||
print(response.model_dump_json(), file=output_buffer)
|
||||
|
||||
output_buffer.seek(0)
|
||||
await write_file(args.output_file, output_buffer.read().strip())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
logger.info("vLLM batch processing API version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
# Start the Prometheus metrics server. LLMEngine uses the Prometheus client
|
||||
# to publish metrics at the /metrics endpoint.
|
||||
if args.enable_metrics:
|
||||
logger.info("Prometheus metrics enabled")
|
||||
start_http_server(port=args.port, addr=args.url)
|
||||
else:
|
||||
logger.info("Prometheus metrics disabled")
|
||||
|
||||
asyncio.run(main(args))
|
||||
891
vllm/entrypoints/openai/serving_chat.py
Normal file
891
vllm/entrypoints/openai/serving_chat.py
Normal file
@@ -0,0 +1,891 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List,
|
||||
Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
load_chat_template,
|
||||
parse_chat_messages_futures)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionLogProb, ChatCompletionLogProbs,
|
||||
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata,
|
||||
ToolCall, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
LoRAModulePath,
|
||||
OpenAIServing,
|
||||
PromptAdapterPath,
|
||||
TextTokensPrompt)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import iterate_with_cancellation, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
def __init__(self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
response_role: str,
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
enable_auto_tools: bool = False,
|
||||
tool_parser: Optional[str] = None):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||
|
||||
self.response_role = response_role
|
||||
self.use_tool_use_model_template = False
|
||||
self.chat_template = load_chat_template(chat_template)
|
||||
|
||||
# set up tool use
|
||||
self.enable_auto_tools: bool = enable_auto_tools
|
||||
if self.enable_auto_tools:
|
||||
logger.info(
|
||||
"\"auto\" tool choice has been enabled please note that while"
|
||||
" the parallel_tool_calls client option is preset for "
|
||||
"compatibility reasons, it will be ignored.")
|
||||
|
||||
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
|
||||
if self.enable_auto_tools:
|
||||
try:
|
||||
self.tool_parser = ToolParserManager.get_tool_parser(
|
||||
tool_parser)
|
||||
except Exception as e:
|
||||
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||
f"tool_parser:'{tool_parser}' which has not "
|
||||
"been registered") from e
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
|
||||
ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/chat/create
|
||||
for the API specification. This API mimics the OpenAI
|
||||
ChatCompletion API.
|
||||
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
logger.error("Error with model %s", error_check_ret)
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
model_config = self.model_config
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
conversation, mm_data_future = parse_chat_messages_futures(
|
||||
request.messages, model_config, tokenizer)
|
||||
|
||||
tool_dicts = None if request.tools is None else [
|
||||
tool.model_dump() for tool in request.tools
|
||||
]
|
||||
|
||||
prompt: Union[str, List[int]]
|
||||
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
|
||||
if is_mistral_tokenizer:
|
||||
prompt = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
tools=tool_dicts,
|
||||
documents=request.documents,
|
||||
**(request.chat_template_kwargs or {}),
|
||||
)
|
||||
else:
|
||||
prompt = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
conversation=conversation,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
tools=tool_dicts,
|
||||
documents=request.documents,
|
||||
**(request.chat_template_kwargs or {}),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error in applying chat template from request")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
try:
|
||||
mm_data = await mm_data_future
|
||||
except Exception as e:
|
||||
logger.exception("Error in loading multi-modal data")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# validation for OpenAI tools
|
||||
# tool_choice = "required" is not supported
|
||||
if request.tool_choice == "required":
|
||||
return self.create_error_response(
|
||||
"tool_choice = \"required\" is not supported!")
|
||||
|
||||
if not is_mistral_tokenizer and request.tool_choice == "auto" and not (
|
||||
self.enable_auto_tools and self.tool_parser is not None):
|
||||
# for hf tokenizers, "auto" tools requires
|
||||
# --enable-auto-tool-choice and --tool-call-parser
|
||||
return self.create_error_response(
|
||||
"\"auto\" tool choice requires "
|
||||
"--enable-auto-tool-choice and --tool-call-parser to be set")
|
||||
|
||||
request_id = f"chat-{random_uuid()}"
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
if self.enable_auto_tools and self.tool_parser:
|
||||
request = self.tool_parser(tokenizer).adjust_request(
|
||||
request=request)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt_inputs = self._tokenize_prompt_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
assert isinstance(prompt, list) and isinstance(
|
||||
prompt[0], int
|
||||
), "Prompt has to be either a string or a list of token ids"
|
||||
prompt_inputs = TextTokensPrompt(
|
||||
prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)
|
||||
|
||||
assert prompt_inputs is not None
|
||||
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
prompt_inputs["prompt_token_ids"])
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
default_max_tokens)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens)
|
||||
|
||||
self._log_inputs(request_id,
|
||||
prompt_inputs,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
engine_inputs = TokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
||||
if mm_data is not None:
|
||||
engine_inputs["multi_modal_data"] = mm_data
|
||||
|
||||
is_tracing_enabled = (await
|
||||
self.engine_client.is_tracing_enabled())
|
||||
trace_headers = None
|
||||
if is_tracing_enabled and raw_request:
|
||||
trace_headers = extract_trace_headers(raw_request.headers)
|
||||
if (not is_tracing_enabled and raw_request
|
||||
and contains_trace_headers(raw_request.headers)):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
assert isinstance(self.engine_client,
|
||||
(AsyncLLMEngine,
|
||||
MQLLMEngineClient)), \
|
||||
"Beam search is only supported with" \
|
||||
"AsyncLLMEngine and MQLLMEngineClient."
|
||||
result_generator = self.engine_client.beam_search(
|
||||
engine_inputs['prompt_token_ids'],
|
||||
request_id,
|
||||
sampling_params,
|
||||
)
|
||||
else:
|
||||
result_generator = self.engine_client.generate(
|
||||
engine_inputs,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=request.priority,
|
||||
)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
if raw_request:
|
||||
result_generator = iterate_with_cancellation(
|
||||
result_generator, raw_request.is_disconnected)
|
||||
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
request, result_generator, request_id, conversation, tokenizer,
|
||||
request_metadata)
|
||||
|
||||
try:
|
||||
return await self.chat_completion_full_generator(
|
||||
request, result_generator, request_id, conversation, tokenizer,
|
||||
request_metadata)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
|
||||
if request.add_generation_prompt:
|
||||
return self.response_role
|
||||
return request.messages[-1]["role"]
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str,
|
||||
conversation: List[ConversationMessage],
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
model_name = self.base_model_paths[0].name
|
||||
created_time = int(time.time())
|
||||
chunk_object_type: Final = "chat.completion.chunk"
|
||||
first_iteration = True
|
||||
|
||||
# Send response for each token for each request.n (index)
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_num_tokens = [0] * num_choices
|
||||
finish_reason_sent = [False] * num_choices
|
||||
num_prompt_tokens = 0
|
||||
|
||||
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||
tool_choice_function_name = request.tool_choice.function.name
|
||||
else:
|
||||
tool_choice_function_name = None
|
||||
|
||||
# Determine whether tools are in use with "auto" tool choice
|
||||
tool_choice_auto = (
|
||||
not tool_choice_function_name
|
||||
and self._should_stream_with_auto_tool_parsing(request))
|
||||
|
||||
all_previous_token_ids: Optional[List[List[int]]]
|
||||
if tool_choice_auto:
|
||||
# These are only required in "auto" tool choice case
|
||||
previous_texts = [""] * num_choices
|
||||
all_previous_token_ids = [[]] * num_choices
|
||||
else:
|
||||
previous_texts, all_previous_token_ids = None, None
|
||||
|
||||
# Prepare the tool parser if it's needed
|
||||
try:
|
||||
if tool_choice_auto and self.tool_parser:
|
||||
tool_parsers: List[Optional[ToolParser]] = [
|
||||
self.tool_parser(tokenizer)
|
||||
] * num_choices
|
||||
else:
|
||||
tool_parsers = [None] * num_choices
|
||||
except RuntimeError as e:
|
||||
logger.error("Error in tool parser creation: %s", e)
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
if res.prompt_token_ids is not None:
|
||||
num_prompt_tokens = len(res.prompt_token_ids)
|
||||
if res.encoder_prompt_token_ids is not None:
|
||||
num_prompt_tokens += len(res.encoder_prompt_token_ids)
|
||||
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
# response (by the try...catch).
|
||||
if first_iteration:
|
||||
# Send first response for each request.n (index) with
|
||||
# the role
|
||||
role = self.get_chat_request_role(request)
|
||||
|
||||
# NOTE num_choices defaults to 1 so this usually executes
|
||||
# once per request
|
||||
for i in range(num_choices):
|
||||
tool_parser = tool_parsers[i]
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(
|
||||
role=role,
|
||||
content="",
|
||||
),
|
||||
logprobs=None,
|
||||
finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
|
||||
# if usage should be included
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
# if continuous usage stats are requested, add it
|
||||
if request.stream_options.continuous_usage_stats:
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=num_prompt_tokens)
|
||||
chunk.usage = usage
|
||||
# otherwise don't
|
||||
else:
|
||||
chunk.usage = None
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send response to echo the input portion of the
|
||||
# last message
|
||||
if request.echo or request.continue_final_message:
|
||||
last_msg_content: str = ""
|
||||
if conversation and "content" in conversation[
|
||||
-1] and conversation[-1].get("role") == role:
|
||||
last_msg_content = conversation[-1]["content"] or ""
|
||||
|
||||
if last_msg_content:
|
||||
for i in range(num_choices):
|
||||
choice_data = (
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(
|
||||
content=last_msg_content),
|
||||
logprobs=None,
|
||||
finish_reason=None))
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
if (request.stream_options and
|
||||
request.stream_options.include_usage):
|
||||
if (request.stream_options.
|
||||
continuous_usage_stats):
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=num_prompt_tokens)
|
||||
chunk.usage = usage
|
||||
else:
|
||||
chunk.usage = None
|
||||
|
||||
data = chunk.model_dump_json(
|
||||
exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
first_iteration = False
|
||||
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
tool_parser = tool_parsers[i]
|
||||
|
||||
if finish_reason_sent[i]:
|
||||
continue
|
||||
|
||||
if request.logprobs and request.top_logprobs is not None:
|
||||
assert output.logprobs is not None, (
|
||||
"Did not output logprobs")
|
||||
logprobs = self._create_chat_logprobs(
|
||||
token_ids=output.token_ids,
|
||||
top_logprobs=output.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
num_output_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
delta_text = output.text
|
||||
delta_message: Optional[DeltaMessage]
|
||||
|
||||
# handle streaming deltas for tools with named tool_choice
|
||||
if tool_choice_function_name:
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(function=DeltaFunctionCall(
|
||||
name=tool_choice_function_name,
|
||||
arguments=delta_text),
|
||||
index=i)
|
||||
])
|
||||
|
||||
# handle streaming deltas for tools with "auto" tool choice
|
||||
elif tool_choice_auto:
|
||||
assert previous_texts is not None
|
||||
assert all_previous_token_ids is not None
|
||||
assert tool_parser is not None
|
||||
#TODO optimize manipulation of these lists
|
||||
previous_text = previous_texts[i]
|
||||
previous_token_ids = all_previous_token_ids[i]
|
||||
current_text = previous_text + delta_text
|
||||
current_token_ids = previous_token_ids + list(
|
||||
output.token_ids)
|
||||
|
||||
delta_message = (
|
||||
tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=output.token_ids,
|
||||
request=request))
|
||||
|
||||
# update the previous values for the next iteration
|
||||
previous_texts[i] = current_text
|
||||
all_previous_token_ids[i] = current_token_ids
|
||||
|
||||
# handle streaming just a content delta
|
||||
else:
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
|
||||
# set the previous values for the next iteration
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
|
||||
# if the message delta is None (e.g. because it was a
|
||||
# "control token" for tool calls or the parser otherwise
|
||||
# wasn't ready to send a token, then
|
||||
# get the next token without streaming a chunk
|
||||
if delta_message is None:
|
||||
continue
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Send token-by-token response for each request.n
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=delta_message,
|
||||
logprobs=logprobs,
|
||||
finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
|
||||
# handle usage stats if requested & if continuous
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
if request.stream_options.continuous_usage_stats:
|
||||
completion_tokens = len(output.token_ids)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens +
|
||||
completion_tokens,
|
||||
)
|
||||
chunk.usage = usage
|
||||
else:
|
||||
chunk.usage = None
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# if the model is finished generating
|
||||
else:
|
||||
# check to make sure we haven't "forgotten" to stream
|
||||
# any tokens that were generated but previously
|
||||
# matched by partial json parsing
|
||||
# only happens if we are NOT using guided decoding
|
||||
auto_tools_called = False
|
||||
if tool_parser:
|
||||
auto_tools_called = len(
|
||||
tool_parser.prev_tool_call_arr) > 0
|
||||
index = len(tool_parser.prev_tool_call_arr
|
||||
) - 1 if auto_tools_called else 0
|
||||
else:
|
||||
index = 0
|
||||
|
||||
if self._should_check_for_unstreamed_tool_arg_tokens(
|
||||
delta_message, output) and tool_parser:
|
||||
# get the expected call based on partial JSON
|
||||
# parsing which "autocompletes" the JSON
|
||||
expected_call = json.dumps(
|
||||
tool_parser.prev_tool_call_arr[index].get(
|
||||
"arguments", {}))
|
||||
|
||||
# get what we've streamed so far for arguments
|
||||
# for the current tool
|
||||
actual_call = tool_parser.streamed_args_for_tool[
|
||||
index]
|
||||
|
||||
# check to see if there's anything left to stream
|
||||
remaining_call = expected_call.replace(
|
||||
actual_call, "", 1)
|
||||
|
||||
# set that as a delta message
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=remaining_call).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
|
||||
# Send the finish response for each request.n only once
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=delta_message,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason
|
||||
if not auto_tools_called else "tool_calls",
|
||||
stop_reason=output.stop_reason)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
if request.stream_options.continuous_usage_stats:
|
||||
completion_tokens = len(output.token_ids)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens +
|
||||
completion_tokens,
|
||||
)
|
||||
chunk.usage = usage
|
||||
else:
|
||||
chunk.usage = None
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
finish_reason_sent[i] = True
|
||||
|
||||
# once the final token is handled, if stream_options.include_usage
|
||||
# is sent, send the usage
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
completion_tokens = previous_num_tokens[i]
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
final_usage_chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[],
|
||||
model=model_name,
|
||||
usage=final_usage)
|
||||
final_usage_data = (final_usage_chunk.model_dump_json(
|
||||
exclude_unset=True, exclude_none=True))
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
num_completion_tokens = sum(previous_num_tokens)
|
||||
request_metadata.final_usage_info = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_completion_tokens,
|
||||
total_tokens=num_prompt_tokens + num_completion_tokens)
|
||||
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
logger.error("error in chat completion stream generator: %s", e)
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def chat_completion_full_generator(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str,
|
||||
conversation: List[ConversationMessage],
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||
|
||||
model_name = self.base_model_paths[0].name
|
||||
created_time = int(time.time())
|
||||
final_res: Optional[RequestOutput] = None
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
final_res = res
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
|
||||
assert final_res is not None
|
||||
|
||||
choices: List[ChatCompletionResponseChoice] = []
|
||||
|
||||
role = self.get_chat_request_role(request)
|
||||
for output in final_res.outputs:
|
||||
token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
|
||||
if request.logprobs and request.top_logprobs is not None:
|
||||
assert out_logprobs is not None, "Did not output logprobs"
|
||||
logprobs = self._create_chat_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=request.top_logprobs,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
# In the OpenAI API the finish_reason is "tools_called"
|
||||
# if the tool choice is auto and the model produced a tool
|
||||
# call. The same is not true for named function calls
|
||||
auto_tools_called = False
|
||||
|
||||
# if auto tools are not enabled, and a named tool choice using
|
||||
# outlines is not being used
|
||||
if (not self.enable_auto_tools
|
||||
or not self.tool_parser) and not isinstance(
|
||||
request.tool_choice,
|
||||
ChatCompletionNamedToolChoiceParam):
|
||||
message = ChatMessage(role=role, content=output.text)
|
||||
|
||||
# if the request uses tools and specified a tool choice
|
||||
elif request.tool_choice and type(
|
||||
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(function=FunctionCall(
|
||||
name=request.tool_choice.function.name,
|
||||
arguments=output.text))
|
||||
])
|
||||
|
||||
# if the request doesn't use tool choice
|
||||
# OR specifies to not use a tool
|
||||
elif not request.tool_choice or request.tool_choice == "none":
|
||||
|
||||
message = ChatMessage(role=role, content=output.text)
|
||||
|
||||
# handle when there are tools and tool choice is auto
|
||||
elif request.tools and (
|
||||
request.tool_choice == "auto"
|
||||
or request.tool_choice is None) and self.enable_auto_tools \
|
||||
and self.tool_parser:
|
||||
|
||||
try:
|
||||
tool_parser = self.tool_parser(tokenizer)
|
||||
except RuntimeError as e:
|
||||
logger.error("Error in tool parser creation: %s", e)
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
tool_call_info = tool_parser.extract_tool_calls(
|
||||
output.text, request=request)
|
||||
# In the OpenAI API the finish_reason is "tools_called"
|
||||
# if the tool choice is auto and the model produced a tool
|
||||
# call. The same is not true for named function calls
|
||||
auto_tools_called = tool_call_info.tools_called
|
||||
if tool_call_info.tools_called:
|
||||
message = ChatMessage(role=role,
|
||||
content=tool_call_info.content,
|
||||
tool_calls=tool_call_info.tool_calls)
|
||||
|
||||
else:
|
||||
# FOR NOW make it a chat message; we will have to detect
|
||||
# the type to make it later.
|
||||
message = ChatMessage(role=role, content=output.text)
|
||||
|
||||
# undetermined case that is still important to handle
|
||||
else:
|
||||
logger.error(
|
||||
"Error in chat_completion_full_generator - cannot determine"
|
||||
" if tools should be extracted. Returning a standard chat "
|
||||
"completion.")
|
||||
message = ChatMessage(role=role, content=output.text)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=message,
|
||||
logprobs=logprobs,
|
||||
finish_reason="tool_calls" if auto_tools_called else
|
||||
output.finish_reason if output.finish_reason else "stop",
|
||||
stop_reason=output.stop_reason)
|
||||
choices.append(choice_data)
|
||||
|
||||
if request.echo or request.continue_final_message:
|
||||
last_msg_content = ""
|
||||
if conversation and "content" in conversation[-1] and conversation[
|
||||
-1].get("role") == role:
|
||||
last_msg_content = conversation[-1]["content"] or ""
|
||||
|
||||
for choice in choices:
|
||||
full_message = last_msg_content + (choice.message.content
|
||||
or "")
|
||||
choice.message.content = full_message
|
||||
|
||||
assert final_res.prompt_token_ids is not None
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
if final_res.encoder_prompt_token_ids is not None:
|
||||
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
|
||||
request_metadata.final_usage_info = usage
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
prompt_logprobs=final_res.prompt_logprobs,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _get_top_logprobs(
|
||||
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
|
||||
tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
|
||||
return [
|
||||
ChatCompletionLogProb(token=(token := self._get_decoded_token(
|
||||
p[1],
|
||||
p[0],
|
||||
tokenizer,
|
||||
return_as_token_id=self.return_tokens_as_token_ids)),
|
||||
logprob=max(p[1].logprob, -9999.0),
|
||||
bytes=list(
|
||||
token.encode("utf-8", errors="replace")))
|
||||
for i, p in enumerate(logprobs.items())
|
||||
if top_logprobs and i < top_logprobs
|
||||
]
|
||||
|
||||
def _create_chat_logprobs(
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||
tokenizer: AnyTokenizer,
|
||||
num_output_top_logprobs: Optional[int] = None,
|
||||
) -> ChatCompletionLogProbs:
|
||||
"""Create OpenAI-style logprobs."""
|
||||
logprobs_content: List[ChatCompletionLogProbsContent] = []
|
||||
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
token = tokenizer.decode(token_id)
|
||||
if self.return_tokens_as_token_ids:
|
||||
token = f"token_id:{token_id}"
|
||||
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=token,
|
||||
bytes=list(token.encode("utf-8", errors="replace")),
|
||||
))
|
||||
else:
|
||||
step_token = step_top_logprobs[token_id]
|
||||
step_decoded = step_token.decoded_token
|
||||
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=self._get_decoded_token(
|
||||
step_token,
|
||||
token_id,
|
||||
tokenizer,
|
||||
self.return_tokens_as_token_ids,
|
||||
),
|
||||
logprob=max(step_token.logprob, -9999.0),
|
||||
bytes=None if step_decoded is None else list(
|
||||
step_decoded.encode("utf-8", errors="replace")),
|
||||
top_logprobs=self._get_top_logprobs(
|
||||
step_top_logprobs,
|
||||
num_output_top_logprobs,
|
||||
tokenizer,
|
||||
),
|
||||
))
|
||||
|
||||
return ChatCompletionLogProbs(content=logprobs_content)
|
||||
|
||||
def _should_stream_with_auto_tool_parsing(self,
|
||||
request: ChatCompletionRequest):
|
||||
"""
|
||||
Utility function to check if streamed tokens should go through the tool
|
||||
call parser that was configured.
|
||||
|
||||
We only want to do this IF user-provided tools are set, a tool parser
|
||||
is configured, "auto" tool choice is enabled, and the request's tool
|
||||
choice field indicates that "auto" tool choice should be used.
|
||||
"""
|
||||
return (request.tools and self.tool_parser and self.enable_auto_tools
|
||||
and request.tool_choice in ['auto', None])
|
||||
|
||||
def _should_check_for_unstreamed_tool_arg_tokens(
|
||||
self,
|
||||
delta_message: Optional[DeltaMessage],
|
||||
output: CompletionOutput,
|
||||
) -> bool:
|
||||
"""
|
||||
Check to see if we should check for unstreamed tool arguments tokens.
|
||||
This is only applicable when auto tool parsing is enabled, the delta
|
||||
is a tool call with arguments.
|
||||
"""
|
||||
|
||||
# yapf: disable
|
||||
return bool(
|
||||
# if there is a delta message that includes tool calls which
|
||||
# include a function that has arguments
|
||||
output.finish_reason is not None
|
||||
and self.enable_auto_tools and self.tool_parser and delta_message
|
||||
and delta_message.tool_calls and delta_message.tool_calls[0]
|
||||
and delta_message.tool_calls[0].function
|
||||
and delta_message.tool_calls[0].function.arguments is not None
|
||||
)
|
||||
554
vllm/entrypoints/openai/serving_completion.py
Normal file
554
vllm/entrypoints/openai/serving_completion.py
Normal file
@@ -0,0 +1,554 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
|
||||
Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Tuple, Union, cast
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
ErrorResponse,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
LoRAModulePath,
|
||||
OpenAIServing,
|
||||
PromptAdapterPath)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import merge_async_iterators, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
TypeTokenIDs = List[int]
|
||||
TypeTopLogProbs = List[Optional[Dict[int, float]]]
|
||||
TypeCreateLogProbsFn = Callable[
|
||||
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
|
||||
|
||||
|
||||
class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||
|
||||
async def create_completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following feature:
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
# Return error for unsupported features.
|
||||
if request.suffix is not None:
|
||||
return self.create_error_response(
|
||||
"suffix is not currently supported")
|
||||
|
||||
model_name = self.base_model_paths[0].name
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
created_time = int(time.time())
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
prompts = list(
|
||||
self._tokenize_prompt_input_or_inputs(
|
||||
request,
|
||||
tokenizer,
|
||||
request.prompt,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
))
|
||||
|
||||
for i, prompt_inputs in enumerate(prompts):
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
prompt_inputs["prompt_token_ids"])
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
default_max_tokens)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens)
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
prompt_inputs,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
is_tracing_enabled = (await
|
||||
self.engine_client.is_tracing_enabled())
|
||||
trace_headers = None
|
||||
if is_tracing_enabled:
|
||||
trace_headers = extract_trace_headers(raw_request.headers)
|
||||
if not is_tracing_enabled and contains_trace_headers(
|
||||
raw_request.headers):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
assert isinstance(self.engine_client,
|
||||
(AsyncLLMEngine,
|
||||
MQLLMEngineClient)), \
|
||||
"Beam search is only supported with" \
|
||||
"AsyncLLMEngine and MQLLMEngineClient."
|
||||
generator = self.engine_client.beam_search(
|
||||
prompt_inputs["prompt_token_ids"],
|
||||
request_id_item,
|
||||
sampling_params,
|
||||
)
|
||||
else:
|
||||
generator = self.engine_client.generate(
|
||||
{
|
||||
"prompt_token_ids":
|
||||
prompt_inputs["prompt_token_ids"]
|
||||
},
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(
|
||||
*generators, is_cancelled=raw_request.is_disconnected)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. In addition, we do not stream the results when use
|
||||
# beam search.
|
||||
stream = (request.stream
|
||||
and (request.best_of is None or request.n == request.best_of)
|
||||
and not request.use_beam_search)
|
||||
|
||||
# Streaming response
|
||||
if stream:
|
||||
return self.completion_stream_generator(
|
||||
request,
|
||||
result_generator,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
num_prompts=len(prompts),
|
||||
tokenizer=tokenizer,
|
||||
request_metadata=request_metadata)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
for i, final_res in enumerate(final_res_batch):
|
||||
assert final_res is not None
|
||||
|
||||
# The output should contain the input text
|
||||
# We did not pass it into vLLM engine to avoid being redundant
|
||||
# with the inputs token IDs
|
||||
if final_res.prompt is None:
|
||||
final_res.prompt = prompts[i]["prompt"]
|
||||
|
||||
final_res_batch_checked = cast(List[RequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = self.request_output_to_completion_response(
|
||||
final_res_batch_checked,
|
||||
request,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
if request.stream:
|
||||
response_json = response.model_dump_json()
|
||||
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return fake_stream_generator()
|
||||
|
||||
return response
|
||||
|
||||
async def completion_stream_generator(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
num_prompts: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_text_lens = [0] * num_choices * num_prompts
|
||||
previous_num_tokens = [0] * num_choices * num_prompts
|
||||
has_echoed = [False] * num_choices * num_prompts
|
||||
num_prompt_tokens = [0] * num_prompts
|
||||
|
||||
try:
|
||||
async for prompt_idx, res in result_generator:
|
||||
prompt_token_ids = res.prompt_token_ids
|
||||
prompt_logprobs = res.prompt_logprobs
|
||||
prompt_text = res.prompt
|
||||
|
||||
# Prompt details are excluded from later streamed outputs
|
||||
if res.prompt_token_ids is not None:
|
||||
num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)
|
||||
|
||||
delta_token_ids: GenericSequence[int]
|
||||
out_logprobs: Optional[GenericSequence[Optional[Dict[
|
||||
int, Logprob]]]]
|
||||
|
||||
for output in res.outputs:
|
||||
i = output.index + prompt_idx * num_choices
|
||||
# TODO(simon): optimize the performance by avoiding full
|
||||
# text O(n^2) sending.
|
||||
|
||||
assert request.max_tokens is not None
|
||||
if request.echo and request.max_tokens == 0:
|
||||
assert prompt_token_ids is not None
|
||||
assert prompt_text is not None
|
||||
# only return the prompt
|
||||
delta_text = prompt_text
|
||||
delta_token_ids = prompt_token_ids
|
||||
out_logprobs = prompt_logprobs
|
||||
has_echoed[i] = True
|
||||
elif (request.echo and request.max_tokens > 0
|
||||
and not has_echoed[i]):
|
||||
assert prompt_token_ids is not None
|
||||
assert prompt_text is not None
|
||||
assert prompt_logprobs is not None
|
||||
# echo the prompt and first token
|
||||
delta_text = prompt_text + output.text
|
||||
delta_token_ids = [
|
||||
*prompt_token_ids, *output.token_ids
|
||||
]
|
||||
out_logprobs = [
|
||||
*prompt_logprobs,
|
||||
*(output.logprobs or []),
|
||||
]
|
||||
has_echoed[i] = True
|
||||
else:
|
||||
# return just the delta
|
||||
delta_text = output.text
|
||||
delta_token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert out_logprobs is not None, (
|
||||
"Did not output logprobs")
|
||||
logprobs = self._create_completion_logprobs(
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
initial_text_offset=previous_text_lens[i],
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
previous_text_lens[i] += len(output.text)
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
finish_reason = output.finish_reason
|
||||
stop_reason = output.stop_reason
|
||||
|
||||
chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
CompletionResponseStreamChoice(
|
||||
index=i,
|
||||
text=delta_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
])
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
if (request.stream_options.continuous_usage_stats
|
||||
or output.finish_reason is not None):
|
||||
prompt_tokens = num_prompt_tokens[prompt_idx]
|
||||
completion_tokens = previous_num_tokens[i]
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
if request.stream_options.continuous_usage_stats:
|
||||
chunk.usage = usage
|
||||
else:
|
||||
chunk.usage = None
|
||||
|
||||
response_json = chunk.model_dump_json(exclude_unset=False)
|
||||
yield f"data: {response_json}\n\n"
|
||||
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
final_usage_chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[],
|
||||
usage=usage,
|
||||
)
|
||||
final_usage_data = (final_usage_chunk.model_dump_json(
|
||||
exclude_unset=False, exclude_none=True))
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
total_prompt_tokens = sum(num_prompt_tokens)
|
||||
total_completion_tokens = sum(previous_num_tokens)
|
||||
request_metadata.final_usage_info = UsageInfo(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=total_completion_tokens,
|
||||
total_tokens=total_prompt_tokens + total_completion_tokens)
|
||||
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
yield f"data: {data}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
def request_output_to_completion_response(
|
||||
self,
|
||||
final_res_batch: List[RequestOutput],
|
||||
request: CompletionRequest,
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> CompletionResponse:
|
||||
choices: List[CompletionResponseChoice] = []
|
||||
num_prompt_tokens = 0
|
||||
num_generated_tokens = 0
|
||||
|
||||
for final_res in final_res_batch:
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
assert prompt_token_ids is not None
|
||||
prompt_logprobs = final_res.prompt_logprobs
|
||||
prompt_text = final_res.prompt
|
||||
|
||||
token_ids: GenericSequence[int]
|
||||
out_logprobs: Optional[GenericSequence[Optional[Dict[int,
|
||||
Logprob]]]]
|
||||
|
||||
for output in final_res.outputs:
|
||||
assert request.max_tokens is not None
|
||||
if request.echo and request.max_tokens == 0:
|
||||
assert prompt_text is not None
|
||||
token_ids = prompt_token_ids
|
||||
out_logprobs = prompt_logprobs
|
||||
output_text = prompt_text
|
||||
elif request.echo and request.max_tokens > 0:
|
||||
assert prompt_text is not None
|
||||
token_ids = [*prompt_token_ids, *output.token_ids]
|
||||
|
||||
if request.logprobs is None:
|
||||
out_logprobs = None
|
||||
else:
|
||||
assert prompt_logprobs is not None
|
||||
assert output.logprobs is not None
|
||||
out_logprobs = [
|
||||
*prompt_logprobs,
|
||||
*output.logprobs,
|
||||
]
|
||||
|
||||
output_text = prompt_text + output.text
|
||||
else:
|
||||
token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
output_text = output.text
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert out_logprobs is not None, "Did not output logprobs"
|
||||
logprobs = self._create_completion_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
tokenizer=tokenizer,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=len(choices),
|
||||
text=output_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason,
|
||||
prompt_logprobs=final_res.prompt_logprobs,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_generated_tokens += len(output.token_ids)
|
||||
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
|
||||
request_metadata.final_usage_info = usage
|
||||
|
||||
return CompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _create_completion_logprobs(
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||
num_output_top_logprobs: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
initial_text_offset: int = 0,
|
||||
) -> CompletionLogProbs:
|
||||
"""Create logprobs for OpenAI Completion API."""
|
||||
out_text_offset: List[int] = []
|
||||
out_token_logprobs: List[Optional[float]] = []
|
||||
out_tokens: List[str] = []
|
||||
out_top_logprobs: List[Optional[Dict[str, float]]] = []
|
||||
|
||||
last_token_len = 0
|
||||
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
token = tokenizer.decode(token_id)
|
||||
if self.return_tokens_as_token_ids:
|
||||
token = f"token_id:{token_id}"
|
||||
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(None)
|
||||
out_top_logprobs.append(None)
|
||||
else:
|
||||
step_token = step_top_logprobs[token_id]
|
||||
|
||||
token = self._get_decoded_token(
|
||||
step_token,
|
||||
token_id,
|
||||
tokenizer,
|
||||
return_as_token_id=self.return_tokens_as_token_ids,
|
||||
)
|
||||
token_logprob = max(step_token.logprob, -9999.0)
|
||||
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(token_logprob)
|
||||
|
||||
# makes sure to add the top num_output_top_logprobs + 1
|
||||
# logprobs, as defined in the openai API
|
||||
# (cf. https://github.com/openai/openai-openapi/blob/
|
||||
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
|
||||
out_top_logprobs.append({
|
||||
# Convert float("-inf") to the
|
||||
# JSON-serializable float that OpenAI uses
|
||||
self._get_decoded_token(
|
||||
top_lp[1],
|
||||
top_lp[0],
|
||||
tokenizer,
|
||||
return_as_token_id=self.return_tokens_as_token_ids):
|
||||
max(top_lp[1].logprob, -9999.0)
|
||||
for i, top_lp in enumerate(step_top_logprobs.items())
|
||||
if num_output_top_logprobs >= i
|
||||
})
|
||||
|
||||
if len(out_text_offset) == 0:
|
||||
out_text_offset.append(initial_text_offset)
|
||||
else:
|
||||
out_text_offset.append(out_text_offset[-1] + last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
return CompletionLogProbs(
|
||||
text_offset=out_text_offset,
|
||||
token_logprobs=out_token_logprobs,
|
||||
tokens=out_tokens,
|
||||
top_logprobs=out_top_logprobs,
|
||||
)
|
||||
203
vllm/entrypoints/openai/serving_embedding.py
Normal file
203
vllm/entrypoints/openai/serving_embedding.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from typing import AsyncGenerator, List, Literal, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
ErrorResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
|
||||
from vllm.utils import merge_async_iterators, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
TypeTokenIDs = List[int]
|
||||
|
||||
|
||||
def _get_embedding(
|
||||
output: EmbeddingOutput,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
) -> Union[List[float], str]:
|
||||
if encoding_format == "float":
|
||||
return output.embedding
|
||||
elif encoding_format == "base64":
|
||||
# Force to use float32 for base64 encoding
|
||||
# to match the OpenAI python client behavior
|
||||
embedding_bytes = np.array(output.embedding, dtype="float32").tobytes()
|
||||
return base64.b64encode(embedding_bytes).decode("utf-8")
|
||||
|
||||
assert_never(encoding_format)
|
||||
|
||||
|
||||
def request_output_to_embedding_response(
|
||||
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
|
||||
created_time: int, model_name: str,
|
||||
encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
|
||||
data: List[EmbeddingResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
embedding = _get_embedding(final_res.outputs, encoding_format)
|
||||
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
|
||||
data.append(embedding_data)
|
||||
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=data,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIServingEmbedding(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger)
|
||||
self._enabled = self._check_embedding_mode(model_config.embedding_mode)
|
||||
|
||||
async def create_embedding(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[EmbeddingResponse, ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return self.create_error_response("Embedding API disabled")
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
encoding_format = request.encoding_format
|
||||
if request.dimensions is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"embd-{random_uuid()}"
|
||||
created_time = int(time.monotonic())
|
||||
|
||||
truncate_prompt_tokens = None
|
||||
|
||||
if request.truncate_prompt_tokens is not None:
|
||||
if request.truncate_prompt_tokens <= self.max_model_len:
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
else:
|
||||
return self.create_error_response(
|
||||
"truncate_prompt_tokens value is "
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size.")
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
prompts = list(
|
||||
self._tokenize_prompt_input_or_inputs(request, tokenizer,
|
||||
request.input,
|
||||
truncate_prompt_tokens))
|
||||
|
||||
for i, prompt_inputs in enumerate(prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
prompt_inputs,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError(
|
||||
"Prompt adapter is not supported "
|
||||
"for embedding models")
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(
|
||||
*generators,
|
||||
is_cancelled=raw_request.is_disconnected if raw_request else None,
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
||||
final_res_batch = [None] * len(prompts)
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
for final_res in final_res_batch:
|
||||
assert final_res is not None
|
||||
|
||||
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = request_output_to_embedding_response(
|
||||
final_res_batch_checked, request_id, created_time, model_name,
|
||||
encoding_format)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return response
|
||||
|
||||
def _check_embedding_mode(self, embedding_mode: bool) -> bool:
|
||||
if not embedding_mode:
|
||||
logger.warning(
|
||||
"embedding_mode is False. Embedding API will not work.")
|
||||
else:
|
||||
logger.info("Activating the server engine with embedding enabled.")
|
||||
return embedding_mode
|
||||
487
vllm/entrypoints/openai/serving_engine.py
Normal file
487
vllm/entrypoints/openai/serving_engine.py
Normal file
@@ -0,0 +1,487 @@
|
||||
import json
|
||||
import pathlib
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
from pydantic import Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
DetokenizeRequest,
|
||||
EmbeddingRequest, ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
ModelCard, ModelList,
|
||||
ModelPermission,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest,
|
||||
TokenizeRequest,
|
||||
UnloadLoraAdapterRequest)
|
||||
# yapf: enable
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import AtomicCounter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelPath:
|
||||
name: str
|
||||
model_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptAdapterPath:
|
||||
name: str
|
||||
local_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAModulePath:
|
||||
name: str
|
||||
path: str
|
||||
base_model_name: Optional[str] = None
|
||||
|
||||
|
||||
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
|
||||
EmbeddingRequest, TokenizeRequest]
|
||||
|
||||
|
||||
class TextTokensPrompt(TypedDict):
|
||||
prompt: str
|
||||
prompt_token_ids: List[int]
|
||||
|
||||
|
||||
class OpenAIServing:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.engine_client = engine_client
|
||||
self.model_config = model_config
|
||||
self.max_model_len = model_config.max_model_len
|
||||
|
||||
self.base_model_paths = base_model_paths
|
||||
|
||||
self.lora_id_counter = AtomicCounter(0)
|
||||
self.lora_requests = []
|
||||
if lora_modules is not None:
|
||||
self.lora_requests = [
|
||||
LoRARequest(lora_name=lora.name,
|
||||
lora_int_id=i,
|
||||
lora_path=lora.path,
|
||||
base_model_name=lora.base_model_name
|
||||
if lora.base_model_name
|
||||
and self._is_model_supported(lora.base_model_name)
|
||||
else self.base_model_paths[0].name)
|
||||
for i, lora in enumerate(lora_modules, start=1)
|
||||
]
|
||||
|
||||
self.prompt_adapter_requests = []
|
||||
if prompt_adapters is not None:
|
||||
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
|
||||
with pathlib.Path(prompt_adapter.local_path,
|
||||
"adapter_config.json").open() as f:
|
||||
adapter_config = json.load(f)
|
||||
num_virtual_tokens = adapter_config["num_virtual_tokens"]
|
||||
self.prompt_adapter_requests.append(
|
||||
PromptAdapterRequest(
|
||||
prompt_adapter_name=prompt_adapter.name,
|
||||
prompt_adapter_id=i,
|
||||
prompt_adapter_local_path=prompt_adapter.local_path,
|
||||
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
|
||||
|
||||
self.request_logger = request_logger
|
||||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. Right now we only have one model."""
|
||||
model_cards = [
|
||||
ModelCard(id=base_model.name,
|
||||
max_model_len=self.max_model_len,
|
||||
root=base_model.model_path,
|
||||
permission=[ModelPermission()])
|
||||
for base_model in self.base_model_paths
|
||||
]
|
||||
lora_cards = [
|
||||
ModelCard(id=lora.lora_name,
|
||||
root=lora.local_path,
|
||||
parent=lora.base_model_name if lora.base_model_name else
|
||||
self.base_model_paths[0].name,
|
||||
permission=[ModelPermission()])
|
||||
for lora in self.lora_requests
|
||||
]
|
||||
prompt_adapter_cards = [
|
||||
ModelCard(id=prompt_adapter.prompt_adapter_name,
|
||||
root=self.base_model_paths[0].name,
|
||||
permission=[ModelPermission()])
|
||||
for prompt_adapter in self.prompt_adapter_requests
|
||||
]
|
||||
model_cards.extend(lora_cards)
|
||||
model_cards.extend(prompt_adapter_cards)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
def create_error_response(
|
||||
self,
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
|
||||
return ErrorResponse(message=message,
|
||||
type=err_type,
|
||||
code=status_code.value)
|
||||
|
||||
def create_streaming_error_response(
|
||||
self,
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
|
||||
json_str = json.dumps({
|
||||
"error":
|
||||
self.create_error_response(message=message,
|
||||
err_type=err_type,
|
||||
status_code=status_code).model_dump()
|
||||
})
|
||||
return json_str
|
||||
|
||||
async def _check_model(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
) -> Optional[ErrorResponse]:
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
||||
return None
|
||||
if request.model in [
|
||||
prompt_adapter.prompt_adapter_name
|
||||
for prompt_adapter in self.prompt_adapter_requests
|
||||
]:
|
||||
return None
|
||||
return self.create_error_response(
|
||||
message=f"The model `{request.model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
def _maybe_get_adapters(
|
||||
self, request: AnyRequest
|
||||
) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
|
||||
None, PromptAdapterRequest]]:
|
||||
if self._is_model_supported(request.model):
|
||||
return None, None
|
||||
for lora in self.lora_requests:
|
||||
if request.model == lora.lora_name:
|
||||
return lora, None
|
||||
for prompt_adapter in self.prompt_adapter_requests:
|
||||
if request.model == prompt_adapter.prompt_adapter_name:
|
||||
return None, prompt_adapter
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||
|
||||
def _normalize_prompt_text_to_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt: str,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
||||
add_special_tokens: bool,
|
||||
) -> TextTokensPrompt:
|
||||
if truncate_prompt_tokens is None:
|
||||
encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
|
||||
else:
|
||||
encoded = tokenizer(prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=True,
|
||||
max_length=truncate_prompt_tokens)
|
||||
|
||||
input_ids = encoded.input_ids
|
||||
|
||||
input_text = prompt
|
||||
|
||||
return self._validate_input(request, input_ids, input_text)
|
||||
|
||||
def _normalize_prompt_tokens_to_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_ids: List[int],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
||||
) -> TextTokensPrompt:
|
||||
if truncate_prompt_tokens is None:
|
||||
input_ids = prompt_ids
|
||||
else:
|
||||
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
||||
|
||||
input_text = tokenizer.decode(input_ids)
|
||||
|
||||
return self._validate_input(request, input_ids, input_text)
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
input_ids: List[int],
|
||||
input_text: str,
|
||||
) -> TextTokensPrompt:
|
||||
token_num = len(input_ids)
|
||||
|
||||
# Note: EmbeddingRequest doesn't have max_tokens
|
||||
if isinstance(request, EmbeddingRequest):
|
||||
if token_num > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the input for embedding "
|
||||
f"generation. Please reduce the length of the input.")
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
|
||||
# and does not require model context length validation
|
||||
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
|
||||
DetokenizeRequest)):
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
if request.max_tokens is None:
|
||||
if token_num >= self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the messages, "
|
||||
f"Please reduce the length of the messages.")
|
||||
elif token_num + request.max_tokens > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{request.max_tokens + token_num} tokens "
|
||||
f"({token_num} in the messages, "
|
||||
f"{request.max_tokens} in the completion). "
|
||||
f"Please reduce the length of the messages or completion.")
|
||||
|
||||
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
def _tokenize_prompt_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_input: Union[str, List[int]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> TextTokensPrompt:
|
||||
"""
|
||||
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
|
||||
that assumes single input.
|
||||
"""
|
||||
return next(
|
||||
self._tokenize_prompt_inputs(
|
||||
request,
|
||||
tokenizer,
|
||||
[prompt_input],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
))
|
||||
|
||||
def _tokenize_prompt_inputs(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_inputs: Iterable[Union[str, List[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> Iterator[TextTokensPrompt]:
|
||||
"""
|
||||
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
|
||||
that assumes multiple inputs.
|
||||
"""
|
||||
for text in prompt_inputs:
|
||||
if isinstance(text, str):
|
||||
yield self._normalize_prompt_text_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt=text,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
yield self._normalize_prompt_tokens_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt_ids=text,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
def _tokenize_prompt_input_or_inputs(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> Iterator[TextTokensPrompt]:
|
||||
"""
|
||||
Tokenize/detokenize depending on the input format.
|
||||
|
||||
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
|
||||
, each input can be a string or array of tokens. Note that each request
|
||||
can pass one or more inputs.
|
||||
"""
|
||||
for prompt_input in parse_and_batch_prompt(input_or_inputs):
|
||||
# Although our type checking is based on mypy,
|
||||
# VSCode Pyright extension should still work properly
|
||||
# "is True" is required for Pyright to perform type narrowing
|
||||
# See: https://github.com/microsoft/pyright/issues/7672
|
||||
if prompt_input["is_tokens"] is False:
|
||||
yield self._normalize_prompt_text_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt=prompt_input["content"],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
yield self._normalize_prompt_tokens_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt_ids=prompt_input["content"],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: Union[str, List[int], TextTokensPrompt],
|
||||
params: Optional[Union[SamplingParams, PoolingParams,
|
||||
BeamSearchParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> None:
|
||||
if self.request_logger is None:
|
||||
return
|
||||
|
||||
if isinstance(inputs, str):
|
||||
prompt = inputs
|
||||
prompt_token_ids = None
|
||||
elif isinstance(inputs, list):
|
||||
prompt = None
|
||||
prompt_token_ids = inputs
|
||||
else:
|
||||
prompt = inputs["prompt"]
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
|
||||
self.request_logger.log_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
prompt_token_ids,
|
||||
params=params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_decoded_token(logprob: Logprob,
|
||||
token_id: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
return_as_token_id: bool = False) -> str:
|
||||
if return_as_token_id:
|
||||
return f"token_id:{token_id}"
|
||||
|
||||
if logprob.decoded_token is not None:
|
||||
return logprob.decoded_token
|
||||
return tokenizer.decode(token_id)
|
||||
|
||||
async def _check_load_lora_adapter_request(
|
||||
self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
|
||||
# Check if both 'lora_name' and 'lora_path' are provided
|
||||
if not request.lora_name or not request.lora_path:
|
||||
return self.create_error_response(
|
||||
message="Both 'lora_name' and 'lora_path' must be provided.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
# Check if the lora adapter with the given name already exists
|
||||
if any(lora_request.lora_name == request.lora_name
|
||||
for lora_request in self.lora_requests):
|
||||
return self.create_error_response(
|
||||
message=
|
||||
f"The lora adapter '{request.lora_name}' has already been"
|
||||
"loaded.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
return None
|
||||
|
||||
async def _check_unload_lora_adapter_request(
|
||||
self,
|
||||
request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
|
||||
# Check if either 'lora_name' or 'lora_int_id' is provided
|
||||
if not request.lora_name and not request.lora_int_id:
|
||||
return self.create_error_response(
|
||||
message=
|
||||
"either 'lora_name' and 'lora_int_id' needs to be provided.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
# Check if the lora adapter with the given name exists
|
||||
if not any(lora_request.lora_name == request.lora_name
|
||||
for lora_request in self.lora_requests):
|
||||
return self.create_error_response(
|
||||
message=
|
||||
f"The lora adapter '{request.lora_name}' cannot be found.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
return None
|
||||
|
||||
async def load_lora_adapter(
|
||||
self,
|
||||
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
|
||||
error_check_ret = await self._check_load_lora_adapter_request(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
lora_name, lora_path = request.lora_name, request.lora_path
|
||||
unique_id = self.lora_id_counter.inc(1)
|
||||
self.lora_requests.append(
|
||||
LoRARequest(lora_name=lora_name,
|
||||
lora_int_id=unique_id,
|
||||
lora_path=lora_path))
|
||||
return f"Success: LoRA adapter '{lora_name}' added successfully."
|
||||
|
||||
async def unload_lora_adapter(
|
||||
self,
|
||||
request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
|
||||
error_check_ret = await self._check_unload_lora_adapter_request(request
|
||||
)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
lora_name = request.lora_name
|
||||
self.lora_requests = [
|
||||
lora_request for lora_request in self.lora_requests
|
||||
if lora_request.lora_name != lora_name
|
||||
]
|
||||
return f"Success: LoRA adapter '{lora_name}' removed successfully."
|
||||
|
||||
def _is_model_supported(self, model_name):
|
||||
return any(model.name == model_name for model in self.base_model_paths)
|
||||
157
vllm/entrypoints/openai/serving_tokenization.py
Normal file
157
vllm/entrypoints/openai/serving_tokenization.py
Normal file
@@ -0,0 +1,157 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
load_chat_template,
|
||||
parse_chat_messages_futures)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
ErrorResponse,
|
||||
TokenizeChatRequest,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
LoRAModulePath,
|
||||
OpenAIServing)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger)
|
||||
|
||||
# If this is None we use the tokenizer's default chat template
|
||||
# the list of commonly-used chat template names for HF named templates
|
||||
hf_chat_templates: List[str] = ['default', 'tool_use']
|
||||
self.chat_template = chat_template \
|
||||
if chat_template in hf_chat_templates \
|
||||
else load_chat_template(chat_template)
|
||||
|
||||
async def create_tokenize(
|
||||
self,
|
||||
request: TokenizeRequest,
|
||||
) -> Union[TokenizeResponse, ErrorResponse]:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"tokn-{random_uuid()}"
|
||||
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
prompt: Union[str, List[int]]
|
||||
if isinstance(request, TokenizeChatRequest):
|
||||
model_config = self.model_config
|
||||
|
||||
conversation, mm_data_future = parse_chat_messages_futures(
|
||||
request.messages, model_config, tokenizer)
|
||||
|
||||
mm_data = await mm_data_future
|
||||
if mm_data:
|
||||
logger.warning(
|
||||
"Multi-modal inputs are ignored during tokenization")
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
prompt = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=request.messages,
|
||||
chat_template=self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
)
|
||||
else:
|
||||
prompt = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
conversation=conversation,
|
||||
chat_template=self.chat_template,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
)
|
||||
else:
|
||||
prompt = request.prompt
|
||||
|
||||
self._log_inputs(request_id,
|
||||
prompt,
|
||||
params=None,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
# Silently ignore prompt adapter since it does not affect tokenization
|
||||
|
||||
prompt_input = self._tokenize_prompt_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
input_ids = prompt_input["prompt_token_ids"]
|
||||
|
||||
return TokenizeResponse(tokens=input_ids,
|
||||
count=len(input_ids),
|
||||
max_model_len=self.max_model_len)
|
||||
|
||||
async def create_detokenize(
|
||||
self,
|
||||
request: DetokenizeRequest,
|
||||
) -> Union[DetokenizeResponse, ErrorResponse]:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"tokn-{random_uuid()}"
|
||||
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
self._log_inputs(request_id,
|
||||
request.tokens,
|
||||
params=None,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for tokenization")
|
||||
|
||||
prompt_input = self._tokenize_prompt_input(
|
||||
request,
|
||||
tokenizer,
|
||||
request.tokens,
|
||||
)
|
||||
input_text = prompt_input["prompt"]
|
||||
|
||||
return DetokenizeResponse(prompt=input_text)
|
||||
10
vllm/entrypoints/openai/tool_parsers/__init__.py
Normal file
10
vllm/entrypoints/openai/tool_parsers/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .abstract_tool_parser import ToolParser, ToolParserManager
|
||||
from .hermes_tool_parser import Hermes2ProToolParser
|
||||
from .internlm2_tool_parser import Internlm2ToolParser
|
||||
from .llama_tool_parser import Llama3JsonToolParser
|
||||
from .mistral_tool_parser import MistralToolParser
|
||||
|
||||
__all__ = [
|
||||
"ToolParser", "ToolParserManager", "Hermes2ProToolParser",
|
||||
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser"
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
161
vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py
Normal file
161
vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Type, Union
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ToolParser:
|
||||
"""
|
||||
Abstract ToolParser class that should not be used directly. Provided
|
||||
properties and methods should be used in
|
||||
derived classes.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
self.prev_tool_call_arr: List[Dict] = []
|
||||
# the index of the tool call that is currently being parsed
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: List[str] = []
|
||||
|
||||
self.model_tokenizer = tokenizer
|
||||
|
||||
@cached_property
|
||||
def vocab(self) -> Dict[str, int]:
|
||||
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
|
||||
# whereas all tokenizers have .get_vocab()
|
||||
return self.model_tokenizer.get_vocab()
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
"""
|
||||
Static method that used to adjust the request parameters.
|
||||
"""
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Static method that should be implemented for extracting tool calls from
|
||||
a complete model-generated string.
|
||||
Used for non-streaming responses where we have the entire model response
|
||||
available before sending to the client.
|
||||
Static because it's stateless.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"AbstractToolParser.extract_tool_calls has not been implemented!")
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Instance method that should be implemented for extracting tool calls
|
||||
from an incomplete response; for use when handling tool calls and
|
||||
streaming. Has to be an instance method because it requires state -
|
||||
the current tokens/diffs, but also the information about what has
|
||||
previously been parsed and extracted (see constructor)
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"AbstractToolParser.extract_tool_calls_streaming has not been "
|
||||
"implemented!")
|
||||
|
||||
|
||||
class ToolParserManager:
|
||||
tool_parsers: Dict[str, Type] = {}
|
||||
|
||||
@classmethod
|
||||
def get_tool_parser(cls, name) -> Type:
|
||||
"""
|
||||
Get tool parser by name which is registered by `register_module`.
|
||||
|
||||
Raise a KeyError exception if the name is not registered.
|
||||
"""
|
||||
if name in cls.tool_parsers:
|
||||
return cls.tool_parsers[name]
|
||||
|
||||
raise KeyError(f"tool helper: '{name}' not found in tool_parsers")
|
||||
|
||||
@classmethod
|
||||
def _register_module(cls,
|
||||
module: Type,
|
||||
module_name: Optional[Union[str, List[str]]] = None,
|
||||
force: bool = True) -> None:
|
||||
if not issubclass(module, ToolParser):
|
||||
raise TypeError(
|
||||
f'module must be subclass of ToolParser, but got {type(module)}'
|
||||
)
|
||||
if module_name is None:
|
||||
module_name = module.__name__
|
||||
if isinstance(module_name, str):
|
||||
module_name = [module_name]
|
||||
for name in module_name:
|
||||
if not force and name in cls.tool_parsers:
|
||||
existed_module = cls.tool_parsers[name]
|
||||
raise KeyError(f'{name} is already registered '
|
||||
f'at {existed_module.__module__}')
|
||||
cls.tool_parsers[name] = module
|
||||
|
||||
@classmethod
|
||||
def register_module(
|
||||
cls,
|
||||
name: Optional[Union[str, List[str]]] = None,
|
||||
force: bool = True,
|
||||
module: Union[Type, None] = None) -> Union[type, Callable]:
|
||||
"""
|
||||
Register module with the given name or name list. it can be used as a
|
||||
decoder(with module as None) or normal function(with module as not
|
||||
None).
|
||||
"""
|
||||
if not isinstance(force, bool):
|
||||
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
||||
|
||||
# raise the error ahead of time
|
||||
if not (name is None or isinstance(name, str)
|
||||
or is_list_of(name, str)):
|
||||
raise TypeError(
|
||||
'name must be None, an instance of str, or a sequence of str, '
|
||||
f'but got {type(name)}')
|
||||
|
||||
# use it as a normal method: x.register_module(module=SomeClass)
|
||||
if module is not None:
|
||||
cls._register_module(module=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
# use it as a decorator: @x.register_module()
|
||||
def _register(module):
|
||||
cls._register_module(module=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
return _register
|
||||
|
||||
@classmethod
|
||||
def import_tool_parser(cls, plugin_path: str) -> None:
|
||||
"""
|
||||
Import a user defined tool parser by the path of the tool parser define
|
||||
file.
|
||||
"""
|
||||
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
|
||||
spec = importlib.util.spec_from_file_location(module_name, plugin_path)
|
||||
if spec is None or spec.loader is None:
|
||||
logger.error("load %s from %s failed.", module_name, plugin_path)
|
||||
return
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
338
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
Normal file
338
vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
Normal file
@@ -0,0 +1,338 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, List, Sequence, Union
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("hermes")
|
||||
class Hermes2ProToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
logger.error(
|
||||
"Detected Mistral tokenizer when using a Hermes model")
|
||||
self.model_tokenizer = self.model_tokenizer.tokenizer
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: List[Dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: List[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
|
||||
self.scratch_pad_regex = re.compile(
|
||||
r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
if not self.tool_call_start_token_id or not self.tool_call_end_token_id:
|
||||
raise RuntimeError(
|
||||
"Hermes 2 Pro Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!")
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_call_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
else:
|
||||
|
||||
try:
|
||||
# there are two possible captures - between tags, or between a
|
||||
# tag and end-of-string so the result of
|
||||
# findall is an array of tuples where one is a function call and
|
||||
# the other is None
|
||||
function_call_tuples = (
|
||||
self.tool_call_regex.findall(model_output))
|
||||
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
raw_function_calls = [
|
||||
json.loads(match[0] if match[0] else match[1])
|
||||
for match in function_call_tuples
|
||||
]
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(function_call["arguments"])))
|
||||
for function_call in raw_function_calls
|
||||
]
|
||||
|
||||
content = model_output[:model_output.
|
||||
find(self.tool_call_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in extracting tool call from response %s",
|
||||
e)
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
# check to see if we should be streaming a tool call - is there a
|
||||
if self.tool_call_start_token_id not in current_token_ids:
|
||||
logger.debug("No tool call tokens found!")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
prev_tool_start_count = previous_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
prev_tool_end_count = previous_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
cur_tool_start_count = current_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
cur_tool_end_count = current_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
|
||||
# case: if we're generating text, OR rounding out a tool call
|
||||
if (cur_tool_start_count == cur_tool_end_count
|
||||
and prev_tool_end_count == cur_tool_end_count):
|
||||
logger.debug("Generating text content! skipping tool parsing.")
|
||||
if delta_text != self.tool_call_end_token:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# case: if tool open & close tag counts don't match, we're doing
|
||||
# imaginary "else" block here
|
||||
# something with tools with this diff.
|
||||
# flags for partial JSON parting. exported constants from
|
||||
# "Allow" are handled via BIT MASK
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
|
||||
# case -- we're starting a new tool call
|
||||
if (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count > prev_tool_start_count):
|
||||
if len(delta_token_ids) > 1:
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
else:
|
||||
tool_call_portion = None
|
||||
delta = None
|
||||
|
||||
text_portion = None
|
||||
|
||||
# set cursors and state appropriately
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
||||
|
||||
# case -- we're updating an existing tool call
|
||||
elif (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count == prev_tool_start_count):
|
||||
|
||||
# get the portion of the text that's the tool call
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
text_portion = None
|
||||
|
||||
# case -- the current tool call is being closed.
|
||||
elif (cur_tool_start_count == cur_tool_end_count
|
||||
and cur_tool_end_count > prev_tool_end_count):
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
if diff:
|
||||
diff = json.dumps(diff).replace(
|
||||
self.streamed_args_for_tool[self.current_tool_id], "")
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not "
|
||||
"been streamed yet: %s", diff)
|
||||
self.streamed_args_for_tool[self.current_tool_id] \
|
||||
+= diff
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
|
||||
# case -- otherwise we're just generating text
|
||||
else:
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
delta = DeltaMessage(tool_calls=[], content=text)
|
||||
return delta
|
||||
|
||||
try:
|
||||
|
||||
current_tool_call = partial_json_parser.loads(
|
||||
tool_call_portion or "{}",
|
||||
flags) if tool_call_portion else None
|
||||
logger.debug("Parsed tool call %s", current_tool_call)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# case - we haven't sent the tool name yet. If it's available, send
|
||||
# it. otherwise, wait until it's available.
|
||||
if not self.current_tool_name_sent:
|
||||
function_name: Union[str, None] = current_tool_call.get("name")
|
||||
if function_name:
|
||||
self.current_tool_name_sent = True
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
else:
|
||||
return None
|
||||
# case -- otherwise, send the tool call delta
|
||||
|
||||
# if the tool call portion is None, send the delta as text
|
||||
if tool_call_portion is None:
|
||||
# if there's text but not tool calls, send that -
|
||||
# otherwise None to skip chunk
|
||||
delta = DeltaMessage(content=delta_text) \
|
||||
if text_portion is not None else None
|
||||
return delta
|
||||
|
||||
# now, the nitty-gritty of tool calls
|
||||
# now we have the portion to parse as tool call.
|
||||
|
||||
logger.debug("Trying to parse current tool call with ID %s",
|
||||
self.current_tool_id)
|
||||
|
||||
# if we're starting a new tool call, push an empty object in as
|
||||
# a placeholder for the arguments
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
|
||||
# main logic for tool parsing here - compare prev. partially-parsed
|
||||
# JSON to the current partially-parsed JSON
|
||||
prev_arguments = (
|
||||
self.prev_tool_call_arr[self.current_tool_id].get("arguments"))
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
||||
logger.debug("against new ones: %s", cur_arguments)
|
||||
|
||||
# case -- no arguments have been created yet. skip sending a delta.
|
||||
if not cur_arguments and not prev_arguments:
|
||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
||||
delta = None
|
||||
|
||||
# case -- prev arguments are defined, but non are now.
|
||||
# probably impossible, but not a fatal error - just keep going
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error("should be impossible to have arguments reset "
|
||||
"mid-call. skipping streaming anything.")
|
||||
delta = None
|
||||
|
||||
# case -- we now have the first info about arguments available from
|
||||
# autocompleting the JSON
|
||||
elif cur_arguments and not prev_arguments:
|
||||
|
||||
cur_arguments_json = json.dumps(cur_arguments)
|
||||
logger.debug("finding %s in %s", delta_text,
|
||||
cur_arguments_json)
|
||||
|
||||
# get the location where previous args differ from current
|
||||
args_delta_start_loc = cur_arguments_json.index(delta_text) \
|
||||
+ len(delta_text)
|
||||
|
||||
# use that to find the actual delta
|
||||
arguments_delta = cur_arguments_json[:args_delta_start_loc]
|
||||
logger.debug("First tokens in arguments received: %s",
|
||||
arguments_delta)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[self.current_tool_id] \
|
||||
+= arguments_delta
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
prev_args_json = json.dumps(prev_arguments)
|
||||
logger.debug("Searching for diff between\n%s", cur_args_json)
|
||||
logger.debug("and\n%s", prev_args_json)
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json)
|
||||
logger.debug("got argument diff %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[self.current_tool_id] \
|
||||
+= argument_diff
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
||||
self.prev_tool_call_arr[self.current_tool_id] = \
|
||||
current_tool_call
|
||||
else:
|
||||
self.prev_tool_call_arr.append(current_tool_call)
|
||||
|
||||
return delta
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error trying to handle streaming tool call: %s", e)
|
||||
return None # do not stream a delta. skip this token ID.
|
||||
208
vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
Normal file
208
vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import json
|
||||
from typing import Dict, Sequence, Union
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module(["internlm"])
|
||||
class Internlm2ToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.position = 0
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
if request.tools and request.tool_choice != 'none':
|
||||
# do not skip special tokens because internlm use the special
|
||||
# tokens to indicated the start and end of the tool calls
|
||||
# information.
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def get_argments(self, obj):
|
||||
if "parameters" in obj:
|
||||
return obj.get("parameters")
|
||||
elif "arguments" in obj:
|
||||
return obj.get("arguments")
|
||||
return None
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
if '<|action_start|>' not in current_text:
|
||||
self.position = len(current_text)
|
||||
return DeltaMessage(content=delta_text)
|
||||
# if the tool call is sended, return a empty delta message
|
||||
# to make sure the finish_reason will be send correctly.
|
||||
if self.current_tool_id > 0:
|
||||
return DeltaMessage(content='')
|
||||
|
||||
last_pos = self.position
|
||||
if '<|action_start|><|plugin|>' not in current_text[last_pos:]:
|
||||
return None
|
||||
|
||||
new_delta = current_text[last_pos:]
|
||||
text, action = new_delta.split('<|action_start|><|plugin|>')
|
||||
|
||||
if len(text) > 0:
|
||||
self.position = self.position + len(text)
|
||||
return DeltaMessage(content=text)
|
||||
|
||||
action = action.strip()
|
||||
action = action.split('<|action_end|>'.strip())[0]
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
|
||||
try:
|
||||
parsable_arr = action
|
||||
|
||||
# tool calls are generated in an object in inernlm2
|
||||
# it's not support parallel tool calls
|
||||
try:
|
||||
tool_call_arr: Dict = partial_json_parser.loads(
|
||||
parsable_arr, flags)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
if not self.current_tool_name_sent:
|
||||
function_name = tool_call_arr.get("name")
|
||||
if function_name:
|
||||
self.current_tool_id = self.current_tool_id + 1
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
self.streamed_args_for_tool.append("")
|
||||
else:
|
||||
delta = None
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
prev_arguments = self.get_argments(
|
||||
self.prev_tool_call_arr[self.current_tool_id])
|
||||
cur_arguments = self.get_argments(tool_call_arr)
|
||||
|
||||
# not arguments generated
|
||||
if not cur_arguments and not prev_arguments:
|
||||
delta = None
|
||||
# will never happen
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"INVARIANT - impossible to have arguments reset "
|
||||
"mid-arguments")
|
||||
delta = None
|
||||
# first time to get parameters
|
||||
elif cur_arguments and not prev_arguments:
|
||||
cur_arguments_json = json.dumps(cur_arguments)
|
||||
|
||||
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
||||
index(delta_text) +
|
||||
len(delta_text)]
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += arguments_delta
|
||||
# both prev and cur parameters, send the increase parameters
|
||||
elif cur_arguments and prev_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
prev_args_json = json.dumps(prev_arguments)
|
||||
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
|
||||
# check to see if the name is defined and has been sent. if so,
|
||||
# stream the name - otherwise keep waiting
|
||||
# finish by setting old and returning None as base case
|
||||
tool_call_arr["arguments"] = self.get_argments(tool_call_arr)
|
||||
self.prev_tool_call_arr = [tool_call_arr]
|
||||
return delta
|
||||
except Exception as e:
|
||||
logger.error("Error trying to handle streaming tool call: %s", e)
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
text = model_output
|
||||
tools = request.tools
|
||||
if '<|action_start|><|plugin|>' in text:
|
||||
text, action = text.split('<|action_start|><|plugin|>')
|
||||
action = action.split('<|action_end|>'.strip())[0]
|
||||
action = action[action.find('{'):]
|
||||
action_dict = json.loads(action)
|
||||
name, parameters = action_dict['name'], json.dumps(
|
||||
action_dict.get('parameters', action_dict.get('arguments',
|
||||
{})))
|
||||
|
||||
if not tools or name not in [t.function.name for t in tools]:
|
||||
ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=text)
|
||||
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
function=FunctionCall(name=name, arguments=parameters))
|
||||
]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=text if len(text) > 0 else None)
|
||||
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=text)
|
||||
277
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
Normal file
277
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
Normal file
@@ -0,0 +1,277 @@
|
||||
import json
|
||||
import re
|
||||
from json import JSONDecodeError, JSONDecoder
|
||||
from typing import Dict, List, Sequence, Union
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import find_common_prefix
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# partial_json_parser doesn't support extra data and
|
||||
# JSONDecorder.raw_decode doesn't support partial JSON
|
||||
def partial_json_loads(input_str, flags):
|
||||
try:
|
||||
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||
except JSONDecodeError as e:
|
||||
if "Extra data" in e.msg:
|
||||
dec = JSONDecoder()
|
||||
return dec.raw_decode(input_str)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def is_complete_json(input_str):
|
||||
try:
|
||||
json.loads(input_str)
|
||||
return True
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
|
||||
|
||||
@ToolParserManager.register_module("llama3_json")
|
||||
class Llama3JsonToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Llama 3.1 models intended for use with the
|
||||
examples/tool_chat_template_llama.jinja template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
# streaming mode
|
||||
self.prev_tool_call_arr: List[Dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: List[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
self.bot_token = "<|python_tag|>"
|
||||
self.bot_token_id = tokenizer.encode(self.bot_token,
|
||||
add_special_tokens=False)[0]
|
||||
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
# case -- if a tool call token is not present, return a text response
|
||||
if not (model_output.startswith(self.bot_token)
|
||||
or model_output.startswith('{')):
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
try:
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
dec = JSONDecoder()
|
||||
function_call_arr = []
|
||||
|
||||
# depending on the prompt format the Llama model may or may not
|
||||
# prefix the output with the <|python_tag|> token
|
||||
start_idx = len(self.bot_token) if model_output.startswith(
|
||||
self.bot_token) else 0
|
||||
while start_idx < len(model_output):
|
||||
(obj, end_idx) = dec.raw_decode(model_output[start_idx:])
|
||||
start_idx += end_idx + len('; ')
|
||||
function_call_arr.append(obj)
|
||||
|
||||
tool_calls: List[ToolCall] = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=raw_function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(raw_function_call["arguments"] \
|
||||
if "arguments" in raw_function_call \
|
||||
else raw_function_call["parameters"])))
|
||||
for raw_function_call in function_call_arr
|
||||
]
|
||||
|
||||
# get any content before the tool call
|
||||
ret = ExtractedToolCallInformation(tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=None)
|
||||
return ret
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in extracting tool call from response: %s", e)
|
||||
print("ERROR", e)
|
||||
# return information to just treat the tool call as regular JSON
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
if not (current_text.startswith(self.bot_token)
|
||||
or current_text.startswith('{')):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
tool_call_arr = []
|
||||
is_complete = []
|
||||
try:
|
||||
# depending on the prompt format the Llama model may or may not
|
||||
# prefix the output with the <|python_tag|> token
|
||||
start_idx = len(self.bot_token) if current_text.startswith(
|
||||
self.bot_token) else 0
|
||||
while start_idx < len(current_text):
|
||||
(obj,
|
||||
end_idx) = partial_json_loads(current_text[start_idx:],
|
||||
flags)
|
||||
is_complete.append(
|
||||
is_complete_json(current_text[start_idx:start_idx +
|
||||
end_idx]))
|
||||
start_idx += end_idx + len('; ')
|
||||
# depending on the prompt Llama can use
|
||||
# either arguments or parameters
|
||||
if "parameters" in obj:
|
||||
assert "arguments" not in obj, \
|
||||
"model generated both parameters and arguments"
|
||||
obj["arguments"] = obj["parameters"]
|
||||
tool_call_arr.append(obj)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
|
||||
if len(tool_call_arr) > 0 else {}
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return None
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (len(tool_call_arr) > 0
|
||||
and len(tool_call_arr) > self.current_tool_id + 1):
|
||||
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
if cur_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
sent = len(
|
||||
self.streamed_args_for_tool[self.current_tool_id])
|
||||
argument_diff = cur_args_json[sent:]
|
||||
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
else:
|
||||
delta = None
|
||||
else:
|
||||
delta = None
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
elif not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
delta = None
|
||||
|
||||
if cur_arguments:
|
||||
sent = len(
|
||||
self.streamed_args_for_tool[self.current_tool_id])
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
prev_arguments = self.prev_tool_call_arr[
|
||||
self.current_tool_id].get("arguments")
|
||||
|
||||
argument_diff = None
|
||||
if is_complete[self.current_tool_id]:
|
||||
argument_diff = cur_args_json[sent:]
|
||||
elif prev_arguments:
|
||||
prev_args_json = json.dumps(prev_arguments)
|
||||
if cur_args_json != prev_args_json:
|
||||
|
||||
prefix = find_common_prefix(
|
||||
prev_args_json, cur_args_json)
|
||||
argument_diff = prefix[sent:]
|
||||
|
||||
if argument_diff is not None:
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error trying to handle streaming tool call: %s", e)
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
306
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
Normal file
306
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
Normal file
@@ -0,0 +1,306 @@
|
||||
import json
|
||||
import re
|
||||
from random import choices
|
||||
from string import ascii_letters, digits
|
||||
from typing import Dict, List, Sequence, Union
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ALPHANUMERIC = ascii_letters + digits
|
||||
|
||||
|
||||
class MistralToolCall(ToolCall):
|
||||
id: str = Field(
|
||||
default_factory=lambda: MistralToolCall.generate_random_id())
|
||||
|
||||
@staticmethod
|
||||
def generate_random_id():
|
||||
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
|
||||
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
|
||||
return "".join(choices(ALPHANUMERIC, k=9))
|
||||
|
||||
|
||||
@ToolParserManager.register_module("mistral")
|
||||
class MistralToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
|
||||
examples/tool_chat_template_mistral.jinja template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if not isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
logger.info("Non-Mistral tokenizer detected when using a Mistral "
|
||||
"model...")
|
||||
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
# streaming mode
|
||||
self.prev_tool_call_arr: List[Dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: List[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
self.bot_token = "[TOOL_CALLS]"
|
||||
self.bot_token_id = self.vocab.get(self.bot_token)
|
||||
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
||||
if not self.bot_token_id:
|
||||
raise RuntimeError(
|
||||
"Mistral Tool Parser could not locate the tool call token in "
|
||||
"the tokenizer!")
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response. Requires
|
||||
find-and-replacing single quotes with double quotes for JSON parsing,
|
||||
make sure your tool call arguments don't ever include quotes!
|
||||
"""
|
||||
|
||||
# case -- if a tool call token is not present, return a text response
|
||||
if self.bot_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
try:
|
||||
|
||||
# use a regex to find the tool call. remove the BOT token
|
||||
# and make sure to replace single quotes with double quotes
|
||||
raw_tool_call = self.tool_call_regex.findall(
|
||||
model_output.replace(self.bot_token, ""))[0]
|
||||
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
function_call_arr = json.loads(raw_tool_call)
|
||||
tool_calls: List[MistralToolCall] = [
|
||||
MistralToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=raw_function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(raw_function_call["arguments"])))
|
||||
for raw_function_call in function_call_arr
|
||||
]
|
||||
|
||||
# get any content before the tool call
|
||||
content = model_output.split(self.bot_token)[0]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if len(content) > 0 else None)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in extracting tool call from response: %s", e)
|
||||
# return information to just treat the tool call as regular JSON
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
# if the tool call token is not in the tokens generated so far, append
|
||||
# output to contents since it's not a tool
|
||||
if self.bot_token not in current_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# if the tool call token ID IS in the tokens generated so far, that
|
||||
# means we're parsing as tool calls now
|
||||
|
||||
# handle if we detected the BOT token which means the start of tool
|
||||
# calling
|
||||
if (self.bot_token_id in delta_token_ids
|
||||
and len(delta_token_ids) == 1):
|
||||
# if it's the only token, return None, so we don't send a chat
|
||||
# completion any don't send a control token
|
||||
return None
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
|
||||
# replace BOT token with empty string, and convert single quotes
|
||||
# to double to allow parsing as JSON since mistral uses single
|
||||
# quotes instead of double for tool calls
|
||||
parsable_arr = current_text.split(self.bot_token)[-1]
|
||||
|
||||
# tool calls are generated in an array, so do partial JSON
|
||||
# parsing on the entire array
|
||||
try:
|
||||
tool_call_arr: List[Dict] = partial_json_parser.loads(
|
||||
parsable_arr, flags)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
|
||||
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
|
||||
if len(tool_call_arr) > 0 else {}
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return None
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (len(tool_call_arr) > 0
|
||||
and len(tool_call_arr) > self.current_tool_id + 1):
|
||||
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
diff: Union[str, None] = current_tool_call.get("arguments")
|
||||
|
||||
if diff:
|
||||
diff = json.dumps(diff).replace(
|
||||
self.streamed_args_for_tool[self.current_tool_id],
|
||||
"")
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += diff
|
||||
else:
|
||||
delta = None
|
||||
else:
|
||||
delta = None
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("starting on new tool %d", self.current_tool_id)
|
||||
return delta
|
||||
|
||||
# case: update an existing tool - this is handled below
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
if not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
|
||||
prev_arguments = self.prev_tool_call_arr[
|
||||
self.current_tool_id].get("arguments")
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
new_text = delta_text.replace("\'", "\"")
|
||||
|
||||
if not cur_arguments and not prev_arguments:
|
||||
|
||||
delta = None
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"INVARIANT - impossible to have arguments reset "
|
||||
"mid-arguments")
|
||||
delta = None
|
||||
elif cur_arguments and not prev_arguments:
|
||||
cur_arguments_json = json.dumps(cur_arguments)
|
||||
logger.debug("finding %s in %s", new_text,
|
||||
cur_arguments_json)
|
||||
|
||||
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
||||
index(new_text) +
|
||||
len(new_text)]
|
||||
logger.debug("First tokens in arguments received: %s",
|
||||
arguments_delta)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += arguments_delta
|
||||
|
||||
elif cur_arguments and prev_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
prev_args_json = json.dumps(prev_arguments)
|
||||
logger.debug("Searching for diff between \n%s\n%s",
|
||||
cur_args_json, prev_args_json)
|
||||
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json)
|
||||
logger.debug("got arguments diff: %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
else:
|
||||
# try parsing it with regular JSON - if it works we're
|
||||
# at the end, and we need to send the difference between
|
||||
# tokens streamed so far and the valid JSON
|
||||
delta = None
|
||||
|
||||
# check to see if the name is defined and has been sent. if so,
|
||||
# stream the name - otherwise keep waiting
|
||||
# finish by setting old and returning None as base case
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return delta
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error trying to handle streaming tool call: %s", e)
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
87
vllm/entrypoints/openai/tool_parsers/utils.py
Normal file
87
vllm/entrypoints/openai/tool_parsers/utils.py
Normal file
@@ -0,0 +1,87 @@
|
||||
def find_common_prefix(s1: str, s2: str) -> str:
|
||||
"""
|
||||
Finds a common prefix that is shared between two strings, if there is one.
|
||||
Order of arguments is NOT important.
|
||||
|
||||
This function is provided as a UTILITY for extracting information from JSON
|
||||
generated by partial_json_parser, to help in ensuring that the right tokens
|
||||
are returned in streaming, so that close-quotes, close-brackets and
|
||||
close-braces are not returned prematurely.
|
||||
|
||||
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
|
||||
'{"fruit": "ap'
|
||||
"""
|
||||
prefix = ''
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(0, min_length):
|
||||
if s1[i] == s2[i]:
|
||||
prefix += s1[i]
|
||||
else:
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def find_common_suffix(s1: str, s2: str) -> str:
|
||||
"""
|
||||
Finds a common suffix shared between two strings, if there is one. Order of
|
||||
arguments is NOT important.
|
||||
Stops when the suffix ends OR it hits an alphanumeric character
|
||||
|
||||
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
|
||||
"""
|
||||
suffix = ''
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(1, min_length + 1):
|
||||
if s1[-i] == s2[-i] and not s1[-i].isalnum():
|
||||
suffix = s1[-i] + suffix
|
||||
else:
|
||||
break
|
||||
return suffix
|
||||
|
||||
|
||||
def extract_intermediate_diff(curr: str, old: str) -> str:
|
||||
"""
|
||||
Given two strings, extract the difference in the middle between two strings
|
||||
that are known to have a common prefix and/or suffix.
|
||||
|
||||
This function is provided as a UTILITY for extracting information from JSON
|
||||
generated by partial_json_parser, to help in ensuring that the right tokens
|
||||
are returned in streaming, so that close-quotes, close-brackets and
|
||||
close-braces are not returned prematurely. The order of arguments IS
|
||||
important - the new version of the partially-parsed JSON must be the first
|
||||
argument, and the secnod argument must be from the previous generation.
|
||||
|
||||
What it returns, is tokens that should be streamed to the client.
|
||||
|
||||
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
|
||||
-> 'ple'
|
||||
|
||||
"""
|
||||
suffix = find_common_suffix(curr, old)
|
||||
|
||||
old = old[::-1].replace(suffix[::-1], '', 1)[::-1]
|
||||
prefix = find_common_prefix(curr, old)
|
||||
diff = curr
|
||||
if len(suffix):
|
||||
diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1]
|
||||
|
||||
if len(prefix):
|
||||
# replace the prefix only once in case it's mirrored
|
||||
diff = diff.replace(prefix, '', 1)
|
||||
|
||||
return diff
|
||||
|
||||
|
||||
def find_all_indices(string, substring):
|
||||
"""
|
||||
Find all (starting) indices of a substring in a given string. Useful for
|
||||
tool call extraction
|
||||
"""
|
||||
indices = []
|
||||
index = -1
|
||||
while True:
|
||||
index = string.find(substring, index + 1)
|
||||
if index == -1:
|
||||
break
|
||||
indices.append(index)
|
||||
return indices
|
||||
Reference in New Issue
Block a user