# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import json from http import HTTPStatus from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response from fastapi.responses import JSONResponse, StreamingResponse from vllm.engine.protocol import EngineClient from vllm.entrypoints.openai.api_server import validate_json_request from vllm.entrypoints.openai.protocol import ( ErrorResponse, ) from vllm.entrypoints.serve.disagg.protocol import ( GenerateRequest, GenerateResponse, ) from vllm.entrypoints.serve.disagg.serving import ( ServingTokens, ) from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization from vllm.entrypoints.utils import ( load_aware_call, with_cancellation, ) from vllm.logger import init_logger logger = init_logger(__name__) def tokenization(request: Request) -> OpenAIServingTokenization: return request.app.state.openai_serving_tokenization def generate_tokens(request: Request) -> ServingTokens | None: return request.app.state.serving_tokens def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client router = APIRouter() @router.post( "/inference/v1/generate", dependencies=[Depends(validate_json_request)], responses={ HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, }, ) @with_cancellation @load_aware_call async def generate(request: GenerateRequest, raw_request: Request): handler = generate_tokens(raw_request) if handler is None: return tokenization(raw_request).create_error_response( message="The model does not support generate tokens API" ) try: generator = await handler.serve_tokens(request, raw_request) except Exception as e: raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) ) from e if isinstance(generator, ErrorResponse): return JSONResponse( content=generator.model_dump(), status_code=generator.error.code ) elif isinstance(generator, GenerateResponse): return JSONResponse(content=generator.model_dump()) return StreamingResponse(content=generator, media_type="text/event-stream") def attach_router(app: FastAPI): if getattr(app.state.args, "tokens_only", False): @router.post("/abort_requests") async def abort_requests(raw_request: Request): """ Abort one or more requests. To be used in a Disaggregated Everything setup. """ try: body = await raw_request.json() except json.JSONDecodeError as e: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST.value, detail=f"JSON decode error: {e}", ) from e request_ids = body.get("request_ids") if request_ids is None: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST.value, detail="Missing 'request_ids' in request body", ) # Abort requests in background asyncio.create_task(engine_client(raw_request).abort(request_ids)) return Response(status_code=200) app.include_router(router)