import dataclasses import weakref from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import torch from torch import nn from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend logger = init_logger(__name__) _PAD_SLOT_ID = -1 @dataclass(frozen=True) class ModelInputForCPU(ModelRunnerInputBase): """ Base class contains metadata needed for the base model forward pass on CPU """ input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None virtual_engine: Optional[int] = None seq_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict @classmethod def from_broadcasted_tensor_dict( cls: Type["ModelInputForCPU"], tensor_dict: Dict[str, Any], attn_backend: Optional["AttentionBackend"] = None ) -> "ModelInputForCPU": if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( attn_backend, tensor_dict) return cls(**tensor_dict) @dataclass(frozen=True) class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU): """ Used by the ModelRunner. """ sampling_metadata: Optional["SamplingMetadata"] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, self.sampling_metadata) return tensor_dict @classmethod def from_broadcasted_tensor_dict( cls, tensor_dict: Dict[str, Any], attn_backend: Optional["AttentionBackend"] = None, ) -> "ModelInputForCPUWithSamplingMetadata": tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( attn_backend, tensor_dict) return cls(**tensor_dict) class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): def __init__(self, runner: "CPUModelRunner", finished_requests_ids: Optional[List[str]] = None) -> None: super().__init__() self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] self.runner = runner self.model_input_cls = self.runner._model_input_cls self.attn_backend = self.runner.attn_backend self.sliding_window = self.runner.sliding_window self.block_size = self.runner.block_size self.device = self.runner.device self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata) def build(self) -> ModelInputForCPU: multi_modal_kwargs = None # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = self.seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_kwargs) = self._prepare_prompt( self.seq_group_metadata_list) else: (input_tokens, input_positions, attn_metadata) = self._prepare_decode( self.seq_group_metadata_list) seq_lens = None return self.model_input_cls( input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, multi_modal_kwargs=multi_modal_kwargs, # query_lens is not needed if chunked prefill is not # supported. Since CPU worker doesn't support chunked prefill # just use seq_lens instead. seq_lens=seq_lens, query_lens=seq_lens, ) def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data, computed_len: int, mm_processor_kwargs: Dict[str, Any]): mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs) # special processing for mrope position deltas. mrope_positions = None if self.runner.model_is_mrope: image_grid_thw = mm_kwargs.get("image_grid_thw", None) video_grid_thw = mm_kwargs.get("video_grid_thw", None) assert image_grid_thw is not None or video_grid_thw is not None, ( "mrope embedding type requires multi-modal input mapper " "returns 'image_grid_thw' or 'video_grid_thw'.") hf_config = self.runner.model_config.hf_config token_ids = seq_data.get_token_ids() mrope_positions, mrope_position_delta = \ MRotaryEmbedding.get_input_positions( token_ids, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, image_token_id=hf_config.image_token_id, video_token_id=hf_config.video_token_id, vision_start_token_id=hf_config.vision_start_token_id, vision_end_token_id=hf_config.vision_end_token_id, spatial_merge_size=hf_config.vision_config. spatial_merge_size, context_len=computed_len, ) seq_data.mrope_position_delta = mrope_position_delta return mm_kwargs, mrope_positions def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], BatchedTensorInputs]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] input_mrope_positions: List[List[int]] = [[] for _ in range(3)] slot_mapping: List[int] = [] seq_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) assert len(seq_ids) == 1 seq_id = seq_ids[0] seq_data = seq_group_metadata.seq_data[seq_id] prompt_tokens = seq_data.get_token_ids() computed_len = seq_data.get_num_computed_tokens() seq_len = len(prompt_tokens) seq_lens.append(seq_len) # Prompt token num input_tokens.extend(prompt_tokens) # Token ids mrope_positions = None if (mm_data := seq_group_metadata.multi_modal_data): mm_kwargs, mrope_positions = self._compute_multi_modal_input( seq_data, mm_data, computed_len, seq_group_metadata.mm_processor_kwargs) multi_modal_inputs_list.append(mm_kwargs) # Token position ids # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. if mrope_positions: for idx in range(3): input_mrope_positions[idx].extend(mrope_positions[idx]) else: input_positions.extend(list(range(computed_len, seq_len))) # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # where start_idx is max(0, seq_len - sliding_window). # For example, if the prompt len is 10, sliding window is 8, and # block size is 4, the first two tokens are masked and the slot # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. start_idx = 0 if self.sliding_window is not None: start_idx = max(0, seq_len - self.sliding_window) for i in range(computed_len, seq_len): if i < start_idx: slot_mapping.append(_PAD_SLOT_ID) continue block_number = block_table[i // self.block_size] # type: ignore block_offset = i % self.block_size # type: ignore slot = block_number * self.block_size + block_offset slot_mapping.append(slot) if any(input_mrope_positions): input_positions = None # type: ignore else: input_mrope_positions = None # type: ignore num_prompt_tokens = len(input_tokens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) # type: ignore input_positions = torch.tensor(input_positions or input_mrope_positions, dtype=torch.long, device=self.device) # type: ignore slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) # type: ignore attn_metadata = self.attn_backend.make_metadata( is_prompt=True, seq_lens=seq_lens, seq_lens_tensor=torch.tensor([]), max_decode_seq_len=0, num_prefills=len(seq_lens), num_prefill_tokens=num_prompt_tokens, num_decode_tokens=0, block_tables=torch.tensor([]), slot_mapping=slot_mapping, ) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_kwargs) def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] input_mrope_positions: List[List[int]] = [[] for _ in range(3)] slot_mapping: List[int] = [] seq_lens: List[int] = [] block_tables: List[List[int]] = [] for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt assert seq_group_metadata.token_chunk_size == 1 seq_ids = list(seq_group_metadata.seq_data.keys()) for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() input_tokens.append(generation_token) seq_len = seq_data.get_len() position = seq_len - 1 if seq_data.mrope_position_delta is not None: context_len = seq_data.get_num_computed_tokens() next_pos = MRotaryEmbedding.get_next_input_positions( seq_data.mrope_position_delta, context_len, seq_len, ) for idx in range(3): input_mrope_positions[idx].extend(next_pos[idx]) else: input_positions.append(position) seq_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) seq_lens.append(seq_len) block_table = seq_group_metadata.block_tables[seq_id] block_number = block_table[position // self.block_size] block_offset = position % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append(slot) if self.sliding_window is not None: sliding_window_blocks = (self.sliding_window // self.block_size) block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) if any(input_mrope_positions): input_positions = None # type: ignore else: input_mrope_positions = None # type: ignore max_decode_seq_len = max(seq_lens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) input_positions = torch.tensor(input_positions or input_mrope_positions, dtype=torch.long, device=self.device) slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=self.device) block_tables = make_tensor_with_pad( block_tables, pad=0, dtype=torch.int, device=self.device, ) attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_decode_seq_len=max_decode_seq_len, num_prefill_tokens=0, num_decode_tokens=len(input_tokens), num_prefills=0, block_tables=block_tables, ) return ( input_tokens, input_positions, attn_metadata, ) class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]): _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = ( ModelInputForCPUWithSamplingMetadata) _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder def __init__( self, model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, cache_config: CacheConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, *args, **kwargs, ): self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config # Currently, CPU worker doesn't support chunked prefill. assert self.scheduler_config.chunked_prefill_enabled is False self.device_config = device_config self.cache_config = cache_config self.lora_config = lora_config self.prompt_adapter_config = prompt_adapter_config self.load_config = load_config self.is_driver_worker = is_driver_worker self.device = self.device_config.device self.kv_cache_dtype = kv_cache_dtype self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.attn_backend = get_attn_backend( self.model_config.get_head_size(), self.model_config.get_sliding_window(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, self.model_config.is_attention_free, ) # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.multi_modal_input_mapper = self.mm_registry \ .create_input_mapper(self.model_config) self.mm_registry.init_mm_limits_per_prompt(self.model_config) # Lazy initialization. self.model: nn.Module # Set after init_Model @property def model_is_mrope(self) -> bool: """Detect if the model has "mrope" rope_scaling type. mrope requires keep "rope_deltas" between prompt and decoding phases.""" rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) if rope_scaling is None: return False return rope_scaling.get("type", None) == "mrope" def load_model(self) -> None: self.model = get_model(model_config=self.model_config, load_config=self.load_config, device_config=self.device_config, lora_config=self.lora_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, cache_config=self.cache_config) def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any], ) -> ModelInputForCPUWithSamplingMetadata: return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501 tensor_dict, attn_backend=self.attn_backend, ) def _prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForCPUWithSamplingMetadata: """Helper method to prepare the model input based on a given sequence group. Prepares metadata needed for the base model forward pass but not metadata for possible additional steps, e.g., sampling. """ builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) for seq_group_metadata in seq_group_metadata_list: builder.add_seq_group(seq_group_metadata) return builder.build() # type: ignore def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForCPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. """ model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) # Sampling metadata is only required for the final pp group generators = self.get_generators(finished_requests_ids) sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, model_input.seq_lens, model_input.query_lens, self.device, pin_memory=False, generators=generators) return dataclasses.replace(model_input, sampling_metadata=sampling_metadata, virtual_engine=virtual_engine) @torch.no_grad() def execute_model( self, model_input: ModelInputForCPUWithSamplingMetadata, kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: raise ValueError( "CPU worker does not support multi-step execution.") model_executable = self.model execute_model_kwargs = { "input_ids": model_input.input_tokens, "positions": model_input.input_positions, "kv_caches": kv_caches, "attn_metadata": model_input.attn_metadata, **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device), "intermediate_tensors": intermediate_tensors, } hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. logits = self.model.compute_logits(hidden_states, model_input.sampling_metadata) # Only perform sampling in the driver worker. if not self.is_driver_worker: return [] # Sample the next token. output = self.model.sample( logits=logits, sampling_metadata=model_input.sampling_metadata, ) return [output]