Files
2026-01-19 10:38:50 +08:00

119 lines
4.5 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Awaitable, Callable
from http import HTTPStatus
from typing import Any
import model_hosting_container_standards.sagemaker as sagemaker_standards
import pydantic
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse, Response
from vllm.entrypoints.openai.api_server import (
base,
chat,
completion,
create_chat_completion,
create_completion,
validate_json_request,
)
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
ErrorResponse,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.pooling.classify.api_router import classify, create_classify
from vllm.entrypoints.pooling.classify.protocol import ClassificationRequest
from vllm.entrypoints.pooling.embed.api_router import create_embedding, embedding
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest
from vllm.entrypoints.pooling.pooling.api_router import create_pooling, pooling
from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest
from vllm.entrypoints.pooling.score.api_router import (
create_score,
do_rerank,
rerank,
score,
)
from vllm.entrypoints.pooling.score.protocol import RerankRequest, ScoreRequest
from vllm.entrypoints.serve.instrumentator.health import health
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
RequestType = Any
GetHandlerFn = Callable[[Request], OpenAIServing | None]
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
# NOTE: Items defined earlier take higher priority
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [
(ChatCompletionRequest, (chat, create_chat_completion)),
(CompletionRequest, (completion, create_completion)),
(EmbeddingRequest, (embedding, create_embedding)),
(ClassificationRequest, (classify, create_classify)),
(ScoreRequest, (score, create_score)),
(RerankRequest, (rerank, do_rerank)),
(PoolingRequest, (pooling, create_pooling)),
]
# NOTE: Construct the TypeAdapters only once
INVOCATION_VALIDATORS = [
(pydantic.TypeAdapter(request_type), (get_handler, endpoint))
for request_type, (get_handler, endpoint) in INVOCATION_TYPES
]
def register_sagemaker_routes(router: APIRouter):
@router.post("/ping", response_class=Response)
@router.get("/ping", response_class=Response)
@sagemaker_standards.register_ping_handler
async def ping(raw_request: Request) -> Response:
"""Ping check. Endpoint required for SageMaker"""
return await health(raw_request)
@router.post(
"/invocations",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@sagemaker_standards.register_invocation_handler
@sagemaker_standards.stateful_session_manager()
@sagemaker_standards.inject_adapter_id(adapter_path="model")
async def invocations(raw_request: Request):
"""For SageMaker, routes requests based on the request type."""
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=f"JSON decode error: {e}",
) from e
valid_endpoints = [
(validator, endpoint)
for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS
if get_handler(raw_request) is not None
]
for request_validator, endpoint in valid_endpoints:
try:
request = request_validator.validate_python(body)
except pydantic.ValidationError:
continue
return await endpoint(request, raw_request)
type_names = [
t.__name__ if isinstance(t := validator._type, type) else str(t)
for validator, _ in valid_endpoints
]
msg = f"Cannot find suitable handler for request. Expected one of: {type_names}"
res = base(raw_request).create_error_response(message=msg)
return JSONResponse(content=res.model_dump(), status_code=res.error.code)
return router