Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -1,116 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Final, TypeAlias
|
||||
from typing import TypeAlias
|
||||
|
||||
import jinja2
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
from vllm import ClassificationOutput
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
from vllm.entrypoints.pooling.base.serving import PoolingServeContext, PoolingServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.renderers import BaseRenderer
|
||||
|
||||
from .io_processor import ClassifyIOProcessor
|
||||
from .protocol import (
|
||||
ClassificationData,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import ClassificationOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest]
|
||||
ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationRequest]
|
||||
|
||||
|
||||
class ServingClassification(OpenAIServing):
|
||||
class ServingClassification(PoolingServing):
|
||||
request_id_prefix = "classify"
|
||||
|
||||
def __init__(
|
||||
def init_io_processor(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
trust_request_chat_template: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
) -> ClassifyIOProcessor:
|
||||
return ClassifyIOProcessor(
|
||||
model_config=model_config,
|
||||
renderer=renderer,
|
||||
chat_template_config=chat_template_config,
|
||||
)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
|
||||
async def _preprocess(
|
||||
async def _build_response(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
) -> ErrorResponse | None:
|
||||
"""
|
||||
Process classification inputs: tokenize text, resolve adapters,
|
||||
and prepare model-specific inputs.
|
||||
"""
|
||||
try:
|
||||
ctx.lora_request = self._maybe_get_adapters(ctx.request)
|
||||
) -> ClassificationResponse:
|
||||
final_res_batch_checked = await self.io_processor.post_process_async(
|
||||
ctx.final_res_batch
|
||||
)
|
||||
|
||||
if isinstance(ctx.request, ClassificationChatRequest):
|
||||
error_check_ret = self._validate_chat_template(
|
||||
request_chat_template=ctx.request.chat_template,
|
||||
chat_template_kwargs=ctx.request.chat_template_kwargs,
|
||||
trust_request_chat_template=self.trust_request_chat_template,
|
||||
)
|
||||
if error_check_ret:
|
||||
return error_check_ret
|
||||
|
||||
_, ctx.engine_prompts = await self._preprocess_chat(
|
||||
ctx.request,
|
||||
ctx.request.messages,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
)
|
||||
elif isinstance(ctx.request, ClassificationCompletionRequest):
|
||||
ctx.engine_prompts = await self._preprocess_completion(
|
||||
ctx.request,
|
||||
prompt_input=ctx.request.input,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
else:
|
||||
return self.create_error_response("Invalid classification request type")
|
||||
|
||||
return None
|
||||
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
) -> ClassificationResponse | ErrorResponse:
|
||||
"""
|
||||
Convert model outputs to a formatted classification response
|
||||
with probabilities and labels.
|
||||
"""
|
||||
id2label = getattr(self.model_config.hf_config, "id2label", {})
|
||||
|
||||
items: list[ClassificationData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
final_res_batch_checked = ctx.final_res_batch
|
||||
|
||||
items: list[ClassificationData] = []
|
||||
for idx, final_res in enumerate(final_res_batch_checked):
|
||||
classify_res = ClassificationOutput.from_base(final_res.outputs)
|
||||
|
||||
@@ -141,20 +82,3 @@ class ServingClassification(OpenAIServing):
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def create_classify(
|
||||
self,
|
||||
request: ClassificationRequest,
|
||||
raw_request: Request,
|
||||
) -> ClassificationResponse | ErrorResponse:
|
||||
model_name = self.models.model_name()
|
||||
request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
|
||||
|
||||
ctx = ClassificationServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
return await self.handle(ctx) # type: ignore[return-value]
|
||||
|
||||
Reference in New Issue
Block a user