# SPDX-License-Identifier: Apache-2.0 import bisect import time from typing import TYPE_CHECKING, Optional, cast from unittest.mock import patch import numpy as np import torch import torch.distributed import torch.nn as nn # TPU XLA related import torch_xla.core.xla_model as xm import torch_xla.runtime as xr import vllm.envs as envs from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, PallasMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput, SamplerOutput) from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from .utils import sanity_check_mm_encoder_outputs if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput logger = init_logger(__name__) # Here we utilize the behavior that out-of-bound index is ignored. # FIXME(woosuk): Find a more reliable way to prevent possible bugs. _PAD_SLOT_ID = 1_000_000_000 INVALID_TOKEN_ID = -1 # Smallest output size MIN_NUM_SEQS = 8 class TPUModelRunner: def __init__( self, vllm_config: VllmConfig, device: torch.device, ): self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config self.lora_config = vllm_config.lora_config self.load_config = vllm_config.load_config self.parallel_config = vllm_config.parallel_config self.scheduler_config = vllm_config.scheduler_config self.speculative_config = vllm_config.speculative_config self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config self.device_config = vllm_config.device_config model_config = self.model_config cache_config = self.cache_config scheduler_config = self.scheduler_config parallel_config = self.parallel_config self.device = device self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION self.enforce_eager = model_config.enforce_eager self.num_xla_graphs = 0 self._update_num_xla_graphs("init") self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype self._hidden_states_dtype = self.dtype self.is_multimodal_model = model_config.is_multimodal_model self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) self.max_num_tokens = scheduler_config.max_num_batched_tokens # InputBatch needs to work with sampling tensors greater than padding # to avoid dynamic shapes. Also, avoid suboptimal alignment. self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS) # Model-related. self.num_attn_layers = model_config.get_num_layers_by_block_type( parallel_config, LayerBlockType.attention) self.num_query_heads = model_config.get_num_attention_heads( parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope # TODO: Support M-RoPE (e.g, Qwen2-VL) assert not self.uses_mrope, "TPU does not support M-RoPE yet." encoder_compute_budget, encoder_cache_size = compute_encoder_budget( model_config=model_config, scheduler_config=scheduler_config, mm_registry=self.mm_registry, ) self.max_num_encoder_input_tokens = encoder_compute_budget self.encoder_cache_size = encoder_cache_size # Lazy initialization # self.model: nn.Module # Set after load_model self.kv_caches: list[torch.Tensor] = [] # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} # Request states. self.requests: dict[str, CachedRequestState] = {} # Persistent batch. self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, max_num_blocks_per_req=self.max_num_blocks_per_req, device=self.device, pin_memory=self.pin_memory, vocab_size=model_config.get_vocab_size(), ) # Cached torch/numpy tensor # The pytorch tensor and numpy array share the same buffer. # Sometimes the numpy op is faster so we create both. self.input_ids_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int32, device="cpu") self.input_ids_np = self.input_ids_cpu.numpy() self.positions_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int32, device="cpu") self.positions_np = self.positions_cpu.numpy() self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int64, device="cpu") self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.block_table_cpu = torch.zeros( (self.max_num_tokens, self.max_num_blocks_per_req), dtype=self.input_batch.block_table.get_cpu_tensor().dtype, device="cpu") self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1, dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) self.query_start_loc_np = self.query_start_loc_cpu.numpy() self.seq_lens_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() # Range tensor with values [0 .. self.max_num_tokens - 1]. # Used to initialize positions / context_lens / seq_lens self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32) self.num_tokens_paddings = _get_paddings( min_token_size=16, max_token_size=self.max_num_tokens, padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) def _update_num_xla_graphs(self, case_str): check_comp = self.check_recompilation and not self.enforce_eager if not check_comp: return total_cached_graphs = xr.get_num_cached_compilation_graph() new_compiled_graphs = total_cached_graphs - self.num_xla_graphs if new_compiled_graphs == 0: return logger.info("Add new %d compiled XLA graphs due to %s", new_compiled_graphs, case_str) self.num_xla_graphs += new_compiled_graphs def _verify_num_xla_graphs(self, case_str): check_comp = self.check_recompilation and not self.enforce_eager if not check_comp: return curr_cached_graph = xr.get_num_cached_compilation_graph() assert self.num_xla_graphs == curr_cached_graph, ( "Recompilation after warm up is detected during {}." " num_xla_graphs = {} curr_cached_graph = {}".format( case_str, self.num_xla_graphs, curr_cached_graph)) def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: """Update the cached states and the persistent batch with the scheduler output. The updated states are used by the `_prepare_inputs` function to create the input GPU tensors for the model. Returns: True if there is a new/resumed/paused/finished request. If False, we can skip copying SamplingMetadata to the GPU. """ # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and # then resubmitted with the same ID. In this case, we treat them as two # distinct requests - clearing the cached states for the first request # and handling the second as a new request. removed_req_indices: list[int] = [] for req_id in scheduler_output.finished_req_ids: req_index = self.input_batch.remove_request(req_id) if req_index is not None: removed_req_indices.append(req_index) # Free the cached encoder outputs. for req_id, input_id in scheduler_output.free_encoder_input_ids: encoder_outputs = self.encoder_cache.get(req_id) if encoder_outputs is not None: encoder_outputs.pop(input_id, None) if not encoder_outputs: self.encoder_cache.pop(req_id, None) # Remove the unscheduled requests from the persistent batch. # NOTE(woosuk): The unscheduled requests are either preempted requests # or running requests that are not scheduled in this step. We remove # them from the persistent batch but keep their cached states since # they will be scheduled again sometime in the future. scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() cached_req_ids = self.input_batch.req_id_to_index.keys() unscheduled_req_ids = cached_req_ids - scheduled_req_ids # NOTE(woosuk): The persistent batch optimization assumes that # consecutive batches contain mostly the same requests. If batches # have low request overlap (e.g., alternating between two distinct # sets of requests), this optimization becomes very inefficient. for req_id in unscheduled_req_ids: req_index = self.input_batch.remove_request(req_id) assert req_index is not None removed_req_indices.append(req_index) req_ids_to_add: list[str] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id sampling_params = new_req_data.sampling_params if sampling_params.sampling_type == SamplingType.RANDOM_SEED: generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) else: generator = None self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, prompt=new_req_data.prompt, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, generator=generator, block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], lora_request=new_req_data.lora_request, ) req_ids_to_add.append(req_id) # Update the states of the running/resumed requests. for req_data in scheduler_output.scheduled_cached_reqs: req_id = req_data.req_id req_state = self.requests[req_id] # Update the cached states. req_state.num_computed_tokens = req_data.num_computed_tokens if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. req_state.block_ids.extend(req_data.new_block_ids) else: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. req_state.block_ids = req_data.new_block_ids req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not # scheduled in the previous step and needs to be added again. req_ids_to_add.append(req_id) continue # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( req_data.num_computed_tokens) self.input_batch.block_table.append_row(req_data.new_block_ids, req_index) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. removed_req_indices = sorted(removed_req_indices, reverse=True) for req_id in req_ids_to_add: req_state = self.requests[req_id] if removed_req_indices: # Fill the empty index. req_index = removed_req_indices.pop() else: # Append to the end. req_index = None self.input_batch.add_request(req_state, req_index) # Condense the batched states if there are empty indices. if removed_req_indices: self.input_batch.condense(removed_req_indices) return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 def get_model(self) -> nn.Module: assert self.model is not None return self.model def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each Attention module in the static forward context. Returns: KVCacheSpec: A dictionary mapping layer names to their KV cache format. Layers that do not need KV cache are not included. """ forward_ctx = self.vllm_config.compilation_config.static_forward_context block_size = self.vllm_config.cache_config.block_size kv_cache_spec: dict[str, KVCacheSpec] = {} for layer_name, attn_module in forward_ctx.items(): assert isinstance(attn_module, Attention) if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=attn_module.dtype, sliding_window=attn_module.sliding_window, use_mla=False, ) else: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=attn_module.dtype, use_mla=False, ) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. continue elif attn_module.attn_type == AttentionType.ENCODER_DECODER: raise NotImplementedError else: raise ValueError( f"Unknown attention type: {attn_module.attn_type}") return kv_cache_spec def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 # Get the number of scheduled tokens for each request. num_scheduled_tokens_per_req = [] max_num_scheduled_tokens_all_reqs = 0 for req_id in self.input_batch.req_ids[:num_reqs]: assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens_per_req.append(num_tokens) max_num_scheduled_tokens_all_reqs = max( max_num_scheduled_tokens_all_reqs, num_tokens) num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req, dtype=np.int32) assert max_num_scheduled_tokens_all_reqs > 0 # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] # For each scheduled token, what are the corresponding req index. req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens_per_req) # Get batched arange. # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # For each scheduled token, what is its position in corresponding req. arange = np.concatenate( [self.arange_np[:n] for n in num_scheduled_tokens_per_req]) # Get positions. positions_np = self.positions_np[:total_num_scheduled_tokens] np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # where M is the max_model_len. token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) # Calculate the slot mapping. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] # where K is the max_num_blocks_per_req and the block size is 2. # NOTE(woosuk): We can't simply use `token_indices // block_size` here # because M (max_model_len) is not necessarily divisible by block_size. # req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] block_table_indices = (req_indices * self.max_num_blocks_per_req + positions_np // self.block_size) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. block_table_cpu = self.input_batch.block_table.get_cpu_tensor() block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() block_offsets = positions_np % self.block_size np.add(block_numbers * self.block_size, block_offsets, out=self.slot_mapping_np[:total_num_scheduled_tokens]) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 np.cumsum(num_scheduled_tokens_per_req, out=self.query_start_loc_np[1:num_reqs + 1]) self.query_start_loc_np[num_reqs + 1:] = 1 self.seq_lens_np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens_per_req) # Do the padding and copy the tensors to the TPU. padded_total_num_scheduled_tokens = _get_padded_token_len( self.num_tokens_paddings, total_num_scheduled_tokens) # Zero out to avoid spurious values from prev iteration (last cp chunk) self.input_ids_cpu[ total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0 self.input_ids = self.input_ids_cpu[: padded_total_num_scheduled_tokens].to( self.device) self.position_ids = self.positions_cpu[: padded_total_num_scheduled_tokens].to( self.device) self.slot_mapping_cpu[total_num_scheduled_tokens:] = _PAD_SLOT_ID slot_mapping = self.slot_mapping_cpu[: padded_total_num_scheduled_tokens].to( self.device) block_tables = self.block_table_cpu[:self.max_num_reqs] block_tables[:num_reqs, :self.max_num_blocks_per_req] = ( self.input_batch.block_table.get_cpu_tensor()[:num_reqs]) block_tables = block_tables.to(self.device) query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to( self.device) seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, block_tables=block_tables, context_lens=seq_lens, query_start_loc=query_start_loc, num_seqs=torch.tensor([num_reqs], dtype=torch.int32, device=self.device), ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this # partial request, we do so for simplicity. We will ignore the sampled # token from the partial request. # TODO: Support prompt logprobs. padded_num_reqs = _get_padded_num_reqs_with_upper_limit( num_reqs, self.max_num_reqs) # Indices at which we sample (positions of last token in the sequence). # Padded to avoid recompiling when `num_reqs` varies. logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 logits_indices = logits_indices.to(self.device) return attn_metadata, logits_indices def _execute_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: return # Batch the multi-modal inputs. mm_inputs: list[MultiModalKwargs] = [] req_input_ids: list[tuple[str, int]] = [] for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] for input_id in encoder_input_ids: mm_inputs.append(req_state.mm_inputs[input_id]) req_input_ids.append((req_id, input_id)) # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, # we process it separately to preserve item order. # FIXME(ywang96): This is a hacky way to deal with multiple modalities # in the same batch while still being able to benefit from batching # multimodal inputs. The proper solution should be reordering the # encoder outputs. grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs) encoder_outputs = [] for grouped_mm_inputs in grouped_mm_inputs_list: batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, device=self.device) # Run the encoder. # `curr_group_outputs` is either of the following: # 1. A tensor of shape (num_items, feature_size, hidden_size) # in case feature_size is fixed across all multimodal items. # 2. A list or tuple (length: num_items) of tensors, each of shape # (feature_size, hidden_size) in case the feature size is dynamic # depending on the input multimodal items. curr_group_outputs = self.model.get_multimodal_embeddings( **batched_mm_inputs) sanity_check_mm_encoder_outputs( curr_group_outputs, expected_num_items=len(grouped_mm_inputs), ) for output in curr_group_outputs: encoder_outputs.append(output) # Cache the encoder outputs. for (req_id, input_id), output in zip(req_input_ids, encoder_outputs): if req_id not in self.encoder_cache: self.encoder_cache[req_id] = {} self.encoder_cache[req_id][input_id] = output def _gather_encoder_outputs( self, scheduler_output: "SchedulerOutput", ) -> list[torch.Tensor]: encoder_outputs: list[torch.Tensor] = [] for req_id in self.input_batch.req_ids: num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens mm_positions = req_state.mm_positions for i, pos_info in enumerate(mm_positions): start_pos = pos_info["offset"] num_encoder_tokens = pos_info["length"] # The encoder output is needed if the two ranges overlap: # [num_computed_tokens, # num_computed_tokens + num_scheduled_tokens) and # [start_pos, start_pos + num_encoder_tokens) if start_pos >= num_computed_tokens + num_scheduled_tokens: # The encoder output is not needed in this step. break if start_pos + num_encoder_tokens <= num_computed_tokens: # The encoder output is already processed and stored # in the decoder's KV cache. continue start_idx = max(num_computed_tokens - start_pos, 0) end_idx = min( num_computed_tokens - start_pos + num_scheduled_tokens, num_encoder_tokens) assert start_idx < end_idx assert req_id in self.encoder_cache assert i in self.encoder_cache[req_id] encoder_output = self.encoder_cache[req_id][i] encoder_outputs.append(encoder_output[start_idx:end_idx]) return encoder_outputs @torch.no_grad() def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> ModelRunnerOutput: # Update cached state self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: # Return empty ModelRunnerOuptut if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT if self.is_multimodal_model: # Run the multimodal encoder if any. self._execute_encoder(scheduler_output) encoder_outputs = self._gather_encoder_outputs(scheduler_output) else: encoder_outputs = [] # Prepare inputs attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) if self.is_multimodal_model: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. if encoder_outputs: inputs_embeds = self.model.get_input_embeddings( self.input_ids, encoder_outputs) else: inputs_embeds = self.model.get_input_embeddings(self.input_ids) input_ids = None else: # For text-only models, we use token ids as input. # While it is possible to use embeddings as input just like the # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids inputs_embeds = None num_reqs = self.input_batch.num_reqs # NOTE (NickLucche) here we sync with TPU: sampling params tensors # are copied to device in chunks of pre-compiled padded shape to # avoid recompilations. tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ from_input_batch(self.input_batch, logits_indices) # Run the decoder with set_forward_context(attn_metadata, self.vllm_config): hidden_states = self.model( input_ids=input_ids, positions=self.position_ids, kv_caches=self.kv_caches, inputs_embeds=inputs_embeds, ) selected_token_ids = self.model.sample_from_hidden( hidden_states, tpu_sampling_metadata) # Remove padding on cpu and keep dynamic op outside of xla graph. selected_token_ids = selected_token_ids.cpu()[:num_reqs] # Update the cache state concurrently. Code above will not block until # we use `selected_token_ids`. Add mark_step if post-processing changes request_seq_lens: list[tuple[int, CachedRequestState, int]] = [] discard_sampled_tokens_req_indices = [] for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) if seq_len >= req_state.num_tokens: request_seq_lens.append((i, req_state, seq_len)) else: # Ignore the sampled token from the partial request. # Rewind the generator state as if the token was not sampled. generator = self.input_batch.generators.get(i) if generator is not None: # This relies on cuda-specific torch-internal impl details generator.set_offset(generator.get_offset() - 4) # Record the index of the request that should not be sampled, # so that we could clear the sampled tokens before returning. discard_sampled_tokens_req_indices.append(i) assert all( req_id is not None for req_id in self.input_batch.req_ids[:num_reqs]), "req_ids contains None" req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs]) prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} for req_id in self.input_batch.req_ids[:num_reqs]: prompt_logprobs_dict[req_id] = None max_gen_len = selected_token_ids.shape[-1] if max_gen_len == 1: valid_sampled_token_ids = selected_token_ids.tolist() # Mask out the sampled tokens that should not be sampled. # TODO: Keep in sync with gpu_model_runner.py, in particular # the "else" case here for i in discard_sampled_tokens_req_indices: valid_sampled_token_ids[i].clear() # Append sampled tokens for i, req_state, seq_len in request_seq_lens: token_id = valid_sampled_token_ids[i][0] self.input_batch.token_ids_cpu[i, seq_len] = token_id req_state.output_token_ids.append(token_id) self.input_batch.num_tokens[i] += 1 else: valid_mask = selected_token_ids != INVALID_TOKEN_ID gen_lens = valid_mask.sum(dim=1).tolist() valid_sampled_token_ids = [ seq.tolist() for seq in selected_token_ids[valid_mask].split(gen_lens) ] self.input_batch.num_tokens[:num_reqs] += gen_lens for i, req_state, seq_len in request_seq_lens: target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1) self.input_batch.token_ids_cpu[ i, target_slice] = valid_sampled_token_ids[i] req_state.output_token_ids.extend(valid_sampled_token_ids[i]) model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, spec_token_ids=None, logprobs=None, prompt_logprobs_dict=prompt_logprobs_dict, ) # Check there are no new graphs compiled - all the graphs should be # captured and compiled during warm up. self._verify_num_xla_graphs("execute_model") return model_runner_output def load_model(self) -> None: self.device = self.device_config.device # NOTE(woosuk): While the executor assigns the TP ranks to the worker # process, the ranks can be different from the ranks internally assigned # by the xm runtime. Therefore, there is a mismatch in the rank # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. # This is not a problem in linear layers because all-reduce is # rank-agnostic. However, it matters for all-gather as the ranks # determine the order of concatenating the output tensors. # As a workaround, we use the xm's rank assignment only when loading # the embedding weights. xm_tp_rank = xr.global_ordinal() with patch( "vllm.model_executor.layers.vocab_parallel_embedding." "get_tensor_model_parallel_rank", return_value=xm_tp_rank): model = get_model(vllm_config=self.vllm_config) model = model.eval() xm.mark_step() xm.wait_device_ops() model = ModelWrapperV1(model) self.model = torch.compile(model, backend="openxla", fullgraph=True, dynamic=False) @torch.no_grad() def _dummy_run(self, kv_caches, num_tokens: int) -> None: if self.is_multimodal_model: input_ids = None inputs_embeds = torch.zeros((num_tokens, self.hidden_size), dtype=self.dtype, device=self.device) else: input_ids = torch.zeros((num_tokens), dtype=torch.int32, device=self.device) inputs_embeds = None actual_num_reqs = min(num_tokens, self.max_num_reqs) position_ids = torch.zeros(num_tokens, dtype=torch.int32, device=self.device) slot_mapping = torch.zeros(num_tokens, dtype=torch.int64, device=self.device) block_tables = torch.zeros( (self.max_num_reqs, self.block_table_cpu.shape[1]), dtype=torch.int32, device=self.device) query_lens = [1] * self.max_num_reqs query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32).to(self.device) context_lens = torch.ones((self.max_num_reqs, ), dtype=torch.int32, device=self.device) num_seqs = torch.tensor([actual_num_reqs], dtype=torch.int32, device=self.device) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, block_tables=block_tables, context_lens=context_lens, query_start_loc=query_start_loc, num_seqs=num_seqs, ) if self.is_multimodal_model: torch._dynamo.mark_dynamic(inputs_embeds, 0) else: torch._dynamo.mark_dynamic(input_ids, 0) torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) with set_forward_context(attn_metadata, self.vllm_config, 0): out = self.model(input_ids=input_ids, positions=position_ids, kv_caches=kv_caches, inputs_embeds=inputs_embeds) self._hidden_states_dtype = out.dtype def capture_model(self) -> None: """Compile the model.""" logger.info("Compiling the model with different input shapes.") start = time.perf_counter() for num_tokens in self.num_tokens_paddings: logger.info(" -- num_tokens: %d", num_tokens) self._dummy_run(self.kv_caches, num_tokens) xm.mark_step() xm.wait_device_ops() end = time.perf_counter() logger.info("Compilation finished in in %.2f [secs].", end - start) self._update_num_xla_graphs("model") logger.info("Compiling sampling with different input shapes.") start = time.perf_counter() hsize = self.model_config.get_hidden_size() device = self.device # Compile sampling step for different model+sampler outputs in bucketed # n_tokens x max_num_reqs. Graph is really small so this is fine. for num_tokens in self.num_tokens_paddings: num_reqs_to_sample = MIN_NUM_SEQS dummy_hidden = torch.randn((num_tokens, hsize), device=device, dtype=self._hidden_states_dtype) # Compile for [8, 16, .., 128,.., `self.max_num_reqs`] while True: indices = torch.zeros( num_reqs_to_sample, dtype=torch.int32, device=device, ) xm.mark_step() sampling_meta = TPUSupportedSamplingMetadata.\ from_input_batch(self.input_batch, indices) logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, num_reqs_to_sample) out = self.model.sample_from_hidden(dummy_hidden, sampling_meta) out = out.cpu() # Requests can't be more than tokens. But do compile for the # next bigger value in case num_tokens uses bucketed padding. if num_reqs_to_sample >= min(num_tokens, self.max_num_reqs): break # Make sure to compile the `max_num_reqs` upper-limit case num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit( num_reqs_to_sample + 1, self.max_num_reqs) xm.wait_device_ops() end = time.perf_counter() logger.info("Compilation finished in in %.2f [secs].", end - start) self._update_num_xla_graphs("sampling") def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. Args: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ if len(kv_cache_config.kv_cache_groups) > 1: raise NotImplementedError( "Hybrid models with more than one KV cache type are not " "supported yet.") kv_caches: dict[str, torch.Tensor] = {} for kv_cache_group in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group.kv_cache_spec for layer_name in kv_cache_group.layer_names: tensor_config = kv_cache_config.tensors[layer_name] assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes if isinstance(kv_cache_spec, FullAttentionSpec): kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype tpu_kv_cache = torch.zeros(kv_cache_shape, dtype=dtype, device=self.device) kv_caches[layer_name] = tpu_kv_cache else: raise NotImplementedError bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, self.kv_caches) class ModelWrapperV1(nn.Module): def __init__(self, model: nn.Module): super().__init__() self.model = model self.sampler = TPUSampler() def sample( self, logits: torch.Tensor, sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput: sampler_out = self.sampler(logits, sampling_metadata) return sampler_out def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: list[torch.Tensor], inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Executes the forward pass of the model. Args: input_ids: The input token IDs of shape [num_tokens]. positions: The input position IDs of shape [num_tokens]. kv_caches: The key and value caches. They can be None during the memory profiling at initialization. inputs_embeds: The input embeddings of shape [num_tokens, hidden_size]. It is used for multimodal models. """ hidden_states = self.model( input_ids=input_ids, positions=positions, inputs_embeds=inputs_embeds, ) return hidden_states def sample_from_hidden( self, hidden_states: torch.Tensor, sampling_metadata: TPUSupportedSamplingMetadata, ) -> torch.Tensor: """ Sample with xla-friendly function. This function is to be traced separately from `forward` for lighter compilation overhead. """ # Tensor `sample_hidden_states` is of fixed pre-compiled size. sample_hidden_states = \ hidden_states[sampling_metadata.indices_do_sample] logits = self.compute_logits(sample_hidden_states) # Optimized greedy sampling branch, tracing both paths in a single pass # NOTE all_greedy is a scalar, this is just an optimized if/else. out_tokens = torch.where(sampling_metadata.all_greedy, torch.argmax(logits, dim=-1, keepdim=True), self.sample(logits, sampling_metadata)\ .sampled_token_ids) return out_tokens def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: # SamplingMetadata here for pruning output in LogitsProcessor, disabled logits = self.model.compute_logits(hidden_states, None) return logits def get_multimodal_embeddings(self, *args, **kwargs): return self.model.get_multimodal_embeddings(*args, **kwargs) def get_input_embeddings(self, *args, **kwargs): return self.model.get_input_embeddings(*args, **kwargs) def _get_padded_number(n: int, multiple: int) -> int: return ((n + multiple - 1) // multiple) * multiple def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int: res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length() return min(res, upper_limit) def _get_paddings(min_token_size: int, max_token_size: int, padding_gap: int) -> list[int]: """Generate a list of padding size, starting from min_token_size, ending with a number that can cover max_token_size If padding_gap == 0 then: increase 2X each time (exponential) else: first increase the size to twice, then increase the padding size by padding_gap. """ paddings = [] num = min_token_size if padding_gap == 0: logger.info("Using exponential paddings:") while num <= max_token_size: logger.info(" %d", num) paddings.append(num) num *= 2 else: logger.info("Using incremental paddings:") while num <= padding_gap: logger.info(" %d", num) paddings.append(num) num *= 2 num //= 2 while num < max_token_size: num += padding_gap logger.info(" %d", num) paddings.append(num) return paddings def _get_padded_token_len(paddings: list[int], x: int) -> int: """Return the first element in paddings list greater or equal to x. """ index = bisect.bisect_left(paddings, x) assert index < len(paddings) return paddings[index]