import dataclasses import weakref from contextlib import contextmanager from collections import defaultdict from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Callable,Dict, List, Optional, Set, Type, TypeVar, Union) from vllm.distributed import get_pp_group import torch from torch import nn from vllm.sampling_params import SamplingParams from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.sampler import Sampler, SamplerOutput, get_sampler from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.distributed import broadcast_tensor_dict, get_pp_group from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs, MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) from vllm.utils import is_pin_memory_available from ..model_executor.models.vars import update_seqence_length # from vacc_tools.trace_logger import get_trace_api # trace_time, register_module_trace, trace_autograd_function, register_optimizer_trace = ( # get_trace_api("deepseek") # ) import logging logger = logging.getLogger(__name__) import os from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend logger = init_logger(__name__) TModelInputForVACC = TypeVar('TModelInputForVACC', bound="ModelInputForVACC") _PAD_SLOT_ID = -1 def check_block_table(block_tables, block_num_per_group): is_nested = any(isinstance(item, list) for item in block_tables) if block_tables else False if is_nested: for bs, block_ids in block_tables: if len(block_ids) > 1 and \ (block_ids[-1] - block_ids[-2] != 1 and \ block_ids[-1] % block_num_per_group != 0): logger.error(f"block id not contiguous bs:{bs}, block_ids: {block_ids}") return False else: for i in range(1, len(block_tables)): if len(block_tables) > 1 and \ (block_tables[-1] - block_tables[-2] != 1 and \ block_tables[-1] % block_num_per_group != 0): logger.error(f"block id not contiguous block_ids: {block_tables}") return False return True @dataclass(frozen=True) class ModelInputForVACC(ModelRunnerInputBase): """ Base class contains metadata needed for the base model forward pass on VACC """ input_tokens: Optional[torch.Tensor] = None inputs_embeds: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None token_type_ids: Optional[torch.Tensor] = None attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None virtual_engine: Optional[int] = None seq_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None lora_mapping: Optional["LoRAMapping"] = None lora_requests: Optional[Set[LoRARequest]] = None async_callback: Optional[Callable] = None previous_hidden_states: Optional[torch.Tensor] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: tensor_dict = { "input_tokens": self.input_tokens, "inputs_embeds": self.inputs_embeds, "input_positions": self.input_positions, "token_type_ids": self.token_type_ids, "multi_modal_kwargs": self.multi_modal_kwargs, "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict @classmethod def from_broadcasted_tensor_dict( cls: Type[TModelInputForVACC], tensor_dict: Dict[str, Any], attn_backend: Optional["AttentionBackend"] = None ) -> TModelInputForVACC: if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( attn_backend, tensor_dict) return cls(**tensor_dict) # Exclude `async_callback` to be able to pickle this object def __getstate__(self): state = self.__dict__.copy() del state["async_callback"] return state # TODO: What happens when we depickle this object? # How can we update this callback to properly pass it to the engine? def __setstate__(self, state): self.__dict__.update(state) self.__dict__.update({'async_callback': None}) @dataclass(frozen=True) class ModelInputForVACCWithSamplingMetadata(ModelInputForVACC): """ Used by the ModelRunner. """ sampling_metadata: Optional["SamplingMetadata"] = None # Used for speculative decoding. We do not broadcast it because it is only # used by the driver worker. is_prompt: Optional[bool] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, "inputs_embeds": self.inputs_embeds, "input_positions": self.input_positions, "token_type_ids": self.token_type_ids, "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, self.sampling_metadata) return tensor_dict @classmethod def from_broadcasted_tensor_dict( cls, tensor_dict: Dict[str, Any], attn_backend: Optional["AttentionBackend"] = None, ) -> "ModelInputForVACCWithSamplingMetadata": tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( attn_backend, tensor_dict) return cls(**tensor_dict) class ModelInputForVACCBuilder(ModelRunnerInputBuilderBase[ModelInputForVACC]): class ModelInputData: def __init__(self, use_mrope: bool): self.use_mrope = use_mrope self.input_tokens: List[int] = [] self.inputs_embeds = None # type: ignore self.input_positions: List[int] = [] self.token_type_ids: Optional[List[int]] = [] self.seq_lens: List[int] = [] self.query_lens: List[int] = [] self.prefill_block_tables: List[List[int]] = [] self.decode_block_tables: List[List[int]] = [] # self.max_decode_seq_len: int = 0 self.num_prefills: int = 0 self.num_prefill_tokens: int = 0 self.num_decode_tokens: int = 0 self.slot_mapping: List[int] = [] self.context_lens: List[int] = [] self.multi_modal_inputs_list: List[MultiModalKwargs] = [] self.multi_modal_placeholder_maps: Dict[ str, MultiModalPlaceholderMap] = defaultdict( MultiModalPlaceholderMap) self.input_mrope_positions: List[List[int]] = [[] for _ in range(3)] def __init__(self, runner: "VACCModelRunnerBase", finished_requests_ids: Optional[List[str]] = None) -> None: super().__init__() self.runner = runner self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled or runner.cache_config.enable_prefix_caching) self.model_input_cls = self.runner._model_input_cls self.attn_backend = self.runner.attn_backend self.sliding_window = self.runner.sliding_window self.block_size = self.runner.block_size self.device = self.runner.device self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper self.enable_lora = self.runner.lora_config is not None if self.runner.attn_backend is not None: # spec decode (e.g. Medusa) does not have atten backend attn_backend = self.runner.attn_backend self.att_metadata_builder = attn_backend.get_builder_cls()(self) self.block_num_per_group = env_blk_grp_size // runner.cache_config.block_size def prepare(self, finished_requests_ids: Optional[List[str]] = None) -> None: self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] self.input_data = ModelInputForVACCBuilder.ModelInputData( self.runner.model_config.uses_mrope) self.att_metadata_builder.prepare() def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata) def set_seq_group_list( self, seq_group_metadata_list: List[SequenceGroupMetadata]): self.seq_group_metadata_list = seq_group_metadata_list def build(self) -> ModelInputForVACC: self._build_input_data() input_data = self.input_data input_tokens = torch.tensor(input_data.input_tokens, dtype=torch.int, device=self.runner.device) inputs_embeds = None if input_data.inputs_embeds is not None: inputs_embeds = input_data.inputs_embeds.to( dtype=self.runner.model_config.dtype, device=self.runner.device) if not input_data.input_tokens and inputs_embeds is None: # This may happen when all prefill requests hit # prefix caching and there is no decode request. return self.model_input_cls() input_positions = torch.tensor( input_data.input_positions if not any(input_data.input_mrope_positions) else input_data.input_mrope_positions, dtype=torch.int, device=self.runner.device) token_type_ids = torch.tensor(input_data.token_type_ids, dtype=torch.int, device=self.runner.device) \ if input_data.token_type_ids else None # For multi-modal models multi_modal_kwargs = None if len(input_data.multi_modal_inputs_list) != 0: multi_modal_kwargs = MultiModalKwargs.batch( input_data.multi_modal_inputs_list) attn_metadata = self.att_metadata_builder.build( input_data.seq_lens, input_data.query_lens, -1, -1) is_prompt = (self.seq_group_metadata_list[0].is_prompt if self.seq_group_metadata_list else None) # LoRA data. lora_requests = set() lora_mapping = None if self.enable_lora: lora_requests = set(seq.lora_request for seq in self.seq_group_metadata_list if seq.lora_request is not None) lora_mapping = self._prepare_lora_input( self.seq_group_metadata_list, is_prompt) return self.model_input_cls(input_tokens=input_tokens, inputs_embeds=inputs_embeds, input_positions=input_positions, token_type_ids=token_type_ids, seq_lens=input_data.seq_lens, query_lens=input_data.query_lens, attn_metadata=attn_metadata, multi_modal_kwargs=multi_modal_kwargs, lora_mapping=lora_mapping, lora_requests=lora_requests) def _build_input_data(self): for seq_group_metadata in self.seq_group_metadata_list: for seq_id, seq_data in seq_group_metadata.seq_data.items(): if seq_group_metadata.is_prompt: self._compute_prompt_input_tokens(self.input_data, seq_group_metadata, seq_data, seq_id) if seq_group_metadata.multi_modal_data: self._compute_multi_modal_input( seq_group_metadata, seq_data) else: self._compute_decode_input_tokens(self.input_data, seq_group_metadata, seq_data, seq_id) def _compute_decode_input_tokens(self, data: ModelInputData, seq_group_metadata: SequenceGroupMetadata, seq_data: SequenceData, seq_id: int): """ Compute decode input tokens, positions, block table and slot mapping. """ block_size = self.runner.block_size block_table = seq_group_metadata.block_tables[seq_id] seq_len = seq_data.get_len() context_len = seq_data.get_num_computed_tokens() tokens = seq_data.get_last_token_id() inputs_embeds = seq_data.get_token_embeddings() token_positions = seq_len - 1 block_number = block_table[token_positions // block_size] block_offset = token_positions % block_size slot = block_number * block_size + block_offset # For paged_attention kernel if self.runner.sliding_window: start_idx = max(0, seq_len - self.runner.sliding_window) start_block = start_idx // block_size start_idx = start_block * block_size seq_len = seq_len - start_idx block_table = block_table[start_block:] data.context_lens.append(context_len) # For MRotaryEmbedding if seq_data.mrope_position_delta is not None: next_pos = MRotaryEmbedding.get_next_input_positions( seq_data.mrope_position_delta, context_len, seq_len, ) for idx in range(3): data.input_mrope_positions[idx].extend( # type: ignore next_pos[idx]) else: data.input_positions.append(token_positions) # type: ignore # Update fields data.input_tokens.append(tokens) if data.inputs_embeds is not None: if inputs_embeds is not None: data.inputs_embeds = torch.cat([data.inputs_embeds, inputs_embeds], 0) else: if inputs_embeds is not None: data.inputs_embeds = inputs_embeds # data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len) data.num_decode_tokens += 1 data.slot_mapping.append(slot) assert check_block_table(block_table, self.block_num_per_group) data.decode_block_tables.append(block_table[::self.block_num_per_group]) data.query_lens.append(1) data.seq_lens.append(seq_len) def _compute_prompt_input_tokens(self, data: ModelInputData, seq_group_metadata: SequenceGroupMetadata, seq_data: SequenceData, seq_id: int): """ Compute prompt input tokens, positions, block table and slot mapping. """ token_chunk_size = seq_group_metadata.token_chunk_size block_size = self.runner.block_size # block_table = seq_group_metadata.block_tables[seq_id] seq_len = seq_data.get_len() context_len = seq_data.get_num_computed_tokens() seq_len = min(seq_len, context_len + token_chunk_size) # For prefix caching if seq_group_metadata.computed_block_nums is None: prefix_cache_block_num =0 else: prefix_cache_block_num = len(seq_group_metadata.computed_block_nums) if prefix_cache_block_num > 0: prefix_cache_len = (prefix_cache_block_num * self.runner.block_size) if prefix_cache_len <= context_len: # We already passed the cache hit region, # so do normal computation. pass elif context_len < prefix_cache_len < seq_len: # Partial hit. Compute the missing part. context_len = prefix_cache_len token_chunk_size = seq_len - context_len elif seq_len <= prefix_cache_len: # Full hit. Only compute the last token to avoid # erroneous behavior. FIXME: Ideally we should directly # mark all tokens as computed in the scheduler and do not # schedule this sequence, so this case should not happen. context_len = seq_len - 1 token_chunk_size = 1 # tokens = seq_data.get_token_ids() # tokens = tokens[context_len:seq_len] if seq_data.prompt_embeds is None: tokens = seq_data.get_token_ids()[context_len:seq_len] prompt_embeds = None else: tokens = [0] * (seq_len - context_len) prompt_embeds = seq_data.get_token_embeddings( )[context_len:seq_len] token_positions = range(context_len, seq_len) token_types = seq_group_metadata.token_type_ids # For encoder-only models, the block_table is None, # and there is no need to initialize the slot_mapping. if seq_group_metadata.block_tables is not None: block_table = seq_group_metadata.block_tables[seq_id] if block_table is not None: slot_mapping = [_PAD_SLOT_ID] * len(token_positions) for i, pos in enumerate(token_positions): block_number = block_table[pos // block_size] block_offset = pos % block_size slot = block_number * block_size + block_offset slot_mapping[i] = slot data.slot_mapping.extend(slot_mapping) else: block_table=[] # The MROPE positions are prepared in _compute_multi_modal_input data.input_positions.extend(token_positions) if data.token_type_ids is not None: data.token_type_ids.extend(token_types if token_types else []) # Update fields data.input_tokens.extend(tokens) if data.inputs_embeds is not None: if prompt_embeds is not None: data.inputs_embeds = torch.cat([data.inputs_embeds, prompt_embeds], 0) else: if prompt_embeds is not None: data.inputs_embeds = prompt_embeds data.num_prefills += 1 data.num_prefill_tokens += len(tokens) data.query_lens.append(len(tokens)) assert check_block_table(block_table, self.block_num_per_group) data.prefill_block_tables.append(block_table[::self.block_num_per_group]) data.seq_lens.append(seq_len) def _compute_multi_modal_input(self, seq_group_metadata: SequenceGroupMetadata, seq_data: SequenceData): computed_len = seq_data.get_num_computed_tokens() seq_len = self.input_data.seq_lens[-1] # NOTE: mm_data only includes the subset of multi-modal items that # intersect with the current prefill positions. mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( seq_group_metadata, range(computed_len, seq_len)) if not mm_data: return if self.runner.mm_registry.has_processor(self.runner.model_config): mm_kwargs = mm_data else: mm_kwargs = self.multi_modal_input_mapper( mm_data, seq_group_metadata.mm_processor_kwargs, ) # special processing for mrope position deltas. if self.runner.model_config.uses_mrope: assert not self.chunked_prefill, \ "MROPE on VACC does not support chunked-prefill." image_grid_thw = mm_kwargs.get("image_grid_thw", None) video_grid_thw = mm_kwargs.get("video_grid_thw", None) assert image_grid_thw is not None or video_grid_thw is not None, ( "mrope embedding type requires multi-modal input mapper " "returns 'image_grid_thw' or 'video_grid_thw'.") hf_config = self.runner.model_config.hf_config token_ids = seq_data.get_token_ids() mrope_positions, mrope_position_delta = \ MRotaryEmbedding.get_input_positions( token_ids, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, image_token_id=hf_config.image_token_id, video_token_id=hf_config.video_token_id, vision_start_token_id=hf_config.vision_start_token_id, vision_end_token_id=hf_config.vision_end_token_id, spatial_merge_size=hf_config.vision_config. spatial_merge_size, context_len=computed_len, ) seq_data.mrope_position_delta = mrope_position_delta for i in range(3): self.input_data.input_mrope_positions[ # type: ignore i].extend(mrope_positions[i]) self.input_data.multi_modal_inputs_list.append(mm_kwargs) for modality, placeholder_map in placeholder_maps.items(): self.input_data.multi_modal_placeholder_maps[modality].extend( placeholder_map) def _prepare_lora_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], is_prefill: bool) -> LoRAMapping: index_mapping = [] prompt_mapping = [] for seq in seq_group_metadata_list: lora_id = seq.lora_int_id query_len = seq.token_chunk_size index_mapping += [lora_id] * query_len prompt_mapping += [lora_id] * ( query_len if seq.sampling_params and seq.sampling_params.prompt_logprobs is not None else 1) return LoRAMapping(index_mapping=tuple(index_mapping), prompt_mapping=tuple(prompt_mapping), is_prefill=is_prefill) class VACCModelRunnerBase(ModelRunnerBase[TModelInputForVACC]): """ Helper class for shared methods between VACC model runners. """ _model_input_cls: Type[TModelInputForVACC] _builder_cls: Type[ModelInputForVACCBuilder] def __init__( self, vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, return_hidden_states: bool = False, *args, **kwargs, ): ModelRunnerBase.__init__(self, vllm_config) model_config = self.model_config cache_config = self.cache_config self.is_driver_worker = is_driver_worker self.return_hidden_states = return_hidden_states self.device = self.device_config.device self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = kv_cache_dtype self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size num_attn_heads = self.model_config.get_num_attention_heads( self.parallel_config) needs_attn_backend = (num_attn_heads != 0 or self.model_config.is_attention_free) self.attn_backend = get_attn_backend( self.model_config.get_head_size(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, self.model_config.is_attention_free, use_mla=self.model_config.use_mla ) if needs_attn_backend else None # Multi-modal data support # self.input_registry =INPUT_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY self.multi_modal_input_mapper = self.mm_registry \ .create_input_mapper(self.model_config) self.mm_registry.init_mm_limits_per_prompt(self.model_config) # Lazy initialization. self.model: nn.Module # Set after init_Model # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None self.sampler = get_sampler() if hasattr(self, "_builder_cls"): # multi-step model runner does not have `_builder_cls` self.builder = self._builder_cls(weakref.proxy(self)) def load_model(self) -> None: self.model = get_model(vllm_config=self.vllm_config) # register_module_trace(self.model) # register_module_trace(self.model.sampler) if self.lora_config: assert supports_lora( self.model ), f"{self.model.__class__.__name__} does not support LoRA yet." if supports_multimodal(self.model): logger.warning("Regarding multimodal models, vLLM currently " "only supports adding LoRA to language model.") # It's necessary to distinguish between the max_position_embeddings # of VLMs and LLMs. if hasattr(self.model.config, "max_position_embeddings"): max_pos_embeddings = self.model.config.max_position_embeddings else: max_pos_embeddings = ( self.model.config.text_config.max_position_embeddings) self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, self.vocab_size, self.lora_config, self.device, self.model.embedding_modules, self.model.embedding_padding_modules, max_position_embeddings=max_pos_embeddings, ) self.model = self.lora_manager.create_lora_manager(self.model) def get_model(self) -> nn.Module: return self.model def _prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], finished_requests_ids: Optional[List[str]] = None ) -> TModelInputForVACC: """Helper method to prepare the model input based on a given sequence group. Prepares metadata needed for the base model forward pass but not metadata for possible additional steps, e.g., sampling. """ self.builder.prepare(finished_requests_ids) self.builder.set_seq_group_list(seq_group_metadata_list) return self.builder.build() # type: ignore # sampler property will be used by spec_decode_worker # @property # def sampler(self): # return self.model.sampler @property def vocab_size(self) -> int: return self.model_config.get_vocab_size() @contextmanager def set_in_profile_run(self): self.in_profile_run = True try: yield finally: self.in_profile_run = False @torch.inference_mode() def profile_run(self) -> None: max_num_batched_tokens = \ self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs self._dummy_run(max_num_batched_tokens, max_num_seqs) def _dummy_run(self, max_num_batched_tokens: int, max_num_seqs: int = 1) -> None: with self.set_in_profile_run(): # Enable top-k sampling to reflect the accurate memory usage. sampling_params = \ SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) # This represents the maximum number of different requests # that will have unique loras, an therefore the max amount of memory # consumption create dummy lora request copies from the lora request # passed in, which contains a lora from the lora warmup path. dummy_lora_requests: List[LoRARequest] = [] dummy_lora_requests_per_seq: List[LoRARequest] = [] if self.lora_config: assert self.lora_manager is not None with self.lora_manager.dummy_lora_cache(): for idx in range(self.lora_config.max_loras): lora_id = idx + 1 dummy_lora_request = LoRARequest( lora_name=f"warmup_{lora_id}", lora_int_id=lora_id, lora_path="/not/a/real/path", ) self.lora_manager.add_dummy_lora(dummy_lora_request, rank=8) dummy_lora_requests.append(dummy_lora_request) dummy_lora_requests_per_seq = [ dummy_lora_requests[idx % len(dummy_lora_requests)] for idx in range(max_num_seqs) ] # Profile memory usage with max_num_sequences sequences and the # total number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] # Additional GPU memory may be needed for multi-modal encoding, # which needs to be accounted for when calculating the GPU blocks # for vLLM blocker manager. # To exercise the worst scenario for GPU memory consumption, # the number of seqs (batch_size) is chosen to maximize the number # of images processed. max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( self.model_config) if max_mm_tokens > 0: max_num_seqs_orig = max_num_seqs max_num_seqs = min(max_num_seqs, max_num_batched_tokens // max_mm_tokens) if max_num_seqs < 1: expr = (f"min({max_num_seqs_orig}, " f"{max_num_batched_tokens} // {max_mm_tokens})") logger.warning( "Computed max_num_seqs (%s) to be less than 1. " "Setting it to the minimum value of 1.", expr) max_num_seqs = 1 block_tables_lst: List[List[int]] = [] max_num_blocks_per_seq = (self.model_config.max_model_len // self.block_size) # for _ in range(max_num_seqs): # block_table = [0 for _ in range(max_num_blocks_per_seq)] # block_tables_lst.append(block_table) # block_tables = torch.tensor(block_tables_lst, # dtype=torch.int32, # device=self.device) batch_size = 0 for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len dummy_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, seq_len, self.mm_registry) seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, seq_data={group_id: dummy_data.seq_data}, sampling_params=sampling_params, block_tables=None, lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None, multi_modal_data=dummy_data.multi_modal_data, multi_modal_placeholders=dummy_data. multi_modal_placeholders, ) seqs.append(seq) # break # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) # use an empty tensor instead of `None`` to force Dynamo to pass # it by reference, rather by specializing on the value ``None``. # the `dtype` argument does not matter, and we use `float32` as # a placeholder (it has wide hardware support). # it is important to create tensors inside the loop, rather than # multiplying the list, to avoid Dynamo from treating them as # tensor aliasing. kv_caches = [ torch.tensor([], dtype=torch.float32, device=self.device) for _ in range(num_layers) ] finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) intermediate_tensors = None if not get_pp_group().is_first_rank: intermediate_tensors = \ self.model.make_empty_intermediate_tensors( batch_size=batch_size, dtype=self.model_config.dtype, device=self.device) # Disable KV Scale Calculation for dummy data during profile run if model_input.attn_metadata is not None: model_input.attn_metadata.enable_kv_scales_calculation = False self.execute_model(model_input, kv_caches, intermediate_tensors) torch.vacc.synchronize() if self.lora_config: # Remove dummy loras. assert self.lora_manager is not None self.remove_all_loras() return def remove_all_loras(self): if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") self.lora_manager.remove_all_adapters() def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") self.lora_manager.set_active_adapters(lora_requests, lora_mapping) def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.add_adapter(lora_request) def remove_lora(self, lora_id: int) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.remove_adapter(lora_id) def pin_lora(self, lora_id: int) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.pin_adapter(lora_id) def list_loras(self) -> Set[int]: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.list_adapters() class VACCModelRunner(VACCModelRunnerBase[ModelInputForVACCWithSamplingMetadata]): _model_input_cls: Type[ModelInputForVACCWithSamplingMetadata] = ( ModelInputForVACCWithSamplingMetadata) _builder_cls: Type[ModelInputForVACCBuilder] = ModelInputForVACCBuilder def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any], ) -> ModelInputForVACCWithSamplingMetadata: return ModelInputForVACCWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501 tensor_dict, attn_backend=self.attn_backend, ) def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForVACCWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. """ model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) # Sampling metadata is only required for the final pp group generators = self.get_generators(finished_requests_ids) sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, model_input.seq_lens, model_input.query_lens, self.device, pin_memory=self.pin_memory, generators=generators) is_prompt = (seq_group_metadata_list[0].is_prompt if seq_group_metadata_list else None) return dataclasses.replace(model_input, sampling_metadata=sampling_metadata, virtual_engine=virtual_engine, is_prompt=is_prompt) @torch.no_grad() def execute_model( self, model_input: ModelInputForVACCWithSamplingMetadata, kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, previous_hidden_states: Optional[torch.Tensor] = None, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: raise ValueError( "CPU worker does not support multi-step execution.") if self.lora_config: assert model_input.lora_requests is not None assert model_input.lora_mapping is not None self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) model_executable = self.model if model_input.attn_metadata.prefill_metadata is not None: try: torch.vacc.empty_cache() except Exception as e: logger.warn("vacc empty cache skiping...") else: update_seqence_length(model_input.attn_metadata.decode_metadata.seq_lens) multimodal_kwargs = {} if model_input.multi_modal_kwargs is not None: multimodal_kwargs = MultiModalKwargs.as_kwargs( model_input.multi_modal_kwargs, device=self.device) execute_model_kwargs = {} if previous_hidden_states is not None: execute_model_kwargs.update( {"previous_hidden_states": previous_hidden_states}) with set_forward_context(model_input.attn_metadata, self.vllm_config, model_input.virtual_engine): hidden_states = model_executable( input_ids=model_input.input_tokens, inputs_embeds=model_input.inputs_embeds, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, **execute_model_kwargs, **multimodal_kwargs, ) if not get_pp_group().is_last_rank: return hidden_states # Compute the logits. logits = self.model.compute_logits(hidden_states, model_input.sampling_metadata) if self.is_driver_worker: if model_input.async_callback is not None: model_input.async_callback() # Sample the next token. assert isinstance(self.sampler, Sampler) orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor if model_input.inputs_embeds is not None: self.sampler.include_gpu_probs_tensor = True output = self.sampler( logits=logits, sampling_metadata=model_input.sampling_metadata, ) if model_input.inputs_embeds is not None: if self.is_driver_worker: sampled = broadcast_tensor_dict( {"token_ids": output.sampled_token_ids}) # sampled = broadcast_tensor_dict( # {"token_ids": output.outputs[0].samples[0].output_token}) else: sampled = broadcast_tensor_dict() if sampled["token_ids"] is not None: sampled_token_embeds = self.model.get_input_embeddings( sampled["token_ids"]) if self.is_driver_worker: self.sampler.include_gpu_probs_tensor = \ orig_include_gpu_probs output.sampled_token_embeds = sampled_token_embeds for token_embed, sequence_group_output in zip( output.sampled_token_embeds, output.outputs): assert len(sequence_group_output.samples) == 1 sequence_group_output.samples[ 0].output_embed = token_embed # model_input.inputs_embeds = sampled_token_embeds # Only perform sampling in the driver worker. if not self.is_driver_worker: return [] # if model_input.async_callback is not None: # model_input.async_callback() # Sample the next token. # output = self.model.sample( # output = self.sampler( # logits=logits, # sampling_metadata=model_input.sampling_metadata, # ) if self.return_hidden_states: # we only need to pass hidden states of most recent token hidden_or_intermediate_states = hidden_states if model_input.is_prompt: assert model_input.sampling_metadata is not None indices = model_input.sampling_metadata.selected_token_indices # align with gpu model-runner hidden_or_intermediate_states = hidden_states.index_select( 0, indices) output.prefill_hidden_states = hidden_states output.hidden_states = hidden_or_intermediate_states return [output] def generate_proposals(self, *args, **kwargs): return self.model.generate_proposals(*args, **kwargs)