Files
2026-01-19 10:38:50 +08:00

78 lines
2.5 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Sequence
from typing import Any, Generic, TypeVar
from vllm.config import VllmConfig
from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse
from vllm.inputs.data import PromptType
from vllm.outputs import PoolingRequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
IOProcessorInput = TypeVar("IOProcessorInput")
IOProcessorOutput = TypeVar("IOProcessorOutput")
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
@abstractmethod
def pre_process(
self,
prompt: IOProcessorInput,
request_id: str | None = None,
**kwargs,
) -> PromptType | Sequence[PromptType]:
raise NotImplementedError
async def pre_process_async(
self,
prompt: IOProcessorInput,
request_id: str | None = None,
**kwargs,
) -> PromptType | Sequence[PromptType]:
return self.pre_process(prompt, request_id, **kwargs)
@abstractmethod
def post_process(
self,
model_output: Sequence[PoolingRequestOutput],
request_id: str | None = None,
**kwargs,
) -> IOProcessorOutput:
raise NotImplementedError
async def post_process_async(
self,
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
request_id: str | None = None,
**kwargs,
) -> IOProcessorOutput:
# We cannot guarantee outputs are returned in the same order they were
# fed to vLLM.
# Let's sort them by id before post_processing
sorted_output = sorted(
[(i, item) async for i, item in model_output], key=lambda output: output[0]
)
collected_output = [output[1] for output in sorted_output]
return self.post_process(collected_output, request_id, **kwargs)
@abstractmethod
def parse_request(self, request: Any) -> IOProcessorInput:
raise NotImplementedError
def validate_or_generate_params(
self, params: SamplingParams | PoolingParams | None = None
) -> SamplingParams | PoolingParams:
return params or PoolingParams()
@abstractmethod
def output_to_response(
self, plugin_output: IOProcessorOutput
) -> IOProcessorResponse:
raise NotImplementedError