Files
2026-01-19 10:38:50 +08:00

286 lines
10 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from collections.abc import AsyncGenerator
from collections.abc import Sequence as GenericSequence
from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb,
ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ErrorResponse,
PromptTokenUsageInfo,
RequestResponseMetadata,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.serve.disagg.protocol import (
GenerateRequest,
GenerateResponse,
GenerateResponseChoice,
)
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.utils.collection_utils import as_list
logger = init_logger(__name__)
class ServingTokens(OpenAIServing):
"""Provides Tokens IN <> Tokens OUT functionality to vLLM API."""
def __init__(
self,
engine_client: EngineClient,
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
force_no_detokenize: bool = False,
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
enable_prompt_tokens_details: bool = False,
enable_log_outputs: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
)
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_log_outputs = enable_log_outputs
self.force_no_detokenize = force_no_detokenize
if force_no_detokenize:
logger.info(
"Tokens-only mode is enabled, skipping detokenization "
"step for incoming requests."
)
async def serve_tokens(
self,
request: GenerateRequest,
raw_request: Request | None = None,
) -> GenerateResponse | ErrorResponse:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
logger.error("Error with model %s", error_check_ret)
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
lora_request = None
lora_request = self._maybe_get_adapters(request, supports_default_mm_loras=True)
model_name = self.models.model_name(lora_request)
request_id = (
f"generate-tokens-{self._base_request_id(raw_request, request.request_id)}"
)
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
# TODO(NickLucche): Change to EngineCoreRequest once Renderer work is
# completed
engine_prompt = TokensPrompt(prompt_token_ids=request.token_ids)
if request.features is not None:
engine_prompt["multi_modal_data"] = None
if hasattr(request, "cache_salt") and request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt
# Schedule the request and get the result generator.
result_generator: AsyncGenerator[RequestOutput, None] | None = None
try:
sampling_params = request.sampling_params
if self.force_no_detokenize:
sampling_params.detokenize = False
self._log_inputs(
request_id,
TokensPrompt(prompt_token_ids=request.token_ids),
params=sampling_params,
lora_request=lora_request,
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
result_generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
except ValueError as e:
return self.create_error_response(str(e))
# TODO(NickLucche): Implement streaming response
try:
assert result_generator is not None
return await self.serve_tokens_full_generator(
request, result_generator, request_id, model_name, request_metadata
)
except ValueError as e:
return self.create_error_response(str(e))
async def serve_tokens_full_generator(
self,
request: GenerateRequest,
result_generator: AsyncGenerator[RequestOutput, None],
request_id: str,
model_name: str,
request_metadata: RequestResponseMetadata,
) -> ErrorResponse | GenerateResponse:
created_time = int(time.time())
final_res: RequestOutput | None = None
sampling_params: SamplingParams = request.sampling_params
try:
async for res in result_generator:
final_res = res
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
return self.create_error_response(str(e))
assert final_res is not None
choices: list[GenerateResponseChoice] = []
num_generated_tokens = 0
for output in final_res.outputs:
token_ids = output.token_ids
out_logprobs = output.logprobs
# This is top_logprobs in completions API
if sampling_params.logprobs:
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_tokens_logprobs(
token_ids=token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=sampling_params.logprobs,
)
else:
logprobs = None
choice_data = GenerateResponseChoice(
index=output.index,
logprobs=logprobs,
finish_reason=output.finish_reason if output.finish_reason else "stop",
token_ids=as_list(output.token_ids),
)
choices.append(choice_data)
num_generated_tokens += len(output.token_ids)
assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids)
if final_res.encoder_prompt_token_ids is not None:
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
# This info is not available at the /coordinator level
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=final_res.num_cached_tokens
)
request_metadata.final_usage_info = usage
response = GenerateResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
kv_transfer_params=final_res.kv_transfer_params,
)
# Log complete response if output logging is enabled
if self.enable_log_outputs and self.request_logger:
for choice in choices:
# Get the corresponding output token IDs
output_token_ids = None
if choice.index < len(final_res.outputs):
output_token_ids = final_res.outputs[choice.index].token_ids
if output_token_ids:
# Log token_ids only.
self.request_logger.log_outputs(
request_id=request_id,
outputs="",
output_token_ids=output_token_ids,
finish_reason=choice.finish_reason,
is_streaming=False,
delta=False,
)
return response
def _create_tokens_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None],
num_output_top_logprobs: int | None = None,
) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs."""
logprobs_content: list[ChatCompletionLogProbsContent] = []
for i, token_id in enumerate(token_ids):
token = f"token_id:{token_id}"
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None or step_top_logprobs.get(token_id) is None:
logprobs_content.append(
ChatCompletionLogProbsContent(
token=token,
)
)
else:
step_token = step_top_logprobs[token_id]
logprobs_content.append(
ChatCompletionLogProbsContent(
token=token,
logprob=max(step_token.logprob, -9999.0),
top_logprobs=[
ChatCompletionLogProb(
token=token,
logprob=max(p[1].logprob, -9999.0),
)
for i, p in enumerate(step_top_logprobs.items())
if num_output_top_logprobs and i < num_output_top_logprobs
],
)
)
return ChatCompletionLogProbs(content=logprobs_content)