111 lines
3.5 KiB
Python
111 lines
3.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
|
import asyncio
|
|
import json
|
|
from http import HTTPStatus
|
|
|
|
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
|
|
from vllm.engine.protocol import EngineClient
|
|
from vllm.entrypoints.openai.api_server import validate_json_request
|
|
from vllm.entrypoints.openai.protocol import (
|
|
ErrorResponse,
|
|
)
|
|
from vllm.entrypoints.serve.disagg.protocol import (
|
|
GenerateRequest,
|
|
GenerateResponse,
|
|
)
|
|
from vllm.entrypoints.serve.disagg.serving import (
|
|
ServingTokens,
|
|
)
|
|
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
|
|
from vllm.entrypoints.utils import (
|
|
load_aware_call,
|
|
with_cancellation,
|
|
)
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def tokenization(request: Request) -> OpenAIServingTokenization:
|
|
return request.app.state.openai_serving_tokenization
|
|
|
|
|
|
def generate_tokens(request: Request) -> ServingTokens | None:
|
|
return request.app.state.serving_tokens
|
|
|
|
|
|
def engine_client(request: Request) -> EngineClient:
|
|
return request.app.state.engine_client
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.post(
|
|
"/inference/v1/generate",
|
|
dependencies=[Depends(validate_json_request)],
|
|
responses={
|
|
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
|
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
|
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
|
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
|
},
|
|
)
|
|
@with_cancellation
|
|
@load_aware_call
|
|
async def generate(request: GenerateRequest, raw_request: Request):
|
|
handler = generate_tokens(raw_request)
|
|
if handler is None:
|
|
return tokenization(raw_request).create_error_response(
|
|
message="The model does not support generate tokens API"
|
|
)
|
|
try:
|
|
generator = await handler.serve_tokens(request, raw_request)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
|
) from e
|
|
if isinstance(generator, ErrorResponse):
|
|
return JSONResponse(
|
|
content=generator.model_dump(), status_code=generator.error.code
|
|
)
|
|
|
|
elif isinstance(generator, GenerateResponse):
|
|
return JSONResponse(content=generator.model_dump())
|
|
|
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
|
|
|
|
|
def attach_router(app: FastAPI):
|
|
if getattr(app.state.args, "tokens_only", False):
|
|
|
|
@router.post("/abort_requests")
|
|
async def abort_requests(raw_request: Request):
|
|
"""
|
|
Abort one or more requests. To be used in a
|
|
Disaggregated Everything setup.
|
|
"""
|
|
try:
|
|
body = await raw_request.json()
|
|
except json.JSONDecodeError as e:
|
|
raise HTTPException(
|
|
status_code=HTTPStatus.BAD_REQUEST.value,
|
|
detail=f"JSON decode error: {e}",
|
|
) from e
|
|
request_ids = body.get("request_ids")
|
|
if request_ids is None:
|
|
raise HTTPException(
|
|
status_code=HTTPStatus.BAD_REQUEST.value,
|
|
detail="Missing 'request_ids' in request body",
|
|
)
|
|
# Abort requests in background
|
|
asyncio.create_task(engine_client(raw_request).abort(request_ids))
|
|
return Response(status_code=200)
|
|
|
|
app.include_router(router)
|