# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json from http import HTTPStatus from typing import Annotated from fastapi import APIRouter, FastAPI, HTTPException, Query, Request from fastapi.responses import JSONResponse import vllm.envs as envs from vllm.distributed.weight_transfer.base import ( WeightTransferInitRequest, WeightTransferUpdateRequest, ) from vllm.engine.protocol import EngineClient from vllm.logger import init_logger from vllm.v1.engine import PauseMode logger = init_logger(__name__) def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client router = APIRouter() @router.post("/pause") async def pause_generation( raw_request: Request, mode: Annotated[PauseMode, Query()] = "abort", wait_for_inflight_requests: bool = Query(False), clear_cache: Annotated[bool, Query()] = True, ) -> JSONResponse: """Pause generation requests to allow weight updates. Args: mode: How to handle in-flight requests: - ``"abort"``: Abort all in-flight requests immediately (default). - ``"wait"``: Wait for in-flight requests to complete. - ``"keep"``: Freeze requests in queue; they resume on /resume. wait_for_inflight_requests: DEPRECATED. Use ``mode="wait"`` instead. clear_cache: DEPRECATED. Whether to clear KV/prefix caches after draining. Ignored when mode="keep". """ engine = engine_client(raw_request) try: await engine.pause_generation( mode=mode, clear_cache=clear_cache, wait_for_inflight_requests=wait_for_inflight_requests, ) return JSONResponse( content={"status": "paused"}, status_code=HTTPStatus.OK.value, ) except ValueError as err: return JSONResponse( content={"error": str(err)}, status_code=HTTPStatus.BAD_REQUEST.value, ) except Exception as err: # pragma: no cover - defensive logger.exception("Failed to pause generation") return JSONResponse( content={"error": f"Failed to pause generation: {err}"}, status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, ) @router.post("/resume") async def resume_generation(raw_request: Request) -> JSONResponse: """Resume generation after a pause.""" engine = engine_client(raw_request) try: await engine.resume_generation() return JSONResponse( content={"status": "resumed"}, status_code=HTTPStatus.OK.value, ) except Exception as err: # pragma: no cover - defensive logger.exception("Failed to resume generation") return JSONResponse( content={"error": f"Failed to resume generation: {err}"}, status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, ) @router.get("/is_paused") async def is_paused(raw_request: Request) -> JSONResponse: """Return the current pause status.""" engine = engine_client(raw_request) try: paused = await engine.is_paused() except Exception as err: # pragma: no cover - defensive logger.exception("Failed to fetch pause status") return JSONResponse( content={"error": f"Failed to fetch pause status: {err}"}, status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, ) return JSONResponse(content={"is_paused": paused}) @router.post("/init_weight_transfer_engine") async def init_weight_transfer_engine(raw_request: Request): try: body = await raw_request.json() except json.JSONDecodeError as e: raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 init_info = body.get("init_info") if init_info is None: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST.value, detail="Missing 'init_info' in request body", ) await engine_client(raw_request).init_weight_transfer_engine( WeightTransferInitRequest(init_info=init_info) ) return JSONResponse(content={"message": "Weight transfer initialized"}) @router.post("/update_weights") async def update_weights(raw_request: Request): try: body = await raw_request.json() except json.JSONDecodeError as e: raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 update_info = body.get("update_info") if update_info is None: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST.value, detail="Missing 'update_info' in request body", ) await engine_client(raw_request).update_weights( request=WeightTransferUpdateRequest(update_info=update_info) ) return JSONResponse(content={"message": "Weights updated"}) @router.get("/get_world_size") async def get_world_size( raw_request: Request, include_dp: bool = Query(True), ): """Get the world size from the parallel config. Args: include_dp: If True (default), returns the world size including data parallelism (TP * PP * DP). If False, returns the world size without data parallelism (TP * PP). """ parallel_config = engine_client(raw_request).vllm_config.parallel_config if include_dp: world_size = parallel_config.world_size_across_dp else: world_size = parallel_config.world_size return JSONResponse(content={"world_size": world_size}) def attach_router(app: FastAPI): if not envs.VLLM_SERVER_DEV_MODE: return app.include_router(router)