# # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # This file is a part of the vllm-ascend project. # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # Copyright 2023 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import gc from typing import TYPE_CHECKING, Dict, List, Optional, Union import numpy as np import torch import torch.distributed import torch.nn as nn from vllm.attention import AttentionType from vllm.attention.layer import Attention from vllm.config import VllmConfig from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.platforms import current_platform from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, LayerBlockType, cdiv, is_pin_memory_available) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend, AscendMetadata) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput logger = init_logger(__name__) class NPUModelRunner: def __init__(self, vllm_config: VllmConfig, device: torch.device): self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config self.lora_config = vllm_config.lora_config self.load_config = vllm_config.load_config self.parallel_config = vllm_config.parallel_config self.scheduler_config = vllm_config.scheduler_config self.speculative_config = vllm_config.speculative_config self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config model_config = self.model_config cache_config = self.cache_config scheduler_config = self.scheduler_config parallel_config = self.parallel_config self.device = device self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype if cache_config.cache_dtype == "auto": self.kv_cache_dtype = self.dtype else: self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ cache_config.cache_dtype] self.is_multimodal_model = model_config.is_multimodal_model self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs # Model-related. self.num_attn_layers = model_config.get_num_layers_by_block_type( parallel_config, LayerBlockType.attention) self.num_query_heads = model_config.get_num_attention_heads( parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() # Multi-modal data support self.input_registry = INPUT_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope encoder_compute_budget, encoder_cache_size = compute_encoder_budget( model_config=model_config, scheduler_config=scheduler_config, ) self.max_num_encoder_input_tokens = encoder_compute_budget self.encoder_cache_size = encoder_cache_size # Lazy initialization # self.model: nn.Module # Set after load_model self.kv_caches: List[torch.Tensor] = [] # req_id -> (input_id -> encoder_output) self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} # Set up speculative decoding. self.use_spec_decode = False if self.speculative_config: self.use_spec_decode = True self.rejection_sampler = RejectionSampler() # TODO: find a better way to check if we are using ngram. assert self.speculative_config.ngram_prompt_lookup_min, \ "Currently, only ngram spec decode is supported in V1." if get_pp_group().is_last_rank: self.drafter = NgramProposer() # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. self.drafter.propose( np.zeros(1024, dtype=np.int32), self.speculative_config.ngram_prompt_lookup_min, self.speculative_config.num_speculative_tokens, ) # Request states. self.requests: Dict[str, CachedRequestState] = {} # Persistent batch. self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, max_num_blocks_per_req=self.max_num_blocks_per_req, device=self.device, pin_memory=self.pin_memory, vocab_size=model_config.get_vocab_size(), ) self.input_ids = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=self.device) self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: # NOTE: `mrope_positions` is implemented with one additional dummy # position on purpose to make it non-contiguous so that it can work # with torch compile. # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 # NOTE: When M-RoPE is enabled, position ids are 3D regardless of # the modality of inputs. For text-only inputs, each dimension has # identical position IDs, making M-RoPE functionally equivalent to # 1D-RoPE. # See page 5 of https://arxiv.org/abs/2409.12191 self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1), dtype=torch.int64, device=self.device) self.mrope_positions_cpu = torch.zeros( (3, self.max_num_tokens + 1), dtype=torch.int64, device="cpu", pin_memory=self.pin_memory) self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=self.device) # OPTIMIZATION: Cache the tensors rather than creating them every step. self.arange_np = np.arange(max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens), dtype=np.int32) # NOTE(woosuk): These tensors are "stateless", i.e., they are literally # a faster version of creating a new tensor every time. Thus, we should # not make any assumptions about the values in these tensors. self.input_ids_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) self.input_ids_np = self.input_ids_cpu.numpy() self.positions_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int64, device="cpu", pin_memory=self.pin_memory) self.positions_np = self.positions_cpu.numpy() self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) self.query_start_loc_np = self.query_start_loc_cpu.numpy() self.seq_lens_cpu = torch.zeros(self.max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() self.input_positions_cpu = torch.arange(0, self.max_num_tokens, device="cpu") def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. The SamplingMetadata is updated and copied to the NPU if there is a new/resumed/paused/finished request in the batch. """ # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and # then resubmitted with the same ID. In this case, we treat them as two # distinct requests - clearing the cached states for the first request # and handling the second as a new request. removed_req_indices: List[int] = [] for req_id in scheduler_output.finished_req_ids: req_index = self.input_batch.remove_request(req_id) if req_index is not None: removed_req_indices.append(req_index) # Remove the unscheduled requests from the persistent batch. # NOTE(woosuk): The unscheduled requests are either preempted requests # or running requests that are not scheduled in this step. We remove # them from the persistent batch but keep their cached states since # they will be scheduled again sometime in the future. scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() cached_req_ids = self.input_batch.req_id_to_index.keys() unscheduled_req_ids = cached_req_ids - scheduled_req_ids # NOTE(woosuk): The persistent batch optimization assumes that # consecutive batches contain mostly the same requests. If batches # have low request overlap (e.g., alternating between two distinct # sets of requests), this optimization becomes very inefficient. for req_id in unscheduled_req_ids: req_index = self.input_batch.remove_request(req_id) assert req_index is not None removed_req_indices.append(req_index) req_ids_to_add: List[str] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id sampling_params = new_req_data.sampling_params if sampling_params.sampling_type == SamplingType.RANDOM_SEED: generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) else: generator = None self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, prompt=new_req_data.prompt, mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, generator=generator, block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], lora_request=new_req_data.lora_request, ) req_ids_to_add.append(req_id) # Update the states of the running/resumed requests. for req_data in scheduler_output.scheduled_cached_reqs: req_id = req_data.req_id req_state = self.requests[req_id] # Update the cached states. num_computed_tokens = req_data.num_computed_tokens req_state.num_computed_tokens = num_computed_tokens # Add the sampled token(s) from the previous step (if any). # This doesn't include "unverified" tokens like spec decode tokens. num_new_tokens = (num_computed_tokens + len(req_data.new_token_ids) - req_state.num_tokens) if num_new_tokens == 1: # Avoid slicing list in most common case. req_state.output_token_ids.append(req_data.new_token_ids[-1]) elif num_new_tokens > 0: req_state.output_token_ids.extend( req_data.new_token_ids[-num_new_tokens:]) # Update the block IDs. if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. req_state.block_ids.extend(req_data.new_block_ids) else: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. req_state.block_ids = req_data.new_block_ids req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not # scheduled in the previous step and needs to be added again. req_ids_to_add.append(req_id) continue # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) start_index = (len(req_state.block_ids) - len(req_data.new_block_ids)) self.input_batch.block_table.append_row(req_data.new_block_ids, req_index) # Add new_token_ids to token_ids_cpu. start_token_index = num_computed_tokens end_token_index = num_computed_tokens + len(req_data.new_token_ids) self.input_batch.token_ids_cpu[ req_index, start_token_index:end_token_index] = req_data.new_token_ids self.input_batch.num_tokens_no_spec[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( req_id, ()) if spec_token_ids: start_index = end_token_index end_token_index += len(spec_token_ids) self.input_batch.token_ids_cpu[ req_index, start_index:end_token_index] = spec_token_ids # NOTE(woosuk): `num_tokens` here may include spec decode tokens. self.input_batch.num_tokens[req_index] = end_token_index # Check if the batch has changed. If not, we can skip copying the # sampling metadata from CPU to GPU. batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0 # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. removed_req_indices = sorted(removed_req_indices, reverse=True) for req_id in req_ids_to_add: req_state = self.requests[req_id] if removed_req_indices: # Fill the empty index. req_index = removed_req_indices.pop() else: # Append to the end. req_index = None self.input_batch.add_request(req_state, req_index) # Condense the batched states if there are empty indices. if removed_req_indices: self.input_batch.condense(removed_req_indices) if batch_changed: self.input_batch.refresh_sampling_metadata() def get_model(self) -> nn.Module: return self.model @staticmethod def make_attention_mask(kv_dtype, kv_device, max_seq_len, seq_lens, query_lens): # for paged attention atten_mask = np.zeros([0, max_seq_len]) for i, context_length in enumerate(seq_lens): q_len = query_lens[i] ones_len = context_length - q_len ones = np.ones((q_len, ones_len), dtype=np.float16) bias_cache = np.tril( np.ones((q_len, max_seq_len - ones_len), dtype=np.float16)) bias_cache = np.concatenate((ones, bias_cache), axis=1) mask_value = -10000 bias_cache[bias_cache == 0] = mask_value bias_cache[bias_cache == 1] = 0 atten_mask = np.concatenate([atten_mask, bias_cache], axis=0) atten_mask = torch.from_numpy(atten_mask).to(kv_dtype).to(kv_device) return atten_mask def _process_reqs( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: # check input valid total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit(num_reqs) # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32) max_num_scheduled_tokens = 0 for i, req_id in enumerate(self.input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens[i] = num_tokens max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) # prepare positions req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) cu_num_tokens = np.cumsum(num_scheduled_tokens) cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, num_scheduled_tokens) arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets positions_np = self.positions_np[:total_num_scheduled_tokens] np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) positions = self.positions[:total_num_scheduled_tokens] self.seq_lens_np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) seq_lens = self.seq_lens_cpu[:num_reqs] query_lens = torch.from_numpy(num_scheduled_tokens) block_table_indices = (req_indices * self.max_num_blocks_per_req + positions_np // self.block_size) block_table_cpu = self.input_batch.block_table.get_cpu_tensor() block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() block_offsets = positions_np % self.block_size np.add(block_numbers * self.block_size, block_offsets, out=self.slot_mapping_np[:total_num_scheduled_tokens]) slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( self.device, non_blocking=True) attn_mask = self.make_attention_mask( self.vllm_config.model_config.dtype, self.device, max(seq_lens, default=0), seq_lens, num_scheduled_tokens) attn_metadata = AscendMetadata( seq_lens=query_lens, context_lens=seq_lens, slot_mapping=slot_mapping, block_tables=( self.input_batch.block_table.get_device_tensor()[:num_reqs]), attn_mask=attn_mask, ) # prepare input_ids token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) # Copy the tensors to the NPU. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) input_ids = self.input_ids[:total_num_scheduled_tokens] # Run forward pass with set_forward_context(attn_metadata, self.vllm_config): assert self.model is not None hidden_states = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=None, ) return hidden_states[cu_num_tokens - 1] @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, torch.Tensor]: self._update_states(scheduler_output) hidden_states = self._process_reqs(scheduler_output, intermediate_tensors) logits = self.model.compute_logits(hidden_states, None) # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata sampler_output = self.model.sample( logits=logits, sampling_metadata=sampling_metadata, ) # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. for i, req_id in enumerate(self.input_batch.req_ids): req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) if seq_len < req_state.num_tokens: # Ignore the sampled token. # Rewind the generator state as if the token was not sampled. generator = self.input_batch.generators.get(i) if generator is not None: generator.set_offset(generator.get_offset() - 4) # NOTE: NPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors logprobs_lists = logprobs_tensors.tolists() \ if logprobs_tensors is not None else None # Get the valid generated tokens. sampled_token_ids = sampler_output.sampled_token_ids max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: # No spec decode tokens. valid_sampled_token_ids = sampled_token_ids.tolist() else: # Includes spec decode tokens. valid_mask = sampled_token_ids != INVALID_TOKEN_ID gen_lens = valid_mask.sum(dim=1).tolist() # TODO(woosuk): Optimize this. valid_sampled_token_ids = [ seq.tolist() for seq in sampled_token_ids[valid_mask].split(gen_lens) ] model_runner_output = ModelRunnerOutput( req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, spec_token_ids=None, logprobs=logprobs_lists, prompt_logprobs_dict={}, ) return model_runner_output def _profile_multimodal(self) -> None: # TODO: handle encoder-decoder models once we support them. # NOTE: Currently model is profiled with a single non-text # modality with the max possible input tokens even when # it supports multiple. if (not self.is_multimodal_model or self.max_num_encoder_input_tokens <= 0 or self.encoder_cache_size <= 0): return max_tokens_by_modality_dict = ( MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( self.model_config)) dummy_data_modality, max_tokens_per_mm_item = max( max_tokens_by_modality_dict.items(), key=lambda item: item[1]) # Check how many items of this modality can be supported by # the encoder budget. encoder_budget = min(self.max_num_encoder_input_tokens, self.encoder_cache_size) max_num_mm_items_encoder_budget = cdiv(encoder_budget, max_tokens_per_mm_item) # Check how many items of this modality can be supported by # the decoder budget. max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt( self.model_config)[dummy_data_modality] # NOTE: We do not consider max_num_batched_tokens on purpose # because the multimodal embeddings can be generated in advance # and chunked prefilled. max_num_mm_items_decoder_budget = self.max_num_reqs * \ max_mm_items_per_req max_num_mm_items = min(max_num_mm_items_encoder_budget, max_num_mm_items_decoder_budget) logger.info( "Encoder cache will be initialized with a budget of %s tokens," " and profiled with %s %s items of the maximum feature size.", encoder_budget, max_num_mm_items, dummy_data_modality) # Create dummy batch of multimodal inputs. dummy_request_data = self.input_registry.dummy_data_for_profiling( model_config=self.model_config, seq_len=self.max_num_tokens, mm_registry=self.mm_registry, ) dummy_mm_data = dummy_request_data.multi_modal_data if not isinstance(dummy_mm_data, MultiModalKwargs): # TODO: Delete this check once input mapper is fully removed. raise RuntimeError("Legacy input mapper is not supported in V1") # Dummy data definition in V0 may contain multiple multimodal items # (e.g, multiple images) for a single request, therefore here we # always replicate first item by max_num_mm_items times since in V1 # they are scheduled to be processed separately. dummy_mm_item = dummy_mm_data.get_item(modality=dummy_data_modality, item_index=0) dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * max_num_mm_items) batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( batched_dummy_mm_inputs, device=self.device) # Run multimodal encoder. dummy_encoder_outputs = self.model.get_multimodal_embeddings( **batched_dummy_mm_inputs) assert len(dummy_encoder_outputs) == max_num_mm_items, ( "Expected dimension 0 of encoder outputs to match the number " f"of multimodal data items: {max_num_mm_items}, got " f"{len(dummy_encoder_outputs)=} instead. This is most likely " "due to the 'get_multimodal_embeddings' method of the model " "not implemented correctly.") # Cache the dummy encoder outputs. self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) @torch.inference_mode() def _dummy_run( self, num_tokens: int, ) -> torch.Tensor: model = self.model if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] else: input_ids = self.input_ids[:num_tokens] inputs_embeds = None if self.uses_mrope: positions = self.mrope_positions[:, :num_tokens] else: positions = self.input_positions_cpu[:num_tokens] if get_pp_group().is_first_rank: intermediate_tensors = None else: if self.intermediate_tensors is None: self.intermediate_tensors = ( self.model.make_empty_intermediate_tensors( batch_size=self.max_num_tokens, dtype=self.model_config.dtype, device=self.device)) intermediate_tensors = IntermediateTensors({ k: v[:num_tokens] for k, v in self.intermediate_tensors.items() }) with set_forward_context(None, self.vllm_config): hidden_states = model(input_ids=input_ids, positions=positions.to(self.device), intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. self._profile_multimodal() # For profile, have maximum num_reqs and that collectively have # maximum num_tokens. num_reqs = self.scheduler_config.max_num_seqs num_tokens = self.max_num_tokens min_tokens_per_req = num_tokens // num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) logit_indices = np.cumsum(num_scheduled_tokens) - 1 # assert self.lora_manager is not None, "LoRA is not enabled" # TODO: call maybe_profile_with_lora() dummy_kv_caches = [ torch.tensor((), dtype=torch.float32, device=self.device) for _ in range(self.num_attn_layers) ] # Trigger compilation for general shape. hidden_states = self._dummy_run(self.max_num_tokens) if get_pp_group().is_last_rank: hidden_states = hidden_states[logit_indices] logits = self.model.compute_logits(hidden_states, None) else: logits = None current_platform.synchronize() del hidden_states, logits, dummy_kv_caches self.encoder_cache.clear() gc.collect() def generate_draft_token_ids( self, sampled_token_ids: list[list[int]], ) -> list[list[int]]: # TODO(woosuk): Optimize. draft_token_ids: list[list[int]] = [] for i, sampled_ids in enumerate(sampled_token_ids): num_sampled_ids = len(sampled_ids) if not num_sampled_ids: # Skip speculative decoding. draft_token_ids.append([]) continue # Add sampled_token_ids to token_ids_cpu. start_idx = self.input_batch.num_tokens_no_spec[i] end_idx = start_idx + num_sampled_ids self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids drafter_output = self.drafter.propose( self.input_batch.token_ids_cpu[i, :end_idx], self.speculative_config.ngram_prompt_lookup_min, self.speculative_config.num_speculative_tokens, ) if drafter_output is None or len(drafter_output) == 0: draft_token_ids.append([]) else: draft_token_ids.append(drafter_output.tolist()) return draft_token_ids def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) if self.lora_config: raise ValueError("LoRA model is not supported on NPU now.") self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. Args: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ if len(kv_cache_config.groups) > 1: raise NotImplementedError( "Hybrid models with more than one KV cache type are not " "supported yet.") kv_caches: Dict[str, torch.Tensor] = {} for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items(): tensor_config = kv_cache_config.tensors[layer_name] assert tensor_config.size % layer_spec.page_size_bytes == 0 num_blocks = tensor_config.size // layer_spec.page_size_bytes if isinstance(layer_spec, FullAttentionSpec): kv_cache_shape = AscendAttentionBackend.get_kv_cache_shape( num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, layer_spec.head_size) dtype = layer_spec.dtype kv_caches[layer_name] = torch.zeros(kv_cache_shape, dtype=dtype, device=self.device) else: raise NotImplementedError bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, self.kv_caches) def get_kv_cache_spec(self) -> KVCacheSpec: """ Generates the KVCacheSpec by parsing the kv cache format from each Attention module in the static forward context. Returns: KVCacheSpec: A dictionary mapping layer names to their KV cache format. Layers that do not need KV cache are not included. """ forward_ctx = self.vllm_config.compilation_config.static_forward_context block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: KVCacheSpec = {} for layer_name, attn_module in forward_ctx.items(): if isinstance(attn_module, FusedMoE): continue # TODO: Support other attention modules, e.g., sliding window, # cross-attention assert isinstance(attn_module, Attention) if attn_module.attn_type == AttentionType.DECODER: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=attn_module.dtype, use_mla=use_mla) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. continue elif attn_module.attn_type == AttentionType.ENCODER_DECODER: raise NotImplementedError else: raise ValueError( f"Unknown attention type: {attn_module.attn_type}") return kv_cache_spec