init
This commit is contained in:
184
vllm_vacc/vllm/v1/engine/processor.py
Normal file
184
vllm_vacc/vllm/v1/engine/processor.py
Normal file
@@ -0,0 +1,184 @@
|
||||
|
||||
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
||||
from vllm.inputs.parse import split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.cache import processor_cache_from_config
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
|
||||
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||
from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
validate_guidance_grammar)
|
||||
from vllm.v1.structured_output.backend_lm_format_enforcer import (
|
||||
validate_structured_output_request_lm_format_enforcer)
|
||||
from vllm.v1.structured_output.backend_outlines import (
|
||||
validate_structured_output_request_outlines)
|
||||
from vllm.v1.structured_output.backend_xgrammar import (
|
||||
validate_xgrammar_grammar)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Processor:
|
||||
|
||||
def process_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> tuple[Optional[str], EngineCoreRequest]:
|
||||
|
||||
# TODO(woosuk): Support pooling models.
|
||||
self._validate_lora(lora_request)
|
||||
self._validate_params(params)
|
||||
|
||||
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
|
||||
data_parallel_size):
|
||||
raise ValueError(f"data_parallel_rank {data_parallel_rank} "
|
||||
f"is out of range [0, {data_parallel_size}).")
|
||||
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
# Optionally generate multimodal hash overrides to avoid hashing
|
||||
# multimodal data items by their content as their identifiers.
|
||||
|
||||
# NOTE: when users explicitly turn off BOTH prefix caching and input
|
||||
# processing caching, no multimodal features or embeddings will be
|
||||
# reused across requests, therefore identifying multimodal data items
|
||||
# by their content is no longer necessary, and we create uuids with
|
||||
# request id-modality-index as multimodal hash overrides.
|
||||
if (self.model_config.multimodal_config and
|
||||
self.model_config.multimodal_config.mm_processor_cache_gb == 0
|
||||
and not self.cache_config.enable_prefix_caching):
|
||||
mm_uuids = self._maybe_build_mm_uuids(request_id, prompt)
|
||||
else:
|
||||
# Otherwise, use user-provided uuids as multimodal hash overrides
|
||||
# if provided.
|
||||
self._validate_multi_modal_uuids(prompt)
|
||||
if isinstance(prompt, dict):
|
||||
mm_uuids = prompt.get("multi_modal_uuids")
|
||||
else:
|
||||
mm_uuids = None
|
||||
|
||||
# Process inputs, which includes:
|
||||
# 1. Tokenize text prompt, with LoRA request if one exists.
|
||||
# 2. For multimodal models with a merged preprocessor, preprocess
|
||||
# multimodal data and expand prompt token ids accordingly.
|
||||
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.validate_request(
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
processed_inputs=processed_inputs,
|
||||
)
|
||||
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id()
|
||||
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||
self._validate_model_inputs(encoder_inputs, decoder_inputs)
|
||||
|
||||
# Mypy does not always properly infer the types of some elements of
|
||||
# discriminated unions of TypedDicts, because of how it handles
|
||||
# inheritance of TypedDict. If we explicitly extract the items we want
|
||||
# we can avoid type errors from using `dict.get` later in the method.
|
||||
prompt_str: Optional[str] = None if decoder_inputs[
|
||||
"type"] == "embeds" else decoder_inputs.get("prompt")
|
||||
prompt_token_ids = decoder_inputs[
|
||||
"prompt_token_ids"] if decoder_inputs["type"] != "embeds" else None
|
||||
prompt_embeds = decoder_inputs["prompt_embeds"] if decoder_inputs[
|
||||
"type"] == "embeds" else None
|
||||
deepstack_input_embeds = decoder_inputs["deepstack_input_embeds"] if decoder_inputs[
|
||||
"type"] == "embeds" else None
|
||||
|
||||
# for deepstack_input_embeds in llm.generate method
|
||||
if isinstance(deepstack_input_embeds, dict):
|
||||
all_tensors = []
|
||||
for key in deepstack_input_embeds:
|
||||
if isinstance(deepstack_input_embeds[key], torch.Tensor):
|
||||
all_tensors.append(deepstack_input_embeds[key].unsqueeze(0))
|
||||
if len(all_tensors) > 0:
|
||||
deepstack_input_embeds = torch.concatenate(all_tensors, 0)
|
||||
|
||||
sampling_params = None
|
||||
pooling_params = None
|
||||
if isinstance(params, SamplingParams):
|
||||
# TODO: can we avoid cloning here in multiproc case?
|
||||
sampling_params = params.clone()
|
||||
# If unset max tokens, then generate up to the max_model_len.
|
||||
if sampling_params.max_tokens is None:
|
||||
seq_len = length_from_prompt_token_ids_or_embeds(
|
||||
prompt_token_ids, prompt_embeds)
|
||||
sampling_params.max_tokens = \
|
||||
self.model_config.max_model_len - seq_len
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields, eos_token_id)
|
||||
if self.tokenizer is not None:
|
||||
sampling_params.update_from_tokenizer(self.tokenizer)
|
||||
else:
|
||||
pooling_params = params.clone()
|
||||
|
||||
# Multimodal related.
|
||||
mm_features: Optional[list[MultiModalFeatureSpec]] = None
|
||||
|
||||
if decoder_inputs["type"] == "multimodal":
|
||||
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
|
||||
decoder_mm_positions = decoder_inputs["mm_placeholders"]
|
||||
decoder_mm_hashes = decoder_inputs["mm_hashes"]
|
||||
|
||||
# Merge and flatten multimodal placeholders, hashes and inputs
|
||||
# from dictionaries to lists, and sort them by each item's position
|
||||
# in the input sequence.
|
||||
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
|
||||
|
||||
mm_features = []
|
||||
for modality, idx in sorted_mm_idxs:
|
||||
mm_features.append(
|
||||
MultiModalFeatureSpec(
|
||||
data=decoder_mm_inputs[modality][idx],
|
||||
modality=modality,
|
||||
identifier=decoder_mm_hashes[modality][idx],
|
||||
mm_position=decoder_mm_positions[modality][idx]))
|
||||
|
||||
return prompt_str, EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_embeds=prompt_embeds,
|
||||
deepstack_input_embeds=deepstack_input_embeds,
|
||||
mm_features=mm_features,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=pooling_params,
|
||||
eos_token_id=eos_token_id,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
cache_salt=decoder_inputs.get("cache_salt"),
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
Reference in New Issue
Block a user