97 lines
3.0 KiB
Python
97 lines
3.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
|
import json
|
|
from http import HTTPStatus
|
|
|
|
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from vllm.engine.protocol import EngineClient
|
|
from vllm.entrypoints.openai.api_server import validate_json_request
|
|
from vllm.entrypoints.openai.protocol import (
|
|
ErrorResponse,
|
|
)
|
|
from vllm.entrypoints.serve.elastic_ep.middleware import (
|
|
get_scaling_elastic_ep,
|
|
set_scaling_elastic_ep,
|
|
)
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def engine_client(request: Request) -> EngineClient:
|
|
return request.app.state.engine_client
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.post(
|
|
"/scale_elastic_ep",
|
|
dependencies=[Depends(validate_json_request)],
|
|
responses={
|
|
HTTPStatus.OK.value: {"model": dict},
|
|
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
|
HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse},
|
|
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
|
},
|
|
)
|
|
async def scale_elastic_ep(raw_request: Request):
|
|
try:
|
|
body = await raw_request.json()
|
|
except json.JSONDecodeError as e:
|
|
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
|
|
|
|
new_data_parallel_size = body.get("new_data_parallel_size")
|
|
drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes
|
|
|
|
if new_data_parallel_size is None:
|
|
raise HTTPException(
|
|
status_code=400, detail="new_data_parallel_size is required"
|
|
)
|
|
|
|
if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="new_data_parallel_size must be a positive integer",
|
|
)
|
|
|
|
if not isinstance(drain_timeout, int) or drain_timeout <= 0:
|
|
raise HTTPException(
|
|
status_code=400, detail="drain_timeout must be a positive integer"
|
|
)
|
|
|
|
# Set scaling flag to prevent new requests
|
|
set_scaling_elastic_ep(True)
|
|
client = engine_client(raw_request)
|
|
try:
|
|
await client.scale_elastic_ep(new_data_parallel_size, drain_timeout)
|
|
return JSONResponse(
|
|
{
|
|
"message": f"Scaled to {new_data_parallel_size} data parallel engines",
|
|
}
|
|
)
|
|
except TimeoutError as e:
|
|
raise HTTPException(
|
|
status_code=408,
|
|
detail="Scale failed due to request drain timeout "
|
|
f"after {drain_timeout} seconds",
|
|
) from e
|
|
except Exception as e:
|
|
logger.error("Scale failed: %s", e)
|
|
raise HTTPException(status_code=500, detail="Scale failed") from e
|
|
finally:
|
|
set_scaling_elastic_ep(False)
|
|
|
|
|
|
@router.post("/is_scaling_elastic_ep")
|
|
async def is_scaling_elastic_ep(raw_request: Request):
|
|
return JSONResponse({"is_scaling_elastic_ep": get_scaling_elastic_ep()})
|
|
|
|
|
|
def attach_router(app: FastAPI):
|
|
app.include_router(router)
|