173 lines
5.6 KiB
Python
173 lines
5.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import json
|
|
from http import HTTPStatus
|
|
from typing import Annotated
|
|
|
|
from fastapi import APIRouter, FastAPI, HTTPException, Query, Request
|
|
from fastapi.responses import JSONResponse
|
|
|
|
import vllm.envs as envs
|
|
from vllm.distributed.weight_transfer.base import (
|
|
WeightTransferInitRequest,
|
|
WeightTransferUpdateRequest,
|
|
)
|
|
from vllm.engine.protocol import EngineClient
|
|
from vllm.logger import init_logger
|
|
from vllm.v1.engine import PauseMode
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def engine_client(request: Request) -> EngineClient:
|
|
return request.app.state.engine_client
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.post("/pause")
|
|
async def pause_generation(
|
|
raw_request: Request,
|
|
mode: Annotated[PauseMode, Query()] = "abort",
|
|
wait_for_inflight_requests: bool = Query(False),
|
|
clear_cache: Annotated[bool, Query()] = True,
|
|
) -> JSONResponse:
|
|
"""Pause generation requests to allow weight updates.
|
|
|
|
Args:
|
|
mode: How to handle in-flight requests:
|
|
- ``"abort"``: Abort all in-flight requests immediately (default).
|
|
- ``"wait"``: Wait for in-flight requests to complete.
|
|
- ``"keep"``: Freeze requests in queue; they resume on /resume.
|
|
wait_for_inflight_requests: DEPRECATED. Use ``mode="wait"`` instead.
|
|
clear_cache: DEPRECATED. Whether to clear KV/prefix caches after
|
|
draining. Ignored when mode="keep".
|
|
"""
|
|
|
|
engine = engine_client(raw_request)
|
|
|
|
try:
|
|
await engine.pause_generation(
|
|
mode=mode,
|
|
clear_cache=clear_cache,
|
|
wait_for_inflight_requests=wait_for_inflight_requests,
|
|
)
|
|
return JSONResponse(
|
|
content={"status": "paused"},
|
|
status_code=HTTPStatus.OK.value,
|
|
)
|
|
|
|
except ValueError as err:
|
|
return JSONResponse(
|
|
content={"error": str(err)},
|
|
status_code=HTTPStatus.BAD_REQUEST.value,
|
|
)
|
|
except Exception as err: # pragma: no cover - defensive
|
|
logger.exception("Failed to pause generation")
|
|
return JSONResponse(
|
|
content={"error": f"Failed to pause generation: {err}"},
|
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
|
)
|
|
|
|
|
|
@router.post("/resume")
|
|
async def resume_generation(raw_request: Request) -> JSONResponse:
|
|
"""Resume generation after a pause."""
|
|
|
|
engine = engine_client(raw_request)
|
|
|
|
try:
|
|
await engine.resume_generation()
|
|
return JSONResponse(
|
|
content={"status": "resumed"},
|
|
status_code=HTTPStatus.OK.value,
|
|
)
|
|
except Exception as err: # pragma: no cover - defensive
|
|
logger.exception("Failed to resume generation")
|
|
return JSONResponse(
|
|
content={"error": f"Failed to resume generation: {err}"},
|
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
|
)
|
|
|
|
|
|
@router.get("/is_paused")
|
|
async def is_paused(raw_request: Request) -> JSONResponse:
|
|
"""Return the current pause status."""
|
|
|
|
engine = engine_client(raw_request)
|
|
|
|
try:
|
|
paused = await engine.is_paused()
|
|
except Exception as err: # pragma: no cover - defensive
|
|
logger.exception("Failed to fetch pause status")
|
|
return JSONResponse(
|
|
content={"error": f"Failed to fetch pause status: {err}"},
|
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
|
)
|
|
|
|
return JSONResponse(content={"is_paused": paused})
|
|
|
|
|
|
@router.post("/init_weight_transfer_engine")
|
|
async def init_weight_transfer_engine(raw_request: Request):
|
|
try:
|
|
body = await raw_request.json()
|
|
except json.JSONDecodeError as e:
|
|
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
|
|
init_info = body.get("init_info")
|
|
if init_info is None:
|
|
raise HTTPException(
|
|
status_code=HTTPStatus.BAD_REQUEST.value,
|
|
detail="Missing 'init_info' in request body",
|
|
)
|
|
await engine_client(raw_request).init_weight_transfer_engine(
|
|
WeightTransferInitRequest(init_info=init_info)
|
|
)
|
|
return JSONResponse(content={"message": "Weight transfer initialized"})
|
|
|
|
|
|
@router.post("/update_weights")
|
|
async def update_weights(raw_request: Request):
|
|
try:
|
|
body = await raw_request.json()
|
|
except json.JSONDecodeError as e:
|
|
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
|
|
update_info = body.get("update_info")
|
|
if update_info is None:
|
|
raise HTTPException(
|
|
status_code=HTTPStatus.BAD_REQUEST.value,
|
|
detail="Missing 'update_info' in request body",
|
|
)
|
|
await engine_client(raw_request).update_weights(
|
|
request=WeightTransferUpdateRequest(update_info=update_info)
|
|
)
|
|
return JSONResponse(content={"message": "Weights updated"})
|
|
|
|
|
|
@router.get("/get_world_size")
|
|
async def get_world_size(
|
|
raw_request: Request,
|
|
include_dp: bool = Query(True),
|
|
):
|
|
"""Get the world size from the parallel config.
|
|
|
|
Args:
|
|
include_dp: If True (default), returns the world size including
|
|
data parallelism (TP * PP * DP). If False, returns the world
|
|
size without data parallelism (TP * PP).
|
|
"""
|
|
parallel_config = engine_client(raw_request).vllm_config.parallel_config
|
|
if include_dp:
|
|
world_size = parallel_config.world_size_across_dp
|
|
else:
|
|
world_size = parallel_config.world_size
|
|
return JSONResponse(content={"world_size": world_size})
|
|
|
|
|
|
def attach_router(app: FastAPI):
|
|
if not envs.VLLM_SERVER_DEV_MODE:
|
|
return
|
|
app.include_router(router)
|