Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm/worker/mlu_model_runner.py
2026-02-04 17:22:39 +08:00

863 lines
38 KiB
Python

import gc
import inspect
import itertools
import time
import weakref
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (Dict, List, Optional, Set, Tuple, Type, Union)
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.utils import CommonAttentionState
from vllm.compilation.compile_context import set_compile_context
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadataCache
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry)
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import (GiB_bytes, PyObjectCache,
async_tensor_h2d, flatten_2d_lists,
is_pin_memory_available)
from vllm.worker.model_runner_base import (ModelRunnerBase,
dump_input_when_exception)
from vllm.worker.model_runner import (
TModelInputForGPU, ModelInputForGPU,
ModelInputForGPUWithSamplingMetadata,
ModelInputForGPUBuilder, GPUModelRunnerBase,
ModelRunner, CUDAGraphRunner,
ModelRunnerBase, LORA_WARMUP_RANK,
_NUM_WARMUP_ITERS, _get_max_graph_batch_size,
_BATCH_SIZES_TO_CAPTURE
)
logger = init_logger(__name__)
@dataclass
class MLUGraphCaptureContext:
stream: torch.mlu.Stream
@contextmanager
def mlu_graph_capture(graph_capture_context: Optional[MLUGraphCaptureContext] = None):
if graph_capture_context is None:
stream = torch.mlu.Stream()
graph_capture_context = MLUGraphCaptureContext(stream)
else:
stream = graph_capture_context.stream
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch.mlu.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)
with torch.mlu.stream(stream):
yield graph_capture_context
class ModelInputForMLUBuilder(ModelInputForGPUBuilder):
"""Build ModelInputForGPU from SequenceGroupMetadata."""
def build(self) -> ModelInputForGPU:
"""Finalize the builder intermediate data and
create on-device tensors.
"""
# Combine and flatten intermediate data.
input_tokens = []
for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens:
input_tokens.extend(cur_input_tokens)
if not input_tokens:
# This may happen when all prefill requests hit
# prefix caching and there is no decode request.
return self.model_input_cls()
mrope_input_positions: Optional[List[List[int]]] = None
if any(inter_data.mrope_input_positions is not None
for inter_data in self.inter_data_list):
mrope_input_positions = [[] for _ in range(3)]
for idx in range(3):
for inter_data in self.inter_data_list:
msections = inter_data.mrope_input_positions
if msections is None:
for _seq_input_positions in inter_data.input_positions:
mrope_input_positions[idx].extend(
_seq_input_positions)
else:
for _seq_mrope_input_positions in msections:
mrope_input_positions[idx].extend(
_seq_mrope_input_positions[idx])
input_positions = None
else:
input_positions = []
for inter_data in self.inter_data_list:
for cur_input_positions in inter_data.input_positions:
input_positions.extend(cur_input_positions)
seq_lens = []
query_lens = []
max_decode_seq_len = 0
max_encoder_seq_len = 0
for inter_data in self.inter_data_list:
seq_lens.extend(inter_data.seq_lens)
query_lens.extend(inter_data.query_lens)
if not inter_data.is_prompt:
max_decode_seq_len = max(max_decode_seq_len,
max(inter_data.seq_lens))
if self.runner.model_config.is_encoder_decoder:
max_encoder_seq_len = max(max_encoder_seq_len,
inter_data.encoder_seq_len)
# Mapping from request IDs to sequence IDs. Used for Jamba models
# that manages the cache by itself.
request_ids_to_seq_ids = {
data.request_id: data.seq_ids
for data in self.inter_data_list
}
cuda_graph_pad_size = self._get_cuda_graph_pad_size(
num_seqs=len(seq_lens),
max_decode_seq_len=max_decode_seq_len,
max_encoder_seq_len=max_encoder_seq_len)
batch_size = len(input_tokens)
if cuda_graph_pad_size != -1:
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
batch_size += cuda_graph_pad_size
# Tokens and positions.
if cuda_graph_pad_size:
input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
assert self.runner.device is not None
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
self.runner.device,
self.runner.pin_memory)
if mrope_input_positions is not None:
for idx in range(3):
mrope_input_positions[idx].extend(
itertools.repeat(0, cuda_graph_pad_size))
input_positions_tensor = async_tensor_h2d(mrope_input_positions,
torch.int32,
self.runner.device,
self.runner.pin_memory)
else:
input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
input_positions_tensor = async_tensor_h2d(input_positions,
torch.int32,
self.runner.device,
self.runner.pin_memory)
# Sequence and query lengths.
if cuda_graph_pad_size:
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
# Attention metadata.
attn_metadata = self.attn_metadata_builder.build(
seq_lens, query_lens, cuda_graph_pad_size, batch_size)
# LoRA data.
lora_requests = set()
lora_mapping = None
if self.enable_lora:
lora_requests = set(r for data in self.inter_data_list
for r in data.lora_requests)
lora_index_mapping = flatten_2d_lists([
flatten_2d_lists(inter_data.lora_index_mapping)
for inter_data in self.inter_data_list
])
if cuda_graph_pad_size:
lora_index_mapping.extend(
itertools.repeat(0, cuda_graph_pad_size))
lora_prompt_mapping = flatten_2d_lists([
flatten_2d_lists(inter_data.lora_prompt_mapping)
for inter_data in self.inter_data_list
])
lora_mapping = LoRAMapping(
**dict(index_mapping=lora_index_mapping,
prompt_mapping=lora_prompt_mapping,
is_prefill=not self.decode_only))
# Prompt adapter data.
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
prompt_adapter_mapping = None
if self.enable_prompt_adapter:
prompt_adapter_requests = set(
data.prompt_adapter_request for data in self.inter_data_list
if data.prompt_adapter_request is not None)
prompt_adapter_index_mapping = flatten_2d_lists([
inter_data.prompt_adapter_index_mapping
for inter_data in self.inter_data_list
])
if cuda_graph_pad_size:
prompt_adapter_index_mapping.extend(
itertools.repeat(0, cuda_graph_pad_size))
prompt_adapter_prompt_mapping = flatten_2d_lists([
inter_data.prompt_adapter_prompt_mapping
for inter_data in self.inter_data_list
])
prompt_adapter_mapping = PromptAdapterMapping(
prompt_adapter_index_mapping,
prompt_adapter_prompt_mapping,
)
# Multi-modal data.
multi_modal_kwargs_list = [
data.multi_modal_kwargs for data in self.inter_data_list
if data.multi_modal_kwargs is not None
]
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
return self.model_input_cls(
input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor,
attn_metadata=attn_metadata,
seq_lens=seq_lens,
query_lens=query_lens,
lora_mapping=lora_mapping,
lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs,
request_ids_to_seq_ids=request_ids_to_seq_ids,
finished_requests_ids=self.finished_requests_ids,
prompt_adapter_mapping=prompt_adapter_mapping,
prompt_adapter_requests=prompt_adapter_requests)
class MLUModelRunnerBase(GPUModelRunnerBase[TModelInputForGPU]):
"""
Helper class for shared methods between MLU model runners.
"""
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
return_hidden_states: bool = False,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
ModelRunnerBase.__init__(self, vllm_config)
model_config = self.model_config
cache_config = self.cache_config
self.is_driver_worker = is_driver_worker
self.return_hidden_states = return_hidden_states
self.device = self.device_config.device
self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
self.max_batchsize_to_capture = _get_max_graph_batch_size(
self.scheduler_config.max_num_seqs)
self.graph_runners: List[Dict[int, MLUGraphRunner]] = [
{} for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.graph_memory_pool: Optional[Tuple[
int, int]] = None # Set during graph capture.
self.has_inner_state = model_config.has_inner_state
# When using CUDA graph, the input block tables must be padded to
# max_seq_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max seq len to capture / block size).
self.graph_block_tables = np.zeros(
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
dtype=np.int32)
# Attention-free but stateful models like Mamba need a placeholder attn
# backend, as the attention metadata is needed to manage internal state.
# However we must bypass attention selection altogether for some models
# used for speculative decoding to avoid a divide-by-zero in
# model_config.get_head_size()
num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
needs_attn_backend = (num_attn_heads != 0
or self.model_config.is_attention_free)
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
) if needs_attn_backend else None
if self.attn_backend:
self.attn_state = self.attn_backend.get_state_cls()(
weakref.proxy(self))
else:
self.attn_state = CommonAttentionState(weakref.proxy(self))
# Multi-modal data support
self.input_registry = input_registry
self.mm_registry = mm_registry
self.multi_modal_input_mapper = mm_registry \
.create_input_mapper(model_config)
self.mm_registry.init_mm_limits_per_prompt(self.model_config)
# Lazy initialization
self.model: nn.Module # Set after load_model
# Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
set_cpu_offload_max_bytes(
int(self.cache_config.cpu_offload_gb * 1024**3))
# Used to cache python objects
self.inter_data_cache: Dict[int, PyObjectCache] = {}
# Using the PythonizationCache in Pipeline-Parallel clobbers the
# SequenceGroupToSample object. In Pipeline-Parallel, we have
# more than 1 Scheduler, resulting in a potential back-to-back
# prepare_model_inputs() call. This clobbers the cached
# SequenceGroupToSample objects, as we reset the cache during
# every prepare_model_inputs() call.
self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache() \
if self.parallel_config.pipeline_parallel_size == 1 else None
@torch.inference_mode()
def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
# passed in, which contains a lora from the lora warmup path.
dummy_lora_requests: List[LoRARequest] = []
dummy_lora_requests_per_seq: List[LoRARequest] = []
if self.lora_config:
assert self.lora_manager is not None
with self.lora_manager.dummy_lora_cache():
for idx in range(self.lora_config.max_loras):
lora_id = idx + 1
dummy_lora_request = LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=LORA_WARMUP_RANK)
dummy_lora_requests.append(dummy_lora_request)
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
]
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
# Additional GPU memory may be needed for multi-modal encoding, which
# needs to be accounted for when calculating the GPU blocks for
# vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
self.model_config)
if max_mm_tokens > 0:
max_num_seqs_orig = max_num_seqs
max_num_seqs = min(max_num_seqs,
max_num_batched_tokens // max_mm_tokens)
if max_num_seqs < 1:
expr = (f"min({max_num_seqs_orig}, "
f"{max_num_batched_tokens} // {max_mm_tokens})")
logger.warning(
"Computed max_num_seqs (%s) to be less than 1. "
"Setting it to the minimum value of 1.", expr)
max_num_seqs = 1
batch_size = 0
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
dummy_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
seq_len,
self.mm_registry)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: dummy_data.seq_data},
sampling_params=sampling_params,
block_tables=None,
lora_request=dummy_lora_requests_per_seq[group_id]
if dummy_lora_requests_per_seq else None,
multi_modal_data=dummy_data.multi_modal_data,
multi_modal_placeholders=dummy_data.multi_modal_placeholders,
)
seqs.append(seq)
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
# it is important to create tensors inside the loop, rather than
# multiplying the list, to avoid Dynamo from treating them as
# tensor aliasing.
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(num_layers)
]
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=batch_size,
dtype=self.model_config.dtype,
device=self.device)
graph_batch_size = self.max_batchsize_to_capture
batch_size_capture_list = [
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
]
if self.model_config.enforce_eager:
batch_size_capture_list = []
with set_compile_context(batch_size_capture_list):
self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.mlu.synchronize()
return
@torch.inference_mode()
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
"""Cuda graph capture a model.
Note that CUDA graph's performance gain is negligible if number
of batched tokens are larger than 200. And since CUDA graph
requires fixed sized tensors, supporting large/variable batch
size requires high GPU memory overhead. Thus, vLLM only captures
decoding requests. Mixed batch (chunked prefill + decoding) or
prefill requests are not captured.
Since it is used for decoding-only, it assumes there's only 1 token
per sequence in the batch.
"""
assert not self.model_config.enforce_eager
logger.info("Capturing the model for MLU graphs. This may lead to "
"unexpected consequences if the model is not static. To "
"run the model in eager mode, set 'enforce_eager=True' or "
"use '--enforce-eager' in the CLI.")
logger.info("MLU graphs can take additional 1~3 GiB memory per MLU. "
"If you are running out of memory, consider decreasing "
"`gpu_memory_utilization` or enforcing eager mode. "
"You can also reduce the `max_num_seqs` as needed "
"to decrease memory usage.")
start_time = time.perf_counter()
start_free_gpu_memory = torch.mlu.mem_get_info()[0]
# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = self.max_batchsize_to_capture
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).mlu()
input_positions = torch.zeros(max_batch_size, dtype=torch.int32).mlu()
if self.model_config.uses_mrope:
input_positions = torch.tile(input_positions, (3, 1))
# Prepare dummy previous_hidden_states only if needed by the model.
# This is used by draft models such as EAGLE.
previous_hidden_states = None
if "previous_hidden_states" in inspect.signature(
self.model.forward).parameters:
previous_hidden_states = torch.empty(
[max_batch_size,
self.model_config.get_hidden_size()],
dtype=self.model_config.dtype,
device=self.device)
intermediate_inputs = None
if not get_pp_group().is_first_rank:
intermediate_inputs = self.model.make_empty_intermediate_tensors(
batch_size=max_batch_size,
dtype=self.model_config.dtype,
device=self.device)
graph_batch_size = self.max_batchsize_to_capture
batch_size_capture_list = [
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
]
with self.attn_state.graph_capture(
max_batch_size), mlu_graph_capture() as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for virtual_engine in range(
self.parallel_config.pipeline_parallel_size):
for batch_size in reversed(batch_size_capture_list):
attn_metadata = (
self.attn_state.graph_capture_get_metadata_for_batch(
batch_size,
is_encoder_decoder_model=self.model_config.
is_encoder_decoder))
if self.lora_config:
lora_mapping = LoRAMapping(
**dict(index_mapping=[0] * batch_size,
prompt_mapping=[0] * batch_size,
is_prefill=False))
self.set_active_loras(set(), lora_mapping)
if self.prompt_adapter_config:
prompt_adapter_mapping = PromptAdapterMapping(
[-1] * batch_size,
[-1] * batch_size,
)
self.set_active_prompt_adapters(
set(), prompt_adapter_mapping)
graph_runner = MLUGraphRunner(
self.model, self.attn_backend.get_name(),
self.attn_state.graph_clone(batch_size),
self.model_config.is_encoder_decoder)
capture_inputs = {
"input_ids":
input_tokens[:batch_size],
"positions":
input_positions[..., :batch_size],
"intermediate_inputs":
intermediate_inputs[:batch_size]
if intermediate_inputs is not None else None,
"kv_caches":
kv_caches[virtual_engine],
"attn_metadata":
attn_metadata,
"memory_pool":
self.graph_memory_pool,
"stream":
graph_capture_context.stream
}
if previous_hidden_states is not None:
capture_inputs[
"previous_hidden_states"] = previous_hidden_states[:
batch_size]
if self.has_inner_state:
# Only used by Mamba-based models CUDA graph atm (Jamba)
capture_inputs.update({
"seqlen_agnostic_capture_inputs":
self.model.get_seqlen_agnostic_capture_inputs(
batch_size)
})
if self.model_config.is_encoder_decoder:
# add the additional inputs to capture for
# encoder-decoder models.
self._update_inputs_to_capture_for_enc_dec_model(
capture_inputs)
with set_forward_context(attn_metadata):
graph_runner.capture(**capture_inputs)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[virtual_engine][batch_size] = (
graph_runner)
end_time = time.perf_counter()
end_free_gpu_memory = torch.mlu.mem_get_info()[0]
elapsed_time = end_time - start_time
mlu_graph_size = start_free_gpu_memory - end_free_gpu_memory
# This usually takes < 10 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, mlu_graph_size / GiB_bytes)
class MLUModelRunner(MLUModelRunnerBase, ModelRunner):
"""
MLU model runner with sampling step.
"""
_builder_cls: Type[ModelInputForMLUBuilder] = ModelInputForMLUBuilder
@torch.inference_mode()
@dump_input_when_exception(exclude_args=[0], exclude_kwargs=["self"])
def execute_model(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in ModelRunner")
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
self.attn_state.begin_forward(model_input)
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
# TODO(andoorve): We can remove this once all
# virtual engines share the same kv cache.
virtual_engine = model_input.virtual_engine
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][
graph_batch_size]
else:
model_executable = self.model
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {}
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_start = torch.mlu.Event(enable_timing=True)
model_forward_end = torch.mlu.Event(enable_timing=True)
model_forward_start.record()
with set_forward_context(model_input.attn_metadata):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
**seqlen_agnostic_kwargs)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.record()
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
if (self.is_driver_worker
and hidden_or_intermediate_states is not None
and isinstance(hidden_or_intermediate_states,
IntermediateTensors)
and self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
hidden_or_intermediate_states.tensors["model_forward_time"] = (
torch.tensor(model_forward_time + orig_model_forward_time))
return hidden_or_intermediate_states
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
if not self.is_driver_worker:
return []
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
output: SamplerOutput = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
and output is not None):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
# If there are multiple workers, we are still tracking the latency
# from the start time of the driver worker to the end time of the
# driver worker. The model forward time will then end up covering
# the communication time as well.
output.model_forward_time = (orig_model_forward_time +
model_forward_time)
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
assert model_input.sampling_metadata is not None
indices = model_input.sampling_metadata.selected_token_indices
if model_input.is_prompt:
hidden_states = hidden_or_intermediate_states.index_select(
0, indices)
output.prefill_hidden_states = hidden_or_intermediate_states
elif decode_meta.use_cuda_graph:
hidden_states = hidden_or_intermediate_states[:len(indices)]
else:
hidden_states = hidden_or_intermediate_states
output.hidden_states = hidden_states
return [output]
class MLUGraphRunner(CUDAGraphRunner):
def capture(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_inputs: Optional[IntermediateTensors],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
memory_pool: Optional[Tuple[int, int]],
stream: torch.cuda.Stream,
**kwargs,
):
assert self._graph is None
# Run the model a few times without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
# Note one iteration is not enough for torch.jit.script
for _ in range(_NUM_WARMUP_ITERS):
self.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_inputs,
**kwargs,
)
# Wait for the warm up operations to finish before proceeding with
# Graph Capture.
torch.mlu.synchronize()
# Capture the graph.
self._graph = torch.mlu.MLUGraph()
with torch.mlu.graph(self._graph, pool=memory_pool, stream=stream):
output_hidden_or_intermediate_states = self.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_inputs,
**kwargs,
)
hidden_or_intermediate_states = (
output_hidden_or_intermediate_states)
del output_hidden_or_intermediate_states
# make sure `output_hidden_or_intermediate_states` is deleted
# in the graph's memory pool
gc.collect()
torch.mlu.synchronize()
# Save the input and output buffers.
self.input_buffers = {
"input_ids":
input_ids,
"positions":
positions,
"kv_caches":
kv_caches,
**self.attn_state.get_graph_input_buffers(
attn_metadata, self._is_encoder_decoder_model),
**kwargs,
}
if intermediate_inputs is not None:
self.input_buffers.update(intermediate_inputs.tensors)
if get_pp_group().is_last_rank:
self.output_buffers = {
"hidden_states": hidden_or_intermediate_states
}
else:
self.output_buffers = hidden_or_intermediate_states
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
**kwargs,
) -> torch.Tensor:
# KV caches are fixed tensors, so we don't need to copy them.
del kv_caches
# Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
self.input_buffers["positions"].copy_(positions, non_blocking=True)
if self.backend_name != "NO_ATTENTION":
self.input_buffers["slot_mapping"].copy_(
attn_metadata.slot_mapping, non_blocking=True)
self.attn_state.prepare_graph_input_buffers(
self.input_buffers, attn_metadata, self._is_encoder_decoder_model)
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
**kwargs)
if "previous_hidden_states" in self.input_buffers:
self.input_buffers["previous_hidden_states"].copy_(
kwargs["previous_hidden_states"], non_blocking=True)
if intermediate_tensors is not None:
for key in intermediate_tensors.tensors:
if key != "model_execute_time" and key != "model_forward_time":
self.input_buffers[key].copy_(intermediate_tensors[key],
non_blocking=True)
if self._is_encoder_decoder_model:
self.input_buffers["encoder_input_ids"].copy_(
kwargs['encoder_input_ids'], non_blocking=True)
self.input_buffers["encoder_positions"].copy_(
kwargs['encoder_positions'], non_blocking=True)
# Run the graph.
self.graph.replay()
# Return the output tensor.
if get_pp_group().is_last_rank:
return self.output_buffers["hidden_states"]
return self.output_buffers