### What this PR does / why we need it?
vLLM Ascend plugin (vllm-ascend) is a backend plugin for running vLLM on
the Ascend NPU.
This plugin is the recommended approach for supporting the Ascend
backend within the vLLM community. It adheres to the principles outlined
in the [RFC]: Hardware pluggable, providing a hardware-pluggable
interface that decouples the integration of the Ascend NPU with vLLM.
This patch also include changes to make CI work and use cache speed up
e2e test, including:
1. Change push (post merge ci) and pull_request (pr ci) trigger branch
to main
2. Make mypy work by ignore base_communicator and clear unused deps
3. Several improvements for vllm_ascend_test:
- use cache (pip, ms, hf) speed up e2e test (25mins --> 5mins)
- switch `git clone` command to `action/checkout` to speedup checkout
and
- Enable sv for pytest for better info dump
- Remove network host to resole `docker: conflicting ontions: cannot
attach both user-defined and non-user-definednetwork-modes`, which is a
problem on docker 1.45 but not on 1.39.
4. Adapt MLA decode optimizations:
cabaf4eff3
### Does this PR introduce _any_ user-facing change?
Yes, init the PR.
### How was this patch tested?
- This is the first PR to make ascend NPU work on vLLM. All code is
tested on ascend with vLLM V0 Engine.
- CI passed
---------
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: wangshuai09 <391746016@qq.com>
Co-authored-by: Shanshan Shen <467638484@qq.com>
Co-authored-by: wangli <wangli858794774@gmail.com>
621 lines
26 KiB
Python
621 lines
26 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm-project/vllm/vllm/worker/model_runner.py
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import dataclasses
|
|
from typing import Any, Dict, List, Optional, Set, Type
|
|
|
|
import torch
|
|
import torch.distributed
|
|
from torch import nn
|
|
from vllm.distributed import get_pp_group
|
|
from vllm.logger import init_logger
|
|
from vllm.lora.layers import LoRAMapping
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.model_executor import SamplingMetadata
|
|
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderMap
|
|
from vllm.platforms import current_platform
|
|
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.sequence import SequenceGroupMetadata
|
|
from vllm.utils import flatten_2d_lists, make_tensor_with_pad
|
|
from vllm.worker.model_runner import (ModelInputForGPU,
|
|
ModelInputForGPUBuilder,
|
|
ModelInputForGPUWithSamplingMetadata,
|
|
ModelRunner)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
LORA_WARMUP_RANK = 8
|
|
|
|
|
|
class ModelInputForNPUBuilder(ModelInputForGPUBuilder):
|
|
"""Build ModelInputForGPU from SequenceGroupMetadata."""
|
|
|
|
# Note: ideally we would be using a dataclass(kw_only=True)
|
|
# here, so that this can be subclassed easily,
|
|
# but kw_only is not supported in python<3.10.
|
|
def build(self) -> ModelInputForGPU:
|
|
"""Finalize the builder intermediate data and
|
|
create on-device tensors.
|
|
"""
|
|
# Combine and flatten intermediate data.
|
|
input_tokens = [
|
|
flatten_2d_lists(inter_data.input_tokens)
|
|
for inter_data in self.inter_data_list
|
|
]
|
|
if not input_tokens:
|
|
# This may happen when all prefill requests hit
|
|
# prefix caching and there is no decode request.
|
|
return self.model_input_cls()
|
|
|
|
mrope_input_positions: Optional[List[List[int]]] = None
|
|
if any(inter_data.mrope_input_positions is not None
|
|
for inter_data in self.inter_data_list):
|
|
mrope_input_positions = [[] for _ in range(3)]
|
|
# calculate max position length for padding
|
|
input_position_lens = [
|
|
len(inter_data.input_positions[0])
|
|
for inter_data in self.inter_data_list
|
|
]
|
|
max_pos_len = max(input_position_lens)
|
|
|
|
for idx in range(3):
|
|
for inter_data in self.inter_data_list:
|
|
msections = inter_data.mrope_input_positions
|
|
if msections is None:
|
|
for _seq_input_positions in inter_data.input_positions:
|
|
# zero pad
|
|
_seq_input_positions.extend(
|
|
[0] *
|
|
(max_pos_len - len(_seq_input_positions)))
|
|
mrope_input_positions[idx].extend(
|
|
_seq_input_positions)
|
|
else:
|
|
for _seq_mrope_input_positions in msections:
|
|
# zero pad
|
|
_seq_mrope_input_positions[idx].extend(
|
|
[0] * (max_pos_len -
|
|
len(_seq_mrope_input_positions[idx])))
|
|
mrope_input_positions[idx].extend(
|
|
_seq_mrope_input_positions[idx])
|
|
input_positions = None
|
|
else:
|
|
input_positions = [
|
|
flatten_2d_lists(inter_data.input_positions)
|
|
for inter_data in self.inter_data_list
|
|
]
|
|
|
|
seq_lens = []
|
|
max_decode_seq_len = 0
|
|
for inter_data in self.inter_data_list:
|
|
seq_lens.extend(inter_data.seq_lens)
|
|
if not inter_data.is_prompt:
|
|
max_decode_seq_len = max(max_decode_seq_len,
|
|
max(inter_data.seq_lens))
|
|
query_lens = flatten_2d_lists(
|
|
[inter_data.query_lens for inter_data in self.inter_data_list])
|
|
# Mapping from request IDs to sequence IDs. Used for Jamba models
|
|
# that manages the cache by itself.
|
|
request_ids_to_seq_ids = {
|
|
data.request_id: data.seq_ids
|
|
for data in self.inter_data_list
|
|
}
|
|
|
|
batch_size = len(input_tokens)
|
|
|
|
# If cuda graph can be used, pad tensors accordingly.
|
|
# See `capture_model` API for more details.
|
|
# vLLM uses cuda graph only for decoding requests.
|
|
cuda_graph_pad_size = -1
|
|
|
|
if self.inter_data_list[0].is_prompt:
|
|
input_tokens_tensor = make_tensor_with_pad(
|
|
input_tokens, 0, dtype=torch.int, device=self.runner.device)
|
|
input_tokens_tensor = torch.flatten(input_tokens_tensor)
|
|
if mrope_input_positions is not None:
|
|
mrope_input_positions_tensor = make_tensor_with_pad(
|
|
mrope_input_positions,
|
|
0,
|
|
dtype=torch.int,
|
|
device=self.runner.device)
|
|
input_positions_tensor = torch.tensor(
|
|
mrope_input_positions_tensor,
|
|
dtype=torch.long,
|
|
device=self.runner.device)
|
|
else:
|
|
input_positions_tensor = make_tensor_with_pad(
|
|
input_positions,
|
|
0,
|
|
dtype=torch.int,
|
|
device=self.runner.device)
|
|
input_positions_tensor = torch.flatten(input_positions_tensor)
|
|
|
|
max_seq_len = max(seq_lens)
|
|
seq_lens = len(seq_lens) * [max_seq_len]
|
|
else:
|
|
input_tokens_tensor = torch.tensor(flatten_2d_lists(input_tokens),
|
|
dtype=torch.long,
|
|
device=self.runner.device)
|
|
if mrope_input_positions is not None:
|
|
input_positions_tensor = torch.tensor(
|
|
mrope_input_positions,
|
|
dtype=torch.long,
|
|
device=self.runner.device)
|
|
else:
|
|
input_positions_tensor = torch.tensor(
|
|
flatten_2d_lists(input_positions),
|
|
dtype=torch.long,
|
|
device=self.runner.device)
|
|
|
|
# Sequence and query lengths.
|
|
seq_lens.extend([1] * cuda_graph_pad_size)
|
|
|
|
# Attention metadata.
|
|
attn_metadata = self.attn_metadata_builder.build(
|
|
seq_lens, query_lens, cuda_graph_pad_size, batch_size)
|
|
|
|
# LoRA data.
|
|
lora_requests = set()
|
|
lora_mapping = None
|
|
if self.enable_lora:
|
|
lora_requests = set(r for data in self.inter_data_list
|
|
for r in data.lora_requests)
|
|
lora_index_mapping = flatten_2d_lists([
|
|
flatten_2d_lists(inter_data.lora_index_mapping)
|
|
for inter_data in self.inter_data_list
|
|
])
|
|
lora_index_mapping.extend([0] * cuda_graph_pad_size)
|
|
lora_prompt_mapping = flatten_2d_lists([
|
|
flatten_2d_lists(inter_data.lora_prompt_mapping)
|
|
for inter_data in self.inter_data_list
|
|
])
|
|
lora_mapping = LoRAMapping(
|
|
**dict(index_mapping=lora_index_mapping,
|
|
prompt_mapping=lora_prompt_mapping,
|
|
is_prefill=not self.decode_only))
|
|
|
|
# Prompt adapter data.
|
|
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
|
|
prompt_adapter_mapping = None
|
|
if self.enable_prompt_adapter:
|
|
prompt_adapter_requests = set(
|
|
data.prompt_adapter_request for data in self.inter_data_list
|
|
if data.prompt_adapter_request is not None)
|
|
prompt_adapter_index_mapping = flatten_2d_lists([
|
|
inter_data.prompt_adapter_index_mapping
|
|
for inter_data in self.inter_data_list
|
|
])
|
|
prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size)
|
|
prompt_adapter_prompt_mapping = flatten_2d_lists([
|
|
inter_data.prompt_adapter_prompt_mapping
|
|
for inter_data in self.inter_data_list
|
|
])
|
|
prompt_adapter_mapping = PromptAdapterMapping(
|
|
prompt_adapter_index_mapping,
|
|
prompt_adapter_prompt_mapping,
|
|
)
|
|
|
|
# Multi-modal data.
|
|
multi_modal_kwargs_list = [
|
|
data.multi_modal_kwargs for data in self.inter_data_list
|
|
if data.multi_modal_kwargs is not None
|
|
]
|
|
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
|
|
|
return self.model_input_cls(
|
|
input_tokens=input_tokens_tensor,
|
|
input_positions=input_positions_tensor,
|
|
attn_metadata=attn_metadata,
|
|
seq_lens=seq_lens,
|
|
query_lens=query_lens,
|
|
lora_mapping=lora_mapping,
|
|
lora_requests=lora_requests,
|
|
multi_modal_kwargs=multi_modal_kwargs,
|
|
request_ids_to_seq_ids=request_ids_to_seq_ids,
|
|
finished_requests_ids=self.finished_requests_ids,
|
|
prompt_adapter_mapping=prompt_adapter_mapping,
|
|
prompt_adapter_requests=prompt_adapter_requests)
|
|
|
|
class InterDataForSeqGroup:
|
|
"""Intermediate data for the current sequence group."""
|
|
|
|
def simple_reinit(self):
|
|
self.input_tokens[0].clear() # type: ignore
|
|
self.input_positions[0].clear() # type: ignore
|
|
self.token_types[0].clear() # type: ignore
|
|
self.mrope_input_positions = None # type: ignore
|
|
self.seq_lens[0] = 0 # type: ignore
|
|
self.orig_seq_lens[0] = 0 # type: ignore
|
|
self.query_lens[0] = 0 # type: ignore
|
|
self.context_lens[0] = 0 # type: ignore
|
|
self.curr_sliding_window_blocks[0] = 0 # type: ignore
|
|
self.lora_index_mapping.clear() # type: ignore
|
|
self.lora_prompt_mapping.clear() # type: ignore
|
|
self.lora_requests.clear() # type: ignore
|
|
self.prompt_adapter_index_mapping.clear() # type: ignore
|
|
self.prompt_adapter_prompt_mapping.clear() # type: ignore
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
# From sequence group metadata.
|
|
request_id: str,
|
|
seq_ids: List[int],
|
|
is_prompt: bool,
|
|
block_tables: Optional[Dict[int, List[int]]],
|
|
computed_block_nums: List[int],
|
|
n_seqs: int = 0,
|
|
|
|
# Input tokens and positions.
|
|
input_tokens: Optional[List[List[int]]] = None,
|
|
input_positions: Optional[List[List[int]]] = None,
|
|
token_types: Optional[List[List[int]]] = None,
|
|
mrope_input_positions: Optional[List[List[List[int]]]] = None,
|
|
|
|
# The sequence length (may be capped to the sliding window).
|
|
seq_lens: Optional[List[int]] = None,
|
|
# The original sequence length (before applying sliding window).
|
|
# This is used to compute slot mapping.
|
|
orig_seq_lens: Optional[List[int]] = None,
|
|
# The query length.
|
|
query_lens: Optional[List[int]] = None,
|
|
# The number of tokens that are already computed.
|
|
context_lens: Optional[List[int]] = None,
|
|
# The current sliding window block.
|
|
curr_sliding_window_blocks: Optional[List[int]] = None,
|
|
|
|
# LoRA inputs.
|
|
lora_index_mapping: Optional[List[List[int]]] = None,
|
|
lora_prompt_mapping: Optional[List[List[int]]] = None,
|
|
lora_requests: Optional[Set[LoRARequest]] = None,
|
|
|
|
# Prompt adapter inputs.
|
|
prompt_adapter_index_mapping: Optional[List[int]] = None,
|
|
prompt_adapter_prompt_mapping: Optional[List[int]] = None,
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
|
|
# Multi-modal inputs.
|
|
multi_modal_kwargs: Optional[MultiModalKwargs] = None,
|
|
multi_modal_placeholder_maps: Optional[Dict[
|
|
str, MultiModalPlaceholderMap]] = None,
|
|
|
|
# Whether the prefix cache is hit (prefill only).
|
|
prefix_cache_hit: bool = False,
|
|
reinit: bool = False,
|
|
reinit_use_defaults: bool = False,
|
|
encoder_seq_len: int = 0,
|
|
):
|
|
if reinit:
|
|
assert len(self.seq_ids) == len(seq_ids) # type: ignore
|
|
for i, seq_id in enumerate(seq_ids):
|
|
self.seq_ids[i] = seq_id # type: ignore
|
|
else:
|
|
self.seq_ids = seq_ids
|
|
|
|
self.request_id = request_id
|
|
self.is_prompt = is_prompt
|
|
self.block_tables = block_tables
|
|
self.computed_block_nums = computed_block_nums
|
|
self.n_seqs = n_seqs
|
|
self.encoder_seq_len = encoder_seq_len
|
|
|
|
if reinit:
|
|
if len(self.seq_ids) == 1 and reinit_use_defaults:
|
|
self.simple_reinit()
|
|
else:
|
|
if input_tokens:
|
|
self.input_tokens = input_tokens
|
|
else:
|
|
for seq_id in range(len(self.seq_ids)):
|
|
self.input_tokens[seq_id].clear()
|
|
|
|
if input_positions:
|
|
self.input_positions = input_positions
|
|
else:
|
|
for seq_id in range(len(self.seq_ids)):
|
|
self.input_positions[seq_id].clear()
|
|
|
|
if token_types:
|
|
self.token_types = token_types
|
|
else:
|
|
for seq_id in range(len(self.seq_ids)):
|
|
self.token_types[seq_id].clear()
|
|
|
|
self.mrope_input_positions = None
|
|
|
|
if seq_lens:
|
|
self.seq_lens = seq_lens
|
|
else:
|
|
for seq_id in range(len(self.seq_ids)):
|
|
self.seq_lens[seq_id] = 0
|
|
|
|
if orig_seq_lens:
|
|
self.orig_seq_lens = orig_seq_lens
|
|
else:
|
|
for seq_id in range(len(self.seq_ids)):
|
|
self.orig_seq_lens[seq_id] = 0
|
|
|
|
if query_lens:
|
|
self.query_lens = query_lens
|
|
else:
|
|
for seq_id in range(len(self.seq_ids)):
|
|
self.query_lens[seq_id] = 0
|
|
|
|
if context_lens:
|
|
self.context_lens = context_lens
|
|
else:
|
|
for seq_id in range(len(self.seq_ids)):
|
|
self.context_lens[seq_id] = 0
|
|
|
|
if curr_sliding_window_blocks:
|
|
self.curr_sliding_window_blocks = \
|
|
curr_sliding_window_blocks
|
|
else:
|
|
for seq_id in range(len(self.seq_ids)):
|
|
self.curr_sliding_window_blocks[seq_id] = 0
|
|
|
|
if lora_index_mapping:
|
|
self.lora_index_mapping = lora_index_mapping
|
|
else:
|
|
self.lora_index_mapping.clear()
|
|
|
|
if lora_prompt_mapping:
|
|
self.lora_prompt_mapping = lora_prompt_mapping
|
|
else:
|
|
self.lora_prompt_mapping.clear()
|
|
|
|
if lora_requests:
|
|
self.lora_requests = lora_requests
|
|
else:
|
|
self.lora_requests.clear()
|
|
|
|
if prompt_adapter_index_mapping:
|
|
self.prompt_adapter_index_mapping = \
|
|
prompt_adapter_index_mapping
|
|
else:
|
|
self.prompt_adapter_index_mapping.clear()
|
|
|
|
if prompt_adapter_prompt_mapping:
|
|
self.prompt_adapter_prompt_mapping = \
|
|
prompt_adapter_prompt_mapping
|
|
else:
|
|
self.prompt_adapter_prompt_mapping.clear()
|
|
|
|
else:
|
|
self.input_tokens = input_tokens or []
|
|
self.input_positions = input_positions or []
|
|
self.token_types = token_types or []
|
|
self.mrope_input_positions = mrope_input_positions or None
|
|
self.seq_lens = seq_lens or []
|
|
self.orig_seq_lens = orig_seq_lens or []
|
|
self.query_lens = query_lens or []
|
|
self.context_lens = context_lens or []
|
|
self.curr_sliding_window_blocks = \
|
|
curr_sliding_window_blocks or []
|
|
|
|
self.lora_index_mapping = lora_index_mapping or []
|
|
self.lora_prompt_mapping = lora_prompt_mapping or []
|
|
self.lora_requests = lora_requests or set()
|
|
|
|
self.prompt_adapter_index_mapping = (
|
|
prompt_adapter_index_mapping or [])
|
|
self.prompt_adapter_prompt_mapping = (
|
|
prompt_adapter_prompt_mapping or [])
|
|
|
|
self.prompt_adapter_request = prompt_adapter_request
|
|
self.multi_modal_kwargs = multi_modal_kwargs
|
|
self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
|
|
self.prefix_cache_hit = prefix_cache_hit
|
|
|
|
self.n_seqs = len(self.seq_ids)
|
|
|
|
if not reinit:
|
|
self.__post_init__()
|
|
|
|
def __post_init__(self):
|
|
self.n_seqs = len(self.seq_ids)
|
|
|
|
self.input_tokens = [[] for _ in range(self.n_seqs)]
|
|
self.input_positions = [[] for _ in range(self.n_seqs)]
|
|
self.token_types = [[] for _ in range(self.n_seqs)]
|
|
self.mrope_input_positions = None
|
|
self.seq_lens = [0] * self.n_seqs
|
|
self.orig_seq_lens = [0] * self.n_seqs
|
|
self.query_lens = [0] * self.n_seqs
|
|
self.context_lens = [0] * self.n_seqs
|
|
self.curr_sliding_window_blocks = [0] * self.n_seqs
|
|
|
|
self.lora_index_mapping = []
|
|
self.lora_prompt_mapping = []
|
|
|
|
|
|
class NPUModelRunner(ModelRunner):
|
|
"""
|
|
NPU model runner with sampling step.
|
|
"""
|
|
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
|
|
ModelInputForGPUWithSamplingMetadata)
|
|
_builder_cls: Type[ModelInputForNPUBuilder] = ModelInputForNPUBuilder
|
|
|
|
def make_model_input_from_broadcasted_tensor_dict(
|
|
self,
|
|
tensor_dict: Dict[str, Any],
|
|
) -> ModelInputForGPUWithSamplingMetadata:
|
|
model_input = \
|
|
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
|
|
tensor_dict,
|
|
attn_backend=self.attn_backend,
|
|
)
|
|
return model_input
|
|
|
|
@current_platform.inference_mode()
|
|
def profile_run(self) -> None:
|
|
# Enable top-k sampling to reflect the accurate memory usage.
|
|
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
|
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
|
max_num_seqs = self.scheduler_config.max_num_seqs
|
|
# This represents the maximum number of different requests
|
|
# that will have unique loras, an therefore the max amount of memory
|
|
# consumption create dummy lora request copies from the lora request
|
|
# passed in, which contains a lora from the lora warmup path.
|
|
dummy_lora_requests: List[LoRARequest] = []
|
|
dummy_lora_requests_per_seq: List[LoRARequest] = []
|
|
if self.lora_config:
|
|
assert self.lora_manager is not None
|
|
with self.lora_manager.dummy_lora_cache():
|
|
for idx in range(self.lora_config.max_loras):
|
|
lora_id = idx + 1
|
|
dummy_lora_request = LoRARequest(
|
|
lora_name=f"warmup_{lora_id}",
|
|
lora_int_id=lora_id,
|
|
lora_path="/not/a/real/path",
|
|
)
|
|
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
|
rank=LORA_WARMUP_RANK)
|
|
dummy_lora_requests.append(dummy_lora_request)
|
|
dummy_lora_requests_per_seq = [
|
|
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
|
for idx in range(max_num_seqs)
|
|
]
|
|
|
|
# Profile memory usage with max_num_sequences sequences and the total
|
|
# number of tokens equal to max_num_batched_tokens.
|
|
seqs: List[SequenceGroupMetadata] = []
|
|
# Additional GPU memory may be needed for multi-modal encoding, which
|
|
# needs to be accounted for when calculating the GPU blocks for
|
|
# vLLM blocker manager.
|
|
# To exercise the worst scenario for GPU memory consumption,
|
|
# the number of seqs (batch_size) is chosen to maximize the number
|
|
# of images processed.
|
|
|
|
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
|
|
self.model_config)
|
|
if max_mm_tokens > 0:
|
|
max_num_seqs_orig = max_num_seqs
|
|
max_num_seqs = min(max_num_seqs,
|
|
max_num_batched_tokens // max_mm_tokens)
|
|
if max_num_seqs < 1:
|
|
expr = (f"min({max_num_seqs_orig}, "
|
|
f"{max_num_batched_tokens} // {max_mm_tokens})")
|
|
logger.warning(
|
|
"Computed max_num_seqs (%s) to be less than 1. "
|
|
"Setting it to the minimum value of 1.", expr)
|
|
max_num_seqs = 1
|
|
|
|
batch_size = 0
|
|
for group_id in range(max_num_seqs):
|
|
seq_len = (max_num_batched_tokens // max_num_seqs +
|
|
(group_id < max_num_batched_tokens % max_num_seqs))
|
|
batch_size += seq_len
|
|
|
|
dummy_data = self.input_registry \
|
|
.dummy_data_for_profiling(self.model_config,
|
|
seq_len,
|
|
self.mm_registry)
|
|
|
|
seq = SequenceGroupMetadata(
|
|
request_id=str(group_id),
|
|
is_prompt=True,
|
|
seq_data={group_id: dummy_data.seq_data},
|
|
sampling_params=sampling_params,
|
|
block_tables=None,
|
|
lora_request=dummy_lora_requests_per_seq[group_id]
|
|
if dummy_lora_requests_per_seq else None,
|
|
multi_modal_data=dummy_data.multi_modal_data,
|
|
multi_modal_placeholders=dummy_data.multi_modal_placeholders,
|
|
)
|
|
seqs.append(seq)
|
|
|
|
# Run the model with the dummy inputs.
|
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
|
# use an empty tensor instead of `None`` to force Dynamo to pass
|
|
# it by reference, rather by specializing on the value ``None``.
|
|
# the `dtype` argument does not matter, and we use `float32` as
|
|
# a placeholder (it has wide hardware support).
|
|
# it is important to create tensors inside the loop, rather than
|
|
# multiplying the list, to avoid Dynamo from treating them as
|
|
# tensor aliasing.
|
|
kv_caches = [
|
|
torch.tensor([], dtype=torch.float32, device=self.device)
|
|
for _ in range(num_layers)
|
|
]
|
|
finished_requests_ids = [seq.request_id for seq in seqs]
|
|
model_input = self.prepare_model_input(
|
|
seqs, finished_requests_ids=finished_requests_ids)
|
|
intermediate_tensors = None
|
|
if not get_pp_group().is_first_rank:
|
|
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
|
batch_size=batch_size,
|
|
dtype=self.model_config.dtype,
|
|
device=self.device)
|
|
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
|
current_platform.synchronize()
|
|
return
|
|
|
|
@current_platform.inference_mode()
|
|
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
|
|
"""NPU graph capture a model.
|
|
TODO: not support now
|
|
"""
|
|
pass
|
|
|
|
def prepare_model_input(
|
|
self,
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
virtual_engine: int = 0,
|
|
finished_requests_ids: Optional[List[str]] = None,
|
|
) -> ModelInputForGPUWithSamplingMetadata:
|
|
"""Prepare the model input based on a given sequence group, including
|
|
metadata for the sampling step.
|
|
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
|
|
The result tensors and data structure also batches input in prefill
|
|
-> decode order. For example,
|
|
- input_tokens[:num_prefill_tokens] contains prefill tokens.
|
|
- input_tokens[num_prefill_tokens:] contains decode tokens.
|
|
If cuda graph is required, this API automatically pads inputs.
|
|
"""
|
|
model_input = self._prepare_model_input_tensors(
|
|
seq_group_metadata_list, finished_requests_ids)
|
|
if get_pp_group().is_last_rank:
|
|
# Sampling metadata is only required for the final pp group
|
|
generators = self.get_generators(finished_requests_ids)
|
|
sampling_metadata = SamplingMetadata.prepare(
|
|
seq_group_metadata_list,
|
|
model_input.seq_lens,
|
|
model_input.query_lens,
|
|
self.device,
|
|
self.pin_memory,
|
|
generators,
|
|
self.sampling_metadata_cache,
|
|
# TODO (cmq): enable this after supported in vllm
|
|
# pad_for_invariant_seq_len=True,
|
|
)
|
|
else:
|
|
sampling_metadata = None
|
|
is_prompt = (seq_group_metadata_list[0].is_prompt
|
|
if seq_group_metadata_list else None)
|
|
return dataclasses.replace(model_input,
|
|
sampling_metadata=sampling_metadata,
|
|
is_prompt=is_prompt,
|
|
virtual_engine=virtual_engine)
|
|
|
|
def get_model(self) -> nn.Module:
|
|
return self.model
|