forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
862
vllm-v0.6.2/vllm/worker/mlu_model_runner.py
Normal file
862
vllm-v0.6.2/vllm/worker/mlu_model_runner.py
Normal file
@@ -0,0 +1,862 @@
|
||||
import gc
|
||||
import inspect
|
||||
import itertools
|
||||
import time
|
||||
import weakref
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import (Dict, List, Optional, Set, Tuple, Type, Union)
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor import SamplingMetadataCache
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||
MultiModalRegistry)
|
||||
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.prompt_adapter.worker_manager import (
|
||||
LRUCacheWorkerPromptAdapterManager)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.utils import (GiB_bytes, PyObjectCache,
|
||||
async_tensor_h2d, flatten_2d_lists,
|
||||
is_pin_memory_available)
|
||||
from vllm.worker.model_runner_base import (ModelRunnerBase,
|
||||
dump_input_when_exception)
|
||||
from vllm.worker.model_runner import (
|
||||
TModelInputForGPU, ModelInputForGPU,
|
||||
ModelInputForGPUWithSamplingMetadata,
|
||||
ModelInputForGPUBuilder, GPUModelRunnerBase,
|
||||
ModelRunner, CUDAGraphRunner,
|
||||
ModelRunnerBase, LORA_WARMUP_RANK,
|
||||
_NUM_WARMUP_ITERS, _get_max_graph_batch_size,
|
||||
_BATCH_SIZES_TO_CAPTURE
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLUGraphCaptureContext:
|
||||
stream: torch.mlu.Stream
|
||||
|
||||
|
||||
@contextmanager
|
||||
def mlu_graph_capture(graph_capture_context: Optional[MLUGraphCaptureContext] = None):
|
||||
if graph_capture_context is None:
|
||||
stream = torch.mlu.Stream()
|
||||
graph_capture_context = MLUGraphCaptureContext(stream)
|
||||
else:
|
||||
stream = graph_capture_context.stream
|
||||
|
||||
# ensure all initialization operations complete before attempting to
|
||||
# capture the graph on another stream
|
||||
curr_stream = torch.mlu.current_stream()
|
||||
if curr_stream != stream:
|
||||
stream.wait_stream(curr_stream)
|
||||
|
||||
with torch.mlu.stream(stream):
|
||||
yield graph_capture_context
|
||||
|
||||
|
||||
class ModelInputForMLUBuilder(ModelInputForGPUBuilder):
|
||||
"""Build ModelInputForGPU from SequenceGroupMetadata."""
|
||||
|
||||
def build(self) -> ModelInputForGPU:
|
||||
"""Finalize the builder intermediate data and
|
||||
create on-device tensors.
|
||||
"""
|
||||
# Combine and flatten intermediate data.
|
||||
input_tokens = []
|
||||
for inter_data in self.inter_data_list:
|
||||
for cur_input_tokens in inter_data.input_tokens:
|
||||
input_tokens.extend(cur_input_tokens)
|
||||
|
||||
if not input_tokens:
|
||||
# This may happen when all prefill requests hit
|
||||
# prefix caching and there is no decode request.
|
||||
return self.model_input_cls()
|
||||
|
||||
mrope_input_positions: Optional[List[List[int]]] = None
|
||||
if any(inter_data.mrope_input_positions is not None
|
||||
for inter_data in self.inter_data_list):
|
||||
mrope_input_positions = [[] for _ in range(3)]
|
||||
for idx in range(3):
|
||||
for inter_data in self.inter_data_list:
|
||||
msections = inter_data.mrope_input_positions
|
||||
if msections is None:
|
||||
for _seq_input_positions in inter_data.input_positions:
|
||||
mrope_input_positions[idx].extend(
|
||||
_seq_input_positions)
|
||||
else:
|
||||
for _seq_mrope_input_positions in msections:
|
||||
mrope_input_positions[idx].extend(
|
||||
_seq_mrope_input_positions[idx])
|
||||
input_positions = None
|
||||
else:
|
||||
input_positions = []
|
||||
for inter_data in self.inter_data_list:
|
||||
for cur_input_positions in inter_data.input_positions:
|
||||
input_positions.extend(cur_input_positions)
|
||||
|
||||
seq_lens = []
|
||||
query_lens = []
|
||||
max_decode_seq_len = 0
|
||||
max_encoder_seq_len = 0
|
||||
for inter_data in self.inter_data_list:
|
||||
seq_lens.extend(inter_data.seq_lens)
|
||||
query_lens.extend(inter_data.query_lens)
|
||||
if not inter_data.is_prompt:
|
||||
max_decode_seq_len = max(max_decode_seq_len,
|
||||
max(inter_data.seq_lens))
|
||||
if self.runner.model_config.is_encoder_decoder:
|
||||
max_encoder_seq_len = max(max_encoder_seq_len,
|
||||
inter_data.encoder_seq_len)
|
||||
|
||||
# Mapping from request IDs to sequence IDs. Used for Jamba models
|
||||
# that manages the cache by itself.
|
||||
request_ids_to_seq_ids = {
|
||||
data.request_id: data.seq_ids
|
||||
for data in self.inter_data_list
|
||||
}
|
||||
|
||||
cuda_graph_pad_size = self._get_cuda_graph_pad_size(
|
||||
num_seqs=len(seq_lens),
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
max_encoder_seq_len=max_encoder_seq_len)
|
||||
|
||||
batch_size = len(input_tokens)
|
||||
if cuda_graph_pad_size != -1:
|
||||
# If cuda graph can be used, pad tensors accordingly.
|
||||
# See `capture_model` API for more details.
|
||||
# vLLM uses cuda graph only for decoding requests.
|
||||
batch_size += cuda_graph_pad_size
|
||||
|
||||
# Tokens and positions.
|
||||
if cuda_graph_pad_size:
|
||||
input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
|
||||
assert self.runner.device is not None
|
||||
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
if mrope_input_positions is not None:
|
||||
for idx in range(3):
|
||||
mrope_input_positions[idx].extend(
|
||||
itertools.repeat(0, cuda_graph_pad_size))
|
||||
input_positions_tensor = async_tensor_h2d(mrope_input_positions,
|
||||
torch.int32,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
else:
|
||||
input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
|
||||
input_positions_tensor = async_tensor_h2d(input_positions,
|
||||
torch.int32,
|
||||
self.runner.device,
|
||||
self.runner.pin_memory)
|
||||
# Sequence and query lengths.
|
||||
if cuda_graph_pad_size:
|
||||
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
|
||||
|
||||
# Attention metadata.
|
||||
attn_metadata = self.attn_metadata_builder.build(
|
||||
seq_lens, query_lens, cuda_graph_pad_size, batch_size)
|
||||
|
||||
# LoRA data.
|
||||
lora_requests = set()
|
||||
lora_mapping = None
|
||||
if self.enable_lora:
|
||||
lora_requests = set(r for data in self.inter_data_list
|
||||
for r in data.lora_requests)
|
||||
lora_index_mapping = flatten_2d_lists([
|
||||
flatten_2d_lists(inter_data.lora_index_mapping)
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
if cuda_graph_pad_size:
|
||||
lora_index_mapping.extend(
|
||||
itertools.repeat(0, cuda_graph_pad_size))
|
||||
lora_prompt_mapping = flatten_2d_lists([
|
||||
flatten_2d_lists(inter_data.lora_prompt_mapping)
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
|
||||
lora_mapping = LoRAMapping(
|
||||
**dict(index_mapping=lora_index_mapping,
|
||||
prompt_mapping=lora_prompt_mapping,
|
||||
is_prefill=not self.decode_only))
|
||||
|
||||
# Prompt adapter data.
|
||||
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
|
||||
prompt_adapter_mapping = None
|
||||
if self.enable_prompt_adapter:
|
||||
prompt_adapter_requests = set(
|
||||
data.prompt_adapter_request for data in self.inter_data_list
|
||||
if data.prompt_adapter_request is not None)
|
||||
prompt_adapter_index_mapping = flatten_2d_lists([
|
||||
inter_data.prompt_adapter_index_mapping
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
if cuda_graph_pad_size:
|
||||
prompt_adapter_index_mapping.extend(
|
||||
itertools.repeat(0, cuda_graph_pad_size))
|
||||
prompt_adapter_prompt_mapping = flatten_2d_lists([
|
||||
inter_data.prompt_adapter_prompt_mapping
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
prompt_adapter_mapping = PromptAdapterMapping(
|
||||
prompt_adapter_index_mapping,
|
||||
prompt_adapter_prompt_mapping,
|
||||
)
|
||||
|
||||
# Multi-modal data.
|
||||
multi_modal_kwargs_list = [
|
||||
data.multi_modal_kwargs for data in self.inter_data_list
|
||||
if data.multi_modal_kwargs is not None
|
||||
]
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
return self.model_input_cls(
|
||||
input_tokens=input_tokens_tensor,
|
||||
input_positions=input_positions_tensor,
|
||||
attn_metadata=attn_metadata,
|
||||
seq_lens=seq_lens,
|
||||
query_lens=query_lens,
|
||||
lora_mapping=lora_mapping,
|
||||
lora_requests=lora_requests,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
request_ids_to_seq_ids=request_ids_to_seq_ids,
|
||||
finished_requests_ids=self.finished_requests_ids,
|
||||
prompt_adapter_mapping=prompt_adapter_mapping,
|
||||
prompt_adapter_requests=prompt_adapter_requests)
|
||||
|
||||
|
||||
class MLUModelRunnerBase(GPUModelRunnerBase[TModelInputForGPU]):
|
||||
"""
|
||||
Helper class for shared methods between MLU model runners.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
return_hidden_states: bool = False,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
):
|
||||
|
||||
ModelRunnerBase.__init__(self, vllm_config)
|
||||
model_config = self.model_config
|
||||
cache_config = self.cache_config
|
||||
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
self.device = self.device_config.device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.block_size = cache_config.block_size
|
||||
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
|
||||
self.max_batchsize_to_capture = _get_max_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs)
|
||||
|
||||
self.graph_runners: List[Dict[int, MLUGraphRunner]] = [
|
||||
{} for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
self.graph_memory_pool: Optional[Tuple[
|
||||
int, int]] = None # Set during graph capture.
|
||||
|
||||
self.has_inner_state = model_config.has_inner_state
|
||||
|
||||
# When using CUDA graph, the input block tables must be padded to
|
||||
# max_seq_len_to_capture. However, creating the block table in
|
||||
# Python can be expensive. To optimize this, we cache the block table
|
||||
# in numpy and only copy the actual input content at every iteration.
|
||||
# The shape of the cached block table will be
|
||||
# (max batch size to capture, max seq len to capture / block size).
|
||||
self.graph_block_tables = np.zeros(
|
||||
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
|
||||
dtype=np.int32)
|
||||
|
||||
# Attention-free but stateful models like Mamba need a placeholder attn
|
||||
# backend, as the attention metadata is needed to manage internal state.
|
||||
# However we must bypass attention selection altogether for some models
|
||||
# used for speculative decoding to avoid a divide-by-zero in
|
||||
# model_config.get_head_size()
|
||||
num_attn_heads = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config)
|
||||
needs_attn_backend = (num_attn_heads != 0
|
||||
or self.model_config.is_attention_free)
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
) if needs_attn_backend else None
|
||||
if self.attn_backend:
|
||||
self.attn_state = self.attn_backend.get_state_cls()(
|
||||
weakref.proxy(self))
|
||||
else:
|
||||
self.attn_state = CommonAttentionState(weakref.proxy(self))
|
||||
|
||||
# Multi-modal data support
|
||||
self.input_registry = input_registry
|
||||
self.mm_registry = mm_registry
|
||||
self.multi_modal_input_mapper = mm_registry \
|
||||
.create_input_mapper(model_config)
|
||||
self.mm_registry.init_mm_limits_per_prompt(self.model_config)
|
||||
|
||||
# Lazy initialization
|
||||
self.model: nn.Module # Set after load_model
|
||||
# Set after load_model.
|
||||
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
|
||||
self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
|
||||
|
||||
set_cpu_offload_max_bytes(
|
||||
int(self.cache_config.cpu_offload_gb * 1024**3))
|
||||
|
||||
# Used to cache python objects
|
||||
self.inter_data_cache: Dict[int, PyObjectCache] = {}
|
||||
|
||||
# Using the PythonizationCache in Pipeline-Parallel clobbers the
|
||||
# SequenceGroupToSample object. In Pipeline-Parallel, we have
|
||||
# more than 1 Scheduler, resulting in a potential back-to-back
|
||||
# prepare_model_inputs() call. This clobbers the cached
|
||||
# SequenceGroupToSample objects, as we reset the cache during
|
||||
# every prepare_model_inputs() call.
|
||||
self.sampling_metadata_cache: SamplingMetadataCache = \
|
||||
SamplingMetadataCache() \
|
||||
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
# Enable top-k sampling to reflect the accurate memory usage.
|
||||
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||
# This represents the maximum number of different requests
|
||||
# that will have unique loras, an therefore the max amount of memory
|
||||
# consumption create dummy lora request copies from the lora request
|
||||
# passed in, which contains a lora from the lora warmup path.
|
||||
dummy_lora_requests: List[LoRARequest] = []
|
||||
dummy_lora_requests_per_seq: List[LoRARequest] = []
|
||||
if self.lora_config:
|
||||
assert self.lora_manager is not None
|
||||
with self.lora_manager.dummy_lora_cache():
|
||||
for idx in range(self.lora_config.max_loras):
|
||||
lora_id = idx + 1
|
||||
dummy_lora_request = LoRARequest(
|
||||
lora_name=f"warmup_{lora_id}",
|
||||
lora_int_id=lora_id,
|
||||
lora_path="/not/a/real/path",
|
||||
)
|
||||
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
||||
rank=LORA_WARMUP_RANK)
|
||||
dummy_lora_requests.append(dummy_lora_request)
|
||||
dummy_lora_requests_per_seq = [
|
||||
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
||||
for idx in range(max_num_seqs)
|
||||
]
|
||||
|
||||
# Profile memory usage with max_num_sequences sequences and the total
|
||||
# number of tokens equal to max_num_batched_tokens.
|
||||
seqs: List[SequenceGroupMetadata] = []
|
||||
# Additional GPU memory may be needed for multi-modal encoding, which
|
||||
# needs to be accounted for when calculating the GPU blocks for
|
||||
# vLLM blocker manager.
|
||||
# To exercise the worst scenario for GPU memory consumption,
|
||||
# the number of seqs (batch_size) is chosen to maximize the number
|
||||
# of images processed.
|
||||
|
||||
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
|
||||
self.model_config)
|
||||
if max_mm_tokens > 0:
|
||||
max_num_seqs_orig = max_num_seqs
|
||||
max_num_seqs = min(max_num_seqs,
|
||||
max_num_batched_tokens // max_mm_tokens)
|
||||
if max_num_seqs < 1:
|
||||
expr = (f"min({max_num_seqs_orig}, "
|
||||
f"{max_num_batched_tokens} // {max_mm_tokens})")
|
||||
logger.warning(
|
||||
"Computed max_num_seqs (%s) to be less than 1. "
|
||||
"Setting it to the minimum value of 1.", expr)
|
||||
max_num_seqs = 1
|
||||
|
||||
batch_size = 0
|
||||
for group_id in range(max_num_seqs):
|
||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
batch_size += seq_len
|
||||
|
||||
dummy_data = self.input_registry \
|
||||
.dummy_data_for_profiling(self.model_config,
|
||||
seq_len,
|
||||
self.mm_registry)
|
||||
|
||||
seq = SequenceGroupMetadata(
|
||||
request_id=str(group_id),
|
||||
is_prompt=True,
|
||||
seq_data={group_id: dummy_data.seq_data},
|
||||
sampling_params=sampling_params,
|
||||
block_tables=None,
|
||||
lora_request=dummy_lora_requests_per_seq[group_id]
|
||||
if dummy_lora_requests_per_seq else None,
|
||||
multi_modal_data=dummy_data.multi_modal_data,
|
||||
multi_modal_placeholders=dummy_data.multi_modal_placeholders,
|
||||
)
|
||||
seqs.append(seq)
|
||||
|
||||
# Run the model with the dummy inputs.
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
# the `dtype` argument does not matter, and we use `float32` as
|
||||
# a placeholder (it has wide hardware support).
|
||||
# it is important to create tensors inside the loop, rather than
|
||||
# multiplying the list, to avoid Dynamo from treating them as
|
||||
# tensor aliasing.
|
||||
kv_caches = [
|
||||
torch.tensor([], dtype=torch.float32, device=self.device)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||
model_input = self.prepare_model_input(
|
||||
seqs, finished_requests_ids=finished_requests_ids)
|
||||
intermediate_tensors = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
||||
batch_size=batch_size,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
|
||||
graph_batch_size = self.max_batchsize_to_capture
|
||||
batch_size_capture_list = [
|
||||
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
||||
]
|
||||
if self.model_config.enforce_eager:
|
||||
batch_size_capture_list = []
|
||||
with set_compile_context(batch_size_capture_list):
|
||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||
torch.mlu.synchronize()
|
||||
return
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
|
||||
"""Cuda graph capture a model.
|
||||
|
||||
Note that CUDA graph's performance gain is negligible if number
|
||||
of batched tokens are larger than 200. And since CUDA graph
|
||||
requires fixed sized tensors, supporting large/variable batch
|
||||
size requires high GPU memory overhead. Thus, vLLM only captures
|
||||
decoding requests. Mixed batch (chunked prefill + decoding) or
|
||||
prefill requests are not captured.
|
||||
|
||||
Since it is used for decoding-only, it assumes there's only 1 token
|
||||
per sequence in the batch.
|
||||
"""
|
||||
assert not self.model_config.enforce_eager
|
||||
logger.info("Capturing the model for MLU graphs. This may lead to "
|
||||
"unexpected consequences if the model is not static. To "
|
||||
"run the model in eager mode, set 'enforce_eager=True' or "
|
||||
"use '--enforce-eager' in the CLI.")
|
||||
logger.info("MLU graphs can take additional 1~3 GiB memory per MLU. "
|
||||
"If you are running out of memory, consider decreasing "
|
||||
"`gpu_memory_utilization` or enforcing eager mode. "
|
||||
"You can also reduce the `max_num_seqs` as needed "
|
||||
"to decrease memory usage.")
|
||||
start_time = time.perf_counter()
|
||||
start_free_gpu_memory = torch.mlu.mem_get_info()[0]
|
||||
|
||||
# Prepare dummy inputs. These will be reused for all batch sizes.
|
||||
max_batch_size = self.max_batchsize_to_capture
|
||||
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).mlu()
|
||||
input_positions = torch.zeros(max_batch_size, dtype=torch.int32).mlu()
|
||||
if self.model_config.uses_mrope:
|
||||
input_positions = torch.tile(input_positions, (3, 1))
|
||||
# Prepare dummy previous_hidden_states only if needed by the model.
|
||||
# This is used by draft models such as EAGLE.
|
||||
previous_hidden_states = None
|
||||
if "previous_hidden_states" in inspect.signature(
|
||||
self.model.forward).parameters:
|
||||
previous_hidden_states = torch.empty(
|
||||
[max_batch_size,
|
||||
self.model_config.get_hidden_size()],
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
|
||||
intermediate_inputs = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_inputs = self.model.make_empty_intermediate_tensors(
|
||||
batch_size=max_batch_size,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
|
||||
graph_batch_size = self.max_batchsize_to_capture
|
||||
batch_size_capture_list = [
|
||||
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
||||
]
|
||||
|
||||
with self.attn_state.graph_capture(
|
||||
max_batch_size), mlu_graph_capture() as graph_capture_context:
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for virtual_engine in range(
|
||||
self.parallel_config.pipeline_parallel_size):
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
attn_metadata = (
|
||||
self.attn_state.graph_capture_get_metadata_for_batch(
|
||||
batch_size,
|
||||
is_encoder_decoder_model=self.model_config.
|
||||
is_encoder_decoder))
|
||||
|
||||
if self.lora_config:
|
||||
lora_mapping = LoRAMapping(
|
||||
**dict(index_mapping=[0] * batch_size,
|
||||
prompt_mapping=[0] * batch_size,
|
||||
is_prefill=False))
|
||||
self.set_active_loras(set(), lora_mapping)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
prompt_adapter_mapping = PromptAdapterMapping(
|
||||
[-1] * batch_size,
|
||||
[-1] * batch_size,
|
||||
)
|
||||
self.set_active_prompt_adapters(
|
||||
set(), prompt_adapter_mapping)
|
||||
|
||||
graph_runner = MLUGraphRunner(
|
||||
self.model, self.attn_backend.get_name(),
|
||||
self.attn_state.graph_clone(batch_size),
|
||||
self.model_config.is_encoder_decoder)
|
||||
|
||||
capture_inputs = {
|
||||
"input_ids":
|
||||
input_tokens[:batch_size],
|
||||
"positions":
|
||||
input_positions[..., :batch_size],
|
||||
"intermediate_inputs":
|
||||
intermediate_inputs[:batch_size]
|
||||
if intermediate_inputs is not None else None,
|
||||
"kv_caches":
|
||||
kv_caches[virtual_engine],
|
||||
"attn_metadata":
|
||||
attn_metadata,
|
||||
"memory_pool":
|
||||
self.graph_memory_pool,
|
||||
"stream":
|
||||
graph_capture_context.stream
|
||||
}
|
||||
if previous_hidden_states is not None:
|
||||
capture_inputs[
|
||||
"previous_hidden_states"] = previous_hidden_states[:
|
||||
batch_size]
|
||||
|
||||
if self.has_inner_state:
|
||||
# Only used by Mamba-based models CUDA graph atm (Jamba)
|
||||
capture_inputs.update({
|
||||
"seqlen_agnostic_capture_inputs":
|
||||
self.model.get_seqlen_agnostic_capture_inputs(
|
||||
batch_size)
|
||||
})
|
||||
if self.model_config.is_encoder_decoder:
|
||||
# add the additional inputs to capture for
|
||||
# encoder-decoder models.
|
||||
self._update_inputs_to_capture_for_enc_dec_model(
|
||||
capture_inputs)
|
||||
|
||||
with set_forward_context(attn_metadata):
|
||||
graph_runner.capture(**capture_inputs)
|
||||
self.graph_memory_pool = graph_runner.graph.pool()
|
||||
self.graph_runners[virtual_engine][batch_size] = (
|
||||
graph_runner)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
end_free_gpu_memory = torch.mlu.mem_get_info()[0]
|
||||
elapsed_time = end_time - start_time
|
||||
mlu_graph_size = start_free_gpu_memory - end_free_gpu_memory
|
||||
# This usually takes < 10 seconds.
|
||||
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
||||
elapsed_time, mlu_graph_size / GiB_bytes)
|
||||
|
||||
|
||||
class MLUModelRunner(MLUModelRunnerBase, ModelRunner):
|
||||
"""
|
||||
MLU model runner with sampling step.
|
||||
"""
|
||||
_builder_cls: Type[ModelInputForMLUBuilder] = ModelInputForMLUBuilder
|
||||
|
||||
@torch.inference_mode()
|
||||
@dump_input_when_exception(exclude_args=[0], exclude_kwargs=["self"])
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForGPUWithSamplingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError("num_steps > 1 is not supported in ModelRunner")
|
||||
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
assert model_input.prompt_adapter_requests is not None
|
||||
assert model_input.prompt_adapter_mapping is not None
|
||||
self.set_active_prompt_adapters(
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
|
||||
self.attn_state.begin_forward(model_input)
|
||||
|
||||
# Currently cuda graph is only supported by the decode phase.
|
||||
assert model_input.attn_metadata is not None
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
decode_meta = model_input.attn_metadata.decode_metadata
|
||||
# TODO(andoorve): We can remove this once all
|
||||
# virtual engines share the same kv cache.
|
||||
virtual_engine = model_input.virtual_engine
|
||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[virtual_engine][
|
||||
graph_batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
seqlen_agnostic_kwargs = {
|
||||
"finished_requests_ids": model_input.finished_requests_ids,
|
||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||
} if self.has_inner_state else {}
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_start = torch.mlu.Event(enable_timing=True)
|
||||
model_forward_end = torch.mlu.Event(enable_timing=True)
|
||||
model_forward_start.record()
|
||||
|
||||
with set_forward_context(model_input.attn_metadata):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**seqlen_agnostic_kwargs)
|
||||
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.record()
|
||||
|
||||
# Compute the logits in the last pipeline stage.
|
||||
if not get_pp_group().is_last_rank:
|
||||
if (self.is_driver_worker
|
||||
and hidden_or_intermediate_states is not None
|
||||
and isinstance(hidden_or_intermediate_states,
|
||||
IntermediateTensors)
|
||||
and self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
hidden_or_intermediate_states.tensors["model_forward_time"] = (
|
||||
torch.tensor(model_forward_time + orig_model_forward_time))
|
||||
return hidden_or_intermediate_states
|
||||
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
|
||||
# Sample the next token.
|
||||
output: SamplerOutput = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time
|
||||
and output is not None):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
# If there are multiple workers, we are still tracking the latency
|
||||
# from the start time of the driver worker to the end time of the
|
||||
# driver worker. The model forward time will then end up covering
|
||||
# the communication time as well.
|
||||
output.model_forward_time = (orig_model_forward_time +
|
||||
model_forward_time)
|
||||
|
||||
if self.return_hidden_states:
|
||||
# we only need to pass hidden states of most recent token
|
||||
assert model_input.sampling_metadata is not None
|
||||
indices = model_input.sampling_metadata.selected_token_indices
|
||||
if model_input.is_prompt:
|
||||
hidden_states = hidden_or_intermediate_states.index_select(
|
||||
0, indices)
|
||||
output.prefill_hidden_states = hidden_or_intermediate_states
|
||||
elif decode_meta.use_cuda_graph:
|
||||
hidden_states = hidden_or_intermediate_states[:len(indices)]
|
||||
else:
|
||||
hidden_states = hidden_or_intermediate_states
|
||||
|
||||
output.hidden_states = hidden_states
|
||||
|
||||
return [output]
|
||||
|
||||
|
||||
class MLUGraphRunner(CUDAGraphRunner):
|
||||
|
||||
def capture(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_inputs: Optional[IntermediateTensors],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
memory_pool: Optional[Tuple[int, int]],
|
||||
stream: torch.cuda.Stream,
|
||||
**kwargs,
|
||||
):
|
||||
assert self._graph is None
|
||||
# Run the model a few times without capturing the graph.
|
||||
# This is to make sure that the captured graph does not include the
|
||||
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
||||
# Note one iteration is not enough for torch.jit.script
|
||||
for _ in range(_NUM_WARMUP_ITERS):
|
||||
self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
# Wait for the warm up operations to finish before proceeding with
|
||||
# Graph Capture.
|
||||
torch.mlu.synchronize()
|
||||
# Capture the graph.
|
||||
self._graph = torch.mlu.MLUGraph()
|
||||
with torch.mlu.graph(self._graph, pool=memory_pool, stream=stream):
|
||||
output_hidden_or_intermediate_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_or_intermediate_states = (
|
||||
output_hidden_or_intermediate_states)
|
||||
|
||||
del output_hidden_or_intermediate_states
|
||||
# make sure `output_hidden_or_intermediate_states` is deleted
|
||||
# in the graph's memory pool
|
||||
gc.collect()
|
||||
torch.mlu.synchronize()
|
||||
|
||||
# Save the input and output buffers.
|
||||
self.input_buffers = {
|
||||
"input_ids":
|
||||
input_ids,
|
||||
"positions":
|
||||
positions,
|
||||
"kv_caches":
|
||||
kv_caches,
|
||||
**self.attn_state.get_graph_input_buffers(
|
||||
attn_metadata, self._is_encoder_decoder_model),
|
||||
**kwargs,
|
||||
}
|
||||
if intermediate_inputs is not None:
|
||||
self.input_buffers.update(intermediate_inputs.tensors)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.output_buffers = {
|
||||
"hidden_states": hidden_or_intermediate_states
|
||||
}
|
||||
else:
|
||||
self.output_buffers = hidden_or_intermediate_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# KV caches are fixed tensors, so we don't need to copy them.
|
||||
del kv_caches
|
||||
|
||||
# Copy the input tensors to the input buffers.
|
||||
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
|
||||
self.input_buffers["positions"].copy_(positions, non_blocking=True)
|
||||
|
||||
if self.backend_name != "NO_ATTENTION":
|
||||
self.input_buffers["slot_mapping"].copy_(
|
||||
attn_metadata.slot_mapping, non_blocking=True)
|
||||
|
||||
self.attn_state.prepare_graph_input_buffers(
|
||||
self.input_buffers, attn_metadata, self._is_encoder_decoder_model)
|
||||
|
||||
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
|
||||
self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
|
||||
**kwargs)
|
||||
|
||||
if "previous_hidden_states" in self.input_buffers:
|
||||
self.input_buffers["previous_hidden_states"].copy_(
|
||||
kwargs["previous_hidden_states"], non_blocking=True)
|
||||
|
||||
if intermediate_tensors is not None:
|
||||
for key in intermediate_tensors.tensors:
|
||||
if key != "model_execute_time" and key != "model_forward_time":
|
||||
self.input_buffers[key].copy_(intermediate_tensors[key],
|
||||
non_blocking=True)
|
||||
if self._is_encoder_decoder_model:
|
||||
self.input_buffers["encoder_input_ids"].copy_(
|
||||
kwargs['encoder_input_ids'], non_blocking=True)
|
||||
self.input_buffers["encoder_positions"].copy_(
|
||||
kwargs['encoder_positions'], non_blocking=True)
|
||||
|
||||
# Run the graph.
|
||||
self.graph.replay()
|
||||
# Return the output tensor.
|
||||
if get_pp_group().is_last_rank:
|
||||
return self.output_buffers["hidden_states"]
|
||||
|
||||
return self.output_buffers
|
||||
Reference in New Issue
Block a user