621 lines
26 KiB
Python
621 lines
26 KiB
Python
|
|
#
|
||
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||
|
|
# This file is a part of the vllm-ascend project.
|
||
|
|
# Adapted from vllm-project/vllm/vllm/worker/model_runner.py
|
||
|
|
# Copyright 2023 The vLLM team.
|
||
|
|
#
|
||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
# you may not use this file except in compliance with the License.
|
||
|
|
# You may obtain a copy of the License at
|
||
|
|
#
|
||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
#
|
||
|
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
# See the License for the specific language governing permissions and
|
||
|
|
# limitations under the License.
|
||
|
|
#
|
||
|
|
|
||
|
|
import dataclasses
|
||
|
|
from typing import Any, Dict, List, Optional, Set, Type
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torch.distributed
|
||
|
|
from torch import nn
|
||
|
|
from vllm.distributed import get_pp_group
|
||
|
|
from vllm.logger import init_logger
|
||
|
|
from vllm.lora.layers import LoRAMapping
|
||
|
|
from vllm.lora.request import LoRARequest
|
||
|
|
from vllm.model_executor import SamplingMetadata
|
||
|
|
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderMap
|
||
|
|
from vllm.platforms import current_platform
|
||
|
|
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
||
|
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||
|
|
from vllm.sampling_params import SamplingParams
|
||
|
|
from vllm.sequence import SequenceGroupMetadata
|
||
|
|
from vllm.utils import flatten_2d_lists, make_tensor_with_pad
|
||
|
|
from vllm.worker.model_runner import (ModelInputForGPU,
|
||
|
|
ModelInputForGPUBuilder,
|
||
|
|
ModelInputForGPUWithSamplingMetadata,
|
||
|
|
ModelRunner)
|
||
|
|
|
||
|
|
logger = init_logger(__name__)
|
||
|
|
|
||
|
|
LORA_WARMUP_RANK = 8
|
||
|
|
|
||
|
|
|
||
|
|
class ModelInputForNPUBuilder(ModelInputForGPUBuilder):
|
||
|
|
"""Build ModelInputForGPU from SequenceGroupMetadata."""
|
||
|
|
|
||
|
|
# Note: ideally we would be using a dataclass(kw_only=True)
|
||
|
|
# here, so that this can be subclassed easily,
|
||
|
|
# but kw_only is not supported in python<3.10.
|
||
|
|
def build(self) -> ModelInputForGPU:
|
||
|
|
"""Finalize the builder intermediate data and
|
||
|
|
create on-device tensors.
|
||
|
|
"""
|
||
|
|
# Combine and flatten intermediate data.
|
||
|
|
input_tokens = [
|
||
|
|
flatten_2d_lists(inter_data.input_tokens)
|
||
|
|
for inter_data in self.inter_data_list
|
||
|
|
]
|
||
|
|
if not input_tokens:
|
||
|
|
# This may happen when all prefill requests hit
|
||
|
|
# prefix caching and there is no decode request.
|
||
|
|
return self.model_input_cls()
|
||
|
|
|
||
|
|
mrope_input_positions: Optional[List[List[int]]] = None
|
||
|
|
if any(inter_data.mrope_input_positions is not None
|
||
|
|
for inter_data in self.inter_data_list):
|
||
|
|
mrope_input_positions = [[] for _ in range(3)]
|
||
|
|
# calculate max position length for padding
|
||
|
|
input_position_lens = [
|
||
|
|
len(inter_data.input_positions[0])
|
||
|
|
for inter_data in self.inter_data_list
|
||
|
|
]
|
||
|
|
max_pos_len = max(input_position_lens)
|
||
|
|
|
||
|
|
for idx in range(3):
|
||
|
|
for inter_data in self.inter_data_list:
|
||
|
|
msections = inter_data.mrope_input_positions
|
||
|
|
if msections is None:
|
||
|
|
for _seq_input_positions in inter_data.input_positions:
|
||
|
|
# zero pad
|
||
|
|
_seq_input_positions.extend(
|
||
|
|
[0] *
|
||
|
|
(max_pos_len - len(_seq_input_positions)))
|
||
|
|
mrope_input_positions[idx].extend(
|
||
|
|
_seq_input_positions)
|
||
|
|
else:
|
||
|
|
for _seq_mrope_input_positions in msections:
|
||
|
|
# zero pad
|
||
|
|
_seq_mrope_input_positions[idx].extend(
|
||
|
|
[0] * (max_pos_len -
|
||
|
|
len(_seq_mrope_input_positions[idx])))
|
||
|
|
mrope_input_positions[idx].extend(
|
||
|
|
_seq_mrope_input_positions[idx])
|
||
|
|
input_positions = None
|
||
|
|
else:
|
||
|
|
input_positions = [
|
||
|
|
flatten_2d_lists(inter_data.input_positions)
|
||
|
|
for inter_data in self.inter_data_list
|
||
|
|
]
|
||
|
|
|
||
|
|
seq_lens = []
|
||
|
|
max_decode_seq_len = 0
|
||
|
|
for inter_data in self.inter_data_list:
|
||
|
|
seq_lens.extend(inter_data.seq_lens)
|
||
|
|
if not inter_data.is_prompt:
|
||
|
|
max_decode_seq_len = max(max_decode_seq_len,
|
||
|
|
max(inter_data.seq_lens))
|
||
|
|
query_lens = flatten_2d_lists(
|
||
|
|
[inter_data.query_lens for inter_data in self.inter_data_list])
|
||
|
|
# Mapping from request IDs to sequence IDs. Used for Jamba models
|
||
|
|
# that manages the cache by itself.
|
||
|
|
request_ids_to_seq_ids = {
|
||
|
|
data.request_id: data.seq_ids
|
||
|
|
for data in self.inter_data_list
|
||
|
|
}
|
||
|
|
|
||
|
|
batch_size = len(input_tokens)
|
||
|
|
|
||
|
|
# If cuda graph can be used, pad tensors accordingly.
|
||
|
|
# See `capture_model` API for more details.
|
||
|
|
# vLLM uses cuda graph only for decoding requests.
|
||
|
|
cuda_graph_pad_size = -1
|
||
|
|
|
||
|
|
if self.inter_data_list[0].is_prompt:
|
||
|
|
input_tokens_tensor = make_tensor_with_pad(
|
||
|
|
input_tokens, 0, dtype=torch.int, device=self.runner.device)
|
||
|
|
input_tokens_tensor = torch.flatten(input_tokens_tensor)
|
||
|
|
if mrope_input_positions is not None:
|
||
|
|
mrope_input_positions_tensor = make_tensor_with_pad(
|
||
|
|
mrope_input_positions,
|
||
|
|
0,
|
||
|
|
dtype=torch.int,
|
||
|
|
device=self.runner.device)
|
||
|
|
input_positions_tensor = torch.tensor(
|
||
|
|
mrope_input_positions_tensor,
|
||
|
|
dtype=torch.long,
|
||
|
|
device=self.runner.device)
|
||
|
|
else:
|
||
|
|
input_positions_tensor = make_tensor_with_pad(
|
||
|
|
input_positions,
|
||
|
|
0,
|
||
|
|
dtype=torch.int,
|
||
|
|
device=self.runner.device)
|
||
|
|
input_positions_tensor = torch.flatten(input_positions_tensor)
|
||
|
|
|
||
|
|
max_seq_len = max(seq_lens)
|
||
|
|
seq_lens = len(seq_lens) * [max_seq_len]
|
||
|
|
else:
|
||
|
|
input_tokens_tensor = torch.tensor(flatten_2d_lists(input_tokens),
|
||
|
|
dtype=torch.long,
|
||
|
|
device=self.runner.device)
|
||
|
|
if mrope_input_positions is not None:
|
||
|
|
input_positions_tensor = torch.tensor(
|
||
|
|
mrope_input_positions,
|
||
|
|
dtype=torch.long,
|
||
|
|
device=self.runner.device)
|
||
|
|
else:
|
||
|
|
input_positions_tensor = torch.tensor(
|
||
|
|
flatten_2d_lists(input_positions),
|
||
|
|
dtype=torch.long,
|
||
|
|
device=self.runner.device)
|
||
|
|
|
||
|
|
# Sequence and query lengths.
|
||
|
|
seq_lens.extend([1] * cuda_graph_pad_size)
|
||
|
|
|
||
|
|
# Attention metadata.
|
||
|
|
attn_metadata = self.attn_metadata_builder.build(
|
||
|
|
seq_lens, query_lens, cuda_graph_pad_size, batch_size)
|
||
|
|
|
||
|
|
# LoRA data.
|
||
|
|
lora_requests = set()
|
||
|
|
lora_mapping = None
|
||
|
|
if self.enable_lora:
|
||
|
|
lora_requests = set(r for data in self.inter_data_list
|
||
|
|
for r in data.lora_requests)
|
||
|
|
lora_index_mapping = flatten_2d_lists([
|
||
|
|
flatten_2d_lists(inter_data.lora_index_mapping)
|
||
|
|
for inter_data in self.inter_data_list
|
||
|
|
])
|
||
|
|
lora_index_mapping.extend([0] * cuda_graph_pad_size)
|
||
|
|
lora_prompt_mapping = flatten_2d_lists([
|
||
|
|
flatten_2d_lists(inter_data.lora_prompt_mapping)
|
||
|
|
for inter_data in self.inter_data_list
|
||
|
|
])
|
||
|
|
lora_mapping = LoRAMapping(
|
||
|
|
**dict(index_mapping=lora_index_mapping,
|
||
|
|
prompt_mapping=lora_prompt_mapping,
|
||
|
|
is_prefill=not self.decode_only))
|
||
|
|
|
||
|
|
# Prompt adapter data.
|
||
|
|
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
|
||
|
|
prompt_adapter_mapping = None
|
||
|
|
if self.enable_prompt_adapter:
|
||
|
|
prompt_adapter_requests = set(
|
||
|
|
data.prompt_adapter_request for data in self.inter_data_list
|
||
|
|
if data.prompt_adapter_request is not None)
|
||
|
|
prompt_adapter_index_mapping = flatten_2d_lists([
|
||
|
|
inter_data.prompt_adapter_index_mapping
|
||
|
|
for inter_data in self.inter_data_list
|
||
|
|
])
|
||
|
|
prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size)
|
||
|
|
prompt_adapter_prompt_mapping = flatten_2d_lists([
|
||
|
|
inter_data.prompt_adapter_prompt_mapping
|
||
|
|
for inter_data in self.inter_data_list
|
||
|
|
])
|
||
|
|
prompt_adapter_mapping = PromptAdapterMapping(
|
||
|
|
prompt_adapter_index_mapping,
|
||
|
|
prompt_adapter_prompt_mapping,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Multi-modal data.
|
||
|
|
multi_modal_kwargs_list = [
|
||
|
|
data.multi_modal_kwargs for data in self.inter_data_list
|
||
|
|
if data.multi_modal_kwargs is not None
|
||
|
|
]
|
||
|
|
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||
|
|
|
||
|
|
return self.model_input_cls(
|
||
|
|
input_tokens=input_tokens_tensor,
|
||
|
|
input_positions=input_positions_tensor,
|
||
|
|
attn_metadata=attn_metadata,
|
||
|
|
seq_lens=seq_lens,
|
||
|
|
query_lens=query_lens,
|
||
|
|
lora_mapping=lora_mapping,
|
||
|
|
lora_requests=lora_requests,
|
||
|
|
multi_modal_kwargs=multi_modal_kwargs,
|
||
|
|
request_ids_to_seq_ids=request_ids_to_seq_ids,
|
||
|
|
finished_requests_ids=self.finished_requests_ids,
|
||
|
|
prompt_adapter_mapping=prompt_adapter_mapping,
|
||
|
|
prompt_adapter_requests=prompt_adapter_requests)
|
||
|
|
|
||
|
|
class InterDataForSeqGroup:
|
||
|
|
"""Intermediate data for the current sequence group."""
|
||
|
|
|
||
|
|
def simple_reinit(self):
|
||
|
|
self.input_tokens[0].clear() # type: ignore
|
||
|
|
self.input_positions[0].clear() # type: ignore
|
||
|
|
self.token_types[0].clear() # type: ignore
|
||
|
|
self.mrope_input_positions = None # type: ignore
|
||
|
|
self.seq_lens[0] = 0 # type: ignore
|
||
|
|
self.orig_seq_lens[0] = 0 # type: ignore
|
||
|
|
self.query_lens[0] = 0 # type: ignore
|
||
|
|
self.context_lens[0] = 0 # type: ignore
|
||
|
|
self.curr_sliding_window_blocks[0] = 0 # type: ignore
|
||
|
|
self.lora_index_mapping.clear() # type: ignore
|
||
|
|
self.lora_prompt_mapping.clear() # type: ignore
|
||
|
|
self.lora_requests.clear() # type: ignore
|
||
|
|
self.prompt_adapter_index_mapping.clear() # type: ignore
|
||
|
|
self.prompt_adapter_prompt_mapping.clear() # type: ignore
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
# From sequence group metadata.
|
||
|
|
request_id: str,
|
||
|
|
seq_ids: List[int],
|
||
|
|
is_prompt: bool,
|
||
|
|
block_tables: Optional[Dict[int, List[int]]],
|
||
|
|
computed_block_nums: List[int],
|
||
|
|
n_seqs: int = 0,
|
||
|
|
|
||
|
|
# Input tokens and positions.
|
||
|
|
input_tokens: Optional[List[List[int]]] = None,
|
||
|
|
input_positions: Optional[List[List[int]]] = None,
|
||
|
|
token_types: Optional[List[List[int]]] = None,
|
||
|
|
mrope_input_positions: Optional[List[List[List[int]]]] = None,
|
||
|
|
|
||
|
|
# The sequence length (may be capped to the sliding window).
|
||
|
|
seq_lens: Optional[List[int]] = None,
|
||
|
|
# The original sequence length (before applying sliding window).
|
||
|
|
# This is used to compute slot mapping.
|
||
|
|
orig_seq_lens: Optional[List[int]] = None,
|
||
|
|
# The query length.
|
||
|
|
query_lens: Optional[List[int]] = None,
|
||
|
|
# The number of tokens that are already computed.
|
||
|
|
context_lens: Optional[List[int]] = None,
|
||
|
|
# The current sliding window block.
|
||
|
|
curr_sliding_window_blocks: Optional[List[int]] = None,
|
||
|
|
|
||
|
|
# LoRA inputs.
|
||
|
|
lora_index_mapping: Optional[List[List[int]]] = None,
|
||
|
|
lora_prompt_mapping: Optional[List[List[int]]] = None,
|
||
|
|
lora_requests: Optional[Set[LoRARequest]] = None,
|
||
|
|
|
||
|
|
# Prompt adapter inputs.
|
||
|
|
prompt_adapter_index_mapping: Optional[List[int]] = None,
|
||
|
|
prompt_adapter_prompt_mapping: Optional[List[int]] = None,
|
||
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||
|
|
|
||
|
|
# Multi-modal inputs.
|
||
|
|
multi_modal_kwargs: Optional[MultiModalKwargs] = None,
|
||
|
|
multi_modal_placeholder_maps: Optional[Dict[
|
||
|
|
str, MultiModalPlaceholderMap]] = None,
|
||
|
|
|
||
|
|
# Whether the prefix cache is hit (prefill only).
|
||
|
|
prefix_cache_hit: bool = False,
|
||
|
|
reinit: bool = False,
|
||
|
|
reinit_use_defaults: bool = False,
|
||
|
|
encoder_seq_len: int = 0,
|
||
|
|
):
|
||
|
|
if reinit:
|
||
|
|
assert len(self.seq_ids) == len(seq_ids) # type: ignore
|
||
|
|
for i, seq_id in enumerate(seq_ids):
|
||
|
|
self.seq_ids[i] = seq_id # type: ignore
|
||
|
|
else:
|
||
|
|
self.seq_ids = seq_ids
|
||
|
|
|
||
|
|
self.request_id = request_id
|
||
|
|
self.is_prompt = is_prompt
|
||
|
|
self.block_tables = block_tables
|
||
|
|
self.computed_block_nums = computed_block_nums
|
||
|
|
self.n_seqs = n_seqs
|
||
|
|
self.encoder_seq_len = encoder_seq_len
|
||
|
|
|
||
|
|
if reinit:
|
||
|
|
if len(self.seq_ids) == 1 and reinit_use_defaults:
|
||
|
|
self.simple_reinit()
|
||
|
|
else:
|
||
|
|
if input_tokens:
|
||
|
|
self.input_tokens = input_tokens
|
||
|
|
else:
|
||
|
|
for seq_id in range(len(self.seq_ids)):
|
||
|
|
self.input_tokens[seq_id].clear()
|
||
|
|
|
||
|
|
if input_positions:
|
||
|
|
self.input_positions = input_positions
|
||
|
|
else:
|
||
|
|
for seq_id in range(len(self.seq_ids)):
|
||
|
|
self.input_positions[seq_id].clear()
|
||
|
|
|
||
|
|
if token_types:
|
||
|
|
self.token_types = token_types
|
||
|
|
else:
|
||
|
|
for seq_id in range(len(self.seq_ids)):
|
||
|
|
self.token_types[seq_id].clear()
|
||
|
|
|
||
|
|
self.mrope_input_positions = None
|
||
|
|
|
||
|
|
if seq_lens:
|
||
|
|
self.seq_lens = seq_lens
|
||
|
|
else:
|
||
|
|
for seq_id in range(len(self.seq_ids)):
|
||
|
|
self.seq_lens[seq_id] = 0
|
||
|
|
|
||
|
|
if orig_seq_lens:
|
||
|
|
self.orig_seq_lens = orig_seq_lens
|
||
|
|
else:
|
||
|
|
for seq_id in range(len(self.seq_ids)):
|
||
|
|
self.orig_seq_lens[seq_id] = 0
|
||
|
|
|
||
|
|
if query_lens:
|
||
|
|
self.query_lens = query_lens
|
||
|
|
else:
|
||
|
|
for seq_id in range(len(self.seq_ids)):
|
||
|
|
self.query_lens[seq_id] = 0
|
||
|
|
|
||
|
|
if context_lens:
|
||
|
|
self.context_lens = context_lens
|
||
|
|
else:
|
||
|
|
for seq_id in range(len(self.seq_ids)):
|
||
|
|
self.context_lens[seq_id] = 0
|
||
|
|
|
||
|
|
if curr_sliding_window_blocks:
|
||
|
|
self.curr_sliding_window_blocks = \
|
||
|
|
curr_sliding_window_blocks
|
||
|
|
else:
|
||
|
|
for seq_id in range(len(self.seq_ids)):
|
||
|
|
self.curr_sliding_window_blocks[seq_id] = 0
|
||
|
|
|
||
|
|
if lora_index_mapping:
|
||
|
|
self.lora_index_mapping = lora_index_mapping
|
||
|
|
else:
|
||
|
|
self.lora_index_mapping.clear()
|
||
|
|
|
||
|
|
if lora_prompt_mapping:
|
||
|
|
self.lora_prompt_mapping = lora_prompt_mapping
|
||
|
|
else:
|
||
|
|
self.lora_prompt_mapping.clear()
|
||
|
|
|
||
|
|
if lora_requests:
|
||
|
|
self.lora_requests = lora_requests
|
||
|
|
else:
|
||
|
|
self.lora_requests.clear()
|
||
|
|
|
||
|
|
if prompt_adapter_index_mapping:
|
||
|
|
self.prompt_adapter_index_mapping = \
|
||
|
|
prompt_adapter_index_mapping
|
||
|
|
else:
|
||
|
|
self.prompt_adapter_index_mapping.clear()
|
||
|
|
|
||
|
|
if prompt_adapter_prompt_mapping:
|
||
|
|
self.prompt_adapter_prompt_mapping = \
|
||
|
|
prompt_adapter_prompt_mapping
|
||
|
|
else:
|
||
|
|
self.prompt_adapter_prompt_mapping.clear()
|
||
|
|
|
||
|
|
else:
|
||
|
|
self.input_tokens = input_tokens or []
|
||
|
|
self.input_positions = input_positions or []
|
||
|
|
self.token_types = token_types or []
|
||
|
|
self.mrope_input_positions = mrope_input_positions or None
|
||
|
|
self.seq_lens = seq_lens or []
|
||
|
|
self.orig_seq_lens = orig_seq_lens or []
|
||
|
|
self.query_lens = query_lens or []
|
||
|
|
self.context_lens = context_lens or []
|
||
|
|
self.curr_sliding_window_blocks = \
|
||
|
|
curr_sliding_window_blocks or []
|
||
|
|
|
||
|
|
self.lora_index_mapping = lora_index_mapping or []
|
||
|
|
self.lora_prompt_mapping = lora_prompt_mapping or []
|
||
|
|
self.lora_requests = lora_requests or set()
|
||
|
|
|
||
|
|
self.prompt_adapter_index_mapping = (
|
||
|
|
prompt_adapter_index_mapping or [])
|
||
|
|
self.prompt_adapter_prompt_mapping = (
|
||
|
|
prompt_adapter_prompt_mapping or [])
|
||
|
|
|
||
|
|
self.prompt_adapter_request = prompt_adapter_request
|
||
|
|
self.multi_modal_kwargs = multi_modal_kwargs
|
||
|
|
self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
|
||
|
|
self.prefix_cache_hit = prefix_cache_hit
|
||
|
|
|
||
|
|
self.n_seqs = len(self.seq_ids)
|
||
|
|
|
||
|
|
if not reinit:
|
||
|
|
self.__post_init__()
|
||
|
|
|
||
|
|
def __post_init__(self):
|
||
|
|
self.n_seqs = len(self.seq_ids)
|
||
|
|
|
||
|
|
self.input_tokens = [[] for _ in range(self.n_seqs)]
|
||
|
|
self.input_positions = [[] for _ in range(self.n_seqs)]
|
||
|
|
self.token_types = [[] for _ in range(self.n_seqs)]
|
||
|
|
self.mrope_input_positions = None
|
||
|
|
self.seq_lens = [0] * self.n_seqs
|
||
|
|
self.orig_seq_lens = [0] * self.n_seqs
|
||
|
|
self.query_lens = [0] * self.n_seqs
|
||
|
|
self.context_lens = [0] * self.n_seqs
|
||
|
|
self.curr_sliding_window_blocks = [0] * self.n_seqs
|
||
|
|
|
||
|
|
self.lora_index_mapping = []
|
||
|
|
self.lora_prompt_mapping = []
|
||
|
|
|
||
|
|
|
||
|
|
class NPUModelRunner(ModelRunner):
|
||
|
|
"""
|
||
|
|
NPU model runner with sampling step.
|
||
|
|
"""
|
||
|
|
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
|
||
|
|
ModelInputForGPUWithSamplingMetadata)
|
||
|
|
_builder_cls: Type[ModelInputForNPUBuilder] = ModelInputForNPUBuilder
|
||
|
|
|
||
|
|
def make_model_input_from_broadcasted_tensor_dict(
|
||
|
|
self,
|
||
|
|
tensor_dict: Dict[str, Any],
|
||
|
|
) -> ModelInputForGPUWithSamplingMetadata:
|
||
|
|
model_input = \
|
||
|
|
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
|
||
|
|
tensor_dict,
|
||
|
|
attn_backend=self.attn_backend,
|
||
|
|
)
|
||
|
|
return model_input
|
||
|
|
|
||
|
|
@current_platform.inference_mode()
|
||
|
|
def profile_run(self) -> None:
|
||
|
|
# Enable top-k sampling to reflect the accurate memory usage.
|
||
|
|
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
||
|
|
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||
|
|
max_num_seqs = self.scheduler_config.max_num_seqs
|
||
|
|
# This represents the maximum number of different requests
|
||
|
|
# that will have unique loras, an therefore the max amount of memory
|
||
|
|
# consumption create dummy lora request copies from the lora request
|
||
|
|
# passed in, which contains a lora from the lora warmup path.
|
||
|
|
dummy_lora_requests: List[LoRARequest] = []
|
||
|
|
dummy_lora_requests_per_seq: List[LoRARequest] = []
|
||
|
|
if self.lora_config:
|
||
|
|
assert self.lora_manager is not None
|
||
|
|
with self.lora_manager.dummy_lora_cache():
|
||
|
|
for idx in range(self.lora_config.max_loras):
|
||
|
|
lora_id = idx + 1
|
||
|
|
dummy_lora_request = LoRARequest(
|
||
|
|
lora_name=f"warmup_{lora_id}",
|
||
|
|
lora_int_id=lora_id,
|
||
|
|
lora_path="/not/a/real/path",
|
||
|
|
)
|
||
|
|
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
||
|
|
rank=LORA_WARMUP_RANK)
|
||
|
|
dummy_lora_requests.append(dummy_lora_request)
|
||
|
|
dummy_lora_requests_per_seq = [
|
||
|
|
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
||
|
|
for idx in range(max_num_seqs)
|
||
|
|
]
|
||
|
|
|
||
|
|
# Profile memory usage with max_num_sequences sequences and the total
|
||
|
|
# number of tokens equal to max_num_batched_tokens.
|
||
|
|
seqs: List[SequenceGroupMetadata] = []
|
||
|
|
# Additional GPU memory may be needed for multi-modal encoding, which
|
||
|
|
# needs to be accounted for when calculating the GPU blocks for
|
||
|
|
# vLLM blocker manager.
|
||
|
|
# To exercise the worst scenario for GPU memory consumption,
|
||
|
|
# the number of seqs (batch_size) is chosen to maximize the number
|
||
|
|
# of images processed.
|
||
|
|
|
||
|
|
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
|
||
|
|
self.model_config)
|
||
|
|
if max_mm_tokens > 0:
|
||
|
|
max_num_seqs_orig = max_num_seqs
|
||
|
|
max_num_seqs = min(max_num_seqs,
|
||
|
|
max_num_batched_tokens // max_mm_tokens)
|
||
|
|
if max_num_seqs < 1:
|
||
|
|
expr = (f"min({max_num_seqs_orig}, "
|
||
|
|
f"{max_num_batched_tokens} // {max_mm_tokens})")
|
||
|
|
logger.warning(
|
||
|
|
"Computed max_num_seqs (%s) to be less than 1. "
|
||
|
|
"Setting it to the minimum value of 1.", expr)
|
||
|
|
max_num_seqs = 1
|
||
|
|
|
||
|
|
batch_size = 0
|
||
|
|
for group_id in range(max_num_seqs):
|
||
|
|
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||
|
|
(group_id < max_num_batched_tokens % max_num_seqs))
|
||
|
|
batch_size += seq_len
|
||
|
|
|
||
|
|
dummy_data = self.input_registry \
|
||
|
|
.dummy_data_for_profiling(self.model_config,
|
||
|
|
seq_len,
|
||
|
|
self.mm_registry)
|
||
|
|
|
||
|
|
seq = SequenceGroupMetadata(
|
||
|
|
request_id=str(group_id),
|
||
|
|
is_prompt=True,
|
||
|
|
seq_data={group_id: dummy_data.seq_data},
|
||
|
|
sampling_params=sampling_params,
|
||
|
|
block_tables=None,
|
||
|
|
lora_request=dummy_lora_requests_per_seq[group_id]
|
||
|
|
if dummy_lora_requests_per_seq else None,
|
||
|
|
multi_modal_data=dummy_data.multi_modal_data,
|
||
|
|
multi_modal_placeholders=dummy_data.multi_modal_placeholders,
|
||
|
|
)
|
||
|
|
seqs.append(seq)
|
||
|
|
|
||
|
|
# Run the model with the dummy inputs.
|
||
|
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||
|
|
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||
|
|
# it by reference, rather by specializing on the value ``None``.
|
||
|
|
# the `dtype` argument does not matter, and we use `float32` as
|
||
|
|
# a placeholder (it has wide hardware support).
|
||
|
|
# it is important to create tensors inside the loop, rather than
|
||
|
|
# multiplying the list, to avoid Dynamo from treating them as
|
||
|
|
# tensor aliasing.
|
||
|
|
kv_caches = [
|
||
|
|
torch.tensor([], dtype=torch.float32, device=self.device)
|
||
|
|
for _ in range(num_layers)
|
||
|
|
]
|
||
|
|
finished_requests_ids = [seq.request_id for seq in seqs]
|
||
|
|
model_input = self.prepare_model_input(
|
||
|
|
seqs, finished_requests_ids=finished_requests_ids)
|
||
|
|
intermediate_tensors = None
|
||
|
|
if not get_pp_group().is_first_rank:
|
||
|
|
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
||
|
|
batch_size=batch_size,
|
||
|
|
dtype=self.model_config.dtype,
|
||
|
|
device=self.device)
|
||
|
|
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||
|
|
current_platform.synchronize()
|
||
|
|
return
|
||
|
|
|
||
|
|
@current_platform.inference_mode()
|
||
|
|
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
|
||
|
|
"""NPU graph capture a model.
|
||
|
|
TODO: not support now
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
def prepare_model_input(
|
||
|
|
self,
|
||
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||
|
|
virtual_engine: int = 0,
|
||
|
|
finished_requests_ids: Optional[List[str]] = None,
|
||
|
|
) -> ModelInputForGPUWithSamplingMetadata:
|
||
|
|
"""Prepare the model input based on a given sequence group, including
|
||
|
|
metadata for the sampling step.
|
||
|
|
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
|
||
|
|
The result tensors and data structure also batches input in prefill
|
||
|
|
-> decode order. For example,
|
||
|
|
- input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||
|
|
- input_tokens[num_prefill_tokens:] contains decode tokens.
|
||
|
|
If cuda graph is required, this API automatically pads inputs.
|
||
|
|
"""
|
||
|
|
model_input = self._prepare_model_input_tensors(
|
||
|
|
seq_group_metadata_list, finished_requests_ids)
|
||
|
|
if get_pp_group().is_last_rank:
|
||
|
|
# Sampling metadata is only required for the final pp group
|
||
|
|
generators = self.get_generators(finished_requests_ids)
|
||
|
|
sampling_metadata = SamplingMetadata.prepare(
|
||
|
|
seq_group_metadata_list,
|
||
|
|
model_input.seq_lens,
|
||
|
|
model_input.query_lens,
|
||
|
|
self.device,
|
||
|
|
self.pin_memory,
|
||
|
|
generators,
|
||
|
|
self.sampling_metadata_cache,
|
||
|
|
# TODO (cmq): enable this after supported in vllm
|
||
|
|
# pad_for_invariant_seq_len=True,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
sampling_metadata = None
|
||
|
|
is_prompt = (seq_group_metadata_list[0].is_prompt
|
||
|
|
if seq_group_metadata_list else None)
|
||
|
|
return dataclasses.replace(model_input,
|
||
|
|
sampling_metadata=sampling_metadata,
|
||
|
|
is_prompt=is_prompt,
|
||
|
|
virtual_engine=virtual_engine)
|
||
|
|
|
||
|
|
def get_model(self) -> nn.Module:
|
||
|
|
return self.model
|