185 lines
8.1 KiB
Python
185 lines
8.1 KiB
Python
|
|
|
|
import time
|
|
from collections.abc import Mapping
|
|
from typing import Any, Literal, Optional, Union
|
|
|
|
import torch
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
|
from vllm.inputs.parse import split_enc_dec_inputs
|
|
from vllm.inputs.preprocess import InputPreprocessor
|
|
from vllm.logger import init_logger
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
|
from vllm.multimodal.cache import processor_cache_from_config
|
|
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
|
|
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
|
from vllm.multimodal.utils import argsort_mm_positions
|
|
from vllm.pooling_params import PoolingParams
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
|
from vllm.v1.engine import EngineCoreRequest
|
|
from vllm.v1.structured_output.backend_guidance import (
|
|
validate_guidance_grammar)
|
|
from vllm.v1.structured_output.backend_lm_format_enforcer import (
|
|
validate_structured_output_request_lm_format_enforcer)
|
|
from vllm.v1.structured_output.backend_outlines import (
|
|
validate_structured_output_request_outlines)
|
|
from vllm.v1.structured_output.backend_xgrammar import (
|
|
validate_xgrammar_grammar)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class Processor:
|
|
|
|
def process_inputs(
|
|
self,
|
|
request_id: str,
|
|
prompt: PromptType,
|
|
params: Union[SamplingParams, PoolingParams],
|
|
arrival_time: Optional[float] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
|
trace_headers: Optional[Mapping[str, str]] = None,
|
|
priority: int = 0,
|
|
data_parallel_rank: Optional[int] = None,
|
|
) -> tuple[Optional[str], EngineCoreRequest]:
|
|
|
|
# TODO(woosuk): Support pooling models.
|
|
self._validate_lora(lora_request)
|
|
self._validate_params(params)
|
|
|
|
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
|
|
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
|
|
data_parallel_size):
|
|
raise ValueError(f"data_parallel_rank {data_parallel_rank} "
|
|
f"is out of range [0, {data_parallel_size}).")
|
|
|
|
if arrival_time is None:
|
|
arrival_time = time.time()
|
|
|
|
# Optionally generate multimodal hash overrides to avoid hashing
|
|
# multimodal data items by their content as their identifiers.
|
|
|
|
# NOTE: when users explicitly turn off BOTH prefix caching and input
|
|
# processing caching, no multimodal features or embeddings will be
|
|
# reused across requests, therefore identifying multimodal data items
|
|
# by their content is no longer necessary, and we create uuids with
|
|
# request id-modality-index as multimodal hash overrides.
|
|
if (self.model_config.multimodal_config and
|
|
self.model_config.multimodal_config.mm_processor_cache_gb == 0
|
|
and not self.cache_config.enable_prefix_caching):
|
|
mm_uuids = self._maybe_build_mm_uuids(request_id, prompt)
|
|
else:
|
|
# Otherwise, use user-provided uuids as multimodal hash overrides
|
|
# if provided.
|
|
self._validate_multi_modal_uuids(prompt)
|
|
if isinstance(prompt, dict):
|
|
mm_uuids = prompt.get("multi_modal_uuids")
|
|
else:
|
|
mm_uuids = None
|
|
|
|
# Process inputs, which includes:
|
|
# 1. Tokenize text prompt, with LoRA request if one exists.
|
|
# 2. For multimodal models with a merged preprocessor, preprocess
|
|
# multimodal data and expand prompt token ids accordingly.
|
|
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
|
|
prompt,
|
|
tokenization_kwargs=tokenization_kwargs,
|
|
mm_uuids=mm_uuids,
|
|
)
|
|
from vllm.platforms import current_platform
|
|
current_platform.validate_request(
|
|
prompt=prompt,
|
|
params=params,
|
|
processed_inputs=processed_inputs,
|
|
)
|
|
|
|
eos_token_id = self.input_preprocessor.get_eos_token_id()
|
|
|
|
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
|
self._validate_model_inputs(encoder_inputs, decoder_inputs)
|
|
|
|
# Mypy does not always properly infer the types of some elements of
|
|
# discriminated unions of TypedDicts, because of how it handles
|
|
# inheritance of TypedDict. If we explicitly extract the items we want
|
|
# we can avoid type errors from using `dict.get` later in the method.
|
|
prompt_str: Optional[str] = None if decoder_inputs[
|
|
"type"] == "embeds" else decoder_inputs.get("prompt")
|
|
prompt_token_ids = decoder_inputs[
|
|
"prompt_token_ids"] if decoder_inputs["type"] != "embeds" else None
|
|
prompt_embeds = decoder_inputs["prompt_embeds"] if decoder_inputs[
|
|
"type"] == "embeds" else None
|
|
deepstack_input_embeds = decoder_inputs["deepstack_input_embeds"] if decoder_inputs[
|
|
"type"] == "embeds" else None
|
|
|
|
# for deepstack_input_embeds in llm.generate method
|
|
if isinstance(deepstack_input_embeds, dict):
|
|
all_tensors = []
|
|
for key in deepstack_input_embeds:
|
|
if isinstance(deepstack_input_embeds[key], torch.Tensor):
|
|
all_tensors.append(deepstack_input_embeds[key].unsqueeze(0))
|
|
if len(all_tensors) > 0:
|
|
deepstack_input_embeds = torch.concatenate(all_tensors, 0)
|
|
|
|
sampling_params = None
|
|
pooling_params = None
|
|
if isinstance(params, SamplingParams):
|
|
# TODO: can we avoid cloning here in multiproc case?
|
|
sampling_params = params.clone()
|
|
# If unset max tokens, then generate up to the max_model_len.
|
|
if sampling_params.max_tokens is None:
|
|
seq_len = length_from_prompt_token_ids_or_embeds(
|
|
prompt_token_ids, prompt_embeds)
|
|
sampling_params.max_tokens = \
|
|
self.model_config.max_model_len - seq_len
|
|
sampling_params.update_from_generation_config(
|
|
self.generation_config_fields, eos_token_id)
|
|
if self.tokenizer is not None:
|
|
sampling_params.update_from_tokenizer(self.tokenizer)
|
|
else:
|
|
pooling_params = params.clone()
|
|
|
|
# Multimodal related.
|
|
mm_features: Optional[list[MultiModalFeatureSpec]] = None
|
|
|
|
if decoder_inputs["type"] == "multimodal":
|
|
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
|
|
decoder_mm_positions = decoder_inputs["mm_placeholders"]
|
|
decoder_mm_hashes = decoder_inputs["mm_hashes"]
|
|
|
|
# Merge and flatten multimodal placeholders, hashes and inputs
|
|
# from dictionaries to lists, and sort them by each item's position
|
|
# in the input sequence.
|
|
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
|
|
|
|
mm_features = []
|
|
for modality, idx in sorted_mm_idxs:
|
|
mm_features.append(
|
|
MultiModalFeatureSpec(
|
|
data=decoder_mm_inputs[modality][idx],
|
|
modality=modality,
|
|
identifier=decoder_mm_hashes[modality][idx],
|
|
mm_position=decoder_mm_positions[modality][idx]))
|
|
|
|
return prompt_str, EngineCoreRequest(
|
|
request_id=request_id,
|
|
prompt_token_ids=prompt_token_ids,
|
|
prompt_embeds=prompt_embeds,
|
|
deepstack_input_embeds=deepstack_input_embeds,
|
|
mm_features=mm_features,
|
|
sampling_params=sampling_params,
|
|
pooling_params=pooling_params,
|
|
eos_token_id=eos_token_id,
|
|
arrival_time=arrival_time,
|
|
lora_request=lora_request,
|
|
cache_salt=decoder_inputs.get("cache_salt"),
|
|
priority=priority,
|
|
data_parallel_rank=data_parallel_rank,
|
|
trace_headers=trace_headers,
|
|
)
|