Sync from v0.13
This commit is contained in:
96
vllm/entrypoints/serve/elastic_ep/api_router.py
Normal file
96
vllm/entrypoints/serve/elastic_ep/api_router.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.api_server import validate_json_request
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ErrorResponse,
|
||||
)
|
||||
from vllm.entrypoints.serve.elastic_ep.middleware import (
|
||||
get_scaling_elastic_ep,
|
||||
set_scaling_elastic_ep,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/scale_elastic_ep",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"model": dict},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def scale_elastic_ep(raw_request: Request):
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
|
||||
|
||||
new_data_parallel_size = body.get("new_data_parallel_size")
|
||||
drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes
|
||||
|
||||
if new_data_parallel_size is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="new_data_parallel_size is required"
|
||||
)
|
||||
|
||||
if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="new_data_parallel_size must be a positive integer",
|
||||
)
|
||||
|
||||
if not isinstance(drain_timeout, int) or drain_timeout <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="drain_timeout must be a positive integer"
|
||||
)
|
||||
|
||||
# Set scaling flag to prevent new requests
|
||||
set_scaling_elastic_ep(True)
|
||||
client = engine_client(raw_request)
|
||||
try:
|
||||
await client.scale_elastic_ep(new_data_parallel_size, drain_timeout)
|
||||
return JSONResponse(
|
||||
{
|
||||
"message": f"Scaled to {new_data_parallel_size} data parallel engines",
|
||||
}
|
||||
)
|
||||
except TimeoutError as e:
|
||||
raise HTTPException(
|
||||
status_code=408,
|
||||
detail="Scale failed due to request drain timeout "
|
||||
f"after {drain_timeout} seconds",
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error("Scale failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail="Scale failed") from e
|
||||
finally:
|
||||
set_scaling_elastic_ep(False)
|
||||
|
||||
|
||||
@router.post("/is_scaling_elastic_ep")
|
||||
async def is_scaling_elastic_ep(raw_request: Request):
|
||||
return JSONResponse({"is_scaling_elastic_ep": get_scaling_elastic_ep()})
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
Reference in New Issue
Block a user