270 lines
10 KiB
Python
270 lines
10 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import asyncio
|
|
import time
|
|
from collections.abc import AsyncGenerator
|
|
from collections.abc import Sequence as GenericSequence
|
|
|
|
from fastapi import Request
|
|
|
|
# yapf: disable
|
|
from vllm.engine.protocol import EngineClient
|
|
from vllm.entrypoints.logger import RequestLogger
|
|
from vllm.entrypoints.openai.protocol import (
|
|
ChatCompletionLogProb,
|
|
ChatCompletionLogProbs,
|
|
ChatCompletionLogProbsContent,
|
|
ErrorResponse,
|
|
GenerateRequest,
|
|
GenerateResponse,
|
|
GenerateResponseChoice,
|
|
PromptTokenUsageInfo,
|
|
RequestResponseMetadata,
|
|
UsageInfo,
|
|
)
|
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
|
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
|
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
|
from vllm.logger import init_logger
|
|
from vllm.logprobs import Logprob
|
|
from vllm.outputs import RequestOutput
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.utils.collection_utils import as_list
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class ServingTokens(OpenAIServing):
|
|
"""Provides Tokens IN <> Tokens OUT functionality to vLLM API."""
|
|
|
|
def __init__(
|
|
self,
|
|
engine_client: EngineClient,
|
|
models: OpenAIServingModels,
|
|
*,
|
|
request_logger: RequestLogger | None,
|
|
force_no_detokenize: bool = False,
|
|
return_tokens_as_token_ids: bool = False,
|
|
log_error_stack: bool = False,
|
|
enable_prompt_tokens_details: bool = False,
|
|
enable_log_outputs: bool = False,
|
|
):
|
|
super().__init__(engine_client=engine_client,
|
|
models=models,
|
|
request_logger=request_logger,
|
|
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
|
log_error_stack=log_error_stack)
|
|
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
|
self.enable_log_outputs = enable_log_outputs
|
|
self.force_no_detokenize = force_no_detokenize
|
|
if force_no_detokenize:
|
|
logger.info("Tokens-only mode is enabled, skipping detokenization "
|
|
"step for incoming requests.")
|
|
|
|
async def serve_tokens(
|
|
self,
|
|
request: GenerateRequest,
|
|
raw_request: Request | None = None
|
|
) -> GenerateResponse | ErrorResponse:
|
|
error_check_ret = await self._check_model(request)
|
|
if error_check_ret is not None:
|
|
logger.error("Error with model %s", error_check_ret)
|
|
return error_check_ret
|
|
|
|
# If the engine is dead, raise the engine's DEAD_ERROR.
|
|
# This is required for the streaming case, where we return a
|
|
# success status before we actually start generating text :).
|
|
if self.engine_client.errored:
|
|
raise self.engine_client.dead_error
|
|
|
|
lora_request = None
|
|
lora_request = self._maybe_get_adapters(request,
|
|
supports_default_mm_loras=True)
|
|
|
|
model_name = self.models.model_name(lora_request)
|
|
|
|
request_id = "generate-tokens-" \
|
|
f"{self._base_request_id(raw_request, request.request_id)}"
|
|
|
|
request_metadata = RequestResponseMetadata(request_id=request_id)
|
|
if raw_request:
|
|
raw_request.state.request_metadata = request_metadata
|
|
|
|
# TODO(NickLucche): Change to EngineCoreRequest once Renderer work is
|
|
# completed
|
|
engine_prompt = EngineTokensPrompt(prompt_token_ids=request.token_ids)
|
|
if request.features is not None:
|
|
engine_prompt["multi_modal_data"] = None
|
|
|
|
if hasattr(request, "cache_salt") and request.cache_salt is not None:
|
|
engine_prompt["cache_salt"] = request.cache_salt
|
|
|
|
# Schedule the request and get the result generator.
|
|
result_generator: AsyncGenerator[RequestOutput, None] | None = None
|
|
try:
|
|
sampling_params = request.sampling_params
|
|
if self.force_no_detokenize:
|
|
sampling_params.detokenize = False
|
|
|
|
self._log_inputs(request_id,
|
|
request.token_ids,
|
|
params=sampling_params,
|
|
lora_request=lora_request)
|
|
|
|
trace_headers = (None if raw_request is None else await
|
|
self._get_trace_headers(raw_request.headers))
|
|
|
|
result_generator = self.engine_client.generate(
|
|
engine_prompt,
|
|
sampling_params,
|
|
request_id,
|
|
lora_request=lora_request,
|
|
trace_headers=trace_headers,
|
|
priority=request.priority,
|
|
)
|
|
|
|
except ValueError as e:
|
|
return self.create_error_response(str(e))
|
|
|
|
# TODO(NickLucche): Implement streaming response
|
|
|
|
try:
|
|
assert result_generator is not None
|
|
return await self.serve_tokens_full_generator(
|
|
request, result_generator, request_id, model_name,
|
|
request_metadata)
|
|
except ValueError as e:
|
|
return self.create_error_response(str(e))
|
|
|
|
async def serve_tokens_full_generator(
|
|
self,
|
|
request: GenerateRequest,
|
|
result_generator: AsyncGenerator[RequestOutput, None],
|
|
request_id: str,
|
|
model_name: str,
|
|
request_metadata: RequestResponseMetadata,
|
|
) -> ErrorResponse | GenerateResponse:
|
|
|
|
created_time = int(time.time())
|
|
final_res: RequestOutput | None = None
|
|
sampling_params: SamplingParams = request.sampling_params
|
|
|
|
try:
|
|
async for res in result_generator:
|
|
final_res = res
|
|
except asyncio.CancelledError:
|
|
return self.create_error_response("Client disconnected")
|
|
except ValueError as e:
|
|
return self.create_error_response(str(e))
|
|
|
|
assert final_res is not None
|
|
|
|
choices: list[GenerateResponseChoice] = []
|
|
num_generated_tokens = 0
|
|
for output in final_res.outputs:
|
|
token_ids = output.token_ids
|
|
out_logprobs = output.logprobs
|
|
|
|
# This is top_logprobs in completions API
|
|
if sampling_params.logprobs:
|
|
assert out_logprobs is not None, "Did not output logprobs"
|
|
logprobs = self._create_tokens_logprobs(
|
|
token_ids=token_ids,
|
|
top_logprobs=out_logprobs,
|
|
num_output_top_logprobs=sampling_params.logprobs,
|
|
)
|
|
else:
|
|
logprobs = None
|
|
|
|
choice_data = GenerateResponseChoice(
|
|
index=output.index,
|
|
logprobs=logprobs,
|
|
finish_reason=output.finish_reason
|
|
if output.finish_reason else "stop",
|
|
token_ids=as_list(output.token_ids))
|
|
|
|
choices.append(choice_data)
|
|
num_generated_tokens += len(output.token_ids)
|
|
|
|
assert final_res.prompt_token_ids is not None
|
|
num_prompt_tokens = len(final_res.prompt_token_ids)
|
|
if final_res.encoder_prompt_token_ids is not None:
|
|
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
|
|
|
|
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
|
completion_tokens=num_generated_tokens,
|
|
total_tokens=num_prompt_tokens +
|
|
num_generated_tokens)
|
|
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
|
|
# This info is not available at the /coordinator level
|
|
usage.prompt_tokens_details = PromptTokenUsageInfo(
|
|
cached_tokens=final_res.num_cached_tokens)
|
|
|
|
request_metadata.final_usage_info = usage
|
|
|
|
response = GenerateResponse(
|
|
id=request_id,
|
|
created=created_time,
|
|
model=model_name,
|
|
choices=choices,
|
|
usage=usage,
|
|
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
|
|
kv_transfer_params=final_res.kv_transfer_params,
|
|
)
|
|
|
|
# Log complete response if output logging is enabled
|
|
if self.enable_log_outputs and self.request_logger:
|
|
for choice in choices:
|
|
# Get the corresponding output token IDs
|
|
output_token_ids = None
|
|
if choice.index < len(final_res.outputs):
|
|
output_token_ids = final_res.outputs[
|
|
choice.index].token_ids
|
|
|
|
if output_token_ids:
|
|
# Log token_ids only.
|
|
self.request_logger.log_outputs(
|
|
request_id=request_id,
|
|
outputs="",
|
|
output_token_ids=output_token_ids,
|
|
finish_reason=choice.finish_reason,
|
|
is_streaming=False,
|
|
delta=False,
|
|
)
|
|
|
|
return response
|
|
|
|
def _create_tokens_logprobs(
|
|
self,
|
|
token_ids: GenericSequence[int],
|
|
top_logprobs: GenericSequence[dict[int, Logprob] | None],
|
|
num_output_top_logprobs: int | None = None,
|
|
) -> ChatCompletionLogProbs:
|
|
"""Create OpenAI-style logprobs."""
|
|
logprobs_content: list[ChatCompletionLogProbsContent] = []
|
|
|
|
for i, token_id in enumerate(token_ids):
|
|
token = f"token_id:{token_id}"
|
|
step_top_logprobs = top_logprobs[i]
|
|
if step_top_logprobs is None or step_top_logprobs.get(
|
|
token_id) is None:
|
|
logprobs_content.append(
|
|
ChatCompletionLogProbsContent(token=token, ))
|
|
else:
|
|
step_token = step_top_logprobs[token_id]
|
|
|
|
logprobs_content.append(
|
|
ChatCompletionLogProbsContent(
|
|
token=token,
|
|
logprob=max(step_token.logprob, -9999.0),
|
|
top_logprobs=[
|
|
ChatCompletionLogProb(
|
|
token=token,
|
|
logprob=max(p[1].logprob, -9999.0),
|
|
) for i, p in enumerate(step_top_logprobs.items())
|
|
if num_output_top_logprobs
|
|
and i < num_output_top_logprobs
|
|
]))
|
|
|
|
return ChatCompletionLogProbs(content=logprobs_content)
|