from collections import defaultdict from typing import Dict, List, NamedTuple, Optional, Tuple import openvino as ov import torch from torch import nn from vllm.attention import get_attn_backend from vllm.attention.backends.openvino import OpenVINOAttentionMetadata from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.openvino import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs, MultiModalPlaceholderMap) from vllm.sequence import SequenceGroupMetadata from vllm.worker.model_runner_base import ModelRunnerBase logger = init_logger(__name__) class ModelInput(NamedTuple): input_tokens: torch.Tensor input_positions: torch.Tensor attn_metadata: Optional[OpenVINOAttentionMetadata] seq_lens: List[int] query_lens: List[int] multi_modal_kwargs: BatchedTensorInputs @classmethod def empty(cls, device): return ModelInput(input_tokens=torch.empty(0, device=device), input_positions=torch.empty(0, device=device), attn_metadata=None, seq_lens=[], query_lens=[], multi_modal_kwargs={}) class OpenVINOModelRunner(ModelRunnerBase): def __init__( self, ov_core: ov.Core, vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, *args, **kwargs, ): self.ov_core = ov_core ModelRunnerBase.__init__(self, vllm_config=vllm_config) cache_config = self.cache_config model_config = self.model_config self.is_driver_worker = is_driver_worker self.device = self.device_config.device self.kv_cache_dtype = kv_cache_dtype self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.attn_backend = get_attn_backend( self.model_config.get_head_size(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, self.model_config.is_attention_free, ) # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.multi_modal_input_mapper = self.mm_registry \ .create_input_mapper(self.model_config) # Lazy initialization. self.model: nn.Module # Set after init_Model def load_model(self) -> None: self.model = get_model(model_config=self.model_config, device_config=self.device_config, kv_cache_dtype=self.kv_cache_dtype, ov_core=self.ov_core) def _prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> ModelInput: """Prepare the model input based on a given sequence group. The API assumes seq_group_metadata_list is sorted by prefill -> decode. The result tensors and data structure also batches input in prefill -> decode order. For example, - input_tokens[:num_prefill_tokens] contains prefill tokens. - input_tokens[num_prefill_tokens:] contains decode tokens. """ input_tokens: List[int] = [] input_positions: List[int] = [] seq_lens: List[int] = [] past_lens: List[int] = [] query_lens: List[int] = [] multi_modal_kwargs_list: List[MultiModalKwargs] = [] multi_modal_placeholder_maps: Dict[ str, MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) subsequence_begins: List[int] = [] block_indices: List[int] = [] block_indices_begins: List[int] = [] # initialize beginning of prefix sums subsequence_begins.append(0) block_indices_begins.append(0) if len(seq_group_metadata_list) == 0: return ModelInput.empty(self.device) for seq_group_metadata in seq_group_metadata_list: seq_ids = list(seq_group_metadata.seq_data.keys()) is_prompt = seq_group_metadata.is_prompt for seq_id in seq_ids: computed_block_nums = seq_group_metadata.computed_block_nums if (self.scheduler_config is not None and self.scheduler_config.chunked_prefill_enabled and not (computed_block_nums is None or computed_block_nums == [])): raise RuntimeError( "chunked prefill cannot be used with prefix caching " "now.") seq_data = seq_group_metadata.seq_data[seq_id] if is_prompt: computed_len = seq_data.get_num_computed_tokens() else: # get_num_computed_tokens is incorrect for spec decoding. # So, we should have a special logic here. # TODO(sang): Fix it. computed_len = seq_data.get_len() - 1 seq_len = min( seq_data.get_len(), computed_len + seq_group_metadata.token_chunk_size, ) if is_prompt: tokens = seq_data.get_token_ids()[computed_len:seq_len] else: # Optimization. get_token_ids requires the entire copy of # tokens. tokens = [seq_data.get_last_token_id()] # Prefix cache was hit. # Prefix is not supported with sliding_window prefix_cache_hit = (computed_block_nums is not None and len(computed_block_nums) > 0 and self.sliding_window is None and is_prompt) block_table = seq_group_metadata.block_tables[seq_id] # TODO(sang): Combine chunked prefill and prefix caching by # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. if prefix_cache_hit: assert computed_block_nums is not None computed_len = len(computed_block_nums) * self.block_size tokens = tokens[computed_len:] elif (self.scheduler_config.chunked_prefill_enabled or not is_prompt): if seq_group_metadata.block_tables is not None: # chunked prefill or decode block_table = seq_group_metadata.block_tables[seq_id] if self.sliding_window is not None: # chunked prefill doesn't support sliding window. assert not self.scheduler_config.chunked_prefill_enabled # noqa: E501 sliding_window_blocks = (self.sliding_window // self.block_size) block_table = block_table[-sliding_window_blocks:] else: # Only happens when memory profiling runs. block_table = [] else: # prompt phase w/o prefix_caching, chunked_prefill pass block_indices.extend(block_table) block_indices_begins.append(block_indices_begins[-1] + len(block_table)) # TODO(sang): This is a hack to make sliding window work with # paged attn. We can remove it if we make paged attn kernel # to properly handle slinding window attn. if self.sliding_window is not None and not is_prompt: seq_len = min(seq_len, self.sliding_window) computed_len = seq_len - 1 seq_lens.append(seq_len) query_len = seq_len - computed_len query_lens.append(query_len) input_tokens.extend(tokens) positions_range = range(computed_len, seq_len) input_positions.extend(list(positions_range)) past_lens.append(computed_len) subsequence_begins.append(subsequence_begins[-1] + query_len) if is_prompt: assert len(seq_ids) == 1 else: assert ( query_len == 1 ), "seq_len: {}, computed_len: {}, query_len: {}".format( seq_len, computed_len, query_len) if seq_group_metadata.multi_modal_data: # NOTE: mm_data only includes the subset of multi-modal # items that intersect with the current prefill positions. mm_data, placeholder_maps = MultiModalPlaceholderMap \ .from_seq_group(seq_group_metadata, positions_range) if self.mm_registry.has_processor(self.model_config): mm_kwargs = mm_data else: mm_kwargs = self.multi_modal_input_mapper( mm_data, seq_group_metadata.mm_processor_kwargs, ) multi_modal_kwargs_list.append(mm_kwargs) for modality, placeholder_map in placeholder_maps.items(): multi_modal_placeholder_maps[modality].extend( placeholder_map, ) max_query_len = max(query_lens) assert max_query_len > 0, "query_lens: {}".format(query_lens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) # type: ignore input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) # type: ignore past_lens_tensor = torch.tensor(past_lens, dtype=torch.int32, device=self.device) # type: ignore subsequence_begins_tensor = torch.tensor( subsequence_begins, dtype=torch.int32, device=self.device) # type: ignore block_indices_tensor = torch.tensor(block_indices, dtype=torch.int32, device=self.device) # type: ignore block_indices_begins_tensor = torch.tensor( block_indices_begins, dtype=torch.int32, device=self.device) # type: ignore max_context_len = max(seq_lens) max_context_len_tensor = torch.tensor( max_context_len, dtype=torch.int32, device=self.device) # type: ignore placeholder_index_maps = { modality: placeholder_map.index_map() for modality, placeholder_map in multi_modal_placeholder_maps.items() } attn_metadata = self.attn_backend.make_openvino_metadata( past_lens=past_lens_tensor, subsequence_begins=subsequence_begins_tensor, block_indices=block_indices_tensor, block_indices_begins=block_indices_begins_tensor, max_context_len=max_context_len_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return ModelInput( input_tokens, input_positions, attn_metadata, seq_lens, query_lens, multi_modal_kwargs=multi_modal_kwargs, ) def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata, SamplingMetadata, BatchedTensorInputs]: # Prepare input tensors. ( input_tokens, input_positions, attn_metadata, seq_lens, query_lens, multi_modal_kwargs, ) = self._prepare_model_input(seq_group_metadata_list) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens, self.device, pin_memory=False, ) return ( input_tokens, input_positions, attn_metadata, sampling_metadata, multi_modal_kwargs, ) @torch.inference_mode() def execute_model( self, seq_group_metadata_list: List[SequenceGroupMetadata], kv_caches: List[Tuple["ov.Tensor", "ov.Tensor"]], ) -> Optional[SamplerOutput]: ( input_tokens, input_positions, attn_metadata, sampling_metadata, multi_modal_kwargs, ) = self.prepare_input_tensors(seq_group_metadata_list) model_executable = self.model execute_model_kwargs = { "input_ids": input_tokens, "positions": input_positions, "kv_caches": kv_caches, "attn_metadata": attn_metadata, **MultiModalKwargs.as_kwargs(multi_modal_kwargs or {}, device=self.device), } hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. logits = self.model.compute_logits(hidden_states, sampling_metadata) # Sample the next token. output = self.model.sample( logits=logits, sampling_metadata=sampling_metadata, ) return output def prepare_model_input(self, *args, **kwargs): raise NotImplementedError def make_model_input_from_broadcasted_tensor_dict(self, *args, **kwargs): raise NotImplementedError