346 lines
15 KiB
Python
346 lines
15 KiB
Python
|
|
import asyncio
|
|
import time
|
|
from collections.abc import AsyncGenerator, AsyncIterator
|
|
from collections.abc import Sequence as GenericSequence
|
|
from typing import Optional, Union, cast
|
|
|
|
import jinja2
|
|
from fastapi import Request
|
|
from typing_extensions import assert_never
|
|
from concurrent.futures.thread import ThreadPoolExecutor
|
|
|
|
from vllm.config import ModelConfig
|
|
from vllm.engine.protocol import EngineClient
|
|
from vllm.entrypoints.logger import RequestLogger
|
|
# yapf conflicts with isort for this block
|
|
# yapf: disable
|
|
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
|
CompletionRequest,
|
|
CompletionResponse,
|
|
CompletionResponseChoice,
|
|
CompletionResponseStreamChoice,
|
|
CompletionStreamResponse,
|
|
ErrorResponse,
|
|
RequestResponseMetadata,
|
|
UsageInfo)
|
|
from vllm.entrypoints.openai.serving_engine import (
|
|
EmbedsPrompt as ServingEngineEmbedsPrompt)
|
|
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
|
TextTokensPrompt,
|
|
clamp_prompt_logprobs,
|
|
is_text_tokens_prompt)
|
|
# yapf: enable
|
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
|
from vllm.entrypoints.utils import get_max_tokens
|
|
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
|
|
is_tokens_prompt)
|
|
from vllm.logger import init_logger
|
|
from vllm.outputs import RequestOutput
|
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|
from vllm.utils import merge_async_iterators
|
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
|
|
|
from vllm.entrypoints.openai.serving_completion import logger
|
|
from vllm.utils import (is_list_of, make_async, merge_async_iterators,
|
|
random_uuid)
|
|
from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
|
|
merge_async_iterators, random_uuid)
|
|
from vllm_vacc.vllm.model_executor.models.vars import LLM_MAX_PREFILL_SEQ_LEN
|
|
|
|
|
|
class OpenAIServingCompletion(OpenAIServing):
|
|
|
|
def __init__(
|
|
self,
|
|
engine_client: EngineClient,
|
|
model_config: ModelConfig,
|
|
models: OpenAIServingModels,
|
|
*,
|
|
request_logger: Optional[RequestLogger],
|
|
return_tokens_as_token_ids: bool = False,
|
|
enable_prompt_tokens_details: bool = False,
|
|
enable_force_include_usage: bool = False,
|
|
enable_strict_batch_barrier: bool = True,
|
|
log_error_stack: bool = False,
|
|
):
|
|
|
|
self.engine_client = engine_client
|
|
self.model_config = model_config
|
|
self.max_model_len = model_config.max_model_len
|
|
|
|
self.models = models
|
|
|
|
self.request_logger = request_logger
|
|
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
|
self.enable_force_include_usage = enable_force_include_usage
|
|
|
|
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
|
|
|
self._async_tokenizer_pool: dict[AnyTokenizer,
|
|
AsyncMicrobatchTokenizer] = {}
|
|
self.log_error_stack = log_error_stack
|
|
|
|
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
|
self.default_sampling_params = (
|
|
self.model_config.get_diff_sampling_param())
|
|
if self.default_sampling_params:
|
|
source = self.model_config.generation_config
|
|
source = "model" if source == "auto" else source
|
|
logger.info("Using default completion sampling params from %s: %s",
|
|
source, self.default_sampling_params)
|
|
self.enable_strict_batch_barrier = enable_strict_batch_barrier
|
|
|
|
|
|
async def create_completion(
|
|
self,
|
|
request: CompletionRequest,
|
|
raw_request: Optional[Request] = None,
|
|
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
|
|
"""Completion API similar to OpenAI's API.
|
|
|
|
See https://platform.openai.com/docs/api-reference/completions/create
|
|
for the API specification. This API mimics the OpenAI Completion API.
|
|
|
|
NOTE: Currently we do not support the following feature:
|
|
- suffix (the language models we currently support do not support
|
|
suffix)
|
|
"""
|
|
error_check_ret = await self._check_model(request)
|
|
if error_check_ret is not None:
|
|
return error_check_ret
|
|
|
|
# If the engine is dead, raise the engine's DEAD_ERROR.
|
|
# This is required for the streaming case, where we return a
|
|
# success status before we actually start generating text :).
|
|
if self.engine_client.errored:
|
|
raise self.engine_client.dead_error
|
|
|
|
# Return error for unsupported features.
|
|
if request.suffix is not None:
|
|
return self.create_error_response(
|
|
"suffix is not currently supported")
|
|
|
|
if request.echo and request.prompt_embeds is not None:
|
|
return self.create_error_response(
|
|
"Echo is unsupported with prompt embeds.")
|
|
|
|
if (request.prompt_logprobs is not None
|
|
and request.prompt_embeds is not None):
|
|
return self.create_error_response(
|
|
"prompt_logprobs is not compatible with prompt embeds.")
|
|
|
|
request_id = (
|
|
f"cmpl-"
|
|
f"{self._base_request_id(raw_request, request.request_id)}")
|
|
created_time = int(time.time())
|
|
|
|
request_metadata = RequestResponseMetadata(request_id=request_id)
|
|
if raw_request:
|
|
raw_request.state.request_metadata = request_metadata
|
|
|
|
try:
|
|
lora_request = self._maybe_get_adapters(request)
|
|
|
|
if self.model_config.skip_tokenizer_init:
|
|
tokenizer = None
|
|
else:
|
|
tokenizer = await self.engine_client.get_tokenizer()
|
|
renderer = self._get_renderer(tokenizer)
|
|
|
|
engine_prompts = await renderer.render_prompt_and_embeds(
|
|
prompt_or_prompts=request.prompt,
|
|
prompt_embeds=request.prompt_embeds,
|
|
deepstack_input_embeds=request.deepstack_input_embeds if hasattr(request, 'deepstack_input_embeds') else None,
|
|
config=self._build_render_config(request),
|
|
)
|
|
except ValueError as e:
|
|
logger.exception("Error in preprocessing prompt inputs")
|
|
return self.create_error_response(str(e))
|
|
except TypeError as e:
|
|
logger.exception("Error in preprocessing prompt inputs")
|
|
return self.create_error_response(str(e))
|
|
except RuntimeError as e:
|
|
logger.exception("Error in preprocessing prompt inputs")
|
|
return self.create_error_response(str(e))
|
|
except jinja2.TemplateError as e:
|
|
logger.exception("Error in preprocessing prompt inputs")
|
|
return self.create_error_response(str(e))
|
|
|
|
# Schedule the request and get the result generator.
|
|
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
|
try:
|
|
total_num_prompts = len(engine_prompts)
|
|
for i, engine_prompt in enumerate(engine_prompts):
|
|
sampling_params: Union[SamplingParams, BeamSearchParams]
|
|
# Mypy does not infer that engine_prompt will have only one of
|
|
# "prompt_token_ids" or "prompt_embeds" defined, and both of
|
|
# these as Union[object, the expected type], where it infers
|
|
# object if engine_prompt is a subclass of one of the
|
|
# typeddicts that defines both keys. Worse, because of
|
|
# https://github.com/python/mypy/issues/8586, mypy does not
|
|
# infer the type of engine_prompt correctly because of the
|
|
# enumerate. So we need an unnecessary cast here.
|
|
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
|
|
engine_prompt)
|
|
if is_embeds_prompt(engine_prompt):
|
|
input_length = len(engine_prompt["prompt_embeds"])
|
|
elif is_tokens_prompt(engine_prompt):
|
|
input_length = len(engine_prompt["prompt_token_ids"])
|
|
if input_length > LLM_MAX_PREFILL_SEQ_LEN:
|
|
raise ValueError(
|
|
f"This model's maximum input seq length limit is "
|
|
f"{LLM_MAX_PREFILL_SEQ_LEN} tokens. However, you requested "
|
|
f"({input_length} in the input messages, "
|
|
f"Please reduce the length of the input messages.")
|
|
else:
|
|
assert_never(engine_prompt)
|
|
|
|
if self.default_sampling_params is None:
|
|
self.default_sampling_params = {}
|
|
|
|
max_tokens = get_max_tokens(
|
|
max_model_len=self.max_model_len,
|
|
request=request,
|
|
input_length=input_length,
|
|
default_sampling_params=self.default_sampling_params,
|
|
)
|
|
|
|
if request.use_beam_search:
|
|
sampling_params = request.to_beam_search_params(
|
|
max_tokens, self.default_sampling_params)
|
|
else:
|
|
sampling_params = request.to_sampling_params(
|
|
max_tokens,
|
|
self.model_config.logits_processor_pattern,
|
|
self.default_sampling_params,
|
|
)
|
|
|
|
# Inject strict batch barrier metadata so this batch is held
|
|
# until all items are ready, then scheduled together.
|
|
if (self.enable_strict_batch_barrier
|
|
and total_num_prompts > 1
|
|
and isinstance(sampling_params, SamplingParams)):
|
|
if sampling_params.extra_args is None:
|
|
sampling_params.extra_args = {}
|
|
sampling_params.extra_args.setdefault("barrier_group_id",
|
|
request_id)
|
|
sampling_params.extra_args.setdefault("barrier_group_size",
|
|
total_num_prompts)
|
|
|
|
request_id_item = f"{request_id}-{i}"
|
|
|
|
self._log_inputs(
|
|
request_id_item,
|
|
engine_prompt,
|
|
params=sampling_params,
|
|
lora_request=lora_request,
|
|
)
|
|
|
|
trace_headers = (None if raw_request is None else await
|
|
self._get_trace_headers(raw_request.headers))
|
|
|
|
# Mypy inconsistently requires this second cast in different
|
|
# environments. It shouldn't be necessary (redundant from above)
|
|
# but pre-commit in CI fails without it.
|
|
engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt],
|
|
engine_prompt)
|
|
if isinstance(sampling_params, BeamSearchParams):
|
|
generator = self.engine_client.beam_search(
|
|
prompt=engine_prompt,
|
|
request_id=request_id,
|
|
params=sampling_params,
|
|
lora_request=lora_request,
|
|
)
|
|
else:
|
|
generator = self.engine_client.generate(
|
|
engine_prompt,
|
|
sampling_params,
|
|
request_id_item,
|
|
lora_request=lora_request,
|
|
trace_headers=trace_headers,
|
|
priority=request.priority,
|
|
)
|
|
|
|
generators.append(generator)
|
|
except ValueError as e:
|
|
# TODO: Use a vllm-specific Validation Error
|
|
logger.error(e)
|
|
return self.create_error_response(str(e))
|
|
|
|
result_generator = merge_async_iterators(*generators)
|
|
|
|
model_name = self.models.model_name(lora_request)
|
|
num_prompts = len(engine_prompts)
|
|
|
|
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
|
# results. Noting that best_of is only supported in V0. In addition,
|
|
# we do not stream the results when use beam search.
|
|
stream = (request.stream
|
|
and (request.best_of is None or request.n == request.best_of)
|
|
and not request.use_beam_search)
|
|
|
|
# Streaming response
|
|
if stream:
|
|
return self.completion_stream_generator(
|
|
request,
|
|
engine_prompts,
|
|
result_generator,
|
|
request_id,
|
|
created_time,
|
|
model_name,
|
|
num_prompts=num_prompts,
|
|
tokenizer=tokenizer,
|
|
request_metadata=request_metadata,
|
|
enable_force_include_usage=self.enable_force_include_usage,
|
|
)
|
|
|
|
# Non-streaming response
|
|
final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts
|
|
try:
|
|
async for i, res in result_generator:
|
|
final_res_batch[i] = res
|
|
|
|
for i, final_res in enumerate(final_res_batch):
|
|
assert final_res is not None
|
|
|
|
# The output should contain the input text
|
|
# We did not pass it into vLLM engine to avoid being redundant
|
|
# with the inputs token IDs
|
|
if final_res.prompt is None:
|
|
engine_prompt = engine_prompts[i]
|
|
final_res.prompt = None if is_embeds_prompt(
|
|
engine_prompt) else engine_prompt.get("prompt")
|
|
|
|
final_res_batch_checked = cast(list[RequestOutput],
|
|
final_res_batch)
|
|
|
|
response = self.request_output_to_completion_response(
|
|
final_res_batch_checked,
|
|
request,
|
|
request_id,
|
|
created_time,
|
|
model_name,
|
|
tokenizer,
|
|
request_metadata,
|
|
)
|
|
except asyncio.CancelledError:
|
|
return self.create_error_response("Client disconnected")
|
|
except ValueError as e:
|
|
# TODO: Use a vllm-specific Validation Error
|
|
return self.create_error_response(str(e))
|
|
|
|
# When user requests streaming but we don't stream, we still need to
|
|
# return a streaming response with a single event.
|
|
if request.stream:
|
|
response_json = response.model_dump_json()
|
|
|
|
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
|
yield f"data: {response_json}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return fake_stream_generator()
|
|
|
|
return response
|