<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> This PR supports the access of vllm-acend to the piecewise_graph feature provided by the v1 engine. 1. register unifiled_ascend_attention_with_output for piecewise_graph to split graph. 2. support NPUGraph to accelerate kernel launch. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> support npugraph to default, Users can disenable the npugraph feature by configuring enforce_eager. This has corresponding requirements for the versions of torch_npu and CANN, and they need to support graph capture. ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> it turn to default --------- Signed-off-by: Bug Hunter Yan <yanpq@zju.edu.cn> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
969 lines
44 KiB
Python
969 lines
44 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
|
#
|
|
|
|
import gc
|
|
import os
|
|
import time
|
|
import weakref
|
|
from contextlib import contextmanager, nullcontext
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
|
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
import torch
|
|
import torch.nn as nn
|
|
from vllm.attention import AttentionType, get_attn_backend
|
|
from vllm.attention.layer import Attention
|
|
from vllm.config import CompilationLevel, VllmConfig
|
|
from vllm.distributed.parallel_state import get_pp_group
|
|
from vllm.forward_context import set_forward_context
|
|
from vllm.inputs import INPUT_REGISTRY
|
|
from vllm.logger import logger
|
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
from vllm.model_executor.model_loader import get_model
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
|
from vllm.sampling_params import SamplingType
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
|
LayerBlockType, LazyLoader, cdiv)
|
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|
KVCacheSpec)
|
|
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
|
from vllm.v1.utils import bind_kv_cache
|
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
|
|
|
from vllm_ascend.attention.attention import AttentionMaskBuilder
|
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
|
from vllm_ascend.platform import NPUPlatform
|
|
|
|
if TYPE_CHECKING:
|
|
import xgrammar as xgr # type: ignore[import-untyped]
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
else:
|
|
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
|
|
|
|
|
@dataclass
|
|
class GraphCaptureContext:
|
|
stream: torch.npu.Stream
|
|
|
|
|
|
@contextmanager
|
|
def graph_capture(device: torch.device):
|
|
"""
|
|
`graph_capture` is a context manager which should surround the code that
|
|
is capturing the NPU graph. Its main purpose is to ensure that the
|
|
some operations will be run after the graph is captured, before the graph
|
|
is replayed. It returns a `GraphCaptureContext` object which contains the
|
|
necessary data for the graph capture. Currently, it only contains the
|
|
stream that the graph capture is running on. This stream is set to the
|
|
current NPU stream when the context manager is entered and reset to the
|
|
default stream when the context manager is exited. This is to ensure that
|
|
the graph capture is running on a separate stream from the default stream,
|
|
in order to explicitly distinguish the kernels to capture
|
|
from other kernels possibly launched on background in the default stream.
|
|
"""
|
|
graph_capture_context = GraphCaptureContext(
|
|
torch.npu.Stream(device=device))
|
|
stream = graph_capture_context.stream
|
|
|
|
# we use nullcontext now
|
|
maybe_ca_context = nullcontext()
|
|
|
|
# ensure all initialization operations complete before attempting to
|
|
# capture the graph on another stream
|
|
curr_stream = torch.npu.current_stream()
|
|
if curr_stream != stream:
|
|
stream.wait_stream(curr_stream)
|
|
|
|
with torch.npu.stream(stream), maybe_ca_context:
|
|
yield graph_capture_context
|
|
|
|
|
|
class NPUModelRunner:
|
|
|
|
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
|
self.vllm_config = vllm_config
|
|
self.model_config = vllm_config.model_config
|
|
self.lora_config = vllm_config.lora_config
|
|
self.scheduler_config = vllm_config.scheduler_config
|
|
self.device = device
|
|
self.is_multimodal_model = self.model_config.is_multimodal_model
|
|
self.block_size = vllm_config.cache_config.block_size
|
|
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
|
|
self.block_size)
|
|
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
|
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
|
|
|
# Model-related.
|
|
self.num_attn_layers = self.model_config.get_num_layers_by_block_type(
|
|
vllm_config.parallel_config, LayerBlockType.attention)
|
|
self.hidden_size = self.model_config.get_hidden_size()
|
|
self.dtype = self.model_config.dtype
|
|
cache_config = vllm_config.cache_config
|
|
if cache_config.cache_dtype == "auto":
|
|
self.kv_cache_dtype = self.dtype
|
|
else:
|
|
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
|
cache_config.cache_dtype]
|
|
|
|
self.head_size = self.model_config.get_head_size()
|
|
self.attn_backend = get_attn_backend(
|
|
self.head_size,
|
|
self.dtype,
|
|
self.kv_cache_dtype,
|
|
self.block_size,
|
|
self.model_config.is_attention_free,
|
|
use_mla=self.model_config.use_mla,
|
|
)
|
|
if self.attn_backend is None:
|
|
error_msg = (
|
|
f"Error with get_att_backend: {self.head_size=}, "
|
|
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
|
|
f"{self.model_config.is_attention_free=}, "
|
|
f"{self.model_config.use_mla=}")
|
|
logger.error(error_msg)
|
|
raise NotImplementedError(
|
|
"Non-Attention backend is not supported by V1 NPUModelRunner.")
|
|
|
|
self.attn_backend = get_attn_backend(
|
|
self.head_size,
|
|
self.dtype,
|
|
self.kv_cache_dtype,
|
|
self.block_size,
|
|
self.model_config.is_attention_free,
|
|
use_mla=self.model_config.use_mla,
|
|
)
|
|
if self.attn_backend is None:
|
|
error_msg = (
|
|
f"Error with get_att_backend: {self.head_size=}, "
|
|
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
|
|
f"{self.model_config.is_attention_free=}, "
|
|
f"{self.model_config.use_mla=}")
|
|
logger.error(error_msg)
|
|
raise NotImplementedError(
|
|
"Non-Attention backend is not supported by V1 GPUModelRunner.")
|
|
|
|
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
|
weakref.proxy(self))
|
|
|
|
# Multi-modal data support
|
|
self.input_registry = INPUT_REGISTRY
|
|
self.mm_registry = MULTIMODAL_REGISTRY
|
|
self.uses_mrope = self.model_config.uses_mrope
|
|
|
|
self.max_num_encoder_input_tokens, self.encoder_cache_size = compute_encoder_budget(
|
|
model_config=self.model_config,
|
|
scheduler_config=self.scheduler_config,
|
|
mm_registry=self.mm_registry)
|
|
|
|
# Lazy initialization
|
|
# self.model: nn.Module # Set after load_model
|
|
self.kv_caches: List[torch.Tensor] = []
|
|
# req_id -> (input_id -> encoder_output)
|
|
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
|
|
|
# Request states.
|
|
self.requests: Dict[str, CachedRequestState] = {}
|
|
# Persistent batch.
|
|
self.input_batch = InputBatch(
|
|
max_num_reqs=self.max_num_reqs,
|
|
max_model_len=self.model_config.max_model_len,
|
|
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
|
device=self.device,
|
|
pin_memory=True,
|
|
vocab_size=self.model_config.get_vocab_size(),
|
|
)
|
|
|
|
self.input_ids = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int32,
|
|
device=self.device)
|
|
self.positions = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int64,
|
|
device=self.device)
|
|
# None in the first PP rank. The rest are set after load_model.
|
|
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
|
|
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
|
if self.uses_mrope:
|
|
# NOTE: `mrope_positions` is implemented with one additional dummy
|
|
# position on purpose to make it non-contiguous so that it can work
|
|
# with torch compile.
|
|
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
|
|
|
|
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
|
|
# the modality of inputs. For text-only inputs, each dimension has
|
|
# identical position IDs, making M-RoPE functionally equivalent to
|
|
# 1D-RoPE.
|
|
# See page 5 of https://arxiv.org/abs/2409.12191
|
|
self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
|
|
dtype=torch.int64,
|
|
device=self.device)
|
|
self.mrope_positions_cpu = torch.zeros(
|
|
(3, self.max_num_tokens + 1),
|
|
dtype=torch.int64,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
|
|
self.inputs_embeds = torch.zeros(
|
|
(self.max_num_tokens, self.hidden_size),
|
|
dtype=self.dtype,
|
|
device=self.device)
|
|
|
|
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
|
self.arange_np: npt.NDArray[np.int32] = np.arange(max(
|
|
self.max_num_reqs + 1, self.model_config.max_model_len,
|
|
self.max_num_tokens),
|
|
dtype=np.int32)
|
|
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
|
|
# a faster version of creating a new tensor every time. Thus, we should
|
|
# not make any assumptions about the values in these tensors.
|
|
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.positions_cpu = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int64,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.positions_np = self.positions_cpu.numpy()
|
|
|
|
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
|
|
|
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
|
|
|
|
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
|
|
dtype=torch.int32,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
|
|
|
self.input_positions_cpu = torch.arange(0,
|
|
self.max_num_tokens,
|
|
device="cpu")
|
|
self.attn_mask = None
|
|
self.attn_state = None
|
|
self.use_npu_graph = (self.vllm_config.compilation_config.level
|
|
== CompilationLevel.PIECEWISE
|
|
and not self.model_config.enforce_eager)
|
|
self.npugraph_batch_sizes = list(
|
|
reversed(
|
|
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
|
|
|
# NOTE: Pre-construct a mask matrix to improve the efficiency of
|
|
# attention mask construction during inference.
|
|
# Note that the length of the matrix needs to be carefully balanced: a
|
|
# matrix that is too large will consume excessive VRAM, while a matrix
|
|
# that is too small will require dynamic concatenation during inference,
|
|
# leading to performance degradation.
|
|
# Therefore, an environment variable is added here to dynamically set
|
|
# the size of the pre-constructed mask matrix based on requirements.
|
|
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
|
|
self.attn_mask_len = min(self.model_config.max_model_len,
|
|
int(mask_len))
|
|
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
|
|
self.attn_mask_len, self.dtype)
|
|
|
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
|
"""Update the cached states and the persistent batch with the scheduler
|
|
output.
|
|
|
|
The SamplingMetadata is updated and copied to the NPU if there is a
|
|
new/resumed/paused/finished request in the batch.
|
|
"""
|
|
# Remove finished requests from the cached states.
|
|
for req_id in scheduler_output.finished_req_ids:
|
|
self.requests.pop(req_id, None)
|
|
# Remove the finished requests from the persistent batch.
|
|
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
|
# scheduled_req_ids overlap. This happens when a request is aborted and
|
|
# then resubmitted with the same ID. In this case, we treat them as two
|
|
# distinct requests - clearing the cached states for the first request
|
|
# and handling the second as a new request.
|
|
removed_req_indices: List[int] = []
|
|
for req_id in scheduler_output.finished_req_ids:
|
|
req_index = self.input_batch.remove_request(req_id)
|
|
if req_index is not None:
|
|
removed_req_indices.append(req_index)
|
|
|
|
# Remove the unscheduled requests from the persistent batch.
|
|
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
|
# or running requests that are not scheduled in this step. We remove
|
|
# them from the persistent batch but keep their cached states since
|
|
# they will be scheduled again sometime in the future.
|
|
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
|
|
cached_req_ids = self.input_batch.req_id_to_index.keys()
|
|
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
|
|
# NOTE(woosuk): The persistent batch optimization assumes that
|
|
# consecutive batches contain mostly the same requests. If batches
|
|
# have low request overlap (e.g., alternating between two distinct
|
|
# sets of requests), this optimization becomes very inefficient.
|
|
for req_id in unscheduled_req_ids:
|
|
req_index = self.input_batch.remove_request(req_id)
|
|
assert req_index is not None
|
|
removed_req_indices.append(req_index)
|
|
|
|
req_ids_to_add: List[str] = []
|
|
# Add new requests to the cached states.
|
|
for new_req_data in scheduler_output.scheduled_new_reqs:
|
|
req_id = new_req_data.req_id
|
|
sampling_params = new_req_data.sampling_params
|
|
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
|
generator = torch.Generator(device=self.device)
|
|
generator.manual_seed(sampling_params.seed)
|
|
else:
|
|
generator = None
|
|
|
|
self.requests[req_id] = CachedRequestState(
|
|
req_id=req_id,
|
|
prompt_token_ids=new_req_data.prompt_token_ids,
|
|
prompt=new_req_data.prompt,
|
|
mm_inputs=new_req_data.mm_inputs,
|
|
mm_positions=new_req_data.mm_positions,
|
|
sampling_params=sampling_params,
|
|
generator=generator,
|
|
block_ids=new_req_data.block_ids,
|
|
num_computed_tokens=new_req_data.num_computed_tokens,
|
|
output_token_ids=[],
|
|
lora_request=new_req_data.lora_request,
|
|
)
|
|
|
|
req_ids_to_add.append(req_id)
|
|
|
|
# Update the states of the running/resumed requests.
|
|
for req_data in scheduler_output.scheduled_cached_reqs:
|
|
req_id = req_data.req_id
|
|
req_state = self.requests[req_id]
|
|
|
|
# Update the cached states.
|
|
num_computed_tokens = req_data.num_computed_tokens
|
|
req_state.num_computed_tokens = num_computed_tokens
|
|
# Add the sampled token(s) from the previous step (if any).
|
|
# This doesn't include "unverified" tokens like spec decode tokens.
|
|
num_new_tokens = (num_computed_tokens +
|
|
len(req_data.new_token_ids) -
|
|
req_state.num_tokens)
|
|
if num_new_tokens == 1:
|
|
# Avoid slicing list in most common case.
|
|
req_state.output_token_ids.append(req_data.new_token_ids[-1])
|
|
elif num_new_tokens > 0:
|
|
req_state.output_token_ids.extend(
|
|
req_data.new_token_ids[-num_new_tokens:])
|
|
# Update the block IDs.
|
|
if not req_data.resumed_from_preemption:
|
|
# Append the new blocks to the existing block IDs.
|
|
req_state.block_ids.extend(req_data.new_block_ids)
|
|
else:
|
|
# The request is resumed from preemption.
|
|
# Replace the existing block IDs with the new ones.
|
|
req_state.block_ids = req_data.new_block_ids
|
|
|
|
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
if req_index is None:
|
|
# The request is not in the persistent batch.
|
|
# The request was either preempted and resumed later, or was not
|
|
# scheduled in the previous step and needs to be added again.
|
|
req_ids_to_add.append(req_id)
|
|
continue
|
|
|
|
# Update the persistent batch.
|
|
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
|
num_computed_tokens)
|
|
|
|
start_index = (len(req_state.block_ids) -
|
|
len(req_data.new_block_ids))
|
|
self.input_batch.block_table.append_row(req_data.new_block_ids,
|
|
req_index)
|
|
# Add new_token_ids to token_ids_cpu.
|
|
start_token_index = num_computed_tokens
|
|
end_token_index = num_computed_tokens + len(req_data.new_token_ids)
|
|
self.input_batch.token_ids_cpu[
|
|
req_index,
|
|
start_token_index:end_token_index] = req_data.new_token_ids
|
|
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
|
|
# Add spec_token_ids to token_ids_cpu.
|
|
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
|
req_id, ())
|
|
if spec_token_ids:
|
|
start_index = end_token_index
|
|
end_token_index += len(spec_token_ids)
|
|
self.input_batch.token_ids_cpu[
|
|
req_index, start_index:end_token_index] = spec_token_ids
|
|
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
|
|
self.input_batch.num_tokens[req_index] = end_token_index
|
|
|
|
# Check if the batch has changed. If not, we can skip copying the
|
|
# sampling metadata from CPU to GPU.
|
|
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
|
|
|
|
# Add the new or resumed requests to the persistent batch.
|
|
# The smaller empty indices are filled first.
|
|
removed_req_indices = sorted(removed_req_indices, reverse=True)
|
|
for req_id in req_ids_to_add:
|
|
req_state = self.requests[req_id]
|
|
if removed_req_indices:
|
|
# Fill the empty index.
|
|
req_index = removed_req_indices.pop()
|
|
else:
|
|
# Append to the end.
|
|
req_index = None
|
|
self.input_batch.add_request(req_state, req_index)
|
|
|
|
# Condense the batched states if there are empty indices.
|
|
if removed_req_indices:
|
|
self.input_batch.condense(removed_req_indices)
|
|
|
|
if batch_changed:
|
|
self.input_batch.refresh_sampling_metadata()
|
|
|
|
def get_model(self) -> nn.Module:
|
|
return self.model
|
|
|
|
def _make_attention_mask(self, seq_lens, query_lens, position,
|
|
attn_state) -> torch.Tensor:
|
|
# Chunk Prefill situation.
|
|
if attn_state == AscendAttentionState.ChunkedPrefill:
|
|
return self.attn_mask_builder.get_splitfuse_attn_mask(
|
|
seq_lens, query_lens, position, self.dtype, self.device)
|
|
# Prefill-only situation.
|
|
elif attn_state == AscendAttentionState.PrefillOnly:
|
|
max_seq_len = max(seq_lens, default=0)
|
|
return self.attn_mask_builder.get_attn_mask(
|
|
max_seq_len, self.dtype, self.device)
|
|
# Decode-only situation.
|
|
else:
|
|
return None
|
|
|
|
def _process_reqs(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
) -> torch.Tensor:
|
|
# Check input valid
|
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
|
assert total_num_scheduled_tokens > 0
|
|
num_reqs = self.input_batch.num_reqs
|
|
assert num_reqs > 0
|
|
|
|
modified_batch = self.attn_metadata_builder.reorder_batch(
|
|
self.input_batch, scheduler_output)
|
|
if modified_batch:
|
|
self.input_batch.refresh_sampling_metadata()
|
|
|
|
# OPTIMIZATION: Start copying the block table first.
|
|
# This way, we can overlap the copy with the following CPU operations.
|
|
self.input_batch.block_table.commit(num_reqs)
|
|
|
|
# Get the number of scheduled tokens for each request.
|
|
# TODO: The Python loop can be slow. Optimize.
|
|
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
|
|
max_num_scheduled_tokens = 0
|
|
for i, req_id in enumerate(self.input_batch.req_ids):
|
|
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
|
num_scheduled_tokens[i] = num_tokens
|
|
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
|
num_tokens)
|
|
|
|
# Prepare positions
|
|
req_indices = np.repeat(self.arange_np[:num_reqs],
|
|
num_scheduled_tokens)
|
|
cu_num_tokens = np.cumsum(num_scheduled_tokens)
|
|
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
|
|
num_scheduled_tokens)
|
|
sample_indices = cu_num_tokens - 1
|
|
sample_indices = torch.from_numpy(sample_indices).to(self.device,
|
|
non_blocking=True)
|
|
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
|
|
|
|
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
|
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
|
|
arange,
|
|
out=positions_np)
|
|
|
|
self.positions[:total_num_scheduled_tokens].copy_(
|
|
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
|
positions = self.positions[:total_num_scheduled_tokens]
|
|
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
|
|
|
self.seq_lens_np[:num_reqs] = (
|
|
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
|
num_scheduled_tokens)
|
|
seq_lens = self.seq_lens_cpu[:num_reqs]
|
|
|
|
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
|
positions_np // self.block_size)
|
|
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
|
|
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
|
|
block_offsets = positions_np % self.block_size
|
|
np.add(block_numbers * self.block_size,
|
|
block_offsets,
|
|
out=self.slot_mapping_np[:total_num_scheduled_tokens])
|
|
|
|
attn_state = AscendAttentionState.ChunkedPrefill
|
|
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
|
attn_state = AscendAttentionState.PrefillOnly
|
|
elif np.all(num_scheduled_tokens == 1):
|
|
attn_state = AscendAttentionState.DecodeOnly
|
|
else:
|
|
attn_state = AscendAttentionState.ChunkedPrefill
|
|
|
|
attn_mask = self._make_attention_mask(seq_lens=seq_lens,
|
|
query_lens=num_scheduled_tokens,
|
|
position=positions,
|
|
attn_state=attn_state)
|
|
self.attn_mask = attn_mask
|
|
self.attn_state = attn_state # type: ignore
|
|
|
|
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
|
num_reqs=num_reqs,
|
|
num_actual_tokens=total_num_scheduled_tokens,
|
|
max_query_len=max_num_scheduled_tokens,
|
|
common_prefix_len=None,
|
|
)
|
|
|
|
# Prepare input_ids
|
|
token_indices = (positions_np +
|
|
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
|
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
|
0,
|
|
torch.from_numpy(token_indices),
|
|
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
|
# Copy the tensors to the NPU.
|
|
self.input_ids[:total_num_scheduled_tokens].copy_(
|
|
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
|
input_ids = self.input_ids[:total_num_scheduled_tokens]
|
|
|
|
# Run forward pass
|
|
with set_forward_context(attn_metadata, self.vllm_config):
|
|
assert self.model is not None
|
|
hidden_states = self.model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=None,
|
|
)
|
|
|
|
return hidden_states[sample_indices]
|
|
|
|
def apply_grammar_bitmask(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
logits: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
# Serialization of np.ndarray is much more efficient than a tensor,
|
|
# so we receive it in that format.
|
|
grammar_bitmask = scheduler_output.grammar_bitmask
|
|
if grammar_bitmask is None:
|
|
return
|
|
|
|
# We receive the structured output bitmask from the scheduler, but the
|
|
# indices of the requests in the batch may not match the indices of
|
|
# the bitmask since the scheduler doesn't know how the gpu runner is
|
|
# ordering the requests in the batch. We need to sort the bitmask to
|
|
# match the order of the requests used here.
|
|
struct_out_req_batch_indices: dict[str, int] = {}
|
|
indices_match = True
|
|
for req_id in self.input_batch.req_ids:
|
|
mask_index = scheduler_output.structured_output_request_ids.get(
|
|
req_id)
|
|
if mask_index is None:
|
|
# not a structured output request
|
|
continue
|
|
batch_index = self.input_batch.req_id_to_index[req_id]
|
|
if batch_index != mask_index:
|
|
indices_match = False
|
|
struct_out_req_batch_indices[req_id] = batch_index
|
|
|
|
if not indices_match:
|
|
# Sort the bitmask to match the order of the requests
|
|
sorted_bitmask = np.zeros_like(grammar_bitmask)
|
|
for req_id, batch_index in struct_out_req_batch_indices.items():
|
|
orig_index = scheduler_output.structured_output_request_ids[
|
|
req_id]
|
|
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
|
|
grammar_bitmask = sorted_bitmask
|
|
|
|
grammar_bitmask = torch.from_numpy(grammar_bitmask)
|
|
|
|
# TODO: compatibility with spec decode.
|
|
# NOTE:
|
|
# 1. XGrammar bitmask applying only supports CPU and GPU.
|
|
# 2. The logits and bitmask should be on the same device.
|
|
# 3. XGrammar logits on CPU only supports float32 dtype.
|
|
logits_dtype = logits.dtype
|
|
logits = logits.to("cpu").float()
|
|
xgr.apply_token_bitmask_inplace(
|
|
logits,
|
|
grammar_bitmask,
|
|
indices=list(struct_out_req_batch_indices.values()),
|
|
)
|
|
return logits.to(self.device).to(logits_dtype)
|
|
|
|
@torch.inference_mode()
|
|
def execute_model(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
) -> Union[ModelRunnerOutput, torch.Tensor]:
|
|
self._update_states(scheduler_output)
|
|
if not scheduler_output.total_num_scheduled_tokens:
|
|
# Return empty ModelRunnerOuptut if there's no work to do.
|
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
|
hidden_states = self._process_reqs(scheduler_output,
|
|
intermediate_tensors)
|
|
logits = self.model.compute_logits(hidden_states, None)
|
|
|
|
# Apply structured output bitmasks if present
|
|
if scheduler_output.grammar_bitmask is not None:
|
|
logits = self.apply_grammar_bitmask(scheduler_output, logits)
|
|
|
|
# Sample the next token and get logprobs if needed.
|
|
sampling_metadata = self.input_batch.sampling_metadata
|
|
sampler_output = self.model.sample(
|
|
logits=logits,
|
|
sampling_metadata=sampling_metadata,
|
|
)
|
|
|
|
# TODO(woosuk): The following loop can be slow since it iterates over
|
|
# the requests one by one. Optimize.
|
|
for i, req_id in enumerate(self.input_batch.req_ids):
|
|
req_state = self.requests[req_id]
|
|
seq_len = (req_state.num_computed_tokens +
|
|
scheduler_output.num_scheduled_tokens[req_id])
|
|
if seq_len < req_state.num_tokens:
|
|
# Ignore the sampled token.
|
|
# Rewind the generator state as if the token was not sampled.
|
|
generator = self.input_batch.generators.get(i)
|
|
if generator is not None:
|
|
generator.set_offset(generator.get_offset() - 4)
|
|
|
|
# NOTE: NPU -> CPU Sync happens here.
|
|
# Move as many CPU operations as possible before this sync point.
|
|
logprobs_tensors = sampler_output.logprobs_tensors
|
|
logprobs_lists = logprobs_tensors.tolists() \
|
|
if logprobs_tensors is not None else None
|
|
|
|
# Get the valid generated tokens.
|
|
sampled_token_ids = sampler_output.sampled_token_ids
|
|
max_gen_len = sampled_token_ids.shape[-1]
|
|
if max_gen_len == 1:
|
|
# No spec decode tokens.
|
|
valid_sampled_token_ids = sampled_token_ids.tolist()
|
|
|
|
model_runner_output = ModelRunnerOutput(
|
|
req_ids=self.input_batch.req_ids,
|
|
req_id_to_index=self.input_batch.req_id_to_index,
|
|
sampled_token_ids=valid_sampled_token_ids,
|
|
spec_token_ids=None,
|
|
logprobs=logprobs_lists,
|
|
prompt_logprobs_dict={},
|
|
)
|
|
return model_runner_output
|
|
|
|
def _profile_multimodal(self) -> None:
|
|
# TODO: handle encoder-decoder models once we support them.
|
|
# NOTE: Currently model is profiled with a single non-text
|
|
# modality with the max possible input tokens even when
|
|
# it supports multiple.
|
|
|
|
if (not self.is_multimodal_model
|
|
or self.max_num_encoder_input_tokens <= 0
|
|
or self.encoder_cache_size <= 0):
|
|
return
|
|
|
|
max_tokens_by_modality_dict = (
|
|
MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality(
|
|
self.model_config))
|
|
dummy_data_modality, max_tokens_per_mm_item = max(
|
|
max_tokens_by_modality_dict.items(), key=lambda item: item[1])
|
|
|
|
# Check how many items of this modality can be supported by
|
|
# the encoder budget.
|
|
encoder_budget = min(self.max_num_encoder_input_tokens,
|
|
self.encoder_cache_size)
|
|
|
|
max_num_mm_items_encoder_budget = cdiv(encoder_budget,
|
|
max_tokens_per_mm_item)
|
|
|
|
# Check how many items of this modality can be supported by
|
|
# the decoder budget.
|
|
max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
|
|
self.model_config)[dummy_data_modality]
|
|
|
|
# NOTE: We do not consider max_num_batched_tokens on purpose
|
|
# because the multimodal embeddings can be generated in advance
|
|
# and chunked prefilled.
|
|
max_num_mm_items_decoder_budget = self.max_num_reqs * \
|
|
max_mm_items_per_req
|
|
|
|
max_num_mm_items = min(max_num_mm_items_encoder_budget,
|
|
max_num_mm_items_decoder_budget)
|
|
|
|
logger.info(
|
|
"Encoder cache will be initialized with a budget of %s tokens,"
|
|
" and profiled with %s %s items of the maximum feature size.",
|
|
encoder_budget, max_num_mm_items, dummy_data_modality)
|
|
|
|
# Create dummy batch of multimodal inputs.
|
|
dummy_request_data = self.input_registry.dummy_data_for_profiling(
|
|
model_config=self.model_config,
|
|
seq_len=self.max_num_tokens,
|
|
mm_registry=self.mm_registry,
|
|
)
|
|
dummy_mm_data = dummy_request_data.multi_modal_data
|
|
|
|
if not isinstance(dummy_mm_data, MultiModalKwargs):
|
|
# TODO: Delete this check once input mapper is fully removed.
|
|
raise RuntimeError("Legacy input mapper is not supported in V1")
|
|
|
|
# Dummy data definition in V0 may contain multiple multimodal items
|
|
# (e.g, multiple images) for a single request, therefore here we
|
|
# always replicate first item by max_num_mm_items times since in V1
|
|
# they are scheduled to be processed separately.
|
|
|
|
dummy_mm_item = dummy_mm_data.get_item(modality=dummy_data_modality,
|
|
item_index=0)
|
|
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
|
|
|
|
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
|
|
max_num_mm_items)
|
|
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
|
|
batched_dummy_mm_inputs, device=self.device)
|
|
|
|
# Run multimodal encoder.
|
|
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
|
**batched_dummy_mm_inputs)
|
|
assert len(dummy_encoder_outputs) == max_num_mm_items, (
|
|
"Expected dimension 0 of encoder outputs to match the number "
|
|
f"of multimodal data items: {max_num_mm_items}, got "
|
|
f"{len(dummy_encoder_outputs)=} instead. This is most likely "
|
|
"due to the 'get_multimodal_embeddings' method of the model "
|
|
"not implemented correctly.")
|
|
|
|
# Cache the dummy encoder outputs.
|
|
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
|
|
|
@torch.inference_mode()
|
|
def _dummy_run(self, num_tokens: int) -> torch.Tensor:
|
|
model = self.model
|
|
if self.is_multimodal_model:
|
|
input_ids = None
|
|
inputs_embeds = self.inputs_embeds[:num_tokens]
|
|
else:
|
|
input_ids = self.input_ids[:num_tokens]
|
|
inputs_embeds = None
|
|
|
|
if self.uses_mrope:
|
|
positions = self.mrope_positions[:, :num_tokens]
|
|
else:
|
|
positions = self.positions[:num_tokens]
|
|
|
|
if get_pp_group().is_first_rank:
|
|
intermediate_tensors = None
|
|
else:
|
|
if self.intermediate_tensors is None:
|
|
self.intermediate_tensors = (
|
|
self.model.make_empty_intermediate_tensors(
|
|
batch_size=num_tokens,
|
|
dtype=self.dtype,
|
|
device=self.device))
|
|
intermediate_tensors = IntermediateTensors({
|
|
k: v[:num_tokens]
|
|
for k, v in self.intermediate_tensors.items()
|
|
})
|
|
|
|
with set_forward_context(None, self.vllm_config):
|
|
hidden_states = model(input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=inputs_embeds)
|
|
return hidden_states
|
|
|
|
def profile_run(self) -> None:
|
|
# Profile with multimodal encoder & encoder cache.
|
|
self._profile_multimodal()
|
|
|
|
# For profile, have maximum num_reqs and that collectively have
|
|
# maximum num_tokens.
|
|
num_reqs = self.scheduler_config.max_num_seqs
|
|
num_tokens = self.max_num_tokens
|
|
min_tokens_per_req = num_tokens // num_reqs
|
|
|
|
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
|
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
|
assert sum(num_scheduled_tokens_list) == num_tokens
|
|
assert len(num_scheduled_tokens_list) == num_reqs
|
|
|
|
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
|
dtype=np.int32)
|
|
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
|
|
|
# assert self.lora_manager is not None, "LoRA is not enabled"
|
|
# TODO: call maybe_profile_with_lora()
|
|
|
|
dummy_kv_caches = [
|
|
torch.tensor((), dtype=torch.float32, device=self.device)
|
|
for _ in range(self.num_attn_layers)
|
|
]
|
|
|
|
# Trigger compilation for general shape.
|
|
hidden_states = self._dummy_run(self.max_num_tokens)
|
|
|
|
if get_pp_group().is_last_rank:
|
|
hidden_states = hidden_states[logit_indices]
|
|
logits = self.model.compute_logits(hidden_states, None)
|
|
else:
|
|
logits = None
|
|
|
|
NPUPlatform.synchronize()
|
|
del hidden_states, logits, dummy_kv_caches
|
|
self.encoder_cache.clear()
|
|
gc.collect()
|
|
|
|
def load_model(self) -> None:
|
|
logger.info("Starting to load model %s...", self.model_config.model)
|
|
|
|
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
|
self.model = get_model(vllm_config=self.vllm_config)
|
|
if self.lora_config:
|
|
raise ValueError("LoRA model is not supported on NPU now.")
|
|
logger.info("Loading model weights took %.4f GB",
|
|
m.consumed_memory / float(2**30))
|
|
|
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
"""
|
|
Initialize KV cache based on `kv_cache_config`.
|
|
Args:
|
|
kv_cache_config: Configuration for the KV cache, including the KV
|
|
cache size of each layer
|
|
"""
|
|
import torch_npu
|
|
kv_caches: Dict[str, torch.Tensor] = {}
|
|
for kv_cache_group in kv_cache_config.kv_cache_groups:
|
|
kv_cache_spec = kv_cache_group.kv_cache_spec
|
|
for layer_name in kv_cache_group.layer_names:
|
|
tensor_config = kv_cache_config.tensors[layer_name]
|
|
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
|
|
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
|
|
# `num_blocks` is the number of blocks the model runner can use.
|
|
# `kv_cache_config.num_blocks` is the number of blocks that
|
|
# KVCacheManager may allocate.
|
|
# Since different GPUs may have different number of layers and
|
|
# different memory capacities, `num_blocks` can be different on
|
|
# different GPUs, and `kv_cache_config.num_blocks` is set to
|
|
# the min of all `num_blocks`. Verify it here.
|
|
assert num_blocks >= kv_cache_config.num_blocks
|
|
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
|
# encounter OOM issue
|
|
if isinstance(kv_cache_spec, FullAttentionSpec):
|
|
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
|
num_blocks, kv_cache_spec.block_size,
|
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
|
dtype = kv_cache_spec.dtype
|
|
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
|
|
dtype=dtype,
|
|
device=self.device)
|
|
torch_npu.npu_format_cast(kv_caches[layer_name], 2)
|
|
else:
|
|
# TODO: add new branches when introducing more types of
|
|
# KV cache specs.
|
|
raise ValueError("Unknown KV cache spec type.")
|
|
|
|
bind_kv_cache(
|
|
kv_caches,
|
|
self.vllm_config.compilation_config.static_forward_context,
|
|
self.kv_caches)
|
|
|
|
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
|
"""
|
|
Generates the KVCacheSpec by parsing the kv cache format from each
|
|
Attention module in the static forward context.
|
|
Returns:
|
|
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
|
format. Layers that do not need KV cache are not included.
|
|
"""
|
|
|
|
forward_ctx = self.vllm_config.compilation_config.static_forward_context
|
|
block_size = self.vllm_config.cache_config.block_size
|
|
use_mla = self.vllm_config.model_config.use_mla
|
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
|
for layer_name, attn_module in forward_ctx.items():
|
|
if isinstance(attn_module, FusedMoE):
|
|
continue
|
|
|
|
# TODO: Support other attention modules, e.g., sliding window,
|
|
# cross-attention
|
|
assert isinstance(attn_module, Attention)
|
|
if attn_module.attn_type == AttentionType.DECODER:
|
|
kv_cache_spec[layer_name] = FullAttentionSpec(
|
|
block_size=block_size,
|
|
num_kv_heads=attn_module.num_kv_heads,
|
|
head_size=attn_module.head_size,
|
|
dtype=attn_module.dtype,
|
|
use_mla=use_mla)
|
|
elif attn_module.attn_type in (AttentionType.ENCODER,
|
|
AttentionType.ENCODER_ONLY):
|
|
# encoder-only attention does not need KV cache.
|
|
continue
|
|
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
|
raise NotImplementedError
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown attention type: {attn_module.attn_type}")
|
|
|
|
return kv_cache_spec
|
|
|
|
def capture_model(self) -> None:
|
|
if not self.use_npu_graph:
|
|
logger.warning(
|
|
"Skipping NPU graph capture. Please add "
|
|
"-O %s to use NPU graphs.", CompilationLevel.PIECEWISE)
|
|
return
|
|
|
|
start_time = time.perf_counter()
|
|
start_free_npu_memory = torch.npu.mem_get_info()[0]
|
|
|
|
# Trigger NPU graph capture for specific shapes.
|
|
# Capture the large shapes first so that the smaller shapes
|
|
# can reuse the memory pool allocated for the large shapes.
|
|
with graph_capture(device=self.device):
|
|
for num_tokens in reversed(self.npugraph_batch_sizes):
|
|
for _ in range(self.vllm_config.compilation_config.
|
|
cudagraph_num_of_warmups):
|
|
self._dummy_run(num_tokens)
|
|
self._dummy_run(num_tokens)
|
|
|
|
end_time = time.perf_counter()
|
|
end_free_npu_memory = torch.npu.mem_get_info()[0]
|
|
elapsed_time = end_time - start_time
|
|
npu_graph_size = start_free_npu_memory - end_free_npu_memory
|
|
# This usually takes 5~20 seconds.
|
|
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
|
elapsed_time, npu_graph_size / (1 << 30))
|