import gc import inspect import itertools import time import weakref import numpy as np import torch import torch.distributed import torch.nn as nn from contextlib import contextmanager from dataclasses import dataclass from typing import (Dict, List, Optional, Set, Tuple, Type, Union) from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.utils import CommonAttentionState from vllm.compilation.compile_context import set_compile_context from vllm.config import VllmConfig from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadataCache from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, MultiModalRegistry) from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.worker_manager import ( LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import (GiB_bytes, PyObjectCache, async_tensor_h2d, flatten_2d_lists, is_pin_memory_available) from vllm.worker.model_runner_base import (ModelRunnerBase, dump_input_when_exception) from vllm.worker.model_runner import ( TModelInputForGPU, ModelInputForGPU, ModelInputForGPUWithSamplingMetadata, ModelInputForGPUBuilder, GPUModelRunnerBase, ModelRunner, CUDAGraphRunner, ModelRunnerBase, LORA_WARMUP_RANK, _NUM_WARMUP_ITERS, _get_max_graph_batch_size, _BATCH_SIZES_TO_CAPTURE ) logger = init_logger(__name__) @dataclass class MLUGraphCaptureContext: stream: torch.mlu.Stream @contextmanager def mlu_graph_capture(graph_capture_context: Optional[MLUGraphCaptureContext] = None): if graph_capture_context is None: stream = torch.mlu.Stream() graph_capture_context = MLUGraphCaptureContext(stream) else: stream = graph_capture_context.stream # ensure all initialization operations complete before attempting to # capture the graph on another stream curr_stream = torch.mlu.current_stream() if curr_stream != stream: stream.wait_stream(curr_stream) with torch.mlu.stream(stream): yield graph_capture_context class ModelInputForMLUBuilder(ModelInputForGPUBuilder): """Build ModelInputForGPU from SequenceGroupMetadata.""" def build(self) -> ModelInputForGPU: """Finalize the builder intermediate data and create on-device tensors. """ # Combine and flatten intermediate data. input_tokens = [] for inter_data in self.inter_data_list: for cur_input_tokens in inter_data.input_tokens: input_tokens.extend(cur_input_tokens) if not input_tokens: # This may happen when all prefill requests hit # prefix caching and there is no decode request. return self.model_input_cls() mrope_input_positions: Optional[List[List[int]]] = None if any(inter_data.mrope_input_positions is not None for inter_data in self.inter_data_list): mrope_input_positions = [[] for _ in range(3)] for idx in range(3): for inter_data in self.inter_data_list: msections = inter_data.mrope_input_positions if msections is None: for _seq_input_positions in inter_data.input_positions: mrope_input_positions[idx].extend( _seq_input_positions) else: for _seq_mrope_input_positions in msections: mrope_input_positions[idx].extend( _seq_mrope_input_positions[idx]) input_positions = None else: input_positions = [] for inter_data in self.inter_data_list: for cur_input_positions in inter_data.input_positions: input_positions.extend(cur_input_positions) seq_lens = [] query_lens = [] max_decode_seq_len = 0 max_encoder_seq_len = 0 for inter_data in self.inter_data_list: seq_lens.extend(inter_data.seq_lens) query_lens.extend(inter_data.query_lens) if not inter_data.is_prompt: max_decode_seq_len = max(max_decode_seq_len, max(inter_data.seq_lens)) if self.runner.model_config.is_encoder_decoder: max_encoder_seq_len = max(max_encoder_seq_len, inter_data.encoder_seq_len) # Mapping from request IDs to sequence IDs. Used for Jamba models # that manages the cache by itself. request_ids_to_seq_ids = { data.request_id: data.seq_ids for data in self.inter_data_list } cuda_graph_pad_size = self._get_cuda_graph_pad_size( num_seqs=len(seq_lens), max_decode_seq_len=max_decode_seq_len, max_encoder_seq_len=max_encoder_seq_len) batch_size = len(input_tokens) if cuda_graph_pad_size != -1: # If cuda graph can be used, pad tensors accordingly. # See `capture_model` API for more details. # vLLM uses cuda graph only for decoding requests. batch_size += cuda_graph_pad_size # Tokens and positions. if cuda_graph_pad_size: input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) assert self.runner.device is not None input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, self.runner.device, self.runner.pin_memory) if mrope_input_positions is not None: for idx in range(3): mrope_input_positions[idx].extend( itertools.repeat(0, cuda_graph_pad_size)) input_positions_tensor = async_tensor_h2d(mrope_input_positions, torch.int32, self.runner.device, self.runner.pin_memory) else: input_positions.extend(itertools.repeat(0, cuda_graph_pad_size)) input_positions_tensor = async_tensor_h2d(input_positions, torch.int32, self.runner.device, self.runner.pin_memory) # Sequence and query lengths. if cuda_graph_pad_size: seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size)) # Attention metadata. attn_metadata = self.attn_metadata_builder.build( seq_lens, query_lens, cuda_graph_pad_size, batch_size) # LoRA data. lora_requests = set() lora_mapping = None if self.enable_lora: lora_requests = set(r for data in self.inter_data_list for r in data.lora_requests) lora_index_mapping = flatten_2d_lists([ flatten_2d_lists(inter_data.lora_index_mapping) for inter_data in self.inter_data_list ]) if cuda_graph_pad_size: lora_index_mapping.extend( itertools.repeat(0, cuda_graph_pad_size)) lora_prompt_mapping = flatten_2d_lists([ flatten_2d_lists(inter_data.lora_prompt_mapping) for inter_data in self.inter_data_list ]) lora_mapping = LoRAMapping( **dict(index_mapping=lora_index_mapping, prompt_mapping=lora_prompt_mapping, is_prefill=not self.decode_only)) # Prompt adapter data. prompt_adapter_requests: Set[PromptAdapterRequest] = set() prompt_adapter_mapping = None if self.enable_prompt_adapter: prompt_adapter_requests = set( data.prompt_adapter_request for data in self.inter_data_list if data.prompt_adapter_request is not None) prompt_adapter_index_mapping = flatten_2d_lists([ inter_data.prompt_adapter_index_mapping for inter_data in self.inter_data_list ]) if cuda_graph_pad_size: prompt_adapter_index_mapping.extend( itertools.repeat(0, cuda_graph_pad_size)) prompt_adapter_prompt_mapping = flatten_2d_lists([ inter_data.prompt_adapter_prompt_mapping for inter_data in self.inter_data_list ]) prompt_adapter_mapping = PromptAdapterMapping( prompt_adapter_index_mapping, prompt_adapter_prompt_mapping, ) # Multi-modal data. multi_modal_kwargs_list = [ data.multi_modal_kwargs for data in self.inter_data_list if data.multi_modal_kwargs is not None ] multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) return self.model_input_cls( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, attn_metadata=attn_metadata, seq_lens=seq_lens, query_lens=query_lens, lora_mapping=lora_mapping, lora_requests=lora_requests, multi_modal_kwargs=multi_modal_kwargs, request_ids_to_seq_ids=request_ids_to_seq_ids, finished_requests_ids=self.finished_requests_ids, prompt_adapter_mapping=prompt_adapter_mapping, prompt_adapter_requests=prompt_adapter_requests) class MLUModelRunnerBase(GPUModelRunnerBase[TModelInputForGPU]): """ Helper class for shared methods between MLU model runners. """ def __init__( self, vllm_config: VllmConfig, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, return_hidden_states: bool = False, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): ModelRunnerBase.__init__(self, vllm_config) model_config = self.model_config cache_config = self.cache_config self.is_driver_worker = is_driver_worker self.return_hidden_states = return_hidden_states self.device = self.device_config.device self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = kv_cache_dtype self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture self.max_batchsize_to_capture = _get_max_graph_batch_size( self.scheduler_config.max_num_seqs) self.graph_runners: List[Dict[int, MLUGraphRunner]] = [ {} for _ in range(self.parallel_config.pipeline_parallel_size) ] self.graph_memory_pool: Optional[Tuple[ int, int]] = None # Set during graph capture. self.has_inner_state = model_config.has_inner_state # When using CUDA graph, the input block tables must be padded to # max_seq_len_to_capture. However, creating the block table in # Python can be expensive. To optimize this, we cache the block table # in numpy and only copy the actual input content at every iteration. # The shape of the cached block table will be # (max batch size to capture, max seq len to capture / block size). self.graph_block_tables = np.zeros( (self.max_batchsize_to_capture, self.get_max_block_per_batch()), dtype=np.int32) # Attention-free but stateful models like Mamba need a placeholder attn # backend, as the attention metadata is needed to manage internal state. # However we must bypass attention selection altogether for some models # used for speculative decoding to avoid a divide-by-zero in # model_config.get_head_size() num_attn_heads = self.model_config.get_num_attention_heads( self.parallel_config) needs_attn_backend = (num_attn_heads != 0 or self.model_config.is_attention_free) self.attn_backend = get_attn_backend( self.model_config.get_head_size(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, self.model_config.is_attention_free, ) if needs_attn_backend else None if self.attn_backend: self.attn_state = self.attn_backend.get_state_cls()( weakref.proxy(self)) else: self.attn_state = CommonAttentionState(weakref.proxy(self)) # Multi-modal data support self.input_registry = input_registry self.mm_registry = mm_registry self.multi_modal_input_mapper = mm_registry \ .create_input_mapper(model_config) self.mm_registry.init_mm_limits_per_prompt(self.model_config) # Lazy initialization self.model: nn.Module # Set after load_model # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None set_cpu_offload_max_bytes( int(self.cache_config.cpu_offload_gb * 1024**3)) # Used to cache python objects self.inter_data_cache: Dict[int, PyObjectCache] = {} # Using the PythonizationCache in Pipeline-Parallel clobbers the # SequenceGroupToSample object. In Pipeline-Parallel, we have # more than 1 Scheduler, resulting in a potential back-to-back # prepare_model_inputs() call. This clobbers the cached # SequenceGroupToSample objects, as we reset the cache during # every prepare_model_inputs() call. self.sampling_metadata_cache: SamplingMetadataCache = \ SamplingMetadataCache() \ if self.parallel_config.pipeline_parallel_size == 1 else None @torch.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs # This represents the maximum number of different requests # that will have unique loras, an therefore the max amount of memory # consumption create dummy lora request copies from the lora request # passed in, which contains a lora from the lora warmup path. dummy_lora_requests: List[LoRARequest] = [] dummy_lora_requests_per_seq: List[LoRARequest] = [] if self.lora_config: assert self.lora_manager is not None with self.lora_manager.dummy_lora_cache(): for idx in range(self.lora_config.max_loras): lora_id = idx + 1 dummy_lora_request = LoRARequest( lora_name=f"warmup_{lora_id}", lora_int_id=lora_id, lora_path="/not/a/real/path", ) self.lora_manager.add_dummy_lora(dummy_lora_request, rank=LORA_WARMUP_RANK) dummy_lora_requests.append(dummy_lora_request) dummy_lora_requests_per_seq = [ dummy_lora_requests[idx % len(dummy_lora_requests)] for idx in range(max_num_seqs) ] # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] # Additional GPU memory may be needed for multi-modal encoding, which # needs to be accounted for when calculating the GPU blocks for # vLLM blocker manager. # To exercise the worst scenario for GPU memory consumption, # the number of seqs (batch_size) is chosen to maximize the number # of images processed. max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( self.model_config) if max_mm_tokens > 0: max_num_seqs_orig = max_num_seqs max_num_seqs = min(max_num_seqs, max_num_batched_tokens // max_mm_tokens) if max_num_seqs < 1: expr = (f"min({max_num_seqs_orig}, " f"{max_num_batched_tokens} // {max_mm_tokens})") logger.warning( "Computed max_num_seqs (%s) to be less than 1. " "Setting it to the minimum value of 1.", expr) max_num_seqs = 1 batch_size = 0 for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len dummy_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, seq_len, self.mm_registry) seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, seq_data={group_id: dummy_data.seq_data}, sampling_params=sampling_params, block_tables=None, lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None, multi_modal_data=dummy_data.multi_modal_data, multi_modal_placeholders=dummy_data.multi_modal_placeholders, ) seqs.append(seq) # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) # use an empty tensor instead of `None`` to force Dynamo to pass # it by reference, rather by specializing on the value ``None``. # the `dtype` argument does not matter, and we use `float32` as # a placeholder (it has wide hardware support). # it is important to create tensors inside the loop, rather than # multiplying the list, to avoid Dynamo from treating them as # tensor aliasing. kv_caches = [ torch.tensor([], dtype=torch.float32, device=self.device) for _ in range(num_layers) ] finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) intermediate_tensors = None if not get_pp_group().is_first_rank: intermediate_tensors = self.model.make_empty_intermediate_tensors( batch_size=batch_size, dtype=self.model_config.dtype, device=self.device) graph_batch_size = self.max_batchsize_to_capture batch_size_capture_list = [ bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] if self.model_config.enforce_eager: batch_size_capture_list = [] with set_compile_context(batch_size_capture_list): self.execute_model(model_input, kv_caches, intermediate_tensors) torch.mlu.synchronize() return @torch.inference_mode() def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: """Cuda graph capture a model. Note that CUDA graph's performance gain is negligible if number of batched tokens are larger than 200. And since CUDA graph requires fixed sized tensors, supporting large/variable batch size requires high GPU memory overhead. Thus, vLLM only captures decoding requests. Mixed batch (chunked prefill + decoding) or prefill requests are not captured. Since it is used for decoding-only, it assumes there's only 1 token per sequence in the batch. """ assert not self.model_config.enforce_eager logger.info("Capturing the model for MLU graphs. This may lead to " "unexpected consequences if the model is not static. To " "run the model in eager mode, set 'enforce_eager=True' or " "use '--enforce-eager' in the CLI.") logger.info("MLU graphs can take additional 1~3 GiB memory per MLU. " "If you are running out of memory, consider decreasing " "`gpu_memory_utilization` or enforcing eager mode. " "You can also reduce the `max_num_seqs` as needed " "to decrease memory usage.") start_time = time.perf_counter() start_free_gpu_memory = torch.mlu.mem_get_info()[0] # Prepare dummy inputs. These will be reused for all batch sizes. max_batch_size = self.max_batchsize_to_capture input_tokens = torch.zeros(max_batch_size, dtype=torch.long).mlu() input_positions = torch.zeros(max_batch_size, dtype=torch.int32).mlu() if self.model_config.uses_mrope: input_positions = torch.tile(input_positions, (3, 1)) # Prepare dummy previous_hidden_states only if needed by the model. # This is used by draft models such as EAGLE. previous_hidden_states = None if "previous_hidden_states" in inspect.signature( self.model.forward).parameters: previous_hidden_states = torch.empty( [max_batch_size, self.model_config.get_hidden_size()], dtype=self.model_config.dtype, device=self.device) intermediate_inputs = None if not get_pp_group().is_first_rank: intermediate_inputs = self.model.make_empty_intermediate_tensors( batch_size=max_batch_size, dtype=self.model_config.dtype, device=self.device) graph_batch_size = self.max_batchsize_to_capture batch_size_capture_list = [ bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] with self.attn_state.graph_capture( max_batch_size), mlu_graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for virtual_engine in range( self.parallel_config.pipeline_parallel_size): for batch_size in reversed(batch_size_capture_list): attn_metadata = ( self.attn_state.graph_capture_get_metadata_for_batch( batch_size, is_encoder_decoder_model=self.model_config. is_encoder_decoder)) if self.lora_config: lora_mapping = LoRAMapping( **dict(index_mapping=[0] * batch_size, prompt_mapping=[0] * batch_size, is_prefill=False)) self.set_active_loras(set(), lora_mapping) if self.prompt_adapter_config: prompt_adapter_mapping = PromptAdapterMapping( [-1] * batch_size, [-1] * batch_size, ) self.set_active_prompt_adapters( set(), prompt_adapter_mapping) graph_runner = MLUGraphRunner( self.model, self.attn_backend.get_name(), self.attn_state.graph_clone(batch_size), self.model_config.is_encoder_decoder) capture_inputs = { "input_ids": input_tokens[:batch_size], "positions": input_positions[..., :batch_size], "intermediate_inputs": intermediate_inputs[:batch_size] if intermediate_inputs is not None else None, "kv_caches": kv_caches[virtual_engine], "attn_metadata": attn_metadata, "memory_pool": self.graph_memory_pool, "stream": graph_capture_context.stream } if previous_hidden_states is not None: capture_inputs[ "previous_hidden_states"] = previous_hidden_states[: batch_size] if self.has_inner_state: # Only used by Mamba-based models CUDA graph atm (Jamba) capture_inputs.update({ "seqlen_agnostic_capture_inputs": self.model.get_seqlen_agnostic_capture_inputs( batch_size) }) if self.model_config.is_encoder_decoder: # add the additional inputs to capture for # encoder-decoder models. self._update_inputs_to_capture_for_enc_dec_model( capture_inputs) with set_forward_context(attn_metadata): graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[virtual_engine][batch_size] = ( graph_runner) end_time = time.perf_counter() end_free_gpu_memory = torch.mlu.mem_get_info()[0] elapsed_time = end_time - start_time mlu_graph_size = start_free_gpu_memory - end_free_gpu_memory # This usually takes < 10 seconds. logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, mlu_graph_size / GiB_bytes) class MLUModelRunner(MLUModelRunnerBase, ModelRunner): """ MLU model runner with sampling step. """ _builder_cls: Type[ModelInputForMLUBuilder] = ModelInputForMLUBuilder @torch.inference_mode() @dump_input_when_exception(exclude_args=[0], exclude_kwargs=["self"]) def execute_model( self, model_input: ModelInputForGPUWithSamplingMetadata, kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError("num_steps > 1 is not supported in ModelRunner") if self.lora_config: assert model_input.lora_requests is not None assert model_input.lora_mapping is not None self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) if self.prompt_adapter_config: assert model_input.prompt_adapter_requests is not None assert model_input.prompt_adapter_mapping is not None self.set_active_prompt_adapters( model_input.prompt_adapter_requests, model_input.prompt_adapter_mapping) self.attn_state.begin_forward(model_input) # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata decode_meta = model_input.attn_metadata.decode_metadata # TODO(andoorve): We can remove this once all # virtual engines share the same kv cache. virtual_engine = model_input.virtual_engine if prefill_meta is None and decode_meta.use_cuda_graph: assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] model_executable = self.graph_runners[virtual_engine][ graph_batch_size] else: model_executable = self.model multi_modal_kwargs = model_input.multi_modal_kwargs or {} seqlen_agnostic_kwargs = { "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_inner_state else {} if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_start = torch.mlu.Event(enable_timing=True) model_forward_end = torch.mlu.Event(enable_timing=True) model_forward_start.record() with set_forward_context(model_input.attn_metadata): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self.device), **seqlen_agnostic_kwargs) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end.record() # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: if (self.is_driver_worker and hidden_or_intermediate_states is not None and isinstance(hidden_or_intermediate_states, IntermediateTensors) and self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end.synchronize() model_forward_time = model_forward_start.elapsed_time( model_forward_end) orig_model_forward_time = 0.0 if intermediate_tensors is not None: orig_model_forward_time = intermediate_tensors.tensors.get( "model_forward_time", torch.tensor(0.0)).item() hidden_or_intermediate_states.tensors["model_forward_time"] = ( torch.tensor(model_forward_time + orig_model_forward_time)) return hidden_or_intermediate_states logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) if not self.is_driver_worker: return [] if model_input.async_callback is not None: model_input.async_callback() # Sample the next token. output: SamplerOutput = self.model.sample( logits=logits, sampling_metadata=model_input.sampling_metadata, ) if (self.observability_config is not None and self.observability_config.collect_model_forward_time and output is not None): model_forward_end.synchronize() model_forward_time = model_forward_start.elapsed_time( model_forward_end) orig_model_forward_time = 0.0 if intermediate_tensors is not None: orig_model_forward_time = intermediate_tensors.tensors.get( "model_forward_time", torch.tensor(0.0)).item() # If there are multiple workers, we are still tracking the latency # from the start time of the driver worker to the end time of the # driver worker. The model forward time will then end up covering # the communication time as well. output.model_forward_time = (orig_model_forward_time + model_forward_time) if self.return_hidden_states: # we only need to pass hidden states of most recent token assert model_input.sampling_metadata is not None indices = model_input.sampling_metadata.selected_token_indices if model_input.is_prompt: hidden_states = hidden_or_intermediate_states.index_select( 0, indices) output.prefill_hidden_states = hidden_or_intermediate_states elif decode_meta.use_cuda_graph: hidden_states = hidden_or_intermediate_states[:len(indices)] else: hidden_states = hidden_or_intermediate_states output.hidden_states = hidden_states return [output] class MLUGraphRunner(CUDAGraphRunner): def capture( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_inputs: Optional[IntermediateTensors], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, memory_pool: Optional[Tuple[int, int]], stream: torch.cuda.Stream, **kwargs, ): assert self._graph is None # Run the model a few times without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). # Note one iteration is not enough for torch.jit.script for _ in range(_NUM_WARMUP_ITERS): self.model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, intermediate_tensors=intermediate_inputs, **kwargs, ) # Wait for the warm up operations to finish before proceeding with # Graph Capture. torch.mlu.synchronize() # Capture the graph. self._graph = torch.mlu.MLUGraph() with torch.mlu.graph(self._graph, pool=memory_pool, stream=stream): output_hidden_or_intermediate_states = self.model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, intermediate_tensors=intermediate_inputs, **kwargs, ) hidden_or_intermediate_states = ( output_hidden_or_intermediate_states) del output_hidden_or_intermediate_states # make sure `output_hidden_or_intermediate_states` is deleted # in the graph's memory pool gc.collect() torch.mlu.synchronize() # Save the input and output buffers. self.input_buffers = { "input_ids": input_ids, "positions": positions, "kv_caches": kv_caches, **self.attn_state.get_graph_input_buffers( attn_metadata, self._is_encoder_decoder_model), **kwargs, } if intermediate_inputs is not None: self.input_buffers.update(intermediate_inputs.tensors) if get_pp_group().is_last_rank: self.output_buffers = { "hidden_states": hidden_or_intermediate_states } else: self.output_buffers = hidden_or_intermediate_states def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], **kwargs, ) -> torch.Tensor: # KV caches are fixed tensors, so we don't need to copy them. del kv_caches # Copy the input tensors to the input buffers. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True) if self.backend_name != "NO_ATTENTION": self.input_buffers["slot_mapping"].copy_( attn_metadata.slot_mapping, non_blocking=True) self.attn_state.prepare_graph_input_buffers( self.input_buffers, attn_metadata, self._is_encoder_decoder_model) if "seqlen_agnostic_capture_inputs" in self.input_buffers: self.model.copy_inputs_before_cuda_graphs(self.input_buffers, **kwargs) if "previous_hidden_states" in self.input_buffers: self.input_buffers["previous_hidden_states"].copy_( kwargs["previous_hidden_states"], non_blocking=True) if intermediate_tensors is not None: for key in intermediate_tensors.tensors: if key != "model_execute_time" and key != "model_forward_time": self.input_buffers[key].copy_(intermediate_tensors[key], non_blocking=True) if self._is_encoder_decoder_model: self.input_buffers["encoder_input_ids"].copy_( kwargs['encoder_input_ids'], non_blocking=True) self.input_buffers["encoder_positions"].copy_( kwargs['encoder_positions'], non_blocking=True) # Run the graph. self.graph.replay() # Return the output tensor. if get_pp_group().is_last_rank: return self.output_buffers["hidden_states"] return self.output_buffers