Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
57
vllm/entrypoints/serve/__init__.py
Normal file
57
vllm/entrypoints/serve/__init__.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def register_vllm_serve_api_routers(app: FastAPI):
|
||||
if envs.VLLM_SERVER_DEV_MODE:
|
||||
logger.warning(
|
||||
"SECURITY WARNING: Development endpoints are enabled! "
|
||||
"This should NOT be used in production!"
|
||||
)
|
||||
|
||||
from vllm.entrypoints.serve.lora.api_router import (
|
||||
attach_router as attach_lora_router,
|
||||
)
|
||||
|
||||
attach_lora_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.profile.api_router import (
|
||||
attach_router as attach_profile_router,
|
||||
)
|
||||
|
||||
attach_profile_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.sleep.api_router import (
|
||||
attach_router as attach_sleep_router,
|
||||
)
|
||||
|
||||
attach_sleep_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.rpc.api_router import (
|
||||
attach_router as attach_rpc_router,
|
||||
)
|
||||
|
||||
attach_rpc_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.cache.api_router import (
|
||||
attach_router as attach_cache_router,
|
||||
)
|
||||
|
||||
attach_cache_router(app)
|
||||
|
||||
from vllm.entrypoints.serve.tokenize.api_router import (
|
||||
attach_router as attach_tokenize_router,
|
||||
)
|
||||
|
||||
attach_tokenize_router(app)
|
||||
|
||||
from .instrumentator import register_instrumentator_api_routers
|
||||
|
||||
register_instrumentator_api_routers(app)
|
||||
0
vllm/entrypoints/serve/cache/__init__.py
vendored
Normal file
0
vllm/entrypoints/serve/cache/__init__.py
vendored
Normal file
72
vllm/entrypoints/serve/cache/api_router.py
vendored
Normal file
72
vllm/entrypoints/serve/cache/api_router.py
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Query, Request
|
||||
from fastapi.responses import Response
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.post("/reset_prefix_cache")
|
||||
async def reset_prefix_cache(
|
||||
raw_request: Request,
|
||||
reset_running_requests: bool = Query(default=False),
|
||||
reset_external: bool = Query(default=False),
|
||||
):
|
||||
"""
|
||||
Reset the local prefix cache.
|
||||
|
||||
Optionally, if the query parameter `reset_external=true`
|
||||
also resets the external (connector-managed) prefix cache.
|
||||
|
||||
Note that we currently do not check if the prefix cache
|
||||
is successfully reset in the API server.
|
||||
|
||||
Example:
|
||||
POST /reset_prefix_cache?reset_external=true
|
||||
"""
|
||||
logger.info("Resetting prefix cache...")
|
||||
|
||||
await engine_client(raw_request).reset_prefix_cache(
|
||||
reset_running_requests, reset_external
|
||||
)
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.post("/reset_mm_cache")
|
||||
async def reset_mm_cache(raw_request: Request):
|
||||
"""
|
||||
Reset the multi-modal cache. Note that we currently do not check if the
|
||||
multi-modal cache is successfully reset in the API server.
|
||||
"""
|
||||
logger.info("Resetting multi-modal cache...")
|
||||
await engine_client(raw_request).reset_mm_cache()
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.post("/reset_encoder_cache")
|
||||
async def reset_encoder_cache(raw_request: Request):
|
||||
"""
|
||||
Reset the encoder cache. Note that we currently do not check if the
|
||||
encoder cache is successfully reset in the API server.
|
||||
"""
|
||||
logger.info("Resetting encoder cache...")
|
||||
await engine_client(raw_request).reset_encoder_cache()
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
if not envs.VLLM_SERVER_DEV_MODE:
|
||||
return
|
||||
app.include_router(router)
|
||||
0
vllm/entrypoints/serve/disagg/__init__.py
Normal file
0
vllm/entrypoints/serve/disagg/__init__.py
Normal file
109
vllm/entrypoints/serve/disagg/api_router.py
Normal file
109
vllm/entrypoints/serve/disagg/api_router.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.serve.disagg.protocol import (
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
)
|
||||
from vllm.entrypoints.serve.disagg.serving import (
|
||||
ServingTokens,
|
||||
)
|
||||
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
|
||||
from vllm.entrypoints.utils import (
|
||||
load_aware_call,
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
def generate_tokens(request: Request) -> ServingTokens | None:
|
||||
return request.app.state.serving_tokens
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/inference/v1/generate",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def generate(request: GenerateRequest, raw_request: Request):
|
||||
handler = generate_tokens(raw_request)
|
||||
if handler is None:
|
||||
return tokenization(raw_request).create_error_response(
|
||||
message="The model does not support generate tokens API"
|
||||
)
|
||||
try:
|
||||
generator = await handler.serve_tokens(request, raw_request)
|
||||
except Exception as e:
|
||||
generator = handler.create_error_response(e)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, GenerateResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
if getattr(app.state.args, "tokens_only", False):
|
||||
|
||||
@router.post("/abort_requests")
|
||||
async def abort_requests(raw_request: Request):
|
||||
"""
|
||||
Abort one or more requests. To be used in a
|
||||
Disaggregated Everything setup.
|
||||
"""
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail=f"JSON decode error: {e}",
|
||||
) from e
|
||||
request_ids = body.get("request_ids")
|
||||
if request_ids is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail="Missing 'request_ids' in request body",
|
||||
)
|
||||
# Abort requests in background
|
||||
asyncio.create_task(engine_client(raw_request).abort(request_ids))
|
||||
return Response(status_code=200)
|
||||
|
||||
app.include_router(router)
|
||||
98
vllm/entrypoints/serve/disagg/protocol.py
Normal file
98
vllm/entrypoints/serve/disagg/protocol.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionLogProbs
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
SamplingParams,
|
||||
StreamOptions,
|
||||
)
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
####### Tokens IN <> Tokens OUT #######
|
||||
class GenerateRequest(BaseModel):
|
||||
request_id: str = Field(
|
||||
default_factory=lambda: f"{random_uuid()}",
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
token_ids: list[int]
|
||||
"""The token ids to generate text from."""
|
||||
|
||||
# features: MultiModalFeatureSpec
|
||||
# TODO (NickLucche): implement once Renderer work is completed
|
||||
features: str | None = None
|
||||
"""The processed MM inputs for the model."""
|
||||
|
||||
sampling_params: SamplingParams
|
||||
"""The sampling parameters for the model."""
|
||||
|
||||
model: str | None = None
|
||||
|
||||
stream: bool | None = False
|
||||
stream_options: StreamOptions | None = None
|
||||
cache_salt: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the prefix cache will be salted with the provided "
|
||||
"string to prevent an attacker to guess prompts in multi-user "
|
||||
"environments. The salt should be random, protected from "
|
||||
"access by 3rd parties, and long enough to be "
|
||||
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||
"to 256 bit)."
|
||||
),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
kv_transfer_params: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="KVTransfer parameters used for disaggregated serving.",
|
||||
)
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
return TokenizeParams(
|
||||
max_total_tokens=None,
|
||||
max_output_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
class GenerateResponseChoice(BaseModel):
|
||||
index: int
|
||||
logprobs: ChatCompletionLogProbs | None = None
|
||||
# per OpenAI spec this is the default
|
||||
finish_reason: str | None = "stop"
|
||||
token_ids: list[int] | None = None
|
||||
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
request_id: str = Field(
|
||||
default_factory=lambda: f"{random_uuid()}",
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
choices: list[GenerateResponseChoice]
|
||||
|
||||
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
|
||||
|
||||
kv_transfer_params: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="KVTransfer parameters used for disaggregated serving.",
|
||||
)
|
||||
286
vllm/entrypoints/serve/disagg/serving.py
Normal file
286
vllm/entrypoints/serve/disagg/serving.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionLogProb,
|
||||
ChatCompletionLogProbs,
|
||||
ChatCompletionLogProbsContent,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
PromptTokenUsageInfo,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing, clamp_prompt_logprobs
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.serve.disagg.protocol import (
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
GenerateResponseChoice,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.collection_utils import as_list
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ServingTokens(OpenAIServing):
|
||||
"""Provides Tokens IN <> Tokens OUT functionality to vLLM API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
force_no_detokenize: bool = False,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
enable_prompt_tokens_details: bool = False,
|
||||
enable_log_outputs: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||
self.enable_log_outputs = enable_log_outputs
|
||||
self.force_no_detokenize = force_no_detokenize
|
||||
if force_no_detokenize:
|
||||
logger.info(
|
||||
"Tokens-only mode is enabled, skipping detokenization "
|
||||
"step for incoming requests."
|
||||
)
|
||||
|
||||
async def serve_tokens(
|
||||
self,
|
||||
request: GenerateRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> GenerateResponse | ErrorResponse:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
logger.error("Error with model %s", error_check_ret)
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
lora_request = None
|
||||
lora_request = self._maybe_get_adapters(request, supports_default_mm_loras=True)
|
||||
|
||||
model_name = self.models.model_name(lora_request)
|
||||
|
||||
request_id = (
|
||||
f"generate-tokens-{self._base_request_id(raw_request, request.request_id)}"
|
||||
)
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
engine_prompts = await self._preprocess_completion(
|
||||
request,
|
||||
prompt_input=request.token_ids,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
assert len(engine_prompts) == 1
|
||||
engine_prompt = engine_prompts[0]
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
result_generator: AsyncGenerator[RequestOutput, None] | None = None
|
||||
try:
|
||||
sampling_params = request.sampling_params
|
||||
if self.force_no_detokenize:
|
||||
sampling_params.detokenize = False
|
||||
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
if raw_request is None
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
)
|
||||
|
||||
result_generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# TODO(NickLucche): Implement streaming response
|
||||
|
||||
try:
|
||||
assert result_generator is not None
|
||||
return await self.serve_tokens_full_generator(
|
||||
request, result_generator, request_id, model_name, request_metadata
|
||||
)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
async def serve_tokens_full_generator(
|
||||
self,
|
||||
request: GenerateRequest,
|
||||
result_generator: AsyncGenerator[RequestOutput, None],
|
||||
request_id: str,
|
||||
model_name: str,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> ErrorResponse | GenerateResponse:
|
||||
created_time = int(time.time())
|
||||
final_res: RequestOutput | None = None
|
||||
sampling_params: SamplingParams = request.sampling_params
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
final_res = res
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
assert final_res is not None
|
||||
|
||||
choices: list[GenerateResponseChoice] = []
|
||||
num_generated_tokens = 0
|
||||
for output in final_res.outputs:
|
||||
token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
|
||||
# This is top_logprobs in completions API
|
||||
if sampling_params.logprobs is not None:
|
||||
assert out_logprobs is not None, "Did not output logprobs"
|
||||
logprobs = self._create_tokens_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=sampling_params.logprobs,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
choice_data = GenerateResponseChoice(
|
||||
index=output.index,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason if output.finish_reason else "stop",
|
||||
token_ids=as_list(output.token_ids),
|
||||
)
|
||||
|
||||
choices.append(choice_data)
|
||||
num_generated_tokens += len(output.token_ids)
|
||||
|
||||
assert final_res.prompt_token_ids is not None
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
if final_res.encoder_prompt_token_ids is not None:
|
||||
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
|
||||
# This info is not available at the /coordinator level
|
||||
usage.prompt_tokens_details = PromptTokenUsageInfo(
|
||||
cached_tokens=final_res.num_cached_tokens
|
||||
)
|
||||
|
||||
request_metadata.final_usage_info = usage
|
||||
|
||||
response = GenerateResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
|
||||
kv_transfer_params=final_res.kv_transfer_params,
|
||||
)
|
||||
|
||||
# Log complete response if output logging is enabled
|
||||
if self.enable_log_outputs and self.request_logger:
|
||||
for choice in choices:
|
||||
# Get the corresponding output token IDs
|
||||
output_token_ids = None
|
||||
if choice.index < len(final_res.outputs):
|
||||
output_token_ids = final_res.outputs[choice.index].token_ids
|
||||
|
||||
if output_token_ids:
|
||||
# Log token_ids only.
|
||||
self.request_logger.log_outputs(
|
||||
request_id=request_id,
|
||||
outputs="",
|
||||
output_token_ids=output_token_ids,
|
||||
finish_reason=choice.finish_reason,
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _create_tokens_logprobs(
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[dict[int, Logprob] | None],
|
||||
num_output_top_logprobs: int | None = None,
|
||||
) -> ChatCompletionLogProbs:
|
||||
"""Create OpenAI-style logprobs."""
|
||||
logprobs_content: list[ChatCompletionLogProbsContent] = []
|
||||
|
||||
for i, token_id in enumerate(token_ids):
|
||||
token = f"token_id:{token_id}"
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None or step_top_logprobs.get(token_id) is None:
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=token,
|
||||
)
|
||||
)
|
||||
else:
|
||||
step_token = step_top_logprobs[token_id]
|
||||
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=token,
|
||||
logprob=max(step_token.logprob, -9999.0),
|
||||
top_logprobs=[
|
||||
ChatCompletionLogProb(
|
||||
token=token,
|
||||
logprob=max(p[1].logprob, -9999.0),
|
||||
)
|
||||
for i, p in enumerate(step_top_logprobs.items())
|
||||
if num_output_top_logprobs is not None
|
||||
and i < max(num_output_top_logprobs, 1)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return ChatCompletionLogProbs(content=logprobs_content)
|
||||
0
vllm/entrypoints/serve/elastic_ep/__init__.py
Normal file
0
vllm/entrypoints/serve/elastic_ep/__init__.py
Normal file
96
vllm/entrypoints/serve/elastic_ep/api_router.py
Normal file
96
vllm/entrypoints/serve/elastic_ep/api_router.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.serve.elastic_ep.middleware import (
|
||||
get_scaling_elastic_ep,
|
||||
set_scaling_elastic_ep,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/scale_elastic_ep",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"model": dict},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def scale_elastic_ep(raw_request: Request):
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON format") from e
|
||||
|
||||
new_data_parallel_size = body.get("new_data_parallel_size")
|
||||
drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes
|
||||
|
||||
if new_data_parallel_size is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="new_data_parallel_size is required"
|
||||
)
|
||||
|
||||
if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="new_data_parallel_size must be a positive integer",
|
||||
)
|
||||
|
||||
if not isinstance(drain_timeout, int) or drain_timeout <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="drain_timeout must be a positive integer"
|
||||
)
|
||||
|
||||
# Set scaling flag to prevent new requests
|
||||
set_scaling_elastic_ep(True)
|
||||
client = engine_client(raw_request)
|
||||
try:
|
||||
await client.scale_elastic_ep(new_data_parallel_size, drain_timeout)
|
||||
return JSONResponse(
|
||||
{
|
||||
"message": f"Scaled to {new_data_parallel_size} data parallel engines",
|
||||
}
|
||||
)
|
||||
except TimeoutError as e:
|
||||
raise HTTPException(
|
||||
status_code=408,
|
||||
detail="Scale failed due to request drain timeout "
|
||||
f"after {drain_timeout} seconds",
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error("Scale failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail="Scale failed") from e
|
||||
finally:
|
||||
set_scaling_elastic_ep(False)
|
||||
|
||||
|
||||
@router.post("/is_scaling_elastic_ep")
|
||||
async def is_scaling_elastic_ep(raw_request: Request):
|
||||
return JSONResponse({"is_scaling_elastic_ep": get_scaling_elastic_ep()})
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
49
vllm/entrypoints/serve/elastic_ep/middleware.py
Normal file
49
vllm/entrypoints/serve/elastic_ep/middleware.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Awaitable
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
# Global variable to track scaling state
|
||||
_scaling_elastic_ep = False
|
||||
|
||||
|
||||
def get_scaling_elastic_ep():
|
||||
return _scaling_elastic_ep
|
||||
|
||||
|
||||
def set_scaling_elastic_ep(value):
|
||||
global _scaling_elastic_ep
|
||||
_scaling_elastic_ep = value
|
||||
|
||||
|
||||
class ScalingMiddleware:
|
||||
"""
|
||||
Middleware that checks if the model is currently scaling and
|
||||
returns a 503 Service Unavailable response if it is.
|
||||
|
||||
This middleware applies to all HTTP requests and prevents
|
||||
processing when the model is in a scaling state.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
|
||||
if scope["type"] != "http":
|
||||
return self.app(scope, receive, send)
|
||||
|
||||
# Check global scaling state
|
||||
if get_scaling_elastic_ep():
|
||||
# Return 503 Service Unavailable response
|
||||
response = JSONResponse(
|
||||
content={
|
||||
"error": "The model is currently scaling. Please try again later."
|
||||
},
|
||||
status_code=503,
|
||||
)
|
||||
return response(scope, receive, send)
|
||||
|
||||
return self.app(scope, receive, send)
|
||||
29
vllm/entrypoints/serve/instrumentator/__init__.py
Normal file
29
vllm/entrypoints/serve/instrumentator/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from vllm import envs
|
||||
|
||||
|
||||
def register_instrumentator_api_routers(app: FastAPI):
|
||||
from .basic import router as basic_router
|
||||
|
||||
app.include_router(basic_router)
|
||||
|
||||
from .health import router as health_router
|
||||
|
||||
app.include_router(health_router)
|
||||
|
||||
from .metrics import attach_router as metrics_attach_router
|
||||
|
||||
metrics_attach_router(app)
|
||||
|
||||
from .offline_docs import attach_router as offline_docs_attach_router
|
||||
|
||||
offline_docs_attach_router(app)
|
||||
|
||||
if envs.VLLM_SERVER_DEV_MODE:
|
||||
from .server_info import router as server_info_router
|
||||
|
||||
app.include_router(server_info_router)
|
||||
57
vllm/entrypoints/serve/instrumentator/basic.py
Normal file
57
vllm/entrypoints/serve/instrumentator/basic.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
|
||||
from vllm.logger import init_logger
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def base(request: Request) -> OpenAIServing:
|
||||
# Reuse the existing instance
|
||||
return tokenization(request)
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.get("/load")
|
||||
async def get_server_load_metrics(request: Request):
|
||||
# This endpoint returns the current server load metrics.
|
||||
# It tracks requests utilizing the GPU from the following routes:
|
||||
# - /v1/responses
|
||||
# - /v1/responses/{response_id}
|
||||
# - /v1/responses/{response_id}/cancel
|
||||
# - /v1/messages
|
||||
# - /v1/chat/completions
|
||||
# - /v1/completions
|
||||
# - /v1/audio/transcriptions
|
||||
# - /v1/audio/translations
|
||||
# - /v1/embeddings
|
||||
# - /pooling
|
||||
# - /classify
|
||||
# - /score
|
||||
# - /v1/score
|
||||
# - /rerank
|
||||
# - /v1/rerank
|
||||
# - /v2/rerank
|
||||
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
|
||||
|
||||
|
||||
@router.get("/version")
|
||||
async def show_version():
|
||||
ver = {"version": VLLM_VERSION}
|
||||
return JSONResponse(content=ver)
|
||||
29
vllm/entrypoints/serve/instrumentator/health.py
Normal file
29
vllm/entrypoints/serve/instrumentator/health.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import Response
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.engine.exceptions import EngineDeadError
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.get("/health", response_class=Response)
|
||||
async def health(raw_request: Request) -> Response:
|
||||
"""Health check."""
|
||||
try:
|
||||
await engine_client(raw_request).check_health()
|
||||
return Response(status_code=200)
|
||||
except EngineDeadError:
|
||||
return Response(status_code=503)
|
||||
45
vllm/entrypoints/serve/instrumentator/metrics.py
Normal file
45
vllm/entrypoints/serve/instrumentator/metrics.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import prometheus_client
|
||||
import regex as re
|
||||
from fastapi import FastAPI, Response
|
||||
from prometheus_client import make_asgi_app
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from starlette.routing import Mount
|
||||
|
||||
from vllm.v1.metrics.prometheus import get_prometheus_registry
|
||||
|
||||
|
||||
class PrometheusResponse(Response):
|
||||
media_type = prometheus_client.CONTENT_TYPE_LATEST
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
"""Mount prometheus metrics to a FastAPI app."""
|
||||
|
||||
registry = get_prometheus_registry()
|
||||
|
||||
# `response_class=PrometheusResponse` is needed to return an HTTP response
|
||||
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
|
||||
# instead of the default "application/json" which is incorrect.
|
||||
# See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
|
||||
Instrumentator(
|
||||
excluded_handlers=[
|
||||
"/metrics",
|
||||
"/health",
|
||||
"/load",
|
||||
"/ping",
|
||||
"/version",
|
||||
"/server_info",
|
||||
],
|
||||
registry=registry,
|
||||
).add().instrument(app).expose(app, response_class=PrometheusResponse)
|
||||
|
||||
# Add prometheus asgi middleware to route /metrics requests
|
||||
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||
|
||||
# Workaround for 307 Redirect for /metrics
|
||||
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||
app.routes.append(metrics_route)
|
||||
50
vllm/entrypoints/serve/instrumentator/offline_docs.py
Normal file
50
vllm/entrypoints/serve/instrumentator/offline_docs.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Offline FastAPI documentation support for air-gapped environments."""
|
||||
|
||||
import pathlib
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.docs import (
|
||||
get_swagger_ui_html,
|
||||
get_swagger_ui_oauth2_redirect_html,
|
||||
)
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def attach_router(app: FastAPI) -> None:
|
||||
"""Attach offline docs router if enabled via args."""
|
||||
args = getattr(app.state, "args", None)
|
||||
if args is None or not getattr(args, "enable_offline_docs", False):
|
||||
return
|
||||
|
||||
static_dir = pathlib.Path(__file__).parent / "static"
|
||||
|
||||
if not static_dir.exists():
|
||||
logger.warning(
|
||||
"Static directory not found at %s. Offline docs will not be available.",
|
||||
static_dir,
|
||||
)
|
||||
return
|
||||
|
||||
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
async def custom_swagger_ui_html():
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=app.openapi_url,
|
||||
title=app.title + " - Swagger UI",
|
||||
oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
|
||||
swagger_js_url="/static/swagger-ui-bundle.js",
|
||||
swagger_css_url="/static/swagger-ui.css",
|
||||
)
|
||||
|
||||
@app.get(app.swagger_ui_oauth2_redirect_url, include_in_schema=False)
|
||||
async def swagger_ui_redirect():
|
||||
return get_swagger_ui_oauth2_redirect_html()
|
||||
|
||||
logger.info("Offline documentation enabled with vendored static assets")
|
||||
59
vllm/entrypoints/serve/instrumentator/server_info.py
Normal file
59
vllm/entrypoints/serve/instrumentator/server_info.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import pydantic
|
||||
from fastapi import APIRouter, Query, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.collect_env import get_env_info
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
PydanticVllmConfig = pydantic.TypeAdapter(VllmConfig)
|
||||
|
||||
|
||||
def _get_vllm_env_vars():
|
||||
from vllm.config.utils import normalize_value
|
||||
|
||||
vllm_envs = {}
|
||||
for key in dir(envs):
|
||||
if key.startswith("VLLM_") and "KEY" not in key:
|
||||
value = getattr(envs, key, None)
|
||||
if value is not None:
|
||||
value = normalize_value(value)
|
||||
vllm_envs[key] = value
|
||||
return vllm_envs
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _get_system_env_info_cached():
|
||||
return get_env_info()._asdict()
|
||||
|
||||
|
||||
@router.get("/server_info")
|
||||
async def show_server_info(
|
||||
raw_request: Request,
|
||||
config_format: Annotated[Literal["text", "json"], Query()] = "text",
|
||||
):
|
||||
vllm_config: VllmConfig = raw_request.app.state.vllm_config
|
||||
server_info = {
|
||||
"vllm_config": (
|
||||
str(vllm_config)
|
||||
if config_format == "text"
|
||||
else PydanticVllmConfig.dump_python(vllm_config, mode="json", fallback=str)
|
||||
),
|
||||
# fallback=str is needed to handle e.g. torch.dtype
|
||||
"vllm_env": _get_vllm_env_vars(),
|
||||
"system_env": await asyncio.to_thread(_get_system_env_info_cached),
|
||||
}
|
||||
return JSONResponse(content=server_info)
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
0
vllm/entrypoints/serve/lora/__init__.py
Normal file
0
vllm/entrypoints/serve/lora/__init__.py
Normal file
74
vllm/entrypoints/serve/lora/api_router.py
Normal file
74
vllm/entrypoints/serve/lora/api_router.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||
from fastapi import APIRouter, Depends, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.api_router import models
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.serve.lora.protocol import (
|
||||
LoadLoRAAdapterRequest,
|
||||
UnloadLoRAAdapterRequest,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
if not envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
"""If LoRA dynamic loading & unloading is not enabled, do nothing."""
|
||||
return
|
||||
logger.warning(
|
||||
"LoRA dynamic loading & unloading is enabled in the API server. "
|
||||
"This should ONLY be used for local development!"
|
||||
)
|
||||
|
||||
@sagemaker_standards.register_load_adapter_handler(
|
||||
request_shape={
|
||||
"lora_name": "body.name",
|
||||
"lora_path": "body.src",
|
||||
"load_inplace": "body.load_inplace || `false`",
|
||||
},
|
||||
)
|
||||
@router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)])
|
||||
async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request):
|
||||
handler: OpenAIServingModels = models(raw_request)
|
||||
response = await handler.load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=response.model_dump(), status_code=response.error.code
|
||||
)
|
||||
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
@sagemaker_standards.register_unload_adapter_handler(
|
||||
request_shape={
|
||||
"lora_name": "path_params.adapter_name",
|
||||
}
|
||||
)
|
||||
@router.post(
|
||||
"/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)]
|
||||
)
|
||||
async def unload_lora_adapter(
|
||||
request: UnloadLoRAAdapterRequest, raw_request: Request
|
||||
):
|
||||
handler: OpenAIServingModels = models(raw_request)
|
||||
response = await handler.unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=response.model_dump(), status_code=response.error.code
|
||||
)
|
||||
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
# register the router
|
||||
app.include_router(router)
|
||||
15
vllm/entrypoints/serve/lora/protocol.py
Normal file
15
vllm/entrypoints/serve/lora/protocol.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class LoadLoRAAdapterRequest(BaseModel):
|
||||
lora_name: str
|
||||
lora_path: str
|
||||
load_inplace: bool = False
|
||||
|
||||
|
||||
class UnloadLoRAAdapterRequest(BaseModel):
|
||||
lora_name: str
|
||||
lora_int_id: int | None = Field(default=None)
|
||||
0
vllm/entrypoints/serve/profile/__init__.py
Normal file
0
vllm/entrypoints/serve/profile/__init__.py
Normal file
46
vllm/entrypoints/serve/profile/api_router.py
Normal file
46
vllm/entrypoints/serve/profile/api_router.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.responses import Response
|
||||
|
||||
from vllm.config import ProfilerConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.post("/start_profile")
|
||||
async def start_profile(raw_request: Request):
|
||||
logger.info("Starting profiler...")
|
||||
await engine_client(raw_request).start_profile()
|
||||
logger.info("Profiler started.")
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.post("/stop_profile")
|
||||
async def stop_profile(raw_request: Request):
|
||||
logger.info("Stopping profiler...")
|
||||
await engine_client(raw_request).stop_profile()
|
||||
logger.info("Profiler stopped.")
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
profiler_config = getattr(app.state.args, "profiler_config", None)
|
||||
assert profiler_config is None or isinstance(profiler_config, ProfilerConfig)
|
||||
if profiler_config is not None and profiler_config.profiler is not None:
|
||||
logger.warning_once(
|
||||
"Profiler with mode '%s' is enabled in the "
|
||||
"API server. This should ONLY be used for local development!",
|
||||
profiler_config.profiler,
|
||||
)
|
||||
app.include_router(router)
|
||||
0
vllm/entrypoints/serve/rlhf/__init__.py
Normal file
0
vllm/entrypoints/serve/rlhf/__init__.py
Normal file
172
vllm/entrypoints/serve/rlhf/api_router.py
Normal file
172
vllm/entrypoints/serve/rlhf/api_router.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, FastAPI, HTTPException, Query, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.weight_transfer.base import (
|
||||
WeightTransferInitRequest,
|
||||
WeightTransferUpdateRequest,
|
||||
)
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.engine import PauseMode
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/pause")
|
||||
async def pause_generation(
|
||||
raw_request: Request,
|
||||
mode: Annotated[PauseMode, Query()] = "abort",
|
||||
wait_for_inflight_requests: bool = Query(False),
|
||||
clear_cache: Annotated[bool, Query()] = True,
|
||||
) -> JSONResponse:
|
||||
"""Pause generation requests to allow weight updates.
|
||||
|
||||
Args:
|
||||
mode: How to handle in-flight requests:
|
||||
- ``"abort"``: Abort all in-flight requests immediately (default).
|
||||
- ``"wait"``: Wait for in-flight requests to complete.
|
||||
- ``"keep"``: Freeze requests in queue; they resume on /resume.
|
||||
wait_for_inflight_requests: DEPRECATED. Use ``mode="wait"`` instead.
|
||||
clear_cache: DEPRECATED. Whether to clear KV/prefix caches after
|
||||
draining. Ignored when mode="keep".
|
||||
"""
|
||||
|
||||
engine = engine_client(raw_request)
|
||||
|
||||
try:
|
||||
await engine.pause_generation(
|
||||
mode=mode,
|
||||
clear_cache=clear_cache,
|
||||
wait_for_inflight_requests=wait_for_inflight_requests,
|
||||
)
|
||||
return JSONResponse(
|
||||
content={"status": "paused"},
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
|
||||
except ValueError as err:
|
||||
return JSONResponse(
|
||||
content={"error": str(err)},
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
)
|
||||
except Exception as err: # pragma: no cover - defensive
|
||||
logger.exception("Failed to pause generation")
|
||||
return JSONResponse(
|
||||
content={"error": f"Failed to pause generation: {err}"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/resume")
|
||||
async def resume_generation(raw_request: Request) -> JSONResponse:
|
||||
"""Resume generation after a pause."""
|
||||
|
||||
engine = engine_client(raw_request)
|
||||
|
||||
try:
|
||||
await engine.resume_generation()
|
||||
return JSONResponse(
|
||||
content={"status": "resumed"},
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
except Exception as err: # pragma: no cover - defensive
|
||||
logger.exception("Failed to resume generation")
|
||||
return JSONResponse(
|
||||
content={"error": f"Failed to resume generation: {err}"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/is_paused")
|
||||
async def is_paused(raw_request: Request) -> JSONResponse:
|
||||
"""Return the current pause status."""
|
||||
|
||||
engine = engine_client(raw_request)
|
||||
|
||||
try:
|
||||
paused = await engine.is_paused()
|
||||
except Exception as err: # pragma: no cover - defensive
|
||||
logger.exception("Failed to fetch pause status")
|
||||
return JSONResponse(
|
||||
content={"error": f"Failed to fetch pause status: {err}"},
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
)
|
||||
|
||||
return JSONResponse(content={"is_paused": paused})
|
||||
|
||||
|
||||
@router.post("/init_weight_transfer_engine")
|
||||
async def init_weight_transfer_engine(raw_request: Request):
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
|
||||
init_info = body.get("init_info")
|
||||
if init_info is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail="Missing 'init_info' in request body",
|
||||
)
|
||||
await engine_client(raw_request).init_weight_transfer_engine(
|
||||
WeightTransferInitRequest(init_info=init_info)
|
||||
)
|
||||
return JSONResponse(content={"message": "Weight transfer initialized"})
|
||||
|
||||
|
||||
@router.post("/update_weights")
|
||||
async def update_weights(raw_request: Request):
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
|
||||
update_info = body.get("update_info")
|
||||
if update_info is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail="Missing 'update_info' in request body",
|
||||
)
|
||||
await engine_client(raw_request).update_weights(
|
||||
request=WeightTransferUpdateRequest(update_info=update_info)
|
||||
)
|
||||
return JSONResponse(content={"message": "Weights updated"})
|
||||
|
||||
|
||||
@router.get("/get_world_size")
|
||||
async def get_world_size(
|
||||
raw_request: Request,
|
||||
include_dp: bool = Query(True),
|
||||
):
|
||||
"""Get the world size from the parallel config.
|
||||
|
||||
Args:
|
||||
include_dp: If True (default), returns the world size including
|
||||
data parallelism (TP * PP * DP). If False, returns the world
|
||||
size without data parallelism (TP * PP).
|
||||
"""
|
||||
parallel_config = engine_client(raw_request).vllm_config.parallel_config
|
||||
if include_dp:
|
||||
world_size = parallel_config.world_size_across_dp
|
||||
else:
|
||||
world_size = parallel_config.world_size
|
||||
return JSONResponse(content={"world_size": world_size})
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
if not envs.VLLM_SERVER_DEV_MODE:
|
||||
return
|
||||
app.include_router(router)
|
||||
0
vllm/entrypoints/serve/rpc/__init__.py
Normal file
0
vllm/entrypoints/serve/rpc/__init__.py
Normal file
61
vllm/entrypoints/serve/rpc/api_router.py
Normal file
61
vllm/entrypoints/serve/rpc/api_router.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.post("/collective_rpc")
|
||||
async def collective_rpc(raw_request: Request):
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail=f"JSON decode error: {e}",
|
||||
) from e
|
||||
method = body.get("method")
|
||||
if method is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail="Missing 'method' in request body",
|
||||
)
|
||||
# For security reason, only serialized string args/kwargs are passed.
|
||||
# User-defined `method` is responsible for deserialization if needed.
|
||||
args: list[str] = body.get("args", [])
|
||||
kwargs: dict[str, str] = body.get("kwargs", {})
|
||||
timeout: float | None = body.get("timeout")
|
||||
results = await engine_client(raw_request).collective_rpc(
|
||||
method=method, timeout=timeout, args=tuple(args), kwargs=kwargs
|
||||
)
|
||||
if results is None:
|
||||
return Response(status_code=200)
|
||||
response: list[Any] = []
|
||||
for result in results:
|
||||
if result is None or isinstance(result, dict | list):
|
||||
response.append(result)
|
||||
else:
|
||||
response.append(str(result))
|
||||
return JSONResponse(content={"results": response})
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
if not envs.VLLM_SERVER_DEV_MODE:
|
||||
return
|
||||
app.include_router(router)
|
||||
0
vllm/entrypoints/serve/sleep/__init__.py
Normal file
0
vllm/entrypoints/serve/sleep/__init__.py
Normal file
57
vllm/entrypoints/serve/sleep/api_router.py
Normal file
57
vllm/entrypoints/serve/sleep/api_router.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/sleep")
|
||||
async def sleep(raw_request: Request):
|
||||
# get POST params
|
||||
level = raw_request.query_params.get("level", "1")
|
||||
mode = raw_request.query_params.get("mode", "abort")
|
||||
await engine_client(raw_request).sleep(int(level), mode)
|
||||
# FIXME: in v0 with frontend multiprocessing, the sleep command
|
||||
# is sent but does not finish yet when we return a response.
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.post("/wake_up")
|
||||
async def wake_up(raw_request: Request):
|
||||
tags = raw_request.query_params.getlist("tags")
|
||||
if tags == []:
|
||||
# set to None to wake up all tags if no tags are provided
|
||||
tags = None
|
||||
logger.info("wake up the engine with tags: %s", tags)
|
||||
await engine_client(raw_request).wake_up(tags)
|
||||
# FIXME: in v0 with frontend multiprocessing, the wake-up command
|
||||
# is sent but does not finish yet when we return a response.
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.get("/is_sleeping")
|
||||
async def is_sleeping(raw_request: Request):
|
||||
logger.info("check whether the engine is sleeping")
|
||||
is_sleeping = await engine_client(raw_request).is_sleeping()
|
||||
return JSONResponse(content={"is_sleeping": is_sleeping})
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
if not envs.VLLM_SERVER_DEV_MODE:
|
||||
return
|
||||
|
||||
app.include_router(router)
|
||||
0
vllm/entrypoints/serve/tokenize/__init__.py
Normal file
0
vllm/entrypoints/serve/tokenize/__init__.py
Normal file
114
vllm/entrypoints/serve/tokenize/api_router.py
Normal file
114
vllm/entrypoints/serve/tokenize/api_router.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.serve.tokenize.protocol import (
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
)
|
||||
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
|
||||
from vllm.entrypoints.utils import (
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tokenize",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
handler = tokenization(raw_request)
|
||||
|
||||
try:
|
||||
generator = await handler.create_tokenize(request, raw_request)
|
||||
except Exception as e:
|
||||
generator = handler.create_error_response(e)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, TokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/detokenize",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||
handler = tokenization(raw_request)
|
||||
|
||||
try:
|
||||
generator = await handler.create_detokenize(request, raw_request)
|
||||
except OverflowError as e:
|
||||
raise RequestValidationError(errors=[str(e)]) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, DetokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
if getattr(app.state.args, "enable_tokenizer_info_endpoint", False):
|
||||
"""Conditionally register the tokenizer info endpoint if enabled."""
|
||||
|
||||
@router.get("/tokenizer_info")
|
||||
async def get_tokenizer_info(raw_request: Request):
|
||||
"""Get comprehensive tokenizer information."""
|
||||
result = await tokenization(raw_request).get_tokenizer_info()
|
||||
return JSONResponse(
|
||||
content=result.model_dump(),
|
||||
status_code=result.error.code
|
||||
if isinstance(result, ErrorResponse)
|
||||
else 200,
|
||||
)
|
||||
|
||||
app.include_router(router)
|
||||
183
vllm/entrypoints/serve/tokenize/protocol.py
Normal file
183
vllm/entrypoints/serve/tokenize/protocol.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from typing import Annotated, Any, TypeAlias
|
||||
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionToolsParam,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
OpenAIBaseModel,
|
||||
)
|
||||
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
|
||||
|
||||
|
||||
class TokenizeCompletionRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
prompt: str
|
||||
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||
"the prompt."
|
||||
),
|
||||
)
|
||||
return_token_strs: bool | None = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, also return the token strings corresponding to the token ids."
|
||||
),
|
||||
)
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
return TokenizeParams(
|
||||
max_total_tokens=None,
|
||||
max_output_tokens=0,
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
)
|
||||
|
||||
|
||||
class TokenizeChatRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
messages: list[ChatCompletionMessageParam]
|
||||
|
||||
add_generation_prompt: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true, the generation prompt will be added to the chat template. "
|
||||
"This is a parameter used by chat template in tokenizer config of the "
|
||||
"model."
|
||||
),
|
||||
)
|
||||
return_token_strs: bool | None = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, also return the token strings corresponding to the token ids."
|
||||
),
|
||||
)
|
||||
continue_final_message: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If this is set, the chat will be formatted so that the final "
|
||||
"message in the chat is open-ended, without any EOS tokens. The "
|
||||
"model will continue this message rather than starting a new one. "
|
||||
'This allows you to "prefill" part of the model\'s response for it. '
|
||||
"Cannot be used at the same time as `add_generation_prompt`."
|
||||
),
|
||||
)
|
||||
add_special_tokens: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||
"on top of what is added by the chat template. "
|
||||
"For most models, the chat template takes care of adding the "
|
||||
"special tokens so this should be set to false (as is the "
|
||||
"default)."
|
||||
),
|
||||
)
|
||||
chat_template: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A Jinja template to use for this conversion. "
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."
|
||||
),
|
||||
)
|
||||
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional keyword args to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."
|
||||
),
|
||||
)
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Additional kwargs to pass to the HF processor.",
|
||||
)
|
||||
tools: list[ChatCompletionToolsParam] | None = Field(
|
||||
default=None,
|
||||
description="A list of tools the model may call.",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_generation_prompt(cls, data):
|
||||
if data.get("continue_final_message") and data.get("add_generation_prompt"):
|
||||
raise ValueError(
|
||||
"Cannot set both `continue_final_message` and "
|
||||
"`add_generation_prompt` to True."
|
||||
)
|
||||
return data
|
||||
|
||||
def build_chat_params(
|
||||
self,
|
||||
default_template: str | None,
|
||||
default_template_content_format: ChatTemplateContentFormatOption,
|
||||
) -> ChatParams:
|
||||
return ChatParams(
|
||||
chat_template=self.chat_template or default_template,
|
||||
chat_template_content_format=default_template_content_format,
|
||||
chat_template_kwargs=merge_kwargs(
|
||||
self.chat_template_kwargs,
|
||||
dict(
|
||||
add_generation_prompt=self.add_generation_prompt,
|
||||
continue_final_message=self.continue_final_message,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
return TokenizeParams(
|
||||
max_total_tokens=None,
|
||||
max_output_tokens=0,
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
)
|
||||
|
||||
|
||||
TokenizeRequest: TypeAlias = TokenizeCompletionRequest | TokenizeChatRequest
|
||||
|
||||
|
||||
class TokenizeResponse(OpenAIBaseModel):
|
||||
count: int
|
||||
max_model_len: int
|
||||
tokens: list[int]
|
||||
token_strs: list[str] | None = None
|
||||
|
||||
|
||||
class DetokenizeRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
# TODO: Factor `torch.iinfo` out. `torch.iinfo` pulls torch into a
|
||||
# Pydantic protocol file that currently has no torch dependency.
|
||||
# See: https://github.com/vllm-project/vllm/pull/34468#discussion_r2801173630
|
||||
tokens: list[Annotated[int, Field(ge=0, le=2**63 - 1)]]
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
return TokenizeParams(
|
||||
max_total_tokens=None,
|
||||
max_output_tokens=0,
|
||||
needs_detokenization=True,
|
||||
)
|
||||
|
||||
|
||||
class DetokenizeResponse(OpenAIBaseModel):
|
||||
prompt: str
|
||||
|
||||
|
||||
class TokenizerInfoResponse(OpenAIBaseModel):
|
||||
"""
|
||||
Response containing tokenizer configuration
|
||||
equivalent to tokenizer_config.json
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
tokenizer_class: str
|
||||
195
vllm/entrypoints/serve/tokenize/serving.py
Normal file
195
vllm/entrypoints/serve/tokenize/serving.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Final
|
||||
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.serve.tokenize.protocol import (
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
TokenizeChatRequest,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
TokenizerInfoResponse,
|
||||
)
|
||||
from vllm.inputs import TokensPrompt, token_inputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingTokenization(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
trust_request_chat_template: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
|
||||
async def create_tokenize(
|
||||
self,
|
||||
request: TokenizeRequest,
|
||||
raw_request: Request,
|
||||
) -> TokenizeResponse | ErrorResponse:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"tokenize-{self._base_request_id(raw_request)}"
|
||||
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
if isinstance(request, TokenizeChatRequest):
|
||||
tool_dicts = (
|
||||
None
|
||||
if request.tools is None
|
||||
else [tool.model_dump() for tool in request.tools]
|
||||
)
|
||||
error_check_ret = self._validate_chat_template(
|
||||
request_chat_template=request.chat_template,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
trust_request_chat_template=self.trust_request_chat_template,
|
||||
)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
_, engine_prompts = await self._preprocess_chat(
|
||||
request,
|
||||
request.messages,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
tool_dicts=tool_dicts,
|
||||
)
|
||||
else:
|
||||
engine_prompts = await self._preprocess_completion(
|
||||
request,
|
||||
prompt_input=request.prompt,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(f"{e} {e.__cause__}")
|
||||
|
||||
input_ids: list[int] = []
|
||||
for engine_prompt in engine_prompts:
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
params=None,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
if "prompt_token_ids" in engine_prompt:
|
||||
input_ids.extend(engine_prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
|
||||
|
||||
token_strs = None
|
||||
if request.return_token_strs:
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
token_strs = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
|
||||
return TokenizeResponse(
|
||||
tokens=input_ids,
|
||||
token_strs=token_strs,
|
||||
count=len(input_ids),
|
||||
max_model_len=self.model_config.max_model_len,
|
||||
)
|
||||
|
||||
async def create_detokenize(
|
||||
self,
|
||||
request: DetokenizeRequest,
|
||||
raw_request: Request,
|
||||
) -> DetokenizeResponse | ErrorResponse:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"tokenize-{self._base_request_id(raw_request)}"
|
||||
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
token_inputs(request.tokens),
|
||||
params=None,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
engine_prompt = await self.renderer.tokenize_prompt_async(
|
||||
TokensPrompt(prompt_token_ids=request.tokens),
|
||||
request.build_tok_params(self.model_config),
|
||||
)
|
||||
prompt_text = engine_prompt["prompt"] # type: ignore[typeddict-item]
|
||||
|
||||
return DetokenizeResponse(prompt=prompt_text)
|
||||
|
||||
async def get_tokenizer_info(
|
||||
self,
|
||||
) -> TokenizerInfoResponse | ErrorResponse:
|
||||
"""Get comprehensive tokenizer information."""
|
||||
try:
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
info = TokenizerInfo(tokenizer, self.chat_template).to_dict()
|
||||
return TokenizerInfoResponse(**info)
|
||||
except Exception as e:
|
||||
return self.create_error_response(f"Failed to get tokenizer info: {str(e)}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenizerInfo:
|
||||
tokenizer: TokenizerLike
|
||||
chat_template: str | None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return the tokenizer configuration."""
|
||||
return self._get_tokenizer_config()
|
||||
|
||||
def _get_tokenizer_config(self) -> dict[str, Any]:
|
||||
"""Get tokenizer configuration directly from the tokenizer object."""
|
||||
config = dict(getattr(self.tokenizer, "init_kwargs", None) or {})
|
||||
|
||||
# Remove file path fields
|
||||
config.pop("vocab_file", None)
|
||||
config.pop("merges_file", None)
|
||||
|
||||
config = self._make_json_serializable(config)
|
||||
config["tokenizer_class"] = type(self.tokenizer).__name__
|
||||
if self.chat_template:
|
||||
config["chat_template"] = self.chat_template
|
||||
return config
|
||||
|
||||
def _make_json_serializable(self, obj):
|
||||
"""Convert any non-JSON-serializable objects to serializable format."""
|
||||
if hasattr(obj, "content"):
|
||||
return obj.content
|
||||
elif isinstance(obj, dict):
|
||||
return {k: self._make_json_serializable(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [self._make_json_serializable(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
Reference in New Issue
Block a user