# # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # Copyright 2023 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This file is a part of the vllm-ascend project. # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # import gc import os from typing import TYPE_CHECKING, Dict, List, Optional, Union import numpy as np import numpy.typing as npt import torch import torch.nn as nn from vllm.attention import AttentionType from vllm.attention.layer import Attention from vllm.config import VllmConfig from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import DeviceMemoryProfiler, LayerBlockType, cdiv from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend, AscendMetadata) from vllm_ascend.platform import NPUPlatform if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput NPU_PAGED_ATTENTION_MASK_VALUE = -10000 class NPUModelRunner: def __init__(self, vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.lora_config = vllm_config.lora_config self.scheduler_config = vllm_config.scheduler_config self.device = device self.is_multimodal_model = self.model_config.is_multimodal_model self.block_size = vllm_config.cache_config.block_size self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len, self.block_size) self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.max_num_reqs = self.scheduler_config.max_num_seqs # Model-related. self.num_attn_layers = self.model_config.get_num_layers_by_block_type( vllm_config.parallel_config, LayerBlockType.attention) self.hidden_size = self.model_config.get_hidden_size() # Multi-modal data support self.input_registry = INPUT_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = self.model_config.uses_mrope self.max_num_encoder_input_tokens, self.encoder_cache_size = compute_encoder_budget( model_config=self.model_config, scheduler_config=self.scheduler_config, mm_registry=self.mm_registry) # Lazy initialization # self.model: nn.Module # Set after load_model self.kv_caches: List[torch.Tensor] = [] # req_id -> (input_id -> encoder_output) self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} # Request states. self.requests: Dict[str, CachedRequestState] = {} # Persistent batch. self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.model_config.max_model_len, max_num_blocks_per_req=self.max_num_blocks_per_req, device=self.device, pin_memory=True, vocab_size=self.model_config.get_vocab_size(), ) self.input_ids = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=self.device) self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: # NOTE: `mrope_positions` is implemented with one additional dummy # position on purpose to make it non-contiguous so that it can work # with torch compile. # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 # NOTE: When M-RoPE is enabled, position ids are 3D regardless of # the modality of inputs. For text-only inputs, each dimension has # identical position IDs, making M-RoPE functionally equivalent to # 1D-RoPE. # See page 5 of https://arxiv.org/abs/2409.12191 self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1), dtype=torch.int64, device=self.device) self.mrope_positions_cpu = torch.zeros( (3, self.max_num_tokens + 1), dtype=torch.int64, device="cpu", pin_memory=True) self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.model_config.dtype, device=self.device) # OPTIMIZATION: Cache the tensors rather than creating them every step. self.arange_np: npt.NDArray[np.int32] = np.arange(max( self.max_num_reqs + 1, self.model_config.max_model_len, self.max_num_tokens), dtype=np.int32) # NOTE(woosuk): These tensors are "stateless", i.e., they are literally # a faster version of creating a new tensor every time. Thus, we should # not make any assumptions about the values in these tensors. self.input_ids_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int32, device="cpu", pin_memory=True) self.positions_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int64, device="cpu", pin_memory=True) self.positions_np = self.positions_cpu.numpy() self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int32, device="cpu", pin_memory=True) self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.seq_lens_cpu = torch.zeros(self.max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=True) self.seq_lens_np = self.seq_lens_cpu.numpy() self.input_positions_cpu = torch.arange(0, self.max_num_tokens, device="cpu") # NOTE: Pre-construct a mask matrix to improve the efficiency of # attention mask construction during inference. # Note that the length of the matrix needs to be carefully balanced: a # matrix that is too large will consume excessive VRAM, while a matrix # that is too small will require dynamic concatenation during inference, # leading to performance degradation. # Therefore, an environment variable is added here to dynamically set # the size of the pre-constructed mask matrix based on requirements. mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000) self.attn_mask_len = min(self.model_config.max_model_len, int(mask_len)) self.attn_mask_npu = torch.full( (self.attn_mask_len, self.attn_mask_len), NPU_PAGED_ATTENTION_MASK_VALUE, device=self.device, dtype=self.vllm_config.model_config.dtype) self.attn_mask_npu.masked_fill_( self.attn_mask_npu.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0) def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. The SamplingMetadata is updated and copied to the NPU if there is a new/resumed/paused/finished request in the batch. """ # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and # then resubmitted with the same ID. In this case, we treat them as two # distinct requests - clearing the cached states for the first request # and handling the second as a new request. removed_req_indices: List[int] = [] for req_id in scheduler_output.finished_req_ids: req_index = self.input_batch.remove_request(req_id) if req_index is not None: removed_req_indices.append(req_index) # Remove the unscheduled requests from the persistent batch. # NOTE(woosuk): The unscheduled requests are either preempted requests # or running requests that are not scheduled in this step. We remove # them from the persistent batch but keep their cached states since # they will be scheduled again sometime in the future. scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() cached_req_ids = self.input_batch.req_id_to_index.keys() unscheduled_req_ids = cached_req_ids - scheduled_req_ids # NOTE(woosuk): The persistent batch optimization assumes that # consecutive batches contain mostly the same requests. If batches # have low request overlap (e.g., alternating between two distinct # sets of requests), this optimization becomes very inefficient. for req_id in unscheduled_req_ids: req_index = self.input_batch.remove_request(req_id) assert req_index is not None removed_req_indices.append(req_index) req_ids_to_add: List[str] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id sampling_params = new_req_data.sampling_params if sampling_params.sampling_type == SamplingType.RANDOM_SEED: generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) else: generator = None self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, prompt=new_req_data.prompt, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, generator=generator, block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], lora_request=new_req_data.lora_request, ) req_ids_to_add.append(req_id) # Update the states of the running/resumed requests. for req_data in scheduler_output.scheduled_cached_reqs: req_id = req_data.req_id req_state = self.requests[req_id] # Update the cached states. num_computed_tokens = req_data.num_computed_tokens req_state.num_computed_tokens = num_computed_tokens # Add the sampled token(s) from the previous step (if any). # This doesn't include "unverified" tokens like spec decode tokens. num_new_tokens = (num_computed_tokens + len(req_data.new_token_ids) - req_state.num_tokens) if num_new_tokens == 1: # Avoid slicing list in most common case. req_state.output_token_ids.append(req_data.new_token_ids[-1]) elif num_new_tokens > 0: req_state.output_token_ids.extend( req_data.new_token_ids[-num_new_tokens:]) # Update the block IDs. if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. req_state.block_ids.extend(req_data.new_block_ids) else: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. req_state.block_ids = req_data.new_block_ids req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not # scheduled in the previous step and needs to be added again. req_ids_to_add.append(req_id) continue # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) start_index = (len(req_state.block_ids) - len(req_data.new_block_ids)) self.input_batch.block_table.append_row(req_data.new_block_ids, req_index) # Add new_token_ids to token_ids_cpu. start_token_index = num_computed_tokens end_token_index = num_computed_tokens + len(req_data.new_token_ids) self.input_batch.token_ids_cpu[ req_index, start_token_index:end_token_index] = req_data.new_token_ids self.input_batch.num_tokens_no_spec[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( req_id, ()) if spec_token_ids: start_index = end_token_index end_token_index += len(spec_token_ids) self.input_batch.token_ids_cpu[ req_index, start_index:end_token_index] = spec_token_ids # NOTE(woosuk): `num_tokens` here may include spec decode tokens. self.input_batch.num_tokens[req_index] = end_token_index # Check if the batch has changed. If not, we can skip copying the # sampling metadata from CPU to GPU. batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0 # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. removed_req_indices = sorted(removed_req_indices, reverse=True) for req_id in req_ids_to_add: req_state = self.requests[req_id] if removed_req_indices: # Fill the empty index. req_index = removed_req_indices.pop() else: # Append to the end. req_index = None self.input_batch.add_request(req_state, req_index) # Condense the batched states if there are empty indices. if removed_req_indices: self.input_batch.condense(removed_req_indices) if batch_changed: self.input_batch.refresh_sampling_metadata() def get_model(self) -> nn.Module: return self.model def _make_attention_mask(self, seq_lens, query_lens, position) -> torch.Tensor: max_seq_len = max(seq_lens, default=0) if max_seq_len <= self.attn_mask_len: return torch.index_select(self.attn_mask_npu, dim=0, index=position)[:, :max_seq_len] total_q_len = sum(query_lens) attn_mask = torch.zeros((total_q_len, max_seq_len), dtype=self.vllm_config.model_config.dtype, device="cpu") current_row = 0 for i in range(len(query_lens)): seq_len = seq_lens[i] q_len = query_lens[i] context_len = seq_len - q_len assert context_len >= 0 attn_mask[current_row:current_row + q_len, context_len:] = NPU_PAGED_ATTENTION_MASK_VALUE right_tensor = attn_mask[current_row:current_row + q_len, context_len:seq_len] right_tensor.mask_fill_( right_tensor.tril() == NPU_PAGED_ATTENTION_MASK_VALUE, 0) current_row += q_len return attn_mask.to(self.device, non_blocking=True) def _process_reqs( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: # Check input valid total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 # Copy the blocks from CPU to NPU. # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit(num_reqs) # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32) max_num_scheduled_tokens = 0 for i, req_id in enumerate(self.input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens[i] = num_tokens max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) # Prepare positions req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) cu_num_tokens = np.cumsum(num_scheduled_tokens) cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, num_scheduled_tokens) arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets positions_np = self.positions_np[:total_num_scheduled_tokens] np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) positions = self.positions[:total_num_scheduled_tokens] self.seq_lens_np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) seq_lens = self.seq_lens_cpu[:num_reqs] query_lens = torch.from_numpy(num_scheduled_tokens) block_table_indices = (req_indices * self.max_num_blocks_per_req + positions_np // self.block_size) block_table_cpu = self.input_batch.block_table.get_cpu_tensor() block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() block_offsets = positions_np % self.block_size np.add(block_numbers * self.block_size, block_offsets, out=self.slot_mapping_np[:total_num_scheduled_tokens]) slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( self.device, non_blocking=True) attn_mask = self._make_attention_mask(seq_lens=seq_lens, query_lens=num_scheduled_tokens, position=positions) attn_metadata = AscendMetadata( seq_lens=query_lens, context_lens=seq_lens, slot_mapping=slot_mapping, block_tables=( self.input_batch.block_table.get_device_tensor()[:num_reqs]), attn_mask=attn_mask, ) # Prepare input_ids token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) # Copy the tensors to the NPU. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) input_ids = self.input_ids[:total_num_scheduled_tokens] # Run forward pass with set_forward_context(attn_metadata, self.vllm_config): assert self.model is not None hidden_states = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=None, ) return hidden_states[cu_num_tokens - 1] @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: # Return empty ModelRunnerOuptut if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT hidden_states = self._process_reqs(scheduler_output, intermediate_tensors) logits = self.model.compute_logits(hidden_states, None) # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata sampler_output = self.model.sample( logits=logits, sampling_metadata=sampling_metadata, ) # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. for i, req_id in enumerate(self.input_batch.req_ids): req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) if seq_len < req_state.num_tokens: # Ignore the sampled token. # Rewind the generator state as if the token was not sampled. generator = self.input_batch.generators.get(i) if generator is not None: generator.set_offset(generator.get_offset() - 4) # NOTE: NPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors logprobs_lists = logprobs_tensors.tolists() \ if logprobs_tensors is not None else None # Get the valid generated tokens. sampled_token_ids = sampler_output.sampled_token_ids max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: # No spec decode tokens. valid_sampled_token_ids = sampled_token_ids.tolist() model_runner_output = ModelRunnerOutput( req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, spec_token_ids=None, logprobs=logprobs_lists, prompt_logprobs_dict={}, ) return model_runner_output def _profile_multimodal(self) -> None: # TODO: handle encoder-decoder models once we support them. # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when # it supports multiple. if (not self.is_multimodal_model or self.max_num_encoder_input_tokens <= 0 or self.encoder_cache_size <= 0): return max_tokens_by_modality_dict = ( MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( self.model_config)) dummy_data_modality, max_tokens_per_mm_item = max( max_tokens_by_modality_dict.items(), key=lambda item: item[1]) # Check how many items of this modality can be supported by # the encoder budget. encoder_budget = min(self.max_num_encoder_input_tokens, self.encoder_cache_size) max_num_mm_items_encoder_budget = cdiv(encoder_budget, max_tokens_per_mm_item) # Check how many items of this modality can be supported by # the decoder budget. max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt( self.model_config)[dummy_data_modality] # NOTE: We do not consider max_num_batched_tokens on purpose # because the multimodal embeddings can be generated in advance # and chunked prefilled. max_num_mm_items_decoder_budget = self.max_num_reqs * \ max_mm_items_per_req max_num_mm_items = min(max_num_mm_items_encoder_budget, max_num_mm_items_decoder_budget) logger.info( "Encoder cache will be initialized with a budget of %s tokens," " and profiled with %s %s items of the maximum feature size.", encoder_budget, max_num_mm_items, dummy_data_modality) # Create dummy batch of multimodal inputs. dummy_request_data = self.input_registry.dummy_data_for_profiling( model_config=self.model_config, seq_len=self.max_num_tokens, mm_registry=self.mm_registry, ) dummy_mm_data = dummy_request_data.multi_modal_data if not isinstance(dummy_mm_data, MultiModalKwargs): # TODO: Delete this check once input mapper is fully removed. raise RuntimeError("Legacy input mapper is not supported in V1") # Dummy data definition in V0 may contain multiple multimodal items # (e.g, multiple images) for a single request, therefore here we # always replicate first item by max_num_mm_items times since in V1 # they are scheduled to be processed separately. dummy_mm_item = dummy_mm_data.get_item(modality=dummy_data_modality, item_index=0) dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * max_num_mm_items) batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( batched_dummy_mm_inputs, device=self.device) # Run multimodal encoder. dummy_encoder_outputs = self.model.get_multimodal_embeddings( **batched_dummy_mm_inputs) assert len(dummy_encoder_outputs) == max_num_mm_items, ( "Expected dimension 0 of encoder outputs to match the number " f"of multimodal data items: {max_num_mm_items}, got " f"{len(dummy_encoder_outputs)=} instead. This is most likely " "due to the 'get_multimodal_embeddings' method of the model " "not implemented correctly.") # Cache the dummy encoder outputs. self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) @torch.inference_mode() def _dummy_run(self) -> torch.Tensor: model = self.model if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:self.max_num_tokens] else: input_ids = self.input_ids[:self.max_num_tokens] inputs_embeds = None if self.uses_mrope: positions = self.mrope_positions[:, :self.max_num_tokens] else: positions = self.input_positions_cpu[:self.max_num_tokens] if get_pp_group().is_first_rank: intermediate_tensors = None else: if self.intermediate_tensors is None: self.intermediate_tensors = ( self.model.make_empty_intermediate_tensors( batch_size=self.max_num_tokens, dtype=self.model_config.dtype, device=self.device)) intermediate_tensors = IntermediateTensors({ k: v[:self.max_num_tokens] for k, v in self.intermediate_tensors.items() }) with set_forward_context(None, self.vllm_config): hidden_states = model(input_ids=input_ids, positions=positions.to(self.device), intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. self._profile_multimodal() # For profile, have maximum num_reqs and that collectively have # maximum num_tokens. num_reqs = self.scheduler_config.max_num_seqs num_tokens = self.max_num_tokens min_tokens_per_req = num_tokens // num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) logit_indices = np.cumsum(num_scheduled_tokens) - 1 # assert self.lora_manager is not None, "LoRA is not enabled" # TODO: call maybe_profile_with_lora() dummy_kv_caches = [ torch.tensor((), dtype=torch.float32, device=self.device) for _ in range(self.num_attn_layers) ] # Trigger compilation for general shape. hidden_states = self._dummy_run() if get_pp_group().is_last_rank: hidden_states = hidden_states[logit_indices] logits = self.model.compute_logits(hidden_states, None) else: logits = None NPUPlatform.synchronize() del hidden_states, logits, dummy_kv_caches self.encoder_cache.clear() gc.collect() def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) if self.lora_config: raise ValueError("LoRA model is not supported on NPU now.") logger.info("Loading model weights took %.4f GB", m.consumed_memory / float(2**30)) def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. Args: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ kv_caches: Dict[str, torch.Tensor] = {} for kv_cache_group in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group.kv_cache_spec for layer_name in kv_cache_group.layer_names: tensor_config = kv_cache_config.tensors[layer_name] assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes # `num_blocks` is the number of blocks the model runner can use. # `kv_cache_config.num_blocks` is the number of blocks that # KVCacheManager may allocate. # Since different GPUs may have different number of layers and # different memory capacities, `num_blocks` can be different on # different GPUs, and `kv_cache_config.num_blocks` is set to # the min of all `num_blocks`. Verify it here. assert num_blocks >= kv_cache_config.num_blocks if isinstance(kv_cache_spec, FullAttentionSpec): kv_cache_shape = AscendAttentionBackend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype kv_caches[layer_name] = torch.zeros(kv_cache_shape, dtype=dtype, device=self.device) else: # TODO: add new branches when introducing more types of # KV cache specs. raise ValueError("Unknown KV cache spec type.") bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, self.kv_caches) def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each Attention module in the static forward context. Returns: KVCacheSpec: A dictionary mapping layer names to their KV cache format. Layers that do not need KV cache are not included. """ forward_ctx = self.vllm_config.compilation_config.static_forward_context block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} for layer_name, attn_module in forward_ctx.items(): if isinstance(attn_module, FusedMoE): continue # TODO: Support other attention modules, e.g., sliding window, # cross-attention assert isinstance(attn_module, Attention) if attn_module.attn_type == AttentionType.DECODER: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=attn_module.dtype, use_mla=use_mla) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. continue elif attn_module.attn_type == AttentionType.ENCODER_DECODER: raise NotImplementedError else: raise ValueError( f"Unknown attention type: {attn_module.attn_type}") return kv_cache_spec