Add reasoning parser mechanism + qwen3 parser + bugfixes
This commit is contained in:
595
qwen3_6_scripts/api_server.py
Normal file
595
qwen3_6_scripts/api_server.py
Normal file
@@ -0,0 +1,595 @@
|
|||||||
|
import asyncio
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import regex as re
|
||||||
|
import signal
|
||||||
|
import socket
|
||||||
|
import tempfile
|
||||||
|
from argparse import Namespace
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from functools import partial
|
||||||
|
from http import HTTPStatus
|
||||||
|
from typing import AsyncIterator, Set
|
||||||
|
|
||||||
|
import uvloop
|
||||||
|
from fastapi import APIRouter, FastAPI, Request
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
|
from starlette.datastructures import State
|
||||||
|
from starlette.routing import Mount
|
||||||
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||||
|
from vllm.engine.multiprocessing.engine import run_mp_engine
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
|
from vllm.entrypoints.launcher import serve_http
|
||||||
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
|
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||||
|
validate_parsed_serve_args)
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
CompletionRequest,
|
||||||
|
CompletionResponse,
|
||||||
|
DetokenizeRequest,
|
||||||
|
DetokenizeResponse,
|
||||||
|
EmbeddingRequest,
|
||||||
|
EmbeddingResponse, ErrorResponse,
|
||||||
|
LoadLoraAdapterRequest,
|
||||||
|
TokenizeRequest,
|
||||||
|
TokenizeResponse,
|
||||||
|
UnloadLoraAdapterRequest)
|
||||||
|
# yapf: enable
|
||||||
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
|
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
|
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||||
|
from vllm.entrypoints.openai.serving_tokenization import (
|
||||||
|
OpenAIServingTokenization)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
|
from vllm.reasoning import ReasoningParserManager
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
|
||||||
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
|
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||||
|
|
||||||
|
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||||
|
|
||||||
|
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||||
|
logger = init_logger('vllm.entrypoints.openai.api_server')
|
||||||
|
|
||||||
|
_running_tasks: Set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
try:
|
||||||
|
if app.state.log_stats:
|
||||||
|
engine_client: EngineClient = app.state.engine_client
|
||||||
|
|
||||||
|
async def _force_log():
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(10.)
|
||||||
|
await engine_client.do_log_stats()
|
||||||
|
|
||||||
|
task = asyncio.create_task(_force_log())
|
||||||
|
_running_tasks.add(task)
|
||||||
|
task.add_done_callback(_running_tasks.remove)
|
||||||
|
else:
|
||||||
|
task = None
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if task is not None:
|
||||||
|
task.cancel()
|
||||||
|
finally:
|
||||||
|
# Ensure app state including engine ref is gc'd
|
||||||
|
del app.state
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def build_async_engine_client(
|
||||||
|
args: Namespace) -> AsyncIterator[EngineClient]:
|
||||||
|
|
||||||
|
# Context manager to handle engine_client lifecycle
|
||||||
|
# Ensures everything is shutdown and cleaned up on error/exit
|
||||||
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
|
|
||||||
|
async with build_async_engine_client_from_engine_args(
|
||||||
|
engine_args, args.disable_frontend_multiprocessing) as engine:
|
||||||
|
yield engine
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def build_async_engine_client_from_engine_args(
|
||||||
|
engine_args: AsyncEngineArgs,
|
||||||
|
disable_frontend_multiprocessing: bool = False,
|
||||||
|
) -> AsyncIterator[EngineClient]:
|
||||||
|
"""
|
||||||
|
Create EngineClient, either:
|
||||||
|
- in-process using the AsyncLLMEngine Directly
|
||||||
|
- multiprocess using AsyncLLMEngine RPC
|
||||||
|
|
||||||
|
Returns the Client or None if the creation failed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Fall back
|
||||||
|
# TODO: fill out feature matrix.
|
||||||
|
if (MQLLMEngineClient.is_unsupported_config(engine_args)
|
||||||
|
or disable_frontend_multiprocessing):
|
||||||
|
engine_config = engine_args.create_engine_config()
|
||||||
|
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
|
||||||
|
"uses_ray", False)
|
||||||
|
|
||||||
|
build_engine = partial(AsyncLLMEngine.from_engine_args,
|
||||||
|
engine_args=engine_args,
|
||||||
|
engine_config=engine_config,
|
||||||
|
usage_context=UsageContext.OPENAI_API_SERVER)
|
||||||
|
if uses_ray:
|
||||||
|
# Must run in main thread with ray for its signal handlers to work
|
||||||
|
engine_client = build_engine()
|
||||||
|
else:
|
||||||
|
engine_client = await asyncio.get_running_loop().run_in_executor(
|
||||||
|
None, build_engine)
|
||||||
|
|
||||||
|
yield engine_client
|
||||||
|
return
|
||||||
|
|
||||||
|
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
||||||
|
else:
|
||||||
|
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
|
||||||
|
# Make TemporaryDirectory for prometheus multiprocessing
|
||||||
|
# Note: global TemporaryDirectory will be automatically
|
||||||
|
# cleaned up upon exit.
|
||||||
|
global prometheus_multiproc_dir
|
||||||
|
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
||||||
|
os.environ[
|
||||||
|
"PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
|
||||||
|
"This directory must be wiped between vLLM runs or "
|
||||||
|
"you will find inaccurate metrics. Unset the variable "
|
||||||
|
"and vLLM will properly handle cleanup.")
|
||||||
|
|
||||||
|
# Select random path for IPC.
|
||||||
|
ipc_path = get_open_zmq_ipc_path()
|
||||||
|
logger.info("Multiprocessing frontend to use %s for IPC Path.",
|
||||||
|
ipc_path)
|
||||||
|
|
||||||
|
# Start RPCServer in separate process (holds the LLMEngine).
|
||||||
|
# the current process might have CUDA context,
|
||||||
|
# so we need to spawn a new process
|
||||||
|
context = multiprocessing.get_context("spawn")
|
||||||
|
|
||||||
|
engine_process = context.Process(target=run_mp_engine,
|
||||||
|
args=(engine_args,
|
||||||
|
UsageContext.OPENAI_API_SERVER,
|
||||||
|
ipc_path))
|
||||||
|
engine_process.start()
|
||||||
|
logger.info("Started engine process with PID %d", engine_process.pid)
|
||||||
|
|
||||||
|
# Build RPCClient, which conforms to EngineClient Protocol.
|
||||||
|
# NOTE: Actually, this is not true yet. We still need to support
|
||||||
|
# embedding models via RPC (see TODO above)
|
||||||
|
engine_config = engine_args.create_engine_config()
|
||||||
|
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await mp_engine_client.setup()
|
||||||
|
break
|
||||||
|
except TimeoutError:
|
||||||
|
if not engine_process.is_alive():
|
||||||
|
raise RuntimeError(
|
||||||
|
"Engine process failed to start") from None
|
||||||
|
|
||||||
|
yield mp_engine_client # type: ignore[misc]
|
||||||
|
finally:
|
||||||
|
# Ensure rpc server process was terminated
|
||||||
|
engine_process.terminate()
|
||||||
|
|
||||||
|
# Close all open connections to the backend
|
||||||
|
mp_engine_client.close()
|
||||||
|
|
||||||
|
# Wait for engine process to join
|
||||||
|
engine_process.join(4)
|
||||||
|
if engine_process.exitcode is None:
|
||||||
|
# Kill if taking longer than 5 seconds to stop
|
||||||
|
engine_process.kill()
|
||||||
|
|
||||||
|
# Lazy import for prometheus multiprocessing.
|
||||||
|
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||||
|
# before prometheus_client is imported.
|
||||||
|
# See https://prometheus.github.io/client_python/multiprocess/
|
||||||
|
from prometheus_client import multiprocess
|
||||||
|
multiprocess.mark_process_dead(engine_process.pid)
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def mount_metrics(app: FastAPI):
|
||||||
|
# Lazy import for prometheus multiprocessing.
|
||||||
|
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||||
|
# before prometheus_client is imported.
|
||||||
|
# See https://prometheus.github.io/client_python/multiprocess/
|
||||||
|
from prometheus_client import (CollectorRegistry, make_asgi_app,
|
||||||
|
multiprocess)
|
||||||
|
|
||||||
|
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
|
||||||
|
if prometheus_multiproc_dir_path is not None:
|
||||||
|
logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
|
||||||
|
prometheus_multiproc_dir_path)
|
||||||
|
registry = CollectorRegistry()
|
||||||
|
multiprocess.MultiProcessCollector(registry)
|
||||||
|
|
||||||
|
# Add prometheus asgi middleware to route /metrics requests
|
||||||
|
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||||
|
else:
|
||||||
|
# Add prometheus asgi middleware to route /metrics requests
|
||||||
|
metrics_route = Mount("/metrics", make_asgi_app())
|
||||||
|
|
||||||
|
# Workaround for 307 Redirect for /metrics
|
||||||
|
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||||
|
app.routes.append(metrics_route)
|
||||||
|
|
||||||
|
|
||||||
|
def chat(request: Request) -> OpenAIServingChat:
|
||||||
|
return request.app.state.openai_serving_chat
|
||||||
|
|
||||||
|
|
||||||
|
def completion(request: Request) -> OpenAIServingCompletion:
|
||||||
|
return request.app.state.openai_serving_completion
|
||||||
|
|
||||||
|
|
||||||
|
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||||
|
return request.app.state.openai_serving_tokenization
|
||||||
|
|
||||||
|
|
||||||
|
def embedding(request: Request) -> OpenAIServingEmbedding:
|
||||||
|
return request.app.state.openai_serving_embedding
|
||||||
|
|
||||||
|
|
||||||
|
def engine_client(request: Request) -> EngineClient:
|
||||||
|
return request.app.state.engine_client
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health")
|
||||||
|
async def health(raw_request: Request) -> Response:
|
||||||
|
"""Health check."""
|
||||||
|
await engine_client(raw_request).check_health()
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/tokenize")
|
||||||
|
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||||
|
generator = await tokenization(raw_request).create_tokenize(request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
elif isinstance(generator, TokenizeResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
assert_never(generator)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/detokenize")
|
||||||
|
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||||
|
generator = await tokenization(raw_request).create_detokenize(request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
elif isinstance(generator, DetokenizeResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
assert_never(generator)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/models")
|
||||||
|
async def show_available_models(raw_request: Request):
|
||||||
|
models = await completion(raw_request).show_available_models()
|
||||||
|
return JSONResponse(content=models.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/version")
|
||||||
|
async def show_version():
|
||||||
|
ver = {"version": VLLM_VERSION}
|
||||||
|
return JSONResponse(content=ver)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/chat/completions")
|
||||||
|
async def create_chat_completion(request: ChatCompletionRequest,
|
||||||
|
raw_request: Request):
|
||||||
|
|
||||||
|
generator = await chat(raw_request).create_chat_completion(
|
||||||
|
request, raw_request)
|
||||||
|
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
|
||||||
|
elif isinstance(generator, ChatCompletionResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/completions")
|
||||||
|
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||||
|
generator = await completion(raw_request).create_completion(
|
||||||
|
request, raw_request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
elif isinstance(generator, CompletionResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/embeddings")
|
||||||
|
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||||
|
generator = await embedding(raw_request).create_embedding(
|
||||||
|
request, raw_request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
elif isinstance(generator, EmbeddingResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
assert_never(generator)
|
||||||
|
|
||||||
|
|
||||||
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||||
|
logger.warning(
|
||||||
|
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||||
|
"used for local development!")
|
||||||
|
|
||||||
|
@router.post("/start_profile")
|
||||||
|
async def start_profile(raw_request: Request):
|
||||||
|
logger.info("Starting profiler...")
|
||||||
|
await engine_client(raw_request).start_profile()
|
||||||
|
logger.info("Profiler started.")
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
@router.post("/stop_profile")
|
||||||
|
async def stop_profile(raw_request: Request):
|
||||||
|
logger.info("Stopping profiler...")
|
||||||
|
await engine_client(raw_request).stop_profile()
|
||||||
|
logger.info("Profiler stopped.")
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||||
|
logger.warning(
|
||||||
|
"Lora dynamic loading & unloading is enabled in the API server. "
|
||||||
|
"This should ONLY be used for local development!")
|
||||||
|
|
||||||
|
@router.post("/v1/load_lora_adapter")
|
||||||
|
async def load_lora_adapter(request: LoadLoraAdapterRequest,
|
||||||
|
raw_request: Request):
|
||||||
|
response = await chat(raw_request).load_lora_adapter(request)
|
||||||
|
if isinstance(response, ErrorResponse):
|
||||||
|
return JSONResponse(content=response.model_dump(),
|
||||||
|
status_code=response.code)
|
||||||
|
|
||||||
|
response = await completion(raw_request).load_lora_adapter(request)
|
||||||
|
if isinstance(response, ErrorResponse):
|
||||||
|
return JSONResponse(content=response.model_dump(),
|
||||||
|
status_code=response.code)
|
||||||
|
|
||||||
|
return Response(status_code=200, content=response)
|
||||||
|
|
||||||
|
@router.post("/v1/unload_lora_adapter")
|
||||||
|
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
|
||||||
|
raw_request: Request):
|
||||||
|
response = await chat(raw_request).unload_lora_adapter(request)
|
||||||
|
if isinstance(response, ErrorResponse):
|
||||||
|
return JSONResponse(content=response.model_dump(),
|
||||||
|
status_code=response.code)
|
||||||
|
|
||||||
|
response = await completion(raw_request).unload_lora_adapter(request)
|
||||||
|
if isinstance(response, ErrorResponse):
|
||||||
|
return JSONResponse(content=response.model_dump(),
|
||||||
|
status_code=response.code)
|
||||||
|
|
||||||
|
return Response(status_code=200, content=response)
|
||||||
|
|
||||||
|
|
||||||
|
def build_app(args: Namespace) -> FastAPI:
|
||||||
|
if args.disable_fastapi_docs:
|
||||||
|
app = FastAPI(openapi_url=None,
|
||||||
|
docs_url=None,
|
||||||
|
redoc_url=None,
|
||||||
|
lifespan=lifespan)
|
||||||
|
else:
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
app.include_router(router)
|
||||||
|
app.root_path = args.root_path
|
||||||
|
|
||||||
|
mount_metrics(app)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=args.allowed_origins,
|
||||||
|
allow_credentials=args.allow_credentials,
|
||||||
|
allow_methods=args.allowed_methods,
|
||||||
|
allow_headers=args.allowed_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def validation_exception_handler(_, exc):
|
||||||
|
chat = app.state.openai_serving_chat
|
||||||
|
err = chat.create_error_response(message=str(exc))
|
||||||
|
return JSONResponse(err.model_dump(),
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
if token := envs.VLLM_API_KEY or args.api_key:
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def authentication(request: Request, call_next):
|
||||||
|
root_path = "" if args.root_path is None else args.root_path
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
return await call_next(request)
|
||||||
|
if not request.url.path.startswith(f"{root_path}/v1"):
|
||||||
|
return await call_next(request)
|
||||||
|
if request.headers.get("Authorization") != "Bearer " + token:
|
||||||
|
return JSONResponse(content={"error": "Unauthorized"},
|
||||||
|
status_code=401)
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
for middleware in args.middleware:
|
||||||
|
module_path, object_name = middleware.rsplit(".", 1)
|
||||||
|
imported = getattr(importlib.import_module(module_path), object_name)
|
||||||
|
if inspect.isclass(imported):
|
||||||
|
app.add_middleware(imported)
|
||||||
|
elif inspect.iscoroutinefunction(imported):
|
||||||
|
app.middleware("http")(imported)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid middleware {middleware}. "
|
||||||
|
f"Must be a function or a class.")
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def init_app_state(
|
||||||
|
engine_client: EngineClient,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
state: State,
|
||||||
|
args: Namespace,
|
||||||
|
) -> None:
|
||||||
|
if args.served_model_name is not None:
|
||||||
|
served_model_names = args.served_model_name
|
||||||
|
else:
|
||||||
|
served_model_names = [args.model]
|
||||||
|
|
||||||
|
if args.disable_log_requests:
|
||||||
|
request_logger = None
|
||||||
|
else:
|
||||||
|
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||||
|
|
||||||
|
base_model_paths = [
|
||||||
|
BaseModelPath(name=name, model_path=args.model)
|
||||||
|
for name in served_model_names
|
||||||
|
]
|
||||||
|
|
||||||
|
state.engine_client = engine_client
|
||||||
|
state.log_stats = not args.disable_log_stats
|
||||||
|
|
||||||
|
state.openai_serving_chat = OpenAIServingChat(
|
||||||
|
engine_client,
|
||||||
|
model_config,
|
||||||
|
base_model_paths,
|
||||||
|
args.response_role,
|
||||||
|
lora_modules=args.lora_modules,
|
||||||
|
prompt_adapters=args.prompt_adapters,
|
||||||
|
request_logger=request_logger,
|
||||||
|
chat_template=args.chat_template,
|
||||||
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
|
enable_auto_tools=args.enable_auto_tool_choice,
|
||||||
|
tool_parser=args.tool_call_parser,
|
||||||
|
reasoning_parser=getattr(args, 'reasoning_parser', None))
|
||||||
|
state.openai_serving_completion = OpenAIServingCompletion(
|
||||||
|
engine_client,
|
||||||
|
model_config,
|
||||||
|
base_model_paths,
|
||||||
|
lora_modules=args.lora_modules,
|
||||||
|
prompt_adapters=args.prompt_adapters,
|
||||||
|
request_logger=request_logger,
|
||||||
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
|
)
|
||||||
|
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
|
engine_client,
|
||||||
|
model_config,
|
||||||
|
base_model_paths,
|
||||||
|
request_logger=request_logger,
|
||||||
|
)
|
||||||
|
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||||
|
engine_client,
|
||||||
|
model_config,
|
||||||
|
base_model_paths,
|
||||||
|
lora_modules=args.lora_modules,
|
||||||
|
request_logger=request_logger,
|
||||||
|
chat_template=args.chat_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||||
|
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||||
|
logger.info("args: %s", args)
|
||||||
|
|
||||||
|
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||||
|
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||||
|
|
||||||
|
valide_tool_parses = ToolParserManager.tool_parsers.keys()
|
||||||
|
if args.enable_auto_tool_choice \
|
||||||
|
and args.tool_call_parser not in valide_tool_parses:
|
||||||
|
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
|
||||||
|
f"(chose from {{ {','.join(valide_tool_parses)} }})")
|
||||||
|
|
||||||
|
reasoning_parser = getattr(args, 'reasoning_parser', None)
|
||||||
|
if reasoning_parser:
|
||||||
|
valid_reasoning = ReasoningParserManager.list_registered()
|
||||||
|
if reasoning_parser not in valid_reasoning:
|
||||||
|
raise KeyError(
|
||||||
|
f"invalid reasoning parser: {reasoning_parser} "
|
||||||
|
f"(chose from {{ {','.join(valid_reasoning)} }})")
|
||||||
|
|
||||||
|
# workaround to make sure that we bind the port before the engine is set up.
|
||||||
|
# This avoids race conditions with ray.
|
||||||
|
# see https://github.com/vllm-project/vllm/issues/8204
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
sock.bind(("", args.port))
|
||||||
|
|
||||||
|
def signal_handler(*_) -> None:
|
||||||
|
# Interrupt server on sigterm while initializing
|
||||||
|
raise KeyboardInterrupt("terminated")
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
|
async with build_async_engine_client(args) as engine_client:
|
||||||
|
app = build_app(args)
|
||||||
|
|
||||||
|
model_config = await engine_client.get_model_config()
|
||||||
|
init_app_state(engine_client, model_config, app.state, args)
|
||||||
|
|
||||||
|
shutdown_task = await serve_http(
|
||||||
|
app,
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
log_level=args.uvicorn_log_level,
|
||||||
|
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
||||||
|
ssl_keyfile=args.ssl_keyfile,
|
||||||
|
ssl_certfile=args.ssl_certfile,
|
||||||
|
ssl_ca_certs=args.ssl_ca_certs,
|
||||||
|
ssl_cert_reqs=args.ssl_cert_reqs,
|
||||||
|
fd=sock.fileno(),
|
||||||
|
**uvicorn_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NB: Await server shutdown only after the backend context is exited
|
||||||
|
await shutdown_task
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# NOTE(simon):
|
||||||
|
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||||
|
parser = make_arg_parser(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
validate_parsed_serve_args(args)
|
||||||
|
|
||||||
|
uvloop.run(run_server(args))
|
||||||
598
qwen3_6_scripts/chat_utils.py
Normal file
598
qwen3_6_scripts/chat_utils.py
Normal file
@@ -0,0 +1,598 @@
|
|||||||
|
import asyncio
|
||||||
|
import codecs
|
||||||
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import defaultdict
|
||||||
|
from functools import lru_cache, partial
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal,
|
||||||
|
Mapping, Optional, Tuple, TypeVar, Union, cast)
|
||||||
|
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
|
from openai.types.chat import (ChatCompletionAssistantMessageParam,
|
||||||
|
ChatCompletionContentPartImageParam)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
|
||||||
|
from openai.types.chat import (ChatCompletionContentPartRefusalParam,
|
||||||
|
ChatCompletionContentPartTextParam)
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
|
||||||
|
from openai.types.chat import (ChatCompletionMessageToolCallParam,
|
||||||
|
ChatCompletionToolMessageParam)
|
||||||
|
# yapf: enable
|
||||||
|
# pydantic needs the TypedDict from typing_extensions
|
||||||
|
from pydantic import ConfigDict
|
||||||
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
from typing_extensions import Required, TypeAlias, TypedDict
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.multimodal import MultiModalDataDict
|
||||||
|
from vllm.multimodal.utils import (async_get_and_parse_audio,
|
||||||
|
async_get_and_parse_image,
|
||||||
|
get_and_parse_audio, get_and_parse_image)
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioURL(TypedDict, total=False):
|
||||||
|
url: Required[str]
|
||||||
|
"""
|
||||||
|
Either a URL of the audio or a data URL with base64 encoded audio data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionContentPartAudioParam(TypedDict, total=False):
|
||||||
|
audio_url: Required[AudioURL]
|
||||||
|
|
||||||
|
type: Required[Literal["audio_url"]]
|
||||||
|
"""The type of the content part."""
|
||||||
|
|
||||||
|
|
||||||
|
class CustomChatCompletionContentPartParam(TypedDict, total=False):
|
||||||
|
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
|
||||||
|
|
||||||
|
type: Required[str]
|
||||||
|
"""The type of the content part."""
|
||||||
|
|
||||||
|
|
||||||
|
ChatCompletionContentPartParam: TypeAlias = Union[
|
||||||
|
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
|
||||||
|
ChatCompletionContentPartRefusalParam,
|
||||||
|
CustomChatCompletionContentPartParam]
|
||||||
|
|
||||||
|
|
||||||
|
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||||
|
"""Enables custom roles in the Chat Completion API."""
|
||||||
|
role: Required[str]
|
||||||
|
"""The role of the message's author."""
|
||||||
|
|
||||||
|
content: Union[str, List[ChatCompletionContentPartParam]]
|
||||||
|
"""The contents of the message."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""An optional name for the participant.
|
||||||
|
|
||||||
|
Provides the model information to differentiate between participants of the
|
||||||
|
same role.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tool_call_id: Optional[str]
|
||||||
|
"""Tool call that this message is responding to."""
|
||||||
|
|
||||||
|
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
|
||||||
|
"""The tool calls generated by the model, such as function calls."""
|
||||||
|
|
||||||
|
reasoning_content: Optional[str]
|
||||||
|
"""Reasoning / thinking content for assistant messages (vLLM extension).
|
||||||
|
When present in a previous assistant turn, it is rendered as
|
||||||
|
<think>...</think> before the main content so the model sees its own
|
||||||
|
chain-of-thought in subsequent turns."""
|
||||||
|
|
||||||
|
|
||||||
|
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
|
||||||
|
CustomChatCompletionMessageParam]
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Make fields ReadOnly once mypy supports it
|
||||||
|
class ConversationMessage(TypedDict, total=False):
|
||||||
|
role: Required[str]
|
||||||
|
"""The role of the message's author."""
|
||||||
|
|
||||||
|
content: Optional[str]
|
||||||
|
"""The contents of the message"""
|
||||||
|
|
||||||
|
tool_call_id: Optional[str]
|
||||||
|
"""Tool call that this message is responding to."""
|
||||||
|
|
||||||
|
name: Optional[str]
|
||||||
|
"""The name of the function to call"""
|
||||||
|
|
||||||
|
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
|
||||||
|
"""The tool calls generated by the model, such as function calls."""
|
||||||
|
|
||||||
|
|
||||||
|
ModalityStr = Literal["image", "audio", "video"]
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||||
|
"""
|
||||||
|
Tracks multi-modal items in a given request and ensures that the number
|
||||||
|
of multi-modal items in a given request does not exceed the configured
|
||||||
|
maximum per prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._model_config = model_config
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
|
||||||
|
if model_config.multimodal_config else {})
|
||||||
|
self._consumed_items = {k: 0 for k in self._allowed_items}
|
||||||
|
|
||||||
|
self._items: List[_T] = []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
|
||||||
|
return tokenizer.decode(token_index)
|
||||||
|
|
||||||
|
def _placeholder_str(self, modality: ModalityStr,
|
||||||
|
current_count: int) -> Optional[str]:
|
||||||
|
# TODO: Let user specify how to insert image tokens into prompt
|
||||||
|
# (similar to chat template)
|
||||||
|
hf_config = self._model_config.hf_config
|
||||||
|
model_type = hf_config.model_type
|
||||||
|
|
||||||
|
if modality == "image":
|
||||||
|
if model_type == "phi3_v":
|
||||||
|
# Workaround since this token is not defined in the tokenizer
|
||||||
|
return f"<|image_{current_count}|>"
|
||||||
|
if model_type == "minicpmv":
|
||||||
|
return "(<image>./</image>)"
|
||||||
|
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
|
||||||
|
"pixtral"):
|
||||||
|
# These models do not use image tokens in the prompt
|
||||||
|
return None
|
||||||
|
if model_type == "qwen":
|
||||||
|
return f"Picture {current_count}: <img></img>"
|
||||||
|
if model_type.startswith("llava"):
|
||||||
|
return self._cached_token_str(self._tokenizer,
|
||||||
|
hf_config.image_token_index)
|
||||||
|
if model_type in ("chameleon", "internvl_chat", "NVLM_D"):
|
||||||
|
return "<image>"
|
||||||
|
if model_type == "mllama":
|
||||||
|
return "<|image|>"
|
||||||
|
if model_type in ("qwen2_vl","qwen2_5_vl"):
|
||||||
|
return "<|vision_start|><|image_pad|><|vision_end|>"
|
||||||
|
if model_type == "molmo":
|
||||||
|
return ""
|
||||||
|
|
||||||
|
raise TypeError(f"Unknown model type: {model_type}")
|
||||||
|
elif modality == "audio":
|
||||||
|
if model_type == "ultravox":
|
||||||
|
return "<|reserved_special_token_0|>"
|
||||||
|
raise TypeError(f"Unknown model type: {model_type}")
|
||||||
|
elif modality == "video":
|
||||||
|
if model_type in ("qwen2_vl","qwen2_5_vl"):
|
||||||
|
return "<|vision_start|><|video_pad|><|vision_end|>"
|
||||||
|
raise TypeError(f"Unknown model type: {model_type}")
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unknown modality: {modality}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
|
||||||
|
mm_lists: Mapping[str, List[object]] = defaultdict(list)
|
||||||
|
|
||||||
|
# Merge all the multi-modal items
|
||||||
|
for single_mm_data in items:
|
||||||
|
for mm_key, mm_item in single_mm_data.items():
|
||||||
|
if isinstance(mm_item, list):
|
||||||
|
mm_lists[mm_key].extend(mm_item)
|
||||||
|
else:
|
||||||
|
mm_lists[mm_key].append(mm_item)
|
||||||
|
|
||||||
|
# Unpack any single item lists for models that don't expect multiple.
|
||||||
|
return {
|
||||||
|
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
|
||||||
|
for mm_key, mm_list in mm_lists.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Add a multi-modal item to the current prompt and returns the
|
||||||
|
placeholder string to use, if any.
|
||||||
|
"""
|
||||||
|
allowed_count = self._allowed_items.get(modality, 1)
|
||||||
|
current_count = self._consumed_items.get(modality, 0) + 1
|
||||||
|
if current_count > allowed_count:
|
||||||
|
raise ValueError(
|
||||||
|
f"At most {allowed_count} {modality}(s) may be provided in "
|
||||||
|
"one request.")
|
||||||
|
|
||||||
|
self._consumed_items[modality] = current_count
|
||||||
|
self._items.append(item)
|
||||||
|
|
||||||
|
return self._placeholder_str(modality, current_count)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]):
|
||||||
|
|
||||||
|
def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||||
|
return self._combine(self._items) if self._items else None
|
||||||
|
|
||||||
|
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||||
|
return MultiModalContentParser(self)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncMultiModalItemTracker(
|
||||||
|
BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
|
||||||
|
|
||||||
|
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||||
|
if self._items:
|
||||||
|
items = await asyncio.gather(*self._items)
|
||||||
|
return self._combine(items)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||||
|
return AsyncMultiModalContentParser(self)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMultiModalContentParser(ABC):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# multimodal placeholder_string : count
|
||||||
|
self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0)
|
||||||
|
|
||||||
|
def _add_placeholder(self, placeholder: Optional[str]):
|
||||||
|
if placeholder:
|
||||||
|
self._placeholder_counts[placeholder] += 1
|
||||||
|
|
||||||
|
def mm_placeholder_counts(self) -> Dict[str, int]:
|
||||||
|
return dict(self._placeholder_counts)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse_image(self, image_url: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse_audio(self, audio_url: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalContentParser(BaseMultiModalContentParser):
|
||||||
|
|
||||||
|
def __init__(self, tracker: MultiModalItemTracker) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._tracker = tracker
|
||||||
|
|
||||||
|
def parse_image(self, image_url: str) -> None:
|
||||||
|
image = get_and_parse_image(image_url)
|
||||||
|
|
||||||
|
placeholder = self._tracker.add("image", image)
|
||||||
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
|
def parse_audio(self, audio_url: str) -> None:
|
||||||
|
audio = get_and_parse_audio(audio_url)
|
||||||
|
|
||||||
|
placeholder = self._tracker.add("audio", audio)
|
||||||
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||||
|
|
||||||
|
def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._tracker = tracker
|
||||||
|
|
||||||
|
def parse_image(self, image_url: str) -> None:
|
||||||
|
image_coro = async_get_and_parse_image(image_url)
|
||||||
|
|
||||||
|
placeholder = self._tracker.add("image", image_coro)
|
||||||
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
|
def parse_audio(self, audio_url: str) -> None:
|
||||||
|
audio_coro = async_get_and_parse_audio(audio_url)
|
||||||
|
|
||||||
|
placeholder = self._tracker.add("audio", audio_coro)
|
||||||
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
|
||||||
|
"""Raises if the provided chat template appears invalid."""
|
||||||
|
if chat_template is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
elif isinstance(chat_template, Path) and not chat_template.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"the supplied chat template path doesn't exist")
|
||||||
|
|
||||||
|
elif isinstance(chat_template, str):
|
||||||
|
JINJA_CHARS = "{}\n"
|
||||||
|
if not any(c in chat_template
|
||||||
|
for c in JINJA_CHARS) and not Path(chat_template).exists():
|
||||||
|
raise ValueError(
|
||||||
|
f"The supplied chat template string ({chat_template}) "
|
||||||
|
f"appears path-like, but doesn't exist!")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"{type(chat_template)} is not a valid chat template type")
|
||||||
|
|
||||||
|
|
||||||
|
def load_chat_template(
|
||||||
|
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
|
||||||
|
if chat_template is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
with open(chat_template, "r") as f:
|
||||||
|
resolved_chat_template = f.read()
|
||||||
|
except OSError as e:
|
||||||
|
if isinstance(chat_template, Path):
|
||||||
|
raise
|
||||||
|
|
||||||
|
JINJA_CHARS = "{}\n"
|
||||||
|
if not any(c in chat_template for c in JINJA_CHARS):
|
||||||
|
msg = (f"The supplied chat template ({chat_template}) "
|
||||||
|
f"looks like a file path, but it failed to be "
|
||||||
|
f"opened. Reason: {e}")
|
||||||
|
raise ValueError(msg) from e
|
||||||
|
|
||||||
|
# If opening a file fails, set chat template to be args to
|
||||||
|
# ensure we decode so our escape are interpreted correctly
|
||||||
|
resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
|
||||||
|
|
||||||
|
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
||||||
|
return resolved_chat_template
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Let user specify how to insert multimodal tokens into prompt
|
||||||
|
# (similar to chat template)
|
||||||
|
def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
|
||||||
|
text_prompt: str) -> str:
|
||||||
|
"""Combine multimodal prompts for a multimodal language model."""
|
||||||
|
|
||||||
|
# Look through the text prompt to check for missing placeholders
|
||||||
|
missing_placeholders: List[str] = []
|
||||||
|
for placeholder in placeholder_counts:
|
||||||
|
|
||||||
|
# For any existing placeholder in the text prompt, we leave it as is
|
||||||
|
placeholder_counts[placeholder] -= text_prompt.count(placeholder)
|
||||||
|
|
||||||
|
if placeholder_counts[placeholder] < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Found more '{placeholder}' placeholders in input prompt than "
|
||||||
|
"actual multimodal data items.")
|
||||||
|
|
||||||
|
missing_placeholders.extend([placeholder] *
|
||||||
|
placeholder_counts[placeholder])
|
||||||
|
|
||||||
|
# NOTE: For now we always add missing placeholders at the front of
|
||||||
|
# the prompt. This may change to be customizable in the future.
|
||||||
|
return "\n".join(missing_placeholders + [text_prompt])
|
||||||
|
|
||||||
|
|
||||||
|
# No need to validate using Pydantic again
|
||||||
|
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
|
||||||
|
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
|
||||||
|
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
|
||||||
|
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||||
|
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_chat_message_content_parts(
|
||||||
|
role: str,
|
||||||
|
parts: Iterable[ChatCompletionContentPartParam],
|
||||||
|
mm_tracker: BaseMultiModalItemTracker,
|
||||||
|
) -> List[ConversationMessage]:
|
||||||
|
texts: List[str] = []
|
||||||
|
|
||||||
|
mm_parser = mm_tracker.create_parser()
|
||||||
|
keep_multimodal_content = \
|
||||||
|
mm_tracker._model_config.hf_config.model_type in \
|
||||||
|
MODEL_KEEP_MULTI_MODAL_CONTENT
|
||||||
|
|
||||||
|
has_image = False
|
||||||
|
for part in parts:
|
||||||
|
part_type = part["type"]
|
||||||
|
if part_type == "text":
|
||||||
|
text = _TextParser(part)["text"]
|
||||||
|
texts.append(text)
|
||||||
|
elif part_type == "image_url":
|
||||||
|
image_url = _ImageParser(part)["image_url"]
|
||||||
|
|
||||||
|
if image_url.get("detail", "auto") != "auto":
|
||||||
|
logger.warning(
|
||||||
|
"'image_url.detail' is currently not supported and "
|
||||||
|
"will be ignored.")
|
||||||
|
|
||||||
|
mm_parser.parse_image(image_url["url"])
|
||||||
|
has_image = True
|
||||||
|
elif part_type == "audio_url":
|
||||||
|
audio_url = _AudioParser(part)["audio_url"]
|
||||||
|
|
||||||
|
mm_parser.parse_audio(audio_url["url"])
|
||||||
|
elif part_type == "refusal":
|
||||||
|
text = _RefusalParser(part)["refusal"]
|
||||||
|
texts.append(text)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||||
|
|
||||||
|
text_prompt = "\n".join(texts)
|
||||||
|
if keep_multimodal_content:
|
||||||
|
text_prompt = "\n".join(texts)
|
||||||
|
role_content = [{'type': 'text', 'text': text_prompt}]
|
||||||
|
|
||||||
|
if has_image:
|
||||||
|
role_content = [{'type': 'image'}] + role_content
|
||||||
|
return [ConversationMessage(role=role,
|
||||||
|
content=role_content)] # type: ignore
|
||||||
|
else:
|
||||||
|
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
|
||||||
|
if mm_placeholder_counts:
|
||||||
|
text_prompt = _get_full_multimodal_text_prompt(
|
||||||
|
mm_placeholder_counts, text_prompt)
|
||||||
|
return [ConversationMessage(role=role, content=text_prompt)]
|
||||||
|
|
||||||
|
|
||||||
|
# No need to validate using Pydantic again
|
||||||
|
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
|
||||||
|
_ToolParser = partial(cast, ChatCompletionToolMessageParam)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_chat_message_content(
|
||||||
|
message: ChatCompletionMessageParam,
|
||||||
|
mm_tracker: BaseMultiModalItemTracker,
|
||||||
|
) -> List[ConversationMessage]:
|
||||||
|
role = message["role"]
|
||||||
|
content = message.get("content")
|
||||||
|
|
||||||
|
if content is None:
|
||||||
|
content = []
|
||||||
|
elif isinstance(content, str):
|
||||||
|
content = [
|
||||||
|
ChatCompletionContentPartTextParam(type="text", text=content)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = _parse_chat_message_content_parts(
|
||||||
|
role,
|
||||||
|
content, # type: ignore
|
||||||
|
mm_tracker,
|
||||||
|
)
|
||||||
|
|
||||||
|
for result_msg in result:
|
||||||
|
if role == 'assistant':
|
||||||
|
parsed_msg = _AssistantParser(message)
|
||||||
|
|
||||||
|
if "tool_calls" in parsed_msg:
|
||||||
|
result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
|
||||||
|
|
||||||
|
# Prepend reasoning_content as <think>...</think> so the model
|
||||||
|
# sees its own chain-of-thought in multi-turn conversations.
|
||||||
|
reasoning = message.get("reasoning_content") # type: ignore[arg-type]
|
||||||
|
if reasoning and isinstance(reasoning, str):
|
||||||
|
existing = result_msg.get("content") or ""
|
||||||
|
result_msg["content"] = (
|
||||||
|
f"<think>{reasoning}</think>\n\n{existing}"
|
||||||
|
if existing else f"<think>{reasoning}</think>"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif role == "tool":
|
||||||
|
parsed_msg = _ToolParser(message)
|
||||||
|
if "tool_call_id" in parsed_msg:
|
||||||
|
result_msg["tool_call_id"] = parsed_msg["tool_call_id"]
|
||||||
|
|
||||||
|
if "name" in message and isinstance(message["name"], str):
|
||||||
|
result_msg["name"] = message["name"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _postprocess_messages(messages: List[ConversationMessage]) -> None:
|
||||||
|
# per the Transformers docs & maintainers, tool call arguments in
|
||||||
|
# assistant-role messages with tool_calls need to be dicts not JSON str -
|
||||||
|
# this is how tool-use chat templates will expect them moving forwards
|
||||||
|
# so, for messages that have tool_calls, parse the string (which we get
|
||||||
|
# from openAI format) to dict
|
||||||
|
for message in messages:
|
||||||
|
if (message["role"] == "assistant" and "tool_calls" in message
|
||||||
|
and isinstance(message["tool_calls"], list)):
|
||||||
|
|
||||||
|
for item in message["tool_calls"]:
|
||||||
|
item["function"]["arguments"] = json.loads(
|
||||||
|
item["function"]["arguments"])
|
||||||
|
|
||||||
|
|
||||||
|
def parse_chat_messages(
|
||||||
|
messages: List[ChatCompletionMessageParam],
|
||||||
|
model_config: ModelConfig,
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]:
|
||||||
|
conversation: List[ConversationMessage] = []
|
||||||
|
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
sub_messages = _parse_chat_message_content(msg, mm_tracker)
|
||||||
|
|
||||||
|
conversation.extend(sub_messages)
|
||||||
|
|
||||||
|
_postprocess_messages(conversation)
|
||||||
|
|
||||||
|
return conversation, mm_tracker.all_mm_data()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_chat_messages_futures(
|
||||||
|
messages: List[ChatCompletionMessageParam],
|
||||||
|
model_config: ModelConfig,
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
|
||||||
|
conversation: List[ConversationMessage] = []
|
||||||
|
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
sub_messages = _parse_chat_message_content(msg, mm_tracker)
|
||||||
|
|
||||||
|
conversation.extend(sub_messages)
|
||||||
|
|
||||||
|
_postprocess_messages(conversation)
|
||||||
|
|
||||||
|
return conversation, mm_tracker.all_mm_data()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_hf_chat_template(
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
|
conversation: List[ConversationMessage],
|
||||||
|
chat_template: Optional[str],
|
||||||
|
*,
|
||||||
|
tokenize: bool = False, # Different from HF's default
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
if chat_template is None and tokenizer.chat_template is None:
|
||||||
|
raise ValueError(
|
||||||
|
"As of transformers v4.44, default chat template is no longer "
|
||||||
|
"allowed, so you must provide a chat template if the tokenizer "
|
||||||
|
"does not define one.")
|
||||||
|
|
||||||
|
return tokenizer.apply_chat_template(
|
||||||
|
conversation=conversation, # type: ignore[arg-type]
|
||||||
|
chat_template=chat_template,
|
||||||
|
tokenize=tokenize,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_mistral_chat_template(
|
||||||
|
tokenizer: MistralTokenizer,
|
||||||
|
messages: List[ChatCompletionMessageParam],
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[int]:
|
||||||
|
if chat_template is not None:
|
||||||
|
logger.warning(
|
||||||
|
"'chat_template' cannot be overridden for mistral tokenizer.")
|
||||||
|
if "add_generation_prompt" in kwargs:
|
||||||
|
logger.warning(
|
||||||
|
"'add_generation_prompt' is not supported for mistral tokenizer, "
|
||||||
|
"so it will be ignored.")
|
||||||
|
if "continue_final_message" in kwargs:
|
||||||
|
logger.warning(
|
||||||
|
"'continue_final_message' is not supported for mistral tokenizer, "
|
||||||
|
"so it will be ignored.")
|
||||||
|
|
||||||
|
return tokenizer.apply_chat_template(
|
||||||
|
messages=messages,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
261
qwen3_6_scripts/cli_args.py
Normal file
261
qwen3_6_scripts/cli_args.py
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
"""
|
||||||
|
This file contains the command line arguments for the vLLM's
|
||||||
|
OpenAI-compatible server. It is kept in a separate file for documentation
|
||||||
|
purposes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import ssl
|
||||||
|
from typing import List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||||
|
from vllm.entrypoints.chat_utils import validate_chat_template
|
||||||
|
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||||
|
PromptAdapterPath)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAParserAction(argparse.Action):
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
parser: argparse.ArgumentParser,
|
||||||
|
namespace: argparse.Namespace,
|
||||||
|
values: Optional[Union[str, Sequence[str]]],
|
||||||
|
option_string: Optional[str] = None,
|
||||||
|
):
|
||||||
|
if values is None:
|
||||||
|
values = []
|
||||||
|
if isinstance(values, str):
|
||||||
|
raise TypeError("Expected values to be a list")
|
||||||
|
|
||||||
|
lora_list: List[LoRAModulePath] = []
|
||||||
|
for item in values:
|
||||||
|
if item in [None, '']: # Skip if item is None or empty string
|
||||||
|
continue
|
||||||
|
if '=' in item and ',' not in item: # Old format: name=path
|
||||||
|
name, path = item.split('=')
|
||||||
|
lora_list.append(LoRAModulePath(name, path))
|
||||||
|
else: # Assume JSON format
|
||||||
|
try:
|
||||||
|
lora_dict = json.loads(item)
|
||||||
|
lora = LoRAModulePath(**lora_dict)
|
||||||
|
lora_list.append(lora)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
parser.error(
|
||||||
|
f"Invalid JSON format for --lora-modules: {item}")
|
||||||
|
except TypeError as e:
|
||||||
|
parser.error(
|
||||||
|
f"Invalid fields for --lora-modules: {item} - {str(e)}"
|
||||||
|
)
|
||||||
|
setattr(namespace, self.dest, lora_list)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptAdapterParserAction(argparse.Action):
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
parser: argparse.ArgumentParser,
|
||||||
|
namespace: argparse.Namespace,
|
||||||
|
values: Optional[Union[str, Sequence[str]]],
|
||||||
|
option_string: Optional[str] = None,
|
||||||
|
):
|
||||||
|
if values is None:
|
||||||
|
values = []
|
||||||
|
if isinstance(values, str):
|
||||||
|
raise TypeError("Expected values to be a list")
|
||||||
|
|
||||||
|
adapter_list: List[PromptAdapterPath] = []
|
||||||
|
for item in values:
|
||||||
|
name, path = item.split('=')
|
||||||
|
adapter_list.append(PromptAdapterPath(name, path))
|
||||||
|
setattr(namespace, self.dest, adapter_list)
|
||||||
|
|
||||||
|
|
||||||
|
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||||
|
parser.add_argument("--host",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="host name")
|
||||||
|
parser.add_argument("--port", type=int, default=8000, help="port number")
|
||||||
|
parser.add_argument(
|
||||||
|
"--uvicorn-log-level",
|
||||||
|
type=str,
|
||||||
|
default="info",
|
||||||
|
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
|
||||||
|
help="log level for uvicorn")
|
||||||
|
parser.add_argument("--allow-credentials",
|
||||||
|
action="store_true",
|
||||||
|
help="allow credentials")
|
||||||
|
parser.add_argument("--allowed-origins",
|
||||||
|
type=json.loads,
|
||||||
|
default=["*"],
|
||||||
|
help="allowed origins")
|
||||||
|
parser.add_argument("--allowed-methods",
|
||||||
|
type=json.loads,
|
||||||
|
default=["*"],
|
||||||
|
help="allowed methods")
|
||||||
|
parser.add_argument("--allowed-headers",
|
||||||
|
type=json.loads,
|
||||||
|
default=["*"],
|
||||||
|
help="allowed headers")
|
||||||
|
parser.add_argument("--api-key",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="If provided, the server will require this key "
|
||||||
|
"to be presented in the header.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora-modules",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
nargs='+',
|
||||||
|
action=LoRAParserAction,
|
||||||
|
help="LoRA module configurations in either 'name=path' format"
|
||||||
|
"or JSON format. "
|
||||||
|
"Example (old format): 'name=path' "
|
||||||
|
"Example (new format): "
|
||||||
|
"'{\"name\": \"name\", \"local_path\": \"path\", "
|
||||||
|
"\"base_model_name\": \"id\"}'")
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt-adapters",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
nargs='+',
|
||||||
|
action=PromptAdapterParserAction,
|
||||||
|
help="Prompt adapter configurations in the format name=path. "
|
||||||
|
"Multiple adapters can be specified.")
|
||||||
|
parser.add_argument("--chat-template",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="The file path to the chat template, "
|
||||||
|
"or the template in single-line form "
|
||||||
|
"for the specified model")
|
||||||
|
parser.add_argument("--response-role",
|
||||||
|
type=nullable_str,
|
||||||
|
default="assistant",
|
||||||
|
help="The role name to return if "
|
||||||
|
"`request.add_generation_prompt=true`.")
|
||||||
|
parser.add_argument("--ssl-keyfile",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="The file path to the SSL key file")
|
||||||
|
parser.add_argument("--ssl-certfile",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="The file path to the SSL cert file")
|
||||||
|
parser.add_argument("--ssl-ca-certs",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="The CA certificates file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ssl-cert-reqs",
|
||||||
|
type=int,
|
||||||
|
default=int(ssl.CERT_NONE),
|
||||||
|
help="Whether client certificate is required (see stdlib ssl module's)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--root-path",
|
||||||
|
type=nullable_str,
|
||||||
|
default=None,
|
||||||
|
help="FastAPI root_path when app is behind a path based routing proxy")
|
||||||
|
parser.add_argument(
|
||||||
|
"--middleware",
|
||||||
|
type=nullable_str,
|
||||||
|
action="append",
|
||||||
|
default=[],
|
||||||
|
help="Additional ASGI middleware to apply to the app. "
|
||||||
|
"We accept multiple --middleware arguments. "
|
||||||
|
"The value should be an import path. "
|
||||||
|
"If a function is provided, vLLM will add it to the server "
|
||||||
|
"using @app.middleware('http'). "
|
||||||
|
"If a class is provided, vLLM will add it to the server "
|
||||||
|
"using app.add_middleware(). ")
|
||||||
|
parser.add_argument(
|
||||||
|
"--return-tokens-as-token-ids",
|
||||||
|
action="store_true",
|
||||||
|
help="When --max-logprobs is specified, represents single tokens as "
|
||||||
|
"strings of the form 'token_id:{token_id}' so that tokens that "
|
||||||
|
"are not JSON-encodable can be identified.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-frontend-multiprocessing",
|
||||||
|
action="store_true",
|
||||||
|
help="If specified, will run the OpenAI frontend server in the same "
|
||||||
|
"process as the model serving engine.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-auto-tool-choice",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help=
|
||||||
|
"Enable auto tool choice for supported models. Use --tool-call-parser"
|
||||||
|
"to specify which parser to use")
|
||||||
|
|
||||||
|
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
|
||||||
|
parser.add_argument(
|
||||||
|
"--tool-call-parser",
|
||||||
|
type=str,
|
||||||
|
metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
|
||||||
|
"--tool-parser-plugin",
|
||||||
|
default=None,
|
||||||
|
help=
|
||||||
|
"Select the tool call parser depending on the model that you're using."
|
||||||
|
" This is used to parse the model-generated tool call into OpenAI API "
|
||||||
|
"format. Required for --enable-auto-tool-choice.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tool-parser-plugin",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help=
|
||||||
|
"Special the tool parser plugin write to parse the model-generated tool"
|
||||||
|
" into OpenAI API format, the name register in this plugin can be used "
|
||||||
|
"in --tool-call-parser.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--reasoning-parser",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=
|
||||||
|
"Select the reasoning parser to split <think>...</think> content into "
|
||||||
|
"reasoning_content vs content in the response. "
|
||||||
|
"Supported: qwen3")
|
||||||
|
|
||||||
|
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||||
|
|
||||||
|
parser.add_argument('--max-log-len',
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help='Max number of prompt characters or prompt '
|
||||||
|
'ID numbers being printed in log.'
|
||||||
|
'\n\nDefault: Unlimited')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-fastapi-docs",
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def validate_parsed_serve_args(args: argparse.Namespace):
|
||||||
|
"""Quick checks for model serve args that raise prior to loading."""
|
||||||
|
if hasattr(args, "subparser") and args.subparser != "serve":
|
||||||
|
return
|
||||||
|
|
||||||
|
# Ensure that the chat template is valid; raises if it likely isn't
|
||||||
|
validate_chat_template(args.chat_template)
|
||||||
|
|
||||||
|
# Enable auto tool needs a tool call parser to be valid
|
||||||
|
if args.enable_auto_tool_choice and not args.tool_call_parser:
|
||||||
|
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||||
|
"--tool-call-parser")
|
||||||
|
|
||||||
|
|
||||||
|
def create_parser_for_docs() -> FlexibleArgumentParser:
|
||||||
|
parser_for_docs = FlexibleArgumentParser(
|
||||||
|
prog="-m vllm.entrypoints.openai.api_server")
|
||||||
|
return make_arg_parser(parser_for_docs)
|
||||||
@@ -58,3 +58,14 @@ python3 ./patch_xformers_sdpa_seq.py
|
|||||||
# Use at server start: --tool-call-parser qwen3_coder --enable-auto-tool-choice
|
# Use at server start: --tool-call-parser qwen3_coder --enable-auto-tool-choice
|
||||||
cp ./qwen3coder_tool_parser.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/tool_parsers/
|
cp ./qwen3coder_tool_parser.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/tool_parsers/
|
||||||
python3 ./patch_vllm_tool_parser.py
|
python3 ./patch_vllm_tool_parser.py
|
||||||
|
|
||||||
|
# --- reasoning parser: Qwen3 <think>...</think> split ------------------------
|
||||||
|
# Adds --reasoning-parser qwen3 support.
|
||||||
|
# Routes thinking tokens to reasoning_content, rest to content in the delta.
|
||||||
|
# Works together with --tool-call-parser qwen3_coder (think → tool call flow).
|
||||||
|
cp -r ./reasoning /usr/local/corex/lib/python3/dist-packages/vllm/
|
||||||
|
cp ./protocol.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/protocol.py
|
||||||
|
cp ./cli_args.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/cli_args.py
|
||||||
|
cp ./serving_chat.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/serving_chat.py
|
||||||
|
cp ./api_server.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/api_server.py
|
||||||
|
cp ./chat_utils.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/chat_utils.py
|
||||||
|
|||||||
995
qwen3_6_scripts/protocol.py
Normal file
995
qwen3_6_scripts/protocol.py
Normal file
@@ -0,0 +1,995 @@
|
|||||||
|
# Adapted from
|
||||||
|
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||||
|
import time
|
||||||
|
from argparse import Namespace
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from openai.types.chat import ChatCompletionContentPartParam
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
|
from typing_extensions import Annotated, Required, TypedDict
|
||||||
|
|
||||||
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
|
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||||
|
RequestOutputKind, SamplingParams)
|
||||||
|
from vllm.sequence import Logprob
|
||||||
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
|
# torch is mocked during docs generation,
|
||||||
|
# so we have to provide the values as literals
|
||||||
|
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
|
||||||
|
_LONG_INFO: Union["torch.iinfo", Namespace]
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sphinx.ext.autodoc.mock import _MockModule
|
||||||
|
|
||||||
|
if isinstance(torch, _MockModule):
|
||||||
|
_LONG_INFO = _MOCK_LONG_INFO
|
||||||
|
else:
|
||||||
|
_LONG_INFO = torch.iinfo(torch.long)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
_LONG_INFO = torch.iinfo(torch.long)
|
||||||
|
|
||||||
|
assert _LONG_INFO.min == _MOCK_LONG_INFO.min
|
||||||
|
assert _LONG_INFO.max == _MOCK_LONG_INFO.max
|
||||||
|
|
||||||
|
|
||||||
|
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||||
|
"""Enables custom roles in the Chat Completion API."""
|
||||||
|
role: Required[str]
|
||||||
|
"""The role of the message's author."""
|
||||||
|
|
||||||
|
content: Union[str, List[ChatCompletionContentPartParam]]
|
||||||
|
"""The contents of the message."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""An optional name for the participant.
|
||||||
|
|
||||||
|
Provides the model information to differentiate between participants of the
|
||||||
|
same role.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tool_call_id: Optional[str]
|
||||||
|
|
||||||
|
tool_calls: Optional[List[dict]]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIBaseModel(BaseModel):
|
||||||
|
# OpenAI API does not allow extra fields
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponse(OpenAIBaseModel):
|
||||||
|
object: str = "error"
|
||||||
|
message: str
|
||||||
|
type: str
|
||||||
|
param: Optional[str] = None
|
||||||
|
code: int
|
||||||
|
|
||||||
|
|
||||||
|
class ModelPermission(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
|
||||||
|
object: str = "model_permission"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
allow_create_engine: bool = False
|
||||||
|
allow_sampling: bool = True
|
||||||
|
allow_logprobs: bool = True
|
||||||
|
allow_search_indices: bool = False
|
||||||
|
allow_view: bool = True
|
||||||
|
allow_fine_tuning: bool = False
|
||||||
|
organization: str = "*"
|
||||||
|
group: Optional[str] = None
|
||||||
|
is_blocking: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCard(OpenAIBaseModel):
|
||||||
|
id: str
|
||||||
|
object: str = "model"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
owned_by: str = "vllm"
|
||||||
|
root: Optional[str] = None
|
||||||
|
parent: Optional[str] = None
|
||||||
|
max_model_len: Optional[int] = None
|
||||||
|
permission: List[ModelPermission] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelList(OpenAIBaseModel):
|
||||||
|
object: str = "list"
|
||||||
|
data: List[ModelCard] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class UsageInfo(OpenAIBaseModel):
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
completion_tokens: Optional[int] = 0
|
||||||
|
reasoning_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class RequestResponseMetadata(BaseModel):
|
||||||
|
request_id: str
|
||||||
|
final_usage_info: Optional[UsageInfo] = None
|
||||||
|
|
||||||
|
|
||||||
|
class JsonSchemaResponseFormat(OpenAIBaseModel):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
# schema is the field in openai but that causes conflicts with pydantic so
|
||||||
|
# instead use json_schema with an alias
|
||||||
|
json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema')
|
||||||
|
strict: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseFormat(OpenAIBaseModel):
|
||||||
|
# type must be "json_schema", "json_object" or "text"
|
||||||
|
type: Literal["text", "json_object", "json_schema"]
|
||||||
|
json_schema: Optional[JsonSchemaResponseFormat] = None
|
||||||
|
|
||||||
|
|
||||||
|
class StreamOptions(OpenAIBaseModel):
|
||||||
|
include_usage: Optional[bool] = True
|
||||||
|
continuous_usage_stats: Optional[bool] = True
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionDefinition(OpenAIBaseModel):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
parameters: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionToolsParam(OpenAIBaseModel):
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
function: FunctionDefinition
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionNamedFunction(OpenAIBaseModel):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
|
||||||
|
function: ChatCompletionNamedFunction
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionRequest(OpenAIBaseModel):
|
||||||
|
# Ordered by official OpenAI API documentation
|
||||||
|
# https://platform.openai.com/docs/api-reference/chat/create
|
||||||
|
messages: List[ChatCompletionMessageParam]
|
||||||
|
model: str
|
||||||
|
frequency_penalty: Optional[float] = 0.0
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None
|
||||||
|
logprobs: Optional[bool] = False
|
||||||
|
top_logprobs: Optional[int] = 0
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
n: Optional[int] = 1
|
||||||
|
presence_penalty: Optional[float] = 0.0
|
||||||
|
response_format: Optional[ResponseFormat] = None
|
||||||
|
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||||
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
stream_options: Optional[StreamOptions] = None
|
||||||
|
temperature: Optional[float] = 0.7
|
||||||
|
top_p: Optional[float] = 1.0
|
||||||
|
tools: Optional[List[ChatCompletionToolsParam]] = None
|
||||||
|
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
|
||||||
|
ChatCompletionNamedToolChoiceParam]] = "none"
|
||||||
|
|
||||||
|
# NOTE this will be ignored by VLLM -- the model determines the behavior
|
||||||
|
parallel_tool_calls: Optional[bool] = False
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
# doc: begin-chat-completion-sampling-params
|
||||||
|
best_of: Optional[int] = None
|
||||||
|
use_beam_search: bool = False
|
||||||
|
top_k: int = -1
|
||||||
|
min_p: float = 0.0
|
||||||
|
repetition_penalty: float = 1.0
|
||||||
|
length_penalty: float = 1.0
|
||||||
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
|
include_stop_str_in_output: bool = False
|
||||||
|
ignore_eos: bool = False
|
||||||
|
min_tokens: int = 0
|
||||||
|
skip_special_tokens: bool = True
|
||||||
|
spaces_between_special_tokens: bool = True
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||||
|
prompt_logprobs: Optional[int] = None
|
||||||
|
# doc: end-chat-completion-sampling-params
|
||||||
|
|
||||||
|
# doc: begin-chat-completion-extra-params
|
||||||
|
echo: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"If true, the new message will be prepended with the last message "
|
||||||
|
"if they belong to the same role."),
|
||||||
|
)
|
||||||
|
add_generation_prompt: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description=
|
||||||
|
("If true, the generation prompt will be added to the chat template. "
|
||||||
|
"This is a parameter used by chat template in tokenizer config of the "
|
||||||
|
"model."),
|
||||||
|
)
|
||||||
|
continue_final_message: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=
|
||||||
|
("If this is set, the chat will be formatted so that the final "
|
||||||
|
"message in the chat is open-ended, without any EOS tokens. The "
|
||||||
|
"model will continue this message rather than starting a new one. "
|
||||||
|
"This allows you to \"prefill\" part of the model's response for it. "
|
||||||
|
"Cannot be used at the same time as `add_generation_prompt`."),
|
||||||
|
)
|
||||||
|
add_special_tokens: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||||
|
"on top of what is added by the chat template. "
|
||||||
|
"For most models, the chat template takes care of adding the "
|
||||||
|
"special tokens so this should be set to false (as is the "
|
||||||
|
"default)."),
|
||||||
|
)
|
||||||
|
documents: Optional[List[Dict[str, str]]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=
|
||||||
|
("A list of dicts representing documents that will be accessible to "
|
||||||
|
"the model if it is performing RAG (retrieval-augmented generation)."
|
||||||
|
" If the template does not support RAG, this argument will have no "
|
||||||
|
"effect. We recommend that each document should be a dict containing "
|
||||||
|
"\"title\" and \"text\" keys."),
|
||||||
|
)
|
||||||
|
chat_template: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"A Jinja template to use for this conversion. "
|
||||||
|
"As of transformers v4.44, default chat template is no longer "
|
||||||
|
"allowed, so you must provide a chat template if the tokenizer "
|
||||||
|
"does not define one."),
|
||||||
|
)
|
||||||
|
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=("Additional kwargs to pass to the template renderer. "
|
||||||
|
"Will be accessible by the chat template."),
|
||||||
|
)
|
||||||
|
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=("If specified, the output will follow the JSON schema."),
|
||||||
|
)
|
||||||
|
guided_regex: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, the output will follow the regex pattern."),
|
||||||
|
)
|
||||||
|
guided_choice: Optional[List[str]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, the output will be exactly one of the choices."),
|
||||||
|
)
|
||||||
|
guided_grammar: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, the output will follow the context free grammar."),
|
||||||
|
)
|
||||||
|
guided_decoding_backend: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, will override the default guided decoding backend "
|
||||||
|
"of the server for this specific request. If set, must be either "
|
||||||
|
"'outlines' / 'lm-format-enforcer'"))
|
||||||
|
guided_whitespace_pattern: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, will override the default whitespace pattern "
|
||||||
|
"for guided json decoding."))
|
||||||
|
priority: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"The priority of the request (lower means earlier handling; "
|
||||||
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
|
"if the served model does not use priority scheduling."))
|
||||||
|
|
||||||
|
# doc: end-chat-completion-extra-params
|
||||||
|
|
||||||
|
def to_beam_search_params(self,
|
||||||
|
default_max_tokens: int) -> BeamSearchParams:
|
||||||
|
max_tokens = self.max_tokens
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
|
n = self.n if self.n is not None else 1
|
||||||
|
temperature = self.temperature if self.temperature is not None else 0.0
|
||||||
|
|
||||||
|
return BeamSearchParams(
|
||||||
|
beam_width=n,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
ignore_eos=self.ignore_eos,
|
||||||
|
temperature=temperature,
|
||||||
|
length_penalty=self.length_penalty,
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||||
|
max_tokens = self.max_tokens
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
|
prompt_logprobs = self.prompt_logprobs
|
||||||
|
if prompt_logprobs is None and self.echo:
|
||||||
|
prompt_logprobs = self.top_logprobs
|
||||||
|
|
||||||
|
guided_json_object = None
|
||||||
|
if (self.response_format is not None
|
||||||
|
and self.response_format.type == "json_object"):
|
||||||
|
guided_json_object = True
|
||||||
|
|
||||||
|
guided_decoding = GuidedDecodingParams.from_optional(
|
||||||
|
json=self._get_guided_json_from_tool() or self.guided_json,
|
||||||
|
regex=self.guided_regex,
|
||||||
|
choice=self.guided_choice,
|
||||||
|
grammar=self.guided_grammar,
|
||||||
|
json_object=guided_json_object,
|
||||||
|
backend=self.guided_decoding_backend,
|
||||||
|
whitespace_pattern=self.guided_whitespace_pattern)
|
||||||
|
|
||||||
|
return SamplingParams.from_optional(
|
||||||
|
n=self.n,
|
||||||
|
best_of=self.best_of,
|
||||||
|
presence_penalty=self.presence_penalty,
|
||||||
|
frequency_penalty=self.frequency_penalty,
|
||||||
|
repetition_penalty=self.repetition_penalty,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
top_k=self.top_k,
|
||||||
|
min_p=self.min_p,
|
||||||
|
seed=self.seed,
|
||||||
|
stop=self.stop,
|
||||||
|
stop_token_ids=self.stop_token_ids,
|
||||||
|
logprobs=self.top_logprobs if self.logprobs else None,
|
||||||
|
prompt_logprobs=prompt_logprobs,
|
||||||
|
ignore_eos=self.ignore_eos,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
min_tokens=self.min_tokens,
|
||||||
|
skip_special_tokens=self.skip_special_tokens,
|
||||||
|
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||||
|
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||||
|
else RequestOutputKind.FINAL_ONLY,
|
||||||
|
guided_decoding=guided_decoding,
|
||||||
|
logit_bias=self.logit_bias)
|
||||||
|
|
||||||
|
def _get_guided_json_from_tool(
|
||||||
|
self) -> Optional[Union[str, dict, BaseModel]]:
|
||||||
|
# user has chosen to not use any tool
|
||||||
|
if self.tool_choice == "none" or self.tools is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# user has chosen to use a named tool
|
||||||
|
if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||||
|
tool_name = self.tool_choice.function.name
|
||||||
|
tools = {tool.function.name: tool.function for tool in self.tools}
|
||||||
|
if tool_name not in tools:
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool '{tool_name}' has not been passed in `tools`.")
|
||||||
|
tool = tools[tool_name]
|
||||||
|
return tool.parameters
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_stream_options(cls, data):
|
||||||
|
if data.get("stream_options") and not data.get("stream"):
|
||||||
|
raise ValueError(
|
||||||
|
"Stream options can only be defined when `stream=True`.")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_logprobs(cls, data):
|
||||||
|
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||||
|
if data.get("stream") and prompt_logprobs > 0:
|
||||||
|
raise ValueError(
|
||||||
|
"`prompt_logprobs` are not available when `stream=True`.")
|
||||||
|
|
||||||
|
if prompt_logprobs < 0:
|
||||||
|
raise ValueError("`prompt_logprobs` must be a positive value.")
|
||||||
|
|
||||||
|
if (top_logprobs := data.get("top_logprobs")) is not None:
|
||||||
|
if top_logprobs < 0:
|
||||||
|
raise ValueError("`top_logprobs` must be a positive value.")
|
||||||
|
|
||||||
|
if not data.get("logprobs"):
|
||||||
|
raise ValueError(
|
||||||
|
"when using `top_logprobs`, `logprobs` must be set to true."
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_guided_decoding_count(cls, data):
|
||||||
|
if isinstance(data, ValueError):
|
||||||
|
raise data
|
||||||
|
|
||||||
|
guide_count = sum([
|
||||||
|
"guided_json" in data and data["guided_json"] is not None,
|
||||||
|
"guided_regex" in data and data["guided_regex"] is not None,
|
||||||
|
"guided_choice" in data and data["guided_choice"] is not None
|
||||||
|
])
|
||||||
|
# you can only use one kind of guided decoding
|
||||||
|
if guide_count > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"You can only use one kind of guided decoding "
|
||||||
|
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||||
|
# you can only either use guided decoding or tools, not both
|
||||||
|
if guide_count > 1 and data.get("tool_choice",
|
||||||
|
"none") not in ("none", "auto"):
|
||||||
|
raise ValueError(
|
||||||
|
"You can only either use guided decoding or tools, not both.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_tool_usage(cls, data):
|
||||||
|
|
||||||
|
# if "tool_choice" is not specified but tools are provided,
|
||||||
|
# default to "auto" tool_choice
|
||||||
|
if "tool_choice" not in data and data.get("tools"):
|
||||||
|
data["tool_choice"] = "auto"
|
||||||
|
|
||||||
|
# if "tool_choice" is specified -- validation
|
||||||
|
if "tool_choice" in data:
|
||||||
|
|
||||||
|
# ensure that if "tool choice" is specified, tools are present
|
||||||
|
if "tools" not in data or data["tools"] is None:
|
||||||
|
raise ValueError(
|
||||||
|
"When using `tool_choice`, `tools` must be set.")
|
||||||
|
|
||||||
|
# make sure that tool choice is either a named tool
|
||||||
|
# OR that it's set to "auto"
|
||||||
|
if data["tool_choice"] != "auto" and not isinstance(
|
||||||
|
data["tool_choice"], dict):
|
||||||
|
raise ValueError(
|
||||||
|
"`tool_choice` must either be a named tool or \"auto\". "
|
||||||
|
"`tool_choice=\"none\" is not supported.")
|
||||||
|
|
||||||
|
# ensure that if "tool_choice" is specified as an object,
|
||||||
|
# it matches a valid tool
|
||||||
|
if isinstance(data["tool_choice"], dict):
|
||||||
|
valid_tool = False
|
||||||
|
specified_function = data["tool_choice"]["function"]
|
||||||
|
if not specified_function:
|
||||||
|
raise ValueError(
|
||||||
|
"Incorrectly formatted `tool_choice`. Should be like "
|
||||||
|
"`{\"type\": \"function\","
|
||||||
|
" \"function\": {\"name\": \"my_function\"}}`")
|
||||||
|
specified_function_name = specified_function["name"]
|
||||||
|
if not specified_function_name:
|
||||||
|
raise ValueError(
|
||||||
|
"Incorrectly formatted `tool_choice`. Should be like "
|
||||||
|
"`{\"type\": \"function\", "
|
||||||
|
"\"function\": {\"name\": \"my_function\"}}`")
|
||||||
|
for tool in data["tools"]:
|
||||||
|
if tool["function"]["name"] == specified_function_name:
|
||||||
|
valid_tool = True
|
||||||
|
break
|
||||||
|
if not valid_tool:
|
||||||
|
raise ValueError(
|
||||||
|
"The tool specified in `tool_choice` does not match any"
|
||||||
|
" of the specified `tools`")
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_generation_prompt(cls, data):
|
||||||
|
if data.get("continue_final_message") and data.get(
|
||||||
|
"add_generation_prompt"):
|
||||||
|
raise ValueError("Cannot set both `continue_final_message` and "
|
||||||
|
"`add_generation_prompt` to True.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionRequest(OpenAIBaseModel):
|
||||||
|
# Ordered by official OpenAI API documentation
|
||||||
|
# https://platform.openai.com/docs/api-reference/completions/create
|
||||||
|
model: str
|
||||||
|
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||||
|
best_of: Optional[int] = None
|
||||||
|
echo: Optional[bool] = False
|
||||||
|
frequency_penalty: Optional[float] = 0.0
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None
|
||||||
|
logprobs: Optional[int] = None
|
||||||
|
max_tokens: Optional[int] = 16
|
||||||
|
n: int = 1
|
||||||
|
presence_penalty: Optional[float] = 0.0
|
||||||
|
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||||
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
stream_options: Optional[StreamOptions] = None
|
||||||
|
suffix: Optional[str] = None
|
||||||
|
temperature: Optional[float] = 1.0
|
||||||
|
top_p: Optional[float] = 1.0
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
# doc: begin-completion-sampling-params
|
||||||
|
use_beam_search: bool = False
|
||||||
|
top_k: int = -1
|
||||||
|
min_p: float = 0.0
|
||||||
|
repetition_penalty: float = 1.0
|
||||||
|
length_penalty: float = 1.0
|
||||||
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
|
include_stop_str_in_output: bool = False
|
||||||
|
ignore_eos: bool = False
|
||||||
|
min_tokens: int = 0
|
||||||
|
skip_special_tokens: bool = True
|
||||||
|
spaces_between_special_tokens: bool = True
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||||
|
allowed_token_ids: Optional[List[int]] = None
|
||||||
|
prompt_logprobs: Optional[int] = None
|
||||||
|
# doc: end-completion-sampling-params
|
||||||
|
|
||||||
|
# doc: begin-completion-extra-params
|
||||||
|
add_special_tokens: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description=(
|
||||||
|
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||||
|
"the prompt."),
|
||||||
|
)
|
||||||
|
response_format: Optional[ResponseFormat] = Field(
|
||||||
|
default=None,
|
||||||
|
description=
|
||||||
|
("Similar to chat completion, this parameter specifies the format of "
|
||||||
|
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
|
||||||
|
"supported."),
|
||||||
|
)
|
||||||
|
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
|
||||||
|
default=None,
|
||||||
|
description="If specified, the output will follow the JSON schema.",
|
||||||
|
)
|
||||||
|
guided_regex: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, the output will follow the regex pattern."),
|
||||||
|
)
|
||||||
|
guided_choice: Optional[List[str]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, the output will be exactly one of the choices."),
|
||||||
|
)
|
||||||
|
guided_grammar: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, the output will follow the context free grammar."),
|
||||||
|
)
|
||||||
|
guided_decoding_backend: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, will override the default guided decoding backend "
|
||||||
|
"of the server for this specific request. If set, must be one of "
|
||||||
|
"'outlines' / 'lm-format-enforcer'"))
|
||||||
|
guided_whitespace_pattern: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, will override the default whitespace pattern "
|
||||||
|
"for guided json decoding."))
|
||||||
|
priority: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"The priority of the request (lower means earlier handling; "
|
||||||
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
|
"if the served model does not use priority scheduling."))
|
||||||
|
|
||||||
|
# doc: end-completion-extra-params
|
||||||
|
|
||||||
|
def to_beam_search_params(self,
|
||||||
|
default_max_tokens: int) -> BeamSearchParams:
|
||||||
|
max_tokens = self.max_tokens
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
|
n = self.n if self.n is not None else 1
|
||||||
|
temperature = self.temperature if self.temperature is not None else 0.0
|
||||||
|
|
||||||
|
return BeamSearchParams(
|
||||||
|
beam_width=n,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
ignore_eos=self.ignore_eos,
|
||||||
|
temperature=temperature,
|
||||||
|
length_penalty=self.length_penalty,
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||||
|
max_tokens = self.max_tokens
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
|
prompt_logprobs = self.prompt_logprobs
|
||||||
|
if prompt_logprobs is None and self.echo:
|
||||||
|
prompt_logprobs = self.logprobs
|
||||||
|
|
||||||
|
echo_without_generation = self.echo and self.max_tokens == 0
|
||||||
|
|
||||||
|
guided_json_object = None
|
||||||
|
if (self.response_format is not None
|
||||||
|
and self.response_format.type == "json_object"):
|
||||||
|
guided_json_object = True
|
||||||
|
|
||||||
|
guided_decoding = GuidedDecodingParams.from_optional(
|
||||||
|
json=self.guided_json,
|
||||||
|
regex=self.guided_regex,
|
||||||
|
choice=self.guided_choice,
|
||||||
|
grammar=self.guided_grammar,
|
||||||
|
json_object=guided_json_object,
|
||||||
|
backend=self.guided_decoding_backend,
|
||||||
|
whitespace_pattern=self.guided_whitespace_pattern)
|
||||||
|
|
||||||
|
return SamplingParams.from_optional(
|
||||||
|
n=self.n,
|
||||||
|
best_of=self.best_of,
|
||||||
|
presence_penalty=self.presence_penalty,
|
||||||
|
frequency_penalty=self.frequency_penalty,
|
||||||
|
repetition_penalty=self.repetition_penalty,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
top_k=self.top_k,
|
||||||
|
min_p=self.min_p,
|
||||||
|
seed=self.seed,
|
||||||
|
stop=self.stop,
|
||||||
|
stop_token_ids=self.stop_token_ids,
|
||||||
|
logprobs=self.logprobs,
|
||||||
|
ignore_eos=self.ignore_eos,
|
||||||
|
max_tokens=max_tokens if not echo_without_generation else 1,
|
||||||
|
min_tokens=self.min_tokens,
|
||||||
|
prompt_logprobs=prompt_logprobs,
|
||||||
|
skip_special_tokens=self.skip_special_tokens,
|
||||||
|
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||||
|
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||||
|
else RequestOutputKind.FINAL_ONLY,
|
||||||
|
guided_decoding=guided_decoding,
|
||||||
|
logit_bias=self.logit_bias,
|
||||||
|
allowed_token_ids=self.allowed_token_ids)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_guided_decoding_count(cls, data):
|
||||||
|
guide_count = sum([
|
||||||
|
"guided_json" in data and data["guided_json"] is not None,
|
||||||
|
"guided_regex" in data and data["guided_regex"] is not None,
|
||||||
|
"guided_choice" in data and data["guided_choice"] is not None
|
||||||
|
])
|
||||||
|
if guide_count > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"You can only use one kind of guided decoding "
|
||||||
|
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_logprobs(cls, data):
|
||||||
|
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||||
|
if data.get("stream") and prompt_logprobs > 0:
|
||||||
|
raise ValueError(
|
||||||
|
"`prompt_logprobs` are not available when `stream=True`.")
|
||||||
|
|
||||||
|
if prompt_logprobs < 0:
|
||||||
|
raise ValueError("`prompt_logprobs` must be a positive value.")
|
||||||
|
|
||||||
|
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
|
||||||
|
raise ValueError("`logprobs` must be a positive value.")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_stream_options(cls, data):
|
||||||
|
if data.get("stream_options") and not data.get("stream"):
|
||||||
|
raise ValueError(
|
||||||
|
"Stream options can only be defined when `stream=True`.")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingRequest(OpenAIBaseModel):
|
||||||
|
# Ordered by official OpenAI API documentation
|
||||||
|
# https://platform.openai.com/docs/api-reference/embeddings
|
||||||
|
model: str
|
||||||
|
input: Union[List[int], List[List[int]], str, List[str]]
|
||||||
|
encoding_format: Literal["float", "base64"] = "float"
|
||||||
|
dimensions: Optional[int] = None
|
||||||
|
user: Optional[str] = None
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||||
|
|
||||||
|
# doc: begin-embedding-pooling-params
|
||||||
|
additional_data: Optional[Any] = None
|
||||||
|
|
||||||
|
# doc: end-embedding-pooling-params
|
||||||
|
|
||||||
|
# doc: begin-embedding-extra-params
|
||||||
|
priority: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"The priority of the request (lower means earlier handling; "
|
||||||
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
|
"if the served model does not use priority scheduling."))
|
||||||
|
|
||||||
|
# doc: end-embedding-extra-params
|
||||||
|
|
||||||
|
def to_pooling_params(self):
|
||||||
|
return PoolingParams(additional_data=self.additional_data)
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionLogProbs(OpenAIBaseModel):
|
||||||
|
text_offset: List[int] = Field(default_factory=list)
|
||||||
|
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||||
|
tokens: List[str] = Field(default_factory=list)
|
||||||
|
top_logprobs: List[Optional[Dict[str,
|
||||||
|
float]]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponseChoice(OpenAIBaseModel):
|
||||||
|
index: int
|
||||||
|
text: str
|
||||||
|
logprobs: Optional[CompletionLogProbs] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
stop_reason: Optional[Union[int, str]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"The stop string or token id that caused the completion "
|
||||||
|
"to stop, None if the completion finished for some other reason "
|
||||||
|
"including encountering the EOS token"),
|
||||||
|
)
|
||||||
|
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||||
|
object: str = "text_completion"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[CompletionResponseChoice]
|
||||||
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponseStreamChoice(OpenAIBaseModel):
|
||||||
|
index: int
|
||||||
|
text: str
|
||||||
|
logprobs: Optional[CompletionLogProbs] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
stop_reason: Optional[Union[int, str]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"The stop string or token id that caused the completion "
|
||||||
|
"to stop, None if the completion finished for some other reason "
|
||||||
|
"including encountering the EOS token"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionStreamResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||||
|
object: str = "text_completion"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[CompletionResponseStreamChoice]
|
||||||
|
usage: Optional[UsageInfo] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingResponseData(OpenAIBaseModel):
|
||||||
|
index: int
|
||||||
|
object: str = "embedding"
|
||||||
|
embedding: Union[List[float], str]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||||
|
object: str = "list"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
data: List[EmbeddingResponseData]
|
||||||
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCall(OpenAIBaseModel):
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCall(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
function: FunctionCall
|
||||||
|
|
||||||
|
|
||||||
|
class DeltaFunctionCall(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
arguments: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# a tool call delta where everything is optional
|
||||||
|
class DeltaToolCall(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
index: int
|
||||||
|
function: Optional[DeltaFunctionCall] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractedToolCallInformation(BaseModel):
|
||||||
|
# indicate if tools were called
|
||||||
|
tools_called: bool
|
||||||
|
|
||||||
|
# extracted tool calls
|
||||||
|
tool_calls: List[ToolCall]
|
||||||
|
|
||||||
|
# content - per OpenAI spec, content AND tool calls can be returned rarely
|
||||||
|
# But some models will do this intentionally
|
||||||
|
content: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(OpenAIBaseModel):
|
||||||
|
role: str
|
||||||
|
reasoning_content: Optional[str] = None
|
||||||
|
content: Optional[str] = None
|
||||||
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionLogProb(OpenAIBaseModel):
|
||||||
|
token: str
|
||||||
|
logprob: float = -9999.0
|
||||||
|
bytes: Optional[List[int]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionLogProbsContent(ChatCompletionLogProb):
|
||||||
|
top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionLogProbs(OpenAIBaseModel):
|
||||||
|
content: Optional[List[ChatCompletionLogProbsContent]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseChoice(OpenAIBaseModel):
|
||||||
|
index: int
|
||||||
|
message: ChatMessage
|
||||||
|
logprobs: Optional[ChatCompletionLogProbs] = None
|
||||||
|
# per OpenAI spec this is the default
|
||||||
|
finish_reason: Optional[str] = "stop"
|
||||||
|
# not part of the OpenAI spec but included in vLLM for legacy reasons
|
||||||
|
stop_reason: Optional[Union[int, str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||||
|
object: Literal["chat.completion"] = "chat.completion"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionResponseChoice]
|
||||||
|
usage: UsageInfo
|
||||||
|
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DeltaMessage(OpenAIBaseModel):
|
||||||
|
role: Optional[str] = None
|
||||||
|
reasoning_content: Optional[str] = None
|
||||||
|
content: Optional[str] = None
|
||||||
|
tool_calls: List[DeltaToolCall] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
||||||
|
index: int
|
||||||
|
delta: DeltaMessage
|
||||||
|
logprobs: Optional[ChatCompletionLogProbs] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
stop_reason: Optional[Union[int, str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionStreamResponse(OpenAIBaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||||
|
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionResponseStreamChoice]
|
||||||
|
usage: Optional[UsageInfo] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchRequestInput(OpenAIBaseModel):
|
||||||
|
"""
|
||||||
|
The per-line object of the batch input file.
|
||||||
|
|
||||||
|
NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# A developer-provided per-request id that will be used to match outputs to
|
||||||
|
# inputs. Must be unique for each request in a batch.
|
||||||
|
custom_id: str
|
||||||
|
|
||||||
|
# The HTTP method to be used for the request. Currently only POST is
|
||||||
|
# supported.
|
||||||
|
method: str
|
||||||
|
|
||||||
|
# The OpenAI API relative URL to be used for the request. Currently
|
||||||
|
# /v1/chat/completions is supported.
|
||||||
|
url: str
|
||||||
|
|
||||||
|
# The parameters of the request.
|
||||||
|
body: Union[ChatCompletionRequest, EmbeddingRequest]
|
||||||
|
|
||||||
|
|
||||||
|
class BatchResponseData(OpenAIBaseModel):
|
||||||
|
# HTTP status code of the response.
|
||||||
|
status_code: int = 200
|
||||||
|
|
||||||
|
# An unique identifier for the API request.
|
||||||
|
request_id: str
|
||||||
|
|
||||||
|
# The body of the response.
|
||||||
|
body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class BatchRequestOutput(OpenAIBaseModel):
|
||||||
|
"""
|
||||||
|
The per-line object of the batch output and error files
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
|
||||||
|
# A developer-provided per-request id that will be used to match outputs to
|
||||||
|
# inputs.
|
||||||
|
custom_id: str
|
||||||
|
|
||||||
|
response: Optional[BatchResponseData]
|
||||||
|
|
||||||
|
# For requests that failed with a non-HTTP error, this will contain more
|
||||||
|
# information on the cause of the failure.
|
||||||
|
error: Optional[Any]
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizeCompletionRequest(OpenAIBaseModel):
|
||||||
|
model: str
|
||||||
|
prompt: str
|
||||||
|
|
||||||
|
add_special_tokens: bool = Field(default=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizeChatRequest(OpenAIBaseModel):
|
||||||
|
model: str
|
||||||
|
messages: List[ChatCompletionMessageParam]
|
||||||
|
|
||||||
|
add_generation_prompt: bool = Field(default=True)
|
||||||
|
continue_final_message: bool = Field(default=False)
|
||||||
|
add_special_tokens: bool = Field(default=False)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_generation_prompt(cls, data):
|
||||||
|
if data.get("continue_final_message") and data.get(
|
||||||
|
"add_generation_prompt"):
|
||||||
|
raise ValueError("Cannot set both `continue_final_message` and "
|
||||||
|
"`add_generation_prompt` to True.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizeResponse(OpenAIBaseModel):
|
||||||
|
count: int
|
||||||
|
max_model_len: int
|
||||||
|
tokens: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
class DetokenizeRequest(OpenAIBaseModel):
|
||||||
|
model: str
|
||||||
|
tokens: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
class DetokenizeResponse(OpenAIBaseModel):
|
||||||
|
prompt: str
|
||||||
|
|
||||||
|
|
||||||
|
class LoadLoraAdapterRequest(BaseModel):
|
||||||
|
lora_name: str
|
||||||
|
lora_path: str
|
||||||
|
|
||||||
|
|
||||||
|
class UnloadLoraAdapterRequest(BaseModel):
|
||||||
|
lora_name: str
|
||||||
|
lora_int_id: Optional[int] = Field(default=None)
|
||||||
16
qwen3_6_scripts/reasoning/__init__.py
Normal file
16
qwen3_6_scripts/reasoning/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""
|
||||||
|
Reasoning parser module for vLLM 0.6.3 (BI-V100 / Qwen3.6-27B adaptation).
|
||||||
|
|
||||||
|
Usage: --reasoning-parser qwen3
|
||||||
|
"""
|
||||||
|
|
||||||
|
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
||||||
|
|
||||||
|
__all__ = ["ReasoningParser", "ReasoningParserManager"]
|
||||||
|
|
||||||
|
# Lazy-register Qwen3 parser; imported on first get_reasoning_parser("qwen3").
|
||||||
|
ReasoningParserManager.register_lazy(
|
||||||
|
"qwen3",
|
||||||
|
"vllm.reasoning.qwen3_reasoning_parser",
|
||||||
|
"Qwen3ReasoningParser",
|
||||||
|
)
|
||||||
243
qwen3_6_scripts/reasoning/abs_reasoning_parsers.py
Normal file
243
qwen3_6_scripts/reasoning/abs_reasoning_parsers.py
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
"""
|
||||||
|
Abstract reasoning parser base classes for vLLM 0.6.3.
|
||||||
|
Adapted from vllm-original/vllm/reasoning/abs_reasoning_parsers.py:
|
||||||
|
- Removed vllm.entrypoints.mcp, vllm.utils.collection_utils, import_utils
|
||||||
|
- DeltaMessage from vllm 0.6.3 protocol path
|
||||||
|
- TokenizerLike -> AnyTokenizer
|
||||||
|
- ReasoningParserManager: simplified eager + lazy registration
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
from abc import abstractmethod
|
||||||
|
from collections.abc import Iterable, Sequence
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Any, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.entrypoints.openai.protocol import DeltaMessage
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
else:
|
||||||
|
DeltaMessage = Any
|
||||||
|
AnyTokenizer = Any
|
||||||
|
|
||||||
|
|
||||||
|
class ReasoningParser:
|
||||||
|
"""Abstract base for all reasoning parsers."""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: "AnyTokenizer", *args, **kwargs):
|
||||||
|
self.model_tokenizer = tokenizer
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def vocab(self) -> dict:
|
||||||
|
return self.model_tokenizer.get_vocab()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||||
|
"""Return True once the reasoning block has closed in input_ids."""
|
||||||
|
|
||||||
|
def is_reasoning_end_streaming(
|
||||||
|
self, input_ids: Sequence[int], delta_ids: Iterable[int]
|
||||||
|
) -> bool:
|
||||||
|
return self.is_reasoning_end(input_ids)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def extract_content_ids(self, input_ids: list) -> list:
|
||||||
|
"""Return token ids that belong to the content (post-reasoning) part."""
|
||||||
|
|
||||||
|
def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def extract_reasoning(
|
||||||
|
self, model_output: str, request: Any
|
||||||
|
) -> "tuple[Optional[str], Optional[str]]":
|
||||||
|
"""
|
||||||
|
Split a complete model output into (reasoning_text, content_text).
|
||||||
|
Either part may be None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def extract_reasoning_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
) -> Optional["DeltaMessage"]:
|
||||||
|
"""
|
||||||
|
Extract reasoning from a streaming delta.
|
||||||
|
Returns a DeltaMessage with reasoning_content and/or content set,
|
||||||
|
or None if this delta should be suppressed (control token).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseThinkingReasoningParser(ReasoningParser):
|
||||||
|
"""
|
||||||
|
Base for parsers that use <start_token>...</end_token> delimiters.
|
||||||
|
Subclasses define start_token / end_token properties.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def start_token(self) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def end_token(self) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: "AnyTokenizer", *args, **kwargs):
|
||||||
|
super().__init__(tokenizer, *args, **kwargs)
|
||||||
|
|
||||||
|
if not self.model_tokenizer:
|
||||||
|
raise ValueError("Tokenizer must be passed to ReasoningParser.")
|
||||||
|
if not self.start_token or not self.end_token:
|
||||||
|
raise ValueError("start_token and end_token must be defined.")
|
||||||
|
|
||||||
|
self.start_token_id: Optional[int] = self.vocab.get(self.start_token)
|
||||||
|
self.end_token_id: Optional[int] = self.vocab.get(self.end_token)
|
||||||
|
if self.start_token_id is None or self.end_token_id is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{self.__class__.__name__}: could not find think tokens "
|
||||||
|
f"'{self.start_token}'/'{self.end_token}' in tokenizer vocab."
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||||
|
for token_id in reversed(input_ids):
|
||||||
|
if token_id == self.start_token_id:
|
||||||
|
return False
|
||||||
|
if token_id == self.end_token_id:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_reasoning_end_streaming(
|
||||||
|
self, input_ids: Sequence[int], delta_ids: Iterable[int]
|
||||||
|
) -> bool:
|
||||||
|
return self.end_token_id in delta_ids
|
||||||
|
|
||||||
|
def extract_content_ids(self, input_ids: list) -> list:
|
||||||
|
if self.end_token_id not in input_ids[:-1]:
|
||||||
|
return []
|
||||||
|
return input_ids[input_ids.index(self.end_token_id) + 1:]
|
||||||
|
|
||||||
|
def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
|
||||||
|
count = 0
|
||||||
|
depth = 0
|
||||||
|
for tid in token_ids:
|
||||||
|
if tid == self.start_token_id:
|
||||||
|
depth += 1
|
||||||
|
elif tid == self.end_token_id:
|
||||||
|
if depth > 0:
|
||||||
|
depth -= 1
|
||||||
|
elif depth > 0:
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
def extract_reasoning(
|
||||||
|
self, model_output: str, request: Any
|
||||||
|
) -> "tuple[Optional[str], Optional[str]]":
|
||||||
|
# Strip <think> if the model generated it (old-style template).
|
||||||
|
parts = model_output.partition(self.start_token)
|
||||||
|
model_output = parts[2] if parts[1] else parts[0]
|
||||||
|
|
||||||
|
if self.end_token not in model_output:
|
||||||
|
return model_output, None
|
||||||
|
reasoning, _, content = model_output.partition(self.end_token)
|
||||||
|
return reasoning, content or None
|
||||||
|
|
||||||
|
def extract_reasoning_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
) -> Optional["DeltaMessage"]:
|
||||||
|
from vllm.entrypoints.openai.protocol import DeltaMessage as _DeltaMessage
|
||||||
|
|
||||||
|
# Suppress lone control tokens.
|
||||||
|
if len(delta_token_ids) == 1 and delta_token_ids[0] in (
|
||||||
|
self.start_token_id, self.end_token_id
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
start_in_prev = self.start_token_id in previous_token_ids
|
||||||
|
start_in_delta = self.start_token_id in delta_token_ids
|
||||||
|
end_in_prev = self.end_token_id in previous_token_ids
|
||||||
|
end_in_delta = self.end_token_id in delta_token_ids
|
||||||
|
|
||||||
|
if start_in_prev:
|
||||||
|
if end_in_delta:
|
||||||
|
end_idx = delta_text.find(self.end_token)
|
||||||
|
reasoning = delta_text[:end_idx] if end_idx >= 0 else ""
|
||||||
|
content = delta_text[end_idx + len(self.end_token):] if end_idx >= 0 else None
|
||||||
|
return _DeltaMessage(
|
||||||
|
reasoning_content=reasoning or None,
|
||||||
|
content=content or None,
|
||||||
|
)
|
||||||
|
elif end_in_prev:
|
||||||
|
return _DeltaMessage(content=delta_text)
|
||||||
|
else:
|
||||||
|
return _DeltaMessage(reasoning_content=delta_text)
|
||||||
|
|
||||||
|
elif start_in_delta:
|
||||||
|
if end_in_delta:
|
||||||
|
start_idx = delta_text.find(self.start_token)
|
||||||
|
end_idx = delta_text.find(self.end_token)
|
||||||
|
reasoning = delta_text[start_idx + len(self.start_token):end_idx]
|
||||||
|
content = delta_text[end_idx + len(self.end_token):]
|
||||||
|
return _DeltaMessage(
|
||||||
|
reasoning_content=reasoning or None,
|
||||||
|
content=content or None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return _DeltaMessage(reasoning_content=delta_text)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return _DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
|
|
||||||
|
class ReasoningParserManager:
|
||||||
|
"""
|
||||||
|
Registry for ReasoningParser implementations.
|
||||||
|
Supports eager and lazy registration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_parsers: dict = {} # name -> class (eager)
|
||||||
|
_lazy: dict = {} # name -> (module_path, class_name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_module(cls, name: str, parser_cls: type) -> None:
|
||||||
|
"""Eagerly register a ReasoningParser class."""
|
||||||
|
if not issubclass(parser_cls, ReasoningParser):
|
||||||
|
raise TypeError(f"{parser_cls} is not a ReasoningParser subclass.")
|
||||||
|
cls._parsers[name] = parser_cls
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_lazy(cls, name: str, module_path: str, class_name: str) -> None:
|
||||||
|
"""Register a parser for deferred import."""
|
||||||
|
cls._lazy[name] = (module_path, class_name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_reasoning_parser(cls, name: str) -> type:
|
||||||
|
if name in cls._parsers:
|
||||||
|
return cls._parsers[name]
|
||||||
|
if name in cls._lazy:
|
||||||
|
module_path, class_name = cls._lazy[name]
|
||||||
|
mod = importlib.import_module(module_path)
|
||||||
|
parser_cls = getattr(mod, class_name)
|
||||||
|
cls._parsers[name] = parser_cls
|
||||||
|
return parser_cls
|
||||||
|
registered = sorted(set(cls._parsers) | set(cls._lazy))
|
||||||
|
raise KeyError(
|
||||||
|
f"Reasoning parser '{name}' not found. "
|
||||||
|
f"Available: {registered}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_registered(cls) -> list:
|
||||||
|
return sorted(set(cls._parsers) | set(cls._lazy))
|
||||||
108
qwen3_6_scripts/reasoning/qwen3_reasoning_parser.py
Normal file
108
qwen3_6_scripts/reasoning/qwen3_reasoning_parser.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""
|
||||||
|
Reasoning parser for Qwen3 / Qwen3.5 / Qwen3.6 model family.
|
||||||
|
Adapted from vllm-original/vllm/reasoning/qwen3_reasoning_parser.py.
|
||||||
|
|
||||||
|
The model uses <think>...</think> to wrap chain-of-thought output.
|
||||||
|
For Qwen3.5+ the chat template injects <think> into the prompt, so only
|
||||||
|
</think> appears in the generated tokens; older templates generate <think>
|
||||||
|
themselves. Both styles are handled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Sequence, Any
|
||||||
|
|
||||||
|
from vllm.reasoning.abs_reasoning_parsers import (
|
||||||
|
BaseThinkingReasoningParser,
|
||||||
|
ReasoningParserManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3ReasoningParser(BaseThinkingReasoningParser):
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: Any, *args, **kwargs):
|
||||||
|
super().__init__(tokenizer, *args, **kwargs)
|
||||||
|
chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
|
||||||
|
self.thinking_enabled = chat_kwargs.get("enable_thinking", True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_token(self) -> str:
|
||||||
|
return "<think>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_token(self) -> str:
|
||||||
|
return "</think>"
|
||||||
|
|
||||||
|
def extract_reasoning(
|
||||||
|
self, model_output: str, request: Any
|
||||||
|
) -> "tuple[Optional[str], Optional[str]]":
|
||||||
|
# Strip <think> if the model generated it (old template / edge case).
|
||||||
|
parts = model_output.partition(self.start_token)
|
||||||
|
model_output = parts[2] if parts[1] else parts[0]
|
||||||
|
|
||||||
|
if self.end_token not in model_output:
|
||||||
|
if not self.thinking_enabled:
|
||||||
|
return None, model_output
|
||||||
|
# Thinking enabled but output truncated before </think>.
|
||||||
|
return model_output, None
|
||||||
|
|
||||||
|
reasoning, _, content = model_output.partition(self.end_token)
|
||||||
|
return reasoning, content or None
|
||||||
|
|
||||||
|
def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
|
||||||
|
token_ids = list(token_ids)
|
||||||
|
if self.start_token_id in token_ids:
|
||||||
|
# Old-style template: model generates <think> itself.
|
||||||
|
# Use depth-counting from the base class.
|
||||||
|
return super().count_reasoning_tokens(token_ids)
|
||||||
|
elif self.end_token_id in token_ids:
|
||||||
|
# New-style template (Qwen3.5+): <think> is injected into the
|
||||||
|
# prompt, so output starts already inside the thinking block.
|
||||||
|
# Every token before </think> is a reasoning token.
|
||||||
|
return token_ids.index(self.end_token_id)
|
||||||
|
else:
|
||||||
|
# No </think> in output: either truncated (all reasoning)
|
||||||
|
# or thinking disabled (none).
|
||||||
|
return len(token_ids) if self.thinking_enabled else 0
|
||||||
|
|
||||||
|
def extract_reasoning_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
):
|
||||||
|
from vllm.entrypoints.openai.protocol import DeltaMessage
|
||||||
|
|
||||||
|
if not self.thinking_enabled:
|
||||||
|
return DeltaMessage(content=delta_text) if delta_text else None
|
||||||
|
|
||||||
|
# Strip <think> from delta if the model generates it itself.
|
||||||
|
if self.start_token_id in delta_token_ids:
|
||||||
|
start_idx = delta_text.find(self.start_token)
|
||||||
|
if start_idx >= 0:
|
||||||
|
delta_text = delta_text[start_idx + len(self.start_token):]
|
||||||
|
|
||||||
|
if self.end_token_id in delta_token_ids:
|
||||||
|
end_idx = delta_text.find(self.end_token)
|
||||||
|
if end_idx >= 0:
|
||||||
|
reasoning = delta_text[:end_idx]
|
||||||
|
content = delta_text[end_idx + len(self.end_token):]
|
||||||
|
if not reasoning and not content:
|
||||||
|
return None
|
||||||
|
return DeltaMessage(
|
||||||
|
reasoning_content=reasoning or None,
|
||||||
|
content=content or None,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not delta_text:
|
||||||
|
return None
|
||||||
|
elif self.end_token_id in previous_token_ids:
|
||||||
|
return DeltaMessage(content=delta_text)
|
||||||
|
else:
|
||||||
|
return DeltaMessage(reasoning_content=delta_text)
|
||||||
|
|
||||||
|
|
||||||
|
# Register immediately when this module is imported.
|
||||||
|
ReasoningParserManager.register_module("qwen3", Qwen3ReasoningParser)
|
||||||
994
qwen3_6_scripts/serving_chat.py
Normal file
994
qwen3_6_scripts/serving_chat.py
Normal file
@@ -0,0 +1,994 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List,
|
||||||
|
Optional)
|
||||||
|
from typing import Sequence as GenericSequence
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
|
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||||
|
apply_hf_chat_template,
|
||||||
|
apply_mistral_chat_template,
|
||||||
|
load_chat_template,
|
||||||
|
parse_chat_messages_futures)
|
||||||
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
ChatCompletionLogProb, ChatCompletionLogProbs,
|
||||||
|
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
|
||||||
|
ChatCompletionRequest, ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||||
|
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
||||||
|
DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata,
|
||||||
|
ToolCall, UsageInfo)
|
||||||
|
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||||
|
LoRAModulePath,
|
||||||
|
OpenAIServing,
|
||||||
|
PromptAdapterPath,
|
||||||
|
TextTokensPrompt)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||||
|
from vllm.inputs import TokensPrompt
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
|
from vllm.sequence import Logprob
|
||||||
|
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||||
|
log_tracing_disabled_warning)
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
|
from vllm.utils import iterate_with_cancellation, random_uuid
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIServingChat(OpenAIServing):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
engine_client: EngineClient,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
base_model_paths: List[BaseModelPath],
|
||||||
|
response_role: str,
|
||||||
|
*,
|
||||||
|
lora_modules: Optional[List[LoRAModulePath]],
|
||||||
|
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||||
|
request_logger: Optional[RequestLogger],
|
||||||
|
chat_template: Optional[str],
|
||||||
|
return_tokens_as_token_ids: bool = False,
|
||||||
|
enable_auto_tools: bool = False,
|
||||||
|
tool_parser: Optional[str] = None,
|
||||||
|
reasoning_parser: Optional[str] = None):
|
||||||
|
super().__init__(engine_client=engine_client,
|
||||||
|
model_config=model_config,
|
||||||
|
base_model_paths=base_model_paths,
|
||||||
|
lora_modules=lora_modules,
|
||||||
|
prompt_adapters=prompt_adapters,
|
||||||
|
request_logger=request_logger,
|
||||||
|
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||||
|
|
||||||
|
self.response_role = response_role
|
||||||
|
self.use_tool_use_model_template = False
|
||||||
|
self.chat_template = load_chat_template(chat_template)
|
||||||
|
|
||||||
|
# set up tool use
|
||||||
|
self.enable_auto_tools: bool = enable_auto_tools
|
||||||
|
if self.enable_auto_tools:
|
||||||
|
logger.info(
|
||||||
|
"\"auto\" tool choice has been enabled please note that while"
|
||||||
|
" the parallel_tool_calls client option is preset for "
|
||||||
|
"compatibility reasons, it will be ignored.")
|
||||||
|
|
||||||
|
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
|
||||||
|
if self.enable_auto_tools:
|
||||||
|
try:
|
||||||
|
self.tool_parser = ToolParserManager.get_tool_parser(
|
||||||
|
tool_parser)
|
||||||
|
except Exception as e:
|
||||||
|
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||||
|
f"tool_parser:'{tool_parser}' which has not "
|
||||||
|
"been registered") from e
|
||||||
|
|
||||||
|
# set up reasoning parser
|
||||||
|
self.reasoning_parser_cls = None
|
||||||
|
if reasoning_parser:
|
||||||
|
try:
|
||||||
|
from vllm.reasoning import ReasoningParserManager
|
||||||
|
self.reasoning_parser_cls = \
|
||||||
|
ReasoningParserManager.get_reasoning_parser(reasoning_parser)
|
||||||
|
logger.info("Reasoning parser '%s' enabled.", reasoning_parser)
|
||||||
|
except Exception as e:
|
||||||
|
raise TypeError(
|
||||||
|
f"Error: --reasoning-parser '{reasoning_parser}' could not "
|
||||||
|
"be loaded. Make sure vllm/reasoning/ is installed."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
async def create_chat_completion(
|
||||||
|
self,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
raw_request: Optional[Request] = None,
|
||||||
|
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
|
||||||
|
ErrorResponse]:
|
||||||
|
"""Completion API similar to OpenAI's API.
|
||||||
|
|
||||||
|
See https://platform.openai.com/docs/api-reference/chat/create
|
||||||
|
for the API specification. This API mimics the OpenAI
|
||||||
|
ChatCompletion API.
|
||||||
|
|
||||||
|
"""
|
||||||
|
error_check_ret = await self._check_model(request)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
logger.error("Error with model %s", error_check_ret)
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
|
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||||
|
# This is required for the streaming case, where we return a
|
||||||
|
# success status before we actually start generating text :).
|
||||||
|
if self.engine_client.errored:
|
||||||
|
raise self.engine_client.dead_error
|
||||||
|
|
||||||
|
try:
|
||||||
|
(
|
||||||
|
lora_request,
|
||||||
|
prompt_adapter_request,
|
||||||
|
) = self._maybe_get_adapters(request)
|
||||||
|
|
||||||
|
model_config = self.model_config
|
||||||
|
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||||
|
|
||||||
|
conversation, mm_data_future = parse_chat_messages_futures(
|
||||||
|
request.messages, model_config, tokenizer)
|
||||||
|
|
||||||
|
tool_dicts = None if request.tools is None else [
|
||||||
|
tool.model_dump() for tool in request.tools
|
||||||
|
]
|
||||||
|
|
||||||
|
prompt: Union[str, List[int]]
|
||||||
|
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
|
||||||
|
if is_mistral_tokenizer:
|
||||||
|
prompt = apply_mistral_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
messages=request.messages,
|
||||||
|
chat_template=request.chat_template or self.chat_template,
|
||||||
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
|
continue_final_message=request.continue_final_message,
|
||||||
|
tools=tool_dicts,
|
||||||
|
documents=request.documents,
|
||||||
|
**(request.chat_template_kwargs or {}),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = apply_hf_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
conversation=conversation,
|
||||||
|
chat_template=request.chat_template or self.chat_template,
|
||||||
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
|
continue_final_message=request.continue_final_message,
|
||||||
|
tools=tool_dicts,
|
||||||
|
documents=request.documents,
|
||||||
|
**(request.chat_template_kwargs or {}),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error in applying chat template from request")
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
try:
|
||||||
|
mm_data = await mm_data_future
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error in loading multi-modal data")
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
# validation for OpenAI tools
|
||||||
|
# tool_choice = "required" is not supported
|
||||||
|
if request.tool_choice == "required":
|
||||||
|
return self.create_error_response(
|
||||||
|
"tool_choice = \"required\" is not supported!")
|
||||||
|
|
||||||
|
if not is_mistral_tokenizer and request.tool_choice == "auto" and not (
|
||||||
|
self.enable_auto_tools and self.tool_parser is not None):
|
||||||
|
# for hf tokenizers, "auto" tools requires
|
||||||
|
# --enable-auto-tool-choice and --tool-call-parser
|
||||||
|
return self.create_error_response(
|
||||||
|
"\"auto\" tool choice requires "
|
||||||
|
"--enable-auto-tool-choice and --tool-call-parser to be set")
|
||||||
|
|
||||||
|
request_id = f"chat-{random_uuid()}"
|
||||||
|
|
||||||
|
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||||
|
if raw_request:
|
||||||
|
raw_request.state.request_metadata = request_metadata
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self.enable_auto_tools and self.tool_parser:
|
||||||
|
request = self.tool_parser(tokenizer).adjust_request(
|
||||||
|
request=request)
|
||||||
|
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
prompt_inputs = self._tokenize_prompt_input(
|
||||||
|
request,
|
||||||
|
tokenizer,
|
||||||
|
prompt,
|
||||||
|
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||||
|
add_special_tokens=request.add_special_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert isinstance(prompt, list) and isinstance(
|
||||||
|
prompt[0], int
|
||||||
|
), "Prompt has to be either a string or a list of token ids"
|
||||||
|
prompt_inputs = TextTokensPrompt(
|
||||||
|
prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)
|
||||||
|
|
||||||
|
assert prompt_inputs is not None
|
||||||
|
|
||||||
|
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||||
|
default_max_tokens = self.max_model_len - len(
|
||||||
|
prompt_inputs["prompt_token_ids"])
|
||||||
|
if request.use_beam_search:
|
||||||
|
sampling_params = request.to_beam_search_params(
|
||||||
|
default_max_tokens)
|
||||||
|
else:
|
||||||
|
sampling_params = request.to_sampling_params(
|
||||||
|
default_max_tokens)
|
||||||
|
|
||||||
|
self._log_inputs(request_id,
|
||||||
|
prompt_inputs,
|
||||||
|
params=sampling_params,
|
||||||
|
lora_request=lora_request,
|
||||||
|
prompt_adapter_request=prompt_adapter_request)
|
||||||
|
|
||||||
|
engine_inputs = TokensPrompt(
|
||||||
|
prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
||||||
|
if mm_data is not None:
|
||||||
|
engine_inputs["multi_modal_data"] = mm_data
|
||||||
|
|
||||||
|
is_tracing_enabled = (await
|
||||||
|
self.engine_client.is_tracing_enabled())
|
||||||
|
trace_headers = None
|
||||||
|
if is_tracing_enabled and raw_request:
|
||||||
|
trace_headers = extract_trace_headers(raw_request.headers)
|
||||||
|
if (not is_tracing_enabled and raw_request
|
||||||
|
and contains_trace_headers(raw_request.headers)):
|
||||||
|
log_tracing_disabled_warning()
|
||||||
|
|
||||||
|
if isinstance(sampling_params, BeamSearchParams):
|
||||||
|
assert isinstance(self.engine_client,
|
||||||
|
(AsyncLLMEngine,
|
||||||
|
MQLLMEngineClient)), \
|
||||||
|
"Beam search is only supported with" \
|
||||||
|
"AsyncLLMEngine and MQLLMEngineClient."
|
||||||
|
result_generator = self.engine_client.beam_search(
|
||||||
|
engine_inputs['prompt_token_ids'],
|
||||||
|
request_id,
|
||||||
|
sampling_params,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result_generator = self.engine_client.generate(
|
||||||
|
engine_inputs,
|
||||||
|
sampling_params,
|
||||||
|
request_id,
|
||||||
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
priority=request.priority,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
if raw_request:
|
||||||
|
result_generator = iterate_with_cancellation(
|
||||||
|
result_generator, raw_request.is_disconnected)
|
||||||
|
|
||||||
|
# Streaming response
|
||||||
|
if request.stream:
|
||||||
|
return self.chat_completion_stream_generator(
|
||||||
|
request, result_generator, request_id, conversation, tokenizer,
|
||||||
|
request_metadata)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self.chat_completion_full_generator(
|
||||||
|
request, result_generator, request_id, conversation, tokenizer,
|
||||||
|
request_metadata)
|
||||||
|
except ValueError as e:
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
|
||||||
|
if request.add_generation_prompt:
|
||||||
|
return self.response_role
|
||||||
|
return request.messages[-1]["role"]
|
||||||
|
|
||||||
|
async def chat_completion_stream_generator(
|
||||||
|
self,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
result_generator: AsyncIterator[RequestOutput],
|
||||||
|
request_id: str,
|
||||||
|
conversation: List[ConversationMessage],
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
request_metadata: RequestResponseMetadata,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
model_name = self.base_model_paths[0].name
|
||||||
|
created_time = int(time.time())
|
||||||
|
chunk_object_type: Final = "chat.completion.chunk"
|
||||||
|
first_iteration = True
|
||||||
|
|
||||||
|
# Send response for each token for each request.n (index)
|
||||||
|
num_choices = 1 if request.n is None else request.n
|
||||||
|
previous_num_tokens = [0] * num_choices
|
||||||
|
finish_reason_sent = [False] * num_choices
|
||||||
|
num_prompt_tokens = 0
|
||||||
|
|
||||||
|
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||||
|
tool_choice_function_name = request.tool_choice.function.name
|
||||||
|
else:
|
||||||
|
tool_choice_function_name = None
|
||||||
|
|
||||||
|
# Determine whether tools are in use with "auto" tool choice
|
||||||
|
tool_choice_auto = (
|
||||||
|
not tool_choice_function_name
|
||||||
|
and self._should_stream_with_auto_tool_parsing(request))
|
||||||
|
|
||||||
|
use_reasoning = self.reasoning_parser_cls is not None
|
||||||
|
|
||||||
|
all_previous_token_ids: Optional[List[List[int]]]
|
||||||
|
# previous_texts / all_previous_token_ids are needed for both tool
|
||||||
|
# parsing and reasoning parsing (both require full-history context).
|
||||||
|
if tool_choice_auto or use_reasoning:
|
||||||
|
previous_texts = [""] * num_choices
|
||||||
|
all_previous_token_ids = [[]] * num_choices
|
||||||
|
else:
|
||||||
|
previous_texts, all_previous_token_ids = None, None
|
||||||
|
|
||||||
|
# Prepare the tool parser if it's needed
|
||||||
|
try:
|
||||||
|
if tool_choice_auto and self.tool_parser:
|
||||||
|
tool_parsers: List[Optional[ToolParser]] = [
|
||||||
|
self.tool_parser(tokenizer)
|
||||||
|
] * num_choices
|
||||||
|
else:
|
||||||
|
tool_parsers = [None] * num_choices
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error("Error in tool parser creation: %s", e)
|
||||||
|
data = self.create_streaming_error_response(str(e))
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
# Prepare reasoning parsers (one instance per choice for state isolation)
|
||||||
|
reasoning_parsers: List[Optional[object]] = [None] * num_choices
|
||||||
|
reasoning_end_arr: List[bool] = [False] * num_choices
|
||||||
|
reasoning_token_counts: List[int] = [0] * num_choices
|
||||||
|
if use_reasoning:
|
||||||
|
try:
|
||||||
|
reasoning_parsers = [
|
||||||
|
self.reasoning_parser_cls(
|
||||||
|
tokenizer,
|
||||||
|
chat_template_kwargs=request.chat_template_kwargs)
|
||||||
|
for _ in range(num_choices)
|
||||||
|
]
|
||||||
|
# If thinking is disabled per-request, mark reasoning as
|
||||||
|
# already ended so the tool-auto branch is reachable.
|
||||||
|
for idx, rp in enumerate(reasoning_parsers):
|
||||||
|
if hasattr(rp, 'thinking_enabled') and not rp.thinking_enabled:
|
||||||
|
reasoning_end_arr[idx] = True
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error("Error in reasoning parser creation: %s", e)
|
||||||
|
data = self.create_streaming_error_response(str(e))
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for res in result_generator:
|
||||||
|
if res.prompt_token_ids is not None:
|
||||||
|
num_prompt_tokens = len(res.prompt_token_ids)
|
||||||
|
if res.encoder_prompt_token_ids is not None:
|
||||||
|
num_prompt_tokens += len(res.encoder_prompt_token_ids)
|
||||||
|
|
||||||
|
# We need to do it here, because if there are exceptions in
|
||||||
|
# the result_generator, it needs to be sent as the FIRST
|
||||||
|
# response (by the try...catch).
|
||||||
|
if first_iteration:
|
||||||
|
# Send first response for each request.n (index) with
|
||||||
|
# the role
|
||||||
|
role = self.get_chat_request_role(request)
|
||||||
|
|
||||||
|
# NOTE num_choices defaults to 1 so this usually executes
|
||||||
|
# once per request
|
||||||
|
for i in range(num_choices):
|
||||||
|
tool_parser = tool_parsers[i]
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=i,
|
||||||
|
delta=DeltaMessage(
|
||||||
|
role=role,
|
||||||
|
content="",
|
||||||
|
),
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=None)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
|
choices=[choice_data],
|
||||||
|
model=model_name)
|
||||||
|
|
||||||
|
# if usage should be included
|
||||||
|
if (request.stream_options
|
||||||
|
and request.stream_options.include_usage):
|
||||||
|
# if continuous usage stats are requested, add it
|
||||||
|
if request.stream_options.continuous_usage_stats:
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=num_prompt_tokens)
|
||||||
|
chunk.usage = usage
|
||||||
|
# otherwise don't
|
||||||
|
else:
|
||||||
|
chunk.usage = None
|
||||||
|
|
||||||
|
data = chunk.model_dump_json(exclude_unset=True)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
# Send response to echo the input portion of the
|
||||||
|
# last message
|
||||||
|
if request.echo or request.continue_final_message:
|
||||||
|
last_msg_content: str = ""
|
||||||
|
if conversation and "content" in conversation[
|
||||||
|
-1] and conversation[-1].get("role") == role:
|
||||||
|
last_msg_content = conversation[-1]["content"] or ""
|
||||||
|
|
||||||
|
if last_msg_content:
|
||||||
|
for i in range(num_choices):
|
||||||
|
choice_data = (
|
||||||
|
ChatCompletionResponseStreamChoice(
|
||||||
|
index=i,
|
||||||
|
delta=DeltaMessage(
|
||||||
|
content=last_msg_content),
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=None))
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
|
choices=[choice_data],
|
||||||
|
model=model_name)
|
||||||
|
if (request.stream_options and
|
||||||
|
request.stream_options.include_usage):
|
||||||
|
if (request.stream_options.
|
||||||
|
continuous_usage_stats):
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=num_prompt_tokens)
|
||||||
|
chunk.usage = usage
|
||||||
|
else:
|
||||||
|
chunk.usage = None
|
||||||
|
|
||||||
|
data = chunk.model_dump_json(
|
||||||
|
exclude_unset=True)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
first_iteration = False
|
||||||
|
|
||||||
|
for output in res.outputs:
|
||||||
|
i = output.index
|
||||||
|
tool_parser = tool_parsers[i]
|
||||||
|
|
||||||
|
if finish_reason_sent[i]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if request.logprobs and request.top_logprobs is not None:
|
||||||
|
assert output.logprobs is not None, (
|
||||||
|
"Did not output logprobs")
|
||||||
|
logprobs = self._create_chat_logprobs(
|
||||||
|
token_ids=output.token_ids,
|
||||||
|
top_logprobs=output.logprobs,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
num_output_top_logprobs=request.top_logprobs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprobs = None
|
||||||
|
|
||||||
|
delta_text = output.text
|
||||||
|
delta_message: Optional[DeltaMessage]
|
||||||
|
|
||||||
|
# Maintain text/token history when either reasoning or
|
||||||
|
# auto-tool parsing is active.
|
||||||
|
assert previous_texts is not None or not (
|
||||||
|
tool_choice_auto or use_reasoning)
|
||||||
|
if previous_texts is not None:
|
||||||
|
assert all_previous_token_ids is not None
|
||||||
|
previous_text = previous_texts[i]
|
||||||
|
previous_token_ids = all_previous_token_ids[i]
|
||||||
|
current_text = previous_text + delta_text
|
||||||
|
current_token_ids = previous_token_ids + list(
|
||||||
|
output.token_ids)
|
||||||
|
previous_texts[i] = current_text
|
||||||
|
all_previous_token_ids[i] = current_token_ids
|
||||||
|
else:
|
||||||
|
previous_text = ""
|
||||||
|
previous_token_ids = []
|
||||||
|
current_text = delta_text
|
||||||
|
current_token_ids = list(output.token_ids)
|
||||||
|
|
||||||
|
# handle streaming deltas for tools with named tool_choice
|
||||||
|
if tool_choice_function_name:
|
||||||
|
delta_message = DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(function=DeltaFunctionCall(
|
||||||
|
name=tool_choice_function_name,
|
||||||
|
arguments=delta_text),
|
||||||
|
index=i)
|
||||||
|
])
|
||||||
|
|
||||||
|
# handle reasoning: route through reasoning parser while
|
||||||
|
# </think> has not yet been seen.
|
||||||
|
elif use_reasoning and not reasoning_end_arr[i]:
|
||||||
|
r_parser = reasoning_parsers[i]
|
||||||
|
delta_message = r_parser.extract_reasoning_streaming(
|
||||||
|
previous_text=previous_text,
|
||||||
|
current_text=current_text,
|
||||||
|
delta_text=delta_text,
|
||||||
|
previous_token_ids=previous_token_ids,
|
||||||
|
current_token_ids=current_token_ids,
|
||||||
|
delta_token_ids=output.token_ids,
|
||||||
|
)
|
||||||
|
# Mark reasoning as ended when end token appears.
|
||||||
|
if r_parser.end_token_id in current_token_ids:
|
||||||
|
reasoning_end_arr[i] = True
|
||||||
|
|
||||||
|
# handle streaming deltas for tools with "auto" tool choice
|
||||||
|
# (only reached after reasoning block, if any, has ended)
|
||||||
|
elif tool_choice_auto:
|
||||||
|
assert tool_parser is not None
|
||||||
|
delta_message = (
|
||||||
|
tool_parser.extract_tool_calls_streaming(
|
||||||
|
previous_text=previous_text,
|
||||||
|
current_text=current_text,
|
||||||
|
delta_text=delta_text,
|
||||||
|
previous_token_ids=previous_token_ids,
|
||||||
|
current_token_ids=current_token_ids,
|
||||||
|
delta_token_ids=output.token_ids,
|
||||||
|
request=request))
|
||||||
|
|
||||||
|
# handle streaming just a content delta
|
||||||
|
else:
|
||||||
|
delta_message = DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
|
# set the previous values for the next iteration
|
||||||
|
previous_num_tokens[i] += len(output.token_ids)
|
||||||
|
|
||||||
|
# if the message delta is None (e.g. because it was a
|
||||||
|
# "control token" for tool calls or the parser otherwise
|
||||||
|
# wasn't ready to send a token, then
|
||||||
|
# get the next token without streaming a chunk
|
||||||
|
if delta_message is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if output.finish_reason is None:
|
||||||
|
# Send token-by-token response for each request.n
|
||||||
|
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=i,
|
||||||
|
delta=delta_message,
|
||||||
|
logprobs=logprobs,
|
||||||
|
finish_reason=None)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
|
choices=[choice_data],
|
||||||
|
model=model_name)
|
||||||
|
|
||||||
|
# handle usage stats if requested & if continuous
|
||||||
|
if (request.stream_options
|
||||||
|
and request.stream_options.include_usage):
|
||||||
|
if request.stream_options.continuous_usage_stats:
|
||||||
|
completion_tokens = len(output.token_ids)
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=num_prompt_tokens +
|
||||||
|
completion_tokens,
|
||||||
|
)
|
||||||
|
chunk.usage = usage
|
||||||
|
else:
|
||||||
|
chunk.usage = None
|
||||||
|
|
||||||
|
data = chunk.model_dump_json(exclude_unset=True)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
# if the model is finished generating
|
||||||
|
else:
|
||||||
|
# check to make sure we haven't "forgotten" to stream
|
||||||
|
# any tokens that were generated but previously
|
||||||
|
# matched by partial json parsing
|
||||||
|
# only happens if we are NOT using guided decoding
|
||||||
|
auto_tools_called = False
|
||||||
|
if tool_parser:
|
||||||
|
auto_tools_called = len(
|
||||||
|
tool_parser.prev_tool_call_arr) > 0
|
||||||
|
index = len(tool_parser.prev_tool_call_arr
|
||||||
|
) - 1 if auto_tools_called else 0
|
||||||
|
else:
|
||||||
|
index = 0
|
||||||
|
|
||||||
|
if self._should_check_for_unstreamed_tool_arg_tokens(
|
||||||
|
delta_message, output) and tool_parser:
|
||||||
|
# get the expected call based on partial JSON
|
||||||
|
# parsing which "autocompletes" the JSON
|
||||||
|
expected_call = json.dumps(
|
||||||
|
tool_parser.prev_tool_call_arr[index].get(
|
||||||
|
"arguments", {}))
|
||||||
|
|
||||||
|
# get what we've streamed so far for arguments
|
||||||
|
# for the current tool
|
||||||
|
actual_call = tool_parser.streamed_args_for_tool[
|
||||||
|
index]
|
||||||
|
|
||||||
|
# check to see if there's anything left to stream
|
||||||
|
remaining_call = expected_call.replace(
|
||||||
|
actual_call, "", 1)
|
||||||
|
|
||||||
|
# set that as a delta message
|
||||||
|
delta_message = DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(index=index,
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
arguments=remaining_call).
|
||||||
|
model_dump(exclude_none=True))
|
||||||
|
])
|
||||||
|
|
||||||
|
# Count reasoning tokens for this choice at finish time.
|
||||||
|
if use_reasoning and all_previous_token_ids is not None:
|
||||||
|
r_parser = reasoning_parsers[i]
|
||||||
|
reasoning_token_counts[i] = \
|
||||||
|
r_parser.count_reasoning_tokens(
|
||||||
|
all_previous_token_ids[i])
|
||||||
|
|
||||||
|
# Send the finish response for each request.n only once
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=i,
|
||||||
|
delta=delta_message,
|
||||||
|
logprobs=logprobs,
|
||||||
|
finish_reason=output.finish_reason
|
||||||
|
if not auto_tools_called else "tool_calls",
|
||||||
|
stop_reason=output.stop_reason)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
|
choices=[choice_data],
|
||||||
|
model=model_name)
|
||||||
|
if (request.stream_options
|
||||||
|
and request.stream_options.include_usage):
|
||||||
|
if request.stream_options.continuous_usage_stats:
|
||||||
|
completion_tokens = len(output.token_ids)
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=num_prompt_tokens +
|
||||||
|
completion_tokens,
|
||||||
|
)
|
||||||
|
chunk.usage = usage
|
||||||
|
else:
|
||||||
|
chunk.usage = None
|
||||||
|
data = chunk.model_dump_json(exclude_unset=True)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
finish_reason_sent[i] = True
|
||||||
|
|
||||||
|
# once the final token is handled, if stream_options.include_usage
|
||||||
|
# is sent, send the usage
|
||||||
|
if (request.stream_options
|
||||||
|
and request.stream_options.include_usage):
|
||||||
|
completion_tokens = previous_num_tokens[i]
|
||||||
|
total_reasoning = sum(reasoning_token_counts) if use_reasoning else None
|
||||||
|
final_usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=num_prompt_tokens + completion_tokens,
|
||||||
|
reasoning_tokens=total_reasoning,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_usage_chunk = ChatCompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
|
choices=[],
|
||||||
|
model=model_name,
|
||||||
|
usage=final_usage)
|
||||||
|
final_usage_data = (final_usage_chunk.model_dump_json(
|
||||||
|
exclude_unset=True, exclude_none=True))
|
||||||
|
yield f"data: {final_usage_data}\n\n"
|
||||||
|
|
||||||
|
# report to FastAPI middleware aggregate usage across all choices
|
||||||
|
num_completion_tokens = sum(previous_num_tokens)
|
||||||
|
total_reasoning = sum(reasoning_token_counts) if use_reasoning else None
|
||||||
|
request_metadata.final_usage_info = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=num_completion_tokens,
|
||||||
|
total_tokens=num_prompt_tokens + num_completion_tokens,
|
||||||
|
reasoning_tokens=total_reasoning)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
logger.error("error in chat completion stream generator: %s", e)
|
||||||
|
data = self.create_streaming_error_response(str(e))
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
# Send the final done message after all response.n are finished
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
async def chat_completion_full_generator(
|
||||||
|
self,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
result_generator: AsyncIterator[RequestOutput],
|
||||||
|
request_id: str,
|
||||||
|
conversation: List[ConversationMessage],
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
request_metadata: RequestResponseMetadata,
|
||||||
|
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||||
|
|
||||||
|
model_name = self.base_model_paths[0].name
|
||||||
|
created_time = int(time.time())
|
||||||
|
final_res: Optional[RequestOutput] = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for res in result_generator:
|
||||||
|
final_res = res
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return self.create_error_response("Client disconnected")
|
||||||
|
|
||||||
|
assert final_res is not None
|
||||||
|
|
||||||
|
choices: List[ChatCompletionResponseChoice] = []
|
||||||
|
|
||||||
|
role = self.get_chat_request_role(request)
|
||||||
|
for output in final_res.outputs:
|
||||||
|
token_ids = output.token_ids
|
||||||
|
out_logprobs = output.logprobs
|
||||||
|
|
||||||
|
if request.logprobs and request.top_logprobs is not None:
|
||||||
|
assert out_logprobs is not None, "Did not output logprobs"
|
||||||
|
logprobs = self._create_chat_logprobs(
|
||||||
|
token_ids=token_ids,
|
||||||
|
top_logprobs=out_logprobs,
|
||||||
|
num_output_top_logprobs=request.top_logprobs,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprobs = None
|
||||||
|
|
||||||
|
# In the OpenAI API the finish_reason is "tools_called"
|
||||||
|
# if the tool choice is auto and the model produced a tool
|
||||||
|
# call. The same is not true for named function calls
|
||||||
|
auto_tools_called = False
|
||||||
|
|
||||||
|
# Extract reasoning content if parser is configured.
|
||||||
|
# output_text is what remains after stripping <think>...</think>.
|
||||||
|
reasoning_text: Optional[str] = None
|
||||||
|
output_text: str = output.text
|
||||||
|
if self.reasoning_parser_cls:
|
||||||
|
r_parser = self.reasoning_parser_cls(
|
||||||
|
tokenizer,
|
||||||
|
chat_template_kwargs=request.chat_template_kwargs)
|
||||||
|
reasoning_text, extracted = r_parser.extract_reasoning(
|
||||||
|
output.text, request)
|
||||||
|
output_text = extracted or ""
|
||||||
|
|
||||||
|
# if auto tools are not enabled, and a named tool choice using
|
||||||
|
# outlines is not being used
|
||||||
|
if (not self.enable_auto_tools
|
||||||
|
or not self.tool_parser) and not isinstance(
|
||||||
|
request.tool_choice,
|
||||||
|
ChatCompletionNamedToolChoiceParam):
|
||||||
|
message = ChatMessage(role=role,
|
||||||
|
reasoning_content=reasoning_text,
|
||||||
|
content=output_text)
|
||||||
|
|
||||||
|
# if the request uses tools and specified a tool choice
|
||||||
|
elif request.tool_choice and type(
|
||||||
|
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||||
|
|
||||||
|
message = ChatMessage(
|
||||||
|
role=role,
|
||||||
|
reasoning_content=reasoning_text,
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name=request.tool_choice.function.name,
|
||||||
|
arguments=output_text))
|
||||||
|
])
|
||||||
|
|
||||||
|
# if the request doesn't use tool choice
|
||||||
|
# OR specifies to not use a tool
|
||||||
|
elif not request.tool_choice or request.tool_choice == "none":
|
||||||
|
|
||||||
|
message = ChatMessage(role=role,
|
||||||
|
reasoning_content=reasoning_text,
|
||||||
|
content=output_text)
|
||||||
|
|
||||||
|
# handle when there are tools and tool choice is auto
|
||||||
|
elif request.tools and (
|
||||||
|
request.tool_choice == "auto"
|
||||||
|
or request.tool_choice is None) and self.enable_auto_tools \
|
||||||
|
and self.tool_parser:
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool_parser = self.tool_parser(tokenizer)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error("Error in tool parser creation: %s", e)
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
# Parse tool calls from the post-reasoning content.
|
||||||
|
tool_call_info = tool_parser.extract_tool_calls(
|
||||||
|
output_text, request=request)
|
||||||
|
auto_tools_called = tool_call_info.tools_called
|
||||||
|
if tool_call_info.tools_called:
|
||||||
|
message = ChatMessage(
|
||||||
|
role=role,
|
||||||
|
reasoning_content=reasoning_text,
|
||||||
|
content=tool_call_info.content,
|
||||||
|
tool_calls=tool_call_info.tool_calls)
|
||||||
|
else:
|
||||||
|
message = ChatMessage(role=role,
|
||||||
|
reasoning_content=reasoning_text,
|
||||||
|
content=output_text)
|
||||||
|
|
||||||
|
# undetermined case that is still important to handle
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"Error in chat_completion_full_generator - cannot determine"
|
||||||
|
" if tools should be extracted. Returning a standard chat "
|
||||||
|
"completion.")
|
||||||
|
message = ChatMessage(role=role,
|
||||||
|
reasoning_content=reasoning_text,
|
||||||
|
content=output_text)
|
||||||
|
|
||||||
|
choice_data = ChatCompletionResponseChoice(
|
||||||
|
index=output.index,
|
||||||
|
message=message,
|
||||||
|
logprobs=logprobs,
|
||||||
|
finish_reason="tool_calls" if auto_tools_called else
|
||||||
|
output.finish_reason if output.finish_reason else "stop",
|
||||||
|
stop_reason=output.stop_reason)
|
||||||
|
choices.append(choice_data)
|
||||||
|
|
||||||
|
if request.echo or request.continue_final_message:
|
||||||
|
last_msg_content = ""
|
||||||
|
if conversation and "content" in conversation[-1] and conversation[
|
||||||
|
-1].get("role") == role:
|
||||||
|
last_msg_content = conversation[-1]["content"] or ""
|
||||||
|
|
||||||
|
for choice in choices:
|
||||||
|
full_message = last_msg_content + (choice.message.content
|
||||||
|
or "")
|
||||||
|
choice.message.content = full_message
|
||||||
|
|
||||||
|
assert final_res.prompt_token_ids is not None
|
||||||
|
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||||
|
if final_res.encoder_prompt_token_ids is not None:
|
||||||
|
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
|
||||||
|
num_generated_tokens = sum(
|
||||||
|
len(output.token_ids) for output in final_res.outputs)
|
||||||
|
total_reasoning_tokens: Optional[int] = None
|
||||||
|
if self.reasoning_parser_cls:
|
||||||
|
rp = self.reasoning_parser_cls(
|
||||||
|
tokenizer,
|
||||||
|
chat_template_kwargs=request.chat_template_kwargs)
|
||||||
|
total_reasoning_tokens = sum(
|
||||||
|
rp.count_reasoning_tokens(list(output.token_ids))
|
||||||
|
for output in final_res.outputs)
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=num_generated_tokens,
|
||||||
|
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||||
|
reasoning_tokens=total_reasoning_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_metadata.final_usage_info = usage
|
||||||
|
|
||||||
|
response = ChatCompletionResponse(
|
||||||
|
id=request_id,
|
||||||
|
created=created_time,
|
||||||
|
model=model_name,
|
||||||
|
choices=choices,
|
||||||
|
usage=usage,
|
||||||
|
prompt_logprobs=final_res.prompt_logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _get_top_logprobs(
|
||||||
|
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
|
||||||
|
tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
|
||||||
|
return [
|
||||||
|
ChatCompletionLogProb(token=(token := self._get_decoded_token(
|
||||||
|
p[1],
|
||||||
|
p[0],
|
||||||
|
tokenizer,
|
||||||
|
return_as_token_id=self.return_tokens_as_token_ids)),
|
||||||
|
logprob=max(p[1].logprob, -9999.0),
|
||||||
|
bytes=list(
|
||||||
|
token.encode("utf-8", errors="replace")))
|
||||||
|
for i, p in enumerate(logprobs.items())
|
||||||
|
if top_logprobs and i < top_logprobs
|
||||||
|
]
|
||||||
|
|
||||||
|
def _create_chat_logprobs(
|
||||||
|
self,
|
||||||
|
token_ids: GenericSequence[int],
|
||||||
|
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
num_output_top_logprobs: Optional[int] = None,
|
||||||
|
) -> ChatCompletionLogProbs:
|
||||||
|
"""Create OpenAI-style logprobs."""
|
||||||
|
logprobs_content: List[ChatCompletionLogProbsContent] = []
|
||||||
|
|
||||||
|
for i, token_id in enumerate(token_ids):
|
||||||
|
step_top_logprobs = top_logprobs[i]
|
||||||
|
if step_top_logprobs is None:
|
||||||
|
token = tokenizer.decode(token_id)
|
||||||
|
if self.return_tokens_as_token_ids:
|
||||||
|
token = f"token_id:{token_id}"
|
||||||
|
|
||||||
|
logprobs_content.append(
|
||||||
|
ChatCompletionLogProbsContent(
|
||||||
|
token=token,
|
||||||
|
bytes=list(token.encode("utf-8", errors="replace")),
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
step_token = step_top_logprobs[token_id]
|
||||||
|
step_decoded = step_token.decoded_token
|
||||||
|
|
||||||
|
logprobs_content.append(
|
||||||
|
ChatCompletionLogProbsContent(
|
||||||
|
token=self._get_decoded_token(
|
||||||
|
step_token,
|
||||||
|
token_id,
|
||||||
|
tokenizer,
|
||||||
|
self.return_tokens_as_token_ids,
|
||||||
|
),
|
||||||
|
logprob=max(step_token.logprob, -9999.0),
|
||||||
|
bytes=None if step_decoded is None else list(
|
||||||
|
step_decoded.encode("utf-8", errors="replace")),
|
||||||
|
top_logprobs=self._get_top_logprobs(
|
||||||
|
step_top_logprobs,
|
||||||
|
num_output_top_logprobs,
|
||||||
|
tokenizer,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
|
||||||
|
return ChatCompletionLogProbs(content=logprobs_content)
|
||||||
|
|
||||||
|
def _should_stream_with_auto_tool_parsing(self,
|
||||||
|
request: ChatCompletionRequest):
|
||||||
|
"""
|
||||||
|
Utility function to check if streamed tokens should go through the tool
|
||||||
|
call parser that was configured.
|
||||||
|
|
||||||
|
We only want to do this IF user-provided tools are set, a tool parser
|
||||||
|
is configured, "auto" tool choice is enabled, and the request's tool
|
||||||
|
choice field indicates that "auto" tool choice should be used.
|
||||||
|
"""
|
||||||
|
return (request.tools and self.tool_parser and self.enable_auto_tools
|
||||||
|
and request.tool_choice in ['auto', None])
|
||||||
|
|
||||||
|
def _should_check_for_unstreamed_tool_arg_tokens(
|
||||||
|
self,
|
||||||
|
delta_message: Optional[DeltaMessage],
|
||||||
|
output: CompletionOutput,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check to see if we should check for unstreamed tool arguments tokens.
|
||||||
|
This is only applicable when auto tool parsing is enabled, the delta
|
||||||
|
is a tool call with arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
return bool(
|
||||||
|
# if there is a delta message that includes tool calls which
|
||||||
|
# include a function that has arguments
|
||||||
|
output.finish_reason is not None
|
||||||
|
and self.enable_auto_tools and self.tool_parser and delta_message
|
||||||
|
and delta_message.tool_calls and delta_message.tool_calls[0]
|
||||||
|
and delta_message.tool_calls[0].function
|
||||||
|
and delta_message.tool_calls[0].function.arguments is not None
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user