Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -3,16 +3,17 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from starlette.responses import JSONResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.classify.serving import ServingClassification
|
||||
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||
from vllm.entrypoints.utils import (
|
||||
create_error_response,
|
||||
load_aware_call,
|
||||
with_cancellation,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -24,25 +25,17 @@ def classify(request: Request) -> ServingClassification | None:
|
||||
@router.post("/classify", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_classify(request: ClassificationRequest, raw_request: Request):
|
||||
async def create_classify(
|
||||
request: ClassificationRequest, raw_request: Request
|
||||
) -> JSONResponse:
|
||||
handler = classify(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
error_response = create_error_response(
|
||||
message="The model does not support Classification API"
|
||||
)
|
||||
|
||||
try:
|
||||
generator = await handler.create_classify(request, raw_request)
|
||||
except Exception as e:
|
||||
generator = handler.create_error_response(e)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
content=error_response.model_dump(),
|
||||
status_code=error_response.error.code,
|
||||
)
|
||||
|
||||
elif isinstance(generator, ClassificationResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
return await handler(request, raw_request)
|
||||
|
||||
50
vllm/entrypoints/pooling/classify/io_processor.py
Normal file
50
vllm/entrypoints/pooling/classify/io_processor.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from vllm import PromptType
|
||||
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
)
|
||||
from vllm.inputs import ProcessorInputs
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
|
||||
|
||||
class ClassifyIOProcessor(PoolingIOProcessor):
|
||||
def pre_process_online(
|
||||
self, request: ClassificationCompletionRequest | ClassificationChatRequest
|
||||
) -> list[TokPrompt] | None:
|
||||
if isinstance(request, ClassificationChatRequest):
|
||||
self._validate_chat_template(
|
||||
request_chat_template=request.chat_template,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
trust_request_chat_template=self.trust_request_chat_template,
|
||||
)
|
||||
_, engine_prompts = self._preprocess_chat_online(
|
||||
request,
|
||||
request.messages,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
)
|
||||
elif isinstance(request, ClassificationCompletionRequest):
|
||||
engine_prompts = self._preprocess_completion_online(
|
||||
request,
|
||||
prompt_input=request.input,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid classification request type")
|
||||
return engine_prompts
|
||||
|
||||
def pre_process_offline(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> Sequence[ProcessorInputs]:
|
||||
return self._preprocess_completion_offline(
|
||||
prompts=prompts, tokenization_kwargs=tokenization_kwargs
|
||||
)
|
||||
@@ -40,7 +40,6 @@ class ClassificationCompletionRequest(
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
task="classify",
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
)
|
||||
|
||||
@@ -63,7 +62,6 @@ class ClassificationChatRequest(
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
task="classify",
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
use_activation=self.use_activation,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,116 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Final, TypeAlias
|
||||
from typing import TypeAlias
|
||||
|
||||
import jinja2
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
from vllm import ClassificationOutput
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
from vllm.entrypoints.pooling.base.serving import PoolingServeContext, PoolingServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.renderers import BaseRenderer
|
||||
|
||||
from .io_processor import ClassifyIOProcessor
|
||||
from .protocol import (
|
||||
ClassificationData,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import ClassificationOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest]
|
||||
ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationRequest]
|
||||
|
||||
|
||||
class ServingClassification(OpenAIServing):
|
||||
class ServingClassification(PoolingServing):
|
||||
request_id_prefix = "classify"
|
||||
|
||||
def __init__(
|
||||
def init_io_processor(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
trust_request_chat_template: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
) -> ClassifyIOProcessor:
|
||||
return ClassifyIOProcessor(
|
||||
model_config=model_config,
|
||||
renderer=renderer,
|
||||
chat_template_config=chat_template_config,
|
||||
)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
|
||||
async def _preprocess(
|
||||
async def _build_response(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
) -> ErrorResponse | None:
|
||||
"""
|
||||
Process classification inputs: tokenize text, resolve adapters,
|
||||
and prepare model-specific inputs.
|
||||
"""
|
||||
try:
|
||||
ctx.lora_request = self._maybe_get_adapters(ctx.request)
|
||||
) -> ClassificationResponse:
|
||||
final_res_batch_checked = await self.io_processor.post_process_async(
|
||||
ctx.final_res_batch
|
||||
)
|
||||
|
||||
if isinstance(ctx.request, ClassificationChatRequest):
|
||||
error_check_ret = self._validate_chat_template(
|
||||
request_chat_template=ctx.request.chat_template,
|
||||
chat_template_kwargs=ctx.request.chat_template_kwargs,
|
||||
trust_request_chat_template=self.trust_request_chat_template,
|
||||
)
|
||||
if error_check_ret:
|
||||
return error_check_ret
|
||||
|
||||
_, ctx.engine_prompts = await self._preprocess_chat(
|
||||
ctx.request,
|
||||
ctx.request.messages,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
)
|
||||
elif isinstance(ctx.request, ClassificationCompletionRequest):
|
||||
ctx.engine_prompts = await self._preprocess_completion(
|
||||
ctx.request,
|
||||
prompt_input=ctx.request.input,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
else:
|
||||
return self.create_error_response("Invalid classification request type")
|
||||
|
||||
return None
|
||||
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
) -> ClassificationResponse | ErrorResponse:
|
||||
"""
|
||||
Convert model outputs to a formatted classification response
|
||||
with probabilities and labels.
|
||||
"""
|
||||
id2label = getattr(self.model_config.hf_config, "id2label", {})
|
||||
|
||||
items: list[ClassificationData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
final_res_batch_checked = ctx.final_res_batch
|
||||
|
||||
items: list[ClassificationData] = []
|
||||
for idx, final_res in enumerate(final_res_batch_checked):
|
||||
classify_res = ClassificationOutput.from_base(final_res.outputs)
|
||||
|
||||
@@ -141,20 +82,3 @@ class ServingClassification(OpenAIServing):
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def create_classify(
|
||||
self,
|
||||
request: ClassificationRequest,
|
||||
raw_request: Request,
|
||||
) -> ClassificationResponse | ErrorResponse:
|
||||
model_name = self.models.model_name()
|
||||
request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
|
||||
|
||||
ctx = ClassificationServeContext(
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
model_name=model_name,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
return await self.handle(ctx) # type: ignore[return-value]
|
||||
|
||||
Reference in New Issue
Block a user