Files
enginex-c_series-vllm/vllm/entrypoints/openai/serving_embedding.py

202 lines
7.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
from typing import Final, Literal, Optional, Union, cast
import numpy as np
from fastapi import Request
from typing_extensions import assert_never, override
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
OpenAIServing,
ServeContext)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
PoolingRequestOutput)
logger = init_logger(__name__)
def _get_embedding(
output: EmbeddingOutput,
encoding_format: Literal["float", "base64"],
) -> Union[list[float], str]:
if encoding_format == "float":
return output.embedding
elif encoding_format == "base64":
# Force to use float32 for base64 encoding
# to match the OpenAI python client behavior
embedding_bytes = np.array(output.embedding, dtype="float32").tobytes()
return base64.b64encode(embedding_bytes).decode("utf-8")
assert_never(encoding_format)
class EmbeddingMixin(OpenAIServing):
async def _preprocess(
self,
ctx: ServeContext,
) -> Optional[ErrorResponse]:
ctx = cast(EmbeddingServeContext, ctx)
try:
(
ctx.lora_request,
ctx.prompt_adapter_request,
) = self._maybe_get_adapters(ctx.request)
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
)
if ctx.prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for embedding models")
if isinstance(ctx.request, EmbeddingChatRequest):
(
_,
ctx.request_prompts,
ctx.engine_prompts,
) = await self._preprocess_chat(
ctx.request,
tokenizer,
ctx.request.messages,
chat_template=ctx.request.chat_template
or ctx.chat_template,
chat_template_content_format=ctx.
chat_template_content_format,
# In embedding requests, we are not generating tokens,
# so there is no need to append extra tokens to the input
add_generation_prompt=False,
continue_final_message=False,
truncate_prompt_tokens=ctx.truncate_prompt_tokens,
add_special_tokens=ctx.request.add_special_tokens,
)
else:
(ctx.request_prompts,
ctx.engine_prompts) = await self._preprocess_completion(
ctx.request,
tokenizer,
ctx.request.input,
truncate_prompt_tokens=ctx.truncate_prompt_tokens,
add_special_tokens=ctx.request.add_special_tokens,
)
return None
except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
def _build_response(
self,
ctx: ServeContext,
) -> Union[EmbeddingResponse, ErrorResponse]:
items: list[EmbeddingResponseData] = []
num_prompt_tokens = 0
final_res_batch_checked = cast(list[PoolingRequestOutput],
ctx.final_res_batch)
for idx, final_res in enumerate(final_res_batch_checked):
embedding_res = EmbeddingRequestOutput.from_base(final_res)
item = EmbeddingResponseData(
index=idx,
embedding=_get_embedding(embedding_res.outputs,
ctx.request.encoding_format),
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return EmbeddingResponse(
id=ctx.request_id,
created=ctx.created_time,
model=ctx.model_name,
data=items,
usage=usage,
)
class OpenAIServingEmbedding(EmbeddingMixin):
request_id_prefix = "embd"
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
async def create_embedding(
self,
request: EmbeddingRequest,
raw_request: Optional[Request] = None,
) -> Union[EmbeddingResponse, ErrorResponse]:
"""
Embedding API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
model_name = self._get_model_name(request.model)
request_id = (f"{self.request_id_prefix}-"
f"{self._base_request_id(raw_request)}")
ctx = EmbeddingServeContext(
request=request,
raw_request=raw_request,
model_name=model_name,
request_id=request_id,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
)
return await super().handle(ctx) # type: ignore
@override
def _validate_request(
self,
ctx: ServeContext[EmbeddingRequest],
) -> Optional[ErrorResponse]:
if error := super()._validate_request(ctx):
return error
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
pooling_params = ctx.request.to_pooling_params()
try:
pooling_params.verify(self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
return None