diff --git a/vllm_ascend/worker/model_runner.py b/vllm_ascend/worker/model_runner.py deleted file mode 100644 index 48c5d4b..0000000 --- a/vllm_ascend/worker/model_runner.py +++ /dev/null @@ -1,1607 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# Adapted from vllm-project/vllm/vllm/worker/model_runner.py -# - -import dataclasses -import itertools -import weakref -from contextlib import contextmanager -from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, - Type, TypeVar, Union) - -import numpy as np -import torch -import torch.nn as nn -import vllm.envs as envs -from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.attention.backends.utils import CommonAttentionState -from vllm.config import VllmConfig -from vllm.core.scheduler import SchedulerOutputs -from vllm.distributed import broadcast_tensor_dict, get_dp_group, get_pp_group -from vllm.distributed.kv_transfer import get_kv_transfer_group -from vllm.forward_context import set_forward_context -from vllm.inputs import INPUT_REGISTRY, InputRegistry -from vllm.logger import logger -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest -from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager -from vllm.model_executor import SamplingMetadata, SamplingMetadataCache -from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput, - get_sampler) -from vllm.model_executor.model_loader import get_model -from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.model_executor.models import supports_lora, supports_multimodal -from vllm.model_executor.models.utils import set_cpu_offload_max_bytes -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalKwargs, MultiModalPlaceholderMap, - MultiModalRegistry) -from vllm.prompt_adapter.layers import PromptAdapterMapping -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, flatten_2d_lists, - is_pin_memory_available) -from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, - _add_attn_metadata_broadcastable_dict, - _add_sampling_metadata_broadcastable_dict, - _init_attn_metadata_from_tensor_dict, - _init_sampling_metadata_from_tensor_dict) - -from vllm_ascend.ascend_config import get_ascend_config - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - -TModelInputForNPU = TypeVar('TModelInputForNPU', bound="ModelInputForNPU") -ENCODER_NUM = 0 -# if True, allow tensor initialization and casting with internal format (e.g., NZ) -torch.npu.config.allow_internal_format = True - - -@dataclass(frozen=True) -class ModelInputForNPU(ModelRunnerInputBase): - """ - This base class contains metadata needed for the base model forward pass - but not metadata for possible additional steps, e.g., sampling. Model - runners that run additional steps should subclass this method to add - additional fields. - """ - input_tokens: Optional[torch.Tensor] = None - inputs_embeds: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - token_types: Optional[torch.Tensor] = None - seq_lens: Optional[List[int]] = None - query_lens: Optional[List[int]] = None - lora_mapping: Optional["LoRAMapping"] = None - lora_requests: Optional[Set[LoRARequest]] = None - attn_metadata: Optional["AttentionMetadata"] = None - multi_modal_kwargs: Optional[BatchedTensorInputs] = None - request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None - finished_requests_ids: Optional[List[str]] = None - virtual_engine: int = 0 - async_callback: Optional[Callable] = None - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None - scheduler_outputs: Optional[SchedulerOutputs] = None - previous_hidden_states: Optional[torch.Tensor] = None - - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - tensor_dict = { - "input_tokens": self.input_tokens, - "inputs_embeds": self.inputs_embeds, - "input_positions": self.input_positions, - "lora_requests": self.lora_requests, - "lora_mapping": self.lora_mapping, - "multi_modal_kwargs": self.multi_modal_kwargs, - "virtual_engine": self.virtual_engine, - "request_ids_to_seq_ids": self.request_ids_to_seq_ids, - "finished_requests_ids": self.finished_requests_ids, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls: Type[TModelInputForNPU], - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> TModelInputForNPU: - if attn_backend is not None: - tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) - return cls(**tensor_dict) - - # Exclude `async_callback` to be able to pickle this object - def __getstate__(self): - state = self.__dict__.copy() - del state["async_callback"] - return state - - # TODO: What happens when we depickle this object? - # How can we update this callback to properly pass it to the engine? - def __setstate__(self, state): - self.__dict__.update(state) - self.__dict__.update({'async_callback': None}) - - -@dataclass(frozen=True) -class ModelInputForNPUWithSamplingMetadata(ModelInputForNPU): - """ - Used by the ModelRunner. - """ - sampling_metadata: Optional["SamplingMetadata"] = None - # Used for speculative decoding. We do not broadcast it because it is only - # used by the driver worker. - is_prompt: Optional[bool] = None - - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - tensor_dict = { - "input_tokens": self.input_tokens, - "inputs_embeds": self.inputs_embeds, - "input_positions": self.input_positions, - "lora_requests": self.lora_requests, - "lora_mapping": self.lora_mapping, - "multi_modal_kwargs": self.multi_modal_kwargs, - "virtual_engine": self.virtual_engine, - "request_ids_to_seq_ids": self.request_ids_to_seq_ids, - "finished_requests_ids": self.finished_requests_ids, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls, - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> "ModelInputForNPUWithSamplingMetadata": - tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) - if attn_backend is not None: - tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) - return cls(**tensor_dict) - - -class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]): - """Build ModelInputForNPU from SequenceGroupMetadata.""" - - # Note: ideally we would be using a dataclass(kw_only=True) - # here, so that this can be subclassed easily, - # but kw_only is not supported in python<3.10. - class InterDataForSeqGroup: - """Intermediate data for the current sequence group.""" - - def simple_reinit(self): - self.input_tokens[0].clear() # type: ignore - self.inputs_embeds = None # type: ignore - self.input_positions[0].clear() # type: ignore - self.token_types[0].clear() # type: ignore - self.mrope_input_positions = None # type: ignore - self.seq_lens[0] = 0 # type: ignore - self.orig_seq_lens[0] = 0 # type: ignore - self.query_lens[0] = 0 # type: ignore - self.context_lens[0] = 0 # type: ignore - self.curr_sliding_window_blocks[0] = 0 # type: ignore - self.lora_index_mapping.clear() # type: ignore - self.lora_prompt_mapping.clear() # type: ignore - self.lora_requests.clear() # type: ignore - - def __init__( - self, - *, - # From sequence group metadata. - request_id: str, - seq_ids: List[int], - is_prompt: bool, - block_tables: Optional[Dict[int, List[int]]], - computed_block_nums: List[int], - n_seqs: int = 0, - - # Input tokens and positions. - input_tokens: Optional[List[List[int]]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - input_positions: Optional[List[List[int]]] = None, - token_types: Optional[List[List[int]]] = None, - mrope_input_positions: Optional[List[List[List[int]]]] = None, - - # The sequence length (may be capped to the sliding window). - seq_lens: Optional[List[int]] = None, - # The original sequence length (before applying sliding window). - # This is used to compute slot mapping. - orig_seq_lens: Optional[List[int]] = None, - # The query length. - query_lens: Optional[List[int]] = None, - # The number of tokens that are already computed. - context_lens: Optional[List[int]] = None, - # The current sliding window block. - curr_sliding_window_blocks: Optional[List[int]] = None, - - # LoRA inputs. - lora_index_mapping: Optional[List[List[int]]] = None, - lora_prompt_mapping: Optional[List[List[int]]] = None, - lora_requests: Optional[Set[LoRARequest]] = None, - - # Multi-modal inputs. - multi_modal_kwargs: Optional[MultiModalKwargs] = None, - multi_modal_placeholder_maps: Optional[Dict[ - str, MultiModalPlaceholderMap]] = None, - - # Whether the prefix cache is hit (prefill only). - prefix_cache_hit: bool = False, - reinit: bool = False, - reinit_use_defaults: bool = False, - encoder_seq_len: int = 0, - ): - if reinit: - assert len(self.seq_ids) == len(seq_ids) # type: ignore - for i, seq_id in enumerate(seq_ids): - self.seq_ids[i] = seq_id # type: ignore - else: - self.seq_ids = seq_ids - - self.request_id = request_id - self.is_prompt = is_prompt - self.block_tables = block_tables - self.computed_block_nums = computed_block_nums - self.n_seqs = n_seqs - self.encoder_seq_len = encoder_seq_len - - if reinit: - if len(self.seq_ids) == 1 and reinit_use_defaults: - self.simple_reinit() - else: - if input_tokens: - self.input_tokens = input_tokens - else: - for seq_id in range(len(self.seq_ids)): - self.input_tokens[seq_id].clear() - self.inputs_embeds = inputs_embeds - - if input_positions: - self.input_positions = input_positions - else: - for seq_id in range(len(self.seq_ids)): - self.input_positions[seq_id].clear() - - if token_types: - self.token_types = token_types - else: - for seq_id in range(len(self.seq_ids)): - self.token_types[seq_id].clear() - - self.mrope_input_positions = None - - if seq_lens: - self.seq_lens = seq_lens - else: - for seq_id in range(len(self.seq_ids)): - self.seq_lens[seq_id] = 0 - - if orig_seq_lens: - self.orig_seq_lens = orig_seq_lens - else: - for seq_id in range(len(self.seq_ids)): - self.orig_seq_lens[seq_id] = 0 - - if query_lens: - self.query_lens = query_lens - else: - for seq_id in range(len(self.seq_ids)): - self.query_lens[seq_id] = 0 - - if context_lens: - self.context_lens = context_lens - else: - for seq_id in range(len(self.seq_ids)): - self.context_lens[seq_id] = 0 - - if curr_sliding_window_blocks: - self.curr_sliding_window_blocks = \ - curr_sliding_window_blocks - else: - for seq_id in range(len(self.seq_ids)): - self.curr_sliding_window_blocks[seq_id] = 0 - - if lora_index_mapping: - self.lora_index_mapping = lora_index_mapping - else: - self.lora_index_mapping.clear() - if lora_prompt_mapping: - self.lora_prompt_mapping = lora_prompt_mapping - else: - self.lora_prompt_mapping.clear() - if lora_requests: - self.lora_requests = lora_requests - else: - self.lora_requests.clear() - - else: - self.input_tokens = input_tokens or [] - self.inputs_embeds = inputs_embeds - self.input_positions = input_positions or [] - self.token_types = token_types or [] - self.mrope_input_positions = mrope_input_positions or None - self.seq_lens = seq_lens or [] - self.orig_seq_lens = orig_seq_lens or [] - self.query_lens = query_lens or [] - self.context_lens = context_lens or [] - self.curr_sliding_window_blocks = \ - curr_sliding_window_blocks or [] - - self.lora_index_mapping = lora_index_mapping or [] - self.lora_prompt_mapping = lora_prompt_mapping or [] - self.lora_requests = lora_requests or set() - - self.multi_modal_kwargs = multi_modal_kwargs - self.multi_modal_placeholder_maps = multi_modal_placeholder_maps - self.prefix_cache_hit = prefix_cache_hit - - self.n_seqs = len(self.seq_ids) - - if not reinit: - self.__post_init__() - - def __post_init__(self): - self.n_seqs = len(self.seq_ids) - - self.input_tokens = [[] for _ in range(self.n_seqs)] - self.input_positions = [[] for _ in range(self.n_seqs)] - self.token_types = [[] for _ in range(self.n_seqs)] - self.mrope_input_positions = None - self.seq_lens = [0] * self.n_seqs - self.orig_seq_lens = [0] * self.n_seqs - self.query_lens = [0] * self.n_seqs - self.context_lens = [0] * self.n_seqs - self.curr_sliding_window_blocks = [0] * self.n_seqs - - self.lora_index_mapping = [] - self.lora_prompt_mapping = [] - - def __repr__(self) -> str: - return (f"InterDataForSeqGroup(" - f"request_id={self.request_id}, " - f"seq_ids={self.seq_ids}, " - f"is_prompt={self.is_prompt}, " - f"block_tables={self.block_tables}, " - f"computed_block_nums={self.computed_block_nums}, " - f"n_seqs={self.n_seqs}, " - f"input_tokens={self.input_tokens}, " - f"inputs_embeds.shape=" - f"{getattr(self.inputs_embeds, 'shape', None)}, " - f"input_positions={self.input_positions}, " - f"token_types={self.token_types}, " - f"mrope_input_positions={self.mrope_input_positions}, " - f"seq_lens={self.seq_lens}, " - f"orig_seq_lens={self.orig_seq_lens}, " - f"query_lens={self.query_lens}, " - f"context_lens={self.context_lens}, " - f"multi_modal_kwargs={self.multi_modal_kwargs}") - - def __init__(self, - runner, - finished_requests_ids: Optional[List[str]] = None): - super().__init__() - # Compute functions for each sequence in a sequence group. - # WARNING: The order of the functions matters! - self.per_seq_compute_fns = [ - self._compute_lens, - self._compute_for_prefix_cache_hit, - self._compute_for_sliding_window, - self._compute_lora_input, - ] - # Compute functions for each sequence group. - # WARNING: The order of the functions matters! - self.per_seq_group_compute_fns = [ - self._compute_multi_modal_input, - ] - - self.runner = runner - self.model_input_cls = self.runner._model_input_cls - self.attn_backend = self.runner.attn_backend - self.scheduler_config = self.runner.scheduler_config - self.sliding_window = self.runner.sliding_window - self.block_size = self.runner.block_size - self.enable_lora = self.runner.lora_config is not None - self.finished_requests_ids = finished_requests_ids - self.decode_only = True - self.is_encoder_decoder = self.runner.model_config.is_encoder_decoder - - # Attention metadata inputs. - self.attn_metadata_builder = self.attn_backend.make_metadata_builder( - weakref.proxy(self)) - - # Engine/Model configurations. - self.chunked_prefill_enabled = ( - self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled) - if self.sliding_window is not None: - self.sliding_window_blocks = ( - self.sliding_window + self.block_size - 1) // self.block_size - self.block_aligned_sliding_window = \ - self.sliding_window_blocks * self.block_size - - def prepare(self, - finished_requests_ids: Optional[List[str]] = None) -> None: - self.finished_requests_ids = finished_requests_ids - - # if the current batch is decode-only. - # will be set to False if there is any non-decode request. - self.decode_only = True - - # Intermediate data (data in CPU before going to NPU) for - # the current sequence group. - self.inter_data_list: List[ - ModelInputForNPUBuilder.InterDataForSeqGroup] = [] - - self.attn_metadata_builder.prepare() - - def gen_inter_data_builder(self, num_seqs: int): - return lambda: ModelInputForNPUBuilder.InterDataForSeqGroup( - request_id="", - seq_ids=[0] * num_seqs, - is_prompt=True, - block_tables=None, - computed_block_nums=[]) - - def init_cached_inter_data(self, *args, **kwargs): - assert len(args) == 0 - assert "seq_ids" in kwargs - seq_ids = kwargs["seq_ids"] - num_seqs = len(seq_ids) - - # The inter-data cache is per model_runner - inter_data_cache = self.runner.inter_data_cache - if num_seqs not in inter_data_cache: - inter_data_cache[num_seqs] = PyObjectCache( - self.gen_inter_data_builder(num_seqs)) - - obj = inter_data_cache[num_seqs].get_object() - obj.__init__(*args, **kwargs) - return obj - - def reset_cached_inter_data(self): - for cache in self.runner.inter_data_cache.values(): - cache.reset() - - def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): - """Add a sequence group to the builder.""" - seq_ids = seq_group_metadata.seq_data.keys() - n_seqs = len(seq_ids) - is_prompt = seq_group_metadata.is_prompt - - if is_prompt: - assert n_seqs == 1 - self.decode_only = False - - encoder_seq_len = 0 - - if self.is_encoder_decoder: - encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() - - inter_data = self.init_cached_inter_data( - request_id=seq_group_metadata.request_id, - seq_ids=seq_ids, - is_prompt=is_prompt, - block_tables=seq_group_metadata.block_tables, - computed_block_nums=seq_group_metadata.computed_block_nums, - reinit=True, - reinit_use_defaults=True, - encoder_seq_len=encoder_seq_len) - - self.inter_data_list.append(inter_data) - - for seq_idx in range(n_seqs): - for per_seq_fn in self.per_seq_compute_fns: - per_seq_fn(inter_data, seq_idx, seq_group_metadata) - for per_seq_group_fn in self.per_seq_group_compute_fns: - per_seq_group_fn(inter_data, seq_group_metadata) - - def build(self) -> ModelInputForNPU: - """Finalize the builder intermediate data and - create on-device tensors. - """ - # Combine and flatten intermediate data. - input_tokens = list[int]() - inputs_embeds_list = list[torch.Tensor]() - token_types = list[int]() - for inter_data in self.inter_data_list: - for cur_input_tokens in inter_data.input_tokens: - input_tokens.extend(cur_input_tokens) - for cur_token_types in inter_data.token_types: - token_types.extend(cur_token_types) - if inter_data.inputs_embeds is not None: - inputs_embeds_list.append( - inter_data.inputs_embeds.to( - dtype=self.runner.model_config.dtype, - device=self.runner.device)) - - inputs_embeds: Optional[torch.Tensor] - if len(inputs_embeds_list) == 0: - inputs_embeds = None - else: - inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to( - dtype=self.runner.model_config.dtype, - device=self.runner.device) - assert len(inputs_embeds) == len(input_tokens) - - if not input_tokens and inputs_embeds is None: - # This may happen when all prefill requests hit - # prefix caching and there is no decode request. - return self.model_input_cls() - - mrope_input_positions: Optional[List[List[int]]] = None - if any(inter_data.mrope_input_positions is not None - for inter_data in self.inter_data_list): - mrope_input_positions = [[] for _ in range(3)] - - for idx in range(3): - for inter_data in self.inter_data_list: - msections = inter_data.mrope_input_positions - if msections is None: - for _seq_input_positions in inter_data.input_positions: - mrope_input_positions[idx].extend( - _seq_input_positions) - else: - for _seq_mrope_input_positions in msections: - mrope_input_positions[idx].extend( - _seq_mrope_input_positions[idx]) - input_positions = None - else: - input_positions = [ - flatten_2d_lists(inter_data.input_positions) - for inter_data in self.inter_data_list - ] - - seq_lens = [] - max_decode_seq_len = 0 - is_prompt = self.inter_data_list[0].is_prompt - for inter_data in self.inter_data_list: - seq_lens.extend(inter_data.seq_lens) - if not inter_data.is_prompt: - max_decode_seq_len = max(max_decode_seq_len, - max(inter_data.seq_lens)) - query_lens = flatten_2d_lists( - [inter_data.query_lens for inter_data in self.inter_data_list]) - # Mapping from request IDs to sequence IDs. Used for Jamba models - # that manages the cache by itself. - request_ids_to_seq_ids = { - data.request_id: data.seq_ids - for data in self.inter_data_list - } - - # Add graph_pad_size here - if self.runner.torchair_graph_enabled: - graph_pad_size = self.runner.scheduler_config.max_num_seqs - len( - seq_lens) - else: - graph_pad_size = -1 - - if input_positions: - input_positions = flatten_2d_lists(input_positions) - if graph_pad_size != -1 and not is_prompt: - input_tokens.extend(itertools.repeat(0, graph_pad_size)) - input_positions.extend( # type: ignore - itertools.repeat(0, graph_pad_size)) - seq_lens.extend(itertools.repeat(1, graph_pad_size)) - query_lens.extend(itertools.repeat(1, graph_pad_size)) - input_tokens_tensor = torch.tensor(input_tokens, - dtype=torch.long, - device=self.runner.device) - token_types_tensor = torch.tensor(token_types, - dtype=torch.long, - device=self.runner.device) \ - if token_types else None - if mrope_input_positions is not None: - input_positions_tensor = torch.tensor(mrope_input_positions, - dtype=torch.long, - device=self.runner.device) - else: - input_positions_tensor = torch.tensor(input_positions, - dtype=torch.long, - device=self.runner.device) - #print(f"after tensor input_tokens_tensor: {input_tokens_tensor}") - #print(f"after tensor input_positions_tensor: {input_positions_tensor}") - #print(f"after list seq_lens: {seq_lens}") - - # Attention metadata. - attn_metadata = self.attn_metadata_builder.build( - seq_lens, query_lens, graph_pad_size) - - # LoRA data. - lora_requests = set() - lora_mapping = None - if self.enable_lora: - lora_requests = set(r for data in self.inter_data_list - for r in data.lora_requests) - lora_index_mapping = flatten_2d_lists([ - flatten_2d_lists(inter_data.lora_index_mapping) - for inter_data in self.inter_data_list - ]) - lora_prompt_mapping = flatten_2d_lists([ - flatten_2d_lists(inter_data.lora_prompt_mapping) - for inter_data in self.inter_data_list - ]) - lora_mapping = LoRAMapping( - **dict(index_mapping=lora_index_mapping, - prompt_mapping=lora_prompt_mapping, - is_prefill=not self.decode_only)) - - # Multi-modal data. - multi_modal_kwargs_list = [ - data.multi_modal_kwargs for data in self.inter_data_list - if data.multi_modal_kwargs is not None - ] - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - - if self.runner.torchair_graph_enabled: - torch._dynamo.mark_static(input_tokens_tensor) - torch._dynamo.mark_static(input_positions_tensor) - torch._dynamo.mark_static(attn_metadata.block_tables) - torch._dynamo.mark_static(attn_metadata.slot_mapping) - - return self.model_input_cls( - input_tokens=input_tokens_tensor, - inputs_embeds=inputs_embeds, - token_types=token_types_tensor, - input_positions=input_positions_tensor, - attn_metadata=attn_metadata, - seq_lens=seq_lens, - query_lens=query_lens, - lora_mapping=lora_mapping, - lora_requests=lora_requests, - multi_modal_kwargs=multi_modal_kwargs, - request_ids_to_seq_ids=request_ids_to_seq_ids, - finished_requests_ids=self.finished_requests_ids) - - def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): - """Compute context length, sequence length and tokens - for the given sequence data. - """ - seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]] - token_chunk_size = seq_group_metadata.token_chunk_size - - # Compute context length (the number of tokens that are - # already computed) and sequence length (total number of tokens). - - seq_len = seq_data.get_len() - if inter_data.is_prompt: - context_len = seq_data.get_num_computed_tokens() - seq_len = min(seq_len, context_len + token_chunk_size) - elif self.runner.scheduler_config.is_multi_step or \ - self.is_encoder_decoder: - context_len = seq_len - 1 - else: - context_len = seq_data.get_num_computed_tokens() - - # Compute tokens. - # Fixme: this is for the version compatibility, remove this once vllm v0.8.5 does not be supported. - if not hasattr(seq_data, - "prompt_embeds") or seq_data.prompt_embeds is None: - tokens = seq_data.get_token_ids()[context_len:seq_len] - prompt_embeds = None - else: - tokens = [0] * (seq_len - context_len) - prompt_embeds = seq_data.get_token_embeddings( - )[context_len:seq_len] - - token_types = seq_group_metadata.token_type_ids - - inter_data.seq_lens[seq_idx] = seq_len - inter_data.orig_seq_lens[seq_idx] = seq_len - inter_data.context_lens[seq_idx] = context_len - inter_data.input_tokens[seq_idx].extend(tokens) - inter_data.inputs_embeds = prompt_embeds - inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) - inter_data.token_types[seq_idx].extend( - token_types if token_types else []) - inter_data.query_lens[seq_idx] = seq_len - context_len - - if seq_data.mrope_position_delta is not None: - if inter_data.mrope_input_positions is None: - inter_data.mrope_input_positions = [None] * inter_data.n_seqs - - inter_data.mrope_input_positions[ - seq_idx] = MRotaryEmbedding.get_next_input_positions( - seq_data.mrope_position_delta, - context_len, - seq_len, - ) - - def _compute_for_prefix_cache_hit( - self, inter_data: InterDataForSeqGroup, seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): - """Check if hit prefix cache (i.e., some blocks are already computed). - If hit, update input tokens and positions to only compute the - remaining blocks. - """ - computed_block_nums = inter_data.computed_block_nums - - # Note that prefix caching does not support sliding window. - prefix_cache_hit = (computed_block_nums is not None - and len(computed_block_nums) > 0 - and self.sliding_window is None - and inter_data.is_prompt) - inter_data.prefix_cache_hit = prefix_cache_hit - - if not prefix_cache_hit: - return - - assert computed_block_nums is not None - # The cache hit prompt tokens in this sequence. Note that - # this may be larger than the sequence length if chunked - # prefill is enabled. - prefix_cache_len = len(computed_block_nums) * self.block_size - - # The total number of prompt tokens in this sequence. - # When chunked prefill is enabled, this is the token number of - # computed chunks + current chunk. - seq_len = inter_data.seq_lens[seq_idx] - - # When full hit, compute the last block rather than the last token, - # due to the requirements of prefix operator. - if seq_len <= prefix_cache_len: - prefix_cache_len -= self.block_size - - seq_group_metadata.seq_data[inter_data.seq_ids[ - seq_idx]].update_num_cached_tokens(prefix_cache_len) - - # The number of so far computed prompt tokens in this sequence. - context_len = inter_data.context_lens[seq_idx] - - if prefix_cache_len <= context_len: - # We already passed the cache hit region, - # so do normal computation. - pass - elif context_len < prefix_cache_len < seq_len: - # Partial hit. Compute the missing part. - uncomputed_start = prefix_cache_len - context_len - inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ - seq_idx][uncomputed_start:] - inter_data.input_positions[seq_idx] = inter_data.input_positions[ - seq_idx][uncomputed_start:] - inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ - uncomputed_start:] - context_len = prefix_cache_len - - inter_data.context_lens[seq_idx] = context_len - inter_data.query_lens[ - seq_idx] = inter_data.seq_lens[seq_idx] - context_len - elif seq_len <= prefix_cache_len: - # Full hit. Only compute the last token to avoid - # erroneous behavior. FIXME: Ideally we should directly - # mark all tokens as computed in the scheduler and do not - # schedule this sequence, so this case should not happen. - inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ - seq_idx][-1:] - inter_data.input_positions[seq_idx] = inter_data.input_positions[ - seq_idx][-1:] - inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ - -1:] - inter_data.query_lens[seq_idx] = 1 - inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1 - - def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup, - seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): - """Update seq_len and curr_sliding_window_block for the given - sequence data (only required by decoding) if sliding window is enabled. - """ - curr_sliding_window_block = 0 - sliding_seq_len = inter_data.seq_lens[seq_idx] - if not inter_data.is_prompt and self.sliding_window is not None: - # TODO(sang): This is a hack to make sliding window work with - # paged attn. We can remove it if we make paged attn kernel - # to properly handle slinding window attn. - curr_sliding_window_block = self.sliding_window_blocks - # number of elements in last block - suff_len = inter_data.seq_lens[seq_idx] % self.block_size - sliding_seq_len = min(inter_data.seq_lens[seq_idx], - self.block_aligned_sliding_window + suff_len) - if suff_len > 0: - curr_sliding_window_block += 1 - - inter_data.curr_sliding_window_blocks[ - seq_idx] = curr_sliding_window_block - inter_data.seq_lens[seq_idx] = sliding_seq_len - - def _compute_lora_input(self, inter_data: InterDataForSeqGroup, - seq_idx: int, - seq_group_metadata: SequenceGroupMetadata): - """If LoRA is enabled, compute LoRA index and prompt mapping.""" - if not self.enable_lora: - return - lora_id = seq_group_metadata.lora_int_id - if lora_id > 0: - inter_data.lora_requests.add(seq_group_metadata.lora_request) - query_len = inter_data.query_lens[seq_idx] - inter_data.lora_index_mapping.append([lora_id] * query_len) - sampling_params = seq_group_metadata.sampling_params - if sampling_params and sampling_params.prompt_logprobs is not None: - inter_data.lora_prompt_mapping.append([lora_id] * query_len) - elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample: - inter_data.lora_prompt_mapping.append([lora_id]) - else: - inter_data.lora_prompt_mapping.append([]) - - def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, - seq_group_metadata: SequenceGroupMetadata): - """If multi-modal data is given, add it to the input.""" - # NOTE: mm_kwargs only includes the subset of multi-modal items that - # intersect with the current prefill positions. - positions = inter_data.input_positions[0] - mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( - seq_group_metadata, - range(positions[0], positions[0] + len(positions))) - if not mm_kwargs: - return - - inter_data.multi_modal_kwargs = mm_kwargs - inter_data.multi_modal_placeholder_maps = placeholder_maps - - # special processing for mrope position deltas. - if self.runner.model_config.uses_mrope: - image_grid_thw = mm_kwargs.get("image_grid_thw", None) - video_grid_thw = mm_kwargs.get("video_grid_thw", None) - assert image_grid_thw is not None or video_grid_thw is not None, ( - "mrope embedding type requires multi-modal input mapper " - "returns 'image_grid_thw' or 'video_grid_thw'.") - second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None) - - hf_config = self.runner.model_config.hf_config - - inter_data.mrope_input_positions = [None] * inter_data.n_seqs - for seq_idx in range(inter_data.n_seqs): - seq_data = seq_group_metadata.seq_data[ - inter_data.seq_ids[seq_idx]] - token_ids = seq_data.get_token_ids() - - mrope_input_positions, mrope_position_delta = \ - MRotaryEmbedding.get_input_positions( - token_ids, - hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=inter_data.context_lens[seq_idx], - seq_len=inter_data.seq_lens[seq_idx], - ) - - seq_data.mrope_position_delta = mrope_position_delta - inter_data.mrope_input_positions[ - seq_idx] = mrope_input_positions - - -class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]): - """ - Helper class for shared methods between NPU model runners. - """ - _model_input_cls: Type[TModelInputForNPU] - _builder_cls: Type[ModelInputForNPUBuilder] - - def __init__( - self, - vllm_config: VllmConfig, - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - return_hidden_states: bool = False, - input_registry: InputRegistry = INPUT_REGISTRY, - mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - ): - - ModelRunnerBase.__init__(self, vllm_config) - model_config = self.model_config - cache_config = self.cache_config - - self.is_driver_worker = is_driver_worker - self.return_hidden_states = return_hidden_states - - self.device = self.device_config.device - self.pin_memory = is_pin_memory_available() - - self.kv_cache_dtype = kv_cache_dtype - self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size - self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture - self.max_batchsize_to_capture = \ - self.vllm_config.compilation_config.max_capture_size - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph - - self.has_inner_state = model_config.has_inner_state - - self.in_profile_run = False - - self.graph_block_tables = np.zeros( - (self.vllm_config.scheduler_config.max_num_seqs, - (model_config.max_model_len + self.block_size - 1) // - self.block_size), - dtype=np.int32) - - # Attention-free but stateful models like Mamba need a placeholder attn - # backend, as the attention metadata is needed to manage internal state. - # However we must bypass attention selection altogether for some models - # used for speculative decoding to avoid a divide-by-zero in - # model_config.get_head_size() - num_attn_heads = self.model_config.get_num_attention_heads( - self.parallel_config) - needs_attn_backend = (num_attn_heads != 0 - or self.model_config.is_attention_free) - - self.attn_backend = get_attn_backend( - self.model_config.get_head_size(), - self.model_config.dtype, - self.kv_cache_dtype, - self.block_size, - self.model_config.is_attention_free, - ) if needs_attn_backend else None - if self.attn_backend: - self.attn_state = self.attn_backend.get_state_cls()( - weakref.proxy(self)) - else: - self.attn_state = CommonAttentionState(weakref.proxy(self)) - - # Multi-modal data support - self.input_registry = input_registry - self.mm_registry = mm_registry - - # Lazy initialization - self.model: nn.Module # Set after load_model - # Set after load_model. - self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None - - set_cpu_offload_max_bytes( - int(self.cache_config.cpu_offload_gb * 1024**3)) - - # Used to cache python objects - self.inter_data_cache: Dict[int, PyObjectCache] = {} - - # Using the PythonizationCache in Pipeline-Parallel clobbers the - # SequenceGroupToSample object. In Pipeline-Parallel, we have - # more than 1 Scheduler, resulting in a potential back-to-back - # prepare_model_inputs() call. This clobbers the cached - # SequenceGroupToSample objects, as we reset the cache during - # every prepare_model_inputs() call. - self.sampling_metadata_cache: SamplingMetadataCache = \ - SamplingMetadataCache() \ - if self.parallel_config.pipeline_parallel_size == 1 else None - self.sampler = get_sampler() - - def get_model(self) -> nn.Module: - return self.model - - def load_model(self) -> None: - logger.info("Starting to load model %s...", self.model_config.model) - with DeviceMemoryProfiler() as m: - self.model = get_model(vllm_config=self.vllm_config) - - self.model_memory_usage = m.consumed_memory - logger.info("Loading model weights took %.4f GB", - self.model_memory_usage / float(2**30)) - - if self.lora_config: - assert supports_lora( - self.model - ), f"{self.model.__class__.__name__} does not support LoRA yet." - if supports_multimodal(self.model): - logger.warning("Regarding multimodal models, vLLM currently " - "only supports adding LoRA to language model.") - # It's necessary to distinguish between the max_position_embeddings - # of VLMs and LLMs. - if hasattr(self.model.config, "max_position_embeddings"): - max_pos_embeddings = self.model.config.max_position_embeddings - else: - max_pos_embeddings = ( - self.model.config.text_config.max_position_embeddings) - self.lora_manager = LRUCacheWorkerLoRAManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, - self.vocab_size, - self.lora_config, - self.device, - self.model.embedding_modules, - self.model.embedding_padding_modules, - max_position_embeddings=max_pos_embeddings, - ) - self.model = self.lora_manager.create_lora_manager(self.model) - - # adapter torch compile with npu_backend - if self.torchair_graph_enabled: - import torchair # type: ignore - from torchair import patch_for_hcom # type: ignore - - # 通信算子成图 - patch_for_hcom() - # 设置npu的config,如果不设置config,可以使用默认的,那可以设置npu_backend="npu" - config = torchair.CompilerConfig() - config.experimental_config.frozen_parameter = True - config.experimental_config.tiling_schedule_optimize = True - config.experimental_config.enable_view_optimize = \ - get_ascend_config().torchair_graph_config.enable_view_optimize - torch.npu.set_compile_mode(jit_compile=False) - if not self.use_cached_npu_graph: - npu_backend = torchair.get_npu_backend(compiler_config=config) - self.compile_model = torch.compile( - self.model, - dynamic=True, - fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend=npu_backend) - else: - self.compile_model = torchair.inference.cache_compile( - self.model.forward, - dynamic=True, - fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - config=config, - ge_cache=False) - - def save_sharded_state( - self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None, - ) -> None: - - from vllm.model_executor.model_loader import ShardedStateLoader - ShardedStateLoader.save_model( - self.model, - path, - pattern=pattern, - max_size=max_size, - ) - - def save_tensorized_model( - self, - tensorizer_config: TensorizerConfig, - ) -> None: - - from vllm.model_executor.model_loader import \ - TensorizerLoader # type: ignore # noqa - TensorizerLoader.save_model( - self.model, - tensorizer_config=tensorizer_config, - ) - - def get_max_block_per_batch(self) -> int: - block_size = self.block_size - return (self.max_seq_len_to_capture + block_size - 1) // block_size - - def _prepare_model_input_tensors( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - finished_requests_ids: Optional[List[str]] = None - ) -> TModelInputForNPU: - """Helper method to prepare the model input based on a given sequence - group. Prepares metadata needed for the base model forward pass but not - metadata for possible additional steps, e.g., sampling. - - The API assumes seq_group_metadata_list is sorted by prefill -> decode. - - The result tensors and data structure also batches input in prefill - -> decode order. For example, - - - input_tokens[:num_prefill_tokens] contains prefill tokens. - - input_tokens[num_prefill_tokens:] contains decode tokens. - """ - builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) - builder.prepare(finished_requests_ids) - for seq_group_metadata in seq_group_metadata_list: - builder.add_seq_group(seq_group_metadata) - - builder.reset_cached_inter_data() - - return builder.build() # type: ignore - - @contextmanager - def set_in_profile_run(self): - self.in_profile_run = True - try: - yield - finally: - self.in_profile_run = False - - @torch.inference_mode() - def profile_run(self) -> None: - with self.set_in_profile_run(): - # Enable top-k sampling to reflect the accurate memory usage. - sampling_params = \ - SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) - max_num_batched_tokens = \ - self.scheduler_config.max_num_batched_tokens - max_num_seqs = self.scheduler_config.max_num_seqs - - # Profile memory usage with max_num_sequences sequences and the - # total number of tokens equal to max_num_batched_tokens. - seqs: List[SequenceGroupMetadata] = [] - # Additional GPU memory may be needed for multi-modal encoding, - # which needs to be accounted for when calculating the GPU blocks - # for vLLM blocker manager. - # To exercise the worst scenario for GPU memory consumption, - # the number of seqs (batch_size) is chosen to maximize the number - # of images processed. - - max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( - self.model_config) - if max_mm_tokens > 0: - max_num_seqs_orig = max_num_seqs - max_num_seqs = min(max_num_seqs, - max_num_batched_tokens // max_mm_tokens) - if max_num_seqs < 1: - expr = (f"min({max_num_seqs_orig}, " - f"{max_num_batched_tokens} // {max_mm_tokens})") - logger.warning( - "Computed max_num_seqs (%s) to be less than 1. " - "Setting it to the minimum value of 1.", expr) - max_num_seqs = 1 - - batch_size = 0 - for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) - batch_size += seq_len - - dummy_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry) - - seq = SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=True, - seq_data={group_id: dummy_data.seq_data}, - sampling_params=sampling_params, - block_tables=None, - lora_request=None, - multi_modal_data=dummy_data.multi_modal_data, - multi_modal_placeholders=dummy_data. - multi_modal_placeholders, - ) - seqs.append(seq) - - # Run the model with the dummy inputs. - num_layers = self.model_config.get_num_layers(self.parallel_config) - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - # it is important to create tensors inside the loop, rather than - # multiplying the list, to avoid Dynamo from treating them as - # tensor aliasing. - kv_caches = [ - torch.tensor([], dtype=torch.float32, device=self.device) - for _ in range(num_layers) - ] - finished_requests_ids = [seq.request_id for seq in seqs] - model_input = self.prepare_model_input( - seqs, finished_requests_ids=finished_requests_ids) - intermediate_tensors = None - if not get_pp_group().is_first_rank: - intermediate_tensors = \ - self.model.make_empty_intermediate_tensors( - batch_size=batch_size, - dtype=self.model_config.dtype, - device=self.device) - - self.execute_model(model_input, kv_caches, intermediate_tensors) - torch.npu.synchronize() - return - - def remove_all_loras(self): - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - self.lora_manager.remove_all_adapters() - - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - self.lora_manager.set_active_adapters(lora_requests, lora_mapping) - - def add_lora(self, lora_request: LoRARequest) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.add_adapter(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.remove_adapter(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.pin_adapter(lora_id) - - def list_loras(self) -> Set[int]: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.list_adapters() - - def remove_all_prompt_adapters(self): - raise RuntimeError("PromptAdapter is not supported on NPU now.") - - def set_active_prompt_adapters( - self, prompt_adapter_requests: Set[PromptAdapterRequest], - prompt_adapter_mapping: PromptAdapterMapping) -> None: - raise RuntimeError("PromptAdapter is not supported on NPU now.") - - def add_prompt_adapter( - self, prompt_adapter_request: PromptAdapterRequest) -> bool: - raise RuntimeError("PromptAdapter is not supported on NPU now.") - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise RuntimeError("PromptAdapter is not supported on NPU now.") - - def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise RuntimeError("PromptAdapter is not supported on NPU now.") - - def list_prompt_adapters(self) -> Set[int]: - raise RuntimeError("PromptAdapter is not supported on NPU now.") - - @property - def vocab_size(self) -> int: - return self.model_config.get_vocab_size() - - -class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]): - """ - NPU model runner with sampling step. - """ - _model_input_cls: Type[ModelInputForNPUWithSamplingMetadata] = ( - ModelInputForNPUWithSamplingMetadata) - _builder_cls: Type[ModelInputForNPUBuilder] = ModelInputForNPUBuilder - - def make_model_input_from_broadcasted_tensor_dict( - self, - tensor_dict: Dict[str, Any], - ) -> ModelInputForNPUWithSamplingMetadata: - model_input = \ - ModelInputForNPUWithSamplingMetadata.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - ) - return model_input - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None, - ) -> ModelInputForNPUWithSamplingMetadata: - """Prepare the model input based on a given sequence group, including - metadata for the sampling step. - The API assumes seq_group_metadata_list is sorted by prefill -> decode. - The result tensors and data structure also batches input in prefill - -> decode order. For example, - - input_tokens[:num_prefill_tokens] contains prefill tokens. - - input_tokens[num_prefill_tokens:] contains decode tokens. - """ - model_input = self._prepare_model_input_tensors( - seq_group_metadata_list, finished_requests_ids) - if get_pp_group().is_last_rank: - # Sampling metadata is only required for the final pp group - generators = self.get_generators(finished_requests_ids) - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - model_input.seq_lens, - model_input.query_lens, - self.device, - self.pin_memory, - generators, - self.sampling_metadata_cache, - # TODO (cmq): enable this after supported in vllm - # pad_for_invariant_seq_len=True, - ) - # Get hash value of request id list to perform sampling param cache in sampler. - request_ids = model_input.request_ids_to_seq_ids.keys( # type: ignore - ) # type: ignore - request_ids_hash = hash("".join(request_ids)) - sampling_metadata.request_ids_hash = request_ids_hash # type: ignore - else: - sampling_metadata = None - is_prompt = (seq_group_metadata_list[0].is_prompt - if seq_group_metadata_list else None) - return dataclasses.replace(model_input, - sampling_metadata=sampling_metadata, - is_prompt=is_prompt, - virtual_engine=virtual_engine) - - @torch.inference_mode() - def execute_model( - self, - model_input: ModelInputForNPUWithSamplingMetadata, - kv_caches: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - **kwargs, - ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: - if num_steps > 1: - raise ValueError("num_steps > 1 is not supported in ModelRunner") - - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) - - self.attn_state.begin_forward(model_input) - - assert model_input.attn_metadata is not None - # TODO(zzzzwwjj): Do we need to do it every time? - if self.torchair_graph_enabled: - torch._dynamo.mark_static(model_input.input_tokens) - torch._dynamo.mark_static(model_input.input_positions) - torch._dynamo.mark_static(model_input.attn_metadata.block_tables) - torch._dynamo.mark_static(model_input.attn_metadata.slot_mapping) - for kv in kv_caches: - if isinstance(kv, tuple): - torch._dynamo.mark_static(kv[0]) - torch._dynamo.mark_static(kv[1]) - - # TODO(andoorve): We can remove this once all - # virtual engines share the same kv cache. - virtual_engine = model_input.virtual_engine - prefill_meta = model_input.attn_metadata.prefill_metadata - previous_hidden_states = kwargs.get("previous_hidden_states") - if prefill_meta is None and self.torchair_graph_enabled: - model_executable = self.compile_model - # Note: graph_batch_size value not same as GPU - graph_batch_size = model_input.input_tokens.shape[ # type: ignore - 0] # type: ignore - # Note: previous_hidden_states maybe None not same as GPU - if previous_hidden_states is not None: - previous_hidden_states = torch.cat([ - previous_hidden_states, - torch.empty([ - graph_batch_size - previous_hidden_states.shape[0], - *previous_hidden_states.shape[1:] - ], - dtype=previous_hidden_states.dtype, - device=previous_hidden_states.device) - ]) - else: - model_executable = self.model - - # Receive KV cache in distributed KV cache transfer setting - # In disagg prefill setting, it will also recv hidden states and bypass - # model forwarding - # In KV cache database setting, it will change the model input so that - # we can skip prefilling on tokens that successfully received KV caches - # NOTE: The receive operation is blocking - bypass_model_exec = False - if self.need_recv_kv(model_input, kv_caches): - hidden_or_intermediate_states, bypass_model_exec, model_input = \ - get_kv_transfer_group().recv_kv_caches_and_hidden_states( - # model is used to know which layer the current worker - # is working on, so that we can receive KV for only those - # layers. - model_executable, - model_input, - kv_caches=kv_caches - ) - - if get_dp_group().world_size > 1: - bypass_model_exec_tensor = torch.tensor( - 1, dtype=torch.int32) if bypass_model_exec else torch.tensor( - 0, dtype=torch.int32) - torch.distributed.all_reduce(bypass_model_exec_tensor, - op=torch.distributed.ReduceOp.MIN, - group=get_dp_group().cpu_group) - # If there is any group have not receive the necessary hidden states or kv_cache, we force all the dp group execute. - if bypass_model_exec_tensor.item() == 0: - bypass_model_exec = False - - multi_modal_kwargs = model_input.multi_modal_kwargs or {} - seqlen_agnostic_kwargs = { - "finished_requests_ids": model_input.finished_requests_ids, - "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, - } if self.has_inner_state else {} - - if self.torchair_graph_enabled: - model_kwargs: Dict[str, Any] = {"inputs_embeds": None} - else: - model_kwargs = {} - if previous_hidden_states is not None: - model_kwargs["previous_hidden_states"] = previous_hidden_states - - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_start = torch.npu.Event(enable_timing=True) - model_forward_end = torch.npu.Event(enable_timing=True) - model_forward_start.record() - - if not bypass_model_exec: - with set_forward_context(model_input.attn_metadata, - self.vllm_config, virtual_engine): - if model_input.attn_metadata is not None: - model_input.attn_metadata.input_positions = model_input.input_positions - if self.torchair_graph_enabled: - model_kwargs["kv_caches"] = kv_caches - model_kwargs["attn_metadata"] = model_input.attn_metadata - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - inputs_embeds=model_input.inputs_embeds, - positions=model_input.input_positions, - intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs, - **model_kwargs) - - # Compute the logits in the last pipeline stage. - if not get_pp_group().is_last_rank: - if (self.is_driver_worker - and hidden_or_intermediate_states is not None - and isinstance(hidden_or_intermediate_states, - IntermediateTensors) - and self.observability_config is not None and - self.observability_config.collect_model_forward_time): - model_forward_end.synchronize() - model_forward_time = model_forward_start.elapsed_time( - model_forward_end) - orig_model_forward_time = 0.0 - if intermediate_tensors is not None: - orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() - hidden_or_intermediate_states.tensors[ - "model_forward_time"] = ( - torch.tensor(model_forward_time + - orig_model_forward_time)) - return hidden_or_intermediate_states - - logits = self.model.compute_logits(hidden_or_intermediate_states, - model_input.sampling_metadata) - - # Sending KV cache in distributed KV cache transfer setting - if self.need_send_kv(model_input, kv_caches): - get_kv_transfer_group().send_kv_caches_and_hidden_states( - # model_executable is used to know which layer the current - # worker is working on, so that we can send KV for only those - # layers. - model_executable, - model_input, - kv_caches, - hidden_or_intermediate_states, - ) - - if self.is_driver_worker: - if model_input.async_callback is not None: - model_input.async_callback() - - # Sample the next token. - assert isinstance(self.sampler, Sampler) - orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor - if model_input.inputs_embeds is not None: - self.sampler.include_gpu_probs_tensor = True - - output: SamplerOutput = self.sampler( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time - and output is not None): - model_forward_end.synchronize() - model_forward_time = model_forward_start.elapsed_time( - model_forward_end) - orig_model_forward_time = 0.0 - if intermediate_tensors is not None: - orig_model_forward_time = intermediate_tensors.tensors.get( - "model_forward_time", torch.tensor(0.0)).item() - # If there are multiple workers, we are still tracking the - # latency from the start time of the driver worker to the end - # time of the driver worker. The model forward time will then - # end up covering the communication time as well. - output.model_forward_time = (orig_model_forward_time + - model_forward_time) - - if model_input.inputs_embeds is not None: - if self.is_driver_worker: - sampled = broadcast_tensor_dict( - {"token_ids": output.sampled_token_ids}) - else: - sampled = broadcast_tensor_dict() - if sampled["token_ids"] is not None: - sampled_token_embeds = self.model.get_input_embeddings( - sampled["token_ids"].squeeze(1)) - if self.is_driver_worker: - self.sampler.include_gpu_probs_tensor = \ - orig_include_gpu_probs - - output.sampled_token_embeds = sampled_token_embeds - - for token_embed, sequence_group_output in zip( - output.sampled_token_embeds, output.outputs): - assert len(sequence_group_output.samples) == 1 - sequence_group_output.samples[ - 0].output_embed = token_embed - - if not self.is_driver_worker: - return [] - - if self.return_hidden_states: - # we only need to pass hidden states of most recent token - assert model_input.sampling_metadata is not None - indices = model_input.sampling_metadata.selected_token_indices - if model_input.is_prompt: - hidden_states = hidden_or_intermediate_states.index_select( - 0, indices) - output.prefill_hidden_states = hidden_or_intermediate_states - elif self.torchair_graph_enabled: - hidden_states = hidden_or_intermediate_states[:len(indices)] - else: - hidden_states = hidden_or_intermediate_states - - output.hidden_states = hidden_states - - return [output] - - def need_recv_kv(self, model_input, kv_caches) -> bool: - """Check if we need to receive kv-cache from the other worker. - We need to receive KV when - 1. current vLLM instance is KV cache consumer/decode vLLM instance - 2. this batch is not a profiling run - 3. this batch is a prefill run - - Args: - model_input: input to the model executable - kv_caches: vLLM's paged memory - """ - - if self.vllm_config.kv_transfer_config is None: - return False - - prefill_meta = model_input.attn_metadata.prefill_metadata - - # check if the current run is profiling - is_profile_run = (kv_caches[0].numel() == 0) - # check if the current run is prefill - is_prefill_run = prefill_meta is not None - - return self.vllm_config.kv_transfer_config.is_kv_consumer and ( - not is_profile_run) and is_prefill_run - - def need_send_kv(self, model_input, kv_caches) -> bool: - """Check if we need to send kv-cache to the other worker. - We need to send KV when - 1. current vLLM instance is KV cache producer/prefill vLLM instance - 2. this batch is not a profiling run - 3. this batch is a prefill run - - Args: - model_input: input to the model executable - kv_caches: vLLM's paged memory - """ - - if self.vllm_config.kv_transfer_config is None: - return False - - prefill_meta = model_input.attn_metadata.prefill_metadata - - # check if the current run is profiling - is_profile_run = (kv_caches[0].numel() == 0) - # check if the current run is prefill - is_prefill_run = prefill_meta is not None - - return self.vllm_config.kv_transfer_config.is_kv_producer and ( - not is_profile_run) and is_prefill_run