Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
172
vllm/entrypoints/serve/rlhf/api_router.py
Normal file
172
vllm/entrypoints/serve/rlhf/api_router.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, FastAPI, HTTPException, Query, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.weight_transfer.base import (
|
||||
WeightTransferInitRequest,
|
||||
WeightTransferUpdateRequest,
|
||||
)
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.engine import PauseMode
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/pause")
|
||||
async def pause_generation(
|
||||
raw_request: Request,
|
||||
mode: Annotated[PauseMode, Query()] = "abort",
|
||||
wait_for_inflight_requests: bool = Query(False),
|
||||
clear_cache: Annotated[bool, Query()] = True,
|
||||
) -> JSONResponse:
|
||||
"""Pause generation requests to allow weight updates.
|
||||
|
||||
Args:
|
||||
mode: How to handle in-flight requests:
|
||||
- ``"abort"``: Abort all in-flight requests immediately (default).
|
||||
- ``"wait"``: Wait for in-flight requests to complete.
|
||||
- ``"keep"``: Freeze requests in queue; they resume on /resume.
|
||||
wait_for_inflight_requests: DEPRECATED. Use ``mode="wait"`` instead.
|
||||
clear_cache: DEPRECATED. Whether to clear KV/prefix caches after
|
||||
draining. Ignored when mode="keep".
|
||||
"""
|
||||
|
||||
engine = engine_client(raw_request)
|
||||
|
||||
try:
|
||||
await engine.pause_generation(
|
||||
mode=mode,
|
||||
clear_cache=clear_cache,
|
||||
wait_for_inflight_requests=wait_for_inflight_requests,
|
||||
)
|
||||
return JSONResponse(
|
||||
content={"status": "paused"},
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
|
||||
except ValueError as err:
|
||||
return JSONResponse(
|
||||
content={"error": str(err)},
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
)
|
||||
except Exception as err: # pragma: no cover - defensive
|
||||
logger.exception("Failed to pause generation")
|
||||
return JSONResponse(
|
||||
content={"error": f"Failed to pause generation: {err}"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/resume")
|
||||
async def resume_generation(raw_request: Request) -> JSONResponse:
|
||||
"""Resume generation after a pause."""
|
||||
|
||||
engine = engine_client(raw_request)
|
||||
|
||||
try:
|
||||
await engine.resume_generation()
|
||||
return JSONResponse(
|
||||
content={"status": "resumed"},
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
except Exception as err: # pragma: no cover - defensive
|
||||
logger.exception("Failed to resume generation")
|
||||
return JSONResponse(
|
||||
content={"error": f"Failed to resume generation: {err}"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/is_paused")
|
||||
async def is_paused(raw_request: Request) -> JSONResponse:
|
||||
"""Return the current pause status."""
|
||||
|
||||
engine = engine_client(raw_request)
|
||||
|
||||
try:
|
||||
paused = await engine.is_paused()
|
||||
except Exception as err: # pragma: no cover - defensive
|
||||
logger.exception("Failed to fetch pause status")
|
||||
return JSONResponse(
|
||||
content={"error": f"Failed to fetch pause status: {err}"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
)
|
||||
|
||||
return JSONResponse(content={"is_paused": paused})
|
||||
|
||||
|
||||
@router.post("/init_weight_transfer_engine")
|
||||
async def init_weight_transfer_engine(raw_request: Request):
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
|
||||
init_info = body.get("init_info")
|
||||
if init_info is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail="Missing 'init_info' in request body",
|
||||
)
|
||||
await engine_client(raw_request).init_weight_transfer_engine(
|
||||
WeightTransferInitRequest(init_info=init_info)
|
||||
)
|
||||
return JSONResponse(content={"message": "Weight transfer initialized"})
|
||||
|
||||
|
||||
@router.post("/update_weights")
|
||||
async def update_weights(raw_request: Request):
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
|
||||
update_info = body.get("update_info")
|
||||
if update_info is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail="Missing 'update_info' in request body",
|
||||
)
|
||||
await engine_client(raw_request).update_weights(
|
||||
request=WeightTransferUpdateRequest(update_info=update_info)
|
||||
)
|
||||
return JSONResponse(content={"message": "Weights updated"})
|
||||
|
||||
|
||||
@router.get("/get_world_size")
|
||||
async def get_world_size(
|
||||
raw_request: Request,
|
||||
include_dp: bool = Query(True),
|
||||
):
|
||||
"""Get the world size from the parallel config.
|
||||
|
||||
Args:
|
||||
include_dp: If True (default), returns the world size including
|
||||
data parallelism (TP * PP * DP). If False, returns the world
|
||||
size without data parallelism (TP * PP).
|
||||
"""
|
||||
parallel_config = engine_client(raw_request).vllm_config.parallel_config
|
||||
if include_dp:
|
||||
world_size = parallel_config.world_size_across_dp
|
||||
else:
|
||||
world_size = parallel_config.world_size
|
||||
return JSONResponse(content={"world_size": world_size})
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
if not envs.VLLM_SERVER_DEV_MODE:
|
||||
return
|
||||
app.include_router(router)
|
||||
Reference in New Issue
Block a user