Sync from v0.13
This commit is contained in:
4
vllm/entrypoints/sagemaker/__init__.py
Normal file
4
vllm/entrypoints/sagemaker/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""SageMaker-specific integration for vLLM."""
|
||||
118
vllm/entrypoints/sagemaker/routes.py
Normal file
118
vllm/entrypoints/sagemaker/routes.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
|
||||
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||
import pydantic
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
base,
|
||||
chat,
|
||||
completion,
|
||||
create_chat_completion,
|
||||
create_completion,
|
||||
validate_json_request,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
ErrorResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.pooling.classify.api_router import classify, create_classify
|
||||
from vllm.entrypoints.pooling.classify.protocol import ClassificationRequest
|
||||
from vllm.entrypoints.pooling.embed.api_router import create_embedding, embedding
|
||||
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest
|
||||
from vllm.entrypoints.pooling.pooling.api_router import create_pooling, pooling
|
||||
from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest
|
||||
from vllm.entrypoints.pooling.score.api_router import (
|
||||
create_score,
|
||||
do_rerank,
|
||||
rerank,
|
||||
score,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.protocol import RerankRequest, ScoreRequest
|
||||
from vllm.entrypoints.serve.instrumentator.health import health
|
||||
|
||||
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
|
||||
# (requires typing_extensions >= 4.13)
|
||||
RequestType = Any
|
||||
GetHandlerFn = Callable[[Request], OpenAIServing | None]
|
||||
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
|
||||
|
||||
# NOTE: Items defined earlier take higher priority
|
||||
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [
|
||||
(ChatCompletionRequest, (chat, create_chat_completion)),
|
||||
(CompletionRequest, (completion, create_completion)),
|
||||
(EmbeddingRequest, (embedding, create_embedding)),
|
||||
(ClassificationRequest, (classify, create_classify)),
|
||||
(ScoreRequest, (score, create_score)),
|
||||
(RerankRequest, (rerank, do_rerank)),
|
||||
(PoolingRequest, (pooling, create_pooling)),
|
||||
]
|
||||
|
||||
# NOTE: Construct the TypeAdapters only once
|
||||
INVOCATION_VALIDATORS = [
|
||||
(pydantic.TypeAdapter(request_type), (get_handler, endpoint))
|
||||
for request_type, (get_handler, endpoint) in INVOCATION_TYPES
|
||||
]
|
||||
|
||||
|
||||
def register_sagemaker_routes(router: APIRouter):
|
||||
@router.post("/ping", response_class=Response)
|
||||
@router.get("/ping", response_class=Response)
|
||||
@sagemaker_standards.register_ping_handler
|
||||
async def ping(raw_request: Request) -> Response:
|
||||
"""Ping check. Endpoint required for SageMaker"""
|
||||
return await health(raw_request)
|
||||
|
||||
@router.post(
|
||||
"/invocations",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@sagemaker_standards.register_invocation_handler
|
||||
@sagemaker_standards.stateful_session_manager()
|
||||
@sagemaker_standards.inject_adapter_id(adapter_path="model")
|
||||
async def invocations(raw_request: Request):
|
||||
"""For SageMaker, routes requests based on the request type."""
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail=f"JSON decode error: {e}",
|
||||
) from e
|
||||
|
||||
valid_endpoints = [
|
||||
(validator, endpoint)
|
||||
for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS
|
||||
if get_handler(raw_request) is not None
|
||||
]
|
||||
|
||||
for request_validator, endpoint in valid_endpoints:
|
||||
try:
|
||||
request = request_validator.validate_python(body)
|
||||
except pydantic.ValidationError:
|
||||
continue
|
||||
|
||||
return await endpoint(request, raw_request)
|
||||
|
||||
type_names = [
|
||||
t.__name__ if isinstance(t := validator._type, type) else str(t)
|
||||
for validator, _ in valid_endpoints
|
||||
]
|
||||
msg = f"Cannot find suitable handler for request. Expected one of: {type_names}"
|
||||
res = base(raw_request).create_error_response(message=msg)
|
||||
return JSONResponse(content=res.model_dump(), status_code=res.error.code)
|
||||
|
||||
return router
|
||||
Reference in New Issue
Block a user