diff --git a/qwen3_6_scripts/api_server.py b/qwen3_6_scripts/api_server.py new file mode 100644 index 0000000..e12cb7f --- /dev/null +++ b/qwen3_6_scripts/api_server.py @@ -0,0 +1,595 @@ +import asyncio +import importlib +import inspect +import multiprocessing +import os +import regex as re +import signal +import socket +import tempfile +from argparse import Namespace +from contextlib import asynccontextmanager +from functools import partial +from http import HTTPStatus +from typing import AsyncIterator, Set + +import uvloop +from fastapi import APIRouter, FastAPI, Request +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, Response, StreamingResponse +from starlette.datastructures import State +from starlette.routing import Mount +from typing_extensions import assert_never + +import vllm.envs as envs +from vllm.config import ModelConfig +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.engine.multiprocessing.engine import run_mp_engine +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.cli_args import (make_arg_parser, + validate_parsed_serve_args) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + CompletionResponse, + DetokenizeRequest, + DetokenizeResponse, + EmbeddingRequest, + EmbeddingResponse, ErrorResponse, + LoadLoraAdapterRequest, + TokenizeRequest, + TokenizeResponse, + UnloadLoraAdapterRequest) +# yapf: enable +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_engine import BaseModelPath +from vllm.entrypoints.openai.serving_tokenization import ( + OpenAIServingTokenization) +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.reasoning import ReasoningParserManager +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path +from vllm.version import __version__ as VLLM_VERSION + +TIMEOUT_KEEP_ALIVE = 5 # seconds + +prometheus_multiproc_dir: tempfile.TemporaryDirectory + +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) +logger = init_logger('vllm.entrypoints.openai.api_server') + +_running_tasks: Set[asyncio.Task] = set() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + try: + if app.state.log_stats: + engine_client: EngineClient = app.state.engine_client + + async def _force_log(): + while True: + await asyncio.sleep(10.) + await engine_client.do_log_stats() + + task = asyncio.create_task(_force_log()) + _running_tasks.add(task) + task.add_done_callback(_running_tasks.remove) + else: + task = None + try: + yield + finally: + if task is not None: + task.cancel() + finally: + # Ensure app state including engine ref is gc'd + del app.state + + +@asynccontextmanager +async def build_async_engine_client( + args: Namespace) -> AsyncIterator[EngineClient]: + + # Context manager to handle engine_client lifecycle + # Ensures everything is shutdown and cleaned up on error/exit + engine_args = AsyncEngineArgs.from_cli_args(args) + + async with build_async_engine_client_from_engine_args( + engine_args, args.disable_frontend_multiprocessing) as engine: + yield engine + + +@asynccontextmanager +async def build_async_engine_client_from_engine_args( + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, +) -> AsyncIterator[EngineClient]: + """ + Create EngineClient, either: + - in-process using the AsyncLLMEngine Directly + - multiprocess using AsyncLLMEngine RPC + + Returns the Client or None if the creation failed. + """ + + # Fall back + # TODO: fill out feature matrix. + if (MQLLMEngineClient.is_unsupported_config(engine_args) + or disable_frontend_multiprocessing): + engine_config = engine_args.create_engine_config() + uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config), + "uses_ray", False) + + build_engine = partial(AsyncLLMEngine.from_engine_args, + engine_args=engine_args, + engine_config=engine_config, + usage_context=UsageContext.OPENAI_API_SERVER) + if uses_ray: + # Must run in main thread with ray for its signal handlers to work + engine_client = build_engine() + else: + engine_client = await asyncio.get_running_loop().run_in_executor( + None, build_engine) + + yield engine_client + return + + # Otherwise, use the multiprocessing AsyncLLMEngine. + else: + if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: + # Make TemporaryDirectory for prometheus multiprocessing + # Note: global TemporaryDirectory will be automatically + # cleaned up upon exit. + global prometheus_multiproc_dir + prometheus_multiproc_dir = tempfile.TemporaryDirectory() + os.environ[ + "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name + else: + logger.warning( + "Found PROMETHEUS_MULTIPROC_DIR was set by user. " + "This directory must be wiped between vLLM runs or " + "you will find inaccurate metrics. Unset the variable " + "and vLLM will properly handle cleanup.") + + # Select random path for IPC. + ipc_path = get_open_zmq_ipc_path() + logger.info("Multiprocessing frontend to use %s for IPC Path.", + ipc_path) + + # Start RPCServer in separate process (holds the LLMEngine). + # the current process might have CUDA context, + # so we need to spawn a new process + context = multiprocessing.get_context("spawn") + + engine_process = context.Process(target=run_mp_engine, + args=(engine_args, + UsageContext.OPENAI_API_SERVER, + ipc_path)) + engine_process.start() + logger.info("Started engine process with PID %d", engine_process.pid) + + # Build RPCClient, which conforms to EngineClient Protocol. + # NOTE: Actually, this is not true yet. We still need to support + # embedding models via RPC (see TODO above) + engine_config = engine_args.create_engine_config() + mp_engine_client = MQLLMEngineClient(ipc_path, engine_config) + + try: + while True: + try: + await mp_engine_client.setup() + break + except TimeoutError: + if not engine_process.is_alive(): + raise RuntimeError( + "Engine process failed to start") from None + + yield mp_engine_client # type: ignore[misc] + finally: + # Ensure rpc server process was terminated + engine_process.terminate() + + # Close all open connections to the backend + mp_engine_client.close() + + # Wait for engine process to join + engine_process.join(4) + if engine_process.exitcode is None: + # Kill if taking longer than 5 seconds to stop + engine_process.kill() + + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from prometheus_client import multiprocess + multiprocess.mark_process_dead(engine_process.pid) + + +router = APIRouter() + + +def mount_metrics(app: FastAPI): + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from prometheus_client import (CollectorRegistry, make_asgi_app, + multiprocess) + + prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) + if prometheus_multiproc_dir_path is not None: + logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", + prometheus_multiproc_dir_path) + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) + else: + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app()) + + # Workaround for 307 Redirect for /metrics + metrics_route.path_regex = re.compile("^/metrics(?P.*)$") + app.routes.append(metrics_route) + + +def chat(request: Request) -> OpenAIServingChat: + return request.app.state.openai_serving_chat + + +def completion(request: Request) -> OpenAIServingCompletion: + return request.app.state.openai_serving_completion + + +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization + + +def embedding(request: Request) -> OpenAIServingEmbedding: + return request.app.state.openai_serving_embedding + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +@router.get("/health") +async def health(raw_request: Request) -> Response: + """Health check.""" + await engine_client(raw_request).check_health() + return Response(status_code=200) + + +@router.post("/tokenize") +async def tokenize(request: TokenizeRequest, raw_request: Request): + generator = await tokenization(raw_request).create_tokenize(request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, TokenizeResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/detokenize") +async def detokenize(request: DetokenizeRequest, raw_request: Request): + generator = await tokenization(raw_request).create_detokenize(request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, DetokenizeResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.get("/v1/models") +async def show_available_models(raw_request: Request): + models = await completion(raw_request).show_available_models() + return JSONResponse(content=models.model_dump()) + + +@router.get("/version") +async def show_version(): + ver = {"version": VLLM_VERSION} + return JSONResponse(content=ver) + + +@router.post("/v1/chat/completions") +async def create_chat_completion(request: ChatCompletionRequest, + raw_request: Request): + + generator = await chat(raw_request).create_chat_completion( + request, raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, ChatCompletionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +@router.post("/v1/completions") +async def create_completion(request: CompletionRequest, raw_request: Request): + generator = await completion(raw_request).create_completion( + request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, CompletionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +@router.post("/v1/embeddings") +async def create_embedding(request: EmbeddingRequest, raw_request: Request): + generator = await embedding(raw_request).create_embedding( + request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, EmbeddingResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +if envs.VLLM_TORCH_PROFILER_DIR: + logger.warning( + "Torch Profiler is enabled in the API server. This should ONLY be " + "used for local development!") + + @router.post("/start_profile") + async def start_profile(raw_request: Request): + logger.info("Starting profiler...") + await engine_client(raw_request).start_profile() + logger.info("Profiler started.") + return Response(status_code=200) + + @router.post("/stop_profile") + async def stop_profile(raw_request: Request): + logger.info("Stopping profiler...") + await engine_client(raw_request).stop_profile() + logger.info("Profiler stopped.") + return Response(status_code=200) + + +if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + logger.warning( + "Lora dynamic loading & unloading is enabled in the API server. " + "This should ONLY be used for local development!") + + @router.post("/v1/load_lora_adapter") + async def load_lora_adapter(request: LoadLoraAdapterRequest, + raw_request: Request): + response = await chat(raw_request).load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + response = await completion(raw_request).load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + + @router.post("/v1/unload_lora_adapter") + async def unload_lora_adapter(request: UnloadLoraAdapterRequest, + raw_request: Request): + response = await chat(raw_request).unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + response = await completion(raw_request).unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + + +def build_app(args: Namespace) -> FastAPI: + if args.disable_fastapi_docs: + app = FastAPI(openapi_url=None, + docs_url=None, + redoc_url=None, + lifespan=lifespan) + else: + app = FastAPI(lifespan=lifespan) + app.include_router(router) + app.root_path = args.root_path + + mount_metrics(app) + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + + @app.exception_handler(RequestValidationError) + async def validation_exception_handler(_, exc): + chat = app.state.openai_serving_chat + err = chat.create_error_response(message=str(exc)) + return JSONResponse(err.model_dump(), + status_code=HTTPStatus.BAD_REQUEST) + + if token := envs.VLLM_API_KEY or args.api_key: + + @app.middleware("http") + async def authentication(request: Request, call_next): + root_path = "" if args.root_path is None else args.root_path + if request.method == "OPTIONS": + return await call_next(request) + if not request.url.path.startswith(f"{root_path}/v1"): + return await call_next(request) + if request.headers.get("Authorization") != "Bearer " + token: + return JSONResponse(content={"error": "Unauthorized"}, + status_code=401) + return await call_next(request) + + for middleware in args.middleware: + module_path, object_name = middleware.rsplit(".", 1) + imported = getattr(importlib.import_module(module_path), object_name) + if inspect.isclass(imported): + app.add_middleware(imported) + elif inspect.iscoroutinefunction(imported): + app.middleware("http")(imported) + else: + raise ValueError(f"Invalid middleware {middleware}. " + f"Must be a function or a class.") + + return app + + +def init_app_state( + engine_client: EngineClient, + model_config: ModelConfig, + state: State, + args: Namespace, +) -> None: + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + if args.disable_log_requests: + request_logger = None + else: + request_logger = RequestLogger(max_log_len=args.max_log_len) + + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] + + state.engine_client = engine_client + state.log_stats = not args.disable_log_stats + + state.openai_serving_chat = OpenAIServingChat( + engine_client, + model_config, + base_model_paths, + args.response_role, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + request_logger=request_logger, + chat_template=args.chat_template, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser, + reasoning_parser=getattr(args, 'reasoning_parser', None)) + state.openai_serving_completion = OpenAIServingCompletion( + engine_client, + model_config, + base_model_paths, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + request_logger=request_logger, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + ) + state.openai_serving_embedding = OpenAIServingEmbedding( + engine_client, + model_config, + base_model_paths, + request_logger=request_logger, + ) + state.openai_serving_tokenization = OpenAIServingTokenization( + engine_client, + model_config, + base_model_paths, + lora_modules=args.lora_modules, + request_logger=request_logger, + chat_template=args.chat_template, + ) + + +async def run_server(args, **uvicorn_kwargs) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + valide_tool_parses = ToolParserManager.tool_parsers.keys() + if args.enable_auto_tool_choice \ + and args.tool_call_parser not in valide_tool_parses: + raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " + f"(chose from {{ {','.join(valide_tool_parses)} }})") + + reasoning_parser = getattr(args, 'reasoning_parser', None) + if reasoning_parser: + valid_reasoning = ReasoningParserManager.list_registered() + if reasoning_parser not in valid_reasoning: + raise KeyError( + f"invalid reasoning parser: {reasoning_parser} " + f"(chose from {{ {','.join(valid_reasoning)} }})") + + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("", args.port)) + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + async with build_async_engine_client(args) as engine_client: + app = build_app(args) + + model_config = await engine_client.get_model_config() + init_app_state(engine_client, model_config, app.state, args) + + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + fd=sock.fileno(), + **uvicorn_kwargs, + ) + + # NB: Await server shutdown only after the backend context is exited + await shutdown_task + + +if __name__ == "__main__": + # NOTE(simon): + # This section should be in sync with vllm/scripts.py for CLI entrypoints. + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible RESTful API server.") + parser = make_arg_parser(parser) + args = parser.parse_args() + validate_parsed_serve_args(args) + + uvloop.run(run_server(args)) diff --git a/qwen3_6_scripts/chat_utils.py b/qwen3_6_scripts/chat_utils.py new file mode 100644 index 0000000..76d92fb --- /dev/null +++ b/qwen3_6_scripts/chat_utils.py @@ -0,0 +1,598 @@ +import asyncio +import codecs +import json +from abc import ABC, abstractmethod +from collections import defaultdict +from functools import lru_cache, partial +from pathlib import Path +from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal, + Mapping, Optional, Tuple, TypeVar, Union, cast) + +# yapf conflicts with isort for this block +# yapf: disable +from openai.types.chat import (ChatCompletionAssistantMessageParam, + ChatCompletionContentPartImageParam) +from openai.types.chat import ( + ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) +from openai.types.chat import (ChatCompletionContentPartRefusalParam, + ChatCompletionContentPartTextParam) +from openai.types.chat import ( + ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) +from openai.types.chat import (ChatCompletionMessageToolCallParam, + ChatCompletionToolMessageParam) +# yapf: enable +# pydantic needs the TypedDict from typing_extensions +from pydantic import ConfigDict +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from typing_extensions import Required, TypeAlias, TypedDict + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.multimodal import MultiModalDataDict +from vllm.multimodal.utils import (async_get_and_parse_audio, + async_get_and_parse_image, + get_and_parse_audio, get_and_parse_image) +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer + +logger = init_logger(__name__) + + +class AudioURL(TypedDict, total=False): + url: Required[str] + """ + Either a URL of the audio or a data URL with base64 encoded audio data. + """ + + +class ChatCompletionContentPartAudioParam(TypedDict, total=False): + audio_url: Required[AudioURL] + + type: Required[Literal["audio_url"]] + """The type of the content part.""" + + +class CustomChatCompletionContentPartParam(TypedDict, total=False): + __pydantic_config__ = ConfigDict(extra="allow") # type: ignore + + type: Required[str] + """The type of the content part.""" + + +ChatCompletionContentPartParam: TypeAlias = Union[ + OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, + ChatCompletionContentPartRefusalParam, + CustomChatCompletionContentPartParam] + + +class CustomChatCompletionMessageParam(TypedDict, total=False): + """Enables custom roles in the Chat Completion API.""" + role: Required[str] + """The role of the message's author.""" + + content: Union[str, List[ChatCompletionContentPartParam]] + """The contents of the message.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the + same role. + """ + + tool_call_id: Optional[str] + """Tool call that this message is responding to.""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" + + reasoning_content: Optional[str] + """Reasoning / thinking content for assistant messages (vLLM extension). + When present in a previous assistant turn, it is rendered as + ... before the main content so the model sees its own + chain-of-thought in subsequent turns.""" + + +ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, + CustomChatCompletionMessageParam] + + +# TODO: Make fields ReadOnly once mypy supports it +class ConversationMessage(TypedDict, total=False): + role: Required[str] + """The role of the message's author.""" + + content: Optional[str] + """The contents of the message""" + + tool_call_id: Optional[str] + """Tool call that this message is responding to.""" + + name: Optional[str] + """The name of the function to call""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" + + +ModalityStr = Literal["image", "audio", "video"] +_T = TypeVar("_T") + + +class BaseMultiModalItemTracker(ABC, Generic[_T]): + """ + Tracks multi-modal items in a given request and ensures that the number + of multi-modal items in a given request does not exceed the configured + maximum per prompt. + """ + + def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer): + super().__init__() + + self._model_config = model_config + self._tokenizer = tokenizer + self._allowed_items = (model_config.multimodal_config.limit_per_prompt + if model_config.multimodal_config else {}) + self._consumed_items = {k: 0 for k in self._allowed_items} + + self._items: List[_T] = [] + + @staticmethod + @lru_cache(maxsize=None) + def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str: + return tokenizer.decode(token_index) + + def _placeholder_str(self, modality: ModalityStr, + current_count: int) -> Optional[str]: + # TODO: Let user specify how to insert image tokens into prompt + # (similar to chat template) + hf_config = self._model_config.hf_config + model_type = hf_config.model_type + + if modality == "image": + if model_type == "phi3_v": + # Workaround since this token is not defined in the tokenizer + return f"<|image_{current_count}|>" + if model_type == "minicpmv": + return "(./)" + if model_type in ("blip-2", "chatglm", "fuyu", "paligemma", + "pixtral"): + # These models do not use image tokens in the prompt + return None + if model_type == "qwen": + return f"Picture {current_count}: " + if model_type.startswith("llava"): + return self._cached_token_str(self._tokenizer, + hf_config.image_token_index) + if model_type in ("chameleon", "internvl_chat", "NVLM_D"): + return "" + if model_type == "mllama": + return "<|image|>" + if model_type in ("qwen2_vl","qwen2_5_vl"): + return "<|vision_start|><|image_pad|><|vision_end|>" + if model_type == "molmo": + return "" + + raise TypeError(f"Unknown model type: {model_type}") + elif modality == "audio": + if model_type == "ultravox": + return "<|reserved_special_token_0|>" + raise TypeError(f"Unknown model type: {model_type}") + elif modality == "video": + if model_type in ("qwen2_vl","qwen2_5_vl"): + return "<|vision_start|><|video_pad|><|vision_end|>" + raise TypeError(f"Unknown model type: {model_type}") + else: + raise TypeError(f"Unknown modality: {modality}") + + @staticmethod + def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict: + mm_lists: Mapping[str, List[object]] = defaultdict(list) + + # Merge all the multi-modal items + for single_mm_data in items: + for mm_key, mm_item in single_mm_data.items(): + if isinstance(mm_item, list): + mm_lists[mm_key].extend(mm_item) + else: + mm_lists[mm_key].append(mm_item) + + # Unpack any single item lists for models that don't expect multiple. + return { + mm_key: mm_list[0] if len(mm_list) == 1 else mm_list + for mm_key, mm_list in mm_lists.items() + } + + def add(self, modality: ModalityStr, item: _T) -> Optional[str]: + """ + Add a multi-modal item to the current prompt and returns the + placeholder string to use, if any. + """ + allowed_count = self._allowed_items.get(modality, 1) + current_count = self._consumed_items.get(modality, 0) + 1 + if current_count > allowed_count: + raise ValueError( + f"At most {allowed_count} {modality}(s) may be provided in " + "one request.") + + self._consumed_items[modality] = current_count + self._items.append(item) + + return self._placeholder_str(modality, current_count) + + @abstractmethod + def create_parser(self) -> "BaseMultiModalContentParser": + raise NotImplementedError + + +class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]): + + def all_mm_data(self) -> Optional[MultiModalDataDict]: + return self._combine(self._items) if self._items else None + + def create_parser(self) -> "BaseMultiModalContentParser": + return MultiModalContentParser(self) + + +class AsyncMultiModalItemTracker( + BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]): + + async def all_mm_data(self) -> Optional[MultiModalDataDict]: + if self._items: + items = await asyncio.gather(*self._items) + return self._combine(items) + + return None + + def create_parser(self) -> "BaseMultiModalContentParser": + return AsyncMultiModalContentParser(self) + + +class BaseMultiModalContentParser(ABC): + + def __init__(self) -> None: + super().__init__() + + # multimodal placeholder_string : count + self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0) + + def _add_placeholder(self, placeholder: Optional[str]): + if placeholder: + self._placeholder_counts[placeholder] += 1 + + def mm_placeholder_counts(self) -> Dict[str, int]: + return dict(self._placeholder_counts) + + @abstractmethod + def parse_image(self, image_url: str) -> None: + raise NotImplementedError + + @abstractmethod + def parse_audio(self, audio_url: str) -> None: + raise NotImplementedError + + +class MultiModalContentParser(BaseMultiModalContentParser): + + def __init__(self, tracker: MultiModalItemTracker) -> None: + super().__init__() + + self._tracker = tracker + + def parse_image(self, image_url: str) -> None: + image = get_and_parse_image(image_url) + + placeholder = self._tracker.add("image", image) + self._add_placeholder(placeholder) + + def parse_audio(self, audio_url: str) -> None: + audio = get_and_parse_audio(audio_url) + + placeholder = self._tracker.add("audio", audio) + self._add_placeholder(placeholder) + + +class AsyncMultiModalContentParser(BaseMultiModalContentParser): + + def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: + super().__init__() + + self._tracker = tracker + + def parse_image(self, image_url: str) -> None: + image_coro = async_get_and_parse_image(image_url) + + placeholder = self._tracker.add("image", image_coro) + self._add_placeholder(placeholder) + + def parse_audio(self, audio_url: str) -> None: + audio_coro = async_get_and_parse_audio(audio_url) + + placeholder = self._tracker.add("audio", audio_coro) + self._add_placeholder(placeholder) + + +def validate_chat_template(chat_template: Optional[Union[Path, str]]): + """Raises if the provided chat template appears invalid.""" + if chat_template is None: + return + + elif isinstance(chat_template, Path) and not chat_template.exists(): + raise FileNotFoundError( + "the supplied chat template path doesn't exist") + + elif isinstance(chat_template, str): + JINJA_CHARS = "{}\n" + if not any(c in chat_template + for c in JINJA_CHARS) and not Path(chat_template).exists(): + raise ValueError( + f"The supplied chat template string ({chat_template}) " + f"appears path-like, but doesn't exist!") + + else: + raise TypeError( + f"{type(chat_template)} is not a valid chat template type") + + +def load_chat_template( + chat_template: Optional[Union[Path, str]]) -> Optional[str]: + if chat_template is None: + return None + try: + with open(chat_template, "r") as f: + resolved_chat_template = f.read() + except OSError as e: + if isinstance(chat_template, Path): + raise + + JINJA_CHARS = "{}\n" + if not any(c in chat_template for c in JINJA_CHARS): + msg = (f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}") + raise ValueError(msg) from e + + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + resolved_chat_template = codecs.decode(chat_template, "unicode_escape") + + logger.info("Using supplied chat template:\n%s", resolved_chat_template) + return resolved_chat_template + + +# TODO: Let user specify how to insert multimodal tokens into prompt +# (similar to chat template) +def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], + text_prompt: str) -> str: + """Combine multimodal prompts for a multimodal language model.""" + + # Look through the text prompt to check for missing placeholders + missing_placeholders: List[str] = [] + for placeholder in placeholder_counts: + + # For any existing placeholder in the text prompt, we leave it as is + placeholder_counts[placeholder] -= text_prompt.count(placeholder) + + if placeholder_counts[placeholder] < 0: + raise ValueError( + f"Found more '{placeholder}' placeholders in input prompt than " + "actual multimodal data items.") + + missing_placeholders.extend([placeholder] * + placeholder_counts[placeholder]) + + # NOTE: For now we always add missing placeholders at the front of + # the prompt. This may change to be customizable in the future. + return "\n".join(missing_placeholders + [text_prompt]) + + +# No need to validate using Pydantic again +_TextParser = partial(cast, ChatCompletionContentPartTextParam) +_ImageParser = partial(cast, ChatCompletionContentPartImageParam) +_AudioParser = partial(cast, ChatCompletionContentPartAudioParam) +_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) +MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'} + + +def _parse_chat_message_content_parts( + role: str, + parts: Iterable[ChatCompletionContentPartParam], + mm_tracker: BaseMultiModalItemTracker, +) -> List[ConversationMessage]: + texts: List[str] = [] + + mm_parser = mm_tracker.create_parser() + keep_multimodal_content = \ + mm_tracker._model_config.hf_config.model_type in \ + MODEL_KEEP_MULTI_MODAL_CONTENT + + has_image = False + for part in parts: + part_type = part["type"] + if part_type == "text": + text = _TextParser(part)["text"] + texts.append(text) + elif part_type == "image_url": + image_url = _ImageParser(part)["image_url"] + + if image_url.get("detail", "auto") != "auto": + logger.warning( + "'image_url.detail' is currently not supported and " + "will be ignored.") + + mm_parser.parse_image(image_url["url"]) + has_image = True + elif part_type == "audio_url": + audio_url = _AudioParser(part)["audio_url"] + + mm_parser.parse_audio(audio_url["url"]) + elif part_type == "refusal": + text = _RefusalParser(part)["refusal"] + texts.append(text) + else: + raise NotImplementedError(f"Unknown part type: {part_type}") + + text_prompt = "\n".join(texts) + if keep_multimodal_content: + text_prompt = "\n".join(texts) + role_content = [{'type': 'text', 'text': text_prompt}] + + if has_image: + role_content = [{'type': 'image'}] + role_content + return [ConversationMessage(role=role, + content=role_content)] # type: ignore + else: + mm_placeholder_counts = mm_parser.mm_placeholder_counts() + if mm_placeholder_counts: + text_prompt = _get_full_multimodal_text_prompt( + mm_placeholder_counts, text_prompt) + return [ConversationMessage(role=role, content=text_prompt)] + + +# No need to validate using Pydantic again +_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam) +_ToolParser = partial(cast, ChatCompletionToolMessageParam) + + +def _parse_chat_message_content( + message: ChatCompletionMessageParam, + mm_tracker: BaseMultiModalItemTracker, +) -> List[ConversationMessage]: + role = message["role"] + content = message.get("content") + + if content is None: + content = [] + elif isinstance(content, str): + content = [ + ChatCompletionContentPartTextParam(type="text", text=content) + ] + + result = _parse_chat_message_content_parts( + role, + content, # type: ignore + mm_tracker, + ) + + for result_msg in result: + if role == 'assistant': + parsed_msg = _AssistantParser(message) + + if "tool_calls" in parsed_msg: + result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) + + # Prepend reasoning_content as ... so the model + # sees its own chain-of-thought in multi-turn conversations. + reasoning = message.get("reasoning_content") # type: ignore[arg-type] + if reasoning and isinstance(reasoning, str): + existing = result_msg.get("content") or "" + result_msg["content"] = ( + f"{reasoning}\n\n{existing}" + if existing else f"{reasoning}" + ) + + elif role == "tool": + parsed_msg = _ToolParser(message) + if "tool_call_id" in parsed_msg: + result_msg["tool_call_id"] = parsed_msg["tool_call_id"] + + if "name" in message and isinstance(message["name"], str): + result_msg["name"] = message["name"] + + return result + + +def _postprocess_messages(messages: List[ConversationMessage]) -> None: + # per the Transformers docs & maintainers, tool call arguments in + # assistant-role messages with tool_calls need to be dicts not JSON str - + # this is how tool-use chat templates will expect them moving forwards + # so, for messages that have tool_calls, parse the string (which we get + # from openAI format) to dict + for message in messages: + if (message["role"] == "assistant" and "tool_calls" in message + and isinstance(message["tool_calls"], list)): + + for item in message["tool_calls"]: + item["function"]["arguments"] = json.loads( + item["function"]["arguments"]) + + +def parse_chat_messages( + messages: List[ChatCompletionMessageParam], + model_config: ModelConfig, + tokenizer: AnyTokenizer, +) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]: + conversation: List[ConversationMessage] = [] + mm_tracker = MultiModalItemTracker(model_config, tokenizer) + + for msg in messages: + sub_messages = _parse_chat_message_content(msg, mm_tracker) + + conversation.extend(sub_messages) + + _postprocess_messages(conversation) + + return conversation, mm_tracker.all_mm_data() + + +def parse_chat_messages_futures( + messages: List[ChatCompletionMessageParam], + model_config: ModelConfig, + tokenizer: AnyTokenizer, +) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]: + conversation: List[ConversationMessage] = [] + mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) + + for msg in messages: + sub_messages = _parse_chat_message_content(msg, mm_tracker) + + conversation.extend(sub_messages) + + _postprocess_messages(conversation) + + return conversation, mm_tracker.all_mm_data() + + +def apply_hf_chat_template( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + conversation: List[ConversationMessage], + chat_template: Optional[str], + *, + tokenize: bool = False, # Different from HF's default + **kwargs: Any, +) -> str: + if chat_template is None and tokenizer.chat_template is None: + raise ValueError( + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one.") + + return tokenizer.apply_chat_template( + conversation=conversation, # type: ignore[arg-type] + chat_template=chat_template, + tokenize=tokenize, + **kwargs, + ) + + +def apply_mistral_chat_template( + tokenizer: MistralTokenizer, + messages: List[ChatCompletionMessageParam], + chat_template: Optional[str] = None, + **kwargs: Any, +) -> List[int]: + if chat_template is not None: + logger.warning( + "'chat_template' cannot be overridden for mistral tokenizer.") + if "add_generation_prompt" in kwargs: + logger.warning( + "'add_generation_prompt' is not supported for mistral tokenizer, " + "so it will be ignored.") + if "continue_final_message" in kwargs: + logger.warning( + "'continue_final_message' is not supported for mistral tokenizer, " + "so it will be ignored.") + + return tokenizer.apply_chat_template( + messages=messages, + **kwargs, + ) diff --git a/qwen3_6_scripts/cli_args.py b/qwen3_6_scripts/cli_args.py new file mode 100644 index 0000000..ad0698d --- /dev/null +++ b/qwen3_6_scripts/cli_args.py @@ -0,0 +1,261 @@ +""" +This file contains the command line arguments for the vLLM's +OpenAI-compatible server. It is kept in a separate file for documentation +purposes. +""" + +import argparse +import json +import ssl +from typing import List, Optional, Sequence, Union + +from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.entrypoints.chat_utils import validate_chat_template +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + PromptAdapterPath) +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.utils import FlexibleArgumentParser + + +class LoRAParserAction(argparse.Action): + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Optional[Union[str, Sequence[str]]], + option_string: Optional[str] = None, + ): + if values is None: + values = [] + if isinstance(values, str): + raise TypeError("Expected values to be a list") + + lora_list: List[LoRAModulePath] = [] + for item in values: + if item in [None, '']: # Skip if item is None or empty string + continue + if '=' in item and ',' not in item: # Old format: name=path + name, path = item.split('=') + lora_list.append(LoRAModulePath(name, path)) + else: # Assume JSON format + try: + lora_dict = json.loads(item) + lora = LoRAModulePath(**lora_dict) + lora_list.append(lora) + except json.JSONDecodeError: + parser.error( + f"Invalid JSON format for --lora-modules: {item}") + except TypeError as e: + parser.error( + f"Invalid fields for --lora-modules: {item} - {str(e)}" + ) + setattr(namespace, self.dest, lora_list) + + +class PromptAdapterParserAction(argparse.Action): + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Optional[Union[str, Sequence[str]]], + option_string: Optional[str] = None, + ): + if values is None: + values = [] + if isinstance(values, str): + raise TypeError("Expected values to be a list") + + adapter_list: List[PromptAdapterPath] = [] + for item in values: + name, path = item.split('=') + adapter_list.append(PromptAdapterPath(name, path)) + setattr(namespace, self.dest, adapter_list) + + +def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + parser.add_argument("--host", + type=nullable_str, + default=None, + help="host name") + parser.add_argument("--port", type=int, default=8000, help="port number") + parser.add_argument( + "--uvicorn-log-level", + type=str, + default="info", + choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'], + help="log level for uvicorn") + parser.add_argument("--allow-credentials", + action="store_true", + help="allow credentials") + parser.add_argument("--allowed-origins", + type=json.loads, + default=["*"], + help="allowed origins") + parser.add_argument("--allowed-methods", + type=json.loads, + default=["*"], + help="allowed methods") + parser.add_argument("--allowed-headers", + type=json.loads, + default=["*"], + help="allowed headers") + parser.add_argument("--api-key", + type=nullable_str, + default=None, + help="If provided, the server will require this key " + "to be presented in the header.") + parser.add_argument( + "--lora-modules", + type=nullable_str, + default=None, + nargs='+', + action=LoRAParserAction, + help="LoRA module configurations in either 'name=path' format" + "or JSON format. " + "Example (old format): 'name=path' " + "Example (new format): " + "'{\"name\": \"name\", \"local_path\": \"path\", " + "\"base_model_name\": \"id\"}'") + parser.add_argument( + "--prompt-adapters", + type=nullable_str, + default=None, + nargs='+', + action=PromptAdapterParserAction, + help="Prompt adapter configurations in the format name=path. " + "Multiple adapters can be specified.") + parser.add_argument("--chat-template", + type=nullable_str, + default=None, + help="The file path to the chat template, " + "or the template in single-line form " + "for the specified model") + parser.add_argument("--response-role", + type=nullable_str, + default="assistant", + help="The role name to return if " + "`request.add_generation_prompt=true`.") + parser.add_argument("--ssl-keyfile", + type=nullable_str, + default=None, + help="The file path to the SSL key file") + parser.add_argument("--ssl-certfile", + type=nullable_str, + default=None, + help="The file path to the SSL cert file") + parser.add_argument("--ssl-ca-certs", + type=nullable_str, + default=None, + help="The CA certificates file") + parser.add_argument( + "--ssl-cert-reqs", + type=int, + default=int(ssl.CERT_NONE), + help="Whether client certificate is required (see stdlib ssl module's)" + ) + parser.add_argument( + "--root-path", + type=nullable_str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy") + parser.add_argument( + "--middleware", + type=nullable_str, + action="append", + default=[], + help="Additional ASGI middleware to apply to the app. " + "We accept multiple --middleware arguments. " + "The value should be an import path. " + "If a function is provided, vLLM will add it to the server " + "using @app.middleware('http'). " + "If a class is provided, vLLM will add it to the server " + "using app.add_middleware(). ") + parser.add_argument( + "--return-tokens-as-token-ids", + action="store_true", + help="When --max-logprobs is specified, represents single tokens as " + "strings of the form 'token_id:{token_id}' so that tokens that " + "are not JSON-encodable can be identified.") + parser.add_argument( + "--disable-frontend-multiprocessing", + action="store_true", + help="If specified, will run the OpenAI frontend server in the same " + "process as the model serving engine.") + + parser.add_argument( + "--enable-auto-tool-choice", + action="store_true", + default=False, + help= + "Enable auto tool choice for supported models. Use --tool-call-parser" + "to specify which parser to use") + + valid_tool_parsers = ToolParserManager.tool_parsers.keys() + parser.add_argument( + "--tool-call-parser", + type=str, + metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in " + "--tool-parser-plugin", + default=None, + help= + "Select the tool call parser depending on the model that you're using." + " This is used to parse the model-generated tool call into OpenAI API " + "format. Required for --enable-auto-tool-choice.") + + parser.add_argument( + "--tool-parser-plugin", + type=str, + default="", + help= + "Special the tool parser plugin write to parse the model-generated tool" + " into OpenAI API format, the name register in this plugin can be used " + "in --tool-call-parser.") + + parser.add_argument( + "--reasoning-parser", + type=str, + default=None, + help= + "Select the reasoning parser to split ... content into " + "reasoning_content vs content in the response. " + "Supported: qwen3") + + parser = AsyncEngineArgs.add_cli_args(parser) + + parser.add_argument('--max-log-len', + type=int, + default=None, + help='Max number of prompt characters or prompt ' + 'ID numbers being printed in log.' + '\n\nDefault: Unlimited') + + parser.add_argument( + "--disable-fastapi-docs", + action='store_true', + default=False, + help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint" + ) + + return parser + + +def validate_parsed_serve_args(args: argparse.Namespace): + """Quick checks for model serve args that raise prior to loading.""" + if hasattr(args, "subparser") and args.subparser != "serve": + return + + # Ensure that the chat template is valid; raises if it likely isn't + validate_chat_template(args.chat_template) + + # Enable auto tool needs a tool call parser to be valid + if args.enable_auto_tool_choice and not args.tool_call_parser: + raise TypeError("Error: --enable-auto-tool-choice requires " + "--tool-call-parser") + + +def create_parser_for_docs() -> FlexibleArgumentParser: + parser_for_docs = FlexibleArgumentParser( + prog="-m vllm.entrypoints.openai.api_server") + return make_arg_parser(parser_for_docs) diff --git a/qwen3_6_scripts/patch_ops.sh b/qwen3_6_scripts/patch_ops.sh index 7d8c188..213f031 100755 --- a/qwen3_6_scripts/patch_ops.sh +++ b/qwen3_6_scripts/patch_ops.sh @@ -58,3 +58,14 @@ python3 ./patch_xformers_sdpa_seq.py # Use at server start: --tool-call-parser qwen3_coder --enable-auto-tool-choice cp ./qwen3coder_tool_parser.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/tool_parsers/ python3 ./patch_vllm_tool_parser.py + +# --- reasoning parser: Qwen3 ... split ------------------------ +# Adds --reasoning-parser qwen3 support. +# Routes thinking tokens to reasoning_content, rest to content in the delta. +# Works together with --tool-call-parser qwen3_coder (think → tool call flow). +cp -r ./reasoning /usr/local/corex/lib/python3/dist-packages/vllm/ +cp ./protocol.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/protocol.py +cp ./cli_args.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/cli_args.py +cp ./serving_chat.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/serving_chat.py +cp ./api_server.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/openai/api_server.py +cp ./chat_utils.py /usr/local/corex/lib/python3/dist-packages/vllm/entrypoints/chat_utils.py diff --git a/qwen3_6_scripts/protocol.py b/qwen3_6_scripts/protocol.py new file mode 100644 index 0000000..17e495b --- /dev/null +++ b/qwen3_6_scripts/protocol.py @@ -0,0 +1,995 @@ +# Adapted from +# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py +import time +from argparse import Namespace +from typing import Any, Dict, List, Literal, Optional, Union + +import torch +from openai.types.chat import ChatCompletionContentPartParam +from pydantic import BaseModel, ConfigDict, Field, model_validator +from typing_extensions import Annotated, Required, TypedDict + +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, + RequestOutputKind, SamplingParams) +from vllm.sequence import Logprob +from vllm.utils import random_uuid + +# torch is mocked during docs generation, +# so we have to provide the values as literals +_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) +_LONG_INFO: Union["torch.iinfo", Namespace] + +try: + from sphinx.ext.autodoc.mock import _MockModule + + if isinstance(torch, _MockModule): + _LONG_INFO = _MOCK_LONG_INFO + else: + _LONG_INFO = torch.iinfo(torch.long) +except ModuleNotFoundError: + _LONG_INFO = torch.iinfo(torch.long) + +assert _LONG_INFO.min == _MOCK_LONG_INFO.min +assert _LONG_INFO.max == _MOCK_LONG_INFO.max + + +class CustomChatCompletionMessageParam(TypedDict, total=False): + """Enables custom roles in the Chat Completion API.""" + role: Required[str] + """The role of the message's author.""" + + content: Union[str, List[ChatCompletionContentPartParam]] + """The contents of the message.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the + same role. + """ + + tool_call_id: Optional[str] + + tool_calls: Optional[List[dict]] + + +class OpenAIBaseModel(BaseModel): + # OpenAI API does not allow extra fields + model_config = ConfigDict(extra="forbid") + + +class ErrorResponse(OpenAIBaseModel): + object: str = "error" + message: str + type: str + param: Optional[str] = None + code: int + + +class ModelPermission(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = False + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: bool = False + + +class ModelCard(OpenAIBaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "vllm" + root: Optional[str] = None + parent: Optional[str] = None + max_model_len: Optional[int] = None + permission: List[ModelPermission] = Field(default_factory=list) + + +class ModelList(OpenAIBaseModel): + object: str = "list" + data: List[ModelCard] = Field(default_factory=list) + + +class UsageInfo(OpenAIBaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + reasoning_tokens: Optional[int] = None + + +class RequestResponseMetadata(BaseModel): + request_id: str + final_usage_info: Optional[UsageInfo] = None + + +class JsonSchemaResponseFormat(OpenAIBaseModel): + name: str + description: Optional[str] = None + # schema is the field in openai but that causes conflicts with pydantic so + # instead use json_schema with an alias + json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema') + strict: Optional[bool] = None + + +class ResponseFormat(OpenAIBaseModel): + # type must be "json_schema", "json_object" or "text" + type: Literal["text", "json_object", "json_schema"] + json_schema: Optional[JsonSchemaResponseFormat] = None + + +class StreamOptions(OpenAIBaseModel): + include_usage: Optional[bool] = True + continuous_usage_stats: Optional[bool] = True + + +class FunctionDefinition(OpenAIBaseModel): + name: str + description: Optional[str] = None + parameters: Optional[Dict[str, Any]] = None + + +class ChatCompletionToolsParam(OpenAIBaseModel): + type: Literal["function"] = "function" + function: FunctionDefinition + + +class ChatCompletionNamedFunction(OpenAIBaseModel): + name: str + + +class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): + function: ChatCompletionNamedFunction + type: Literal["function"] = "function" + + +class ChatCompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/chat/create + messages: List[ChatCompletionMessageParam] + model: str + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = 0 + max_tokens: Optional[int] = None + n: Optional[int] = 1 + presence_penalty: Optional[float] = 0.0 + response_format: Optional[ResponseFormat] = None + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + tools: Optional[List[ChatCompletionToolsParam]] = None + tool_choice: Optional[Union[Literal["none"], Literal["auto"], + ChatCompletionNamedToolChoiceParam]] = "none" + + # NOTE this will be ignored by VLLM -- the model determines the behavior + parallel_tool_calls: Optional[bool] = False + user: Optional[str] = None + + # doc: begin-chat-completion-sampling-params + best_of: Optional[int] = None + use_beam_search: bool = False + top_k: int = -1 + min_p: float = 0.0 + repetition_penalty: float = 1.0 + length_penalty: float = 1.0 + stop_token_ids: Optional[List[int]] = Field(default_factory=list) + include_stop_str_in_output: bool = False + ignore_eos: bool = False + min_tokens: int = 0 + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + prompt_logprobs: Optional[int] = None + # doc: end-chat-completion-sampling-params + + # doc: begin-chat-completion-extra-params + echo: bool = Field( + default=False, + description=( + "If true, the new message will be prepended with the last message " + "if they belong to the same role."), + ) + add_generation_prompt: bool = Field( + default=True, + description= + ("If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model."), + ) + continue_final_message: bool = Field( + default=False, + description= + ("If this is set, the chat will be formatted so that the final " + "message in the chat is open-ended, without any EOS tokens. The " + "model will continue this message rather than starting a new one. " + "This allows you to \"prefill\" part of the model's response for it. " + "Cannot be used at the same time as `add_generation_prompt`."), + ) + add_special_tokens: bool = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to false (as is the " + "default)."), + ) + documents: Optional[List[Dict[str, str]]] = Field( + default=None, + description= + ("A list of dicts representing documents that will be accessible to " + "the model if it is performing RAG (retrieval-augmented generation)." + " If the template does not support RAG, this argument will have no " + "effect. We recommend that each document should be a dict containing " + "\"title\" and \"text\" keys."), + ) + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one."), + ) + chat_template_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template."), + ) + guided_json: Optional[Union[str, dict, BaseModel]] = Field( + default=None, + description=("If specified, the output will follow the JSON schema."), + ) + guided_regex: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the regex pattern."), + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description=( + "If specified, the output will be exactly one of the choices."), + ) + guided_grammar: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the context free grammar."), + ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) + + # doc: end-chat-completion-extra-params + + def to_beam_search_params(self, + default_max_tokens: int) -> BeamSearchParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens + + n = self.n if self.n is not None else 1 + temperature = self.temperature if self.temperature is not None else 0.0 + + return BeamSearchParams( + beam_width=n, + max_tokens=max_tokens, + ignore_eos=self.ignore_eos, + temperature=temperature, + length_penalty=self.length_penalty, + ) + + def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens + + prompt_logprobs = self.prompt_logprobs + if prompt_logprobs is None and self.echo: + prompt_logprobs = self.top_logprobs + + guided_json_object = None + if (self.response_format is not None + and self.response_format.type == "json_object"): + guided_json_object = True + + guided_decoding = GuidedDecodingParams.from_optional( + json=self._get_guided_json_from_tool() or self.guided_json, + regex=self.guided_regex, + choice=self.guided_choice, + grammar=self.guided_grammar, + json_object=guided_json_object, + backend=self.guided_decoding_backend, + whitespace_pattern=self.guided_whitespace_pattern) + + return SamplingParams.from_optional( + n=self.n, + best_of=self.best_of, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + min_p=self.min_p, + seed=self.seed, + stop=self.stop, + stop_token_ids=self.stop_token_ids, + logprobs=self.top_logprobs if self.logprobs else None, + prompt_logprobs=prompt_logprobs, + ignore_eos=self.ignore_eos, + max_tokens=max_tokens, + min_tokens=self.min_tokens, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + include_stop_str_in_output=self.include_stop_str_in_output, + truncate_prompt_tokens=self.truncate_prompt_tokens, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, + guided_decoding=guided_decoding, + logit_bias=self.logit_bias) + + def _get_guided_json_from_tool( + self) -> Optional[Union[str, dict, BaseModel]]: + # user has chosen to not use any tool + if self.tool_choice == "none" or self.tools is None: + return None + + # user has chosen to use a named tool + if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam: + tool_name = self.tool_choice.function.name + tools = {tool.function.name: tool.function for tool in self.tools} + if tool_name not in tools: + raise ValueError( + f"Tool '{tool_name}' has not been passed in `tools`.") + tool = tools[tool_name] + return tool.parameters + + return None + + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + if data.get("stream_options") and not data.get("stream"): + raise ValueError( + "Stream options can only be defined when `stream=True`.") + + return data + + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if (prompt_logprobs := data.get("prompt_logprobs")) is not None: + if data.get("stream") and prompt_logprobs > 0: + raise ValueError( + "`prompt_logprobs` are not available when `stream=True`.") + + if prompt_logprobs < 0: + raise ValueError("`prompt_logprobs` must be a positive value.") + + if (top_logprobs := data.get("top_logprobs")) is not None: + if top_logprobs < 0: + raise ValueError("`top_logprobs` must be a positive value.") + + if not data.get("logprobs"): + raise ValueError( + "when using `top_logprobs`, `logprobs` must be set to true." + ) + + return data + + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + if isinstance(data, ValueError): + raise data + + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + # you can only use one kind of guided decoding + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + # you can only either use guided decoding or tools, not both + if guide_count > 1 and data.get("tool_choice", + "none") not in ("none", "auto"): + raise ValueError( + "You can only either use guided decoding or tools, not both.") + return data + + @model_validator(mode="before") + @classmethod + def check_tool_usage(cls, data): + + # if "tool_choice" is not specified but tools are provided, + # default to "auto" tool_choice + if "tool_choice" not in data and data.get("tools"): + data["tool_choice"] = "auto" + + # if "tool_choice" is specified -- validation + if "tool_choice" in data: + + # ensure that if "tool choice" is specified, tools are present + if "tools" not in data or data["tools"] is None: + raise ValueError( + "When using `tool_choice`, `tools` must be set.") + + # make sure that tool choice is either a named tool + # OR that it's set to "auto" + if data["tool_choice"] != "auto" and not isinstance( + data["tool_choice"], dict): + raise ValueError( + "`tool_choice` must either be a named tool or \"auto\". " + "`tool_choice=\"none\" is not supported.") + + # ensure that if "tool_choice" is specified as an object, + # it matches a valid tool + if isinstance(data["tool_choice"], dict): + valid_tool = False + specified_function = data["tool_choice"]["function"] + if not specified_function: + raise ValueError( + "Incorrectly formatted `tool_choice`. Should be like " + "`{\"type\": \"function\"," + " \"function\": {\"name\": \"my_function\"}}`") + specified_function_name = specified_function["name"] + if not specified_function_name: + raise ValueError( + "Incorrectly formatted `tool_choice`. Should be like " + "`{\"type\": \"function\", " + "\"function\": {\"name\": \"my_function\"}}`") + for tool in data["tools"]: + if tool["function"]["name"] == specified_function_name: + valid_tool = True + break + if not valid_tool: + raise ValueError( + "The tool specified in `tool_choice` does not match any" + " of the specified `tools`") + return data + + @model_validator(mode="before") + @classmethod + def check_generation_prompt(cls, data): + if data.get("continue_final_message") and data.get( + "add_generation_prompt"): + raise ValueError("Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True.") + return data + + +class CompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/completions/create + model: str + prompt: Union[List[int], List[List[int]], str, List[str]] + best_of: Optional[int] = None + echo: Optional[bool] = False + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + logprobs: Optional[int] = None + max_tokens: Optional[int] = 16 + n: int = 1 + presence_penalty: Optional[float] = 0.0 + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None + suffix: Optional[str] = None + temperature: Optional[float] = 1.0 + top_p: Optional[float] = 1.0 + user: Optional[str] = None + + # doc: begin-completion-sampling-params + use_beam_search: bool = False + top_k: int = -1 + min_p: float = 0.0 + repetition_penalty: float = 1.0 + length_penalty: float = 1.0 + stop_token_ids: Optional[List[int]] = Field(default_factory=list) + include_stop_str_in_output: bool = False + ignore_eos: bool = False + min_tokens: int = 0 + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + allowed_token_ids: Optional[List[int]] = None + prompt_logprobs: Optional[int] = None + # doc: end-completion-sampling-params + + # doc: begin-completion-extra-params + add_special_tokens: bool = Field( + default=True, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt."), + ) + response_format: Optional[ResponseFormat] = Field( + default=None, + description= + ("Similar to chat completion, this parameter specifies the format of " + "output. Only {'type': 'json_object'} or {'type': 'text' } is " + "supported."), + ) + guided_json: Optional[Union[str, dict, BaseModel]] = Field( + default=None, + description="If specified, the output will follow the JSON schema.", + ) + guided_regex: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the regex pattern."), + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description=( + "If specified, the output will be exactly one of the choices."), + ) + guided_grammar: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the context free grammar."), + ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be one of " + "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) + + # doc: end-completion-extra-params + + def to_beam_search_params(self, + default_max_tokens: int) -> BeamSearchParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens + + n = self.n if self.n is not None else 1 + temperature = self.temperature if self.temperature is not None else 0.0 + + return BeamSearchParams( + beam_width=n, + max_tokens=max_tokens, + ignore_eos=self.ignore_eos, + temperature=temperature, + length_penalty=self.length_penalty, + ) + + def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens + + prompt_logprobs = self.prompt_logprobs + if prompt_logprobs is None and self.echo: + prompt_logprobs = self.logprobs + + echo_without_generation = self.echo and self.max_tokens == 0 + + guided_json_object = None + if (self.response_format is not None + and self.response_format.type == "json_object"): + guided_json_object = True + + guided_decoding = GuidedDecodingParams.from_optional( + json=self.guided_json, + regex=self.guided_regex, + choice=self.guided_choice, + grammar=self.guided_grammar, + json_object=guided_json_object, + backend=self.guided_decoding_backend, + whitespace_pattern=self.guided_whitespace_pattern) + + return SamplingParams.from_optional( + n=self.n, + best_of=self.best_of, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + min_p=self.min_p, + seed=self.seed, + stop=self.stop, + stop_token_ids=self.stop_token_ids, + logprobs=self.logprobs, + ignore_eos=self.ignore_eos, + max_tokens=max_tokens if not echo_without_generation else 1, + min_tokens=self.min_tokens, + prompt_logprobs=prompt_logprobs, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + include_stop_str_in_output=self.include_stop_str_in_output, + truncate_prompt_tokens=self.truncate_prompt_tokens, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, + guided_decoding=guided_decoding, + logit_bias=self.logit_bias, + allowed_token_ids=self.allowed_token_ids) + + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + return data + + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if (prompt_logprobs := data.get("prompt_logprobs")) is not None: + if data.get("stream") and prompt_logprobs > 0: + raise ValueError( + "`prompt_logprobs` are not available when `stream=True`.") + + if prompt_logprobs < 0: + raise ValueError("`prompt_logprobs` must be a positive value.") + + if (logprobs := data.get("logprobs")) is not None and logprobs < 0: + raise ValueError("`logprobs` must be a positive value.") + + return data + + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + if data.get("stream_options") and not data.get("stream"): + raise ValueError( + "Stream options can only be defined when `stream=True`.") + + return data + + +class EmbeddingRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/embeddings + model: str + input: Union[List[int], List[List[int]], str, List[str]] + encoding_format: Literal["float", "base64"] = "float" + dimensions: Optional[int] = None + user: Optional[str] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + + # doc: begin-embedding-pooling-params + additional_data: Optional[Any] = None + + # doc: end-embedding-pooling-params + + # doc: begin-embedding-extra-params + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) + + # doc: end-embedding-extra-params + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + +class CompletionLogProbs(OpenAIBaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, + float]]] = Field(default_factory=list) + + +class CompletionResponseChoice(OpenAIBaseModel): + index: int + text: str + logprobs: Optional[CompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) + prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None + + +class CompletionResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(OpenAIBaseModel): + index: int + text: str + logprobs: Optional[CompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) + + +class CompletionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class EmbeddingResponseData(OpenAIBaseModel): + index: int + object: str = "embedding" + embedding: Union[List[float], str] + + +class EmbeddingResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: List[EmbeddingResponseData] + usage: UsageInfo + + +class FunctionCall(OpenAIBaseModel): + name: str + arguments: str + + +class ToolCall(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + type: Literal["function"] = "function" + function: FunctionCall + + +class DeltaFunctionCall(BaseModel): + name: Optional[str] = None + arguments: Optional[str] = None + + +# a tool call delta where everything is optional +class DeltaToolCall(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + type: Literal["function"] = "function" + index: int + function: Optional[DeltaFunctionCall] = None + + +class ExtractedToolCallInformation(BaseModel): + # indicate if tools were called + tools_called: bool + + # extracted tool calls + tool_calls: List[ToolCall] + + # content - per OpenAI spec, content AND tool calls can be returned rarely + # But some models will do this intentionally + content: Optional[str] = None + + +class ChatMessage(OpenAIBaseModel): + role: str + reasoning_content: Optional[str] = None + content: Optional[str] = None + tool_calls: List[ToolCall] = Field(default_factory=list) + + +class ChatCompletionLogProb(OpenAIBaseModel): + token: str + logprob: float = -9999.0 + bytes: Optional[List[int]] = None + + +class ChatCompletionLogProbsContent(ChatCompletionLogProb): + top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list) + + +class ChatCompletionLogProbs(OpenAIBaseModel): + content: Optional[List[ChatCompletionLogProbsContent]] = None + + +class ChatCompletionResponseChoice(OpenAIBaseModel): + index: int + message: ChatMessage + logprobs: Optional[ChatCompletionLogProbs] = None + # per OpenAI spec this is the default + finish_reason: Optional[str] = "stop" + # not part of the OpenAI spec but included in vLLM for legacy reasons + stop_reason: Optional[Union[int, str]] = None + + +class ChatCompletionResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: Literal["chat.completion"] = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None + + +class DeltaMessage(OpenAIBaseModel): + role: Optional[str] = None + reasoning_content: Optional[str] = None + content: Optional[str] = None + tool_calls: List[DeltaToolCall] = Field(default_factory=list) + + +class ChatCompletionResponseStreamChoice(OpenAIBaseModel): + index: int + delta: DeltaMessage + logprobs: Optional[ChatCompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None + + +class ChatCompletionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class BatchRequestInput(OpenAIBaseModel): + """ + The per-line object of the batch input file. + + NOTE: Currently only the `/v1/chat/completions` endpoint is supported. + """ + + # A developer-provided per-request id that will be used to match outputs to + # inputs. Must be unique for each request in a batch. + custom_id: str + + # The HTTP method to be used for the request. Currently only POST is + # supported. + method: str + + # The OpenAI API relative URL to be used for the request. Currently + # /v1/chat/completions is supported. + url: str + + # The parameters of the request. + body: Union[ChatCompletionRequest, EmbeddingRequest] + + +class BatchResponseData(OpenAIBaseModel): + # HTTP status code of the response. + status_code: int = 200 + + # An unique identifier for the API request. + request_id: str + + # The body of the response. + body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None + + +class BatchRequestOutput(OpenAIBaseModel): + """ + The per-line object of the batch output and error files + """ + + id: str + + # A developer-provided per-request id that will be used to match outputs to + # inputs. + custom_id: str + + response: Optional[BatchResponseData] + + # For requests that failed with a non-HTTP error, this will contain more + # information on the cause of the failure. + error: Optional[Any] + + +class TokenizeCompletionRequest(OpenAIBaseModel): + model: str + prompt: str + + add_special_tokens: bool = Field(default=True) + + +class TokenizeChatRequest(OpenAIBaseModel): + model: str + messages: List[ChatCompletionMessageParam] + + add_generation_prompt: bool = Field(default=True) + continue_final_message: bool = Field(default=False) + add_special_tokens: bool = Field(default=False) + + @model_validator(mode="before") + @classmethod + def check_generation_prompt(cls, data): + if data.get("continue_final_message") and data.get( + "add_generation_prompt"): + raise ValueError("Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True.") + return data + + +TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest] + + +class TokenizeResponse(OpenAIBaseModel): + count: int + max_model_len: int + tokens: List[int] + + +class DetokenizeRequest(OpenAIBaseModel): + model: str + tokens: List[int] + + +class DetokenizeResponse(OpenAIBaseModel): + prompt: str + + +class LoadLoraAdapterRequest(BaseModel): + lora_name: str + lora_path: str + + +class UnloadLoraAdapterRequest(BaseModel): + lora_name: str + lora_int_id: Optional[int] = Field(default=None) diff --git a/qwen3_6_scripts/reasoning/__init__.py b/qwen3_6_scripts/reasoning/__init__.py new file mode 100644 index 0000000..5f2e50d --- /dev/null +++ b/qwen3_6_scripts/reasoning/__init__.py @@ -0,0 +1,16 @@ +""" +Reasoning parser module for vLLM 0.6.3 (BI-V100 / Qwen3.6-27B adaptation). + +Usage: --reasoning-parser qwen3 +""" + +from vllm.reasoning.abs_reasoning_parsers import ReasoningParser, ReasoningParserManager + +__all__ = ["ReasoningParser", "ReasoningParserManager"] + +# Lazy-register Qwen3 parser; imported on first get_reasoning_parser("qwen3"). +ReasoningParserManager.register_lazy( + "qwen3", + "vllm.reasoning.qwen3_reasoning_parser", + "Qwen3ReasoningParser", +) diff --git a/qwen3_6_scripts/reasoning/abs_reasoning_parsers.py b/qwen3_6_scripts/reasoning/abs_reasoning_parsers.py new file mode 100644 index 0000000..c614107 --- /dev/null +++ b/qwen3_6_scripts/reasoning/abs_reasoning_parsers.py @@ -0,0 +1,243 @@ +""" +Abstract reasoning parser base classes for vLLM 0.6.3. +Adapted from vllm-original/vllm/reasoning/abs_reasoning_parsers.py: + - Removed vllm.entrypoints.mcp, vllm.utils.collection_utils, import_utils + - DeltaMessage from vllm 0.6.3 protocol path + - TokenizerLike -> AnyTokenizer + - ReasoningParserManager: simplified eager + lazy registration +""" + +import importlib +from abc import abstractmethod +from collections.abc import Iterable, Sequence +from functools import cached_property +from typing import Any, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from vllm.entrypoints.openai.protocol import DeltaMessage + from vllm.transformers_utils.tokenizer import AnyTokenizer +else: + DeltaMessage = Any + AnyTokenizer = Any + + +class ReasoningParser: + """Abstract base for all reasoning parsers.""" + + def __init__(self, tokenizer: "AnyTokenizer", *args, **kwargs): + self.model_tokenizer = tokenizer + + @cached_property + def vocab(self) -> dict: + return self.model_tokenizer.get_vocab() + + @abstractmethod + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + """Return True once the reasoning block has closed in input_ids.""" + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Iterable[int] + ) -> bool: + return self.is_reasoning_end(input_ids) + + @abstractmethod + def extract_content_ids(self, input_ids: list) -> list: + """Return token ids that belong to the content (post-reasoning) part.""" + + def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int: + return 0 + + @abstractmethod + def extract_reasoning( + self, model_output: str, request: Any + ) -> "tuple[Optional[str], Optional[str]]": + """ + Split a complete model output into (reasoning_text, content_text). + Either part may be None. + """ + + @abstractmethod + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Optional["DeltaMessage"]: + """ + Extract reasoning from a streaming delta. + Returns a DeltaMessage with reasoning_content and/or content set, + or None if this delta should be suppressed (control token). + """ + + +class BaseThinkingReasoningParser(ReasoningParser): + """ + Base for parsers that use ... delimiters. + Subclasses define start_token / end_token properties. + """ + + @property + @abstractmethod + def start_token(self) -> str: + raise NotImplementedError + + @property + @abstractmethod + def end_token(self) -> str: + raise NotImplementedError + + def __init__(self, tokenizer: "AnyTokenizer", *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + if not self.model_tokenizer: + raise ValueError("Tokenizer must be passed to ReasoningParser.") + if not self.start_token or not self.end_token: + raise ValueError("start_token and end_token must be defined.") + + self.start_token_id: Optional[int] = self.vocab.get(self.start_token) + self.end_token_id: Optional[int] = self.vocab.get(self.end_token) + if self.start_token_id is None or self.end_token_id is None: + raise RuntimeError( + f"{self.__class__.__name__}: could not find think tokens " + f"'{self.start_token}'/'{self.end_token}' in tokenizer vocab." + ) + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + for token_id in reversed(input_ids): + if token_id == self.start_token_id: + return False + if token_id == self.end_token_id: + return True + return False + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Iterable[int] + ) -> bool: + return self.end_token_id in delta_ids + + def extract_content_ids(self, input_ids: list) -> list: + if self.end_token_id not in input_ids[:-1]: + return [] + return input_ids[input_ids.index(self.end_token_id) + 1:] + + def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int: + count = 0 + depth = 0 + for tid in token_ids: + if tid == self.start_token_id: + depth += 1 + elif tid == self.end_token_id: + if depth > 0: + depth -= 1 + elif depth > 0: + count += 1 + return count + + def extract_reasoning( + self, model_output: str, request: Any + ) -> "tuple[Optional[str], Optional[str]]": + # Strip if the model generated it (old-style template). + parts = model_output.partition(self.start_token) + model_output = parts[2] if parts[1] else parts[0] + + if self.end_token not in model_output: + return model_output, None + reasoning, _, content = model_output.partition(self.end_token) + return reasoning, content or None + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Optional["DeltaMessage"]: + from vllm.entrypoints.openai.protocol import DeltaMessage as _DeltaMessage + + # Suppress lone control tokens. + if len(delta_token_ids) == 1 and delta_token_ids[0] in ( + self.start_token_id, self.end_token_id + ): + return None + + start_in_prev = self.start_token_id in previous_token_ids + start_in_delta = self.start_token_id in delta_token_ids + end_in_prev = self.end_token_id in previous_token_ids + end_in_delta = self.end_token_id in delta_token_ids + + if start_in_prev: + if end_in_delta: + end_idx = delta_text.find(self.end_token) + reasoning = delta_text[:end_idx] if end_idx >= 0 else "" + content = delta_text[end_idx + len(self.end_token):] if end_idx >= 0 else None + return _DeltaMessage( + reasoning_content=reasoning or None, + content=content or None, + ) + elif end_in_prev: + return _DeltaMessage(content=delta_text) + else: + return _DeltaMessage(reasoning_content=delta_text) + + elif start_in_delta: + if end_in_delta: + start_idx = delta_text.find(self.start_token) + end_idx = delta_text.find(self.end_token) + reasoning = delta_text[start_idx + len(self.start_token):end_idx] + content = delta_text[end_idx + len(self.end_token):] + return _DeltaMessage( + reasoning_content=reasoning or None, + content=content or None, + ) + else: + return _DeltaMessage(reasoning_content=delta_text) + + else: + return _DeltaMessage(content=delta_text) + + +class ReasoningParserManager: + """ + Registry for ReasoningParser implementations. + Supports eager and lazy registration. + """ + + _parsers: dict = {} # name -> class (eager) + _lazy: dict = {} # name -> (module_path, class_name) + + @classmethod + def register_module(cls, name: str, parser_cls: type) -> None: + """Eagerly register a ReasoningParser class.""" + if not issubclass(parser_cls, ReasoningParser): + raise TypeError(f"{parser_cls} is not a ReasoningParser subclass.") + cls._parsers[name] = parser_cls + + @classmethod + def register_lazy(cls, name: str, module_path: str, class_name: str) -> None: + """Register a parser for deferred import.""" + cls._lazy[name] = (module_path, class_name) + + @classmethod + def get_reasoning_parser(cls, name: str) -> type: + if name in cls._parsers: + return cls._parsers[name] + if name in cls._lazy: + module_path, class_name = cls._lazy[name] + mod = importlib.import_module(module_path) + parser_cls = getattr(mod, class_name) + cls._parsers[name] = parser_cls + return parser_cls + registered = sorted(set(cls._parsers) | set(cls._lazy)) + raise KeyError( + f"Reasoning parser '{name}' not found. " + f"Available: {registered}" + ) + + @classmethod + def list_registered(cls) -> list: + return sorted(set(cls._parsers) | set(cls._lazy)) diff --git a/qwen3_6_scripts/reasoning/qwen3_reasoning_parser.py b/qwen3_6_scripts/reasoning/qwen3_reasoning_parser.py new file mode 100644 index 0000000..d90fc91 --- /dev/null +++ b/qwen3_6_scripts/reasoning/qwen3_reasoning_parser.py @@ -0,0 +1,108 @@ +""" +Reasoning parser for Qwen3 / Qwen3.5 / Qwen3.6 model family. +Adapted from vllm-original/vllm/reasoning/qwen3_reasoning_parser.py. + +The model uses ... to wrap chain-of-thought output. +For Qwen3.5+ the chat template injects into the prompt, so only + appears in the generated tokens; older templates generate +themselves. Both styles are handled. +""" + +from typing import Optional, Sequence, Any + +from vllm.reasoning.abs_reasoning_parsers import ( + BaseThinkingReasoningParser, + ReasoningParserManager, +) + + +class Qwen3ReasoningParser(BaseThinkingReasoningParser): + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} + self.thinking_enabled = chat_kwargs.get("enable_thinking", True) + + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def extract_reasoning( + self, model_output: str, request: Any + ) -> "tuple[Optional[str], Optional[str]]": + # Strip if the model generated it (old template / edge case). + parts = model_output.partition(self.start_token) + model_output = parts[2] if parts[1] else parts[0] + + if self.end_token not in model_output: + if not self.thinking_enabled: + return None, model_output + # Thinking enabled but output truncated before . + return model_output, None + + reasoning, _, content = model_output.partition(self.end_token) + return reasoning, content or None + + def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int: + token_ids = list(token_ids) + if self.start_token_id in token_ids: + # Old-style template: model generates itself. + # Use depth-counting from the base class. + return super().count_reasoning_tokens(token_ids) + elif self.end_token_id in token_ids: + # New-style template (Qwen3.5+): is injected into the + # prompt, so output starts already inside the thinking block. + # Every token before is a reasoning token. + return token_ids.index(self.end_token_id) + else: + # No in output: either truncated (all reasoning) + # or thinking disabled (none). + return len(token_ids) if self.thinking_enabled else 0 + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ): + from vllm.entrypoints.openai.protocol import DeltaMessage + + if not self.thinking_enabled: + return DeltaMessage(content=delta_text) if delta_text else None + + # Strip from delta if the model generates it itself. + if self.start_token_id in delta_token_ids: + start_idx = delta_text.find(self.start_token) + if start_idx >= 0: + delta_text = delta_text[start_idx + len(self.start_token):] + + if self.end_token_id in delta_token_ids: + end_idx = delta_text.find(self.end_token) + if end_idx >= 0: + reasoning = delta_text[:end_idx] + content = delta_text[end_idx + len(self.end_token):] + if not reasoning and not content: + return None + return DeltaMessage( + reasoning_content=reasoning or None, + content=content or None, + ) + return None + + if not delta_text: + return None + elif self.end_token_id in previous_token_ids: + return DeltaMessage(content=delta_text) + else: + return DeltaMessage(reasoning_content=delta_text) + + +# Register immediately when this module is imported. +ReasoningParserManager.register_module("qwen3", Qwen3ReasoningParser) diff --git a/qwen3_6_scripts/serving_chat.py b/qwen3_6_scripts/serving_chat.py new file mode 100644 index 0000000..2c91959 --- /dev/null +++ b/qwen3_6_scripts/serving_chat.py @@ -0,0 +1,994 @@ +import asyncio +import json +import time +from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List, + Optional) +from typing import Sequence as GenericSequence +from typing import Union + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import (ConversationMessage, + apply_hf_chat_template, + apply_mistral_chat_template, + load_chat_template, + parse_chat_messages_futures) +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import ( + ChatCompletionLogProb, ChatCompletionLogProbs, + ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, + ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, + DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata, + ToolCall, UsageInfo) +from vllm.entrypoints.openai.serving_engine import (BaseModelPath, + LoRAModulePath, + OpenAIServing, + PromptAdapterPath, + TextTokensPrompt) +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.inputs import TokensPrompt +from vllm.logger import init_logger +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sequence import Logprob +from vllm.tracing import (contains_trace_headers, extract_trace_headers, + log_tracing_disabled_warning) +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import iterate_with_cancellation, random_uuid + +logger = init_logger(__name__) + + +class OpenAIServingChat(OpenAIServing): + + def __init__(self, + engine_client: EngineClient, + model_config: ModelConfig, + base_model_paths: List[BaseModelPath], + response_role: str, + *, + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]], + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + return_tokens_as_token_ids: bool = False, + enable_auto_tools: bool = False, + tool_parser: Optional[str] = None, + reasoning_parser: Optional[str] = None): + super().__init__(engine_client=engine_client, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=lora_modules, + prompt_adapters=prompt_adapters, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids) + + self.response_role = response_role + self.use_tool_use_model_template = False + self.chat_template = load_chat_template(chat_template) + + # set up tool use + self.enable_auto_tools: bool = enable_auto_tools + if self.enable_auto_tools: + logger.info( + "\"auto\" tool choice has been enabled please note that while" + " the parallel_tool_calls client option is preset for " + "compatibility reasons, it will be ignored.") + + self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None + if self.enable_auto_tools: + try: + self.tool_parser = ToolParserManager.get_tool_parser( + tool_parser) + except Exception as e: + raise TypeError("Error: --enable-auto-tool-choice requires " + f"tool_parser:'{tool_parser}' which has not " + "been registered") from e + + # set up reasoning parser + self.reasoning_parser_cls = None + if reasoning_parser: + try: + from vllm.reasoning import ReasoningParserManager + self.reasoning_parser_cls = \ + ReasoningParserManager.get_reasoning_parser(reasoning_parser) + logger.info("Reasoning parser '%s' enabled.", reasoning_parser) + except Exception as e: + raise TypeError( + f"Error: --reasoning-parser '{reasoning_parser}' could not " + "be loaded. Make sure vllm/reasoning/ is installed." + ) from e + + async def create_chat_completion( + self, + request: ChatCompletionRequest, + raw_request: Optional[Request] = None, + ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, + ErrorResponse]: + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/chat/create + for the API specification. This API mimics the OpenAI + ChatCompletion API. + + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + logger.error("Error with model %s", error_check_ret) + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + model_config = self.model_config + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + conversation, mm_data_future = parse_chat_messages_futures( + request.messages, model_config, tokenizer) + + tool_dicts = None if request.tools is None else [ + tool.model_dump() for tool in request.tools + ] + + prompt: Union[str, List[int]] + is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer) + if is_mistral_tokenizer: + prompt = apply_mistral_chat_template( + tokenizer, + messages=request.messages, + chat_template=request.chat_template or self.chat_template, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + tools=tool_dicts, + documents=request.documents, + **(request.chat_template_kwargs or {}), + ) + else: + prompt = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=request.chat_template or self.chat_template, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + tools=tool_dicts, + documents=request.documents, + **(request.chat_template_kwargs or {}), + ) + except Exception as e: + logger.exception("Error in applying chat template from request") + return self.create_error_response(str(e)) + + try: + mm_data = await mm_data_future + except Exception as e: + logger.exception("Error in loading multi-modal data") + return self.create_error_response(str(e)) + + # validation for OpenAI tools + # tool_choice = "required" is not supported + if request.tool_choice == "required": + return self.create_error_response( + "tool_choice = \"required\" is not supported!") + + if not is_mistral_tokenizer and request.tool_choice == "auto" and not ( + self.enable_auto_tools and self.tool_parser is not None): + # for hf tokenizers, "auto" tools requires + # --enable-auto-tool-choice and --tool-call-parser + return self.create_error_response( + "\"auto\" tool choice requires " + "--enable-auto-tool-choice and --tool-call-parser to be set") + + request_id = f"chat-{random_uuid()}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + try: + if self.enable_auto_tools and self.tool_parser: + request = self.tool_parser(tokenizer).adjust_request( + request=request) + + if isinstance(prompt, str): + prompt_inputs = self._tokenize_prompt_input( + request, + tokenizer, + prompt, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + else: + assert isinstance(prompt, list) and isinstance( + prompt[0], int + ), "Prompt has to be either a string or a list of token ids" + prompt_inputs = TextTokensPrompt( + prompt=tokenizer.decode(prompt), prompt_token_ids=prompt) + + assert prompt_inputs is not None + + sampling_params: Union[SamplingParams, BeamSearchParams] + default_max_tokens = self.max_model_len - len( + prompt_inputs["prompt_token_ids"]) + if request.use_beam_search: + sampling_params = request.to_beam_search_params( + default_max_tokens) + else: + sampling_params = request.to_sampling_params( + default_max_tokens) + + self._log_inputs(request_id, + prompt_inputs, + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + engine_inputs = TokensPrompt( + prompt_token_ids=prompt_inputs["prompt_token_ids"]) + if mm_data is not None: + engine_inputs["multi_modal_data"] = mm_data + + is_tracing_enabled = (await + self.engine_client.is_tracing_enabled()) + trace_headers = None + if is_tracing_enabled and raw_request: + trace_headers = extract_trace_headers(raw_request.headers) + if (not is_tracing_enabled and raw_request + and contains_trace_headers(raw_request.headers)): + log_tracing_disabled_warning() + + if isinstance(sampling_params, BeamSearchParams): + assert isinstance(self.engine_client, + (AsyncLLMEngine, + MQLLMEngineClient)), \ + "Beam search is only supported with" \ + "AsyncLLMEngine and MQLLMEngineClient." + result_generator = self.engine_client.beam_search( + engine_inputs['prompt_token_ids'], + request_id, + sampling_params, + ) + else: + result_generator = self.engine_client.generate( + engine_inputs, + sampling_params, + request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=request.priority, + ) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + if raw_request: + result_generator = iterate_with_cancellation( + result_generator, raw_request.is_disconnected) + + # Streaming response + if request.stream: + return self.chat_completion_stream_generator( + request, result_generator, request_id, conversation, tokenizer, + request_metadata) + + try: + return await self.chat_completion_full_generator( + request, result_generator, request_id, conversation, tokenizer, + request_metadata) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + def get_chat_request_role(self, request: ChatCompletionRequest) -> str: + if request.add_generation_prompt: + return self.response_role + return request.messages[-1]["role"] + + async def chat_completion_stream_generator( + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + conversation: List[ConversationMessage], + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + ) -> AsyncGenerator[str, None]: + model_name = self.base_model_paths[0].name + created_time = int(time.time()) + chunk_object_type: Final = "chat.completion.chunk" + first_iteration = True + + # Send response for each token for each request.n (index) + num_choices = 1 if request.n is None else request.n + previous_num_tokens = [0] * num_choices + finish_reason_sent = [False] * num_choices + num_prompt_tokens = 0 + + if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): + tool_choice_function_name = request.tool_choice.function.name + else: + tool_choice_function_name = None + + # Determine whether tools are in use with "auto" tool choice + tool_choice_auto = ( + not tool_choice_function_name + and self._should_stream_with_auto_tool_parsing(request)) + + use_reasoning = self.reasoning_parser_cls is not None + + all_previous_token_ids: Optional[List[List[int]]] + # previous_texts / all_previous_token_ids are needed for both tool + # parsing and reasoning parsing (both require full-history context). + if tool_choice_auto or use_reasoning: + previous_texts = [""] * num_choices + all_previous_token_ids = [[]] * num_choices + else: + previous_texts, all_previous_token_ids = None, None + + # Prepare the tool parser if it's needed + try: + if tool_choice_auto and self.tool_parser: + tool_parsers: List[Optional[ToolParser]] = [ + self.tool_parser(tokenizer) + ] * num_choices + else: + tool_parsers = [None] * num_choices + except RuntimeError as e: + logger.error("Error in tool parser creation: %s", e) + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + return + + # Prepare reasoning parsers (one instance per choice for state isolation) + reasoning_parsers: List[Optional[object]] = [None] * num_choices + reasoning_end_arr: List[bool] = [False] * num_choices + reasoning_token_counts: List[int] = [0] * num_choices + if use_reasoning: + try: + reasoning_parsers = [ + self.reasoning_parser_cls( + tokenizer, + chat_template_kwargs=request.chat_template_kwargs) + for _ in range(num_choices) + ] + # If thinking is disabled per-request, mark reasoning as + # already ended so the tool-auto branch is reachable. + for idx, rp in enumerate(reasoning_parsers): + if hasattr(rp, 'thinking_enabled') and not rp.thinking_enabled: + reasoning_end_arr[idx] = True + except RuntimeError as e: + logger.error("Error in reasoning parser creation: %s", e) + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + return + + try: + async for res in result_generator: + if res.prompt_token_ids is not None: + num_prompt_tokens = len(res.prompt_token_ids) + if res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(res.encoder_prompt_token_ids) + + # We need to do it here, because if there are exceptions in + # the result_generator, it needs to be sent as the FIRST + # response (by the try...catch). + if first_iteration: + # Send first response for each request.n (index) with + # the role + role = self.get_chat_request_role(request) + + # NOTE num_choices defaults to 1 so this usually executes + # once per request + for i in range(num_choices): + tool_parser = tool_parsers[i] + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage( + role=role, + content="", + ), + logprobs=None, + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # if usage should be included + if (request.stream_options + and request.stream_options.include_usage): + # if continuous usage stats are requested, add it + if request.stream_options.continuous_usage_stats: + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens) + chunk.usage = usage + # otherwise don't + else: + chunk.usage = None + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Send response to echo the input portion of the + # last message + if request.echo or request.continue_final_message: + last_msg_content: str = "" + if conversation and "content" in conversation[ + -1] and conversation[-1].get("role") == role: + last_msg_content = conversation[-1]["content"] or "" + + if last_msg_content: + for i in range(num_choices): + choice_data = ( + ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage( + content=last_msg_content), + logprobs=None, + finish_reason=None)) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + if (request.stream_options and + request.stream_options.include_usage): + if (request.stream_options. + continuous_usage_stats): + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens) + chunk.usage = usage + else: + chunk.usage = None + + data = chunk.model_dump_json( + exclude_unset=True) + yield f"data: {data}\n\n" + first_iteration = False + + for output in res.outputs: + i = output.index + tool_parser = tool_parsers[i] + + if finish_reason_sent[i]: + continue + + if request.logprobs and request.top_logprobs is not None: + assert output.logprobs is not None, ( + "Did not output logprobs") + logprobs = self._create_chat_logprobs( + token_ids=output.token_ids, + top_logprobs=output.logprobs, + tokenizer=tokenizer, + num_output_top_logprobs=request.top_logprobs, + ) + else: + logprobs = None + + delta_text = output.text + delta_message: Optional[DeltaMessage] + + # Maintain text/token history when either reasoning or + # auto-tool parsing is active. + assert previous_texts is not None or not ( + tool_choice_auto or use_reasoning) + if previous_texts is not None: + assert all_previous_token_ids is not None + previous_text = previous_texts[i] + previous_token_ids = all_previous_token_ids[i] + current_text = previous_text + delta_text + current_token_ids = previous_token_ids + list( + output.token_ids) + previous_texts[i] = current_text + all_previous_token_ids[i] = current_token_ids + else: + previous_text = "" + previous_token_ids = [] + current_text = delta_text + current_token_ids = list(output.token_ids) + + # handle streaming deltas for tools with named tool_choice + if tool_choice_function_name: + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall(function=DeltaFunctionCall( + name=tool_choice_function_name, + arguments=delta_text), + index=i) + ]) + + # handle reasoning: route through reasoning parser while + # has not yet been seen. + elif use_reasoning and not reasoning_end_arr[i]: + r_parser = reasoning_parsers[i] + delta_message = r_parser.extract_reasoning_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=output.token_ids, + ) + # Mark reasoning as ended when end token appears. + if r_parser.end_token_id in current_token_ids: + reasoning_end_arr[i] = True + + # handle streaming deltas for tools with "auto" tool choice + # (only reached after reasoning block, if any, has ended) + elif tool_choice_auto: + assert tool_parser is not None + delta_message = ( + tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=output.token_ids, + request=request)) + + # handle streaming just a content delta + else: + delta_message = DeltaMessage(content=delta_text) + + # set the previous values for the next iteration + previous_num_tokens[i] += len(output.token_ids) + + # if the message delta is None (e.g. because it was a + # "control token" for tool calls or the parser otherwise + # wasn't ready to send a token, then + # get the next token without streaming a chunk + if delta_message is None: + continue + + if output.finish_reason is None: + # Send token-by-token response for each request.n + + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # handle usage stats if requested & if continuous + if (request.stream_options + and request.stream_options.include_usage): + if request.stream_options.continuous_usage_stats: + completion_tokens = len(output.token_ids) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + + completion_tokens, + ) + chunk.usage = usage + else: + chunk.usage = None + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # if the model is finished generating + else: + # check to make sure we haven't "forgotten" to stream + # any tokens that were generated but previously + # matched by partial json parsing + # only happens if we are NOT using guided decoding + auto_tools_called = False + if tool_parser: + auto_tools_called = len( + tool_parser.prev_tool_call_arr) > 0 + index = len(tool_parser.prev_tool_call_arr + ) - 1 if auto_tools_called else 0 + else: + index = 0 + + if self._should_check_for_unstreamed_tool_arg_tokens( + delta_message, output) and tool_parser: + # get the expected call based on partial JSON + # parsing which "autocompletes" the JSON + expected_call = json.dumps( + tool_parser.prev_tool_call_arr[index].get( + "arguments", {})) + + # get what we've streamed so far for arguments + # for the current tool + actual_call = tool_parser.streamed_args_for_tool[ + index] + + # check to see if there's anything left to stream + remaining_call = expected_call.replace( + actual_call, "", 1) + + # set that as a delta message + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall(index=index, + function=DeltaFunctionCall( + arguments=remaining_call). + model_dump(exclude_none=True)) + ]) + + # Count reasoning tokens for this choice at finish time. + if use_reasoning and all_previous_token_ids is not None: + r_parser = reasoning_parsers[i] + reasoning_token_counts[i] = \ + r_parser.count_reasoning_tokens( + all_previous_token_ids[i]) + + # Send the finish response for each request.n only once + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=output.finish_reason + if not auto_tools_called else "tool_calls", + stop_reason=output.stop_reason) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + if (request.stream_options + and request.stream_options.include_usage): + if request.stream_options.continuous_usage_stats: + completion_tokens = len(output.token_ids) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + + completion_tokens, + ) + chunk.usage = usage + else: + chunk.usage = None + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + finish_reason_sent[i] = True + + # once the final token is handled, if stream_options.include_usage + # is sent, send the usage + if (request.stream_options + and request.stream_options.include_usage): + completion_tokens = previous_num_tokens[i] + total_reasoning = sum(reasoning_token_counts) if use_reasoning else None + final_usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + reasoning_tokens=total_reasoning, + ) + + final_usage_chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[], + model=model_name, + usage=final_usage) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" + + # report to FastAPI middleware aggregate usage across all choices + num_completion_tokens = sum(previous_num_tokens) + total_reasoning = sum(reasoning_token_counts) if use_reasoning else None + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_completion_tokens, + total_tokens=num_prompt_tokens + num_completion_tokens, + reasoning_tokens=total_reasoning) + + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + logger.error("error in chat completion stream generator: %s", e) + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + + async def chat_completion_full_generator( + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + conversation: List[ConversationMessage], + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + ) -> Union[ErrorResponse, ChatCompletionResponse]: + + model_name = self.base_model_paths[0].name + created_time = int(time.time()) + final_res: Optional[RequestOutput] = None + + try: + async for res in result_generator: + final_res = res + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + + assert final_res is not None + + choices: List[ChatCompletionResponseChoice] = [] + + role = self.get_chat_request_role(request) + for output in final_res.outputs: + token_ids = output.token_ids + out_logprobs = output.logprobs + + if request.logprobs and request.top_logprobs is not None: + assert out_logprobs is not None, "Did not output logprobs" + logprobs = self._create_chat_logprobs( + token_ids=token_ids, + top_logprobs=out_logprobs, + num_output_top_logprobs=request.top_logprobs, + tokenizer=tokenizer, + ) + else: + logprobs = None + + # In the OpenAI API the finish_reason is "tools_called" + # if the tool choice is auto and the model produced a tool + # call. The same is not true for named function calls + auto_tools_called = False + + # Extract reasoning content if parser is configured. + # output_text is what remains after stripping .... + reasoning_text: Optional[str] = None + output_text: str = output.text + if self.reasoning_parser_cls: + r_parser = self.reasoning_parser_cls( + tokenizer, + chat_template_kwargs=request.chat_template_kwargs) + reasoning_text, extracted = r_parser.extract_reasoning( + output.text, request) + output_text = extracted or "" + + # if auto tools are not enabled, and a named tool choice using + # outlines is not being used + if (not self.enable_auto_tools + or not self.tool_parser) and not isinstance( + request.tool_choice, + ChatCompletionNamedToolChoiceParam): + message = ChatMessage(role=role, + reasoning_content=reasoning_text, + content=output_text) + + # if the request uses tools and specified a tool choice + elif request.tool_choice and type( + request.tool_choice) is ChatCompletionNamedToolChoiceParam: + + message = ChatMessage( + role=role, + reasoning_content=reasoning_text, + content="", + tool_calls=[ + ToolCall(function=FunctionCall( + name=request.tool_choice.function.name, + arguments=output_text)) + ]) + + # if the request doesn't use tool choice + # OR specifies to not use a tool + elif not request.tool_choice or request.tool_choice == "none": + + message = ChatMessage(role=role, + reasoning_content=reasoning_text, + content=output_text) + + # handle when there are tools and tool choice is auto + elif request.tools and ( + request.tool_choice == "auto" + or request.tool_choice is None) and self.enable_auto_tools \ + and self.tool_parser: + + try: + tool_parser = self.tool_parser(tokenizer) + except RuntimeError as e: + logger.error("Error in tool parser creation: %s", e) + return self.create_error_response(str(e)) + + # Parse tool calls from the post-reasoning content. + tool_call_info = tool_parser.extract_tool_calls( + output_text, request=request) + auto_tools_called = tool_call_info.tools_called + if tool_call_info.tools_called: + message = ChatMessage( + role=role, + reasoning_content=reasoning_text, + content=tool_call_info.content, + tool_calls=tool_call_info.tool_calls) + else: + message = ChatMessage(role=role, + reasoning_content=reasoning_text, + content=output_text) + + # undetermined case that is still important to handle + else: + logger.error( + "Error in chat_completion_full_generator - cannot determine" + " if tools should be extracted. Returning a standard chat " + "completion.") + message = ChatMessage(role=role, + reasoning_content=reasoning_text, + content=output_text) + + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=message, + logprobs=logprobs, + finish_reason="tool_calls" if auto_tools_called else + output.finish_reason if output.finish_reason else "stop", + stop_reason=output.stop_reason) + choices.append(choice_data) + + if request.echo or request.continue_final_message: + last_msg_content = "" + if conversation and "content" in conversation[-1] and conversation[ + -1].get("role") == role: + last_msg_content = conversation[-1]["content"] or "" + + for choice in choices: + full_message = last_msg_content + (choice.message.content + or "") + choice.message.content = full_message + + assert final_res.prompt_token_ids is not None + num_prompt_tokens = len(final_res.prompt_token_ids) + if final_res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(final_res.encoder_prompt_token_ids) + num_generated_tokens = sum( + len(output.token_ids) for output in final_res.outputs) + total_reasoning_tokens: Optional[int] = None + if self.reasoning_parser_cls: + rp = self.reasoning_parser_cls( + tokenizer, + chat_template_kwargs=request.chat_template_kwargs) + total_reasoning_tokens = sum( + rp.count_reasoning_tokens(list(output.token_ids)) + for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + reasoning_tokens=total_reasoning_tokens, + ) + + request_metadata.final_usage_info = usage + + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + prompt_logprobs=final_res.prompt_logprobs, + ) + + return response + + def _get_top_logprobs( + self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int], + tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]: + return [ + ChatCompletionLogProb(token=(token := self._get_decoded_token( + p[1], + p[0], + tokenizer, + return_as_token_id=self.return_tokens_as_token_ids)), + logprob=max(p[1].logprob, -9999.0), + bytes=list( + token.encode("utf-8", errors="replace"))) + for i, p in enumerate(logprobs.items()) + if top_logprobs and i < top_logprobs + ] + + def _create_chat_logprobs( + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], + tokenizer: AnyTokenizer, + num_output_top_logprobs: Optional[int] = None, + ) -> ChatCompletionLogProbs: + """Create OpenAI-style logprobs.""" + logprobs_content: List[ChatCompletionLogProbsContent] = [] + + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is None: + token = tokenizer.decode(token_id) + if self.return_tokens_as_token_ids: + token = f"token_id:{token_id}" + + logprobs_content.append( + ChatCompletionLogProbsContent( + token=token, + bytes=list(token.encode("utf-8", errors="replace")), + )) + else: + step_token = step_top_logprobs[token_id] + step_decoded = step_token.decoded_token + + logprobs_content.append( + ChatCompletionLogProbsContent( + token=self._get_decoded_token( + step_token, + token_id, + tokenizer, + self.return_tokens_as_token_ids, + ), + logprob=max(step_token.logprob, -9999.0), + bytes=None if step_decoded is None else list( + step_decoded.encode("utf-8", errors="replace")), + top_logprobs=self._get_top_logprobs( + step_top_logprobs, + num_output_top_logprobs, + tokenizer, + ), + )) + + return ChatCompletionLogProbs(content=logprobs_content) + + def _should_stream_with_auto_tool_parsing(self, + request: ChatCompletionRequest): + """ + Utility function to check if streamed tokens should go through the tool + call parser that was configured. + + We only want to do this IF user-provided tools are set, a tool parser + is configured, "auto" tool choice is enabled, and the request's tool + choice field indicates that "auto" tool choice should be used. + """ + return (request.tools and self.tool_parser and self.enable_auto_tools + and request.tool_choice in ['auto', None]) + + def _should_check_for_unstreamed_tool_arg_tokens( + self, + delta_message: Optional[DeltaMessage], + output: CompletionOutput, + ) -> bool: + """ + Check to see if we should check for unstreamed tool arguments tokens. + This is only applicable when auto tool parsing is enabled, the delta + is a tool call with arguments. + """ + + # yapf: disable + return bool( + # if there is a delta message that includes tool calls which + # include a function that has arguments + output.finish_reason is not None + and self.enable_auto_tools and self.tool_parser and delta_message + and delta_message.tool_calls and delta_message.tool_calls[0] + and delta_message.tool_calls[0].function + and delta_message.tool_calls[0].function.arguments is not None + )