Files
2026-04-02 04:55:00 +00:00

346 lines
15 KiB
Python

import asyncio
import time
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import Optional, Union, cast
import jinja2
from fastapi import Request
from typing_extensions import assert_never
from concurrent.futures.thread import ThreadPoolExecutor
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
RequestResponseMetadata,
UsageInfo)
from vllm.entrypoints.openai.serving_engine import (
EmbedsPrompt as ServingEngineEmbedsPrompt)
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
TextTokensPrompt,
clamp_prompt_logprobs,
is_text_tokens_prompt)
# yapf: enable
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
is_tokens_prompt)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import merge_async_iterators
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_completion import logger
from vllm.utils import (is_list_of, make_async, merge_async_iterators,
random_uuid)
from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
merge_async_iterators, random_uuid)
from vllm_vacc.vllm.model_executor.models.vars import LLM_MAX_PREFILL_SEQ_LEN
class OpenAIServingCompletion(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
enable_strict_batch_barrier: bool = True,
log_error_stack: bool = False,
):
self.engine_client = engine_client
self.model_config = model_config
self.max_model_len = model_config.max_model_len
self.models = models
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
self.enable_force_include_usage = enable_force_include_usage
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
self._async_tokenizer_pool: dict[AnyTokenizer,
AsyncMicrobatchTokenizer] = {}
self.log_error_stack = log_error_stack
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
if self.default_sampling_params:
source = self.model_config.generation_config
source = "model" if source == "auto" else source
logger.info("Using default completion sampling params from %s: %s",
source, self.default_sampling_params)
self.enable_strict_batch_barrier = enable_strict_batch_barrier
async def create_completion(
self,
request: CompletionRequest,
raw_request: Optional[Request] = None,
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following feature:
- suffix (the language models we currently support do not support
suffix)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
# Return error for unsupported features.
if request.suffix is not None:
return self.create_error_response(
"suffix is not currently supported")
if request.echo and request.prompt_embeds is not None:
return self.create_error_response(
"Echo is unsupported with prompt embeds.")
if (request.prompt_logprobs is not None
and request.prompt_embeds is not None):
return self.create_error_response(
"prompt_logprobs is not compatible with prompt embeds.")
request_id = (
f"cmpl-"
f"{self._base_request_id(raw_request, request.request_id)}")
created_time = int(time.time())
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
try:
lora_request = self._maybe_get_adapters(request)
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
engine_prompts = await renderer.render_prompt_and_embeds(
prompt_or_prompts=request.prompt,
prompt_embeds=request.prompt_embeds,
deepstack_input_embeds=request.deepstack_input_embeds if hasattr(request, 'deepstack_input_embeds') else None,
config=self._build_render_config(request),
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except TypeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except RuntimeError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
except jinja2.TemplateError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
total_num_prompts = len(engine_prompts)
for i, engine_prompt in enumerate(engine_prompts):
sampling_params: Union[SamplingParams, BeamSearchParams]
# Mypy does not infer that engine_prompt will have only one of
# "prompt_token_ids" or "prompt_embeds" defined, and both of
# these as Union[object, the expected type], where it infers
# object if engine_prompt is a subclass of one of the
# typeddicts that defines both keys. Worse, because of
# https://github.com/python/mypy/issues/8586, mypy does not
# infer the type of engine_prompt correctly because of the
# enumerate. So we need an unnecessary cast here.
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
engine_prompt)
if is_embeds_prompt(engine_prompt):
input_length = len(engine_prompt["prompt_embeds"])
elif is_tokens_prompt(engine_prompt):
input_length = len(engine_prompt["prompt_token_ids"])
if input_length > LLM_MAX_PREFILL_SEQ_LEN:
raise ValueError(
f"This model's maximum input seq length limit is "
f"{LLM_MAX_PREFILL_SEQ_LEN} tokens. However, you requested "
f"({input_length} in the input messages, "
f"Please reduce the length of the input messages.")
else:
assert_never(engine_prompt)
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
input_length=input_length,
default_sampling_params=self.default_sampling_params,
)
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params)
else:
sampling_params = request.to_sampling_params(
max_tokens,
self.model_config.logits_processor_pattern,
self.default_sampling_params,
)
# Inject strict batch barrier metadata so this batch is held
# until all items are ready, then scheduled together.
if (self.enable_strict_batch_barrier
and total_num_prompts > 1
and isinstance(sampling_params, SamplingParams)):
if sampling_params.extra_args is None:
sampling_params.extra_args = {}
sampling_params.extra_args.setdefault("barrier_group_id",
request_id)
sampling_params.extra_args.setdefault("barrier_group_size",
total_num_prompts)
request_id_item = f"{request_id}-{i}"
self._log_inputs(
request_id_item,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
# Mypy inconsistently requires this second cast in different
# environments. It shouldn't be necessary (redundant from above)
# but pre-commit in CI fails without it.
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
engine_prompt)
if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
prompt=engine_prompt,
request_id=request_id,
params=sampling_params,
lora_request=lora_request,
)
else:
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
logger.error(e)
return self.create_error_response(str(e))
result_generator = merge_async_iterators(*generators)
model_name = self.models.model_name(lora_request)
num_prompts = len(engine_prompts)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. Noting that best_of is only supported in V0. In addition,
# we do not stream the results when use beam search.
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
# Streaming response
if stream:
return self.completion_stream_generator(
request,
engine_prompts,
result_generator,
request_id,
created_time,
model_name,
num_prompts=num_prompts,
tokenizer=tokenizer,
request_metadata=request_metadata,
enable_force_include_usage=self.enable_force_include_usage,
)
# Non-streaming response
final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
for i, final_res in enumerate(final_res_batch):
assert final_res is not None
# The output should contain the input text
# We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs
if final_res.prompt is None:
engine_prompt = engine_prompts[i]
final_res.prompt = None if is_embeds_prompt(
engine_prompt) else engine_prompt.get("prompt")
final_res_batch_checked = cast(list[RequestOutput],
final_res_batch)
response = self.request_output_to_completion_response(
final_res_batch_checked,
request,
request_id,
created_time,
model_name,
tokenizer,
request_metadata,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
if request.stream:
response_json = response.model_dump_json()
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return fake_stream_generator()
return response