import dataclasses from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.distributed import get_pp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MultiModalInputs from vllm.pooling_params import PoolingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, SequenceGroupMetadata) from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU, ModelInputForGPUBuilder) logger = init_logger(__name__) @dataclasses.dataclass(frozen=True) class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): """ Used by the EmbeddingModelRunner. """ pooling_metadata: Optional["PoolingMetadata"] = None class EmbeddingModelRunner( GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = ( ModelInputForGPUWithPoolingMetadata) _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder def __init__( self, model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, cache_config: CacheConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, prompt_adapter_config: Optional[PromptAdapterConfig] = None, observability_config: Optional[ObservabilityConfig] = None, ): super().__init__(model_config, parallel_config, scheduler_config, device_config, cache_config, load_config, lora_config=lora_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, prompt_adapter_config=prompt_adapter_config, observability_config=observability_config) @torch.inference_mode() def execute_model( self, model_input: ModelInputForGPUWithPoolingMetadata, kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError( "EmbeddingModelRunner does not support multi-step execution.") if self.lora_config: assert model_input.lora_requests is not None assert model_input.lora_mapping is not None self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) if self.prompt_adapter_config: assert model_input.prompt_adapter_requests is not None assert model_input.prompt_adapter_mapping is not None self.set_active_prompt_adapters( model_input.prompt_adapter_requests, model_input.prompt_adapter_mapping) # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata decode_meta = model_input.attn_metadata.decode_metadata virtual_engine = model_input.virtual_engine if prefill_meta is None and decode_meta.use_cuda_graph: assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] model_executable = self.graph_runners[virtual_engine][ graph_batch_size] else: model_executable = self.model num_layers = self.model_config.get_num_layers(self.parallel_config) # use an empty tensor instead of `None`` to force Dynamo to pass # it by reference, rather by specializing on the value ``None``. # the `dtype` argument does not matter, and we use `float32` as # a placeholder (it has wide hardware support). kv_caches = [ torch.tensor([], dtype=torch.float32, device=self.device) for _ in range(num_layers) ] multi_modal_kwargs = model_input.multi_modal_kwargs or {} if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_start = torch.cuda.Event(enable_timing=True) model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() with set_forward_context(model_input.attn_metadata): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalInputs.as_kwargs(multi_modal_kwargs, device=self.device)) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end.record() # Only perform pooling in the last pipeline stage. if not get_pp_group().is_last_rank: if (self.is_driver_worker and hidden_or_intermediate_states is not None and isinstance(hidden_or_intermediate_states, IntermediateTensors) and self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end.synchronize() model_forward_time = model_forward_start.elapsed_time( model_forward_end) orig_model_forward_time = 0.0 if intermediate_tensors is not None: orig_model_forward_time = intermediate_tensors.tensors.get( "model_forward_time", torch.tensor(0.0)).item() hidden_or_intermediate_states.tensors["model_forward_time"] = ( torch.tensor(model_forward_time + orig_model_forward_time)) return hidden_or_intermediate_states # Only perform pooling in the driver worker. if not self.is_driver_worker: return [] return [ self.model.pooler(hidden_states=hidden_or_intermediate_states, pooling_metadata=model_input.pooling_metadata) ] def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any]) -> ModelInputForGPUWithPoolingMetadata: return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, ) def prepare_model_input( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForGPUWithPoolingMetadata: assert seq_group_metadata_list is not None model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) # Prepare PoolingMetadata. assert model_input.seq_lens is not None pooling_metadata = self._prepare_pooling(seq_group_metadata_list, model_input.seq_lens) return dataclasses.replace(model_input, pooling_metadata=pooling_metadata) def _prepare_pooling( self, seq_group_metadata_list: List[SequenceGroupMetadata], prompt_lens: List[int], ) -> PoolingMetadata: """Prepare PoolingMetadata for the sequence group metadata list.""" seq_groups: List[Tuple[List[int], PoolingParams]] = [] for i, seq_group_metadata in enumerate(seq_group_metadata_list): seq_ids = list(seq_group_metadata.seq_data.keys()) pooling_params = seq_group_metadata.pooling_params seq_groups.append((seq_ids, pooling_params)) seq_data: Dict[int, SequenceData] = {} for seq_group_metadata in seq_group_metadata_list: seq_data.update(seq_group_metadata.seq_data) pooling_metadata = PoolingMetadata( seq_groups=seq_groups, seq_data=seq_data, prompt_lens=prompt_lens, ) return pooling_metadata