# # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # This file is a part of the vllm-ascend project. # Adapted from vllm-project/vllm/vllm/worker/model_runner.py # Copyright 2023 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import dataclasses from typing import Any, Dict, List, Optional, Set, Type import torch import torch.distributed from torch import nn from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.model_executor import SamplingMetadata from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderMap from vllm.platforms import current_platform from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceGroupMetadata from vllm.utils import flatten_2d_lists, make_tensor_with_pad from vllm.worker.model_runner import (ModelInputForGPU, ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata, ModelRunner) logger = init_logger(__name__) LORA_WARMUP_RANK = 8 class ModelInputForNPUBuilder(ModelInputForGPUBuilder): """Build ModelInputForGPU from SequenceGroupMetadata.""" # Note: ideally we would be using a dataclass(kw_only=True) # here, so that this can be subclassed easily, # but kw_only is not supported in python<3.10. def build(self) -> ModelInputForGPU: """Finalize the builder intermediate data and create on-device tensors. """ # Combine and flatten intermediate data. input_tokens = [ flatten_2d_lists(inter_data.input_tokens) for inter_data in self.inter_data_list ] if not input_tokens: # This may happen when all prefill requests hit # prefix caching and there is no decode request. return self.model_input_cls() mrope_input_positions: Optional[List[List[int]]] = None if any(inter_data.mrope_input_positions is not None for inter_data in self.inter_data_list): mrope_input_positions = [[] for _ in range(3)] # calculate max position length for padding input_position_lens = [ len(inter_data.input_positions[0]) for inter_data in self.inter_data_list ] max_pos_len = max(input_position_lens) for idx in range(3): for inter_data in self.inter_data_list: msections = inter_data.mrope_input_positions if msections is None: for _seq_input_positions in inter_data.input_positions: # zero pad _seq_input_positions.extend( [0] * (max_pos_len - len(_seq_input_positions))) mrope_input_positions[idx].extend( _seq_input_positions) else: for _seq_mrope_input_positions in msections: # zero pad _seq_mrope_input_positions[idx].extend( [0] * (max_pos_len - len(_seq_mrope_input_positions[idx]))) mrope_input_positions[idx].extend( _seq_mrope_input_positions[idx]) input_positions = None else: input_positions = [ flatten_2d_lists(inter_data.input_positions) for inter_data in self.inter_data_list ] seq_lens = [] max_decode_seq_len = 0 for inter_data in self.inter_data_list: seq_lens.extend(inter_data.seq_lens) if not inter_data.is_prompt: max_decode_seq_len = max(max_decode_seq_len, max(inter_data.seq_lens)) query_lens = flatten_2d_lists( [inter_data.query_lens for inter_data in self.inter_data_list]) # Mapping from request IDs to sequence IDs. Used for Jamba models # that manages the cache by itself. request_ids_to_seq_ids = { data.request_id: data.seq_ids for data in self.inter_data_list } batch_size = len(input_tokens) # If cuda graph can be used, pad tensors accordingly. # See `capture_model` API for more details. # vLLM uses cuda graph only for decoding requests. cuda_graph_pad_size = -1 if self.inter_data_list[0].is_prompt: input_tokens_tensor = make_tensor_with_pad( input_tokens, 0, dtype=torch.int, device=self.runner.device) input_tokens_tensor = torch.flatten(input_tokens_tensor) if mrope_input_positions is not None: mrope_input_positions_tensor = make_tensor_with_pad( mrope_input_positions, 0, dtype=torch.int, device=self.runner.device) input_positions_tensor = torch.tensor( mrope_input_positions_tensor, dtype=torch.long, device=self.runner.device) else: input_positions_tensor = make_tensor_with_pad( input_positions, 0, dtype=torch.int, device=self.runner.device) input_positions_tensor = torch.flatten(input_positions_tensor) max_seq_len = max(seq_lens) seq_lens = len(seq_lens) * [max_seq_len] else: input_tokens_tensor = torch.tensor(flatten_2d_lists(input_tokens), dtype=torch.long, device=self.runner.device) if mrope_input_positions is not None: input_positions_tensor = torch.tensor( mrope_input_positions, dtype=torch.long, device=self.runner.device) else: input_positions_tensor = torch.tensor( flatten_2d_lists(input_positions), dtype=torch.long, device=self.runner.device) # Sequence and query lengths. seq_lens.extend([1] * cuda_graph_pad_size) # Attention metadata. attn_metadata = self.attn_metadata_builder.build( seq_lens, query_lens, cuda_graph_pad_size, batch_size) # LoRA data. lora_requests = set() lora_mapping = None if self.enable_lora: lora_requests = set(r for data in self.inter_data_list for r in data.lora_requests) lora_index_mapping = flatten_2d_lists([ flatten_2d_lists(inter_data.lora_index_mapping) for inter_data in self.inter_data_list ]) lora_index_mapping.extend([0] * cuda_graph_pad_size) lora_prompt_mapping = flatten_2d_lists([ flatten_2d_lists(inter_data.lora_prompt_mapping) for inter_data in self.inter_data_list ]) lora_mapping = LoRAMapping( **dict(index_mapping=lora_index_mapping, prompt_mapping=lora_prompt_mapping, is_prefill=not self.decode_only)) # Prompt adapter data. prompt_adapter_requests: Set[PromptAdapterRequest] = set() prompt_adapter_mapping = None if self.enable_prompt_adapter: prompt_adapter_requests = set( data.prompt_adapter_request for data in self.inter_data_list if data.prompt_adapter_request is not None) prompt_adapter_index_mapping = flatten_2d_lists([ inter_data.prompt_adapter_index_mapping for inter_data in self.inter_data_list ]) prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size) prompt_adapter_prompt_mapping = flatten_2d_lists([ inter_data.prompt_adapter_prompt_mapping for inter_data in self.inter_data_list ]) prompt_adapter_mapping = PromptAdapterMapping( prompt_adapter_index_mapping, prompt_adapter_prompt_mapping, ) # Multi-modal data. multi_modal_kwargs_list = [ data.multi_modal_kwargs for data in self.inter_data_list if data.multi_modal_kwargs is not None ] multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return self.model_input_cls( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, attn_metadata=attn_metadata, seq_lens=seq_lens, query_lens=query_lens, lora_mapping=lora_mapping, lora_requests=lora_requests, multi_modal_kwargs=multi_modal_kwargs, request_ids_to_seq_ids=request_ids_to_seq_ids, finished_requests_ids=self.finished_requests_ids, prompt_adapter_mapping=prompt_adapter_mapping, prompt_adapter_requests=prompt_adapter_requests) class InterDataForSeqGroup: """Intermediate data for the current sequence group.""" def simple_reinit(self): self.input_tokens[0].clear() # type: ignore self.input_positions[0].clear() # type: ignore self.token_types[0].clear() # type: ignore self.mrope_input_positions = None # type: ignore self.seq_lens[0] = 0 # type: ignore self.orig_seq_lens[0] = 0 # type: ignore self.query_lens[0] = 0 # type: ignore self.context_lens[0] = 0 # type: ignore self.curr_sliding_window_blocks[0] = 0 # type: ignore self.lora_index_mapping.clear() # type: ignore self.lora_prompt_mapping.clear() # type: ignore self.lora_requests.clear() # type: ignore self.prompt_adapter_index_mapping.clear() # type: ignore self.prompt_adapter_prompt_mapping.clear() # type: ignore def __init__( self, *, # From sequence group metadata. request_id: str, seq_ids: List[int], is_prompt: bool, block_tables: Optional[Dict[int, List[int]]], computed_block_nums: List[int], n_seqs: int = 0, # Input tokens and positions. input_tokens: Optional[List[List[int]]] = None, input_positions: Optional[List[List[int]]] = None, token_types: Optional[List[List[int]]] = None, mrope_input_positions: Optional[List[List[List[int]]]] = None, # The sequence length (may be capped to the sliding window). seq_lens: Optional[List[int]] = None, # The original sequence length (before applying sliding window). # This is used to compute slot mapping. orig_seq_lens: Optional[List[int]] = None, # The query length. query_lens: Optional[List[int]] = None, # The number of tokens that are already computed. context_lens: Optional[List[int]] = None, # The current sliding window block. curr_sliding_window_blocks: Optional[List[int]] = None, # LoRA inputs. lora_index_mapping: Optional[List[List[int]]] = None, lora_prompt_mapping: Optional[List[List[int]]] = None, lora_requests: Optional[Set[LoRARequest]] = None, # Prompt adapter inputs. prompt_adapter_index_mapping: Optional[List[int]] = None, prompt_adapter_prompt_mapping: Optional[List[int]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, # Multi-modal inputs. multi_modal_kwargs: Optional[MultiModalKwargs] = None, multi_modal_placeholder_maps: Optional[Dict[ str, MultiModalPlaceholderMap]] = None, # Whether the prefix cache is hit (prefill only). prefix_cache_hit: bool = False, reinit: bool = False, reinit_use_defaults: bool = False, encoder_seq_len: int = 0, ): if reinit: assert len(self.seq_ids) == len(seq_ids) # type: ignore for i, seq_id in enumerate(seq_ids): self.seq_ids[i] = seq_id # type: ignore else: self.seq_ids = seq_ids self.request_id = request_id self.is_prompt = is_prompt self.block_tables = block_tables self.computed_block_nums = computed_block_nums self.n_seqs = n_seqs self.encoder_seq_len = encoder_seq_len if reinit: if len(self.seq_ids) == 1 and reinit_use_defaults: self.simple_reinit() else: if input_tokens: self.input_tokens = input_tokens else: for seq_id in range(len(self.seq_ids)): self.input_tokens[seq_id].clear() if input_positions: self.input_positions = input_positions else: for seq_id in range(len(self.seq_ids)): self.input_positions[seq_id].clear() if token_types: self.token_types = token_types else: for seq_id in range(len(self.seq_ids)): self.token_types[seq_id].clear() self.mrope_input_positions = None if seq_lens: self.seq_lens = seq_lens else: for seq_id in range(len(self.seq_ids)): self.seq_lens[seq_id] = 0 if orig_seq_lens: self.orig_seq_lens = orig_seq_lens else: for seq_id in range(len(self.seq_ids)): self.orig_seq_lens[seq_id] = 0 if query_lens: self.query_lens = query_lens else: for seq_id in range(len(self.seq_ids)): self.query_lens[seq_id] = 0 if context_lens: self.context_lens = context_lens else: for seq_id in range(len(self.seq_ids)): self.context_lens[seq_id] = 0 if curr_sliding_window_blocks: self.curr_sliding_window_blocks = \ curr_sliding_window_blocks else: for seq_id in range(len(self.seq_ids)): self.curr_sliding_window_blocks[seq_id] = 0 if lora_index_mapping: self.lora_index_mapping = lora_index_mapping else: self.lora_index_mapping.clear() if lora_prompt_mapping: self.lora_prompt_mapping = lora_prompt_mapping else: self.lora_prompt_mapping.clear() if lora_requests: self.lora_requests = lora_requests else: self.lora_requests.clear() if prompt_adapter_index_mapping: self.prompt_adapter_index_mapping = \ prompt_adapter_index_mapping else: self.prompt_adapter_index_mapping.clear() if prompt_adapter_prompt_mapping: self.prompt_adapter_prompt_mapping = \ prompt_adapter_prompt_mapping else: self.prompt_adapter_prompt_mapping.clear() else: self.input_tokens = input_tokens or [] self.input_positions = input_positions or [] self.token_types = token_types or [] self.mrope_input_positions = mrope_input_positions or None self.seq_lens = seq_lens or [] self.orig_seq_lens = orig_seq_lens or [] self.query_lens = query_lens or [] self.context_lens = context_lens or [] self.curr_sliding_window_blocks = \ curr_sliding_window_blocks or [] self.lora_index_mapping = lora_index_mapping or [] self.lora_prompt_mapping = lora_prompt_mapping or [] self.lora_requests = lora_requests or set() self.prompt_adapter_index_mapping = ( prompt_adapter_index_mapping or []) self.prompt_adapter_prompt_mapping = ( prompt_adapter_prompt_mapping or []) self.prompt_adapter_request = prompt_adapter_request self.multi_modal_kwargs = multi_modal_kwargs self.multi_modal_placeholder_maps = multi_modal_placeholder_maps self.prefix_cache_hit = prefix_cache_hit self.n_seqs = len(self.seq_ids) if not reinit: self.__post_init__() def __post_init__(self): self.n_seqs = len(self.seq_ids) self.input_tokens = [[] for _ in range(self.n_seqs)] self.input_positions = [[] for _ in range(self.n_seqs)] self.token_types = [[] for _ in range(self.n_seqs)] self.mrope_input_positions = None self.seq_lens = [0] * self.n_seqs self.orig_seq_lens = [0] * self.n_seqs self.query_lens = [0] * self.n_seqs self.context_lens = [0] * self.n_seqs self.curr_sliding_window_blocks = [0] * self.n_seqs self.lora_index_mapping = [] self.lora_prompt_mapping = [] class NPUModelRunner(ModelRunner): """ NPU model runner with sampling step. """ _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = ( ModelInputForGPUWithSamplingMetadata) _builder_cls: Type[ModelInputForNPUBuilder] = ModelInputForNPUBuilder def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any], ) -> ModelInputForGPUWithSamplingMetadata: model_input = \ ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, ) return model_input @current_platform.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs # This represents the maximum number of different requests # that will have unique loras, an therefore the max amount of memory # consumption create dummy lora request copies from the lora request # passed in, which contains a lora from the lora warmup path. dummy_lora_requests: List[LoRARequest] = [] dummy_lora_requests_per_seq: List[LoRARequest] = [] if self.lora_config: assert self.lora_manager is not None with self.lora_manager.dummy_lora_cache(): for idx in range(self.lora_config.max_loras): lora_id = idx + 1 dummy_lora_request = LoRARequest( lora_name=f"warmup_{lora_id}", lora_int_id=lora_id, lora_path="/not/a/real/path", ) self.lora_manager.add_dummy_lora(dummy_lora_request, rank=LORA_WARMUP_RANK) dummy_lora_requests.append(dummy_lora_request) dummy_lora_requests_per_seq = [ dummy_lora_requests[idx % len(dummy_lora_requests)] for idx in range(max_num_seqs) ] # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] # Additional GPU memory may be needed for multi-modal encoding, which # needs to be accounted for when calculating the GPU blocks for # vLLM blocker manager. # To exercise the worst scenario for GPU memory consumption, # the number of seqs (batch_size) is chosen to maximize the number # of images processed. max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( self.model_config) if max_mm_tokens > 0: max_num_seqs_orig = max_num_seqs max_num_seqs = min(max_num_seqs, max_num_batched_tokens // max_mm_tokens) if max_num_seqs < 1: expr = (f"min({max_num_seqs_orig}, " f"{max_num_batched_tokens} // {max_mm_tokens})") logger.warning( "Computed max_num_seqs (%s) to be less than 1. " "Setting it to the minimum value of 1.", expr) max_num_seqs = 1 batch_size = 0 for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len dummy_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, seq_len, self.mm_registry) seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, seq_data={group_id: dummy_data.seq_data}, sampling_params=sampling_params, block_tables=None, lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None, multi_modal_data=dummy_data.multi_modal_data, multi_modal_placeholders=dummy_data.multi_modal_placeholders, ) seqs.append(seq) # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) # use an empty tensor instead of `None`` to force Dynamo to pass # it by reference, rather by specializing on the value ``None``. # the `dtype` argument does not matter, and we use `float32` as # a placeholder (it has wide hardware support). # it is important to create tensors inside the loop, rather than # multiplying the list, to avoid Dynamo from treating them as # tensor aliasing. kv_caches = [ torch.tensor([], dtype=torch.float32, device=self.device) for _ in range(num_layers) ] finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) intermediate_tensors = None if not get_pp_group().is_first_rank: intermediate_tensors = self.model.make_empty_intermediate_tensors( batch_size=batch_size, dtype=self.model_config.dtype, device=self.device) self.execute_model(model_input, kv_caches, intermediate_tensors) current_platform.synchronize() return @current_platform.inference_mode() def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: """NPU graph capture a model. TODO: not support now """ pass def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None, ) -> ModelInputForGPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. The API assumes seq_group_metadata_list is sorted by prefill -> decode. The result tensors and data structure also batches input in prefill -> decode order. For example, - input_tokens[:num_prefill_tokens] contains prefill tokens. - input_tokens[num_prefill_tokens:] contains decode tokens. If cuda graph is required, this API automatically pads inputs. """ model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) if get_pp_group().is_last_rank: # Sampling metadata is only required for the final pp group generators = self.get_generators(finished_requests_ids) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, model_input.seq_lens, model_input.query_lens, self.device, self.pin_memory, generators, self.sampling_metadata_cache, # TODO (cmq): enable this after supported in vllm # pad_for_invariant_seq_len=True, ) else: sampling_metadata = None is_prompt = (seq_group_metadata_list[0].is_prompt if seq_group_metadata_list else None) return dataclasses.replace(model_input, sampling_metadata=sampling_metadata, is_prompt=is_prompt, virtual_engine=virtual_engine) def get_model(self) -> nn.Module: return self.model