321 lines
9.5 KiB
Python
321 lines
9.5 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
|
|
|
||
|
|
from collections.abc import Mapping
|
||
|
|
from typing import Any, overload
|
||
|
|
|
||
|
|
from typing_extensions import assert_never
|
||
|
|
|
||
|
|
from vllm.config import VllmConfig
|
||
|
|
from vllm.inputs.data import build_enc_dec_inputs
|
||
|
|
from vllm.logger import init_logger
|
||
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||
|
|
from vllm.multimodal.inputs import (
|
||
|
|
MultiModalDataDict,
|
||
|
|
MultiModalInputs,
|
||
|
|
MultiModalUUIDDict,
|
||
|
|
)
|
||
|
|
from vllm.renderers import BaseRenderer, renderer_from_config
|
||
|
|
from vllm.renderers.inputs import (
|
||
|
|
DecoderDictPrompt,
|
||
|
|
DecoderOnlyDictPrompt,
|
||
|
|
EncoderDecoderDictPrompt,
|
||
|
|
EncoderDictPrompt,
|
||
|
|
SingletonDictPrompt,
|
||
|
|
)
|
||
|
|
from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt
|
||
|
|
from vllm.tokenizers import TokenizerLike
|
||
|
|
|
||
|
|
from .data import (
|
||
|
|
DecoderInputs,
|
||
|
|
DecoderOnlyInputs,
|
||
|
|
EmbedsInputs,
|
||
|
|
EmbedsPrompt,
|
||
|
|
EncoderDecoderInputs,
|
||
|
|
EncoderInputs,
|
||
|
|
ProcessorInputs,
|
||
|
|
PromptType,
|
||
|
|
SingletonInputs,
|
||
|
|
TextPrompt,
|
||
|
|
TokenInputs,
|
||
|
|
TokensPrompt,
|
||
|
|
token_inputs,
|
||
|
|
)
|
||
|
|
|
||
|
|
logger = init_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class InputPreprocessor:
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
vllm_config: VllmConfig,
|
||
|
|
renderer: BaseRenderer | None = None,
|
||
|
|
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||
|
|
) -> None:
|
||
|
|
super().__init__()
|
||
|
|
|
||
|
|
self.model_config = vllm_config.model_config
|
||
|
|
self.renderer = renderer or renderer_from_config(vllm_config)
|
||
|
|
self.mm_registry = mm_registry
|
||
|
|
|
||
|
|
@property
|
||
|
|
def tokenizer(self) -> TokenizerLike | None:
|
||
|
|
return self.renderer.tokenizer
|
||
|
|
|
||
|
|
def get_tokenizer(self) -> TokenizerLike:
|
||
|
|
return self.renderer.get_tokenizer()
|
||
|
|
|
||
|
|
def _tokenize_prompt(
|
||
|
|
self,
|
||
|
|
prompt: str,
|
||
|
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||
|
|
) -> list[int]:
|
||
|
|
"""
|
||
|
|
Apply the model's tokenizer to a text prompt, returning the
|
||
|
|
corresponding token IDs.
|
||
|
|
"""
|
||
|
|
renderer = self.renderer
|
||
|
|
|
||
|
|
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
|
||
|
|
**(tokenization_kwargs or {})
|
||
|
|
)
|
||
|
|
|
||
|
|
tok_prompt = renderer._tokenize_singleton_prompt(
|
||
|
|
TextPrompt(prompt=prompt),
|
||
|
|
tok_params,
|
||
|
|
)
|
||
|
|
|
||
|
|
return tok_prompt["prompt_token_ids"]
|
||
|
|
|
||
|
|
def _process_multimodal(
|
||
|
|
self,
|
||
|
|
prompt: str | list[int],
|
||
|
|
mm_data: MultiModalDataDict,
|
||
|
|
mm_processor_kwargs: Mapping[str, object] | None = None,
|
||
|
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||
|
|
*,
|
||
|
|
mm_uuids: MultiModalUUIDDict | None = None,
|
||
|
|
) -> MultiModalInputs:
|
||
|
|
"""
|
||
|
|
Apply the model's multi-modal processor to a multi-modal prompt,
|
||
|
|
returning the corresponding token IDs and metadata.
|
||
|
|
"""
|
||
|
|
return self.renderer._process_multimodal(
|
||
|
|
prompt,
|
||
|
|
mm_data,
|
||
|
|
mm_uuids=mm_uuids,
|
||
|
|
mm_processor_kwargs=mm_processor_kwargs,
|
||
|
|
tokenization_kwargs=tokenization_kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _process_embeds(
|
||
|
|
self,
|
||
|
|
parsed_content: EmbedsPrompt,
|
||
|
|
) -> EmbedsInputs:
|
||
|
|
return self.renderer._process_embeds(parsed_content)
|
||
|
|
|
||
|
|
def _truncate_inputs(
|
||
|
|
self, inputs: list[int], tokenization_kwargs: dict[str, Any] | None = None
|
||
|
|
) -> list[int]:
|
||
|
|
renderer = self.renderer
|
||
|
|
|
||
|
|
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
|
||
|
|
**(tokenization_kwargs or {})
|
||
|
|
)
|
||
|
|
|
||
|
|
tok_prompt = renderer._tokenize_singleton_prompt(
|
||
|
|
TokensPrompt(prompt_token_ids=inputs),
|
||
|
|
tok_params,
|
||
|
|
)
|
||
|
|
|
||
|
|
return tok_prompt["prompt_token_ids"]
|
||
|
|
|
||
|
|
def _process_tokens(
|
||
|
|
self,
|
||
|
|
parsed_content: TokensPrompt,
|
||
|
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||
|
|
) -> TokenInputs | MultiModalInputs:
|
||
|
|
prompt_token_ids = self._truncate_inputs(
|
||
|
|
parsed_content["prompt_token_ids"], tokenization_kwargs
|
||
|
|
)
|
||
|
|
|
||
|
|
inputs: TokenInputs | MultiModalInputs
|
||
|
|
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
||
|
|
inputs = self._process_multimodal(
|
||
|
|
prompt_token_ids,
|
||
|
|
multi_modal_data,
|
||
|
|
parsed_content.get("mm_processor_kwargs"),
|
||
|
|
tokenization_kwargs=tokenization_kwargs,
|
||
|
|
mm_uuids=parsed_content.get("multi_modal_uuids"),
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
inputs = token_inputs(prompt_token_ids)
|
||
|
|
|
||
|
|
if prompt_text := parsed_content.get("prompt"):
|
||
|
|
inputs["prompt"] = prompt_text
|
||
|
|
if cache_salt := parsed_content.get("cache_salt"):
|
||
|
|
inputs["cache_salt"] = cache_salt
|
||
|
|
|
||
|
|
return inputs
|
||
|
|
|
||
|
|
def _process_text(
|
||
|
|
self,
|
||
|
|
parsed_content: TextPrompt,
|
||
|
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||
|
|
) -> TokenInputs | MultiModalInputs:
|
||
|
|
prompt_text = parsed_content["prompt"]
|
||
|
|
|
||
|
|
inputs: TokenInputs | MultiModalInputs
|
||
|
|
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
||
|
|
inputs = self._process_multimodal(
|
||
|
|
prompt_text,
|
||
|
|
multi_modal_data,
|
||
|
|
parsed_content.get("mm_processor_kwargs") or {},
|
||
|
|
tokenization_kwargs=tokenization_kwargs,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
prompt_token_ids = self._tokenize_prompt(
|
||
|
|
prompt_text,
|
||
|
|
tokenization_kwargs=tokenization_kwargs,
|
||
|
|
)
|
||
|
|
inputs = token_inputs(prompt_token_ids)
|
||
|
|
|
||
|
|
inputs["prompt"] = prompt_text
|
||
|
|
|
||
|
|
if cache_salt := parsed_content.get("cache_salt"):
|
||
|
|
inputs["cache_salt"] = cache_salt
|
||
|
|
|
||
|
|
return inputs
|
||
|
|
|
||
|
|
@overload
|
||
|
|
def _prompt_to_llm_inputs(
|
||
|
|
self,
|
||
|
|
prompt: EncoderDictPrompt,
|
||
|
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||
|
|
) -> EncoderInputs: ...
|
||
|
|
|
||
|
|
@overload
|
||
|
|
def _prompt_to_llm_inputs( # type: ignore[misc]
|
||
|
|
self,
|
||
|
|
prompt: DecoderDictPrompt,
|
||
|
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||
|
|
) -> DecoderInputs: ...
|
||
|
|
|
||
|
|
@overload
|
||
|
|
def _prompt_to_llm_inputs( # type: ignore[misc]
|
||
|
|
self,
|
||
|
|
prompt: DecoderOnlyDictPrompt,
|
||
|
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||
|
|
) -> DecoderOnlyInputs: ...
|
||
|
|
|
||
|
|
def _prompt_to_llm_inputs(
|
||
|
|
self,
|
||
|
|
prompt: SingletonDictPrompt,
|
||
|
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||
|
|
) -> SingletonInputs:
|
||
|
|
"""
|
||
|
|
Extract the singleton inputs from a prompt.
|
||
|
|
|
||
|
|
Arguments:
|
||
|
|
|
||
|
|
* prompt: single encoder or decoder input prompt
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
|
||
|
|
* [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
|
||
|
|
"""
|
||
|
|
if "prompt_embeds" in prompt:
|
||
|
|
return self._process_embeds(prompt) # type: ignore[arg-type]
|
||
|
|
|
||
|
|
if "prompt_token_ids" in prompt:
|
||
|
|
return self._process_tokens(prompt) # type: ignore[arg-type]
|
||
|
|
|
||
|
|
if "prompt" in prompt:
|
||
|
|
return self._process_text(
|
||
|
|
prompt, # type: ignore[arg-type]
|
||
|
|
tokenization_kwargs=tokenization_kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert_never(prompt) # type: ignore[arg-type]
|
||
|
|
|
||
|
|
def _process_encoder_decoder_prompt(
|
||
|
|
self,
|
||
|
|
prompt: EncoderDecoderDictPrompt,
|
||
|
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||
|
|
) -> EncoderDecoderInputs:
|
||
|
|
"""
|
||
|
|
For encoder/decoder models only:
|
||
|
|
Process an input prompt into an
|
||
|
|
[`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
|
||
|
|
instance.
|
||
|
|
|
||
|
|
Arguments:
|
||
|
|
|
||
|
|
* prompt: an input prompt
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
|
||
|
|
* [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
|
||
|
|
instance
|
||
|
|
"""
|
||
|
|
encoder_prompt = prompt["encoder_prompt"]
|
||
|
|
decoder_prompt = prompt["decoder_prompt"]
|
||
|
|
|
||
|
|
return build_enc_dec_inputs(
|
||
|
|
encoder_inputs=self._prompt_to_llm_inputs(
|
||
|
|
encoder_prompt,
|
||
|
|
tokenization_kwargs=tokenization_kwargs,
|
||
|
|
),
|
||
|
|
decoder_inputs=(
|
||
|
|
None
|
||
|
|
if decoder_prompt is None
|
||
|
|
else self._prompt_to_llm_inputs(
|
||
|
|
decoder_prompt,
|
||
|
|
tokenization_kwargs=tokenization_kwargs,
|
||
|
|
)
|
||
|
|
),
|
||
|
|
decoder_start_token_id=self.renderer.get_dec_start_token_id(),
|
||
|
|
)
|
||
|
|
|
||
|
|
def _process_decoder_only_prompt(
|
||
|
|
self,
|
||
|
|
prompt: DecoderOnlyDictPrompt,
|
||
|
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||
|
|
) -> DecoderOnlyInputs:
|
||
|
|
"""
|
||
|
|
For decoder-only models:
|
||
|
|
Process an input prompt into a
|
||
|
|
[`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance.
|
||
|
|
|
||
|
|
Arguments:
|
||
|
|
|
||
|
|
* prompt: input prompt
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
|
||
|
|
* [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance
|
||
|
|
"""
|
||
|
|
return self._prompt_to_llm_inputs(
|
||
|
|
prompt,
|
||
|
|
tokenization_kwargs=tokenization_kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
def preprocess(
|
||
|
|
self,
|
||
|
|
prompt: PromptType,
|
||
|
|
tokenization_kwargs: dict[str, Any] | None = None,
|
||
|
|
) -> ProcessorInputs:
|
||
|
|
"""Preprocess the input prompt."""
|
||
|
|
if self.model_config.is_encoder_decoder:
|
||
|
|
# Encoder-decoder model requires special mapping of
|
||
|
|
# input prompts to encoder & decoder.
|
||
|
|
return self._process_encoder_decoder_prompt(
|
||
|
|
parse_enc_dec_prompt(prompt),
|
||
|
|
tokenization_kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
return self._process_decoder_only_prompt(
|
||
|
|
parse_dec_only_prompt(prompt),
|
||
|
|
tokenization_kwargs=tokenization_kwargs,
|
||
|
|
)
|